Spaces:
Sleeping
Sleeping
| import torch | |
| from backend.pipeline.state import RetrievedChunk | |
| from backend.retrieval.vector_store import get_device, get_embedder | |
| def build_context_vector( | |
| raw_query: str, | |
| history: list[dict] | None, | |
| last_n_turns: int, | |
| weight_current: float, | |
| ) -> torch.Tensor: | |
| # Keep the current query dominant — silently dropping it (weight=0) makes | |
| # the reranker drift onto stale conversation topics. | |
| weight_current = max(0.05, min(1.0, weight_current)) | |
| embedder = get_embedder() | |
| device = get_device() | |
| q_vec = embedder.encode( | |
| [raw_query], | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| device=device, | |
| )[0] | |
| if not history or last_n_turns <= 0 or weight_current >= 1.0: | |
| return q_vec | |
| user_turns = [ | |
| (h.get("content") or "").strip() | |
| for h in history | |
| if h.get("role") == "user" and (h.get("content") or "").strip() | |
| ] | |
| recent = user_turns[-last_n_turns:] | |
| if not recent: | |
| return q_vec | |
| h_vecs = embedder.encode( | |
| recent, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| device=device, | |
| ) | |
| h_mean = h_vecs.mean(dim=0) | |
| fused = weight_current * q_vec + (1.0 - weight_current) * h_mean | |
| return fused / fused.norm().clamp_min(1e-12) | |
| def mmr_rerank( | |
| query_vec: torch.Tensor, | |
| candidate_vecs: torch.Tensor, | |
| candidate_chunks: list[RetrievedChunk], | |
| top_k: int, | |
| lambda_: float, | |
| ) -> list[RetrievedChunk]: | |
| n = candidate_vecs.shape[0] | |
| if n == 0 or top_k <= 0: | |
| return [] | |
| if n <= top_k and lambda_ >= 1.0: | |
| return candidate_chunks[:top_k] | |
| rel = candidate_vecs @ query_vec # (N,) | |
| pair = candidate_vecs @ candidate_vecs.T # (N, N) | |
| target = min(top_k, n) | |
| NEG_INF = torch.full((), float("-inf"), device=rel.device, dtype=rel.dtype) | |
| available_mask = torch.ones(n, dtype=torch.bool, device=rel.device) | |
| max_sim_to_selected = torch.full( | |
| (n,), float("-inf"), device=rel.device, dtype=rel.dtype | |
| ) | |
| selected: list[int] = [] | |
| selected_scores: list[float] = [] | |
| for step in range(target): | |
| if step == 0: | |
| scores = rel.clone() | |
| else: | |
| scores = lambda_ * rel - (1.0 - lambda_) * max_sim_to_selected | |
| masked = torch.where(available_mask, scores, NEG_INF) | |
| idx = int(torch.argmax(masked).item()) | |
| selected.append(idx) | |
| selected_scores.append(float(scores[idx].item())) | |
| available_mask[idx] = False | |
| max_sim_to_selected = torch.maximum(max_sim_to_selected, pair[idx]) | |
| out: list[RetrievedChunk] = [] | |
| for idx, score in zip(selected, selected_scores): | |
| chunk = dict(candidate_chunks[idx]) | |
| chunk["score"] = score | |
| out.append(chunk) # type: ignore[arg-type] | |
| return out | |