""" Classifier Service — loads the trained QueryProductClassifier model and classifies (query, product) pairs into E/S/C/I labels. """ import os import torch import torch.nn as nn import numpy as np from config import CLASSIFIER_MODEL_PATH, BERT_EMBEDDING_DIM class QueryProductClassifier(nn.Module): """ Feed-forward classifier that takes concatenated query + product embeddings and classifies into E/S/C/I categories. This is a copy of the trained model architecture from the ESCI project. Must stay in sync with classification_identification/query_product/classifier_model.py """ def __init__(self, size_pretrained=768, dense_hidden_dim=256, num_dense_layers=2, num_labels=4, dropout_rate=0.1): super(QueryProductClassifier, self).__init__() self.num_labels = 1 if num_labels <= 2 else num_labels self.size_pretrained = size_pretrained * 2 # query + product concatenated fc_layers = [] prev_dim = self.size_pretrained self.dropout_embedding = nn.Dropout(dropout_rate) for _ in range(num_dense_layers): fc_layers.append(nn.Linear(prev_dim, dense_hidden_dim, bias=True)) fc_layers.append(nn.BatchNorm1d(dense_hidden_dim)) fc_layers.append(nn.ReLU()) fc_layers.append(nn.Dropout(dropout_rate)) prev_dim = dense_hidden_dim fc_layers.append(nn.Linear(prev_dim, self.num_labels)) self.fc = nn.Sequential(*fc_layers) def forward(self, query_embedding, product_embedding): embedding = torch.cat((query_embedding, product_embedding), 1) embedding = self.dropout_embedding(embedding) logits = self.fc(embedding).squeeze(-1) return logits # Label mapping CLASS_ID_TO_LABEL = { 0: "Exact", 1: "Substitute", 2: "Complement", 3: "Irrelevant", } # Priority for sorting (lower = more relevant = shown first) LABEL_PRIORITY = { "Exact": 0, "Substitute": 1, "Complement": 2, "Irrelevant": 3, } class ClassifierService: """Singleton service that classifies query-product pairs.""" def __init__(self): self.model = None self.device = None self._loaded = False def load(self): """Load the trained classifier model. Call once at app startup.""" if self._loaded: return self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = CLASSIFIER_MODEL_PATH if not os.path.exists(model_path): print(f"[ClassifierService] WARNING: Model file not found at {model_path}") print("[ClassifierService] Search will use similarity-only ranking (no E/S/C/I classification)") return print(f"[ClassifierService] Loading classifier from {model_path}...") self.model = QueryProductClassifier( size_pretrained=BERT_EMBEDDING_DIM, num_labels=4, ) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() self._loaded = True print(f"[ClassifierService] Classifier loaded on {self.device}") def classify(self, query_embedding: np.ndarray, product_embedding: np.ndarray) -> dict: """ Classify a single (query, product) pair. Returns: {"label": "Exact", "confidence": 0.92, "class_id": 0} """ if not self._loaded: return {"label": "Unknown", "confidence": 0.0, "class_id": -1} q = torch.tensor(query_embedding).float().unsqueeze(0).to(self.device) p = torch.tensor(product_embedding).float().unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(q, p) probabilities = torch.softmax(logits, dim=1) class_id = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][class_id].item() return { "label": CLASS_ID_TO_LABEL[class_id], "confidence": round(confidence, 4), "class_id": class_id, } def classify_batch(self, query_embedding: np.ndarray, product_embeddings: np.ndarray) -> list[dict]: """ Classify a query against multiple products at once. query_embedding: shape (768,) product_embeddings: shape (N, 768) Returns list of classification dicts. """ if not self._loaded: return [{"label": "Unknown", "confidence": 0.0, "class_id": -1}] * len(product_embeddings) n = product_embeddings.shape[0] # Repeat query embedding N times to match batch q = torch.tensor(np.tile(query_embedding, (n, 1))).float().to(self.device) p = torch.tensor(product_embeddings).float().to(self.device) with torch.no_grad(): logits = self.model(q, p) probabilities = torch.softmax(logits, dim=1) class_ids = torch.argmax(probabilities, dim=1).cpu().numpy() confidences = probabilities.max(dim=1).values.cpu().numpy() all_probs = probabilities.cpu().numpy() results = [] for i in range(n): results.append({ "label": CLASS_ID_TO_LABEL[int(class_ids[i])], "confidence": round(float(confidences[i]), 4), "class_id": int(class_ids[i]), "exact_prob": round(float(all_probs[i][0]), 4), "substitute_prob": round(float(all_probs[i][1]), 4), "complement_prob": round(float(all_probs[i][2]), 4), "irrelevant_prob": round(float(all_probs[i][3]), 4), }) return results # Global singleton instance classifier_service = ClassifierService()