RetailTalk / backend /models /query_rewriter.py
Dashm
Initial commit — RetailTalk backend for HuggingFace Spaces
26d82f3
"""
Query Rewriter — transforms raw user queries into structured search inputs
using intent classification and slot extraction.
Takes: raw query + intents + extracted slots
Returns: rewritten search text + structured filters + metadata
"""
import re
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class SearchGroup:
"""A single search sub-query with its own text and filters."""
search_text: str
filters: dict = field(default_factory=dict)
@dataclass
class RewrittenQuery:
"""Output of the query rewriter."""
search_text: str # Cleaned query for BERT embedding + keyword matching
filters: dict # Structured filters for Supabase WHERE clauses
original_query: str # The raw user input
intents: list = field(default_factory=list) # Detected intents
slots: dict = field(default_factory=dict) # All extracted slots
is_rewritten: bool = False # Whether any rewriting was applied
search_groups: list = field(default_factory=list) # List[SearchGroup] for compound queries
# Slot types that represent product names (included in search text)
PRODUCT_SLOTS = {"PRODUCT1", "PRODUCT2"}
# Slot types that become Supabase filters (excluded from search text)
FILTER_SLOTS = {"PRICE_MIN", "PRICE_MAX", "PRICE_MOD", "BRAND", "COLOR",
"SIZE", "RATING_MIN", "RATING_MOD"}
# Modifier words to strip from search text (common non-product words)
MODIFIER_WORDS = {
"under", "below", "less", "than", "above", "over", "more",
"at", "least", "most", "around", "between", "and",
"cheaper", "cheapest", "expensive",
"budget", "affordable", "cheap", "pricey",
"minimum", "maximum", "max", "min",
"rating", "rated", "stars", "star",
"i", "want", "need", "looking", "for", "find", "show", "me",
"the", "a", "an", "of", "with", "in", "na", "ng", "ang", "yung",
"paano", "saan", "ano", "may", "gusto", "ko", "hanap",
"magkano", "pesos", "peso", "php",
}
# Conjunction patterns for compound query splitting
COMPOUND_CONJUNCTIONS = [
r'\s+and\s+',
r'\s+at\s+saka\s+',
r'\s+tapos\s+',
r'\s+tsaka\s+',
r'\s+pati\s+(na\s+)?',
]
def split_compound_query(query: str) -> list[str]:
"""
Split a compound query on conjunctions ('and', 'at saka', etc.)
into separate product searches.
"party items less than 300 and shoes for kids less 200"
-> ["party items less than 300", "shoes for kids less 200"]
"peanut butter and jelly under 200"
-> ["peanut butter", "jelly under 200"]
"""
for conj in COMPOUND_CONJUNCTIONS:
parts = re.split(conj, query, flags=re.IGNORECASE)
if len(parts) >= 2:
cleaned = [p.strip() for p in parts if p and p.strip()]
if len(cleaned) >= 2:
return cleaned
return [query]
def split_sentences(query: str) -> list[str]:
"""
Split a multi-sentence or compound query into individual sub-queries.
1. Splits on . ? ! followed by whitespace (standard sentence splitting).
2. Splits on commas that act as product separators.
3. Then splits each sentence on conjunctions ('and', 'at saka').
"""
# Split on sentence-ending punctuation: . followed by space+letter (avoids
# decimals like "3.5"), or ? / ! at end/followed by whitespace.
parts = re.split(r'\.(?=\s+[A-Za-z])|[?!](?:\s+|$)', query)
sentences = [s.strip().rstrip('.!?') for s in parts if s and s.strip()]
sentences = [s for s in sentences if s]
if not sentences:
sentences = [query.strip()]
# Split on commas — treat commas as product separators
comma_parts = []
for sent in sentences:
csv = [p.strip() for p in sent.split(',') if p and p.strip()]
comma_parts.extend(csv if len(csv) >= 2 else [sent])
# Compound query splitting on each part (handles 'and')
all_parts = []
for part in comma_parts:
compound_parts = split_compound_query(part)
all_parts.extend(compound_parts)
return all_parts if all_parts else [query.strip()]
def _merge_rewritten_queries(
sub_queries: list[RewrittenQuery],
original_query: str,
) -> RewrittenQuery:
"""Merge multiple per-sentence RewrittenQuery results into one."""
if len(sub_queries) == 1:
rq = sub_queries[0]
rq.search_groups = [SearchGroup(search_text=rq.search_text, filters=rq.filters)]
return rq
# Union of all intents (deduplicated, preserving order)
merged_intents = list(dict.fromkeys(
intent for rq in sub_queries for intent in rq.intents
))
# Each sub-query that originated from a separator (dot, comma, 'and') represents
# a distinct product search — always keep them as independent search groups.
# We only collapse to a single group when all sub-queries share the same search
# text (i.e., splitting produced no meaningful separation).
distinct_texts = len({rq.search_text for rq in sub_queries}) > 1
if distinct_texts:
# ── COMPOUND QUERY ──
# Each sub-query becomes its own SearchGroup with independent filters.
search_groups = [
SearchGroup(search_text=rq.search_text, filters=rq.filters)
for rq in sub_queries
]
merged_slots = {}
product_idx = 1
for rq in sub_queries:
for key, value in rq.slots.items():
if key in ("PRODUCT1", "PRODUCT2"):
slot_key = f"PRODUCT{product_idx}"
if slot_key not in merged_slots:
merged_slots[slot_key] = value
product_idx += 1
elif key not in merged_slots:
merged_slots[key] = value
search_text = " | ".join(g.search_text for g in search_groups)
return RewrittenQuery(
search_text=search_text,
filters={},
original_query=original_query,
intents=merged_intents,
slots=merged_slots,
is_rewritten=True,
search_groups=search_groups,
)
# ── SINGLE GROUP (original merge behavior) ──
merged_slots = {}
for rq in sub_queries:
for key, value in rq.slots.items():
if key in ("PRODUCT1", "PRODUCT2"):
if "PRODUCT1" not in merged_slots:
merged_slots["PRODUCT1"] = value
elif "PRODUCT2" not in merged_slots:
merged_slots["PRODUCT2"] = value
elif key not in merged_slots:
merged_slots[key] = value
merged_filters = {}
for rq in sub_queries:
for key, value in rq.filters.items():
if key not in merged_filters:
merged_filters[key] = value
elif key == "price_min":
merged_filters[key] = max(merged_filters[key], value)
elif key == "price_max":
merged_filters[key] = min(merged_filters[key], value)
elif key == "rating_min":
merged_filters[key] = max(merged_filters[key], value)
seen = set()
search_parts = []
for rq in sub_queries:
for word in rq.search_text.split():
lower = word.lower()
if lower not in seen:
seen.add(lower)
search_parts.append(word)
search_text = " ".join(search_parts)
return RewrittenQuery(
search_text=search_text,
filters=merged_filters,
original_query=original_query,
intents=merged_intents,
slots=merged_slots,
is_rewritten=any(rq.is_rewritten for rq in sub_queries),
search_groups=[SearchGroup(search_text=search_text, filters=merged_filters)],
)
def _parse_price(value: str) -> Optional[float]:
"""Try to parse a numeric value from a price slot."""
clean = re.sub(r"[^\d.]", "", value)
try:
return float(clean)
except ValueError:
return None
def _detect_price_direction(query: str) -> Optional[str]:
"""
Detect whether the user's price intent is a minimum or maximum
based on modifier words in the raw query.
Returns "min", "max", or None if ambiguous.
"""
q = query.lower()
# Patterns that indicate a MINIMUM price ("more than X", "above X", etc.)
min_patterns = [
r"\bmore\s+than\b", r"\babove\b", r"\bover\b", r"\bat\s+least\b",
r"\bhigher\s+than\b", r"\bstarting\b", r"\bfrom\b",
r"\bexpensive\b", r"\bpricey\b",
# Filipino
r"\bhigit\s+sa\b", r"\bmula\s+sa\b",
]
# Patterns that indicate a MAXIMUM price ("less than X", "under X", etc.)
max_patterns = [
r"\bless\s+than\b", r"\bunder\b", r"\bbelow\b", r"\bat\s+most\b",
r"\bcheaper\s+than\b", r"\bbudget\b", r"\bcheap\b", r"\baffordable\b",
# Filipino
r"\bmura\b", r"\bmababa\b",
]
for pat in min_patterns:
if re.search(pat, q):
return "min"
for pat in max_patterns:
if re.search(pat, q):
return "max"
return None
def rewrite(query: str, intents: list[str], slots: dict) -> RewrittenQuery:
"""
Rewrite a user query based on detected intents and extracted slots.
Logic:
- For free_form queries (and no other intents): pass through as-is
- For filtered_search: extract filter slots into structured filters,
build search text from product slots only
- For single_search / multi_search: build search text from product slots,
include brand/color in search text too
"""
# Default: use original query as-is
result = RewrittenQuery(
search_text=query.strip(),
filters={},
original_query=query.strip(),
intents=intents,
slots=slots,
)
# If no intents or slots were extracted, return original query
if not intents and not slots:
return result
# Free-form intent with no product slots: pass through as-is
# (e.g., "pano magluto ng adobo" — not a product search)
if "free_form" in intents and len(intents) == 1 and not slots:
return result
# --- Correct price slot direction ---
# The NER model may tag the price value as PRICE_MAX when the user
# actually means "more than X" (a minimum), or vice versa.
# Use modifier words in the raw query to fix this.
direction = _detect_price_direction(query)
has_min = "PRICE_MIN" in slots
has_max = "PRICE_MAX" in slots
if direction == "min" and has_max and not has_min:
# NER said PRICE_MAX but user said "more than" → swap to PRICE_MIN
slots["PRICE_MIN"] = slots.pop("PRICE_MAX")
elif direction == "max" and has_min and not has_max:
# NER said PRICE_MIN but user said "under" → swap to PRICE_MAX
slots["PRICE_MAX"] = slots.pop("PRICE_MIN")
# --- Regex fallback: extract price if NER missed it ---
# Covers patterns like "less than 30", "under 500", "more than 100", etc.
if direction is not None and "PRICE_MIN" not in slots and "PRICE_MAX" not in slots:
price_match = re.search(r'(\d+(?:\.\d+)?)', query)
if price_match:
price_val = price_match.group(1)
if direction == "max":
slots["PRICE_MAX"] = price_val
print(f"[QueryRewriter] Regex fallback: PRICE_MAX={price_val} (from '{query}')")
elif direction == "min":
slots["PRICE_MIN"] = price_val
print(f"[QueryRewriter] Regex fallback: PRICE_MIN={price_val} (from '{query}')")
# --- Build structured filters from slots ---
filters = {}
price_max = slots.get("PRICE_MAX")
if price_max:
parsed = _parse_price(price_max)
if parsed is not None:
filters["price_max"] = parsed
price_min = slots.get("PRICE_MIN")
if price_min:
parsed = _parse_price(price_min)
if parsed is not None:
filters["price_min"] = parsed
brand = slots.get("BRAND")
if brand:
filters["brand"] = brand.strip()
color = slots.get("COLOR")
if color:
filters["color"] = color.strip()
size = slots.get("SIZE")
if size:
filters["size"] = size.strip()
rating_min = slots.get("RATING_MIN")
if rating_min:
parsed = _parse_price(rating_min)
if parsed is not None:
filters["rating_min"] = parsed
# --- Build search text ---
search_parts = []
# Include product names
for slot_type in ["PRODUCT1", "PRODUCT2"]:
if slot_type in slots:
search_parts.append(slots[slot_type].strip())
# Include brand in search text (helps BERT + keyword matching)
if brand:
search_parts.insert(0, brand.strip())
# Include color in search text (helps keyword matching)
if color:
search_parts.insert(0, color.strip())
# If we have product-related slots, use them as the search text
if search_parts:
search_text = " ".join(search_parts)
else:
# No product slots found — clean the original query
# Remove modifier words and price values
words = query.strip().split()
cleaned = [
w for w in words
if w.lower() not in MODIFIER_WORDS
and not re.match(r"^\d+$", w)
]
search_text = " ".join(cleaned) if cleaned else query.strip()
result.search_text = search_text
result.filters = filters
result.is_rewritten = bool(filters) or (search_text != query.strip())
return result
class QueryRewriterService:
"""
Orchestrates intent classification + slot extraction + query rewriting.
This is the main entry point called from the search route.
"""
def __init__(self):
self._intent_service = None
self._slot_service = None
def init(self, intent_service, slot_service):
"""Initialize with references to the intent and slot services."""
self._intent_service = intent_service
self._slot_service = slot_service
def process(self, query: str) -> RewrittenQuery:
"""
Full query rewriting pipeline with multi-sentence support:
1. Split query into sentences
2. Process each sentence (intent + slot + rewrite)
3. Merge results into a single RewrittenQuery
Returns RewrittenQuery with search_text, filters, intents, and slots.
"""
sentences = split_sentences(query)
if len(sentences) == 1:
# Single sentence: no splitting overhead
result = self._process_single(sentences[0])
result.original_query = query.strip()
result.search_groups = [SearchGroup(search_text=result.search_text, filters=result.filters)]
self._log(query, result)
return result
# Multiple sentences: process each independently, then merge
sub_results = [self._process_single(s) for s in sentences]
merged = _merge_rewritten_queries(sub_results, original_query=query.strip())
if merged.is_rewritten:
print(f"[QueryRewriter] '{query}' -> '{merged.search_text}'")
print(f"[QueryRewriter] Sentences: {sentences}")
print(f"[QueryRewriter] Intents: {merged.intents}")
print(f"[QueryRewriter] Slots: {merged.slots}")
print(f"[QueryRewriter] Filters: {merged.filters}")
return merged
def _process_single(self, sentence: str) -> RewrittenQuery:
"""Process a single sentence through intent + slot + rewrite."""
# Step 1: Classify intents
intent_result = {"intents": [], "probabilities": {}}
if self._intent_service and self._intent_service._loaded:
intent_result = self._intent_service.predict(sentence)
# Step 2: Extract slots
slot_result = {"slots": {}, "tagged_tokens": []}
if self._slot_service and self._slot_service._loaded:
slot_result = self._slot_service.extract(sentence)
# Step 3: Rewrite
return rewrite(
query=sentence,
intents=intent_result["intents"],
slots=slot_result["slots"],
)
def _log(self, query: str, result: RewrittenQuery):
"""Log rewriting details for debugging."""
if result.is_rewritten:
print(f"[QueryRewriter] '{query}' -> '{result.search_text}'")
print(f"[QueryRewriter] Intents: {result.intents}")
print(f"[QueryRewriter] Slots: {result.slots}")
print(f"[QueryRewriter] Filters: {result.filters}")
# Global singleton
query_rewriter = QueryRewriterService()