import torch import re from transformers import AutoTokenizer, AutoModelForSequenceClassification from api.retriever import ChunkRetriever TEMPERATURE = 1.5 CONFIDENCE_THRESHOLD = 0.60 CHUNK_SIZE = 400 CHUNK_OVERLAP = 50 def sliding_window_chunker(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]: """Splits a large text into overlapping word-level chunks.""" words = text.split() chunks = [] if not words: return chunks step = chunk_size - overlap if step <= 0: step = 1 for i in range(0, len(words), step): chunk_words = words[i:i + chunk_size] chunks.append(" ".join(chunk_words)) if i + chunk_size >= len(words): break return chunks def split_into_claims(text: str) -> list[str]: """Breaks LLM output into individual sentences so each factual claim gets scored independently (avoids filler diluting scores).""" raw_sentences = re.split(r'(?<=[.!?])\s+', text.strip()) valid_claims = [] for s in raw_sentences: clean = s.strip() if len(clean.split()) >= 3: valid_claims.append(clean) if not valid_claims and text.strip(): valid_claims = [text.strip()] return valid_claims def normalize_scores(contradiction: float, entailment: float, neutral: float) -> tuple[float, float, float]: """Makes sure the three scores always add up to exactly 100%.""" total = contradiction + entailment + neutral if total == 0: return (0.0, 0.0, 100.0) c = round((contradiction / total) * 100.0, 2) e = round((entailment / total) * 100.0, 2) n = round(100.0 - c - e, 2) return (c, e, n) class HallucinationDetector: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_name = "cross-encoder/nli-deberta-v3-base" print(f"Initializing Detector on {self.device.type.upper()}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device) print("Detector Ready!") # Stage 1 retriever — lightweight bi-encoder for pre-filtering chunks self.retriever = ChunkRetriever() def _infer_chunk(self, chunk: str, claim: str) -> dict: """Stage 2: runs the heavy cross-encoder on a single (chunk, claim) pair.""" inputs = self.tokenizer( chunk, claim, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) scaled_logits = outputs.logits / TEMPERATURE probs = torch.nn.functional.softmax(scaled_logits, dim=-1) c_raw = probs[0][0].item() e_raw = probs[0][1].item() n_raw = probs[0][2].item() # if the model isn't confident about anything, default to neutral max_score = max(c_raw, e_raw, n_raw) if max_score < CONFIDENCE_THRESHOLD: c_raw, e_raw, n_raw = 0.0, 0.0, 1.0 return { "contradiction": c_raw, "entailment": e_raw, "neutral": n_raw, "spans": [] # placeholder for Captum attributions } def analyze(self, context: str, llm_response: str) -> dict: """Two-stage pipeline: 1) Chunk the document → retrieve top-5 relevant chunks (bi-encoder) 2) Score each claim against those top chunks (cross-encoder) 3) Aggregate with priority resolution """ all_chunks = sliding_window_chunker(context) if not all_chunks: all_chunks = [""] # Stage 1: narrow down to the most relevant chunks relevant_chunks = self.retriever.get_top_chunks(llm_response, all_chunks) claims = split_into_claims(llm_response) sentence_scores = [] for claim in claims: # Stage 2: cross-encoder only runs on the pre-filtered chunks chunk_results = [self._infer_chunk(chunk, claim) for chunk in relevant_chunks] s_max_e = max(r["entailment"] for r in chunk_results) s_max_c = max(r["contradiction"] for r in chunk_results) s_max_n = max(r["neutral"] for r in chunk_results) # priority resolution — if the fact exists somewhere, entailment wins if s_max_e >= CONFIDENCE_THRESHOLD and s_max_e >= s_max_c: final_s_e = s_max_e final_s_c = s_max_c * 0.25 final_s_n = max(0.0, 1.0 - final_s_e - final_s_c) winning_spans = max(chunk_results, key=lambda x: x["entailment"])["spans"] elif s_max_c >= CONFIDENCE_THRESHOLD and s_max_c > s_max_e: final_s_c = s_max_c final_s_e = s_max_e * 0.25 final_s_n = max(0.0, 1.0 - final_s_c - final_s_e) winning_spans = max(chunk_results, key=lambda x: x["contradiction"])["spans"] else: final_s_c = s_max_c final_s_e = s_max_e final_s_n = s_max_n winning_spans = [] sentence_scores.append({ "c": final_s_c, "e": final_s_e, "n": final_s_n, "spans": winning_spans }) # document-level aggregation # contradiction uses max (one-strike rule) doc_c = max(s["c"] for s in sentence_scores) # entailment and neutral use average across claims doc_e = sum(s["e"] for s in sentence_scores) / len(sentence_scores) doc_n = sum(s["n"] for s in sentence_scores) / len(sentence_scores) doc_c = max(doc_c, 0.0) doc_e = max(doc_e, 0.0) doc_n = max(doc_n, 0.0) c_pct, e_pct, n_pct = normalize_scores(doc_c, doc_e, doc_n) # grab attribution spans from the highest-severity claim if doc_c > doc_e: best_spans = max(sentence_scores, key=lambda x: x["c"])["spans"] else: best_spans = max(sentence_scores, key=lambda x: x["e"])["spans"] is_hallucination = (c_pct > e_pct) and (doc_c >= CONFIDENCE_THRESHOLD) return { "contradiction_score": c_pct, "entailment_score": e_pct, "neutral_score": n_pct, "is_hallucination": is_hallucination, "attribution_spans": best_spans }