Spaces:
Sleeping
Sleeping
github-actions[bot] commited on
Commit ·
79f9b3a
0
Parent(s):
Sync from GitHub 38cd8d69dc858672e22cd1448f7768fef87468b1
Browse files- .dockerignore +12 -0
- .gitattributes +10 -0
- .github/workflows/sync_to_hf_space.yml +90 -0
- .gitignore +34 -0
- Dockerfile +35 -0
- LICENSE +21 -0
- README.md +158 -0
- app.py +754 -0
- data/benchmark_cases.csv +79 -0
- data/harmonized-system/harmonized-system.csv +0 -0
- data/hs_codes_reference.json +3 -0
- data/hts/us_hts_lookup.json +3 -0
- data/sample_documents/customs_zh.png +0 -0
- data/sample_documents/invoice_en.png +0 -0
- data/sample_documents/invoice_vi.png +0 -0
- data/sample_documents/packing_list_th.png +0 -0
- dataset/ATTRIBUTION.md +19 -0
- dataset/README.md +66 -0
- field_extractor.py +358 -0
- hs_dataset.py +341 -0
- models/.gitkeep +0 -0
- requirements-dev.txt +2 -0
- requirements.txt +15 -0
- static/.gitkeep +0 -0
- templates/index.html +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
venv
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
.DS_Store
|
| 7 |
+
uploads
|
| 8 |
+
scripts
|
| 9 |
+
README.md
|
| 10 |
+
LICENSE
|
| 11 |
+
data/training_data.json
|
| 12 |
+
data/sample_documents
|
.gitattributes
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/training_data.json filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
models/umap_data.json filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
data/training_data.csv filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
data/training_data_indexed.csv filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
data/cargo_descriptions.csv filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
data/hts/us_hts_lookup.json filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
data/hts/*.csv filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
data/*.json filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync_to_hf_space.yml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync GitHub to Hugging Face Space
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
sync-to-hf:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
steps:
|
| 13 |
+
- name: Checkout
|
| 14 |
+
uses: actions/checkout@v4
|
| 15 |
+
with:
|
| 16 |
+
fetch-depth: 0
|
| 17 |
+
lfs: true
|
| 18 |
+
|
| 19 |
+
- name: Pull LFS files
|
| 20 |
+
run: git lfs pull
|
| 21 |
+
|
| 22 |
+
- name: Push to Hugging Face Space
|
| 23 |
+
env:
|
| 24 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 25 |
+
run: |
|
| 26 |
+
if [ -z "${HF_TOKEN}" ]; then
|
| 27 |
+
echo "HF_TOKEN secret is not set."
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
DEPLOY_DIR="/tmp/hf-deploy"
|
| 32 |
+
rm -rf "${DEPLOY_DIR}"
|
| 33 |
+
mkdir -p "${DEPLOY_DIR}"
|
| 34 |
+
|
| 35 |
+
# Export the current tree (with real LFS file contents, not pointers)
|
| 36 |
+
tar --exclude=.git -cf - . | (cd "${DEPLOY_DIR}" && tar -xf -)
|
| 37 |
+
cd "${DEPLOY_DIR}"
|
| 38 |
+
|
| 39 |
+
# Keep GitHub README clean; inject Space front matter only for HF deploy.
|
| 40 |
+
if [ -f README.md ]; then
|
| 41 |
+
awk '
|
| 42 |
+
NR == 1 && $0 == "---" {in_yaml=1; next}
|
| 43 |
+
in_yaml && $0 == "---" {in_yaml=0; next}
|
| 44 |
+
!in_yaml {print}
|
| 45 |
+
' README.md > README.clean.md
|
| 46 |
+
|
| 47 |
+
printf '%s\n' \
|
| 48 |
+
'---' \
|
| 49 |
+
'title: HS Code Classifier Micro' \
|
| 50 |
+
'emoji: ⚡' \
|
| 51 |
+
'colorFrom: pink' \
|
| 52 |
+
'colorTo: blue' \
|
| 53 |
+
'sdk: docker' \
|
| 54 |
+
'app_port: 7860' \
|
| 55 |
+
'---' \
|
| 56 |
+
> README.frontmatter.md
|
| 57 |
+
|
| 58 |
+
cat README.frontmatter.md README.clean.md > README.md
|
| 59 |
+
rm -f README.frontmatter.md README.clean.md
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
# HF rejects files >10MB without Git LFS.
|
| 63 |
+
git lfs install
|
| 64 |
+
git init
|
| 65 |
+
git checkout -b main
|
| 66 |
+
git config user.name "github-actions[bot]"
|
| 67 |
+
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
| 68 |
+
|
| 69 |
+
# Track large files with LFS for HF (only files still bundled)
|
| 70 |
+
git lfs track "data/*.json" "models/umap_data.json"
|
| 71 |
+
git add .gitattributes
|
| 72 |
+
|
| 73 |
+
# Remove files not needed at runtime to stay under HF Space 1GB limit.
|
| 74 |
+
# Large artifacts (sentence model, embeddings, classifier, training data)
|
| 75 |
+
# are hosted on HF Hub at $SENTENCE_MODEL_NAME and downloaded at startup.
|
| 76 |
+
rm -rf scripts/
|
| 77 |
+
rm -rf models/sentence_model/
|
| 78 |
+
rm -f models/embeddings.npy models/knn_classifier.pkl models/label_encoder.pkl models/metadata.json models/umap_data.json
|
| 79 |
+
touch models/.gitkeep
|
| 80 |
+
rm -f data/training_data.csv data/training_data_indexed.csv
|
| 81 |
+
rm -f data/hts/hts_*.csv
|
| 82 |
+
rm -f data/cargo_descriptions.csv
|
| 83 |
+
rm -f data/training_data.json
|
| 84 |
+
rm -f data/hf_real_data.csv
|
| 85 |
+
|
| 86 |
+
git add -A
|
| 87 |
+
git commit -m "Sync from GitHub ${GITHUB_SHA}"
|
| 88 |
+
|
| 89 |
+
git remote add hf "https://oauth2:${HF_TOKEN}@huggingface.co/spaces/Mead0w1ark/MicroHS"
|
| 90 |
+
git push --force hf main
|
.gitignore
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
*.egg
|
| 9 |
+
venv/
|
| 10 |
+
.venv/
|
| 11 |
+
env/
|
| 12 |
+
|
| 13 |
+
# Sentence model weights (downloaded from HF Hub at startup)
|
| 14 |
+
models/sentence_model/model.safetensors
|
| 15 |
+
models/sentence_model/tokenizer.json
|
| 16 |
+
|
| 17 |
+
# IDE
|
| 18 |
+
.vscode/
|
| 19 |
+
.idea/
|
| 20 |
+
*.swp
|
| 21 |
+
*.swo
|
| 22 |
+
|
| 23 |
+
# OS
|
| 24 |
+
.DS_Store
|
| 25 |
+
Thumbs.db
|
| 26 |
+
|
| 27 |
+
# Jupyter
|
| 28 |
+
.ipynb_checkpoints/
|
| 29 |
+
|
| 30 |
+
# Local publish staging
|
| 31 |
+
.hf_dataset_release/
|
| 32 |
+
|
| 33 |
+
# Benchmark output
|
| 34 |
+
benchmark_results.json
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# syntax=docker/dockerfile:1
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PYTHONUNBUFFERED=1 \
|
| 7 |
+
PIP_NO_CACHE_DIR=1 \
|
| 8 |
+
SENTENCE_MODEL_NAME=intfloat/multilingual-e5-small \
|
| 9 |
+
HF_ARTIFACT_REPO=Mead0w1ark/multilingual-e5-small-hs-codes
|
| 10 |
+
|
| 11 |
+
# System deps for OCR endpoints:
|
| 12 |
+
# - tesseract for image OCR
|
| 13 |
+
# - poppler-utils for pdf2image PDF conversion
|
| 14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 15 |
+
tesseract-ocr \
|
| 16 |
+
poppler-utils \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Copy requirements.txt and install dependencies
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy only runtime files (keeps build context and cache churn smaller)
|
| 24 |
+
COPY app.py field_extractor.py hs_dataset.py ./
|
| 25 |
+
COPY templates ./templates
|
| 26 |
+
COPY static ./static
|
| 27 |
+
COPY data ./data
|
| 28 |
+
COPY models ./models
|
| 29 |
+
RUN mkdir -p uploads
|
| 30 |
+
|
| 31 |
+
# Expose the port FastAPI will run on
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# Command to run the FastAPI application
|
| 35 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 James Ball
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HS Code Classifier Micro
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
# HSClassify_micro 🔍
|
| 10 |
+
|
| 11 |
+
[](https://opensource.org/licenses/MIT)
|
| 12 |
+
[](https://www.python.org/downloads/)
|
| 13 |
+
|
| 14 |
+
**Machine learning model for multilingual HS/HTS classification** for trade finance and customs workflows, built with FastAPI + OCR.
|
| 15 |
+
|
| 16 |
+
Classifies product descriptions into [Harmonized System (HS) codes](https://en.wikipedia.org/wiki/Harmonized_System) using sentence embeddings and k-NN search, with an interactive latent space visualization.
|
| 17 |
+
|
| 18 |
+
## Live Demo
|
| 19 |
+
|
| 20 |
+
- Hugging Face Space: [https://huggingface.co/spaces/Troglobyte/MicroHS/](https://huggingface.co/spaces/Mead0w1ark/MicroHS)
|
| 21 |
+
## Features
|
| 22 |
+
|
| 23 |
+
- 🌍 **Multilingual** — example supports English, Thai, Vietnamese, and Chinese product descriptions
|
| 24 |
+
- ⚡ **Real-time classification** — top-3 HS code predictions with confidence scores
|
| 25 |
+
- 📊 **Latent space visualization** — interactive UMAP plot showing embedding clusters
|
| 26 |
+
- 🎯 **KNN-based** — simple, interpretable nearest-neighbor approach using fine-tuned `multilingual-e5-small`
|
| 27 |
+
- 🧾 **Official HS coverage** — training generation incorporates the [datasets/harmonized-system](https://github.com/datasets/harmonized-system) 6-digit nomenclature
|
| 28 |
+
|
| 29 |
+
## Dataset Attribution
|
| 30 |
+
|
| 31 |
+
This project includes HS nomenclature content sourced from:
|
| 32 |
+
|
| 33 |
+
- [datasets/harmonized-system](https://github.com/datasets/harmonized-system)
|
| 34 |
+
- Upstream references listed by that dataset:
|
| 35 |
+
- WCO HS nomenclature documentation
|
| 36 |
+
- UN Comtrade data extraction API
|
| 37 |
+
|
| 38 |
+
Related datasets (evaluated during development):
|
| 39 |
+
|
| 40 |
+
- [Customs-Declaration-Datasets](https://github.com/Seondong/Customs-Declaration-Datasets) — 54,000 synthetic customs declaration records derived from 24.7M real Korean customs entries. Provides structured trade metadata (HS codes, country of origin, price, weight, fraud labels) but does not include free-text product descriptions. Cited as a reference for customs data research. See: *S. Kim et al., "DATE: Dual Attentive Tree-aware Embedding for Customs Fraud Detection," KDD 2020.*
|
| 41 |
+
|
| 42 |
+
Licensing:
|
| 43 |
+
|
| 44 |
+
- Upstream HS source data: **ODC Public Domain Dedication and License (PDDL) v1.0**
|
| 45 |
+
- Project-added synthetic multilingual examples and labels: **MIT** (this repo)
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# Clone
|
| 51 |
+
git clone https://github.com/JamesEBall/HSClassify_micro.git
|
| 52 |
+
cd HSClassify_micro
|
| 53 |
+
|
| 54 |
+
# Install dependencies
|
| 55 |
+
python -m venv venv
|
| 56 |
+
source venv/bin/activate
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
+
|
| 59 |
+
# Generate training data & train model
|
| 60 |
+
python scripts/generate_training_data.py
|
| 61 |
+
python scripts/train_model.py
|
| 62 |
+
|
| 63 |
+
# Run the web app
|
| 64 |
+
uvicorn app:app --reload --port 8000
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Open [http://localhost:8000](http://localhost:8000) to classify products.
|
| 68 |
+
|
| 69 |
+
## Deployment
|
| 70 |
+
|
| 71 |
+
- The Space runs in Docker (`sdk: docker`, `app_port: 7860`).
|
| 72 |
+
- OCR endpoints require OS packages; `Dockerfile` installs:
|
| 73 |
+
- `tesseract-ocr`
|
| 74 |
+
- `poppler-utils` (for PDF conversion via `pdf2image`)
|
| 75 |
+
- Model and data loading is resilient in hosted environments:
|
| 76 |
+
- Large artifacts (model weights, embeddings, classifier, training data) are hosted on [HF Hub](https://huggingface.co/Mead0w1ark/multilingual-e5-small-hs-codes) and downloaded automatically at startup if not present locally
|
| 77 |
+
- Set `SENTENCE_MODEL_NAME` to override the HF model repo (default: `Mead0w1ark/multilingual-e5-small-hs-codes`)
|
| 78 |
+
|
| 79 |
+
### Auto Sync (GitHub -> Hugging Face Space)
|
| 80 |
+
|
| 81 |
+
This repo includes a GitHub Action at `.github/workflows/sync_to_hf_space.yml` that syncs `main` to:
|
| 82 |
+
|
| 83 |
+
- `spaces/Troglobyte/MicroHS`
|
| 84 |
+
|
| 85 |
+
Required GitHub secret:
|
| 86 |
+
|
| 87 |
+
- `HF_TOKEN`: Hugging Face token with write access to the Space
|
| 88 |
+
|
| 89 |
+
## Publish Dataset to Hugging Face Datasets
|
| 90 |
+
|
| 91 |
+
Use the included publish helper:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
bash scripts/publish_dataset_to_hf.sh <namespace>/<dataset-repo>
|
| 95 |
+
# Example:
|
| 96 |
+
bash scripts/publish_dataset_to_hf.sh Troglobyte/hsclassify-micro-dataset
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
The script creates/updates a Dataset repo and uploads:
|
| 100 |
+
|
| 101 |
+
- `training_data_indexed.csv`
|
| 102 |
+
- `harmonized-system.csv` (attributed source snapshot)
|
| 103 |
+
- `hs_codes_reference.json`
|
| 104 |
+
- Dataset card + attribution notes
|
| 105 |
+
|
| 106 |
+
## Model
|
| 107 |
+
|
| 108 |
+
The classifier uses [`multilingual-e5-small`](https://huggingface.co/intfloat/multilingual-e5-small) fine-tuned with contrastive learning (MultipleNegativesRankingLoss) on 9,829 curated HS-coded training pairs. Fine-tuned weights are hosted on HF Hub at [`Mead0w1ark/multilingual-e5-small-hs-codes`](https://huggingface.co/Mead0w1ark/multilingual-e5-small-hs-codes).
|
| 109 |
+
|
| 110 |
+
| Metric | Before Fine-Tuning | After Fine-Tuning |
|
| 111 |
+
|---|---|---|
|
| 112 |
+
| Training accuracy (80/20 split) | 77.2% | **87.0%** |
|
| 113 |
+
| Benchmark Top-1 (in-label-space) | 88.6% | **92.9%** |
|
| 114 |
+
| Benchmark Top-3 (in-label-space) | — | **97.1%** |
|
| 115 |
+
|
| 116 |
+
To fine-tune from scratch:
|
| 117 |
+
```bash
|
| 118 |
+
python scripts/train_model.py --finetune
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## How It Works
|
| 122 |
+
|
| 123 |
+
1. **Embedding**: Product descriptions are encoded using fine-tuned `multilingual-e5-small` (384-dim sentence embeddings)
|
| 124 |
+
2. **Classification**: K-nearest neighbors (k=5) over pre-computed embeddings of HS-coded training examples
|
| 125 |
+
3. **Visualization**: UMAP reduction to 2D for interactive cluster exploration via Plotly
|
| 126 |
+
|
| 127 |
+
## Project Structure
|
| 128 |
+
|
| 129 |
+
```
|
| 130 |
+
├── app.py # FastAPI web application
|
| 131 |
+
├── dataset/
|
| 132 |
+
│ ├── README.md # HF dataset card (attribution + schema)
|
| 133 |
+
│ └── ATTRIBUTION.md # Source and license attribution details
|
| 134 |
+
├── requirements.txt # Python dependencies
|
| 135 |
+
├── scripts/
|
| 136 |
+
│ ├── generate_training_data.py # Synthetic training data generator
|
| 137 |
+
│ ├── train_model.py # Model training (embeddings + KNN)
|
| 138 |
+
│ └── publish_dataset_to_hf.sh # Publish dataset artifacts to HF Datasets
|
| 139 |
+
├── data/
|
| 140 |
+
│ ├── hs_codes_reference.json # HS code definitions
|
| 141 |
+
│ ├── harmonized-system/harmonized-system.csv # Upstream HS source snapshot
|
| 142 |
+
│ ├── training_data.csv # Generated training examples
|
| 143 |
+
│ └── training_data_indexed.csv # App/latent-ready training examples
|
| 144 |
+
├── models/ # Trained artifacts (generated)
|
| 145 |
+
│ ├── sentence_model/ # Cached sentence transformer
|
| 146 |
+
│ ├── embeddings.npy # Pre-computed embeddings
|
| 147 |
+
│ ├── knn_classifier.pkl # Trained KNN model
|
| 148 |
+
│ └── label_encoder.pkl # Label encoder
|
| 149 |
+
└── templates/
|
| 150 |
+
└── index.html # Web UI
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Context
|
| 154 |
+
|
| 155 |
+
Built as a rapid POC exploring whether multilingual sentence embeddings can simplify HS code classification for customs authorities.
|
| 156 |
+
## License
|
| 157 |
+
|
| 158 |
+
MIT — see [LICENSE](LICENSE)
|
app.py
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HS Code Classifier Web App
|
| 3 |
+
|
| 4 |
+
FastAPI backend with:
|
| 5 |
+
- Real-time HS code prediction from text input
|
| 6 |
+
- Document upload with OCR (Tesseract) support
|
| 7 |
+
- Structured field extraction from trade documents
|
| 8 |
+
- HS (6-digit) and HTS (7-10 digit) code support
|
| 9 |
+
- Top-5 suggestions with confidence scores
|
| 10 |
+
- Latent space visualization with UMAP
|
| 11 |
+
- Multilingual support (EN, TH, VI, ZH)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import threading
|
| 20 |
+
import time
|
| 21 |
+
import pickle
|
| 22 |
+
import uuid
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
from fastapi import FastAPI, Request, UploadFile, File, Form
|
| 28 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 29 |
+
from fastapi.staticfiles import StaticFiles
|
| 30 |
+
from fastapi.templating import Jinja2Templates
|
| 31 |
+
from sentence_transformers import SentenceTransformer
|
| 32 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 33 |
+
from sklearn.preprocessing import LabelEncoder
|
| 34 |
+
|
| 35 |
+
from field_extractor import extract_fields, get_all_countries, get_all_currencies
|
| 36 |
+
from hs_dataset import get_dataset, get_hts_extensions, get_available_hts_countries
|
| 37 |
+
|
| 38 |
+
# Paths
|
| 39 |
+
PROJECT_DIR = Path(__file__).parent
|
| 40 |
+
MODEL_DIR = PROJECT_DIR / "models"
|
| 41 |
+
DATA_DIR = PROJECT_DIR / "data"
|
| 42 |
+
UPLOAD_DIR = PROJECT_DIR / "uploads"
|
| 43 |
+
UPLOAD_DIR.mkdir(exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Upload config
|
| 46 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
| 47 |
+
ALLOWED_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp", ".pdf"}
|
| 48 |
+
|
| 49 |
+
# Initialize FastAPI
|
| 50 |
+
from starlette.middleware.gzip import GZipMiddleware
|
| 51 |
+
|
| 52 |
+
app = FastAPI(title="HS Code Classifier", version="2.0.0")
|
| 53 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 54 |
+
app.mount("/static", StaticFiles(directory=str(PROJECT_DIR / "static")), name="static")
|
| 55 |
+
templates = Jinja2Templates(directory=str(PROJECT_DIR / "templates"))
|
| 56 |
+
|
| 57 |
+
# Global model state
|
| 58 |
+
model = None
|
| 59 |
+
classifier = None
|
| 60 |
+
label_encoder = None
|
| 61 |
+
hs_reference = None
|
| 62 |
+
training_data = None
|
| 63 |
+
embeddings = None
|
| 64 |
+
umap_data = None
|
| 65 |
+
umap_ready = False
|
| 66 |
+
hs_dataset = None
|
| 67 |
+
classifier_training_indices = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _download_hf_artifacts():
|
| 71 |
+
"""Download large artifacts from HF Hub if not present locally."""
|
| 72 |
+
from huggingface_hub import hf_hub_download
|
| 73 |
+
repo_id = os.getenv("HF_ARTIFACT_REPO", "Mead0w1ark/multilingual-e5-small-hs-codes")
|
| 74 |
+
|
| 75 |
+
file_map = {
|
| 76 |
+
MODEL_DIR / "embeddings.npy": "embeddings.npy",
|
| 77 |
+
MODEL_DIR / "knn_classifier.pkl": "knn_classifier.pkl",
|
| 78 |
+
MODEL_DIR / "label_encoder.pkl": "label_encoder.pkl",
|
| 79 |
+
MODEL_DIR / "metadata.json": "metadata.json",
|
| 80 |
+
MODEL_DIR / "umap_data.json": "umap_data.json",
|
| 81 |
+
DATA_DIR / "training_data.csv": "training_data.csv",
|
| 82 |
+
}
|
| 83 |
+
for local_path, repo_filename in file_map.items():
|
| 84 |
+
if not local_path.exists():
|
| 85 |
+
print(f"Downloading {repo_filename} from {repo_id}...")
|
| 86 |
+
try:
|
| 87 |
+
downloaded = hf_hub_download(
|
| 88 |
+
repo_id=repo_id, filename=repo_filename,
|
| 89 |
+
)
|
| 90 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
shutil.copy2(downloaded, local_path)
|
| 92 |
+
print(f" -> {local_path}")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f" Warning: could not download {repo_filename}: {e}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_models():
|
| 98 |
+
"""Load all model artifacts on startup."""
|
| 99 |
+
global model, classifier, label_encoder, hs_reference, training_data, embeddings, umap_data, hs_dataset, classifier_training_indices
|
| 100 |
+
|
| 101 |
+
print("Loading models...")
|
| 102 |
+
start = time.time()
|
| 103 |
+
|
| 104 |
+
# Download large artifacts from HF Hub if missing locally.
|
| 105 |
+
_download_hf_artifacts()
|
| 106 |
+
|
| 107 |
+
# Load sentence transformer:
|
| 108 |
+
# prefer local bundled model, fall back to Hub model when large files are not in repo.
|
| 109 |
+
local_model_dir = MODEL_DIR / "sentence_model"
|
| 110 |
+
has_local_weights = (
|
| 111 |
+
(local_model_dir / "model.safetensors").exists()
|
| 112 |
+
or (local_model_dir / "pytorch_model.bin").exists()
|
| 113 |
+
)
|
| 114 |
+
has_local_tokenizer = (local_model_dir / "tokenizer.json").exists()
|
| 115 |
+
|
| 116 |
+
if local_model_dir.exists() and has_local_weights and has_local_tokenizer:
|
| 117 |
+
model = SentenceTransformer(str(local_model_dir))
|
| 118 |
+
print("Loaded local sentence model from models/sentence_model")
|
| 119 |
+
else:
|
| 120 |
+
fallback_model = os.getenv(
|
| 121 |
+
"SENTENCE_MODEL_NAME",
|
| 122 |
+
"intfloat/multilingual-e5-small",
|
| 123 |
+
)
|
| 124 |
+
model = SentenceTransformer(fallback_model)
|
| 125 |
+
print(f"Loaded sentence model from Hugging Face Hub: {fallback_model}")
|
| 126 |
+
|
| 127 |
+
# Load HS code reference
|
| 128 |
+
with open(DATA_DIR / "hs_codes_reference.json") as f:
|
| 129 |
+
hs_reference = json.load(f)
|
| 130 |
+
|
| 131 |
+
# Load training data
|
| 132 |
+
training_data_path = DATA_DIR / "training_data_indexed.csv"
|
| 133 |
+
if not training_data_path.exists():
|
| 134 |
+
training_data_path = DATA_DIR / "training_data.csv"
|
| 135 |
+
training_data = pd.read_csv(training_data_path)
|
| 136 |
+
training_data["hs_code"] = training_data["hs_code"].astype(str).str.zfill(6)
|
| 137 |
+
|
| 138 |
+
classifier_path = MODEL_DIR / "knn_classifier.pkl"
|
| 139 |
+
label_encoder_path = MODEL_DIR / "label_encoder.pkl"
|
| 140 |
+
embeddings_path = MODEL_DIR / "embeddings.npy"
|
| 141 |
+
embeddings_part_paths = sorted(MODEL_DIR.glob("embeddings_part*.npy"))
|
| 142 |
+
core_codes = {str(k).zfill(6) for k in hs_reference.keys()}
|
| 143 |
+
artifacts_exist = (
|
| 144 |
+
classifier_path.exists()
|
| 145 |
+
and label_encoder_path.exists()
|
| 146 |
+
and (embeddings_path.exists() or len(embeddings_part_paths) > 0)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def load_cached_embeddings():
|
| 150 |
+
if embeddings_path.exists():
|
| 151 |
+
return np.load(embeddings_path)
|
| 152 |
+
part_paths = sorted(MODEL_DIR.glob("embeddings_part*.npy"))
|
| 153 |
+
if part_paths:
|
| 154 |
+
parts = [np.load(p) for p in part_paths]
|
| 155 |
+
return np.concatenate(parts, axis=0)
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
def compute_full_embeddings():
|
| 159 |
+
texts = training_data["text"].fillna("").astype(str).tolist()
|
| 160 |
+
if not texts:
|
| 161 |
+
raise RuntimeError("No training rows available to rebuild classifier.")
|
| 162 |
+
return model.encode(
|
| 163 |
+
[f"passage: {text}" for text in texts],
|
| 164 |
+
normalize_embeddings=True,
|
| 165 |
+
convert_to_numpy=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def rebuild_classifier_on_curated_codes():
|
| 169 |
+
global classifier, label_encoder, classifier_training_indices
|
| 170 |
+
classifier_df = training_data[training_data["hs_code"].isin(core_codes)].copy()
|
| 171 |
+
if classifier_df.empty:
|
| 172 |
+
classifier_df = training_data
|
| 173 |
+
|
| 174 |
+
clf_indices = classifier_df.index.to_numpy()
|
| 175 |
+
clf_embeddings = embeddings[clf_indices]
|
| 176 |
+
hs_labels = classifier_df["hs_code"].tolist()
|
| 177 |
+
label_encoder = LabelEncoder()
|
| 178 |
+
y = label_encoder.fit_transform(hs_labels)
|
| 179 |
+
|
| 180 |
+
classifier = KNeighborsClassifier(
|
| 181 |
+
n_neighbors=min(5, len(classifier_df)),
|
| 182 |
+
metric="cosine",
|
| 183 |
+
weights="distance",
|
| 184 |
+
)
|
| 185 |
+
classifier.fit(clf_embeddings, y)
|
| 186 |
+
classifier_training_indices = clf_indices
|
| 187 |
+
print(
|
| 188 |
+
f"Rebuilt classifier on {len(classifier_df)} rows "
|
| 189 |
+
f"across {len(set(hs_labels))} curated HS codes"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
np.save(embeddings_path, embeddings)
|
| 194 |
+
with open(classifier_path, "wb") as f:
|
| 195 |
+
pickle.dump(classifier, f)
|
| 196 |
+
with open(label_encoder_path, "wb") as f:
|
| 197 |
+
pickle.dump(label_encoder, f)
|
| 198 |
+
print("Saved rebuilt classifier artifacts to models/")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Warning: could not cache rebuilt artifacts: {e}")
|
| 201 |
+
|
| 202 |
+
if artifacts_exist:
|
| 203 |
+
with open(classifier_path, "rb") as f:
|
| 204 |
+
classifier = pickle.load(f)
|
| 205 |
+
with open(label_encoder_path, "rb") as f:
|
| 206 |
+
label_encoder = pickle.load(f)
|
| 207 |
+
embeddings = load_cached_embeddings()
|
| 208 |
+
print("Loaded classifier artifacts from models/")
|
| 209 |
+
|
| 210 |
+
if embeddings is None or len(embeddings) != len(training_data):
|
| 211 |
+
print(
|
| 212 |
+
f"Embeddings size mismatch (embeddings={len(embeddings) if embeddings is not None else 0}, "
|
| 213 |
+
f"data={len(training_data)}). "
|
| 214 |
+
"Recomputing embeddings..."
|
| 215 |
+
)
|
| 216 |
+
embeddings = compute_full_embeddings()
|
| 217 |
+
|
| 218 |
+
artifact_codes = {str(c).zfill(6) for c in getattr(label_encoder, "classes_", [])}
|
| 219 |
+
invalid_artifacts = (
|
| 220 |
+
not artifact_codes
|
| 221 |
+
or not artifact_codes.issubset(core_codes)
|
| 222 |
+
or len(artifact_codes) > len(core_codes)
|
| 223 |
+
)
|
| 224 |
+
if invalid_artifacts:
|
| 225 |
+
print("Classifier artifacts not aligned with curated HS set; rebuilding classifier...")
|
| 226 |
+
rebuild_classifier_on_curated_codes()
|
| 227 |
+
else:
|
| 228 |
+
# Map KNN fit row indices back to full training_data row indices for latent neighbors.
|
| 229 |
+
classifier_df = training_data[training_data["hs_code"].isin(artifact_codes)].copy()
|
| 230 |
+
classifier_training_indices = classifier_df.index.to_numpy()
|
| 231 |
+
n_fit = int(getattr(classifier, "n_samples_fit_", 0))
|
| 232 |
+
if n_fit <= 0:
|
| 233 |
+
fit_x = getattr(classifier, "_fit_X", None)
|
| 234 |
+
n_fit = int(fit_x.shape[0]) if fit_x is not None else 0
|
| 235 |
+
|
| 236 |
+
if n_fit > 0 and len(classifier_training_indices) == n_fit:
|
| 237 |
+
print(f"Mapped classifier indices to {len(classifier_training_indices)} training rows")
|
| 238 |
+
else:
|
| 239 |
+
print(
|
| 240 |
+
"Classifier index mapping mismatch "
|
| 241 |
+
f"(mapped={len(classifier_training_indices)}, fit={n_fit}); rebuilding classifier..."
|
| 242 |
+
)
|
| 243 |
+
rebuild_classifier_on_curated_codes()
|
| 244 |
+
else:
|
| 245 |
+
print("Classifier artifacts missing; rebuilding from training data...")
|
| 246 |
+
embeddings = compute_full_embeddings()
|
| 247 |
+
rebuild_classifier_on_curated_codes()
|
| 248 |
+
|
| 249 |
+
# Load HS dataset (official harmonized-system data)
|
| 250 |
+
hs_dataset = get_dataset()
|
| 251 |
+
|
| 252 |
+
# UMAP data is loaded/computed in a background thread so the server
|
| 253 |
+
# can start immediately and pass the HF Space health check.
|
| 254 |
+
umap_data = []
|
| 255 |
+
|
| 256 |
+
elapsed = time.time() - start
|
| 257 |
+
print(f"All models loaded in {elapsed:.1f}s")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _compute_umap_background():
|
| 261 |
+
"""Load UMAP data from cache or compute in background.
|
| 262 |
+
|
| 263 |
+
Sets the global ``umap_data`` list and ``umap_ready`` flag when done.
|
| 264 |
+
"""
|
| 265 |
+
global umap_data, umap_ready
|
| 266 |
+
|
| 267 |
+
cache_path = MODEL_DIR / "umap_data.json"
|
| 268 |
+
if cache_path.exists():
|
| 269 |
+
try:
|
| 270 |
+
with open(cache_path, encoding="utf-8") as f:
|
| 271 |
+
cached = json.load(f)
|
| 272 |
+
has_category_fields = (
|
| 273 |
+
isinstance(cached, list)
|
| 274 |
+
and len(cached) > 0
|
| 275 |
+
and "chapter_name" in cached[0]
|
| 276 |
+
)
|
| 277 |
+
if isinstance(cached, list) and len(cached) == len(training_data) and has_category_fields:
|
| 278 |
+
umap_data = cached
|
| 279 |
+
umap_ready = True
|
| 280 |
+
print(f"Loaded cached UMAP data: {len(umap_data)} points")
|
| 281 |
+
return
|
| 282 |
+
else:
|
| 283 |
+
print(
|
| 284 |
+
f"Cached UMAP size mismatch (cache={len(cached)}, data={len(training_data)}). "
|
| 285 |
+
"Recomputing UMAP projection..."
|
| 286 |
+
)
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Warning: could not read UMAP cache: {e}")
|
| 289 |
+
|
| 290 |
+
print("Computing UMAP projection (background)...")
|
| 291 |
+
try:
|
| 292 |
+
import umap
|
| 293 |
+
reducer = umap.UMAP(
|
| 294 |
+
n_neighbors=30,
|
| 295 |
+
min_dist=0.0,
|
| 296 |
+
n_components=2,
|
| 297 |
+
metric='cosine',
|
| 298 |
+
random_state=42,
|
| 299 |
+
)
|
| 300 |
+
umap_coords = reducer.fit_transform(embeddings)
|
| 301 |
+
|
| 302 |
+
points = []
|
| 303 |
+
for i, row in training_data.iterrows():
|
| 304 |
+
hs_code = str(row["hs_code"]).zfill(6)
|
| 305 |
+
chapter = row["hs_chapter"]
|
| 306 |
+
chapter_name = str(row.get("hs_chapter_name", "")).strip()
|
| 307 |
+
if not chapter_name or re.match(r"^HS\s\d{2}$", chapter_name):
|
| 308 |
+
chapter_name = str(chapter).split(";")[0].strip()
|
| 309 |
+
desc = hs_reference.get(hs_code, {}).get("desc", "Unknown")
|
| 310 |
+
points.append({
|
| 311 |
+
"x": float(umap_coords[i, 0]),
|
| 312 |
+
"y": float(umap_coords[i, 1]),
|
| 313 |
+
"text": row["text"][:80],
|
| 314 |
+
"hs_code": hs_code,
|
| 315 |
+
"chapter": chapter,
|
| 316 |
+
"chapter_name": chapter_name,
|
| 317 |
+
"hs_desc": desc,
|
| 318 |
+
"language": row["language"],
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
+
with open(cache_path, "w", encoding="utf-8") as f:
|
| 322 |
+
json.dump(points, f, ensure_ascii=False)
|
| 323 |
+
|
| 324 |
+
umap_data = points
|
| 325 |
+
umap_ready = True
|
| 326 |
+
print(f"UMAP projection computed for {len(umap_data)} points")
|
| 327 |
+
except Exception as e:
|
| 328 |
+
print(f"UMAP computation failed: {e}")
|
| 329 |
+
umap_ready = True # mark ready so endpoints stop saying "computing"
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@app.on_event("startup")
|
| 333 |
+
async def startup():
|
| 334 |
+
load_models()
|
| 335 |
+
threading.Thread(target=_compute_umap_background, daemon=True).start()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@app.get("/", response_class=HTMLResponse)
|
| 339 |
+
async def index(request: Request):
|
| 340 |
+
"""Main page."""
|
| 341 |
+
metadata = {}
|
| 342 |
+
try:
|
| 343 |
+
with open(MODEL_DIR / "metadata.json") as f:
|
| 344 |
+
metadata = json.load(f)
|
| 345 |
+
except:
|
| 346 |
+
pass
|
| 347 |
+
|
| 348 |
+
countries = get_all_countries()
|
| 349 |
+
currencies = get_all_currencies()
|
| 350 |
+
hts_countries = get_available_hts_countries()
|
| 351 |
+
|
| 352 |
+
return templates.TemplateResponse("index.html", {
|
| 353 |
+
"request": request,
|
| 354 |
+
"metadata": metadata,
|
| 355 |
+
"countries": countries,
|
| 356 |
+
"currencies": currencies,
|
| 357 |
+
"hts_countries": hts_countries,
|
| 358 |
+
})
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@app.post("/predict")
|
| 362 |
+
async def predict(request: Request):
|
| 363 |
+
"""Predict HS code for a product description with optional structured context."""
|
| 364 |
+
body = await request.json()
|
| 365 |
+
query_text = body.get("text", "").strip()
|
| 366 |
+
made_in = body.get("made_in", "")
|
| 367 |
+
ship_to = body.get("ship_to", "")
|
| 368 |
+
item_price = body.get("item_price", None)
|
| 369 |
+
currency = body.get("currency", "")
|
| 370 |
+
|
| 371 |
+
if not query_text:
|
| 372 |
+
return JSONResponse({"error": "No text provided"}, status_code=400)
|
| 373 |
+
|
| 374 |
+
start = time.time()
|
| 375 |
+
|
| 376 |
+
# Build enriched query using structured fields
|
| 377 |
+
enriched_query = query_text
|
| 378 |
+
context_parts = []
|
| 379 |
+
if made_in:
|
| 380 |
+
context_parts.append(f"origin: {made_in}")
|
| 381 |
+
if ship_to:
|
| 382 |
+
context_parts.append(f"destination: {ship_to}")
|
| 383 |
+
if item_price and currency:
|
| 384 |
+
context_parts.append(f"value: {currency} {item_price}")
|
| 385 |
+
|
| 386 |
+
if context_parts:
|
| 387 |
+
enriched_query = f"{query_text} ({', '.join(context_parts)})"
|
| 388 |
+
|
| 389 |
+
# Encode query with e5 prefix
|
| 390 |
+
query_emb = model.encode(
|
| 391 |
+
[f"query: {enriched_query}"],
|
| 392 |
+
normalize_embeddings=True,
|
| 393 |
+
convert_to_numpy=True
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Get predictions with probabilities
|
| 397 |
+
probs = classifier.predict_proba(query_emb)[0]
|
| 398 |
+
top_k = 5
|
| 399 |
+
top_indices = np.argsort(probs)[-top_k:][::-1]
|
| 400 |
+
|
| 401 |
+
predictions = []
|
| 402 |
+
for idx in top_indices:
|
| 403 |
+
hs_code = label_encoder.classes_[idx]
|
| 404 |
+
hs_code_padded = str(hs_code).zfill(6)
|
| 405 |
+
confidence = float(probs[idx])
|
| 406 |
+
if confidence < 0.01:
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
info = hs_reference.get(hs_code_padded, {})
|
| 410 |
+
chapter_code = hs_code_padded[:2]
|
| 411 |
+
heading_code = hs_code_padded[:4]
|
| 412 |
+
|
| 413 |
+
# Get official description from HS dataset if available
|
| 414 |
+
official = hs_dataset.lookup(hs_code_padded) if hs_dataset else None
|
| 415 |
+
official_desc = official['description'] if official else None
|
| 416 |
+
|
| 417 |
+
# Validate against official dataset
|
| 418 |
+
validation = hs_dataset.validate_hs_code(hs_code_padded) if hs_dataset else None
|
| 419 |
+
|
| 420 |
+
predictions.append({
|
| 421 |
+
"hs_code": hs_code_padded,
|
| 422 |
+
"confidence": confidence,
|
| 423 |
+
"description": info.get("desc", official_desc or "No description available"),
|
| 424 |
+
"official_description": official_desc,
|
| 425 |
+
"chapter": info.get("chapter", "Unknown"),
|
| 426 |
+
"chapter_code": chapter_code,
|
| 427 |
+
"heading_code": heading_code,
|
| 428 |
+
"validated": validation['valid'] if validation else None,
|
| 429 |
+
})
|
| 430 |
+
|
| 431 |
+
# Find nearest training examples
|
| 432 |
+
sims = embeddings @ query_emb.T
|
| 433 |
+
top_sim_idx = np.argsort(sims.flatten())[-3:][::-1]
|
| 434 |
+
similar_examples = []
|
| 435 |
+
for idx in top_sim_idx:
|
| 436 |
+
if idx < len(training_data):
|
| 437 |
+
similar_examples.append({
|
| 438 |
+
"text": training_data.iloc[idx]["text"],
|
| 439 |
+
"hs_code": str(training_data.iloc[idx]["hs_code"]).zfill(6),
|
| 440 |
+
"similarity": float(sims[idx][0]),
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
+
elapsed = time.time() - start
|
| 444 |
+
|
| 445 |
+
return JSONResponse({
|
| 446 |
+
"query": query_text,
|
| 447 |
+
"enriched_query": enriched_query,
|
| 448 |
+
"predictions": predictions,
|
| 449 |
+
"similar_examples": similar_examples,
|
| 450 |
+
"inference_time_ms": round(elapsed * 1000, 1),
|
| 451 |
+
})
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@app.post("/upload-document")
|
| 455 |
+
async def upload_document(file: UploadFile = File(...)):
|
| 456 |
+
"""Upload a document (image/PDF) and extract text via OCR + structured fields."""
|
| 457 |
+
# Validate file
|
| 458 |
+
if not file.filename:
|
| 459 |
+
return JSONResponse({"error": "No file provided"}, status_code=400)
|
| 460 |
+
|
| 461 |
+
ext = Path(file.filename).suffix.lower()
|
| 462 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 463 |
+
return JSONResponse(
|
| 464 |
+
{"error": f"Unsupported file type: {ext}. Allowed: {', '.join(ALLOWED_EXTENSIONS)}"},
|
| 465 |
+
status_code=400
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Read file content
|
| 469 |
+
content = await file.read()
|
| 470 |
+
if len(content) > MAX_FILE_SIZE:
|
| 471 |
+
return JSONResponse(
|
| 472 |
+
{"error": f"File too large. Maximum: {MAX_FILE_SIZE // (1024*1024)}MB"},
|
| 473 |
+
status_code=400
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Save to temp file
|
| 477 |
+
file_id = str(uuid.uuid4())[:8]
|
| 478 |
+
temp_path = UPLOAD_DIR / f"{file_id}{ext}"
|
| 479 |
+
with open(temp_path, "wb") as f:
|
| 480 |
+
f.write(content)
|
| 481 |
+
|
| 482 |
+
try:
|
| 483 |
+
import pytesseract
|
| 484 |
+
from PIL import Image
|
| 485 |
+
|
| 486 |
+
ocr_text = ""
|
| 487 |
+
|
| 488 |
+
if ext == ".pdf":
|
| 489 |
+
# Convert PDF to images, then OCR
|
| 490 |
+
try:
|
| 491 |
+
from pdf2image import convert_from_path
|
| 492 |
+
images = convert_from_path(str(temp_path), dpi=300)
|
| 493 |
+
texts = []
|
| 494 |
+
for img in images:
|
| 495 |
+
texts.append(pytesseract.image_to_string(img))
|
| 496 |
+
ocr_text = "\n\n".join(texts)
|
| 497 |
+
except ImportError:
|
| 498 |
+
return JSONResponse(
|
| 499 |
+
{"error": "PDF support requires pdf2image and poppler. Install with: pip install pdf2image"},
|
| 500 |
+
status_code=500
|
| 501 |
+
)
|
| 502 |
+
except Exception as e:
|
| 503 |
+
return JSONResponse(
|
| 504 |
+
{"error": f"PDF processing error: {str(e)}"},
|
| 505 |
+
status_code=500
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
# Image OCR
|
| 509 |
+
img = Image.open(temp_path)
|
| 510 |
+
ocr_text = pytesseract.image_to_string(img)
|
| 511 |
+
|
| 512 |
+
if not ocr_text.strip():
|
| 513 |
+
return JSONResponse({
|
| 514 |
+
"error": "OCR could not extract any text from this document. Please try a clearer image.",
|
| 515 |
+
"raw_text": "",
|
| 516 |
+
"fields": {},
|
| 517 |
+
})
|
| 518 |
+
|
| 519 |
+
# Extract structured fields
|
| 520 |
+
fields = extract_fields(ocr_text)
|
| 521 |
+
|
| 522 |
+
return JSONResponse({
|
| 523 |
+
"success": True,
|
| 524 |
+
"file_id": file_id,
|
| 525 |
+
"filename": file.filename,
|
| 526 |
+
"raw_text": ocr_text.strip(),
|
| 527 |
+
"fields": fields,
|
| 528 |
+
})
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
return JSONResponse(
|
| 532 |
+
{"error": f"OCR processing failed: {str(e)}"},
|
| 533 |
+
status_code=500
|
| 534 |
+
)
|
| 535 |
+
finally:
|
| 536 |
+
# Clean up temp file
|
| 537 |
+
if temp_path.exists():
|
| 538 |
+
temp_path.unlink()
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
@app.post("/extract-fields")
|
| 542 |
+
async def extract_fields_endpoint(request: Request):
|
| 543 |
+
"""Extract structured fields from arbitrary text (no OCR needed)."""
|
| 544 |
+
body = await request.json()
|
| 545 |
+
text = body.get("text", "").strip()
|
| 546 |
+
|
| 547 |
+
if not text:
|
| 548 |
+
return JSONResponse({"error": "No text provided"}, status_code=400)
|
| 549 |
+
|
| 550 |
+
fields = extract_fields(text)
|
| 551 |
+
return JSONResponse({"fields": fields})
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
@app.get("/hts-extensions/{hs_code}")
|
| 555 |
+
async def get_hts(hs_code: str, country: str = "US"):
|
| 556 |
+
"""Get HTS (country-specific) extensions for a 6-digit HS code."""
|
| 557 |
+
result = get_hts_extensions(hs_code, country)
|
| 558 |
+
return JSONResponse(result)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@app.get("/hs-lookup/{hs_code}")
|
| 562 |
+
async def hs_lookup(hs_code: str):
|
| 563 |
+
"""Look up an HS code in the official dataset."""
|
| 564 |
+
if not hs_dataset:
|
| 565 |
+
return JSONResponse({"error": "HS dataset not loaded"}, status_code=500)
|
| 566 |
+
|
| 567 |
+
result = hs_dataset.lookup(hs_code)
|
| 568 |
+
if not result:
|
| 569 |
+
# Try search instead
|
| 570 |
+
search_results = hs_dataset.search(hs_code, max_results=5)
|
| 571 |
+
return JSONResponse({
|
| 572 |
+
"found": False,
|
| 573 |
+
"message": f"Code {hs_code} not found. Did you mean one of these?",
|
| 574 |
+
"suggestions": search_results,
|
| 575 |
+
})
|
| 576 |
+
|
| 577 |
+
return JSONResponse({"found": True, **result})
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@app.get("/hs-search")
|
| 581 |
+
async def hs_search(q: str = "", limit: int = 20):
|
| 582 |
+
"""Search HS codes by description."""
|
| 583 |
+
if not q:
|
| 584 |
+
return JSONResponse({"error": "No query provided"}, status_code=400)
|
| 585 |
+
|
| 586 |
+
results = hs_dataset.search(q, max_results=limit)
|
| 587 |
+
return JSONResponse({"results": results, "query": q})
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
@app.get("/hs-validate/{hs_code}")
|
| 591 |
+
async def hs_validate(hs_code: str):
|
| 592 |
+
"""Validate whether an HS code exists."""
|
| 593 |
+
result = hs_dataset.validate_hs_code(hs_code)
|
| 594 |
+
return JSONResponse(result)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@app.get("/hts-countries")
|
| 598 |
+
async def hts_countries():
|
| 599 |
+
"""Get list of countries with HTS extensions available."""
|
| 600 |
+
return JSONResponse({"countries": get_available_hts_countries()})
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@app.get("/visualization-data")
|
| 604 |
+
async def get_visualization_data(request: Request):
|
| 605 |
+
"""Return UMAP projection data for visualization.
|
| 606 |
+
|
| 607 |
+
Supports ``?max_points=N`` to subsample for faster initial load.
|
| 608 |
+
The subsample is stratified by chapter so every category is represented.
|
| 609 |
+
"""
|
| 610 |
+
max_points = int(request.query_params.get("max_points", "0"))
|
| 611 |
+
|
| 612 |
+
points = umap_data
|
| 613 |
+
if not points:
|
| 614 |
+
cache_path = MODEL_DIR / "umap_data.json"
|
| 615 |
+
if cache_path.exists():
|
| 616 |
+
with open(cache_path, encoding="utf-8") as f:
|
| 617 |
+
points = json.load(f)
|
| 618 |
+
|
| 619 |
+
if not points:
|
| 620 |
+
if not umap_ready:
|
| 621 |
+
return JSONResponse({"points": [], "computing": True})
|
| 622 |
+
return JSONResponse({"points": [], "error": "No UMAP data available"})
|
| 623 |
+
|
| 624 |
+
total = len(points)
|
| 625 |
+
if 0 < max_points < total:
|
| 626 |
+
# Stratified subsample: keep proportional representation per chapter
|
| 627 |
+
import random as _rng
|
| 628 |
+
_rng.seed(42)
|
| 629 |
+
by_chapter: dict[str, list] = {}
|
| 630 |
+
for p in points:
|
| 631 |
+
by_chapter.setdefault(p.get("chapter_name", "Other"), []).append(p)
|
| 632 |
+
sampled: list = []
|
| 633 |
+
for ch, ch_pts in by_chapter.items():
|
| 634 |
+
n = max(1, round(len(ch_pts) / total * max_points))
|
| 635 |
+
sampled.extend(_rng.sample(ch_pts, min(n, len(ch_pts))))
|
| 636 |
+
_rng.shuffle(sampled)
|
| 637 |
+
return JSONResponse({"points": sampled, "total": total, "sampled": True})
|
| 638 |
+
|
| 639 |
+
return JSONResponse({"points": points, "total": total})
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
@app.get("/visualization-density")
|
| 643 |
+
async def get_visualization_density():
|
| 644 |
+
"""All UMAP points in compact columnar format for density/labels."""
|
| 645 |
+
points = umap_data or []
|
| 646 |
+
if not points:
|
| 647 |
+
cache_path = MODEL_DIR / "umap_data.json"
|
| 648 |
+
if cache_path.exists():
|
| 649 |
+
with open(cache_path, encoding="utf-8") as f:
|
| 650 |
+
points = json.load(f)
|
| 651 |
+
if not points:
|
| 652 |
+
if not umap_ready:
|
| 653 |
+
return JSONResponse({"chapters": {}, "computing": True})
|
| 654 |
+
return JSONResponse({"error": "No data"})
|
| 655 |
+
|
| 656 |
+
by_chapter: dict[str, dict[str, list]] = {}
|
| 657 |
+
for p in points:
|
| 658 |
+
ch = p.get("chapter_name", "Unknown")
|
| 659 |
+
if ch not in by_chapter:
|
| 660 |
+
by_chapter[ch] = {"x": [], "y": []}
|
| 661 |
+
by_chapter[ch]["x"].append(round(p["x"], 3))
|
| 662 |
+
by_chapter[ch]["y"].append(round(p["y"], 3))
|
| 663 |
+
|
| 664 |
+
return JSONResponse({"chapters": by_chapter})
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
@app.post("/embed-query")
|
| 668 |
+
async def embed_query(request: Request):
|
| 669 |
+
"""Get UMAP coordinates for a query."""
|
| 670 |
+
body = await request.json()
|
| 671 |
+
query_text = body.get("text", "").strip()
|
| 672 |
+
|
| 673 |
+
if not query_text:
|
| 674 |
+
return JSONResponse({"error": "No text provided"}, status_code=400)
|
| 675 |
+
|
| 676 |
+
query_emb = model.encode(
|
| 677 |
+
[f"query: {query_text}"],
|
| 678 |
+
normalize_embeddings=True,
|
| 679 |
+
convert_to_numpy=True
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
n_fit = int(getattr(classifier, "n_samples_fit_", 0))
|
| 683 |
+
if n_fit <= 0:
|
| 684 |
+
fit_x = getattr(classifier, "_fit_X", None)
|
| 685 |
+
n_fit = int(fit_x.shape[0]) if fit_x is not None else 0
|
| 686 |
+
if n_fit <= 0:
|
| 687 |
+
return JSONResponse({"error": "Classifier has no fitted rows"}, status_code=500)
|
| 688 |
+
|
| 689 |
+
n_neighbors = min(5, n_fit)
|
| 690 |
+
distances, indices = classifier.kneighbors(query_emb, n_neighbors=n_neighbors)
|
| 691 |
+
|
| 692 |
+
if umap_data and len(umap_data) > 0:
|
| 693 |
+
weights = 1.0 / (distances[0] + 1e-6)
|
| 694 |
+
weights = weights / weights.sum()
|
| 695 |
+
|
| 696 |
+
mapped_indices = []
|
| 697 |
+
for idx in indices[0]:
|
| 698 |
+
mapped_idx = int(idx)
|
| 699 |
+
if (
|
| 700 |
+
classifier_training_indices is not None
|
| 701 |
+
and mapped_idx < len(classifier_training_indices)
|
| 702 |
+
):
|
| 703 |
+
mapped_idx = int(classifier_training_indices[mapped_idx])
|
| 704 |
+
mapped_indices.append(mapped_idx)
|
| 705 |
+
|
| 706 |
+
x = sum(
|
| 707 |
+
umap_data[idx]["x"] * w
|
| 708 |
+
for idx, w in zip(mapped_indices, weights)
|
| 709 |
+
if 0 <= idx < len(umap_data)
|
| 710 |
+
)
|
| 711 |
+
y = sum(
|
| 712 |
+
umap_data[idx]["y"] * w
|
| 713 |
+
for idx, w in zip(mapped_indices, weights)
|
| 714 |
+
if 0 <= idx < len(umap_data)
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
neighbors = []
|
| 718 |
+
for idx, dist in zip(mapped_indices, distances[0]):
|
| 719 |
+
if idx < len(umap_data):
|
| 720 |
+
point = umap_data[idx]
|
| 721 |
+
# cosine distance in [0, 2] for normalized vectors; lower is closer
|
| 722 |
+
similarity = max(0.0, min(1.0, 1.0 - float(dist)))
|
| 723 |
+
neighbors.append({
|
| 724 |
+
**point,
|
| 725 |
+
"distance": float(dist),
|
| 726 |
+
"similarity": similarity,
|
| 727 |
+
})
|
| 728 |
+
|
| 729 |
+
return JSONResponse({
|
| 730 |
+
"x": float(x),
|
| 731 |
+
"y": float(y),
|
| 732 |
+
"neighbors": neighbors,
|
| 733 |
+
})
|
| 734 |
+
|
| 735 |
+
if not umap_ready:
|
| 736 |
+
return JSONResponse({"error": "UMAP data is still computing", "computing": True})
|
| 737 |
+
return JSONResponse({"error": "No UMAP data for projection"})
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
@app.get("/health")
|
| 741 |
+
async def health():
|
| 742 |
+
"""Health check."""
|
| 743 |
+
return {
|
| 744 |
+
"status": "ok",
|
| 745 |
+
"model_loaded": model is not None,
|
| 746 |
+
"hs_dataset_loaded": hs_dataset._loaded if hs_dataset else False,
|
| 747 |
+
"hs_codes_count": len(hs_dataset.subheadings) if hs_dataset else 0,
|
| 748 |
+
"umap_ready": umap_ready,
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
if __name__ == "__main__":
|
| 753 |
+
import uvicorn
|
| 754 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
data/benchmark_cases.csv
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
text,expected_hs_code,category,language,notes
|
| 2 |
+
fresh boneless beef,020130,easy,en,common meat product
|
| 3 |
+
frozen boneless bovine meat for export,020230,easy,en,frozen meat variant
|
| 4 |
+
frozen shrimp 500g bag,030617,easy,en,common seafood
|
| 5 |
+
whole milk 3.5% fat,040120,easy,en,standard dairy
|
| 6 |
+
cheddar cheese block,040690,easy,en,common cheese
|
| 7 |
+
fresh tomatoes,070200,easy,en,basic vegetable
|
| 8 |
+
fresh red apples,080810,easy,en,common fruit
|
| 9 |
+
bananas fresh,080300,easy,en,top traded fruit
|
| 10 |
+
raw coffee beans unroasted,090111,easy,en,major commodity
|
| 11 |
+
white rice 25kg bag,100630,easy,en,staple grain
|
| 12 |
+
palm oil refined,151190,easy,en,major edible oil
|
| 13 |
+
cane sugar raw,170199,easy,en,basic commodity
|
| 14 |
+
sweet biscuits assorted,190531,easy,en,packaged food
|
| 15 |
+
bottled sparkling water flavored,220210,easy,en,common beverage
|
| 16 |
+
beer lager 330ml bottles,220300,easy,en,common alcohol
|
| 17 |
+
scotch whisky 700ml,220830,easy,en,spirits
|
| 18 |
+
crude petroleum oil,270900,easy,en,major commodity
|
| 19 |
+
polyethylene pellets LDPE,390110,easy,en,common plastic resin
|
| 20 |
+
car tyre 205/55R16 new,401110,easy,en,auto consumable
|
| 21 |
+
cotton t-shirts mens,610910,easy,en,basic garment
|
| 22 |
+
hot rolled steel coil 600mm,720839,easy,en,industrial steel
|
| 23 |
+
copper cathodes 99.99% purity,740311,easy,en,refined metal
|
| 24 |
+
laptop computer 14 inch,847130,easy,en,common electronics
|
| 25 |
+
smartphone Samsung Galaxy,851712,easy,en,ubiquitous device
|
| 26 |
+
lithium ion battery pack 48V,850760,easy,en,EV battery
|
| 27 |
+
sedan car 2000cc petrol engine,870323,easy,en,standard vehicle
|
| 28 |
+
wooden bedroom wardrobe,940350,easy,en,common furniture
|
| 29 |
+
tea,090210,edge_case,en,very short query - ambiguous
|
| 30 |
+
car parts,870899,edge_case,en,vague automotive
|
| 31 |
+
medicine,300490,edge_case,en,extremely generic
|
| 32 |
+
chips,854231,edge_case,en,ambiguous - food or electronics
|
| 33 |
+
oil,270900,edge_case,en,highly ambiguous
|
| 34 |
+
shoes,640399,edge_case,en,generic footwear
|
| 35 |
+
paper,480256,edge_case,en,very generic
|
| 36 |
+
plastic bags for groceries,392321,edge_case,en,everyday item
|
| 37 |
+
Galaxy S24 Ultra,851712,edge_case,en,brand name only
|
| 38 |
+
Nespresso coffee capsules,210111,edge_case,en,branded coffee product
|
| 39 |
+
Goodyear truck tyre 315/80R22.5,401120,edge_case,en,brand + specs
|
| 40 |
+
Jack Daniels Tennessee whiskey 750ml,220830,edge_case,en,brand name spirits
|
| 41 |
+
Nintendo Switch gaming console,950490,edge_case,en,brand - games vs electronics
|
| 42 |
+
surgical masks disposable,901890,edge_case,en,medical supply
|
| 43 |
+
USB-C charging cable,854239,edge_case,en,tech accessory
|
| 44 |
+
frozen cod fish fillet,030389,edge_case,en,specific fish species
|
| 45 |
+
stainless steel bolts M10,730890,edge_case,en,metal hardware
|
| 46 |
+
yoga pants women polyester,620462,edge_case,en,modern clothing description
|
| 47 |
+
aspirin tablets 500mg retail,300490,edge_case,en,OTC pharma
|
| 48 |
+
insecticide spray for mosquitoes,380891,edge_case,en,household chemical
|
| 49 |
+
PET bottles preform,390760,edge_case,en,industrial plastic
|
| 50 |
+
ข้าวหอมมะลิ,100630,multilingual,th,jasmine rice
|
| 51 |
+
กุ้งแช่แข็ง,030617,multilingual,th,frozen shrimp
|
| 52 |
+
รถยนต์ไฟฟ้า,870380,multilingual,th,electric car
|
| 53 |
+
โทรศัพท์มือถือ,851712,multilingual,th,mobile phone
|
| 54 |
+
ยางรถยนต์,401110,multilingual,th,car tyre
|
| 55 |
+
น้ำตาลทราย,170199,multilingual,th,granulated sugar
|
| 56 |
+
เสื้อยืดผ้าฝ้าย,610910,multilingual,th,cotton t-shirt
|
| 57 |
+
gạo trắng,100630,multilingual,vi,white rice
|
| 58 |
+
tôm đông lạnh,030617,multilingual,vi,frozen shrimp
|
| 59 |
+
cà phê nhân,090111,multilingual,vi,raw coffee beans
|
| 60 |
+
thép cuộn cán nóng,720839,multilingual,vi,hot rolled steel coil
|
| 61 |
+
điện thoại thông minh,851712,multilingual,vi,smartphone
|
| 62 |
+
xe ô tô điện,870380,multilingual,vi,electric car
|
| 63 |
+
dầu thô,270900,multilingual,vi,crude oil
|
| 64 |
+
笔记本电脑,847130,multilingual,zh,laptop computer
|
| 65 |
+
冷冻虾,030617,multilingual,zh,frozen shrimp
|
| 66 |
+
大米,100630,multilingual,zh,rice
|
| 67 |
+
锂电池,850760,multilingual,zh,lithium battery
|
| 68 |
+
棉质T恤,610910,multilingual,zh,cotton t-shirt
|
| 69 |
+
原油,270900,multilingual,zh,crude oil
|
| 70 |
+
English breakfast tea,090240,known_failure,en,black tea - code 090240 not in label space
|
| 71 |
+
matcha green tea powder,090210,known_failure,en,tea variant - model often confuses with other categories
|
| 72 |
+
oolong tea leaves 100g,090230,known_failure,en,semi-fermented tea - code 090230 not in label space
|
| 73 |
+
chamomile herbal tea bags,121190,known_failure,en,herbal infusion - not tea chapter - not in label space
|
| 74 |
+
fresh avocado,080440,known_failure,en,avocado code 080440 not in label space
|
| 75 |
+
quinoa grain organic,100850,known_failure,en,quinoa code 100850 not in label space
|
| 76 |
+
soy sauce 500ml bottle,210390,known_failure,en,soy sauce code 210390 not in label space
|
| 77 |
+
hand sanitizer gel 70% alcohol,380894,known_failure,en,sanitizer code 380894 not in label space
|
| 78 |
+
drone with 4K camera,880211,known_failure,en,UAV code not in label space
|
| 79 |
+
solar panel 400W monocrystalline,854140,known_failure,en,maps to photosensitive devices - often misclassified
|
data/harmonized-system/harmonized-system.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/hs_codes_reference.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f917717590bf6f457bbefb65b2827aff3a11d029988dae3e1d315a4dbb54134
|
| 3 |
+
size 10516
|
data/hts/us_hts_lookup.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f3fa20f7b7a0573536ad33fb7ba795b43ecc4e5d5ef544b2d605716291a44ad
|
| 3 |
+
size 6988357
|
data/sample_documents/customs_zh.png
ADDED
|
data/sample_documents/invoice_en.png
ADDED
|
data/sample_documents/invoice_vi.png
ADDED
|
data/sample_documents/packing_list_th.png
ADDED
|
dataset/ATTRIBUTION.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attribution
|
| 2 |
+
|
| 3 |
+
This dataset includes and builds on HS nomenclature content from:
|
| 4 |
+
|
| 5 |
+
- `datasets/harmonized-system`:
|
| 6 |
+
- <https://github.com/datasets/harmonized-system>
|
| 7 |
+
- Upstream references: WCO HS nomenclature materials and UN Comtrade API
|
| 8 |
+
|
| 9 |
+
## Licensing
|
| 10 |
+
|
| 11 |
+
- Upstream HS source data: **ODC Public Domain Dedication and License (PDDL) v1.0**
|
| 12 |
+
- <https://opendatacommons.org/licenses/pddl/1-0/>
|
| 13 |
+
- Project-generated synthetic text examples and label normalization:
|
| 14 |
+
- **MIT License** (see this repository's `LICENSE`)
|
| 15 |
+
|
| 16 |
+
## Notes
|
| 17 |
+
|
| 18 |
+
- HS codes are international nomenclature identifiers and may vary by country-level tariff schedules/extensions.
|
| 19 |
+
- Use this dataset for prototyping and research; verify classifications in official customs workflows.
|
dataset/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pretty_name: HSClassify Micro Training Dataset
|
| 3 |
+
license: pddl
|
| 4 |
+
language:
|
| 5 |
+
- en
|
| 6 |
+
- th
|
| 7 |
+
- vi
|
| 8 |
+
- zh
|
| 9 |
+
task_categories:
|
| 10 |
+
- text-classification
|
| 11 |
+
task_ids:
|
| 12 |
+
- multi-class-classification
|
| 13 |
+
size_categories:
|
| 14 |
+
- 10K<n<100K
|
| 15 |
+
configs:
|
| 16 |
+
- config_name: default
|
| 17 |
+
data_files:
|
| 18 |
+
- split: train
|
| 19 |
+
path: training_data_indexed.csv
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# Dataset Card for HSClassify Micro Training Dataset
|
| 23 |
+
|
| 24 |
+
## Dataset Summary
|
| 25 |
+
|
| 26 |
+
This dataset supports multilingual HS code classification for customs and trade workflows.
|
| 27 |
+
It combines:
|
| 28 |
+
|
| 29 |
+
- HS nomenclature records (6-digit level and hierarchy context)
|
| 30 |
+
- Synthetic product descriptions mapped to HS codes
|
| 31 |
+
- Human-readable chapter/category labels for UI and latent-space analysis
|
| 32 |
+
|
| 33 |
+
## Included Files
|
| 34 |
+
|
| 35 |
+
- `training_data_indexed.csv`: training rows with text, HS code, chapter metadata, and language.
|
| 36 |
+
- `harmonized-system.csv`: source HS table snapshot used for data generation and indexing.
|
| 37 |
+
- `hs_codes_reference.json`: curated HS reference used by the app and training pipeline.
|
| 38 |
+
- `ATTRIBUTION.md`: explicit source and license attribution.
|
| 39 |
+
|
| 40 |
+
## Data Fields (`training_data_indexed.csv`)
|
| 41 |
+
|
| 42 |
+
- `text`: product description text used for embedding/classification.
|
| 43 |
+
- `hs_code`: 6-digit HS code target.
|
| 44 |
+
- `hs_chapter`: chapter description text.
|
| 45 |
+
- `hs_chapter_code`: chapter ID (e.g., `HS 08`).
|
| 46 |
+
- `hs_chapter_name`: normalized human-readable category label.
|
| 47 |
+
- `hs_desc`: HS description aligned to `hs_code`.
|
| 48 |
+
- `language`: language code (`en`, `th`, `vi`, `zh`).
|
| 49 |
+
|
| 50 |
+
## Source Attribution
|
| 51 |
+
|
| 52 |
+
Core HS nomenclature content is sourced from the `datasets/harmonized-system` project:
|
| 53 |
+
|
| 54 |
+
- Repository: <https://github.com/datasets/harmonized-system>
|
| 55 |
+
- Declared source chain in upstream metadata:
|
| 56 |
+
- WCO HS nomenclature documentation
|
| 57 |
+
- UN Comtrade data extraction API
|
| 58 |
+
- Upstream data license: ODC Public Domain Dedication and License (PDDL) v1.0
|
| 59 |
+
|
| 60 |
+
Project-added synthetic texts and normalized labels are released under this project's MIT license.
|
| 61 |
+
|
| 62 |
+
## Limitations
|
| 63 |
+
|
| 64 |
+
- Language balance is intentionally skewed toward English in the current snapshot.
|
| 65 |
+
- Synthetic text patterns may not cover all commercial phrasing edge cases.
|
| 66 |
+
- This dataset is for research/prototyping and is not legal customs advice.
|
field_extractor.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structured field extraction from OCR text of trade documents.
|
| 3 |
+
|
| 4 |
+
Extracts:
|
| 5 |
+
- Made in (country of origin)
|
| 6 |
+
- Ship to (destination country)
|
| 7 |
+
- Item price (numeric value)
|
| 8 |
+
- Currency (USD, EUR, etc.)
|
| 9 |
+
- Product description
|
| 10 |
+
- Email addresses
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
# --- Country Matching ---
|
| 17 |
+
|
| 18 |
+
COUNTRIES = {
|
| 19 |
+
# Major trading nations
|
| 20 |
+
"china": "CN", "peoples republic of china": "CN", "prc": "CN", "中国": "CN",
|
| 21 |
+
"united states": "US", "usa": "US", "u.s.a.": "US", "united states of america": "US",
|
| 22 |
+
"japan": "JP", "日本": "JP",
|
| 23 |
+
"germany": "DE", "deutschland": "DE",
|
| 24 |
+
"united kingdom": "GB", "uk": "GB", "great britain": "GB", "england": "GB",
|
| 25 |
+
"france": "FR",
|
| 26 |
+
"italy": "IT", "italia": "IT",
|
| 27 |
+
"south korea": "KR", "korea": "KR", "republic of korea": "KR", "한국": "KR",
|
| 28 |
+
"india": "IN",
|
| 29 |
+
"canada": "CA",
|
| 30 |
+
"australia": "AU",
|
| 31 |
+
"brazil": "BR",
|
| 32 |
+
"mexico": "MX",
|
| 33 |
+
"indonesia": "ID",
|
| 34 |
+
"thailand": "TH", "ไทย": "TH",
|
| 35 |
+
"vietnam": "VN", "viet nam": "VN", "việt nam": "VN",
|
| 36 |
+
"malaysia": "MY",
|
| 37 |
+
"singapore": "SG",
|
| 38 |
+
"taiwan": "TW", "chinese taipei": "TW",
|
| 39 |
+
"netherlands": "NL", "holland": "NL",
|
| 40 |
+
"spain": "ES", "españa": "ES",
|
| 41 |
+
"turkey": "TR", "türkiye": "TR",
|
| 42 |
+
"switzerland": "CH",
|
| 43 |
+
"saudi arabia": "SA",
|
| 44 |
+
"united arab emirates": "AE", "uae": "AE",
|
| 45 |
+
"poland": "PL",
|
| 46 |
+
"sweden": "SE",
|
| 47 |
+
"belgium": "BE",
|
| 48 |
+
"argentina": "AR",
|
| 49 |
+
"austria": "AT",
|
| 50 |
+
"norway": "NO",
|
| 51 |
+
"ireland": "IE",
|
| 52 |
+
"israel": "IL",
|
| 53 |
+
"denmark": "DK",
|
| 54 |
+
"philippines": "PH",
|
| 55 |
+
"colombia": "CO",
|
| 56 |
+
"pakistan": "PK",
|
| 57 |
+
"chile": "CL",
|
| 58 |
+
"finland": "FI",
|
| 59 |
+
"bangladesh": "BD",
|
| 60 |
+
"egypt": "EG",
|
| 61 |
+
"czech republic": "CZ", "czechia": "CZ",
|
| 62 |
+
"portugal": "PT",
|
| 63 |
+
"romania": "RO",
|
| 64 |
+
"new zealand": "NZ",
|
| 65 |
+
"greece": "GR",
|
| 66 |
+
"peru": "PE",
|
| 67 |
+
"south africa": "ZA",
|
| 68 |
+
"hungary": "HU",
|
| 69 |
+
"sri lanka": "LK",
|
| 70 |
+
"cambodia": "KH",
|
| 71 |
+
"myanmar": "MM", "burma": "MM",
|
| 72 |
+
"nigeria": "NG",
|
| 73 |
+
"kenya": "KE",
|
| 74 |
+
"ghana": "GH",
|
| 75 |
+
"ethiopia": "ET",
|
| 76 |
+
"tanzania": "TZ",
|
| 77 |
+
"morocco": "MA",
|
| 78 |
+
"hong kong": "HK",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Reverse map: code -> name
|
| 82 |
+
COUNTRY_CODE_TO_NAME = {}
|
| 83 |
+
for name, code in COUNTRIES.items():
|
| 84 |
+
if code not in COUNTRY_CODE_TO_NAME:
|
| 85 |
+
COUNTRY_CODE_TO_NAME[code] = name.title()
|
| 86 |
+
|
| 87 |
+
# Fix some names
|
| 88 |
+
COUNTRY_CODE_TO_NAME["US"] = "United States"
|
| 89 |
+
COUNTRY_CODE_TO_NAME["GB"] = "United Kingdom"
|
| 90 |
+
COUNTRY_CODE_TO_NAME["CN"] = "China"
|
| 91 |
+
COUNTRY_CODE_TO_NAME["KR"] = "South Korea"
|
| 92 |
+
COUNTRY_CODE_TO_NAME["AE"] = "United Arab Emirates"
|
| 93 |
+
COUNTRY_CODE_TO_NAME["NZ"] = "New Zealand"
|
| 94 |
+
COUNTRY_CODE_TO_NAME["ZA"] = "South Africa"
|
| 95 |
+
COUNTRY_CODE_TO_NAME["CZ"] = "Czech Republic"
|
| 96 |
+
COUNTRY_CODE_TO_NAME["HK"] = "Hong Kong"
|
| 97 |
+
COUNTRY_CODE_TO_NAME["TW"] = "Taiwan"
|
| 98 |
+
COUNTRY_CODE_TO_NAME["SA"] = "Saudi Arabia"
|
| 99 |
+
COUNTRY_CODE_TO_NAME["NL"] = "Netherlands"
|
| 100 |
+
|
| 101 |
+
# All country names for dropdown
|
| 102 |
+
ALL_COUNTRIES = sorted(set(COUNTRY_CODE_TO_NAME.values()))
|
| 103 |
+
|
| 104 |
+
# --- Currency Matching ---
|
| 105 |
+
|
| 106 |
+
CURRENCIES = {
|
| 107 |
+
"USD": "US Dollar",
|
| 108 |
+
"EUR": "Euro",
|
| 109 |
+
"GBP": "British Pound",
|
| 110 |
+
"JPY": "Japanese Yen",
|
| 111 |
+
"CNY": "Chinese Yuan",
|
| 112 |
+
"RMB": "Chinese Yuan",
|
| 113 |
+
"KRW": "Korean Won",
|
| 114 |
+
"THB": "Thai Baht",
|
| 115 |
+
"VND": "Vietnamese Dong",
|
| 116 |
+
"INR": "Indian Rupee",
|
| 117 |
+
"CAD": "Canadian Dollar",
|
| 118 |
+
"AUD": "Australian Dollar",
|
| 119 |
+
"SGD": "Singapore Dollar",
|
| 120 |
+
"MYR": "Malaysian Ringgit",
|
| 121 |
+
"IDR": "Indonesian Rupiah",
|
| 122 |
+
"PHP": "Philippine Peso",
|
| 123 |
+
"BRL": "Brazilian Real",
|
| 124 |
+
"MXN": "Mexican Peso",
|
| 125 |
+
"CHF": "Swiss Franc",
|
| 126 |
+
"SEK": "Swedish Krona",
|
| 127 |
+
"NOK": "Norwegian Krone",
|
| 128 |
+
"DKK": "Danish Krone",
|
| 129 |
+
"HKD": "Hong Kong Dollar",
|
| 130 |
+
"TWD": "Taiwan Dollar",
|
| 131 |
+
"AED": "UAE Dirham",
|
| 132 |
+
"SAR": "Saudi Riyal",
|
| 133 |
+
"ZAR": "South African Rand",
|
| 134 |
+
"NZD": "New Zealand Dollar",
|
| 135 |
+
"TRY": "Turkish Lira",
|
| 136 |
+
"PLN": "Polish Zloty",
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
CURRENCY_SYMBOLS = {
|
| 140 |
+
"$": "USD",
|
| 141 |
+
"€": "EUR",
|
| 142 |
+
"£": "GBP",
|
| 143 |
+
"¥": "JPY",
|
| 144 |
+
"₹": "INR",
|
| 145 |
+
"฿": "THB",
|
| 146 |
+
"₫": "VND",
|
| 147 |
+
"₩": "KRW",
|
| 148 |
+
"R$": "BRL",
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def find_country(text: str, context_keywords: list[str]) -> Optional[str]:
|
| 153 |
+
"""Find a country name near context keywords in the text."""
|
| 154 |
+
text_lower = text.lower()
|
| 155 |
+
|
| 156 |
+
# Try to find country near context keywords
|
| 157 |
+
for keyword in context_keywords:
|
| 158 |
+
# Search for keyword in text
|
| 159 |
+
pattern = re.compile(
|
| 160 |
+
rf'{keyword}\s*[:\-]?\s*(.{{2,50}})',
|
| 161 |
+
re.IGNORECASE
|
| 162 |
+
)
|
| 163 |
+
match = pattern.search(text)
|
| 164 |
+
if match:
|
| 165 |
+
fragment = match.group(1).strip().lower()
|
| 166 |
+
# Check if any country name is in the fragment
|
| 167 |
+
for country_name, code in sorted(COUNTRIES.items(), key=lambda x: -len(x[0])):
|
| 168 |
+
if country_name in fragment:
|
| 169 |
+
return code
|
| 170 |
+
# Also check for ISO country codes (2 letters)
|
| 171 |
+
code_match = re.match(r'^([A-Z]{2})\b', match.group(1).strip())
|
| 172 |
+
if code_match:
|
| 173 |
+
c = code_match.group(1)
|
| 174 |
+
if c in COUNTRY_CODE_TO_NAME:
|
| 175 |
+
return c
|
| 176 |
+
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def extract_fields(ocr_text: str) -> dict:
|
| 181 |
+
"""
|
| 182 |
+
Extract structured fields from OCR text of a trade document.
|
| 183 |
+
|
| 184 |
+
Returns dict with:
|
| 185 |
+
- email: str or None
|
| 186 |
+
- made_in: country code or None
|
| 187 |
+
- ship_to: country code or None
|
| 188 |
+
- item_price: float or None
|
| 189 |
+
- currency: currency code or None
|
| 190 |
+
- product_description: str or None
|
| 191 |
+
- raw_text: the original OCR text
|
| 192 |
+
- confidence: dict with confidence scores for each field
|
| 193 |
+
"""
|
| 194 |
+
result = {
|
| 195 |
+
"email": None,
|
| 196 |
+
"made_in": None,
|
| 197 |
+
"made_in_name": None,
|
| 198 |
+
"ship_to": None,
|
| 199 |
+
"ship_to_name": None,
|
| 200 |
+
"item_price": None,
|
| 201 |
+
"currency": None,
|
| 202 |
+
"product_description": None,
|
| 203 |
+
"raw_text": ocr_text,
|
| 204 |
+
"confidence": {},
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
if not ocr_text or not ocr_text.strip():
|
| 208 |
+
return result
|
| 209 |
+
|
| 210 |
+
text = ocr_text.strip()
|
| 211 |
+
|
| 212 |
+
# --- Extract Email ---
|
| 213 |
+
email_pattern = re.compile(
|
| 214 |
+
r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}',
|
| 215 |
+
re.IGNORECASE
|
| 216 |
+
)
|
| 217 |
+
email_match = email_pattern.search(text)
|
| 218 |
+
if email_match:
|
| 219 |
+
result["email"] = email_match.group(0)
|
| 220 |
+
result["confidence"]["email"] = 0.95
|
| 221 |
+
|
| 222 |
+
# --- Extract Country of Origin (Made in) ---
|
| 223 |
+
origin_keywords = [
|
| 224 |
+
"made in", "manufactured in", "produced in", "origin",
|
| 225 |
+
"country of origin", "country of manufacture",
|
| 226 |
+
"mfg country", "mfg. country", "fabricated in",
|
| 227 |
+
"assembled in", "place of origin", "product of",
|
| 228 |
+
"sourced from", "shipped from", "exporting country",
|
| 229 |
+
"from"
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
origin_code = find_country(text, origin_keywords)
|
| 233 |
+
if origin_code:
|
| 234 |
+
result["made_in"] = origin_code
|
| 235 |
+
result["made_in_name"] = COUNTRY_CODE_TO_NAME.get(origin_code, origin_code)
|
| 236 |
+
result["confidence"]["made_in"] = 0.85
|
| 237 |
+
|
| 238 |
+
# --- Extract Destination (Ship to) ---
|
| 239 |
+
dest_keywords = [
|
| 240 |
+
"ship to", "shipped to", "deliver to", "delivery to",
|
| 241 |
+
"destination", "consignee", "import to", "importing country",
|
| 242 |
+
"port of discharge", "port of destination", "final destination",
|
| 243 |
+
"to country", "dest", "buyer country",
|
| 244 |
+
"bill to", "sold to"
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
dest_code = find_country(text, dest_keywords)
|
| 248 |
+
if dest_code:
|
| 249 |
+
result["ship_to"] = dest_code
|
| 250 |
+
result["ship_to_name"] = COUNTRY_CODE_TO_NAME.get(dest_code, dest_code)
|
| 251 |
+
result["confidence"]["ship_to"] = 0.80
|
| 252 |
+
|
| 253 |
+
# --- Extract Currency ---
|
| 254 |
+
# First check for currency symbols
|
| 255 |
+
for symbol, curr_code in sorted(CURRENCY_SYMBOLS.items(), key=lambda x: -len(x[0])):
|
| 256 |
+
if symbol in text:
|
| 257 |
+
result["currency"] = curr_code
|
| 258 |
+
result["confidence"]["currency"] = 0.90
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
# Then check for explicit currency codes
|
| 262 |
+
if not result["currency"]:
|
| 263 |
+
for curr_code in CURRENCIES:
|
| 264 |
+
pattern = re.compile(rf'\b{curr_code}\b', re.IGNORECASE)
|
| 265 |
+
if pattern.search(text):
|
| 266 |
+
result["currency"] = curr_code
|
| 267 |
+
result["confidence"]["currency"] = 0.95
|
| 268 |
+
break
|
| 269 |
+
|
| 270 |
+
# --- Extract Price ---
|
| 271 |
+
price_patterns = [
|
| 272 |
+
# "price: $123.45" or "amount: 123.45 USD"
|
| 273 |
+
re.compile(
|
| 274 |
+
r'(?:price|amount|total|value|unit price|item price|cost|fob value|cif value|invoice value)\s*[:\-]?\s*'
|
| 275 |
+
r'(?:[A-Z]{3}\s*)?'
|
| 276 |
+
r'[\$€£¥₹฿₫₩]?\s*'
|
| 277 |
+
r'([\d,]+\.?\d*)',
|
| 278 |
+
re.IGNORECASE
|
| 279 |
+
),
|
| 280 |
+
# "$123.45" or "€99.99"
|
| 281 |
+
re.compile(
|
| 282 |
+
r'[\$€£¥₹฿₫₩]\s*([\d,]+\.?\d*)'
|
| 283 |
+
),
|
| 284 |
+
# "123.45 USD" or "99.99 EUR"
|
| 285 |
+
re.compile(
|
| 286 |
+
r'([\d,]+\.?\d*)\s*(?:USD|EUR|GBP|JPY|CNY|RMB|THB|VND|INR|CAD|AUD|SGD|MYR)',
|
| 287 |
+
re.IGNORECASE
|
| 288 |
+
),
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
for pattern in price_patterns:
|
| 292 |
+
match = pattern.search(text)
|
| 293 |
+
if match:
|
| 294 |
+
price_str = match.group(1).replace(",", "")
|
| 295 |
+
try:
|
| 296 |
+
price = float(price_str)
|
| 297 |
+
if 0 < price < 1e12: # Sanity check
|
| 298 |
+
result["item_price"] = price
|
| 299 |
+
result["confidence"]["item_price"] = 0.80
|
| 300 |
+
break
|
| 301 |
+
except ValueError:
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# --- Extract Product Description ---
|
| 305 |
+
desc_keywords = [
|
| 306 |
+
"description", "product description", "item description",
|
| 307 |
+
"goods description", "description of goods",
|
| 308 |
+
"commodity", "product name", "item name",
|
| 309 |
+
"goods", "merchandise", "articles"
|
| 310 |
+
]
|
| 311 |
+
|
| 312 |
+
for keyword in desc_keywords:
|
| 313 |
+
pattern = re.compile(
|
| 314 |
+
rf'{keyword}\s*[:\-]?\s*(.{{10,300}}?)(?:\n|$)',
|
| 315 |
+
re.IGNORECASE
|
| 316 |
+
)
|
| 317 |
+
match = pattern.search(text)
|
| 318 |
+
if match:
|
| 319 |
+
desc = match.group(1).strip()
|
| 320 |
+
# Clean up
|
| 321 |
+
desc = re.sub(r'\s+', ' ', desc)
|
| 322 |
+
if len(desc) > 10:
|
| 323 |
+
result["product_description"] = desc
|
| 324 |
+
result["confidence"]["product_description"] = 0.75
|
| 325 |
+
break
|
| 326 |
+
|
| 327 |
+
# If no structured description found, use the longest non-header line
|
| 328 |
+
if not result["product_description"]:
|
| 329 |
+
lines = [l.strip() for l in text.split('\n') if l.strip() and len(l.strip()) > 15]
|
| 330 |
+
# Filter out lines that look like headers/labels
|
| 331 |
+
content_lines = [
|
| 332 |
+
l for l in lines
|
| 333 |
+
if not re.match(r'^(invoice|bill|date|ref|no\.|number|email|phone|fax|tel|address)', l, re.IGNORECASE)
|
| 334 |
+
and not re.match(r'^[A-Z\s]{2,20}:$', l)
|
| 335 |
+
]
|
| 336 |
+
if content_lines:
|
| 337 |
+
# Pick the longest line as likely description
|
| 338 |
+
best = max(content_lines, key=len)
|
| 339 |
+
result["product_description"] = best[:300]
|
| 340 |
+
result["confidence"]["product_description"] = 0.40
|
| 341 |
+
|
| 342 |
+
return result
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_all_countries() -> list[dict]:
|
| 346 |
+
"""Return list of all countries for dropdowns."""
|
| 347 |
+
return [
|
| 348 |
+
{"code": code, "name": name}
|
| 349 |
+
for code, name in sorted(COUNTRY_CODE_TO_NAME.items(), key=lambda x: x[1])
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_all_currencies() -> list[dict]:
|
| 354 |
+
"""Return list of all currencies for dropdowns."""
|
| 355 |
+
return [
|
| 356 |
+
{"code": code, "name": name}
|
| 357 |
+
for code, name in sorted(CURRENCIES.items(), key=lambda x: x[1])
|
| 358 |
+
]
|
hs_dataset.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Harmonized System dataset integration.
|
| 3 |
+
|
| 4 |
+
Loads the official HS code dataset from:
|
| 5 |
+
https://github.com/datasets/harmonized-system
|
| 6 |
+
|
| 7 |
+
Provides:
|
| 8 |
+
- Full HS code lookup (2, 4, 6 digit)
|
| 9 |
+
- Section/chapter/heading/subheading hierarchy
|
| 10 |
+
- HTS extension support (country-specific 7-10 digit codes)
|
| 11 |
+
- Search by description
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import csv
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
PROJECT_DIR = Path(__file__).parent
|
| 22 |
+
HS_DATA_PATH = PROJECT_DIR / "data" / "harmonized-system" / "harmonized-system.csv"
|
| 23 |
+
US_HTS_LOOKUP_PATH = PROJECT_DIR / "data" / "hts" / "us_hts_lookup.json"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HSDataset:
|
| 27 |
+
"""Harmonized System code dataset."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.codes = {} # hscode -> {section, description, parent, level}
|
| 31 |
+
self.sections = {} # section number -> section name
|
| 32 |
+
self.chapters = {} # 2-digit -> description
|
| 33 |
+
self.headings = {} # 4-digit -> description
|
| 34 |
+
self.subheadings = {} # 6-digit -> description
|
| 35 |
+
self._loaded = False
|
| 36 |
+
|
| 37 |
+
def load(self) -> bool:
|
| 38 |
+
"""Load the HS dataset from CSV."""
|
| 39 |
+
if self._loaded:
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
if not HS_DATA_PATH.exists():
|
| 43 |
+
print(f"HS dataset not found at {HS_DATA_PATH}")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
with open(HS_DATA_PATH, 'r', encoding='utf-8') as f:
|
| 47 |
+
reader = csv.DictReader(f)
|
| 48 |
+
for row in reader:
|
| 49 |
+
hscode = row['hscode'].strip()
|
| 50 |
+
desc = row['description'].strip()
|
| 51 |
+
section = row['section'].strip()
|
| 52 |
+
parent = row['parent'].strip()
|
| 53 |
+
level = int(row['level'])
|
| 54 |
+
|
| 55 |
+
self.codes[hscode] = {
|
| 56 |
+
'section': section,
|
| 57 |
+
'description': desc,
|
| 58 |
+
'parent': parent,
|
| 59 |
+
'level': level,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
if level == 2:
|
| 63 |
+
self.chapters[hscode] = desc
|
| 64 |
+
elif level == 4:
|
| 65 |
+
self.headings[hscode] = desc
|
| 66 |
+
elif level == 6:
|
| 67 |
+
self.subheadings[hscode] = desc
|
| 68 |
+
|
| 69 |
+
self._loaded = True
|
| 70 |
+
print(f"Loaded HS dataset: {len(self.chapters)} chapters, "
|
| 71 |
+
f"{len(self.headings)} headings, {len(self.subheadings)} subheadings")
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
def lookup(self, hscode: str) -> Optional[dict]:
|
| 75 |
+
"""Look up an HS code and return full hierarchy."""
|
| 76 |
+
hscode = hscode.strip().replace('.', '').replace(' ', '')
|
| 77 |
+
|
| 78 |
+
if hscode not in self.codes:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
entry = self.codes[hscode].copy()
|
| 82 |
+
|
| 83 |
+
# Build hierarchy
|
| 84 |
+
hierarchy = []
|
| 85 |
+
current = hscode
|
| 86 |
+
while current and current in self.codes and current != 'TOTAL':
|
| 87 |
+
hierarchy.insert(0, {
|
| 88 |
+
'code': current,
|
| 89 |
+
'description': self.codes[current]['description'],
|
| 90 |
+
'level': self.codes[current]['level'],
|
| 91 |
+
})
|
| 92 |
+
current = self.codes[current]['parent']
|
| 93 |
+
|
| 94 |
+
entry['hierarchy'] = hierarchy
|
| 95 |
+
entry['hscode'] = hscode
|
| 96 |
+
|
| 97 |
+
# Get chapter and heading descriptions
|
| 98 |
+
if len(hscode) >= 2:
|
| 99 |
+
ch = hscode[:2]
|
| 100 |
+
entry['chapter'] = self.chapters.get(ch, '')
|
| 101 |
+
entry['chapter_code'] = ch
|
| 102 |
+
if len(hscode) >= 4:
|
| 103 |
+
hd = hscode[:4]
|
| 104 |
+
entry['heading'] = self.headings.get(hd, '')
|
| 105 |
+
entry['heading_code'] = hd
|
| 106 |
+
if len(hscode) == 6:
|
| 107 |
+
entry['subheading'] = self.subheadings.get(hscode, '')
|
| 108 |
+
|
| 109 |
+
return entry
|
| 110 |
+
|
| 111 |
+
def search(self, query: str, max_results: int = 20) -> list[dict]:
|
| 112 |
+
"""Search HS codes by description text."""
|
| 113 |
+
query_lower = query.lower()
|
| 114 |
+
query_words = set(query_lower.split())
|
| 115 |
+
|
| 116 |
+
results = []
|
| 117 |
+
for hscode, info in self.codes.items():
|
| 118 |
+
if info['level'] != 6:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
desc_lower = info['description'].lower()
|
| 122 |
+
|
| 123 |
+
# Score by word overlap
|
| 124 |
+
desc_words = set(desc_lower.split())
|
| 125 |
+
overlap = query_words & desc_words
|
| 126 |
+
|
| 127 |
+
if overlap:
|
| 128 |
+
score = len(overlap) / len(query_words)
|
| 129 |
+
# Bonus for exact substring match
|
| 130 |
+
if query_lower in desc_lower:
|
| 131 |
+
score += 1.0
|
| 132 |
+
|
| 133 |
+
results.append({
|
| 134 |
+
'hscode': hscode,
|
| 135 |
+
'description': info['description'],
|
| 136 |
+
'section': info['section'],
|
| 137 |
+
'score': score,
|
| 138 |
+
})
|
| 139 |
+
|
| 140 |
+
results.sort(key=lambda x: -x['score'])
|
| 141 |
+
return results[:max_results]
|
| 142 |
+
|
| 143 |
+
def get_chapter_name(self, chapter_code: str) -> str:
|
| 144 |
+
"""Get chapter description from 2-digit code."""
|
| 145 |
+
return self.chapters.get(chapter_code.zfill(2), 'Unknown')
|
| 146 |
+
|
| 147 |
+
def validate_hs_code(self, hscode: str) -> dict:
|
| 148 |
+
"""Validate an HS code and return info about its validity."""
|
| 149 |
+
hscode = hscode.strip().replace('.', '').replace(' ', '')
|
| 150 |
+
|
| 151 |
+
result = {
|
| 152 |
+
'valid': False,
|
| 153 |
+
'code': hscode,
|
| 154 |
+
'level': None,
|
| 155 |
+
'description': None,
|
| 156 |
+
'message': '',
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
if not re.match(r'^\d{2,6}$', hscode):
|
| 160 |
+
result['message'] = 'HS code must be 2-6 digits'
|
| 161 |
+
return result
|
| 162 |
+
|
| 163 |
+
if hscode in self.codes:
|
| 164 |
+
info = self.codes[hscode]
|
| 165 |
+
result['valid'] = True
|
| 166 |
+
result['level'] = info['level']
|
| 167 |
+
result['description'] = info['description']
|
| 168 |
+
result['message'] = f'Valid {info["level"]}-digit HS code'
|
| 169 |
+
else:
|
| 170 |
+
# Check if partial code is valid
|
| 171 |
+
if len(hscode) == 6:
|
| 172 |
+
heading = hscode[:4]
|
| 173 |
+
chapter = hscode[:2]
|
| 174 |
+
if heading in self.codes:
|
| 175 |
+
result['message'] = f'Heading {heading} exists but subheading {hscode} not found'
|
| 176 |
+
elif chapter in self.codes:
|
| 177 |
+
result['message'] = f'Chapter {chapter} exists but code {hscode} not found'
|
| 178 |
+
else:
|
| 179 |
+
result['message'] = f'Code {hscode} not found in HS nomenclature'
|
| 180 |
+
|
| 181 |
+
return result
|
| 182 |
+
|
| 183 |
+
def get_all_6digit_codes(self) -> list[dict]:
|
| 184 |
+
"""Return all 6-digit HS codes with descriptions."""
|
| 185 |
+
return [
|
| 186 |
+
{'hscode': code, 'description': info['description'], 'section': info['section']}
|
| 187 |
+
for code, info in self.codes.items()
|
| 188 |
+
if info['level'] == 6
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# --- HTS Extensions ---
|
| 193 |
+
# HTS (Harmonized Tariff Schedule) adds country-specific digits (7-10) after the 6-digit HS code.
|
| 194 |
+
# This is a simplified reference for major trading partners.
|
| 195 |
+
|
| 196 |
+
def _load_us_hts_extensions() -> dict:
|
| 197 |
+
"""Load US HTS extensions from the pre-built JSON lookup table."""
|
| 198 |
+
if not US_HTS_LOOKUP_PATH.exists():
|
| 199 |
+
return {}
|
| 200 |
+
with open(US_HTS_LOOKUP_PATH, "r", encoding="utf-8") as f:
|
| 201 |
+
raw = json.load(f)
|
| 202 |
+
# Convert from build_hts_lookup format to API format
|
| 203 |
+
extensions = {}
|
| 204 |
+
for hs6, entries in raw.items():
|
| 205 |
+
extensions[hs6] = [
|
| 206 |
+
{"hts": e["hts_code"], "description": e["description"],
|
| 207 |
+
"general_duty": e.get("general_duty", ""),
|
| 208 |
+
"special_duty": e.get("special_duty", ""),
|
| 209 |
+
"unit": e.get("unit", "")}
|
| 210 |
+
for e in entries
|
| 211 |
+
]
|
| 212 |
+
return extensions
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Lazy-loaded cache for US HTS data
|
| 216 |
+
_us_hts_cache = None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _get_us_hts_extensions() -> dict:
|
| 220 |
+
global _us_hts_cache
|
| 221 |
+
if _us_hts_cache is None:
|
| 222 |
+
_us_hts_cache = _load_us_hts_extensions()
|
| 223 |
+
return _us_hts_cache
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
HTS_EXTENSIONS = {
|
| 227 |
+
"US": {
|
| 228 |
+
"name": "United States HTS",
|
| 229 |
+
"digits": 10,
|
| 230 |
+
"format": "XXXX.XX.XXXX",
|
| 231 |
+
# Extensions loaded lazily from us_hts_lookup.json
|
| 232 |
+
"extensions": None, # Sentinel — resolved in get_hts_extensions()
|
| 233 |
+
},
|
| 234 |
+
"EU": {
|
| 235 |
+
"name": "EU Combined Nomenclature (CN)",
|
| 236 |
+
"digits": 8,
|
| 237 |
+
"format": "XXXX.XX.XX",
|
| 238 |
+
"extensions": {
|
| 239 |
+
"851712": [
|
| 240 |
+
{"hts": "85171200", "description": "Telephones for cellular networks; smartphones"},
|
| 241 |
+
],
|
| 242 |
+
"847130": [
|
| 243 |
+
{"hts": "84713000", "description": "Portable digital automatic data-processing machines, ≤ 10 kg"},
|
| 244 |
+
],
|
| 245 |
+
"870380": [
|
| 246 |
+
{"hts": "87038000", "description": "Other vehicles, with electric motor for propulsion"},
|
| 247 |
+
],
|
| 248 |
+
}
|
| 249 |
+
},
|
| 250 |
+
"CN": {
|
| 251 |
+
"name": "China Customs Tariff",
|
| 252 |
+
"digits": 10,
|
| 253 |
+
"format": "XXXX.XXXX.XX",
|
| 254 |
+
"extensions": {
|
| 255 |
+
"851712": [
|
| 256 |
+
{"hts": "8517120010", "description": "Smartphones, 5G capable"},
|
| 257 |
+
{"hts": "8517120090", "description": "Other mobile phones"},
|
| 258 |
+
],
|
| 259 |
+
"847130": [
|
| 260 |
+
{"hts": "8471300000", "description": "Portable digital data processing machines"},
|
| 261 |
+
],
|
| 262 |
+
}
|
| 263 |
+
},
|
| 264 |
+
"JP": {
|
| 265 |
+
"name": "Japan HS Tariff",
|
| 266 |
+
"digits": 9,
|
| 267 |
+
"format": "XXXX.XX.XXX",
|
| 268 |
+
"extensions": {
|
| 269 |
+
"851712": [
|
| 270 |
+
{"hts": "851712000", "description": "Telephones for cellular networks or wireless"},
|
| 271 |
+
],
|
| 272 |
+
"870380": [
|
| 273 |
+
{"hts": "870380000", "description": "Electric motor vehicles for passenger transport"},
|
| 274 |
+
],
|
| 275 |
+
}
|
| 276 |
+
},
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_hts_extensions(hs_code: str, country_code: str) -> Optional[dict]:
|
| 281 |
+
"""
|
| 282 |
+
Get HTS (country-specific) extensions for a 6-digit HS code.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
hs_code: 6-digit HS code
|
| 286 |
+
country_code: 2-letter country code (US, EU, CN, JP, etc.)
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Dict with country HTS info and available extensions, or None.
|
| 290 |
+
"""
|
| 291 |
+
hs_code = hs_code.strip().replace('.', '').replace(' ', '')
|
| 292 |
+
country_code = country_code.upper().strip()
|
| 293 |
+
|
| 294 |
+
if country_code not in HTS_EXTENSIONS:
|
| 295 |
+
return {
|
| 296 |
+
"available": False,
|
| 297 |
+
"country": country_code,
|
| 298 |
+
"message": f"HTS extensions not available for {country_code}. "
|
| 299 |
+
f"Available: {', '.join(HTS_EXTENSIONS.keys())}",
|
| 300 |
+
"extensions": [],
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
tariff = HTS_EXTENSIONS[country_code]
|
| 304 |
+
# US extensions are lazy-loaded from JSON
|
| 305 |
+
if country_code == "US":
|
| 306 |
+
ext_dict = _get_us_hts_extensions()
|
| 307 |
+
else:
|
| 308 |
+
ext_dict = tariff["extensions"]
|
| 309 |
+
extensions = ext_dict.get(hs_code, [])
|
| 310 |
+
|
| 311 |
+
return {
|
| 312 |
+
"available": True,
|
| 313 |
+
"country": country_code,
|
| 314 |
+
"tariff_name": tariff["name"],
|
| 315 |
+
"total_digits": tariff["digits"],
|
| 316 |
+
"format": tariff["format"],
|
| 317 |
+
"extensions": extensions,
|
| 318 |
+
"hs_code": hs_code,
|
| 319 |
+
"message": f"Found {len(extensions)} HTS extension(s)" if extensions else
|
| 320 |
+
f"No specific extensions found for {hs_code} in {tariff['name']}. "
|
| 321 |
+
f"The base HS code {hs_code} applies.",
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def get_available_hts_countries() -> list[dict]:
|
| 326 |
+
"""Return list of countries with HTS extensions available."""
|
| 327 |
+
return [
|
| 328 |
+
{"code": code, "name": info["name"], "digits": info["digits"]}
|
| 329 |
+
for code, info in HTS_EXTENSIONS.items()
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# Singleton instance
|
| 334 |
+
_dataset = HSDataset()
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_dataset() -> HSDataset:
|
| 338 |
+
"""Get the singleton HSDataset instance, loading if necessary."""
|
| 339 |
+
if not _dataset._loaded:
|
| 340 |
+
_dataset.load()
|
| 341 |
+
return _dataset
|
models/.gitkeep
ADDED
|
File without changes
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-r requirements.txt
|
| 2 |
+
datasets>=3.0
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.129.0
|
| 2 |
+
uvicorn[standard]==0.41.0
|
| 3 |
+
sentence-transformers==5.2.3
|
| 4 |
+
transformers==5.2.0
|
| 5 |
+
torch==2.10.0
|
| 6 |
+
scikit-learn==1.8.0
|
| 7 |
+
pandas==3.0.1
|
| 8 |
+
numpy==2.3.5
|
| 9 |
+
plotly==6.5.2
|
| 10 |
+
umap-learn==0.5.11
|
| 11 |
+
jinja2==3.1.6
|
| 12 |
+
pytesseract==0.3.13
|
| 13 |
+
pdf2image==1.17.0
|
| 14 |
+
pillow==12.1.1
|
| 15 |
+
python-multipart==0.0.22
|
static/.gitkeep
ADDED
|
File without changes
|
templates/index.html
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|