RetailMind / modules /drift.py
hodfa840's picture
fix: resolve 3 failing CI tests
778855c
"""
Semantic drift detector for RetailMind.
Tracks the rolling semantic similarity of incoming user queries against
predefined *concept anchors* (e.g., price-sensitivity, seasonal shift,
eco-trend). When the exponentially-weighted moving average for any concept
exceeds a configurable threshold the system flags an active drift β€” which
triggers the self-healing adapter to rewrite the LLM system prompt.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from modules.shared import get_embedding_model
logger = logging.getLogger(__name__)
@dataclass
class DriftEvent:
"""Immutable record of a single drift measurement."""
timestamp: float
query: str
scores: dict[str, float]
dominant: str
@dataclass
class DriftDetector:
"""
Monitors semantic drift across configurable concept anchors.
Uses EWMA (exponentially weighted moving average) to smooth noisy
single-query scores into stable trend signals.
"""
threshold: float = 0.38
ewma_alpha: float = 0.35 # smoothing factor (higher = more reactive)
history: list[DriftEvent] = field(default_factory=list)
_ewma: dict[str, float] = field(default_factory=dict)
_concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
def __post_init__(self) -> None:
model = get_embedding_model()
# Multiple anchor phrases per concept β†’ averaged embedding for robustness
concept_phrases = {
"price_sensitive": [
"cheap budget discount low price clearance sale savings affordable",
"what is the cheapest option under twenty dollars bargain deal",
"I only have a limited budget, show me value picks",
],
"summer_shift": [
"summer heat warm weather sandals shorts sunscreen beach",
"lightweight breathable sun protection hot climate UV",
"vacation tropical poolside outdoor warm temperature",
],
"eco_trend": [
"eco-friendly sustainable organic recycled environment green",
"plant-based carbon-neutral zero waste biodegradable vegan",
"responsible sourcing ethical production renewable materials",
],
}
for concept, phrases in concept_phrases.items():
embs = model.encode(phrases, show_progress_bar=False)
self._concept_embs[concept] = np.mean(embs, axis=0)
self._ewma[concept] = 0.0
logger.info("DriftDetector initialized with %d concept anchors.", len(concept_phrases))
# ── Public API ──────────────────────────────────────────────────────────
def analyze_drift(
self, query: str, query_emb=None
) -> tuple[str, dict[str, float]]:
"""
Score *query* against all concept anchors and return
``(dominant_concept, raw_scores)``.
Pass *query_emb* to skip re-encoding when the caller already has it.
"""
if query_emb is None:
query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
raw_scores: dict[str, float] = {}
for concept, ref_emb in self._concept_embs.items():
sim = float(
np.dot(query_emb, ref_emb)
/ (np.linalg.norm(query_emb) * np.linalg.norm(ref_emb) + 1e-10)
)
raw_scores[concept] = sim
# Update EWMA
prev = self._ewma[concept]
self._ewma[concept] = self.ewma_alpha * sim + (1 - self.ewma_alpha) * prev
# Determine dominant drift from smoothed signal
detected = "normal"
max_smoothed = 0.0
for concept, smoothed in self._ewma.items():
if smoothed > self.threshold and smoothed > max_smoothed:
max_smoothed = smoothed
detected = concept
event = DriftEvent(
timestamp=time.time(),
query=query,
scores=raw_scores,
dominant=detected,
)
self.history.append(event)
if len(self.history) > 200:
self.history = self.history[-200:]
logger.debug("Drift analysis: %s | scores=%s | ewma=%s", detected, raw_scores, self._ewma)
return detected, raw_scores
def get_ewma_scores(self) -> dict[str, float]:
"""Return current EWMA-smoothed scores for dashboard display."""
return dict(self._ewma)
def get_recent_stats(self) -> dict[str, float] | None:
"""Return averaged raw scores from last N queries."""
if not self.history:
return None
recent = self.history[-5:]
concepts = list(self._concept_embs.keys())
return {
c: float(np.mean([e.scores[c] for e in recent]))
for c in concepts
}
def get_history_series(self) -> dict[str, list[float]]:
"""Return full EWMA time-series for each concept (for charts).
Pads with baseline values when fewer than 5 real events exist so the
chart renders a smooth baseline line on first load.
"""
series: dict[str, list[float]] = {c: [] for c in self._concept_embs}
ewma_state = {c: 0.0 for c in self._concept_embs}
# Pad with neutral baseline so chart always has something to show
padding = max(0, 5 - len(self.history))
for _ in range(padding):
for c in self._concept_embs:
series[c].append(0.15)
for event in self.history:
for c in self._concept_embs:
ewma_state[c] = self.ewma_alpha * event.scores[c] + (1 - self.ewma_alpha) * ewma_state[c]
series[c].append(ewma_state[c])
return series