RetailTalk / backend /models /slot_service.py
Dashm
Initial commit — RetailTalk backend for HuggingFace Spaces
26d82f3
"""
Slot Extraction Service — runs the trained SlotExtractor (NER) model
for real-time inference on user queries.
Loaded once at startup, reused for all search requests.
Extracts: PRODUCT1, PRODUCT2, BRAND, COLOR, PRICE_MIN, PRICE_MAX,
PRICE_MOD, RATING_MIN, RATING_MOD, CONN, SIZE, etc.
"""
import os
import json
import torch
import torch.nn as nn
import numpy as np
from transformers import BertModel, BertTokenizerFast
from config import SLOT_MODEL_PATH, BERT_MODEL_NAME, SLOT_MAX_LENGTH
class SlotExtractor(nn.Module):
"""
BERT + token-level classification head for NER/slot extraction.
Must match the architecture used during training.
"""
def __init__(self, bert_model_name="bert-base-multilingual-uncased", num_tags=20):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(768, num_tags)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits
class SlotService:
"""Singleton service for slot/entity extraction at inference time."""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.tag2id = {}
self.id2tag = {}
self._loaded = False
def load(self):
"""Load the trained slot extractor. Call once at app startup."""
if self._loaded:
return
model_dir = SLOT_MODEL_PATH
model_path = os.path.join(model_dir, "model.pt")
tag_map_path = os.path.join(model_dir, "tag_map.json")
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(model_path):
print(f"[SlotService] WARNING: Model not found at {model_path}")
print("[SlotService] Slot extraction will be unavailable.")
return
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tag map
if os.path.exists(tag_map_path):
with open(tag_map_path, "r") as f:
self.tag2id = json.load(f)
elif os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
tag_names = config.get("tag_names", [])
self.tag2id = {tag: i for i, tag in enumerate(tag_names)}
else:
print("[SlotService] WARNING: No tag_map.json or config.json found")
return
self.id2tag = {v: k for k, v in self.tag2id.items()}
num_tags = len(self.tag2id)
print(f"[SlotService] Loading slot extractor ({num_tags} tags)...")
self.model = SlotExtractor(
bert_model_name=BERT_MODEL_NAME,
num_tags=num_tags,
)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(self.device)
self.model.eval()
self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
self._loaded = True
print(f"[SlotService] Slot extractor loaded on {self.device}")
print(f"[SlotService] Tags: {sorted(self.tag2id.keys())}")
def extract(self, query: str) -> dict:
"""
Extract slots/entities from a query using BIO tagging.
Returns:
{
"slots": {
"PRODUCT1": "shoes",
"BRAND": "Nike",
"COLOR": "blue",
"PRICE_MAX": "3000"
},
"tagged_tokens": [
("blue", "B-COLOR"),
("Nike", "B-BRAND"),
("shoes", "B-PRODUCT1"),
("under", "B-PRICE_MOD"),
("3000", "B-PRICE_MAX")
]
}
"""
if not self._loaded:
return {"slots": {}, "tagged_tokens": []}
# Tokenize
words = query.split()
encoding = self.tokenizer(
words,
is_split_into_words=True,
padding="max_length",
truncation=True,
max_length=SLOT_MAX_LENGTH,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
word_ids = encoding.word_ids(batch_index=0)
# Predict
with torch.no_grad():
logits = self.model(input_ids, attention_mask)
preds = torch.argmax(logits, dim=-1).cpu().numpy()[0]
# Decode: map subword predictions back to words
# Only take the first subword prediction for each word
word_tags = {}
for token_idx, word_idx in enumerate(word_ids):
if word_idx is None:
continue # [CLS], [SEP], [PAD]
if word_idx not in word_tags:
tag_id = int(preds[token_idx])
word_tags[word_idx] = self.id2tag.get(tag_id, "O")
# Build tagged tokens list
tagged_tokens = []
for word_idx, word in enumerate(words):
tag = word_tags.get(word_idx, "O")
tagged_tokens.append((word, tag))
# Merge BIO tags into slot dict
slots = self._merge_bio_tags(words, tagged_tokens)
return {
"slots": slots,
"tagged_tokens": tagged_tokens,
}
def _merge_bio_tags(self, words: list, tagged_tokens: list) -> dict:
"""
Merge BIO-tagged tokens into a slot dictionary.
Example:
[("blue", "B-COLOR"), ("Nike", "B-BRAND"), ("running", "B-PRODUCT1"),
("shoes", "I-PRODUCT1")]
-> {"COLOR": "blue", "BRAND": "Nike", "PRODUCT1": "running shoes"}
"""
slots = {}
current_entity = None
current_tokens = []
for word, tag in tagged_tokens:
if tag.startswith("B-"):
# Save previous entity
if current_entity and current_tokens:
slot_key = current_entity
slot_value = " ".join(current_tokens)
# Handle multiple entities of same type (e.g., PRODUCT1, PRODUCT2)
if slot_key in slots:
slots[slot_key] += " " + slot_value
else:
slots[slot_key] = slot_value
current_entity = tag[2:] # Strip "B-"
current_tokens = [word]
elif tag.startswith("I-"):
entity_type = tag[2:]
if entity_type == current_entity:
current_tokens.append(word)
else:
# Mismatched I-tag: save current, start new
if current_entity and current_tokens:
slot_key = current_entity
slot_value = " ".join(current_tokens)
if slot_key in slots:
slots[slot_key] += " " + slot_value
else:
slots[slot_key] = slot_value
current_entity = entity_type
current_tokens = [word]
else:
# O tag: save current entity
if current_entity and current_tokens:
slot_key = current_entity
slot_value = " ".join(current_tokens)
if slot_key in slots:
slots[slot_key] += " " + slot_value
else:
slots[slot_key] = slot_value
current_entity = None
current_tokens = []
# Save last entity
if current_entity and current_tokens:
slot_key = current_entity
slot_value = " ".join(current_tokens)
if slot_key in slots:
slots[slot_key] += " " + slot_value
else:
slots[slot_key] = slot_value
return slots
# Global singleton
slot_service = SlotService()