financial-intelligence-engine / src /retrieval_engine.py
GitHub Action
sync: deploy from GitHub
1fce89d
"""
Enterprise Hybrid Retrieval Engine.
Combines Dense Vectors (ChromaDB) and Sparse Keywords (BM25).
Features a custom Reciprocal Rank Fusion (RRF) implementation with company-balanced output.
UPGRADES vs previous version:
- Corrected RRF math: weights now scale the full RRF fraction, not just the numerator.
Previous: weight / (rank + K) β€” mathematically inconsistent with standard RRF.
Fixed: weight * (1 / (rank + K)) β€” weight scales the entire rank-fusion score.
- RRF_K constant imported from config (no hardcoding).
- Atomic BM25 serialization via tempfile + shutil.move() β€” prevents partial-write
corruption that left the index in an unrecoverable split state.
- SHA-256 integrity check on BM25 load β€” detects file tampering or corruption.
- Company-balanced retrieval β€” prevents one company dominating the context window
(was 71.4% Meta in the original, revealed by the telemetry dashboard).
- Full type annotations on all methods.
"""
import os
import json
import pickle
import hashlib
import tempfile
import shutil
from typing import Optional
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from src.config import (
logger,
VECTOR_DB_DIR,
TOP_K_VECTORS,
RRF_K,
MAX_CHUNKS_PER_COMPANY,
EMBEDDING_MODEL_NAME,
)
# ── Reciprocal Rank Fusion ────────────────────────────────────────────────────
class CustomHybridRetriever:
"""
Reciprocal Rank Fusion (RRF) over dense + sparse retrieval results.
RRF Formula (per retriever r, per document d at rank i):
score(d) += weight_r * (1 / (rank_i + K))
Where K=60 is the standard smoothing constant that prevents top-ranked
documents from receiving disproportionately high scores.
Weighting: dense_weight + sparse_weight should equal 1.0. These scale
each retriever's contribution to the fused score independently, allowing
domain tuning (e.g., keyword-heavy financial text may warrant higher
sparse_weight for exact figure matching).
"""
def __init__(
self,
dense_retriever,
sparse_retriever,
dense_weight: float = 0.5,
sparse_weight: float = 0.5,
) -> None:
if not (0.0 < dense_weight <= 1.0 and 0.0 < sparse_weight <= 1.0):
raise ValueError("Retriever weights must be in (0.0, 1.0].")
self.dense_retriever = dense_retriever
self.sparse_retriever = sparse_retriever
self.dense_weight = dense_weight
self.sparse_weight = sparse_weight
def _compute_rrf_scores(
self,
docs: list[Document],
weight: float,
rrf_scores: dict,
doc_map: dict,
) -> None:
"""
Accumulate weighted RRF scores into rrf_scores in-place.
Args:
docs: Ranked list of documents from one retriever.
weight: Scalar weight for this retriever's contribution.
rrf_scores: Mutable dict mapping chunk_id β†’ cumulative score.
doc_map: Mutable dict mapping chunk_id β†’ Document object.
"""
for rank, doc in enumerate(docs):
# Use deterministic chunk_id from metadata (set during ingestion).
# Fall back to a hash of content if metadata is missing β€” this
# prevents KeyError but also signals an ingestion configuration issue.
chunk_id: str = doc.metadata.get(
"chunk_id",
hashlib.sha256(doc.page_content.encode()).hexdigest()[:16],
)
doc_map[chunk_id] = doc
# CORRECTED: weight * (1 / (rank + K)) β€” weight scales the full fraction.
rrf_scores[chunk_id] = (
rrf_scores.get(chunk_id, 0.0) + weight * (1.0 / (rank + RRF_K))
)
def _balance_by_company(
self, sorted_docs: list[tuple[str, float]], doc_map: dict
) -> list[Document]:
"""
Return top-K documents with per-company diversity enforcement.
Prevents any single company from dominating the context window.
A query about "Google vs Meta R&D" should surface chunks from both,
not 70%+ from whichever company had more keyword matches.
Algorithm: iterate RRF-ranked list; admit a doc only if its company
hasn't exceeded MAX_CHUNKS_PER_COMPANY. Continue until TOP_K_VECTORS
are collected or the list is exhausted.
Args:
sorted_docs: List of (chunk_id, rrf_score) sorted desc by score.
doc_map: Mapping from chunk_id to Document.
Returns:
Balanced list of up to TOP_K_VECTORS Document objects.
"""
company_counts: dict[str, int] = {}
balanced: list[Document] = []
for chunk_id, _ in sorted_docs:
if len(balanced) >= TOP_K_VECTORS:
break
doc = doc_map[chunk_id]
company = doc.metadata.get("company", "unknown")
count = company_counts.get(company, 0)
if count < MAX_CHUNKS_PER_COMPANY:
company_counts[company] = count + 1
balanced.append(doc)
# Safety fallback: if balancing left us short (e.g., only 1 company
# in the corpus), fill remaining slots from the unfiltered ranked list.
if len(balanced) < TOP_K_VECTORS:
seen_ids = {d.metadata.get("chunk_id") for d in balanced}
for chunk_id, _ in sorted_docs:
if len(balanced) >= TOP_K_VECTORS:
break
if chunk_id not in seen_ids:
balanced.append(doc_map[chunk_id])
seen_ids.add(chunk_id)
return balanced
def invoke(self, query: str) -> list[Document]:
"""
Execute hybrid search and return RRF-fused, company-balanced results.
Args:
query: The user's natural language query string.
Returns:
List of up to TOP_K_VECTORS Document objects ranked by fused score.
"""
dense_docs: list[Document] = self.dense_retriever.invoke(query)
sparse_docs: list[Document] = self.sparse_retriever.invoke(query)
rrf_scores: dict[str, float] = {}
doc_map: dict[str, Document] = {}
self._compute_rrf_scores(dense_docs, self.dense_weight, rrf_scores, doc_map)
self._compute_rrf_scores(sparse_docs, self.sparse_weight, rrf_scores, doc_map)
sorted_docs: list[tuple[str, float]] = sorted(
rrf_scores.items(), key=lambda x: x[1], reverse=True
)
return self._balance_by_company(sorted_docs, doc_map)
# ── Hybrid Retrieval Engine ───────────────────────────────────────────────────
class HybridRetrievalEngine:
"""
Orchestrates ChromaDB (dense) + BM25 (sparse) index construction and loading.
Smart Load Logic:
- If both indexes exist on disk β†’ load without re-embedding (warm start).
- If either is missing β†’ build from scratch using document_chunks.
- Atomic BM25 writes prevent split-state corruption.
- SHA-256 integrity verification on BM25 load detects tampering/corruption.
"""
def __init__(self) -> None:
self.vector_db_dir: str = VECTOR_DB_DIR
self.bm25_path: str = os.path.join(VECTOR_DB_DIR, "bm25_index.pkl")
self.bm25_hash_path: str = os.path.join(VECTOR_DB_DIR, "bm25_index.sha256")
logger.info("Loading embedding model: %s", EMBEDDING_MODEL_NAME)
self.embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": "cpu"}, # explicit; avoids silent GPU fallback
encode_kwargs={"normalize_embeddings": True}, # required for cosine similarity
)
self.ensemble_retriever: Optional[CustomHybridRetriever] = None
# ── Internal: Integrity Helpers ───────────────────────────────────────────
@staticmethod
def _compute_file_sha256(path: str) -> str:
"""Compute SHA-256 hex digest of a file's binary contents."""
sha256 = hashlib.sha256()
with open(path, "rb") as f:
for block in iter(lambda: f.read(65536), b""):
sha256.update(block)
return sha256.hexdigest()
def _write_bm25_with_integrity(self, sparse_retriever) -> None:
"""
Serialize BM25 retriever atomically with integrity hash.
Steps:
1. Write pickle to a temp file in the same directory (same filesystem
as destination β†’ rename is atomic on POSIX).
2. Compute SHA-256 of the temp file.
3. Rename temp file to final path (atomic).
4. Write hash to .sha256 sidecar file.
This prevents the partial-write state where Chroma exists but BM25
doesn't, causing an unrecoverable ValueError on the next cold start.
"""
dir_path = os.path.dirname(self.bm25_path)
with tempfile.NamedTemporaryFile(
dir=dir_path, suffix=".pkl", delete=False
) as tmp:
pickle.dump(sparse_retriever, tmp)
tmp_path = tmp.name
file_hash = self._compute_file_sha256(tmp_path)
shutil.move(tmp_path, self.bm25_path) # atomic rename
with open(self.bm25_hash_path, "w") as hf:
hf.write(file_hash)
logger.info(
"BM25 index serialized. SHA-256: %s", file_hash
)
def _load_bm25_with_integrity(self):
"""
Deserialize BM25 and verify SHA-256 integrity before returning.
Raises:
RuntimeError: If the hash sidecar is missing or digest mismatches,
indicating file corruption or tampering.
"""
if not os.path.exists(self.bm25_hash_path):
raise RuntimeError(
f"BM25 integrity file missing: {self.bm25_hash_path}. "
"Delete the vector_db directory and rebuild indexes."
)
with open(self.bm25_hash_path, "r") as hf:
expected_hash = hf.read().strip()
actual_hash = self._compute_file_sha256(self.bm25_path)
if actual_hash != expected_hash:
raise RuntimeError(
f"BM25 index integrity check FAILED. "
f"Expected {expected_hash}, got {actual_hash}. "
"Index may be corrupted. Delete vector_db/ and rebuild."
)
with open(self.bm25_path, "rb") as f:
sparse_retriever = pickle.load(f) # safe: integrity verified above
logger.info("BM25 index loaded and integrity verified.")
return sparse_retriever
# ── Public: Build or Load Indexes ────────────────────────────────────────
def build_indexes(
self,
document_chunks: Optional[list[Document]] = None,
) -> CustomHybridRetriever:
"""
Build indexes from chunks (cold start) or load them from disk (warm start).
Args:
document_chunks: Required for cold start (first run). Ignored on
warm start (indexes already on disk).
Returns:
Initialized CustomHybridRetriever ready for querying.
Raises:
ValueError: If cold start is required but no chunks provided.
RuntimeError: If BM25 integrity check fails on warm start.
"""
chroma_exists: bool = (
os.path.exists(self.vector_db_dir)
and len(os.listdir(self.vector_db_dir)) > 0
)
bm25_exists: bool = os.path.exists(self.bm25_path)
if chroma_exists and bm25_exists:
logger.info(
"[2/4] Smart Load: Found existing indexes on disk. "
"Bypassing embedding compute..."
)
vector_store = Chroma(
persist_directory=self.vector_db_dir,
embedding_function=self.embedding_model,
)
sparse_retriever = self._load_bm25_with_integrity()
else:
if not document_chunks:
raise ValueError(
"No existing indexes found on disk and no document_chunks provided. "
"Pass document_chunks=load_and_chunk_pdfs() for the initial build."
)
logger.info(
"[2/4] Cold start: Building Dense Vector Database (ChromaDB)..."
)
vector_store = Chroma.from_documents(
documents=document_chunks,
embedding=self.embedding_model,
persist_directory=self.vector_db_dir,
)
logger.info("[2/4] Building Sparse Keyword Index (BM25)...")
sparse_retriever = BM25Retriever.from_documents(document_chunks)
sparse_retriever.k = TOP_K_VECTORS
logger.info("[2/4] Serializing BM25 index with integrity hash...")
self._write_bm25_with_integrity(sparse_retriever)
dense_retriever = vector_store.as_retriever(
search_kwargs={"k": TOP_K_VECTORS}
)
logger.info("[2/4] Initializing Reciprocal Rank Fusion engine...")
self.ensemble_retriever = CustomHybridRetriever(
dense_retriever=dense_retriever,
sparse_retriever=sparse_retriever,
dense_weight=0.5,
sparse_weight=0.5,
)
logger.info("Hybrid Retrieval Engine ready.")
return self.ensemble_retriever