File size: 12,177 Bytes
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ec89b
8fb73f8
80ec89b
8fb73f8
 
 
 
 
 
80ec89b
8fb73f8
80ec89b
8fb73f8
 
 
 
 
 
80ec89b
8fb73f8
80ec89b
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
7dc5288
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dc5288
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dc5288
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dc5288
 
 
 
 
 
 
 
8fb73f8
 
7dc5288
8fb73f8
 
 
 
 
 
 
7dc5288
 
 
 
 
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
Core – The main check() and verify() functions that wire the FactEval pipeline.

Usage:
    from facteval import check, verify

    # Full pipeline (extract + retrieve + verify)
    result = check(answer, contexts)

    # Lightweight mode (skip extraction, bring your own claims)
    result = verify(claims=["claim 1", "claim 2"], contexts=docs)
"""

import re
import logging
import time
from pathlib import Path

import numpy as np

from facteval.calibrator import Calibrator
from facteval.claim_extractor import ClaimExtractor
from facteval.retriever import EvidenceRetriever
from facteval.verifier import Verifier, FactLabel
from facteval.models import Claim

logger = logging.getLogger(__name__)

# Module-level singletons (lazy-loaded)
_extractor: ClaimExtractor | None = None
_retriever: EvidenceRetriever | None = None
_verifier: Verifier | None = None
_calibrator: Calibrator | None = None
_calibrator_path: str | None = None


def _get_extractor() -> ClaimExtractor:
    global _extractor
    if _extractor is None:
        print("⏳ Loading claim extractor (Qwen 1.5B)...", flush=True)
        _extractor = ClaimExtractor()
        print("βœ… Claim extractor ready.", flush=True)
    return _extractor


def _get_retriever() -> EvidenceRetriever:
    global _retriever
    if _retriever is None:
        print("⏳ Loading retriever (MiniLM + FAISS)...", flush=True)
        _retriever = EvidenceRetriever()
        print("βœ… Retriever ready.", flush=True)
    return _retriever


def _get_verifier() -> Verifier:
    global _verifier
    if _verifier is None:
        print("⏳ Loading verifier (DeBERTa NLI)...", flush=True)
        _verifier = Verifier()
        print("βœ… Verifier ready.", flush=True)
    return _verifier


def _get_calibrator(path: str | None = None) -> Calibrator:
    global _calibrator, _calibrator_path
    if _calibrator is None or path != _calibrator_path:
        _calibrator = Calibrator(calibrator_path=path)
        _calibrator_path = path
    return _calibrator


# ── Full pipeline ────────────────────────────────────────────────────────────

def analyze(
    answer: str,
    contexts: list[str],
    top_k: int = 3,
    max_claims: int = 10,
    calibrator_path: str | Path | None = None,
) -> dict:
    """
    Run the full FactEval pipeline on an answer + contexts.

    Stages: extract claims β†’ retrieve evidence β†’ NLI verify β†’ calibrate.

    Args:
        answer:          The LLM-generated text to evaluate.
        contexts:        List of reference passages (ground truth).
        top_k:           Number of evidence sentences to retrieve per claim.
        max_claims:      Maximum claims to extract.
        calibrator_path: Path to a pre-fitted calibrator pickle file.

    Returns:
        A dict with claims, summary, highlighted_answer, and pipeline_time.
    """
    t0 = time.perf_counter()

    # 1. Extract claims
    extractor = _get_extractor()
    claims = extractor.extract(answer, max_claims=max_claims)
    logger.info("Extracted %d claims.", len(claims))

    if not claims:
        return {
            "claims": [],
            "summary": _build_summary([]),
            "highlighted_answer": answer,
            "calibrated": False,
            "pipeline_time_seconds": round(time.perf_counter() - t0, 3),
        }

    # 2–5. Shared pipeline
    return _run_pipeline(claims, contexts, answer, top_k, calibrator_path, t0)


# ── Lightweight mode ─────────────────────────────────────────────────────────

def fast_check(
    claims: list[str],
    contexts: list[str],
    top_k: int = 3,
    calibrator_path: str | Path | None = None,
) -> dict:
    """
    Verify pre-extracted claims against contexts. Skips claim extraction.

    Use this when you already have claims and want faster results
    (avoids the ~1s extraction step and the Qwen model entirely).

    Args:
        claims:          List of claim strings to verify.
        contexts:        List of reference passages (ground truth).
        top_k:           Number of evidence sentences to retrieve per claim.
        calibrator_path: Path to a pre-fitted calibrator pickle file.

    Returns:
        Same output format as analyze().
    """
    t0 = time.perf_counter()

    claim_objs = [Claim(text=c) for c in claims if c.strip()]

    if not claim_objs:
        return {
            "claims": [],
            "summary": _build_summary([]),
            "highlighted_answer": "",
            "calibrated": False,
            "pipeline_time_seconds": round(time.perf_counter() - t0, 3),
        }

    answer = " ".join(claims)  # reconstruct for highlighting
    return _run_pipeline(claim_objs, contexts, answer, top_k, calibrator_path, t0)


# ── Shared pipeline ──────────────────────────────────────────────────────────

def _run_pipeline(
    claims: list[Claim],
    contexts: list[str],
    answer: str,
    top_k: int,
    calibrator_path: str | Path | None,
    t0: float,
) -> dict:
    """Shared pipeline: retrieve β†’ verify β†’ calibrate β†’ diagnose β†’ highlight."""

    # 2. Retrieve evidence
    retriever = _get_retriever()
    retriever.index(contexts)
    claims_with_evidence = retriever.retrieve_for_claims(claims, top_k=top_k)

    # 3. Verify (batch NLI)
    verifier = _get_verifier()
    results = verifier.verify_batch(claims_with_evidence)

    # 4. Calibrate
    calibrator = _get_calibrator(str(calibrator_path) if calibrator_path else None)
    for r in results:
        if r.raw_scores:
            cal_conf, cal_err = calibrator.calibrate(r.raw_scores)
            r.calibrated_confidence = cal_conf
            r.calibration_error = cal_err

    # 5. Build output with diagnostics
    elapsed = time.perf_counter() - t0
    claim_dicts = [r.to_dict() for r in results]

    # Add diagnostics to each claim
    for cd in claim_dicts:
        cd["diagnostics"] = _diagnose(cd)

    summary = _build_summary(results)
    
    # User feedback logging (feels alive)
    hallucinations = summary.get("contradicted", 0)
    supported = summary.get("supported", 0)
    print(f"βœ” Found {hallucinations} hallucination(s)")
    print(f"βœ” {supported} supported claim(s)")

    return {
        "claims": claim_dicts,
        "summary": summary,
        "highlighted_answer": _highlight_answer_semantic(
            answer, claim_dicts, retriever.embedder
        ),
        "calibrated": calibrator.is_calibrated,
        "pipeline_time_seconds": round(elapsed, 3),
    }

# Backward compatibility aliases
check = analyze
verify = fast_check
evaluate = analyze


# ── Diagnostics ──────────────────────────────────────────────────────────────

def _diagnose(claim_dict: dict) -> dict:
    """
    Generate pipeline diagnostics for a claim.

    Tells the developer *why* a claim got its label β€”
    was it a retrieval failure or a genuine hallucination?
    """
    label = claim_dict["label"]
    ev_score = claim_dict.get("evidence_score")
    confidence = claim_dict.get("confidence", 0)

    # Retrieval quality assessment
    if ev_score is None:
        retrieval_quality = "none"
    elif ev_score >= 0.7:
        retrieval_quality = "strong"
    elif ev_score >= 0.4:
        retrieval_quality = "moderate"
    else:
        retrieval_quality = "weak"

    # Failure type classification
    if label == "supported":
        failure_type = "verified"
        suggestion = None
    elif label == "contradicted":
        if retrieval_quality in ("strong", "moderate"):
            failure_type = "hallucination"
            suggestion = "Claim directly contradicts the evidence. This is a factual error in the LLM output."
        else:
            failure_type = "possible_hallucination"
            suggestion = "Claim contradicts weak evidence. Consider adding more specific context for reliable verification."
    elif ev_score is None:
        failure_type = "no_evidence"
        suggestion = "No relevant context was provided. Add reference passages covering this topic."
    elif retrieval_quality == "weak":
        failure_type = "retrieval_gap"
        suggestion = "Evidence was found but too dissimilar to trust. The context may not cover this claim."
    else:
        failure_type = "inconclusive"
        suggestion = "Evidence exists but is neutral β€” neither confirms nor denies the claim."

    d = {
        "failure_type": failure_type,
        "retrieval_quality": retrieval_quality,
    }
    if suggestion:
        d["suggestion"] = suggestion
    return d


# ── Summary ──────────────────────────────────────────────────────────────────

def _build_summary(results: list) -> dict:
    """Build summary statistics from verification results."""
    total = len(results)
    supported = sum(1 for r in results if r.label == FactLabel.SUPPORTED)
    contradicted = sum(1 for r in results if r.label == FactLabel.CONTRADICTED)
    unverifiable = total - supported - contradicted

    return {
        "total_claims": total,
        "supported": supported,
        "contradicted": contradicted,
        "unverifiable": unverifiable,
        "hallucination_rate": round(contradicted / max(total, 1), 4),
    }


# ── Semantic Highlighting ────────────────────────────────────────────────────

_LABEL_EMOJI = {"supported": "βœ…", "contradicted": "❌", "unverifiable": "❓"}
_LABEL_COLOR = {"supported": "#22c55e", "contradicted": "#ef4444", "unverifiable": "#f59e0b"}


def _highlight_answer_semantic(answer: str, claim_dicts: list[dict], embedder) -> str:
    """
    Map claims to source sentences using embedding similarity (not Jaccard).

    Uses the retriever's SentenceTransformer to compute cosine similarity
    between each claim and each sentence in the original answer. This handles
    paraphrasing, reordering, and partial overlaps much better than token overlap.
    """
    if not answer.strip() or not claim_dicts:
        return answer

    # Split answer into sentences with positions
    sentences = []
    for m in re.finditer(r'[^.!?]+[.!?]*', answer):
        text = m.group().strip()
        if text:
            sentences.append(text)

    if not sentences:
        return answer

    # Compute embedding similarity
    claim_texts = [c["claim"] for c in claim_dicts]
    claim_labels = [c["label"] for c in claim_dicts]

    sent_embeddings = embedder.encode(sentences, normalize_embeddings=True)
    claim_embeddings = embedder.encode(claim_texts, normalize_embeddings=True)

    # Similarity matrix: sentences Γ— claims
    sim_matrix = np.dot(sent_embeddings, claim_embeddings.T)

    # For each sentence, find best matching claim
    sentence_labels: dict[str, str] = {}
    for i, sent_text in enumerate(sentences):
        best_j = int(sim_matrix[i].argmax())
        best_sim = float(sim_matrix[i, best_j])

        if best_sim > 0.35:  # Semantic similarity threshold
            sentence_labels[sent_text] = claim_labels[best_j]

    # Build highlighted text (longest matches first to avoid partial replacements)
    highlighted = answer
    for sent_text in sorted(sentence_labels, key=len, reverse=True):
        label = sentence_labels[sent_text]
        color = _LABEL_COLOR.get(label, "#94a3b8")
        emoji = _LABEL_EMOJI.get(label, "")
        highlighted = highlighted.replace(
            sent_text,
            f'<mark style="background:{color}30;padding:2px 4px;border-radius:3px">'
            f'{sent_text} {emoji}</mark>',
            1,
        )

    return highlighted