Sahil al farib commited on
Commit
8fb73f8
·
1 Parent(s): b29689f

Deploy FactEval: claim-level hallucination detection with Gradio demo

Browse files
README.md CHANGED
@@ -1,15 +1,29 @@
1
  ---
2
  title: FactEval
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
  short_description: Find exactly which parts of your LLM output are hallucinated
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: FactEval
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.44.1
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Find exactly which parts of your LLM output are hallucinated
12
  ---
13
 
14
+ # 🔍 FactEval
15
+
16
+ **Find exactly which parts of your LLM output are hallucinated.**
17
+
18
+ Debug hallucinations in RAG and LLM pipelines with claim-level verification.
19
+
20
+ Paste an LLM-generated answer and reference contexts. FactEval highlights ✅ **supported**, ❌ **contradicted**, and ❓ **unverifiable** claims with human-readable reasons and pipeline diagnostics.
21
+
22
+ ## How it works
23
+
24
+ 1. **Claim Extraction** — Breaks the answer into atomic claims (Qwen2.5-1.5B)
25
+ 2. **Evidence Retrieval** — Finds the most relevant sentences from your contexts (MiniLM + FAISS)
26
+ 3. **NLI Verification** — Checks each claim against evidence (DeBERTa-v3)
27
+ 4. **Calibration** — Produces trustworthy confidence scores (Isotonic Regression)
28
+
29
+ 📦 [GitHub Repository](https://github.com/sahilaf/FactEval)
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """HF Spaces entry point — imports the Gradio demo from demo/app.py."""
2
+ from demo.app import demo
3
+
4
+ if __name__ == "__main__":
5
+ demo.launch()
demo/app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactEval Gradio Demo – Interactive factuality checker.
3
+
4
+ Run locally: python demo/app.py
5
+ Run on Colab: Upload facteval/ folder, then run this file.
6
+ """
7
+
8
+ import json
9
+ import gradio as gr
10
+ from facteval import check, verify
11
+
12
+ EXAMPLES = [
13
+ [
14
+ "Paris is the capital of Germany and has 5 million people.",
15
+ "Paris is the capital of France. Paris has approximately 2.2 million inhabitants.\nGermany's capital is Berlin.",
16
+ ],
17
+ [
18
+ "Python was created by Guido van Rossum and first released in 2005.",
19
+ "Python was created by Guido van Rossum and first released in 1991.",
20
+ ],
21
+ [
22
+ "The Amazon rainforest produces 20% of the world's oxygen and spans across nine countries.",
23
+ "The Amazon rainforest produces about 6% of the world's oxygen.\nThe Amazon rainforest spans across nine countries in South America.",
24
+ ],
25
+ [
26
+ "Albert Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921 for his work on the photoelectric effect.",
27
+ "Albert Einstein developed the theory of relativity. He won the Nobel Prize in Physics in 1921 for his explanation of the photoelectric effect.",
28
+ ],
29
+ ]
30
+
31
+
32
+ def run_check(answer: str, contexts: str, calibrator_path: str = ""):
33
+ """Run FactEval pipeline and format results for Gradio."""
34
+ if not answer.strip():
35
+ return "⚠️ Please enter an answer to check.", "", "", ""
36
+
37
+ context_list = [c.strip() for c in contexts.strip().split("\n") if c.strip()]
38
+ if not context_list:
39
+ return "⚠️ Please enter at least one context passage.", "", "", ""
40
+
41
+ cal_path = calibrator_path.strip() if calibrator_path.strip() else None
42
+ result = check(answer, context_list, calibrator_path=cal_path)
43
+
44
+ # 1. Highlighted answer (the viral feature)
45
+ highlighted_html = f"""
46
+ <div style="font-family: Inter, sans-serif; font-size: 18px; line-height: 2;
47
+ padding: 20px; border-radius: 12px; background: #0f172a; color: #e2e8f0;">
48
+ {result.get("highlighted_answer", answer)}
49
+ </div>
50
+ """
51
+
52
+ # 2. Per-claim verdicts with reasons
53
+ details_parts = []
54
+ for c in result["claims"]:
55
+ label = c["label"]
56
+ colors = {"supported": "#22c55e", "contradicted": "#ef4444", "unverifiable": "#f59e0b"}
57
+ emojis = {"supported": "✅", "contradicted": "❌", "unverifiable": "❓"}
58
+ color = colors.get(label, "#94a3b8")
59
+ emoji = emojis.get(label, "")
60
+ conf = c.get("calibrated_confidence", c["confidence"])
61
+
62
+ diag = c.get("diagnostics", {})
63
+ diag_type = diag.get("failure_type", "")
64
+ diag_badge_colors = {
65
+ "verified": "#22c55e", "hallucination": "#ef4444", "possible_hallucination": "#f97316",
66
+ "no_evidence": "#6b7280", "retrieval_gap": "#8b5cf6", "inconclusive": "#f59e0b",
67
+ }
68
+ badge_color = diag_badge_colors.get(diag_type, "#64748b")
69
+ suggestion = diag.get("suggestion", "")
70
+
71
+ details_parts.append(f"""
72
+ <div style="padding: 12px; margin: 8px 0; border-left: 4px solid {color};
73
+ background: {color}10; border-radius: 0 8px 8px 0; font-family: Inter, sans-serif;">
74
+ <div style="font-weight: 600; font-size: 15px; color: #f1f5f9;">
75
+ {emoji} {c["claim"]}
76
+ <span style="font-size: 11px; padding: 2px 8px; border-radius: 12px;
77
+ background: {badge_color}30; color: {badge_color}; margin-left: 8px;">
78
+ {diag_type.replace("_", " ")}
79
+ </span>
80
+ </div>
81
+ <div style="font-size: 13px; color: #94a3b8; margin-top: 4px;">
82
+ {c.get("reason", "")}
83
+ </div>
84
+ {'<div style="font-size: 12px; color: #f59e0b; margin-top: 4px; font-style: italic;">💡 ' + suggestion + '</div>' if suggestion else ''}
85
+ <div style="font-size: 12px; color: #64748b; margin-top: 4px;">
86
+ Confidence: {conf:.1%}
87
+ {"• Evidence score: " + f"{c['evidence_score']:.3f}" if c.get("evidence_score") else ""}
88
+ • Retrieval: {diag.get("retrieval_quality", "n/a")}
89
+ </div>
90
+ </div>
91
+ """)
92
+
93
+ details_html = '<div>' + ''.join(details_parts) + '</div>'
94
+
95
+ # 3. Summary card
96
+ s = result["summary"]
97
+ summary_html = f"""
98
+ <div style="font-family: Inter, sans-serif; padding: 16px; border-radius: 12px;
99
+ background: linear-gradient(135deg, #1e293b, #334155); color: white;">
100
+ <h3 style="margin: 0 0 12px 0; color: #e2e8f0;">📊 Summary</h3>
101
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 8px;">
102
+ <div style="padding: 8px; background: #ffffff10; border-radius: 8px;">
103
+ <div style="font-size: 24px; font-weight: bold;">{s['total_claims']}</div>
104
+ <div style="font-size: 12px; color: #94a3b8;">Total Claims</div>
105
+ </div>
106
+ <div style="padding: 8px; background: #22c55e20; border-radius: 8px;">
107
+ <div style="font-size: 24px; font-weight: bold; color: #22c55e;">{s['supported']}</div>
108
+ <div style="font-size: 12px; color: #94a3b8;">Supported</div>
109
+ </div>
110
+ <div style="padding: 8px; background: #ef444420; border-radius: 8px;">
111
+ <div style="font-size: 24px; font-weight: bold; color: #ef4444;">{s['contradicted']}</div>
112
+ <div style="font-size: 12px; color: #94a3b8;">Contradicted</div>
113
+ </div>
114
+ <div style="padding: 8px; background: #f59e0b20; border-radius: 8px;">
115
+ <div style="font-size: 24px; font-weight: bold; color: #f59e0b;">{s['unverifiable']}</div>
116
+ <div style="font-size: 12px; color: #94a3b8;">Unverifiable</div>
117
+ </div>
118
+ </div>
119
+ <div style="margin-top: 12px; padding: 8px; background: #ffffff10; border-radius: 8px; text-align: center;">
120
+ <span style="font-size: 14px; color: #94a3b8;">Hallucination Rate</span><br>
121
+ <span style="font-size: 28px; font-weight: bold;
122
+ color: {'#22c55e' if s['hallucination_rate'] < 0.3 else '#ef4444'};">
123
+ {s['hallucination_rate']:.0%}
124
+ </span>
125
+ </div>
126
+ <div style="margin-top: 8px; font-size: 11px; color: #64748b; text-align: right;">
127
+ ⏱ {result['pipeline_time_seconds']:.1f}s
128
+ {'• 📐 calibrated' if result.get('calibrated') else '• raw scores'}
129
+ </div>
130
+ </div>
131
+ """
132
+
133
+ # 4. Raw JSON
134
+ json_output = json.dumps(result, indent=2, ensure_ascii=False)
135
+
136
+ return highlighted_html, details_html, summary_html, json_output
137
+
138
+
139
+ # ── Gradio Interface ─────────────────────────────────────────────────────────
140
+
141
+ with gr.Blocks(
142
+ title="FactEval – Hallucination Detector",
143
+ theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
144
+ css="""
145
+ .gradio-container { max-width: 960px !important; }
146
+ footer { display: none !important; }
147
+ """,
148
+ ) as demo:
149
+ gr.Markdown(
150
+ """
151
+ # 🔍 FactEval – Find Exactly Which Parts Are Hallucinated
152
+ Paste an LLM-generated answer and reference contexts.
153
+ FactEval highlights ✅ **supported**, ❌ **contradicted**, and ❓ **unverifiable** claims.
154
+ """
155
+ )
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ answer_input = gr.Textbox(
160
+ label="LLM Answer",
161
+ placeholder="Enter the text to fact-check...",
162
+ lines=4,
163
+ )
164
+ context_input = gr.Textbox(
165
+ label="Reference Contexts (one per line)",
166
+ placeholder="Enter ground truth passages, one per line...",
167
+ lines=5,
168
+ )
169
+ calibrator_input = gr.Textbox(
170
+ label="Calibrator Path (optional)",
171
+ placeholder="Path to calibrator.pkl",
172
+ lines=1,
173
+ )
174
+ check_btn = gr.Button("🔍 Check Factuality", variant="primary", size="lg")
175
+
176
+ gr.Markdown("### 📝 Highlighted Answer")
177
+ highlighted_output = gr.HTML()
178
+
179
+ with gr.Row():
180
+ with gr.Column(scale=2):
181
+ gr.Markdown("### 📋 Claim Details")
182
+ details_output = gr.HTML()
183
+ with gr.Column(scale=1):
184
+ summary_output = gr.HTML()
185
+
186
+ with gr.Accordion("Raw JSON Output", open=False):
187
+ json_output = gr.Code(language="json")
188
+
189
+ check_btn.click(
190
+ fn=run_check,
191
+ inputs=[answer_input, context_input, calibrator_input],
192
+ outputs=[highlighted_output, details_output, summary_output, json_output],
193
+ )
194
+
195
+ gr.Examples(
196
+ examples=EXAMPLES,
197
+ inputs=[answer_input, context_input],
198
+ label="Try these examples",
199
+ )
200
+
201
+ if __name__ == "__main__":
202
+ demo.launch(share=True)
facteval/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FactEval – Find exactly which parts of your LLM output are hallucinated."""
2
+
3
+ # Suppress known harmless warnings from dependencies before any imports
4
+ import os as _os
5
+ import sys as _sys
6
+ import warnings as _warnings
7
+ import logging as _logging
8
+ import contextlib as _contextlib
9
+ import io as _io
10
+
11
+ # Suppress safetensors / accelerate noise
12
+ _os.environ.setdefault("SAFETENSORS_LOG_LEVEL", "error")
13
+ _os.environ.setdefault("ACCELERATE_LOG_LEVEL", "error")
14
+ _logging.getLogger("safetensors").setLevel(_logging.ERROR)
15
+ _logging.getLogger("accelerate").setLevel(_logging.ERROR)
16
+
17
+ # Suppress HF Hub unauthenticated request warnings
18
+ _logging.getLogger("huggingface_hub.utils._http").setLevel(_logging.ERROR)
19
+ _logging.getLogger("huggingface_hub").setLevel(_logging.ERROR)
20
+
21
+ # Suppress transformers info-level noise
22
+ _logging.getLogger("transformers.modeling_utils").setLevel(_logging.ERROR)
23
+ _logging.getLogger("transformers.generation.configuration_utils").setLevel(_logging.ERROR)
24
+
25
+ # Suppress FutureWarning about clean_up_tokenization_spaces
26
+ _warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
27
+
28
+
29
+ @_contextlib.contextmanager
30
+ def suppress_loading_noise():
31
+ """Suppress stdout + stderr noise during model loading (LOAD REPORT, sharding info)."""
32
+ old_stdout, old_stderr = _sys.stdout, _sys.stderr
33
+ _sys.stdout = _io.StringIO()
34
+ _sys.stderr = _io.StringIO()
35
+ try:
36
+ yield
37
+ finally:
38
+ _sys.stdout = old_stdout
39
+ _sys.stderr = old_stderr
40
+
41
+
42
+ # Backward compat alias
43
+ suppress_stdout = suppress_loading_noise
44
+
45
+
46
+ # ── Public API ───────────────────────────────────────────────────────────────
47
+ from facteval.core import check, verify
48
+ from facteval.models import Claim, Evidence, ClaimWithEvidence
49
+ from facteval.verifier import FactLabel, VerificationResult
50
+
51
+ __version__ = "0.1.0"
52
+ __all__ = [
53
+ "check",
54
+ "verify",
55
+ "Claim",
56
+ "Evidence",
57
+ "ClaimWithEvidence",
58
+ "FactLabel",
59
+ "VerificationResult",
60
+ "suppress_loading_noise",
61
+ ]
facteval/calibrator.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calibrator – Transforms raw NLI scores into calibrated probabilities.
3
+
4
+ Uses isotonic regression models (fitted in Week 0) to produce trustworthy
5
+ confidence scores and calibration error estimates.
6
+
7
+ Falls back gracefully to raw scores if no calibrator file is available.
8
+ """
9
+
10
+ import logging
11
+ import pickle
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Calibrator:
18
+ """Apply isotonic regression calibration to raw NLI probabilities."""
19
+
20
+ def __init__(self, calibrator_path: str | Path | None = None):
21
+ """
22
+ Load a pre-fitted calibrator from a pickle file.
23
+
24
+ Args:
25
+ calibrator_path: Path to the pickle file containing a dict of
26
+ {label_name: IsotonicRegression} objects.
27
+ If None or file doesn't exist, falls back to raw scores.
28
+ """
29
+ self._calibrators: dict | None = None
30
+
31
+ if calibrator_path is not None:
32
+ path = Path(calibrator_path)
33
+ if path.exists():
34
+ with open(path, "rb") as f:
35
+ self._calibrators = pickle.load(f)
36
+ logger.info(
37
+ "Loaded calibrator from %s (labels: %s)",
38
+ path, list(self._calibrators.keys()),
39
+ )
40
+ else:
41
+ logger.warning("Calibrator file not found: %s. Using raw scores.", path)
42
+
43
+ @property
44
+ def is_calibrated(self) -> bool:
45
+ """Whether a calibrator is loaded."""
46
+ return self._calibrators is not None
47
+
48
+ def calibrate(self, raw_scores: dict[str, float]) -> tuple[float, float]:
49
+ """
50
+ Calibrate raw NLI probabilities.
51
+
52
+ Args:
53
+ raw_scores: Dict mapping label names to raw probabilities
54
+ (e.g. {"entailment": 0.95, "neutral": 0.03, "contradiction": 0.02}).
55
+
56
+ Returns:
57
+ (calibrated_confidence, calibration_error)
58
+ - calibrated_confidence: The calibrated probability for the predicted label.
59
+ - calibration_error: Absolute difference between raw and calibrated confidence.
60
+ """
61
+ if not raw_scores:
62
+ return 0.0, 0.0
63
+
64
+ # Find the predicted label (highest raw score)
65
+ predicted_label = max(raw_scores, key=raw_scores.get)
66
+ raw_confidence = raw_scores[predicted_label]
67
+
68
+ if not self.is_calibrated:
69
+ # Fallback: return raw confidence with an estimated error
70
+ return raw_confidence, self._estimate_error(raw_confidence)
71
+
72
+ # Apply isotonic regression for each label
73
+ calibrated_scores = {}
74
+ for label, raw_prob in raw_scores.items():
75
+ if label in self._calibrators:
76
+ cal_prob = float(self._calibrators[label].predict([[raw_prob]])[0])
77
+ calibrated_scores[label] = max(0.0, min(1.0, cal_prob))
78
+ else:
79
+ calibrated_scores[label] = raw_prob
80
+
81
+ calibrated_confidence = calibrated_scores.get(predicted_label, raw_confidence)
82
+ calibration_error = abs(raw_confidence - calibrated_confidence)
83
+
84
+ return round(calibrated_confidence, 4), round(calibration_error, 4)
85
+
86
+ @staticmethod
87
+ def _estimate_error(raw_confidence: float) -> float:
88
+ """Rough error estimate when no calibrator is available."""
89
+ # Higher confidence → lower estimated error, but never zero
90
+ return round(max(0.02, (1.0 - raw_confidence) * 0.3), 4)
facteval/claim_extractor.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Claim Extractor – Breaks text into atomic, verifiable claims.
3
+
4
+ Uses Qwen2.5-1.5B-Instruct (chosen in Week 0 for speed and output quality)
5
+ with the model's chat template to produce clean numbered lists.
6
+ """
7
+
8
+ import re
9
+ import logging
10
+
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
+ from facteval import suppress_stdout
15
+
16
+ from facteval.config import (
17
+ CLAIM_MODEL,
18
+ CLAIM_SYSTEM_PROMPT,
19
+ CLAIM_USER_PROMPT,
20
+ MAX_CLAIMS,
21
+ MAX_NEW_TOKENS,
22
+ )
23
+ from facteval.models import Claim
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ClaimExtractor:
29
+ """Extract atomic claims from text using a causal LM with chat prompting."""
30
+
31
+ def __init__(
32
+ self,
33
+ model_name: str = CLAIM_MODEL,
34
+ device: str | None = None,
35
+ dtype: torch.dtype | None = None,
36
+ ):
37
+ self.model_name = model_name
38
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
39
+ self.dtype = dtype or (torch.float16 if self.device == "cuda" else torch.float32)
40
+
41
+ logger.info("Loading claim extractor: %s on %s", model_name, self.device)
42
+ self.tokenizer = AutoTokenizer.from_pretrained(
43
+ model_name, trust_remote_code=True
44
+ )
45
+ with suppress_stdout():
46
+ self.model = AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ dtype=self.dtype,
49
+ device_map="auto" if self.device == "cuda" else None,
50
+ trust_remote_code=True,
51
+ )
52
+ if self.device == "cpu":
53
+ self.model = self.model.to(self.device)
54
+ self.model.eval()
55
+
56
+ # Clear sampling params from generation_config to avoid
57
+ # "generation flags are not valid" warnings with do_sample=False
58
+ gen_cfg = self.model.generation_config
59
+ for attr in ("temperature", "top_p", "top_k"):
60
+ if hasattr(gen_cfg, attr):
61
+ setattr(gen_cfg, attr, None)
62
+
63
+ logger.info("Claim extractor ready.")
64
+
65
+ def extract(
66
+ self,
67
+ text: str,
68
+ max_claims: int = MAX_CLAIMS,
69
+ max_new_tokens: int = MAX_NEW_TOKENS,
70
+ ) -> list[Claim]:
71
+ """
72
+ Extract atomic claims from *text*.
73
+
74
+ Args:
75
+ text: The text to decompose into claims.
76
+ max_claims: Maximum number of claims to return.
77
+ max_new_tokens: Generation length cap (prevents rambling).
78
+
79
+ Returns:
80
+ A deduplicated list of Claim objects.
81
+ """
82
+ if not text or not text.strip():
83
+ return []
84
+
85
+ raw_output = self._generate(text, max_new_tokens)
86
+ claims = self._parse_claims(raw_output, text, max_claims)
87
+ logger.info("Extracted %d claims from %d-char text.", len(claims), len(text))
88
+ return claims
89
+
90
+ # ── Private helpers ──────────────────────────────────────────────────────
91
+
92
+ def _generate(self, text: str, max_new_tokens: int) -> str:
93
+ """Run the LLM to generate claim text."""
94
+ messages = [
95
+ {"role": "system", "content": CLAIM_SYSTEM_PROMPT},
96
+ {"role": "user", "content": CLAIM_USER_PROMPT.format(text=text)},
97
+ ]
98
+ prompt = self.tokenizer.apply_chat_template(
99
+ messages, tokenize=False, add_generation_prompt=True
100
+ )
101
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
102
+
103
+ with torch.no_grad():
104
+ output_ids = self.model.generate(
105
+ **inputs,
106
+ max_new_tokens=max_new_tokens,
107
+ do_sample=False,
108
+ )
109
+
110
+ # Decode only the newly generated tokens
111
+ generated = output_ids[0][inputs["input_ids"].shape[1]:]
112
+ return self.tokenizer.decode(generated, skip_special_tokens=True).strip()
113
+
114
+ @staticmethod
115
+ def _parse_claims(
116
+ raw: str, source_text: str, max_claims: int
117
+ ) -> list[Claim]:
118
+ """Parse numbered/bulleted list into deduplicated Claim objects."""
119
+ seen: set[str] = set()
120
+ claims: list[Claim] = []
121
+
122
+ for line in raw.split("\n"):
123
+ # Strip numbering (e.g. "1.", "1)", "- ", "• ")
124
+ cleaned = re.sub(r"^[\d.\)\-•\s]+", "", line).strip()
125
+ if len(cleaned) <= 5:
126
+ continue
127
+
128
+ # Normalize for dedup (lowercase, collapse whitespace)
129
+ key = re.sub(r"\s+", " ", cleaned.lower())
130
+ if key in seen:
131
+ continue
132
+ seen.add(key)
133
+
134
+ claims.append(Claim(text=cleaned, source_text=source_text))
135
+ if len(claims) >= max_claims:
136
+ break
137
+
138
+ return claims
facteval/cli.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI – Command-line interface for FactEval.
3
+
4
+ Usage:
5
+ facteval check input.json
6
+ facteval check input.json --output output.json
7
+ facteval check --answer "..." --context "ctx1" --context "ctx2"
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import sys
13
+ import logging
14
+
15
+
16
+ def main():
17
+ """Entry point for the facteval CLI."""
18
+ parser = argparse.ArgumentParser(
19
+ prog="facteval",
20
+ description="FactEval – Claim-level factuality evaluation with calibrated confidence.",
21
+ )
22
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
23
+
24
+ # ── facteval check ───────────────────────────────────────────────────────
25
+ check_parser = subparsers.add_parser(
26
+ "check", help="Check an answer for factual accuracy against provided contexts."
27
+ )
28
+ check_parser.add_argument(
29
+ "input_file", nargs="?", default=None,
30
+ help='JSON file with "answer" and "contexts" keys.',
31
+ )
32
+ check_parser.add_argument(
33
+ "--answer", "-a", type=str, default=None,
34
+ help="The answer text to check (alternative to input file).",
35
+ )
36
+ check_parser.add_argument(
37
+ "--context", "-c", action="append", default=None,
38
+ help="Context passage (can be repeated). Alternative to input file.",
39
+ )
40
+ check_parser.add_argument(
41
+ "--output", "-o", type=str, default=None,
42
+ help="Output file path. If not provided, prints to stdout.",
43
+ )
44
+ check_parser.add_argument(
45
+ "--calibrator", type=str, default=None,
46
+ help="Path to a pre-fitted calibrator pickle file.",
47
+ )
48
+ check_parser.add_argument(
49
+ "--top-k", type=int, default=3,
50
+ help="Number of evidence sentences to retrieve per claim (default: 3).",
51
+ )
52
+ check_parser.add_argument(
53
+ "--max-claims", type=int, default=10,
54
+ help="Maximum number of claims to extract (default: 10).",
55
+ )
56
+ check_parser.add_argument(
57
+ "--verbose", "-v", action="store_true",
58
+ help="Enable verbose logging.",
59
+ )
60
+
61
+ args = parser.parse_args()
62
+
63
+ if args.command is None:
64
+ parser.print_help()
65
+ sys.exit(0)
66
+
67
+ if args.command == "check":
68
+ _run_check(args)
69
+
70
+
71
+ def _run_check(args):
72
+ """Execute the check command."""
73
+ # Configure logging
74
+ level = logging.INFO if args.verbose else logging.WARNING
75
+ logging.basicConfig(level=level, format="%(name)s | %(message)s")
76
+
77
+ # Parse input
78
+ answer, contexts = _parse_input(args)
79
+ if answer is None:
80
+ print("Error: Provide either an input JSON file or --answer + --context flags.", file=sys.stderr)
81
+ sys.exit(1)
82
+
83
+ # Import here to avoid slow import on --help
84
+ from facteval.core import check
85
+
86
+ # Run pipeline
87
+ result = check(
88
+ answer=answer,
89
+ contexts=contexts,
90
+ top_k=args.top_k,
91
+ max_claims=args.max_claims,
92
+ calibrator_path=args.calibrator,
93
+ )
94
+
95
+ # Output
96
+ output_json = json.dumps(result, indent=2, ensure_ascii=False)
97
+
98
+ if args.output:
99
+ with open(args.output, "w", encoding="utf-8") as f:
100
+ f.write(output_json)
101
+ print(f"Results saved to {args.output}", file=sys.stderr)
102
+ else:
103
+ print(output_json)
104
+
105
+
106
+ def _parse_input(args) -> tuple[str | None, list[str]]:
107
+ """Parse answer and contexts from file or CLI flags."""
108
+ # Option 1: JSON file
109
+ if args.input_file:
110
+ with open(args.input_file, "r", encoding="utf-8") as f:
111
+ data = json.load(f)
112
+ return data.get("answer"), data.get("contexts", [])
113
+
114
+ # Option 2: CLI flags
115
+ if args.answer:
116
+ return args.answer, args.context or []
117
+
118
+ # Option 3: stdin
119
+ if not sys.stdin.isatty():
120
+ data = json.load(sys.stdin)
121
+ return data.get("answer"), data.get("contexts", [])
122
+
123
+ return None, []
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
facteval/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Default configuration for FactEval models and parameters.
3
+ """
4
+
5
+ # ── Model IDs (Hugging Face Hub) ─────────────────────────────────────────────
6
+ # Claim extraction – chosen in Week 0: 1.5B was 3.5x faster with cleaner output
7
+ CLAIM_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
8
+
9
+ # Sentence embeddings for evidence retrieval
10
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
11
+
12
+ # NLI verification (used in Week 2)
13
+ NLI_MODEL = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
14
+
15
+ # ── Retrieval defaults ───────────────────────────────────────────────────────
16
+ DEFAULT_TOP_K = 3
17
+ MIN_EVIDENCE_SCORE = 0.3 # Below this, evidence is too weak to use
18
+
19
+ # ── Claim extraction defaults ────────────────────────────────────────────────
20
+ MAX_NEW_TOKENS = 200
21
+ MAX_CLAIMS = 10
22
+
23
+ CLAIM_SYSTEM_PROMPT = (
24
+ "You are a claim extraction engine. Given a text, break it into atomic, "
25
+ "independently verifiable claims. Each claim states exactly ONE fact. "
26
+ "Return ONLY a numbered list. No explanations, no commentary."
27
+ )
28
+
29
+ CLAIM_USER_PROMPT = "Break this into atomic claims:\n\n{text}"
facteval/core.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core – The main check() and verify() functions that wire the FactEval pipeline.
3
+
4
+ Usage:
5
+ from facteval import check, verify
6
+
7
+ # Full pipeline (extract + retrieve + verify)
8
+ result = check(answer, contexts)
9
+
10
+ # Lightweight mode (skip extraction, bring your own claims)
11
+ result = verify(claims=["claim 1", "claim 2"], contexts=docs)
12
+ """
13
+
14
+ import re
15
+ import logging
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+
21
+ from facteval.calibrator import Calibrator
22
+ from facteval.claim_extractor import ClaimExtractor
23
+ from facteval.retriever import EvidenceRetriever
24
+ from facteval.verifier import Verifier, FactLabel
25
+ from facteval.models import Claim
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Module-level singletons (lazy-loaded)
30
+ _extractor: ClaimExtractor | None = None
31
+ _retriever: EvidenceRetriever | None = None
32
+ _verifier: Verifier | None = None
33
+ _calibrator: Calibrator | None = None
34
+ _calibrator_path: str | None = None
35
+
36
+
37
+ def _get_extractor() -> ClaimExtractor:
38
+ global _extractor
39
+ if _extractor is None:
40
+ _extractor = ClaimExtractor()
41
+ return _extractor
42
+
43
+
44
+ def _get_retriever() -> EvidenceRetriever:
45
+ global _retriever
46
+ if _retriever is None:
47
+ _retriever = EvidenceRetriever()
48
+ return _retriever
49
+
50
+
51
+ def _get_verifier() -> Verifier:
52
+ global _verifier
53
+ if _verifier is None:
54
+ _verifier = Verifier()
55
+ return _verifier
56
+
57
+
58
+ def _get_calibrator(path: str | None = None) -> Calibrator:
59
+ global _calibrator, _calibrator_path
60
+ if _calibrator is None or path != _calibrator_path:
61
+ _calibrator = Calibrator(calibrator_path=path)
62
+ _calibrator_path = path
63
+ return _calibrator
64
+
65
+
66
+ # ── Full pipeline ────────────────────────────────────────────────────────────
67
+
68
+ def check(
69
+ answer: str,
70
+ contexts: list[str],
71
+ top_k: int = 3,
72
+ max_claims: int = 10,
73
+ calibrator_path: str | Path | None = None,
74
+ ) -> dict:
75
+ """
76
+ Run the full FactEval pipeline on an answer + contexts.
77
+
78
+ Stages: extract claims → retrieve evidence → NLI verify → calibrate.
79
+
80
+ Args:
81
+ answer: The LLM-generated text to evaluate.
82
+ contexts: List of reference passages (ground truth).
83
+ top_k: Number of evidence sentences to retrieve per claim.
84
+ max_claims: Maximum claims to extract.
85
+ calibrator_path: Path to a pre-fitted calibrator pickle file.
86
+
87
+ Returns:
88
+ A dict with claims, summary, highlighted_answer, and pipeline_time.
89
+ """
90
+ t0 = time.perf_counter()
91
+
92
+ # 1. Extract claims
93
+ extractor = _get_extractor()
94
+ claims = extractor.extract(answer, max_claims=max_claims)
95
+ logger.info("Extracted %d claims.", len(claims))
96
+
97
+ if not claims:
98
+ return {
99
+ "claims": [],
100
+ "summary": _build_summary([]),
101
+ "highlighted_answer": answer,
102
+ "calibrated": False,
103
+ "pipeline_time_seconds": round(time.perf_counter() - t0, 3),
104
+ }
105
+
106
+ # 2–5. Shared pipeline
107
+ return _run_pipeline(claims, contexts, answer, top_k, calibrator_path, t0)
108
+
109
+
110
+ # ── Lightweight mode ─────────────────────────────────────────────────────────
111
+
112
+ def verify(
113
+ claims: list[str],
114
+ contexts: list[str],
115
+ top_k: int = 3,
116
+ calibrator_path: str | Path | None = None,
117
+ ) -> dict:
118
+ """
119
+ Verify pre-extracted claims against contexts. Skips claim extraction.
120
+
121
+ Use this when you already have claims and want faster results
122
+ (avoids the ~1s extraction step and the Qwen model entirely).
123
+
124
+ Args:
125
+ claims: List of claim strings to verify.
126
+ contexts: List of reference passages (ground truth).
127
+ top_k: Number of evidence sentences to retrieve per claim.
128
+ calibrator_path: Path to a pre-fitted calibrator pickle file.
129
+
130
+ Returns:
131
+ Same output format as check().
132
+ """
133
+ t0 = time.perf_counter()
134
+
135
+ claim_objs = [Claim(text=c) for c in claims if c.strip()]
136
+
137
+ if not claim_objs:
138
+ return {
139
+ "claims": [],
140
+ "summary": _build_summary([]),
141
+ "highlighted_answer": "",
142
+ "calibrated": False,
143
+ "pipeline_time_seconds": round(time.perf_counter() - t0, 3),
144
+ }
145
+
146
+ answer = " ".join(claims) # reconstruct for highlighting
147
+ return _run_pipeline(claim_objs, contexts, answer, top_k, calibrator_path, t0)
148
+
149
+
150
+ # ── Shared pipeline ──────────────────────────────────────────────────────────
151
+
152
+ def _run_pipeline(
153
+ claims: list[Claim],
154
+ contexts: list[str],
155
+ answer: str,
156
+ top_k: int,
157
+ calibrator_path: str | Path | None,
158
+ t0: float,
159
+ ) -> dict:
160
+ """Shared pipeline: retrieve → verify → calibrate → diagnose → highlight."""
161
+
162
+ # 2. Retrieve evidence
163
+ retriever = _get_retriever()
164
+ retriever.index(contexts)
165
+ claims_with_evidence = retriever.retrieve_for_claims(claims, top_k=top_k)
166
+
167
+ # 3. Verify (batch NLI)
168
+ verifier = _get_verifier()
169
+ results = verifier.verify_batch(claims_with_evidence)
170
+
171
+ # 4. Calibrate
172
+ calibrator = _get_calibrator(str(calibrator_path) if calibrator_path else None)
173
+ for r in results:
174
+ if r.raw_scores:
175
+ cal_conf, cal_err = calibrator.calibrate(r.raw_scores)
176
+ r.calibrated_confidence = cal_conf
177
+ r.calibration_error = cal_err
178
+
179
+ # 5. Build output with diagnostics
180
+ elapsed = time.perf_counter() - t0
181
+ claim_dicts = [r.to_dict() for r in results]
182
+
183
+ # Add diagnostics to each claim
184
+ for cd in claim_dicts:
185
+ cd["diagnostics"] = _diagnose(cd)
186
+
187
+ return {
188
+ "claims": claim_dicts,
189
+ "summary": _build_summary(results),
190
+ "highlighted_answer": _highlight_answer_semantic(
191
+ answer, claim_dicts, retriever.embedder
192
+ ),
193
+ "calibrated": calibrator.is_calibrated,
194
+ "pipeline_time_seconds": round(elapsed, 3),
195
+ }
196
+
197
+
198
+ # ── Diagnostics ──────────────────────────────────────────────────────────────
199
+
200
+ def _diagnose(claim_dict: dict) -> dict:
201
+ """
202
+ Generate pipeline diagnostics for a claim.
203
+
204
+ Tells the developer *why* a claim got its label —
205
+ was it a retrieval failure or a genuine hallucination?
206
+ """
207
+ label = claim_dict["label"]
208
+ ev_score = claim_dict.get("evidence_score")
209
+ confidence = claim_dict.get("confidence", 0)
210
+
211
+ # Retrieval quality assessment
212
+ if ev_score is None:
213
+ retrieval_quality = "none"
214
+ elif ev_score >= 0.7:
215
+ retrieval_quality = "strong"
216
+ elif ev_score >= 0.4:
217
+ retrieval_quality = "moderate"
218
+ else:
219
+ retrieval_quality = "weak"
220
+
221
+ # Failure type classification
222
+ if label == "supported":
223
+ failure_type = "verified"
224
+ suggestion = None
225
+ elif label == "contradicted":
226
+ if retrieval_quality in ("strong", "moderate"):
227
+ failure_type = "hallucination"
228
+ suggestion = "Claim directly contradicts the evidence. This is a factual error in the LLM output."
229
+ else:
230
+ failure_type = "possible_hallucination"
231
+ suggestion = "Claim contradicts weak evidence. Consider adding more specific context for reliable verification."
232
+ elif ev_score is None:
233
+ failure_type = "no_evidence"
234
+ suggestion = "No relevant context was provided. Add reference passages covering this topic."
235
+ elif retrieval_quality == "weak":
236
+ failure_type = "retrieval_gap"
237
+ suggestion = "Evidence was found but too dissimilar to trust. The context may not cover this claim."
238
+ else:
239
+ failure_type = "inconclusive"
240
+ suggestion = "Evidence exists but is neutral — neither confirms nor denies the claim."
241
+
242
+ d = {
243
+ "failure_type": failure_type,
244
+ "retrieval_quality": retrieval_quality,
245
+ }
246
+ if suggestion:
247
+ d["suggestion"] = suggestion
248
+ return d
249
+
250
+
251
+ # ── Summary ──────────────────────────────────────────────────────────────────
252
+
253
+ def _build_summary(results: list) -> dict:
254
+ """Build summary statistics from verification results."""
255
+ total = len(results)
256
+ supported = sum(1 for r in results if r.label == FactLabel.SUPPORTED)
257
+ contradicted = sum(1 for r in results if r.label == FactLabel.CONTRADICTED)
258
+ unverifiable = total - supported - contradicted
259
+
260
+ return {
261
+ "total_claims": total,
262
+ "supported": supported,
263
+ "contradicted": contradicted,
264
+ "unverifiable": unverifiable,
265
+ "hallucination_rate": round(contradicted / max(total, 1), 4),
266
+ }
267
+
268
+
269
+ # ── Semantic Highlighting ────────────────────────────────────────────────────
270
+
271
+ _LABEL_EMOJI = {"supported": "✅", "contradicted": "❌", "unverifiable": "❓"}
272
+ _LABEL_COLOR = {"supported": "#22c55e", "contradicted": "#ef4444", "unverifiable": "#f59e0b"}
273
+
274
+
275
+ def _highlight_answer_semantic(answer: str, claim_dicts: list[dict], embedder) -> str:
276
+ """
277
+ Map claims to source sentences using embedding similarity (not Jaccard).
278
+
279
+ Uses the retriever's SentenceTransformer to compute cosine similarity
280
+ between each claim and each sentence in the original answer. This handles
281
+ paraphrasing, reordering, and partial overlaps much better than token overlap.
282
+ """
283
+ if not answer.strip() or not claim_dicts:
284
+ return answer
285
+
286
+ # Split answer into sentences with positions
287
+ sentences = []
288
+ for m in re.finditer(r'[^.!?]+[.!?]*', answer):
289
+ text = m.group().strip()
290
+ if text:
291
+ sentences.append(text)
292
+
293
+ if not sentences:
294
+ return answer
295
+
296
+ # Compute embedding similarity
297
+ claim_texts = [c["claim"] for c in claim_dicts]
298
+ claim_labels = [c["label"] for c in claim_dicts]
299
+
300
+ sent_embeddings = embedder.encode(sentences, normalize_embeddings=True)
301
+ claim_embeddings = embedder.encode(claim_texts, normalize_embeddings=True)
302
+
303
+ # Similarity matrix: sentences × claims
304
+ sim_matrix = np.dot(sent_embeddings, claim_embeddings.T)
305
+
306
+ # For each sentence, find best matching claim
307
+ sentence_labels: dict[str, str] = {}
308
+ for i, sent_text in enumerate(sentences):
309
+ best_j = int(sim_matrix[i].argmax())
310
+ best_sim = float(sim_matrix[i, best_j])
311
+
312
+ if best_sim > 0.35: # Semantic similarity threshold
313
+ sentence_labels[sent_text] = claim_labels[best_j]
314
+
315
+ # Build highlighted text (longest matches first to avoid partial replacements)
316
+ highlighted = answer
317
+ for sent_text in sorted(sentence_labels, key=len, reverse=True):
318
+ label = sentence_labels[sent_text]
319
+ color = _LABEL_COLOR.get(label, "#94a3b8")
320
+ emoji = _LABEL_EMOJI.get(label, "")
321
+ highlighted = highlighted.replace(
322
+ sent_text,
323
+ f'<mark style="background:{color}30;padding:2px 4px;border-radius:3px">'
324
+ f'{sent_text} {emoji}</mark>',
325
+ 1,
326
+ )
327
+
328
+ return highlighted
facteval/models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic data models for FactEval's pipeline objects.
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class Claim(BaseModel):
9
+ """A single atomic, verifiable claim extracted from text."""
10
+
11
+ text: str = Field(..., description="The claim statement.")
12
+ source_text: str = Field(
13
+ default="",
14
+ description="The original text this claim was extracted from.",
15
+ )
16
+
17
+ def __str__(self) -> str:
18
+ return self.text
19
+
20
+
21
+ class Evidence(BaseModel):
22
+ """A single piece of evidence retrieved for a claim."""
23
+
24
+ sentence: str = Field(..., description="The evidence sentence.")
25
+ score: float = Field(
26
+ ..., ge=0.0, description="Cosine similarity score (may slightly exceed 1.0 due to float precision)."
27
+ )
28
+ source_context: str = Field(
29
+ default="",
30
+ description="The full context passage this sentence came from.",
31
+ )
32
+
33
+ def __str__(self) -> str:
34
+ return f"[{self.score:.3f}] {self.sentence}"
35
+
36
+
37
+ class ClaimWithEvidence(BaseModel):
38
+ """A claim paired with its retrieved evidence."""
39
+
40
+ claim: Claim
41
+ evidence: list[Evidence] = Field(default_factory=list)
42
+
43
+ @property
44
+ def best_evidence(self) -> Evidence | None:
45
+ """Return the highest-scoring evidence, or None."""
46
+ return self.evidence[0] if self.evidence else None
facteval/retriever.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evidence Retriever – FAISS-based semantic search over user-provided contexts.
3
+
4
+ Encodes context sentences with all-MiniLM-L6-v2 and retrieves the top-k
5
+ most similar evidence sentences for each claim.
6
+ """
7
+
8
+ import re
9
+ import logging
10
+
11
+ import numpy as np
12
+ import faiss
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ from facteval import suppress_stdout
16
+ from facteval.config import DEFAULT_TOP_K, EMBEDDING_MODEL, MIN_EVIDENCE_SCORE
17
+ from facteval.models import Claim, Evidence, ClaimWithEvidence
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class EvidenceRetriever:
23
+ """Build a FAISS index over context sentences and retrieve evidence for claims."""
24
+
25
+ def __init__(
26
+ self,
27
+ model_name: str = EMBEDDING_MODEL,
28
+ device: str | None = None,
29
+ ):
30
+ self.device = device or ("cuda" if __import__("torch").cuda.is_available() else "cpu")
31
+ logger.info("Loading embedding model: %s", model_name)
32
+ with suppress_stdout():
33
+ self.embedder = SentenceTransformer(model_name, device=self.device)
34
+
35
+ # Populated by .index()
36
+ self._sentences: list[str] = []
37
+ self._sentence_to_context: dict[int, str] = {}
38
+ self._index: faiss.IndexFlatIP | None = None
39
+
40
+ def index(self, contexts: list[str]) -> "EvidenceRetriever":
41
+ """
42
+ Build a FAISS index from a list of context passages.
43
+
44
+ Each context is split into individual sentences before indexing.
45
+
46
+ Args:
47
+ contexts: List of context passages (strings).
48
+
49
+ Returns:
50
+ self (for chaining: `retriever.index(ctx).retrieve(claim)`).
51
+ """
52
+ if not contexts:
53
+ logger.warning("No contexts provided; retriever will return empty results.")
54
+ self._sentences = []
55
+ self._index = None
56
+ return self
57
+
58
+ self._sentences = []
59
+ self._sentence_to_context = {}
60
+
61
+ for ctx in contexts:
62
+ for sent in self._split_sentences(ctx):
63
+ idx = len(self._sentences)
64
+ self._sentences.append(sent)
65
+ self._sentence_to_context[idx] = ctx
66
+
67
+ if not self._sentences:
68
+ logger.warning("No sentences extracted from contexts.")
69
+ self._index = None
70
+ return self
71
+
72
+ logger.info("Indexing %d evidence sentences.", len(self._sentences))
73
+ embeddings = self.embedder.encode(
74
+ self._sentences, convert_to_numpy=True, normalize_embeddings=True
75
+ ).astype(np.float32)
76
+
77
+ dim = embeddings.shape[1]
78
+ self._index = faiss.IndexFlatIP(dim) # Cosine similarity (normalized)
79
+ self._index.add(embeddings)
80
+
81
+ return self
82
+
83
+ def retrieve(
84
+ self,
85
+ claim: Claim | str,
86
+ top_k: int = DEFAULT_TOP_K,
87
+ min_score: float = MIN_EVIDENCE_SCORE,
88
+ ) -> list[Evidence]:
89
+ """
90
+ Retrieve the top-k most relevant evidence sentences for a claim.
91
+
92
+ Args:
93
+ claim: A Claim object or plain string.
94
+ top_k: Number of evidence sentences to return.
95
+ min_score: Minimum cosine similarity to include.
96
+
97
+ Returns:
98
+ List of Evidence objects, sorted by score descending.
99
+ """
100
+ if self._index is None or not self._sentences:
101
+ return []
102
+
103
+ query_text = claim.text if isinstance(claim, Claim) else claim
104
+ q_emb = self.embedder.encode(
105
+ [query_text], convert_to_numpy=True, normalize_embeddings=True
106
+ ).astype(np.float32)
107
+
108
+ scores, indices = self._index.search(q_emb, top_k)
109
+ results: list[Evidence] = []
110
+
111
+ for score, idx in zip(scores[0], indices[0]):
112
+ if idx < 0 or idx >= len(self._sentences):
113
+ continue
114
+ clamped_score = float(min(max(score, 0.0), 1.0))
115
+ if clamped_score < min_score:
116
+ continue
117
+ results.append(
118
+ Evidence(
119
+ sentence=self._sentences[idx],
120
+ score=clamped_score,
121
+ source_context=self._sentence_to_context.get(idx, ""),
122
+ )
123
+ )
124
+
125
+ return results
126
+
127
+ def retrieve_for_claims(
128
+ self,
129
+ claims: list[Claim],
130
+ top_k: int = DEFAULT_TOP_K,
131
+ min_score: float = MIN_EVIDENCE_SCORE,
132
+ ) -> list[ClaimWithEvidence]:
133
+ """
134
+ Batch-retrieve evidence for a list of claims.
135
+
136
+ Returns:
137
+ List of ClaimWithEvidence objects.
138
+ """
139
+ return [
140
+ ClaimWithEvidence(
141
+ claim=claim,
142
+ evidence=self.retrieve(claim, top_k=top_k, min_score=min_score),
143
+ )
144
+ for claim in claims
145
+ ]
146
+
147
+ @staticmethod
148
+ def _split_sentences(text: str) -> list[str]:
149
+ """Split text into sentences on sentence-ending punctuation."""
150
+ raw = re.split(r"(?<=[.!?])\s+", text)
151
+ return [s.strip() for s in raw if s.strip() and len(s.strip()) > 3]
facteval/verifier.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Verifier – NLI-based factual verification of claims against evidence.
3
+
4
+ Uses DeBERTa-v3 fine-tuned on MNLI+FEVER+ANLI to classify each
5
+ claim/evidence pair as entailment, contradiction, or neutral.
6
+ Maps NLI labels to FactEval labels: supported, contradicted, unverifiable.
7
+ """
8
+
9
+ import logging
10
+ from enum import Enum
11
+
12
+ import torch
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
+
15
+ from facteval import suppress_stdout
16
+ from facteval.config import NLI_MODEL, MIN_EVIDENCE_SCORE
17
+ from facteval.models import Claim, Evidence, ClaimWithEvidence
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class FactLabel(str, Enum):
23
+ """FactEval verdict labels."""
24
+ SUPPORTED = "supported"
25
+ CONTRADICTED = "contradicted"
26
+ UNVERIFIABLE = "unverifiable"
27
+
28
+
29
+ # Map DeBERTa NLI labels → FactEval labels
30
+ _NLI_TO_FACT = {
31
+ "entailment": FactLabel.SUPPORTED,
32
+ "contradiction": FactLabel.CONTRADICTED,
33
+ "neutral": FactLabel.UNVERIFIABLE,
34
+ }
35
+
36
+
37
+ class VerificationResult:
38
+ """Result of verifying a single claim."""
39
+
40
+ def __init__(
41
+ self,
42
+ claim: str,
43
+ label: FactLabel,
44
+ confidence: float,
45
+ evidence: str | None,
46
+ evidence_score: float | None,
47
+ raw_scores: dict[str, float],
48
+ reason: str = "",
49
+ calibrated_confidence: float | None = None,
50
+ calibration_error: float | None = None,
51
+ ):
52
+ self.claim = claim
53
+ self.label = label
54
+ self.confidence = confidence
55
+ self.evidence = evidence
56
+ self.evidence_score = evidence_score
57
+ self.raw_scores = raw_scores
58
+ self.reason = reason
59
+ self.calibrated_confidence = calibrated_confidence
60
+ self.calibration_error = calibration_error
61
+
62
+ def to_dict(self) -> dict:
63
+ d = {
64
+ "claim": self.claim,
65
+ "label": self.label.value,
66
+ "confidence": round(self.confidence, 4),
67
+ "reason": self.reason,
68
+ "evidence": self.evidence,
69
+ "evidence_score": round(self.evidence_score, 4) if self.evidence_score else None,
70
+ "raw_nli_scores": {k: round(v, 4) for k, v in self.raw_scores.items()},
71
+ }
72
+ if self.calibrated_confidence is not None:
73
+ d["calibrated_confidence"] = round(self.calibrated_confidence, 4)
74
+ if self.calibration_error is not None:
75
+ d["calibration_error"] = round(self.calibration_error, 4)
76
+ return d
77
+
78
+
79
+ class Verifier:
80
+ """Verify claims against evidence using NLI."""
81
+
82
+ def __init__(
83
+ self,
84
+ model_name: str = NLI_MODEL,
85
+ device: str | None = None,
86
+ ):
87
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
88
+ logger.info("Loading NLI model: %s on %s", model_name, self.device)
89
+
90
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
91
+ with suppress_stdout():
92
+ self.model = AutoModelForSequenceClassification.from_pretrained(
93
+ model_name
94
+ ).to(self.device)
95
+ self.model.eval()
96
+
97
+ self.id2label = self.model.config.id2label
98
+ logger.info("Verifier ready. Labels: %s", self.id2label)
99
+
100
+ def verify(
101
+ self,
102
+ claim_with_evidence: ClaimWithEvidence,
103
+ min_evidence_score: float = MIN_EVIDENCE_SCORE,
104
+ ) -> VerificationResult:
105
+ """
106
+ Verify a single claim against its retrieved evidence.
107
+
108
+ If no evidence meets the min_score threshold, returns 'unverifiable'
109
+ with zero confidence.
110
+ """
111
+ claim_text = claim_with_evidence.claim.text
112
+ best = claim_with_evidence.best_evidence
113
+
114
+ # Fallback: no usable evidence
115
+ if best is None or best.score < min_evidence_score:
116
+ logger.debug("No evidence for claim: %s", claim_text)
117
+ return VerificationResult(
118
+ claim=claim_text,
119
+ label=FactLabel.UNVERIFIABLE,
120
+ confidence=0.0,
121
+ evidence=None,
122
+ evidence_score=None,
123
+ raw_scores={},
124
+ reason="No relevant evidence found in the provided context.",
125
+ )
126
+
127
+ # Run NLI: premise=evidence, hypothesis=claim
128
+ return self._run_nli(claim_text, best.sentence, best.score)
129
+
130
+ def verify_batch(
131
+ self,
132
+ claims_with_evidence: list[ClaimWithEvidence],
133
+ min_evidence_score: float = MIN_EVIDENCE_SCORE,
134
+ ) -> list[VerificationResult]:
135
+ """
136
+ Verify a batch of claims using batched NLI inference.
137
+
138
+ Claims without evidence are immediately marked unverifiable.
139
+ Remaining claims are processed in a single forward pass for speed.
140
+ """
141
+ results: list[VerificationResult | None] = [None] * len(claims_with_evidence)
142
+ nli_pairs: list[tuple[int, str, str, float]] = []
143
+
144
+ for i, cwe in enumerate(claims_with_evidence):
145
+ claim_text = cwe.claim.text
146
+ best = cwe.best_evidence
147
+
148
+ if best is None or best.score < min_evidence_score:
149
+ results[i] = VerificationResult(
150
+ claim=claim_text,
151
+ label=FactLabel.UNVERIFIABLE,
152
+ confidence=0.0,
153
+ evidence=None,
154
+ evidence_score=None,
155
+ raw_scores={},
156
+ reason="No relevant evidence found in the provided context.",
157
+ )
158
+ else:
159
+ nli_pairs.append((i, claim_text, best.sentence, best.score))
160
+
161
+ # Batch NLI inference for all claims with evidence
162
+ if nli_pairs:
163
+ indices, claims, evidences, scores = zip(*nli_pairs)
164
+ inputs = self.tokenizer(
165
+ list(evidences), list(claims),
166
+ return_tensors="pt",
167
+ padding=True,
168
+ truncation=True,
169
+ max_length=512,
170
+ ).to(self.device)
171
+
172
+ with torch.no_grad():
173
+ logits = self.model(**inputs).logits
174
+
175
+ all_probs = torch.softmax(logits, dim=-1).cpu()
176
+
177
+ for idx, probs_t, claim, evidence, ev_score in zip(
178
+ indices, all_probs, claims, evidences, scores
179
+ ):
180
+ probs = probs_t.tolist()
181
+ label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
182
+ predicted_nli = self.id2label[probs_t.argmax().item()]
183
+ fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)
184
+
185
+ results[idx] = VerificationResult(
186
+ claim=claim,
187
+ label=fact_label,
188
+ confidence=max(probs),
189
+ evidence=evidence,
190
+ evidence_score=ev_score,
191
+ raw_scores=label_probs,
192
+ reason=self._make_reason(fact_label, evidence),
193
+ )
194
+
195
+ return results
196
+
197
+ def _run_nli(
198
+ self, claim: str, evidence: str, evidence_score: float
199
+ ) -> VerificationResult:
200
+ """Run NLI inference on a single claim/evidence pair."""
201
+ inputs = self.tokenizer(
202
+ evidence, claim,
203
+ return_tensors="pt",
204
+ truncation=True,
205
+ max_length=512,
206
+ ).to(self.device)
207
+
208
+ with torch.no_grad():
209
+ logits = self.model(**inputs).logits
210
+
211
+ probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
212
+ label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
213
+ predicted_nli = self.id2label[logits.argmax().item()]
214
+ fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)
215
+
216
+ return VerificationResult(
217
+ claim=claim,
218
+ label=fact_label,
219
+ confidence=max(probs),
220
+ evidence=evidence,
221
+ evidence_score=evidence_score,
222
+ raw_scores=label_probs,
223
+ reason=self._make_reason(fact_label, evidence),
224
+ )
225
+
226
+ @staticmethod
227
+ def _make_reason(label: FactLabel, evidence: str) -> str:
228
+ """Generate a human-readable reason for the verdict."""
229
+ ev_short = evidence[:80] + "..." if len(evidence) > 80 else evidence
230
+ if label == FactLabel.SUPPORTED:
231
+ return f"Supported by evidence: \"{ev_short}\""
232
+ elif label == FactLabel.CONTRADICTED:
233
+ return f"Contradicts evidence: \"{ev_short}\""
234
+ else:
235
+ return f"Evidence is neutral — neither confirms nor denies: \"{ev_short}\""
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.36
3
+ sentence-transformers>=2.2
4
+ faiss-cpu>=1.7
5
+ scikit-learn>=1.3
6
+ pydantic>=2.0
7
+ accelerate>=0.25
8
+ numpy>=1.24
9
+ gradio>=4.0