""" Evidence Retriever – FAISS-based semantic search over user-provided contexts. Encodes context sentences with all-MiniLM-L6-v2 and retrieves the top-k most similar evidence sentences for each claim. """ import re import logging import numpy as np import faiss from sentence_transformers import SentenceTransformer from facteval import suppress_stdout from facteval.config import DEFAULT_TOP_K, EMBEDDING_MODEL, MIN_EVIDENCE_SCORE from facteval.models import Claim, Evidence, ClaimWithEvidence logger = logging.getLogger(__name__) class EvidenceRetriever: """Build a FAISS index over context sentences and retrieve evidence for claims.""" def __init__( self, model_name: str = EMBEDDING_MODEL, device: str | None = None, ): self.device = device or ("cuda" if __import__("torch").cuda.is_available() else "cpu") logger.info("Loading embedding model: %s", model_name) with suppress_stdout(): self.embedder = SentenceTransformer(model_name, device=self.device) # Populated by .index() self._sentences: list[str] = [] self._sentence_to_context: dict[int, str] = {} self._index: faiss.IndexFlatIP | None = None def index(self, contexts: list[str]) -> "EvidenceRetriever": """ Build a FAISS index from a list of context passages. Each context is split into individual sentences before indexing. Args: contexts: List of context passages (strings). Returns: self (for chaining: `retriever.index(ctx).retrieve(claim)`). """ if not contexts: logger.warning("No contexts provided; retriever will return empty results.") self._sentences = [] self._index = None return self self._sentences = [] self._sentence_to_context = {} for ctx in contexts: for sent in self._split_sentences(ctx): idx = len(self._sentences) self._sentences.append(sent) self._sentence_to_context[idx] = ctx if not self._sentences: logger.warning("No sentences extracted from contexts.") self._index = None return self logger.info("Indexing %d evidence sentences.", len(self._sentences)) embeddings = self.embedder.encode( self._sentences, convert_to_numpy=True, normalize_embeddings=True ).astype(np.float32) dim = embeddings.shape[1] self._index = faiss.IndexFlatIP(dim) # Cosine similarity (normalized) self._index.add(embeddings) return self def retrieve( self, claim: Claim | str, top_k: int = DEFAULT_TOP_K, min_score: float = MIN_EVIDENCE_SCORE, ) -> list[Evidence]: """ Retrieve the top-k most relevant evidence sentences for a claim. Args: claim: A Claim object or plain string. top_k: Number of evidence sentences to return. min_score: Minimum cosine similarity to include. Returns: List of Evidence objects, sorted by score descending. """ if self._index is None or not self._sentences: return [] query_text = claim.text if isinstance(claim, Claim) else claim q_emb = self.embedder.encode( [query_text], convert_to_numpy=True, normalize_embeddings=True ).astype(np.float32) scores, indices = self._index.search(q_emb, top_k) results: list[Evidence] = [] for score, idx in zip(scores[0], indices[0]): if idx < 0 or idx >= len(self._sentences): continue clamped_score = float(min(max(score, 0.0), 1.0)) if clamped_score < min_score: continue results.append( Evidence( sentence=self._sentences[idx], score=clamped_score, source_context=self._sentence_to_context.get(idx, ""), ) ) return results def retrieve_for_claims( self, claims: list[Claim], top_k: int = DEFAULT_TOP_K, min_score: float = MIN_EVIDENCE_SCORE, ) -> list[ClaimWithEvidence]: """ Batch-retrieve evidence for a list of claims. Returns: List of ClaimWithEvidence objects. """ return [ ClaimWithEvidence( claim=claim, evidence=self.retrieve(claim, top_k=top_k, min_score=min_score), ) for claim in claims ] @staticmethod def _split_sentences(text: str) -> list[str]: """Split text into sentences on sentence-ending punctuation.""" raw = re.split(r"(?<=[.!?])\s+", text) return [s.strip() for s in raw if s.strip() and len(s.strip()) > 3]