| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
| from datetime import datetime |
| import json |
| import os |
| import math |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| MODEL_DIR = 'model' |
| FULL_MODEL_PATH = os.path.join(MODEL_DIR, 'cascaded_best.pt') |
| CONFIG_PATH = os.path.join(MODEL_DIR, 'model_config.json') |
| TOKENIZER_PATH = os.path.join(MODEL_DIR, 'tokenizer') |
| BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model') |
|
|
| DICT_2 = os.path.join(MODEL_DIR, 'label2id_2.json') |
| DICT_4 = os.path.join(MODEL_DIR, 'label2id_4.json') |
| DICT_6 = os.path.join(MODEL_DIR, 'label2id_6.json') |
|
|
| RESULTS_PATH = os.path.join(MODEL_DIR, 'test_results.txt') |
|
|
|
|
| class ArcMarginProduct(nn.Module): |
| """ArcFace classifier (inference mode: no margin, just cosine * scale).""" |
| def __init__(self, in_features, out_features, s=30.0, m=0.30): |
| super().__init__() |
| self.s = s |
| self.m = m |
| self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) |
| nn.init.xavier_uniform_(self.weight) |
| self.cos_m = math.cos(m) |
| self.sin_m = math.sin(m) |
| self.th = math.cos(math.pi - m) |
| self.mm = math.sin(math.pi - m) * m |
|
|
| def forward(self, x, label=None): |
| cosine = F.linear(F.normalize(x), F.normalize(self.weight)) |
| if label is not None and self.training: |
| sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1)) |
| phi = cosine * self.cos_m - sine * self.sin_m |
| phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
| one_hot = torch.zeros_like(cosine) |
| one_hot.scatter_(1, label.view(-1, 1).long(), 1) |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
| return output * self.s |
| return cosine * self.s |
|
|
|
|
| class CascadedClassifier(nn.Module): |
| """3-level cascaded classifier: 2 → 4 → 6 with ArcFace on level 6.""" |
| def __init__(self, base_model, hidden_size, n2, n4, n6, |
| dropout=0.15, arc_s=30.0, arc_m=0.3): |
| super().__init__() |
| self.base_model = base_model |
| self.drop = nn.Dropout(dropout) |
|
|
| self.head_2 = nn.Sequential( |
| nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(256, n2)) |
|
|
| self.head_4_fusion = nn.Linear(hidden_size + n2, hidden_size) |
| self.head_4 = nn.Sequential( |
| nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, n4)) |
|
|
| self.head_6_fusion = nn.Linear(hidden_size + n4, hidden_size) |
| self.head_6_feat = nn.Sequential( |
| nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(hidden_size, 512), nn.GELU()) |
| self.head_6_arc = ArcMarginProduct(512, n6, s=arc_s, m=arc_m) |
|
|
| def forward(self, input_ids, attention_mask, label_6=None): |
| out = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
| cls_out = self.drop(out.last_hidden_state[:, 0, :]) |
|
|
| l2 = self.head_2(cls_out) |
| p2 = torch.softmax(l2, dim=1) |
| f4 = self.head_4_fusion(torch.cat([cls_out, p2], dim=1)) |
| l4 = self.head_4(f4) |
| p4 = torch.softmax(l4, dim=1) |
| f6 = self.head_6_fusion(torch.cat([cls_out, p4], dim=1)) |
| feat6 = self.head_6_feat(f6) |
| l6 = self.head_6_arc(feat6, label_6) |
| return l2, l4, l6 |
|
|
|
|
| def save_result(filepath, text, candidates, cascade_2, cascade_4): |
| """Append a single test result to the results txt file.""" |
| with open(filepath, 'a', encoding='utf-8') as f: |
| f.write(f"\n{'='*80}\n") |
| f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| f.write(f"Input: {text}\n") |
| f.write(f"Cascade: {cascade_2} → {cascade_4}\n") |
| f.write(f"{'-'*80}\n") |
| f.write(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain\n") |
| f.write(f"{'-'*80}\n") |
| for i, c in enumerate(candidates[:5]): |
| cd = c['code'] |
| ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})" |
| f.write(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}\n") |
| f.write(f"{'-'*80}\n") |
| if candidates[0]['score'] > 1e-3: |
| f.write("✅ Strong match.\n") |
| elif candidates[0]['p6'] < 0.1: |
| f.write("⚠️ Low confidence.\n") |
|
|
|
|
| def main(): |
| print("Loading bert-base-uncased FULL FT + ArcFace model (3-level, 6-digit)...") |
|
|
| if not os.path.exists(CONFIG_PATH): |
| print(f"Config not found: {CONFIG_PATH}. Train first.") |
| return |
|
|
| try: |
| config = json.load(open(CONFIG_PATH)) |
| model_name = config['model_name'] |
| hidden_size = config['hidden_size'] |
| max_seq_len = config['max_seq_len'] |
| counts = config['classes'] |
| dropout = config.get('dropout', 0.15) |
| arc_s = config.get('arcface_scale', 30.0) |
| arc_m = config.get('arcface_margin', 0.3) |
|
|
| l2id_2 = json.load(open(DICT_2)) |
| l2id_4 = json.load(open(DICT_4)) |
| l2id_6 = json.load(open(DICT_6)) |
|
|
| id2l_2 = {v: k for k, v in l2id_2.items()} |
| id2l_4 = {v: k for k, v in l2id_4.items()} |
| id2l_6 = {v: k for k, v in l2id_6.items()} |
|
|
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
|
|
| if os.path.exists(BASE_MODEL_PATH): |
| base_model = AutoModel.from_pretrained(BASE_MODEL_PATH) |
| else: |
| base_model = AutoModel.from_pretrained(model_name) |
|
|
| model = CascadedClassifier( |
| base_model=base_model, hidden_size=hidden_size, |
| n2=counts['n2'], n4=counts['n4'], n6=counts['n6'], |
| dropout=dropout, arc_s=arc_s, arc_m=arc_m |
| ).to(device) |
|
|
| if os.path.exists(FULL_MODEL_PATH): |
| state_dict = torch.load(FULL_MODEL_PATH, map_location=device) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| model.eval() |
| print(f"Loaded. Best val acc: {config.get('best_val_acc_6', 'N/A')}%") |
| print(f"Mode: {config.get('training_mode', 'N/A')}") |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return |
|
|
| |
| with open(RESULTS_PATH, 'a', encoding='utf-8') as f: |
| f.write(f"\n{'#'*80}\n") |
| f.write(f"Test session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| f.write(f"Model: {config.get('model_name', 'N/A')}\n") |
| f.write(f"Architecture: {config.get('architecture', 'N/A')}\n") |
| f.write(f"Best val acc (6-digit): {config.get('best_val_acc_6', 'N/A')}%\n") |
| f.write(f"{'#'*80}\n") |
|
|
| print(f"\n📝 Results will be saved to: {RESULTS_PATH}") |
| print("\n--- HS Code Classification (3-level, 6-digit) ---") |
| print("Type description or 'q' to quit.\n") |
|
|
| while True: |
| try: |
| text = input("Description: ") |
| except (KeyboardInterrupt, EOFError): |
| break |
| if text.lower() in ('q', 'quit', 'exit') or not text.strip(): |
| if not text.strip(): |
| continue |
| break |
|
|
| enc = tokenizer(text, max_length=max_seq_len, padding='max_length', |
| truncation=True, return_tensors='pt') |
| ids = enc['input_ids'].to(device) |
| mask = enc['attention_mask'].to(device) |
|
|
| with torch.no_grad(): |
| with torch.amp.autocast('cuda'): |
| o2, o4, o6 = model(ids, mask) |
|
|
| p2 = F.softmax(o2, dim=1) |
| p4 = F.softmax(o4, dim=1) |
| p6 = F.softmax(o6, dim=1) |
|
|
| _, b2 = torch.max(p2, 1) |
| b2c = id2l_2.get(b2.item(), "") |
| _, b4 = torch.max(p4, 1) |
| b4c = id2l_4.get(b4.item(), "") |
|
|
| top_p, top_i = torch.topk(p6, 10, dim=1) |
|
|
| candidates = [] |
| for j in range(10): |
| idx = top_i[0][j].item() |
| prob6 = top_p[0][j].item() |
| code6 = id2l_6.get(idx, "Unk") |
|
|
| def get_prob(code_str, mapper, probs): |
| for k, v in mapper.items(): |
| if v == code_str: |
| return probs[0][k].item() |
| return 0.0 |
|
|
| pr2 = get_prob(code6[:2], id2l_2, p2) |
| pr4 = get_prob(code6[:4], id2l_4, p4) |
|
|
| eps = 1e-6 |
| score = (prob6**2) * ((pr4+eps)**0.5) * ((pr2+eps)**0.5) |
| if code6.startswith(b4c): |
| score *= 10.0 |
| elif code6[:2] == b2c: |
| score *= 5.0 |
|
|
| candidates.append({"code": code6, "score": score, "p6": prob6, |
| "p4": pr4, "p2": pr2}) |
|
|
| candidates.sort(key=lambda x: x["score"], reverse=True) |
|
|
| print(f"\n Cascade: {b2c} → {b4c}") |
| print("-" * 80) |
| print(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain") |
| print("-" * 80) |
| for i in range(min(5, len(candidates))): |
| c = candidates[i] |
| cd = c["code"] |
| ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})" |
| print(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}") |
| print("-" * 80) |
|
|
| if candidates[0]['score'] > 1e-3: |
| print("✅ Strong match.") |
| elif candidates[0]['p6'] < 0.1: |
| print("⚠️ Low confidence.") |
|
|
| |
| save_result(RESULTS_PATH, text, candidates, b2c, b4c) |
| print(f" 📝 Saved to {RESULTS_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|