""" Enterprise Hybrid Retrieval Engine. Combines Dense Vectors (ChromaDB) and Sparse Keywords (BM25). Features a custom Reciprocal Rank Fusion (RRF) implementation with company-balanced output. UPGRADES vs previous version: - Corrected RRF math: weights now scale the full RRF fraction, not just the numerator. Previous: weight / (rank + K) — mathematically inconsistent with standard RRF. Fixed: weight * (1 / (rank + K)) — weight scales the entire rank-fusion score. - RRF_K constant imported from config (no hardcoding). - Atomic BM25 serialization via tempfile + shutil.move() — prevents partial-write corruption that left the index in an unrecoverable split state. - SHA-256 integrity check on BM25 load — detects file tampering or corruption. - Company-balanced retrieval — prevents one company dominating the context window (was 71.4% Meta in the original, revealed by the telemetry dashboard). - Full type annotations on all methods. """ import os import json import pickle import hashlib import tempfile import shutil from typing import Optional from langchain_chroma import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document from src.config import ( logger, VECTOR_DB_DIR, TOP_K_VECTORS, RRF_K, MAX_CHUNKS_PER_COMPANY, EMBEDDING_MODEL_NAME, ) # ── Reciprocal Rank Fusion ──────────────────────────────────────────────────── class CustomHybridRetriever: """ Reciprocal Rank Fusion (RRF) over dense + sparse retrieval results. RRF Formula (per retriever r, per document d at rank i): score(d) += weight_r * (1 / (rank_i + K)) Where K=60 is the standard smoothing constant that prevents top-ranked documents from receiving disproportionately high scores. Weighting: dense_weight + sparse_weight should equal 1.0. These scale each retriever's contribution to the fused score independently, allowing domain tuning (e.g., keyword-heavy financial text may warrant higher sparse_weight for exact figure matching). """ def __init__( self, dense_retriever, sparse_retriever, dense_weight: float = 0.5, sparse_weight: float = 0.5, ) -> None: if not (0.0 < dense_weight <= 1.0 and 0.0 < sparse_weight <= 1.0): raise ValueError("Retriever weights must be in (0.0, 1.0].") self.dense_retriever = dense_retriever self.sparse_retriever = sparse_retriever self.dense_weight = dense_weight self.sparse_weight = sparse_weight def _compute_rrf_scores( self, docs: list[Document], weight: float, rrf_scores: dict, doc_map: dict, ) -> None: """ Accumulate weighted RRF scores into rrf_scores in-place. Args: docs: Ranked list of documents from one retriever. weight: Scalar weight for this retriever's contribution. rrf_scores: Mutable dict mapping chunk_id → cumulative score. doc_map: Mutable dict mapping chunk_id → Document object. """ for rank, doc in enumerate(docs): # Use deterministic chunk_id from metadata (set during ingestion). # Fall back to a hash of content if metadata is missing — this # prevents KeyError but also signals an ingestion configuration issue. chunk_id: str = doc.metadata.get( "chunk_id", hashlib.sha256(doc.page_content.encode()).hexdigest()[:16], ) doc_map[chunk_id] = doc # CORRECTED: weight * (1 / (rank + K)) — weight scales the full fraction. rrf_scores[chunk_id] = ( rrf_scores.get(chunk_id, 0.0) + weight * (1.0 / (rank + RRF_K)) ) def _balance_by_company( self, sorted_docs: list[tuple[str, float]], doc_map: dict ) -> list[Document]: """ Return top-K documents with per-company diversity enforcement. Prevents any single company from dominating the context window. A query about "Google vs Meta R&D" should surface chunks from both, not 70%+ from whichever company had more keyword matches. Algorithm: iterate RRF-ranked list; admit a doc only if its company hasn't exceeded MAX_CHUNKS_PER_COMPANY. Continue until TOP_K_VECTORS are collected or the list is exhausted. Args: sorted_docs: List of (chunk_id, rrf_score) sorted desc by score. doc_map: Mapping from chunk_id to Document. Returns: Balanced list of up to TOP_K_VECTORS Document objects. """ company_counts: dict[str, int] = {} balanced: list[Document] = [] for chunk_id, _ in sorted_docs: if len(balanced) >= TOP_K_VECTORS: break doc = doc_map[chunk_id] company = doc.metadata.get("company", "unknown") count = company_counts.get(company, 0) if count < MAX_CHUNKS_PER_COMPANY: company_counts[company] = count + 1 balanced.append(doc) # Safety fallback: if balancing left us short (e.g., only 1 company # in the corpus), fill remaining slots from the unfiltered ranked list. if len(balanced) < TOP_K_VECTORS: seen_ids = {d.metadata.get("chunk_id") for d in balanced} for chunk_id, _ in sorted_docs: if len(balanced) >= TOP_K_VECTORS: break if chunk_id not in seen_ids: balanced.append(doc_map[chunk_id]) seen_ids.add(chunk_id) return balanced def invoke(self, query: str) -> list[Document]: """ Execute hybrid search and return RRF-fused, company-balanced results. Args: query: The user's natural language query string. Returns: List of up to TOP_K_VECTORS Document objects ranked by fused score. """ dense_docs: list[Document] = self.dense_retriever.invoke(query) sparse_docs: list[Document] = self.sparse_retriever.invoke(query) rrf_scores: dict[str, float] = {} doc_map: dict[str, Document] = {} self._compute_rrf_scores(dense_docs, self.dense_weight, rrf_scores, doc_map) self._compute_rrf_scores(sparse_docs, self.sparse_weight, rrf_scores, doc_map) sorted_docs: list[tuple[str, float]] = sorted( rrf_scores.items(), key=lambda x: x[1], reverse=True ) return self._balance_by_company(sorted_docs, doc_map) # ── Hybrid Retrieval Engine ─────────────────────────────────────────────────── class HybridRetrievalEngine: """ Orchestrates ChromaDB (dense) + BM25 (sparse) index construction and loading. Smart Load Logic: - If both indexes exist on disk → load without re-embedding (warm start). - If either is missing → build from scratch using document_chunks. - Atomic BM25 writes prevent split-state corruption. - SHA-256 integrity verification on BM25 load detects tampering/corruption. """ def __init__(self) -> None: self.vector_db_dir: str = VECTOR_DB_DIR self.bm25_path: str = os.path.join(VECTOR_DB_DIR, "bm25_index.pkl") self.bm25_hash_path: str = os.path.join(VECTOR_DB_DIR, "bm25_index.sha256") logger.info("Loading embedding model: %s", EMBEDDING_MODEL_NAME) self.embedding_model = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": "cpu"}, # explicit; avoids silent GPU fallback encode_kwargs={"normalize_embeddings": True}, # required for cosine similarity ) self.ensemble_retriever: Optional[CustomHybridRetriever] = None # ── Internal: Integrity Helpers ─────────────────────────────────────────── @staticmethod def _compute_file_sha256(path: str) -> str: """Compute SHA-256 hex digest of a file's binary contents.""" sha256 = hashlib.sha256() with open(path, "rb") as f: for block in iter(lambda: f.read(65536), b""): sha256.update(block) return sha256.hexdigest() def _write_bm25_with_integrity(self, sparse_retriever) -> None: """ Serialize BM25 retriever atomically with integrity hash. Steps: 1. Write pickle to a temp file in the same directory (same filesystem as destination → rename is atomic on POSIX). 2. Compute SHA-256 of the temp file. 3. Rename temp file to final path (atomic). 4. Write hash to .sha256 sidecar file. This prevents the partial-write state where Chroma exists but BM25 doesn't, causing an unrecoverable ValueError on the next cold start. """ dir_path = os.path.dirname(self.bm25_path) with tempfile.NamedTemporaryFile( dir=dir_path, suffix=".pkl", delete=False ) as tmp: pickle.dump(sparse_retriever, tmp) tmp_path = tmp.name file_hash = self._compute_file_sha256(tmp_path) shutil.move(tmp_path, self.bm25_path) # atomic rename with open(self.bm25_hash_path, "w") as hf: hf.write(file_hash) logger.info( "BM25 index serialized. SHA-256: %s", file_hash ) def _load_bm25_with_integrity(self): """ Deserialize BM25 and verify SHA-256 integrity before returning. Raises: RuntimeError: If the hash sidecar is missing or digest mismatches, indicating file corruption or tampering. """ if not os.path.exists(self.bm25_hash_path): raise RuntimeError( f"BM25 integrity file missing: {self.bm25_hash_path}. " "Delete the vector_db directory and rebuild indexes." ) with open(self.bm25_hash_path, "r") as hf: expected_hash = hf.read().strip() actual_hash = self._compute_file_sha256(self.bm25_path) if actual_hash != expected_hash: raise RuntimeError( f"BM25 index integrity check FAILED. " f"Expected {expected_hash}, got {actual_hash}. " "Index may be corrupted. Delete vector_db/ and rebuild." ) with open(self.bm25_path, "rb") as f: sparse_retriever = pickle.load(f) # safe: integrity verified above logger.info("BM25 index loaded and integrity verified.") return sparse_retriever # ── Public: Build or Load Indexes ──────────────────────────────────────── def build_indexes( self, document_chunks: Optional[list[Document]] = None, ) -> CustomHybridRetriever: """ Build indexes from chunks (cold start) or load them from disk (warm start). Args: document_chunks: Required for cold start (first run). Ignored on warm start (indexes already on disk). Returns: Initialized CustomHybridRetriever ready for querying. Raises: ValueError: If cold start is required but no chunks provided. RuntimeError: If BM25 integrity check fails on warm start. """ chroma_exists: bool = ( os.path.exists(self.vector_db_dir) and len(os.listdir(self.vector_db_dir)) > 0 ) bm25_exists: bool = os.path.exists(self.bm25_path) if chroma_exists and bm25_exists: logger.info( "[2/4] Smart Load: Found existing indexes on disk. " "Bypassing embedding compute..." ) vector_store = Chroma( persist_directory=self.vector_db_dir, embedding_function=self.embedding_model, ) sparse_retriever = self._load_bm25_with_integrity() else: if not document_chunks: raise ValueError( "No existing indexes found on disk and no document_chunks provided. " "Pass document_chunks=load_and_chunk_pdfs() for the initial build." ) logger.info( "[2/4] Cold start: Building Dense Vector Database (ChromaDB)..." ) vector_store = Chroma.from_documents( documents=document_chunks, embedding=self.embedding_model, persist_directory=self.vector_db_dir, ) logger.info("[2/4] Building Sparse Keyword Index (BM25)...") sparse_retriever = BM25Retriever.from_documents(document_chunks) sparse_retriever.k = TOP_K_VECTORS logger.info("[2/4] Serializing BM25 index with integrity hash...") self._write_bm25_with_integrity(sparse_retriever) dense_retriever = vector_store.as_retriever( search_kwargs={"k": TOP_K_VECTORS} ) logger.info("[2/4] Initializing Reciprocal Rank Fusion engine...") self.ensemble_retriever = CustomHybridRetriever( dense_retriever=dense_retriever, sparse_retriever=sparse_retriever, dense_weight=0.5, sparse_weight=0.5, ) logger.info("Hybrid Retrieval Engine ready.") return self.ensemble_retriever