aac-chatbot / backend /retrieval /reranker.py
shwetangisingh's picture
Add MMR reranking with conversation-context query fusion
c511e09
import torch
from backend.pipeline.state import RetrievedChunk
from backend.retrieval.vector_store import get_device, get_embedder
def build_context_vector(
raw_query: str,
history: list[dict] | None,
last_n_turns: int,
weight_current: float,
) -> torch.Tensor:
# Keep the current query dominant — silently dropping it (weight=0) makes
# the reranker drift onto stale conversation topics.
weight_current = max(0.05, min(1.0, weight_current))
embedder = get_embedder()
device = get_device()
q_vec = embedder.encode(
[raw_query],
convert_to_tensor=True,
normalize_embeddings=True,
device=device,
)[0]
if not history or last_n_turns <= 0 or weight_current >= 1.0:
return q_vec
user_turns = [
(h.get("content") or "").strip()
for h in history
if h.get("role") == "user" and (h.get("content") or "").strip()
]
recent = user_turns[-last_n_turns:]
if not recent:
return q_vec
h_vecs = embedder.encode(
recent,
convert_to_tensor=True,
normalize_embeddings=True,
device=device,
)
h_mean = h_vecs.mean(dim=0)
fused = weight_current * q_vec + (1.0 - weight_current) * h_mean
return fused / fused.norm().clamp_min(1e-12)
def mmr_rerank(
query_vec: torch.Tensor,
candidate_vecs: torch.Tensor,
candidate_chunks: list[RetrievedChunk],
top_k: int,
lambda_: float,
) -> list[RetrievedChunk]:
n = candidate_vecs.shape[0]
if n == 0 or top_k <= 0:
return []
if n <= top_k and lambda_ >= 1.0:
return candidate_chunks[:top_k]
rel = candidate_vecs @ query_vec # (N,)
pair = candidate_vecs @ candidate_vecs.T # (N, N)
target = min(top_k, n)
NEG_INF = torch.full((), float("-inf"), device=rel.device, dtype=rel.dtype)
available_mask = torch.ones(n, dtype=torch.bool, device=rel.device)
max_sim_to_selected = torch.full(
(n,), float("-inf"), device=rel.device, dtype=rel.dtype
)
selected: list[int] = []
selected_scores: list[float] = []
for step in range(target):
if step == 0:
scores = rel.clone()
else:
scores = lambda_ * rel - (1.0 - lambda_) * max_sim_to_selected
masked = torch.where(available_mask, scores, NEG_INF)
idx = int(torch.argmax(masked).item())
selected.append(idx)
selected_scores.append(float(scores[idx].item()))
available_mask[idx] = False
max_sim_to_selected = torch.maximum(max_sim_to_selected, pair[idx])
out: list[RetrievedChunk] = []
for idx, score in zip(selected, selected_scores):
chunk = dict(candidate_chunks[idx])
chunk["score"] = score
out.append(chunk) # type: ignore[arg-type]
return out