RetailTalk / backend /models /classifier.py
Dashm
Initial commit — RetailTalk backend for HuggingFace Spaces
26d82f3
"""
Classifier Service — loads the trained QueryProductClassifier model
and classifies (query, product) pairs into E/S/C/I labels.
"""
import os
import torch
import torch.nn as nn
import numpy as np
from config import CLASSIFIER_MODEL_PATH, BERT_EMBEDDING_DIM
class QueryProductClassifier(nn.Module):
"""
Feed-forward classifier that takes concatenated query + product embeddings
and classifies into E/S/C/I categories.
This is a copy of the trained model architecture from the ESCI project.
Must stay in sync with classification_identification/query_product/classifier_model.py
"""
def __init__(self, size_pretrained=768, dense_hidden_dim=256, num_dense_layers=2, num_labels=4, dropout_rate=0.1):
super(QueryProductClassifier, self).__init__()
self.num_labels = 1 if num_labels <= 2 else num_labels
self.size_pretrained = size_pretrained * 2 # query + product concatenated
fc_layers = []
prev_dim = self.size_pretrained
self.dropout_embedding = nn.Dropout(dropout_rate)
for _ in range(num_dense_layers):
fc_layers.append(nn.Linear(prev_dim, dense_hidden_dim, bias=True))
fc_layers.append(nn.BatchNorm1d(dense_hidden_dim))
fc_layers.append(nn.ReLU())
fc_layers.append(nn.Dropout(dropout_rate))
prev_dim = dense_hidden_dim
fc_layers.append(nn.Linear(prev_dim, self.num_labels))
self.fc = nn.Sequential(*fc_layers)
def forward(self, query_embedding, product_embedding):
embedding = torch.cat((query_embedding, product_embedding), 1)
embedding = self.dropout_embedding(embedding)
logits = self.fc(embedding).squeeze(-1)
return logits
# Label mapping
CLASS_ID_TO_LABEL = {
0: "Exact",
1: "Substitute",
2: "Complement",
3: "Irrelevant",
}
# Priority for sorting (lower = more relevant = shown first)
LABEL_PRIORITY = {
"Exact": 0,
"Substitute": 1,
"Complement": 2,
"Irrelevant": 3,
}
class ClassifierService:
"""Singleton service that classifies query-product pairs."""
def __init__(self):
self.model = None
self.device = None
self._loaded = False
def load(self):
"""Load the trained classifier model. Call once at app startup."""
if self._loaded:
return
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = CLASSIFIER_MODEL_PATH
if not os.path.exists(model_path):
print(f"[ClassifierService] WARNING: Model file not found at {model_path}")
print("[ClassifierService] Search will use similarity-only ranking (no E/S/C/I classification)")
return
print(f"[ClassifierService] Loading classifier from {model_path}...")
self.model = QueryProductClassifier(
size_pretrained=BERT_EMBEDDING_DIM,
num_labels=4,
)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self._loaded = True
print(f"[ClassifierService] Classifier loaded on {self.device}")
def classify(self, query_embedding: np.ndarray, product_embedding: np.ndarray) -> dict:
"""
Classify a single (query, product) pair.
Returns: {"label": "Exact", "confidence": 0.92, "class_id": 0}
"""
if not self._loaded:
return {"label": "Unknown", "confidence": 0.0, "class_id": -1}
q = torch.tensor(query_embedding).float().unsqueeze(0).to(self.device)
p = torch.tensor(product_embedding).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(q, p)
probabilities = torch.softmax(logits, dim=1)
class_id = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][class_id].item()
return {
"label": CLASS_ID_TO_LABEL[class_id],
"confidence": round(confidence, 4),
"class_id": class_id,
}
def classify_batch(self, query_embedding: np.ndarray, product_embeddings: np.ndarray) -> list[dict]:
"""
Classify a query against multiple products at once.
query_embedding: shape (768,)
product_embeddings: shape (N, 768)
Returns list of classification dicts.
"""
if not self._loaded:
return [{"label": "Unknown", "confidence": 0.0, "class_id": -1}] * len(product_embeddings)
n = product_embeddings.shape[0]
# Repeat query embedding N times to match batch
q = torch.tensor(np.tile(query_embedding, (n, 1))).float().to(self.device)
p = torch.tensor(product_embeddings).float().to(self.device)
with torch.no_grad():
logits = self.model(q, p)
probabilities = torch.softmax(logits, dim=1)
class_ids = torch.argmax(probabilities, dim=1).cpu().numpy()
confidences = probabilities.max(dim=1).values.cpu().numpy()
all_probs = probabilities.cpu().numpy()
results = []
for i in range(n):
results.append({
"label": CLASS_ID_TO_LABEL[int(class_ids[i])],
"confidence": round(float(confidences[i]), 4),
"class_id": int(class_ids[i]),
"exact_prob": round(float(all_probs[i][0]), 4),
"substitute_prob": round(float(all_probs[i][1]), 4),
"complement_prob": round(float(all_probs[i][2]), 4),
"irrelevant_prob": round(float(all_probs[i][3]), 4),
})
return results
# Global singleton instance
classifier_service = ClassifierService()