Spaces:
Running
Running
Sahil al farib commited on
Commit ·
8fb73f8
1
Parent(s): b29689f
Deploy FactEval: claim-level hallucination detection with Gradio demo
Browse files- README.md +20 -6
- app.py +5 -0
- demo/app.py +202 -0
- facteval/__init__.py +61 -0
- facteval/calibrator.py +90 -0
- facteval/claim_extractor.py +138 -0
- facteval/cli.py +127 -0
- facteval/config.py +29 -0
- facteval/core.py +328 -0
- facteval/models.py +46 -0
- facteval/retriever.py +151 -0
- facteval/verifier.py +235 -0
- requirements.txt +9 -0
README.md
CHANGED
|
@@ -1,15 +1,29 @@
|
|
| 1 |
---
|
| 2 |
title: FactEval
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
python_version: '3.13'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: mit
|
| 12 |
short_description: Find exactly which parts of your LLM output are hallucinated
|
| 13 |
---
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: FactEval
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: Find exactly which parts of your LLM output are hallucinated
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# 🔍 FactEval
|
| 15 |
+
|
| 16 |
+
**Find exactly which parts of your LLM output are hallucinated.**
|
| 17 |
+
|
| 18 |
+
Debug hallucinations in RAG and LLM pipelines with claim-level verification.
|
| 19 |
+
|
| 20 |
+
Paste an LLM-generated answer and reference contexts. FactEval highlights ✅ **supported**, ❌ **contradicted**, and ❓ **unverifiable** claims with human-readable reasons and pipeline diagnostics.
|
| 21 |
+
|
| 22 |
+
## How it works
|
| 23 |
+
|
| 24 |
+
1. **Claim Extraction** — Breaks the answer into atomic claims (Qwen2.5-1.5B)
|
| 25 |
+
2. **Evidence Retrieval** — Finds the most relevant sentences from your contexts (MiniLM + FAISS)
|
| 26 |
+
3. **NLI Verification** — Checks each claim against evidence (DeBERTa-v3)
|
| 27 |
+
4. **Calibration** — Produces trustworthy confidence scores (Isotonic Regression)
|
| 28 |
+
|
| 29 |
+
📦 [GitHub Repository](https://github.com/sahilaf/FactEval)
|
app.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Spaces entry point — imports the Gradio demo from demo/app.py."""
|
| 2 |
+
from demo.app import demo
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
demo.launch()
|
demo/app.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactEval Gradio Demo – Interactive factuality checker.
|
| 3 |
+
|
| 4 |
+
Run locally: python demo/app.py
|
| 5 |
+
Run on Colab: Upload facteval/ folder, then run this file.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from facteval import check, verify
|
| 11 |
+
|
| 12 |
+
EXAMPLES = [
|
| 13 |
+
[
|
| 14 |
+
"Paris is the capital of Germany and has 5 million people.",
|
| 15 |
+
"Paris is the capital of France. Paris has approximately 2.2 million inhabitants.\nGermany's capital is Berlin.",
|
| 16 |
+
],
|
| 17 |
+
[
|
| 18 |
+
"Python was created by Guido van Rossum and first released in 2005.",
|
| 19 |
+
"Python was created by Guido van Rossum and first released in 1991.",
|
| 20 |
+
],
|
| 21 |
+
[
|
| 22 |
+
"The Amazon rainforest produces 20% of the world's oxygen and spans across nine countries.",
|
| 23 |
+
"The Amazon rainforest produces about 6% of the world's oxygen.\nThe Amazon rainforest spans across nine countries in South America.",
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
"Albert Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921 for his work on the photoelectric effect.",
|
| 27 |
+
"Albert Einstein developed the theory of relativity. He won the Nobel Prize in Physics in 1921 for his explanation of the photoelectric effect.",
|
| 28 |
+
],
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def run_check(answer: str, contexts: str, calibrator_path: str = ""):
|
| 33 |
+
"""Run FactEval pipeline and format results for Gradio."""
|
| 34 |
+
if not answer.strip():
|
| 35 |
+
return "⚠️ Please enter an answer to check.", "", "", ""
|
| 36 |
+
|
| 37 |
+
context_list = [c.strip() for c in contexts.strip().split("\n") if c.strip()]
|
| 38 |
+
if not context_list:
|
| 39 |
+
return "⚠️ Please enter at least one context passage.", "", "", ""
|
| 40 |
+
|
| 41 |
+
cal_path = calibrator_path.strip() if calibrator_path.strip() else None
|
| 42 |
+
result = check(answer, context_list, calibrator_path=cal_path)
|
| 43 |
+
|
| 44 |
+
# 1. Highlighted answer (the viral feature)
|
| 45 |
+
highlighted_html = f"""
|
| 46 |
+
<div style="font-family: Inter, sans-serif; font-size: 18px; line-height: 2;
|
| 47 |
+
padding: 20px; border-radius: 12px; background: #0f172a; color: #e2e8f0;">
|
| 48 |
+
{result.get("highlighted_answer", answer)}
|
| 49 |
+
</div>
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# 2. Per-claim verdicts with reasons
|
| 53 |
+
details_parts = []
|
| 54 |
+
for c in result["claims"]:
|
| 55 |
+
label = c["label"]
|
| 56 |
+
colors = {"supported": "#22c55e", "contradicted": "#ef4444", "unverifiable": "#f59e0b"}
|
| 57 |
+
emojis = {"supported": "✅", "contradicted": "❌", "unverifiable": "❓"}
|
| 58 |
+
color = colors.get(label, "#94a3b8")
|
| 59 |
+
emoji = emojis.get(label, "")
|
| 60 |
+
conf = c.get("calibrated_confidence", c["confidence"])
|
| 61 |
+
|
| 62 |
+
diag = c.get("diagnostics", {})
|
| 63 |
+
diag_type = diag.get("failure_type", "")
|
| 64 |
+
diag_badge_colors = {
|
| 65 |
+
"verified": "#22c55e", "hallucination": "#ef4444", "possible_hallucination": "#f97316",
|
| 66 |
+
"no_evidence": "#6b7280", "retrieval_gap": "#8b5cf6", "inconclusive": "#f59e0b",
|
| 67 |
+
}
|
| 68 |
+
badge_color = diag_badge_colors.get(diag_type, "#64748b")
|
| 69 |
+
suggestion = diag.get("suggestion", "")
|
| 70 |
+
|
| 71 |
+
details_parts.append(f"""
|
| 72 |
+
<div style="padding: 12px; margin: 8px 0; border-left: 4px solid {color};
|
| 73 |
+
background: {color}10; border-radius: 0 8px 8px 0; font-family: Inter, sans-serif;">
|
| 74 |
+
<div style="font-weight: 600; font-size: 15px; color: #f1f5f9;">
|
| 75 |
+
{emoji} {c["claim"]}
|
| 76 |
+
<span style="font-size: 11px; padding: 2px 8px; border-radius: 12px;
|
| 77 |
+
background: {badge_color}30; color: {badge_color}; margin-left: 8px;">
|
| 78 |
+
{diag_type.replace("_", " ")}
|
| 79 |
+
</span>
|
| 80 |
+
</div>
|
| 81 |
+
<div style="font-size: 13px; color: #94a3b8; margin-top: 4px;">
|
| 82 |
+
{c.get("reason", "")}
|
| 83 |
+
</div>
|
| 84 |
+
{'<div style="font-size: 12px; color: #f59e0b; margin-top: 4px; font-style: italic;">💡 ' + suggestion + '</div>' if suggestion else ''}
|
| 85 |
+
<div style="font-size: 12px; color: #64748b; margin-top: 4px;">
|
| 86 |
+
Confidence: {conf:.1%}
|
| 87 |
+
{"• Evidence score: " + f"{c['evidence_score']:.3f}" if c.get("evidence_score") else ""}
|
| 88 |
+
• Retrieval: {diag.get("retrieval_quality", "n/a")}
|
| 89 |
+
</div>
|
| 90 |
+
</div>
|
| 91 |
+
""")
|
| 92 |
+
|
| 93 |
+
details_html = '<div>' + ''.join(details_parts) + '</div>'
|
| 94 |
+
|
| 95 |
+
# 3. Summary card
|
| 96 |
+
s = result["summary"]
|
| 97 |
+
summary_html = f"""
|
| 98 |
+
<div style="font-family: Inter, sans-serif; padding: 16px; border-radius: 12px;
|
| 99 |
+
background: linear-gradient(135deg, #1e293b, #334155); color: white;">
|
| 100 |
+
<h3 style="margin: 0 0 12px 0; color: #e2e8f0;">📊 Summary</h3>
|
| 101 |
+
<div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 8px;">
|
| 102 |
+
<div style="padding: 8px; background: #ffffff10; border-radius: 8px;">
|
| 103 |
+
<div style="font-size: 24px; font-weight: bold;">{s['total_claims']}</div>
|
| 104 |
+
<div style="font-size: 12px; color: #94a3b8;">Total Claims</div>
|
| 105 |
+
</div>
|
| 106 |
+
<div style="padding: 8px; background: #22c55e20; border-radius: 8px;">
|
| 107 |
+
<div style="font-size: 24px; font-weight: bold; color: #22c55e;">{s['supported']}</div>
|
| 108 |
+
<div style="font-size: 12px; color: #94a3b8;">Supported</div>
|
| 109 |
+
</div>
|
| 110 |
+
<div style="padding: 8px; background: #ef444420; border-radius: 8px;">
|
| 111 |
+
<div style="font-size: 24px; font-weight: bold; color: #ef4444;">{s['contradicted']}</div>
|
| 112 |
+
<div style="font-size: 12px; color: #94a3b8;">Contradicted</div>
|
| 113 |
+
</div>
|
| 114 |
+
<div style="padding: 8px; background: #f59e0b20; border-radius: 8px;">
|
| 115 |
+
<div style="font-size: 24px; font-weight: bold; color: #f59e0b;">{s['unverifiable']}</div>
|
| 116 |
+
<div style="font-size: 12px; color: #94a3b8;">Unverifiable</div>
|
| 117 |
+
</div>
|
| 118 |
+
</div>
|
| 119 |
+
<div style="margin-top: 12px; padding: 8px; background: #ffffff10; border-radius: 8px; text-align: center;">
|
| 120 |
+
<span style="font-size: 14px; color: #94a3b8;">Hallucination Rate</span><br>
|
| 121 |
+
<span style="font-size: 28px; font-weight: bold;
|
| 122 |
+
color: {'#22c55e' if s['hallucination_rate'] < 0.3 else '#ef4444'};">
|
| 123 |
+
{s['hallucination_rate']:.0%}
|
| 124 |
+
</span>
|
| 125 |
+
</div>
|
| 126 |
+
<div style="margin-top: 8px; font-size: 11px; color: #64748b; text-align: right;">
|
| 127 |
+
⏱ {result['pipeline_time_seconds']:.1f}s
|
| 128 |
+
{'• 📐 calibrated' if result.get('calibrated') else '• raw scores'}
|
| 129 |
+
</div>
|
| 130 |
+
</div>
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# 4. Raw JSON
|
| 134 |
+
json_output = json.dumps(result, indent=2, ensure_ascii=False)
|
| 135 |
+
|
| 136 |
+
return highlighted_html, details_html, summary_html, json_output
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ── Gradio Interface ─────────────────────────────────────────────────────────
|
| 140 |
+
|
| 141 |
+
with gr.Blocks(
|
| 142 |
+
title="FactEval – Hallucination Detector",
|
| 143 |
+
theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
|
| 144 |
+
css="""
|
| 145 |
+
.gradio-container { max-width: 960px !important; }
|
| 146 |
+
footer { display: none !important; }
|
| 147 |
+
""",
|
| 148 |
+
) as demo:
|
| 149 |
+
gr.Markdown(
|
| 150 |
+
"""
|
| 151 |
+
# 🔍 FactEval – Find Exactly Which Parts Are Hallucinated
|
| 152 |
+
Paste an LLM-generated answer and reference contexts.
|
| 153 |
+
FactEval highlights ✅ **supported**, ❌ **contradicted**, and ❓ **unverifiable** claims.
|
| 154 |
+
"""
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
with gr.Row():
|
| 158 |
+
with gr.Column(scale=1):
|
| 159 |
+
answer_input = gr.Textbox(
|
| 160 |
+
label="LLM Answer",
|
| 161 |
+
placeholder="Enter the text to fact-check...",
|
| 162 |
+
lines=4,
|
| 163 |
+
)
|
| 164 |
+
context_input = gr.Textbox(
|
| 165 |
+
label="Reference Contexts (one per line)",
|
| 166 |
+
placeholder="Enter ground truth passages, one per line...",
|
| 167 |
+
lines=5,
|
| 168 |
+
)
|
| 169 |
+
calibrator_input = gr.Textbox(
|
| 170 |
+
label="Calibrator Path (optional)",
|
| 171 |
+
placeholder="Path to calibrator.pkl",
|
| 172 |
+
lines=1,
|
| 173 |
+
)
|
| 174 |
+
check_btn = gr.Button("🔍 Check Factuality", variant="primary", size="lg")
|
| 175 |
+
|
| 176 |
+
gr.Markdown("### 📝 Highlighted Answer")
|
| 177 |
+
highlighted_output = gr.HTML()
|
| 178 |
+
|
| 179 |
+
with gr.Row():
|
| 180 |
+
with gr.Column(scale=2):
|
| 181 |
+
gr.Markdown("### 📋 Claim Details")
|
| 182 |
+
details_output = gr.HTML()
|
| 183 |
+
with gr.Column(scale=1):
|
| 184 |
+
summary_output = gr.HTML()
|
| 185 |
+
|
| 186 |
+
with gr.Accordion("Raw JSON Output", open=False):
|
| 187 |
+
json_output = gr.Code(language="json")
|
| 188 |
+
|
| 189 |
+
check_btn.click(
|
| 190 |
+
fn=run_check,
|
| 191 |
+
inputs=[answer_input, context_input, calibrator_input],
|
| 192 |
+
outputs=[highlighted_output, details_output, summary_output, json_output],
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
gr.Examples(
|
| 196 |
+
examples=EXAMPLES,
|
| 197 |
+
inputs=[answer_input, context_input],
|
| 198 |
+
label="Try these examples",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
demo.launch(share=True)
|
facteval/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FactEval – Find exactly which parts of your LLM output are hallucinated."""
|
| 2 |
+
|
| 3 |
+
# Suppress known harmless warnings from dependencies before any imports
|
| 4 |
+
import os as _os
|
| 5 |
+
import sys as _sys
|
| 6 |
+
import warnings as _warnings
|
| 7 |
+
import logging as _logging
|
| 8 |
+
import contextlib as _contextlib
|
| 9 |
+
import io as _io
|
| 10 |
+
|
| 11 |
+
# Suppress safetensors / accelerate noise
|
| 12 |
+
_os.environ.setdefault("SAFETENSORS_LOG_LEVEL", "error")
|
| 13 |
+
_os.environ.setdefault("ACCELERATE_LOG_LEVEL", "error")
|
| 14 |
+
_logging.getLogger("safetensors").setLevel(_logging.ERROR)
|
| 15 |
+
_logging.getLogger("accelerate").setLevel(_logging.ERROR)
|
| 16 |
+
|
| 17 |
+
# Suppress HF Hub unauthenticated request warnings
|
| 18 |
+
_logging.getLogger("huggingface_hub.utils._http").setLevel(_logging.ERROR)
|
| 19 |
+
_logging.getLogger("huggingface_hub").setLevel(_logging.ERROR)
|
| 20 |
+
|
| 21 |
+
# Suppress transformers info-level noise
|
| 22 |
+
_logging.getLogger("transformers.modeling_utils").setLevel(_logging.ERROR)
|
| 23 |
+
_logging.getLogger("transformers.generation.configuration_utils").setLevel(_logging.ERROR)
|
| 24 |
+
|
| 25 |
+
# Suppress FutureWarning about clean_up_tokenization_spaces
|
| 26 |
+
_warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@_contextlib.contextmanager
|
| 30 |
+
def suppress_loading_noise():
|
| 31 |
+
"""Suppress stdout + stderr noise during model loading (LOAD REPORT, sharding info)."""
|
| 32 |
+
old_stdout, old_stderr = _sys.stdout, _sys.stderr
|
| 33 |
+
_sys.stdout = _io.StringIO()
|
| 34 |
+
_sys.stderr = _io.StringIO()
|
| 35 |
+
try:
|
| 36 |
+
yield
|
| 37 |
+
finally:
|
| 38 |
+
_sys.stdout = old_stdout
|
| 39 |
+
_sys.stderr = old_stderr
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Backward compat alias
|
| 43 |
+
suppress_stdout = suppress_loading_noise
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ── Public API ───────────────────────────────────────────────────────────────
|
| 47 |
+
from facteval.core import check, verify
|
| 48 |
+
from facteval.models import Claim, Evidence, ClaimWithEvidence
|
| 49 |
+
from facteval.verifier import FactLabel, VerificationResult
|
| 50 |
+
|
| 51 |
+
__version__ = "0.1.0"
|
| 52 |
+
__all__ = [
|
| 53 |
+
"check",
|
| 54 |
+
"verify",
|
| 55 |
+
"Claim",
|
| 56 |
+
"Evidence",
|
| 57 |
+
"ClaimWithEvidence",
|
| 58 |
+
"FactLabel",
|
| 59 |
+
"VerificationResult",
|
| 60 |
+
"suppress_loading_noise",
|
| 61 |
+
]
|
facteval/calibrator.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calibrator – Transforms raw NLI scores into calibrated probabilities.
|
| 3 |
+
|
| 4 |
+
Uses isotonic regression models (fitted in Week 0) to produce trustworthy
|
| 5 |
+
confidence scores and calibration error estimates.
|
| 6 |
+
|
| 7 |
+
Falls back gracefully to raw scores if no calibrator file is available.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import pickle
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Calibrator:
|
| 18 |
+
"""Apply isotonic regression calibration to raw NLI probabilities."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, calibrator_path: str | Path | None = None):
|
| 21 |
+
"""
|
| 22 |
+
Load a pre-fitted calibrator from a pickle file.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
calibrator_path: Path to the pickle file containing a dict of
|
| 26 |
+
{label_name: IsotonicRegression} objects.
|
| 27 |
+
If None or file doesn't exist, falls back to raw scores.
|
| 28 |
+
"""
|
| 29 |
+
self._calibrators: dict | None = None
|
| 30 |
+
|
| 31 |
+
if calibrator_path is not None:
|
| 32 |
+
path = Path(calibrator_path)
|
| 33 |
+
if path.exists():
|
| 34 |
+
with open(path, "rb") as f:
|
| 35 |
+
self._calibrators = pickle.load(f)
|
| 36 |
+
logger.info(
|
| 37 |
+
"Loaded calibrator from %s (labels: %s)",
|
| 38 |
+
path, list(self._calibrators.keys()),
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
logger.warning("Calibrator file not found: %s. Using raw scores.", path)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def is_calibrated(self) -> bool:
|
| 45 |
+
"""Whether a calibrator is loaded."""
|
| 46 |
+
return self._calibrators is not None
|
| 47 |
+
|
| 48 |
+
def calibrate(self, raw_scores: dict[str, float]) -> tuple[float, float]:
|
| 49 |
+
"""
|
| 50 |
+
Calibrate raw NLI probabilities.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
raw_scores: Dict mapping label names to raw probabilities
|
| 54 |
+
(e.g. {"entailment": 0.95, "neutral": 0.03, "contradiction": 0.02}).
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
(calibrated_confidence, calibration_error)
|
| 58 |
+
- calibrated_confidence: The calibrated probability for the predicted label.
|
| 59 |
+
- calibration_error: Absolute difference between raw and calibrated confidence.
|
| 60 |
+
"""
|
| 61 |
+
if not raw_scores:
|
| 62 |
+
return 0.0, 0.0
|
| 63 |
+
|
| 64 |
+
# Find the predicted label (highest raw score)
|
| 65 |
+
predicted_label = max(raw_scores, key=raw_scores.get)
|
| 66 |
+
raw_confidence = raw_scores[predicted_label]
|
| 67 |
+
|
| 68 |
+
if not self.is_calibrated:
|
| 69 |
+
# Fallback: return raw confidence with an estimated error
|
| 70 |
+
return raw_confidence, self._estimate_error(raw_confidence)
|
| 71 |
+
|
| 72 |
+
# Apply isotonic regression for each label
|
| 73 |
+
calibrated_scores = {}
|
| 74 |
+
for label, raw_prob in raw_scores.items():
|
| 75 |
+
if label in self._calibrators:
|
| 76 |
+
cal_prob = float(self._calibrators[label].predict([[raw_prob]])[0])
|
| 77 |
+
calibrated_scores[label] = max(0.0, min(1.0, cal_prob))
|
| 78 |
+
else:
|
| 79 |
+
calibrated_scores[label] = raw_prob
|
| 80 |
+
|
| 81 |
+
calibrated_confidence = calibrated_scores.get(predicted_label, raw_confidence)
|
| 82 |
+
calibration_error = abs(raw_confidence - calibrated_confidence)
|
| 83 |
+
|
| 84 |
+
return round(calibrated_confidence, 4), round(calibration_error, 4)
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def _estimate_error(raw_confidence: float) -> float:
|
| 88 |
+
"""Rough error estimate when no calibrator is available."""
|
| 89 |
+
# Higher confidence → lower estimated error, but never zero
|
| 90 |
+
return round(max(0.02, (1.0 - raw_confidence) * 0.3), 4)
|
facteval/claim_extractor.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Claim Extractor – Breaks text into atomic, verifiable claims.
|
| 3 |
+
|
| 4 |
+
Uses Qwen2.5-1.5B-Instruct (chosen in Week 0 for speed and output quality)
|
| 5 |
+
with the model's chat template to produce clean numbered lists.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from facteval import suppress_stdout
|
| 15 |
+
|
| 16 |
+
from facteval.config import (
|
| 17 |
+
CLAIM_MODEL,
|
| 18 |
+
CLAIM_SYSTEM_PROMPT,
|
| 19 |
+
CLAIM_USER_PROMPT,
|
| 20 |
+
MAX_CLAIMS,
|
| 21 |
+
MAX_NEW_TOKENS,
|
| 22 |
+
)
|
| 23 |
+
from facteval.models import Claim
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ClaimExtractor:
|
| 29 |
+
"""Extract atomic claims from text using a causal LM with chat prompting."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
model_name: str = CLAIM_MODEL,
|
| 34 |
+
device: str | None = None,
|
| 35 |
+
dtype: torch.dtype | None = None,
|
| 36 |
+
):
|
| 37 |
+
self.model_name = model_name
|
| 38 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
self.dtype = dtype or (torch.float16 if self.device == "cuda" else torch.float32)
|
| 40 |
+
|
| 41 |
+
logger.info("Loading claim extractor: %s on %s", model_name, self.device)
|
| 42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
+
model_name, trust_remote_code=True
|
| 44 |
+
)
|
| 45 |
+
with suppress_stdout():
|
| 46 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_name,
|
| 48 |
+
dtype=self.dtype,
|
| 49 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
)
|
| 52 |
+
if self.device == "cpu":
|
| 53 |
+
self.model = self.model.to(self.device)
|
| 54 |
+
self.model.eval()
|
| 55 |
+
|
| 56 |
+
# Clear sampling params from generation_config to avoid
|
| 57 |
+
# "generation flags are not valid" warnings with do_sample=False
|
| 58 |
+
gen_cfg = self.model.generation_config
|
| 59 |
+
for attr in ("temperature", "top_p", "top_k"):
|
| 60 |
+
if hasattr(gen_cfg, attr):
|
| 61 |
+
setattr(gen_cfg, attr, None)
|
| 62 |
+
|
| 63 |
+
logger.info("Claim extractor ready.")
|
| 64 |
+
|
| 65 |
+
def extract(
|
| 66 |
+
self,
|
| 67 |
+
text: str,
|
| 68 |
+
max_claims: int = MAX_CLAIMS,
|
| 69 |
+
max_new_tokens: int = MAX_NEW_TOKENS,
|
| 70 |
+
) -> list[Claim]:
|
| 71 |
+
"""
|
| 72 |
+
Extract atomic claims from *text*.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
text: The text to decompose into claims.
|
| 76 |
+
max_claims: Maximum number of claims to return.
|
| 77 |
+
max_new_tokens: Generation length cap (prevents rambling).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
A deduplicated list of Claim objects.
|
| 81 |
+
"""
|
| 82 |
+
if not text or not text.strip():
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
raw_output = self._generate(text, max_new_tokens)
|
| 86 |
+
claims = self._parse_claims(raw_output, text, max_claims)
|
| 87 |
+
logger.info("Extracted %d claims from %d-char text.", len(claims), len(text))
|
| 88 |
+
return claims
|
| 89 |
+
|
| 90 |
+
# ── Private helpers ──────────────────────────────────────────────────────
|
| 91 |
+
|
| 92 |
+
def _generate(self, text: str, max_new_tokens: int) -> str:
|
| 93 |
+
"""Run the LLM to generate claim text."""
|
| 94 |
+
messages = [
|
| 95 |
+
{"role": "system", "content": CLAIM_SYSTEM_PROMPT},
|
| 96 |
+
{"role": "user", "content": CLAIM_USER_PROMPT.format(text=text)},
|
| 97 |
+
]
|
| 98 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 99 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 100 |
+
)
|
| 101 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
output_ids = self.model.generate(
|
| 105 |
+
**inputs,
|
| 106 |
+
max_new_tokens=max_new_tokens,
|
| 107 |
+
do_sample=False,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Decode only the newly generated tokens
|
| 111 |
+
generated = output_ids[0][inputs["input_ids"].shape[1]:]
|
| 112 |
+
return self.tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _parse_claims(
|
| 116 |
+
raw: str, source_text: str, max_claims: int
|
| 117 |
+
) -> list[Claim]:
|
| 118 |
+
"""Parse numbered/bulleted list into deduplicated Claim objects."""
|
| 119 |
+
seen: set[str] = set()
|
| 120 |
+
claims: list[Claim] = []
|
| 121 |
+
|
| 122 |
+
for line in raw.split("\n"):
|
| 123 |
+
# Strip numbering (e.g. "1.", "1)", "- ", "• ")
|
| 124 |
+
cleaned = re.sub(r"^[\d.\)\-•\s]+", "", line).strip()
|
| 125 |
+
if len(cleaned) <= 5:
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Normalize for dedup (lowercase, collapse whitespace)
|
| 129 |
+
key = re.sub(r"\s+", " ", cleaned.lower())
|
| 130 |
+
if key in seen:
|
| 131 |
+
continue
|
| 132 |
+
seen.add(key)
|
| 133 |
+
|
| 134 |
+
claims.append(Claim(text=cleaned, source_text=source_text))
|
| 135 |
+
if len(claims) >= max_claims:
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
return claims
|
facteval/cli.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI – Command-line interface for FactEval.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
facteval check input.json
|
| 6 |
+
facteval check input.json --output output.json
|
| 7 |
+
facteval check --answer "..." --context "ctx1" --context "ctx2"
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
"""Entry point for the facteval CLI."""
|
| 18 |
+
parser = argparse.ArgumentParser(
|
| 19 |
+
prog="facteval",
|
| 20 |
+
description="FactEval – Claim-level factuality evaluation with calibrated confidence.",
|
| 21 |
+
)
|
| 22 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 23 |
+
|
| 24 |
+
# ── facteval check ───────────────────────────────────────────────────────
|
| 25 |
+
check_parser = subparsers.add_parser(
|
| 26 |
+
"check", help="Check an answer for factual accuracy against provided contexts."
|
| 27 |
+
)
|
| 28 |
+
check_parser.add_argument(
|
| 29 |
+
"input_file", nargs="?", default=None,
|
| 30 |
+
help='JSON file with "answer" and "contexts" keys.',
|
| 31 |
+
)
|
| 32 |
+
check_parser.add_argument(
|
| 33 |
+
"--answer", "-a", type=str, default=None,
|
| 34 |
+
help="The answer text to check (alternative to input file).",
|
| 35 |
+
)
|
| 36 |
+
check_parser.add_argument(
|
| 37 |
+
"--context", "-c", action="append", default=None,
|
| 38 |
+
help="Context passage (can be repeated). Alternative to input file.",
|
| 39 |
+
)
|
| 40 |
+
check_parser.add_argument(
|
| 41 |
+
"--output", "-o", type=str, default=None,
|
| 42 |
+
help="Output file path. If not provided, prints to stdout.",
|
| 43 |
+
)
|
| 44 |
+
check_parser.add_argument(
|
| 45 |
+
"--calibrator", type=str, default=None,
|
| 46 |
+
help="Path to a pre-fitted calibrator pickle file.",
|
| 47 |
+
)
|
| 48 |
+
check_parser.add_argument(
|
| 49 |
+
"--top-k", type=int, default=3,
|
| 50 |
+
help="Number of evidence sentences to retrieve per claim (default: 3).",
|
| 51 |
+
)
|
| 52 |
+
check_parser.add_argument(
|
| 53 |
+
"--max-claims", type=int, default=10,
|
| 54 |
+
help="Maximum number of claims to extract (default: 10).",
|
| 55 |
+
)
|
| 56 |
+
check_parser.add_argument(
|
| 57 |
+
"--verbose", "-v", action="store_true",
|
| 58 |
+
help="Enable verbose logging.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
if args.command is None:
|
| 64 |
+
parser.print_help()
|
| 65 |
+
sys.exit(0)
|
| 66 |
+
|
| 67 |
+
if args.command == "check":
|
| 68 |
+
_run_check(args)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _run_check(args):
|
| 72 |
+
"""Execute the check command."""
|
| 73 |
+
# Configure logging
|
| 74 |
+
level = logging.INFO if args.verbose else logging.WARNING
|
| 75 |
+
logging.basicConfig(level=level, format="%(name)s | %(message)s")
|
| 76 |
+
|
| 77 |
+
# Parse input
|
| 78 |
+
answer, contexts = _parse_input(args)
|
| 79 |
+
if answer is None:
|
| 80 |
+
print("Error: Provide either an input JSON file or --answer + --context flags.", file=sys.stderr)
|
| 81 |
+
sys.exit(1)
|
| 82 |
+
|
| 83 |
+
# Import here to avoid slow import on --help
|
| 84 |
+
from facteval.core import check
|
| 85 |
+
|
| 86 |
+
# Run pipeline
|
| 87 |
+
result = check(
|
| 88 |
+
answer=answer,
|
| 89 |
+
contexts=contexts,
|
| 90 |
+
top_k=args.top_k,
|
| 91 |
+
max_claims=args.max_claims,
|
| 92 |
+
calibrator_path=args.calibrator,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Output
|
| 96 |
+
output_json = json.dumps(result, indent=2, ensure_ascii=False)
|
| 97 |
+
|
| 98 |
+
if args.output:
|
| 99 |
+
with open(args.output, "w", encoding="utf-8") as f:
|
| 100 |
+
f.write(output_json)
|
| 101 |
+
print(f"Results saved to {args.output}", file=sys.stderr)
|
| 102 |
+
else:
|
| 103 |
+
print(output_json)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _parse_input(args) -> tuple[str | None, list[str]]:
|
| 107 |
+
"""Parse answer and contexts from file or CLI flags."""
|
| 108 |
+
# Option 1: JSON file
|
| 109 |
+
if args.input_file:
|
| 110 |
+
with open(args.input_file, "r", encoding="utf-8") as f:
|
| 111 |
+
data = json.load(f)
|
| 112 |
+
return data.get("answer"), data.get("contexts", [])
|
| 113 |
+
|
| 114 |
+
# Option 2: CLI flags
|
| 115 |
+
if args.answer:
|
| 116 |
+
return args.answer, args.context or []
|
| 117 |
+
|
| 118 |
+
# Option 3: stdin
|
| 119 |
+
if not sys.stdin.isatty():
|
| 120 |
+
data = json.load(sys.stdin)
|
| 121 |
+
return data.get("answer"), data.get("contexts", [])
|
| 122 |
+
|
| 123 |
+
return None, []
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
facteval/config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Default configuration for FactEval models and parameters.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# ── Model IDs (Hugging Face Hub) ─────────────────────────────────────────────
|
| 6 |
+
# Claim extraction – chosen in Week 0: 1.5B was 3.5x faster with cleaner output
|
| 7 |
+
CLAIM_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 8 |
+
|
| 9 |
+
# Sentence embeddings for evidence retrieval
|
| 10 |
+
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 11 |
+
|
| 12 |
+
# NLI verification (used in Week 2)
|
| 13 |
+
NLI_MODEL = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
| 14 |
+
|
| 15 |
+
# ── Retrieval defaults ───────────────────────────────────────────────────────
|
| 16 |
+
DEFAULT_TOP_K = 3
|
| 17 |
+
MIN_EVIDENCE_SCORE = 0.3 # Below this, evidence is too weak to use
|
| 18 |
+
|
| 19 |
+
# ── Claim extraction defaults ────────────────────────────────────────────────
|
| 20 |
+
MAX_NEW_TOKENS = 200
|
| 21 |
+
MAX_CLAIMS = 10
|
| 22 |
+
|
| 23 |
+
CLAIM_SYSTEM_PROMPT = (
|
| 24 |
+
"You are a claim extraction engine. Given a text, break it into atomic, "
|
| 25 |
+
"independently verifiable claims. Each claim states exactly ONE fact. "
|
| 26 |
+
"Return ONLY a numbered list. No explanations, no commentary."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
CLAIM_USER_PROMPT = "Break this into atomic claims:\n\n{text}"
|
facteval/core.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core – The main check() and verify() functions that wire the FactEval pipeline.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from facteval import check, verify
|
| 6 |
+
|
| 7 |
+
# Full pipeline (extract + retrieve + verify)
|
| 8 |
+
result = check(answer, contexts)
|
| 9 |
+
|
| 10 |
+
# Lightweight mode (skip extraction, bring your own claims)
|
| 11 |
+
result = verify(claims=["claim 1", "claim 2"], contexts=docs)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from facteval.calibrator import Calibrator
|
| 22 |
+
from facteval.claim_extractor import ClaimExtractor
|
| 23 |
+
from facteval.retriever import EvidenceRetriever
|
| 24 |
+
from facteval.verifier import Verifier, FactLabel
|
| 25 |
+
from facteval.models import Claim
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
# Module-level singletons (lazy-loaded)
|
| 30 |
+
_extractor: ClaimExtractor | None = None
|
| 31 |
+
_retriever: EvidenceRetriever | None = None
|
| 32 |
+
_verifier: Verifier | None = None
|
| 33 |
+
_calibrator: Calibrator | None = None
|
| 34 |
+
_calibrator_path: str | None = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_extractor() -> ClaimExtractor:
|
| 38 |
+
global _extractor
|
| 39 |
+
if _extractor is None:
|
| 40 |
+
_extractor = ClaimExtractor()
|
| 41 |
+
return _extractor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_retriever() -> EvidenceRetriever:
|
| 45 |
+
global _retriever
|
| 46 |
+
if _retriever is None:
|
| 47 |
+
_retriever = EvidenceRetriever()
|
| 48 |
+
return _retriever
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _get_verifier() -> Verifier:
|
| 52 |
+
global _verifier
|
| 53 |
+
if _verifier is None:
|
| 54 |
+
_verifier = Verifier()
|
| 55 |
+
return _verifier
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_calibrator(path: str | None = None) -> Calibrator:
|
| 59 |
+
global _calibrator, _calibrator_path
|
| 60 |
+
if _calibrator is None or path != _calibrator_path:
|
| 61 |
+
_calibrator = Calibrator(calibrator_path=path)
|
| 62 |
+
_calibrator_path = path
|
| 63 |
+
return _calibrator
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ── Full pipeline ────────────────────────────────────────────────────────────
|
| 67 |
+
|
| 68 |
+
def check(
|
| 69 |
+
answer: str,
|
| 70 |
+
contexts: list[str],
|
| 71 |
+
top_k: int = 3,
|
| 72 |
+
max_claims: int = 10,
|
| 73 |
+
calibrator_path: str | Path | None = None,
|
| 74 |
+
) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Run the full FactEval pipeline on an answer + contexts.
|
| 77 |
+
|
| 78 |
+
Stages: extract claims → retrieve evidence → NLI verify → calibrate.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
answer: The LLM-generated text to evaluate.
|
| 82 |
+
contexts: List of reference passages (ground truth).
|
| 83 |
+
top_k: Number of evidence sentences to retrieve per claim.
|
| 84 |
+
max_claims: Maximum claims to extract.
|
| 85 |
+
calibrator_path: Path to a pre-fitted calibrator pickle file.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
A dict with claims, summary, highlighted_answer, and pipeline_time.
|
| 89 |
+
"""
|
| 90 |
+
t0 = time.perf_counter()
|
| 91 |
+
|
| 92 |
+
# 1. Extract claims
|
| 93 |
+
extractor = _get_extractor()
|
| 94 |
+
claims = extractor.extract(answer, max_claims=max_claims)
|
| 95 |
+
logger.info("Extracted %d claims.", len(claims))
|
| 96 |
+
|
| 97 |
+
if not claims:
|
| 98 |
+
return {
|
| 99 |
+
"claims": [],
|
| 100 |
+
"summary": _build_summary([]),
|
| 101 |
+
"highlighted_answer": answer,
|
| 102 |
+
"calibrated": False,
|
| 103 |
+
"pipeline_time_seconds": round(time.perf_counter() - t0, 3),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# 2–5. Shared pipeline
|
| 107 |
+
return _run_pipeline(claims, contexts, answer, top_k, calibrator_path, t0)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ── Lightweight mode ─────────────────────────────────────────────────────────
|
| 111 |
+
|
| 112 |
+
def verify(
|
| 113 |
+
claims: list[str],
|
| 114 |
+
contexts: list[str],
|
| 115 |
+
top_k: int = 3,
|
| 116 |
+
calibrator_path: str | Path | None = None,
|
| 117 |
+
) -> dict:
|
| 118 |
+
"""
|
| 119 |
+
Verify pre-extracted claims against contexts. Skips claim extraction.
|
| 120 |
+
|
| 121 |
+
Use this when you already have claims and want faster results
|
| 122 |
+
(avoids the ~1s extraction step and the Qwen model entirely).
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
claims: List of claim strings to verify.
|
| 126 |
+
contexts: List of reference passages (ground truth).
|
| 127 |
+
top_k: Number of evidence sentences to retrieve per claim.
|
| 128 |
+
calibrator_path: Path to a pre-fitted calibrator pickle file.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Same output format as check().
|
| 132 |
+
"""
|
| 133 |
+
t0 = time.perf_counter()
|
| 134 |
+
|
| 135 |
+
claim_objs = [Claim(text=c) for c in claims if c.strip()]
|
| 136 |
+
|
| 137 |
+
if not claim_objs:
|
| 138 |
+
return {
|
| 139 |
+
"claims": [],
|
| 140 |
+
"summary": _build_summary([]),
|
| 141 |
+
"highlighted_answer": "",
|
| 142 |
+
"calibrated": False,
|
| 143 |
+
"pipeline_time_seconds": round(time.perf_counter() - t0, 3),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
answer = " ".join(claims) # reconstruct for highlighting
|
| 147 |
+
return _run_pipeline(claim_objs, contexts, answer, top_k, calibrator_path, t0)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ── Shared pipeline ──────────────────────────────────────────────────────────
|
| 151 |
+
|
| 152 |
+
def _run_pipeline(
|
| 153 |
+
claims: list[Claim],
|
| 154 |
+
contexts: list[str],
|
| 155 |
+
answer: str,
|
| 156 |
+
top_k: int,
|
| 157 |
+
calibrator_path: str | Path | None,
|
| 158 |
+
t0: float,
|
| 159 |
+
) -> dict:
|
| 160 |
+
"""Shared pipeline: retrieve → verify → calibrate → diagnose → highlight."""
|
| 161 |
+
|
| 162 |
+
# 2. Retrieve evidence
|
| 163 |
+
retriever = _get_retriever()
|
| 164 |
+
retriever.index(contexts)
|
| 165 |
+
claims_with_evidence = retriever.retrieve_for_claims(claims, top_k=top_k)
|
| 166 |
+
|
| 167 |
+
# 3. Verify (batch NLI)
|
| 168 |
+
verifier = _get_verifier()
|
| 169 |
+
results = verifier.verify_batch(claims_with_evidence)
|
| 170 |
+
|
| 171 |
+
# 4. Calibrate
|
| 172 |
+
calibrator = _get_calibrator(str(calibrator_path) if calibrator_path else None)
|
| 173 |
+
for r in results:
|
| 174 |
+
if r.raw_scores:
|
| 175 |
+
cal_conf, cal_err = calibrator.calibrate(r.raw_scores)
|
| 176 |
+
r.calibrated_confidence = cal_conf
|
| 177 |
+
r.calibration_error = cal_err
|
| 178 |
+
|
| 179 |
+
# 5. Build output with diagnostics
|
| 180 |
+
elapsed = time.perf_counter() - t0
|
| 181 |
+
claim_dicts = [r.to_dict() for r in results]
|
| 182 |
+
|
| 183 |
+
# Add diagnostics to each claim
|
| 184 |
+
for cd in claim_dicts:
|
| 185 |
+
cd["diagnostics"] = _diagnose(cd)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"claims": claim_dicts,
|
| 189 |
+
"summary": _build_summary(results),
|
| 190 |
+
"highlighted_answer": _highlight_answer_semantic(
|
| 191 |
+
answer, claim_dicts, retriever.embedder
|
| 192 |
+
),
|
| 193 |
+
"calibrated": calibrator.is_calibrated,
|
| 194 |
+
"pipeline_time_seconds": round(elapsed, 3),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ── Diagnostics ──────────────────────────────────────────────────────────────
|
| 199 |
+
|
| 200 |
+
def _diagnose(claim_dict: dict) -> dict:
|
| 201 |
+
"""
|
| 202 |
+
Generate pipeline diagnostics for a claim.
|
| 203 |
+
|
| 204 |
+
Tells the developer *why* a claim got its label —
|
| 205 |
+
was it a retrieval failure or a genuine hallucination?
|
| 206 |
+
"""
|
| 207 |
+
label = claim_dict["label"]
|
| 208 |
+
ev_score = claim_dict.get("evidence_score")
|
| 209 |
+
confidence = claim_dict.get("confidence", 0)
|
| 210 |
+
|
| 211 |
+
# Retrieval quality assessment
|
| 212 |
+
if ev_score is None:
|
| 213 |
+
retrieval_quality = "none"
|
| 214 |
+
elif ev_score >= 0.7:
|
| 215 |
+
retrieval_quality = "strong"
|
| 216 |
+
elif ev_score >= 0.4:
|
| 217 |
+
retrieval_quality = "moderate"
|
| 218 |
+
else:
|
| 219 |
+
retrieval_quality = "weak"
|
| 220 |
+
|
| 221 |
+
# Failure type classification
|
| 222 |
+
if label == "supported":
|
| 223 |
+
failure_type = "verified"
|
| 224 |
+
suggestion = None
|
| 225 |
+
elif label == "contradicted":
|
| 226 |
+
if retrieval_quality in ("strong", "moderate"):
|
| 227 |
+
failure_type = "hallucination"
|
| 228 |
+
suggestion = "Claim directly contradicts the evidence. This is a factual error in the LLM output."
|
| 229 |
+
else:
|
| 230 |
+
failure_type = "possible_hallucination"
|
| 231 |
+
suggestion = "Claim contradicts weak evidence. Consider adding more specific context for reliable verification."
|
| 232 |
+
elif ev_score is None:
|
| 233 |
+
failure_type = "no_evidence"
|
| 234 |
+
suggestion = "No relevant context was provided. Add reference passages covering this topic."
|
| 235 |
+
elif retrieval_quality == "weak":
|
| 236 |
+
failure_type = "retrieval_gap"
|
| 237 |
+
suggestion = "Evidence was found but too dissimilar to trust. The context may not cover this claim."
|
| 238 |
+
else:
|
| 239 |
+
failure_type = "inconclusive"
|
| 240 |
+
suggestion = "Evidence exists but is neutral — neither confirms nor denies the claim."
|
| 241 |
+
|
| 242 |
+
d = {
|
| 243 |
+
"failure_type": failure_type,
|
| 244 |
+
"retrieval_quality": retrieval_quality,
|
| 245 |
+
}
|
| 246 |
+
if suggestion:
|
| 247 |
+
d["suggestion"] = suggestion
|
| 248 |
+
return d
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ── Summary ──────────────────────────────────────────────────────────────────
|
| 252 |
+
|
| 253 |
+
def _build_summary(results: list) -> dict:
|
| 254 |
+
"""Build summary statistics from verification results."""
|
| 255 |
+
total = len(results)
|
| 256 |
+
supported = sum(1 for r in results if r.label == FactLabel.SUPPORTED)
|
| 257 |
+
contradicted = sum(1 for r in results if r.label == FactLabel.CONTRADICTED)
|
| 258 |
+
unverifiable = total - supported - contradicted
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"total_claims": total,
|
| 262 |
+
"supported": supported,
|
| 263 |
+
"contradicted": contradicted,
|
| 264 |
+
"unverifiable": unverifiable,
|
| 265 |
+
"hallucination_rate": round(contradicted / max(total, 1), 4),
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ── Semantic Highlighting ────────────────────────────────────────────────────
|
| 270 |
+
|
| 271 |
+
_LABEL_EMOJI = {"supported": "✅", "contradicted": "❌", "unverifiable": "❓"}
|
| 272 |
+
_LABEL_COLOR = {"supported": "#22c55e", "contradicted": "#ef4444", "unverifiable": "#f59e0b"}
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _highlight_answer_semantic(answer: str, claim_dicts: list[dict], embedder) -> str:
|
| 276 |
+
"""
|
| 277 |
+
Map claims to source sentences using embedding similarity (not Jaccard).
|
| 278 |
+
|
| 279 |
+
Uses the retriever's SentenceTransformer to compute cosine similarity
|
| 280 |
+
between each claim and each sentence in the original answer. This handles
|
| 281 |
+
paraphrasing, reordering, and partial overlaps much better than token overlap.
|
| 282 |
+
"""
|
| 283 |
+
if not answer.strip() or not claim_dicts:
|
| 284 |
+
return answer
|
| 285 |
+
|
| 286 |
+
# Split answer into sentences with positions
|
| 287 |
+
sentences = []
|
| 288 |
+
for m in re.finditer(r'[^.!?]+[.!?]*', answer):
|
| 289 |
+
text = m.group().strip()
|
| 290 |
+
if text:
|
| 291 |
+
sentences.append(text)
|
| 292 |
+
|
| 293 |
+
if not sentences:
|
| 294 |
+
return answer
|
| 295 |
+
|
| 296 |
+
# Compute embedding similarity
|
| 297 |
+
claim_texts = [c["claim"] for c in claim_dicts]
|
| 298 |
+
claim_labels = [c["label"] for c in claim_dicts]
|
| 299 |
+
|
| 300 |
+
sent_embeddings = embedder.encode(sentences, normalize_embeddings=True)
|
| 301 |
+
claim_embeddings = embedder.encode(claim_texts, normalize_embeddings=True)
|
| 302 |
+
|
| 303 |
+
# Similarity matrix: sentences × claims
|
| 304 |
+
sim_matrix = np.dot(sent_embeddings, claim_embeddings.T)
|
| 305 |
+
|
| 306 |
+
# For each sentence, find best matching claim
|
| 307 |
+
sentence_labels: dict[str, str] = {}
|
| 308 |
+
for i, sent_text in enumerate(sentences):
|
| 309 |
+
best_j = int(sim_matrix[i].argmax())
|
| 310 |
+
best_sim = float(sim_matrix[i, best_j])
|
| 311 |
+
|
| 312 |
+
if best_sim > 0.35: # Semantic similarity threshold
|
| 313 |
+
sentence_labels[sent_text] = claim_labels[best_j]
|
| 314 |
+
|
| 315 |
+
# Build highlighted text (longest matches first to avoid partial replacements)
|
| 316 |
+
highlighted = answer
|
| 317 |
+
for sent_text in sorted(sentence_labels, key=len, reverse=True):
|
| 318 |
+
label = sentence_labels[sent_text]
|
| 319 |
+
color = _LABEL_COLOR.get(label, "#94a3b8")
|
| 320 |
+
emoji = _LABEL_EMOJI.get(label, "")
|
| 321 |
+
highlighted = highlighted.replace(
|
| 322 |
+
sent_text,
|
| 323 |
+
f'<mark style="background:{color}30;padding:2px 4px;border-radius:3px">'
|
| 324 |
+
f'{sent_text} {emoji}</mark>',
|
| 325 |
+
1,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return highlighted
|
facteval/models.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic data models for FactEval's pipeline objects.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Claim(BaseModel):
|
| 9 |
+
"""A single atomic, verifiable claim extracted from text."""
|
| 10 |
+
|
| 11 |
+
text: str = Field(..., description="The claim statement.")
|
| 12 |
+
source_text: str = Field(
|
| 13 |
+
default="",
|
| 14 |
+
description="The original text this claim was extracted from.",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def __str__(self) -> str:
|
| 18 |
+
return self.text
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Evidence(BaseModel):
|
| 22 |
+
"""A single piece of evidence retrieved for a claim."""
|
| 23 |
+
|
| 24 |
+
sentence: str = Field(..., description="The evidence sentence.")
|
| 25 |
+
score: float = Field(
|
| 26 |
+
..., ge=0.0, description="Cosine similarity score (may slightly exceed 1.0 due to float precision)."
|
| 27 |
+
)
|
| 28 |
+
source_context: str = Field(
|
| 29 |
+
default="",
|
| 30 |
+
description="The full context passage this sentence came from.",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def __str__(self) -> str:
|
| 34 |
+
return f"[{self.score:.3f}] {self.sentence}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ClaimWithEvidence(BaseModel):
|
| 38 |
+
"""A claim paired with its retrieved evidence."""
|
| 39 |
+
|
| 40 |
+
claim: Claim
|
| 41 |
+
evidence: list[Evidence] = Field(default_factory=list)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def best_evidence(self) -> Evidence | None:
|
| 45 |
+
"""Return the highest-scoring evidence, or None."""
|
| 46 |
+
return self.evidence[0] if self.evidence else None
|
facteval/retriever.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evidence Retriever – FAISS-based semantic search over user-provided contexts.
|
| 3 |
+
|
| 4 |
+
Encodes context sentences with all-MiniLM-L6-v2 and retrieves the top-k
|
| 5 |
+
most similar evidence sentences for each claim.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import faiss
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
|
| 15 |
+
from facteval import suppress_stdout
|
| 16 |
+
from facteval.config import DEFAULT_TOP_K, EMBEDDING_MODEL, MIN_EVIDENCE_SCORE
|
| 17 |
+
from facteval.models import Claim, Evidence, ClaimWithEvidence
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EvidenceRetriever:
|
| 23 |
+
"""Build a FAISS index over context sentences and retrieve evidence for claims."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
model_name: str = EMBEDDING_MODEL,
|
| 28 |
+
device: str | None = None,
|
| 29 |
+
):
|
| 30 |
+
self.device = device or ("cuda" if __import__("torch").cuda.is_available() else "cpu")
|
| 31 |
+
logger.info("Loading embedding model: %s", model_name)
|
| 32 |
+
with suppress_stdout():
|
| 33 |
+
self.embedder = SentenceTransformer(model_name, device=self.device)
|
| 34 |
+
|
| 35 |
+
# Populated by .index()
|
| 36 |
+
self._sentences: list[str] = []
|
| 37 |
+
self._sentence_to_context: dict[int, str] = {}
|
| 38 |
+
self._index: faiss.IndexFlatIP | None = None
|
| 39 |
+
|
| 40 |
+
def index(self, contexts: list[str]) -> "EvidenceRetriever":
|
| 41 |
+
"""
|
| 42 |
+
Build a FAISS index from a list of context passages.
|
| 43 |
+
|
| 44 |
+
Each context is split into individual sentences before indexing.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
contexts: List of context passages (strings).
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
self (for chaining: `retriever.index(ctx).retrieve(claim)`).
|
| 51 |
+
"""
|
| 52 |
+
if not contexts:
|
| 53 |
+
logger.warning("No contexts provided; retriever will return empty results.")
|
| 54 |
+
self._sentences = []
|
| 55 |
+
self._index = None
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
self._sentences = []
|
| 59 |
+
self._sentence_to_context = {}
|
| 60 |
+
|
| 61 |
+
for ctx in contexts:
|
| 62 |
+
for sent in self._split_sentences(ctx):
|
| 63 |
+
idx = len(self._sentences)
|
| 64 |
+
self._sentences.append(sent)
|
| 65 |
+
self._sentence_to_context[idx] = ctx
|
| 66 |
+
|
| 67 |
+
if not self._sentences:
|
| 68 |
+
logger.warning("No sentences extracted from contexts.")
|
| 69 |
+
self._index = None
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
logger.info("Indexing %d evidence sentences.", len(self._sentences))
|
| 73 |
+
embeddings = self.embedder.encode(
|
| 74 |
+
self._sentences, convert_to_numpy=True, normalize_embeddings=True
|
| 75 |
+
).astype(np.float32)
|
| 76 |
+
|
| 77 |
+
dim = embeddings.shape[1]
|
| 78 |
+
self._index = faiss.IndexFlatIP(dim) # Cosine similarity (normalized)
|
| 79 |
+
self._index.add(embeddings)
|
| 80 |
+
|
| 81 |
+
return self
|
| 82 |
+
|
| 83 |
+
def retrieve(
|
| 84 |
+
self,
|
| 85 |
+
claim: Claim | str,
|
| 86 |
+
top_k: int = DEFAULT_TOP_K,
|
| 87 |
+
min_score: float = MIN_EVIDENCE_SCORE,
|
| 88 |
+
) -> list[Evidence]:
|
| 89 |
+
"""
|
| 90 |
+
Retrieve the top-k most relevant evidence sentences for a claim.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
claim: A Claim object or plain string.
|
| 94 |
+
top_k: Number of evidence sentences to return.
|
| 95 |
+
min_score: Minimum cosine similarity to include.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List of Evidence objects, sorted by score descending.
|
| 99 |
+
"""
|
| 100 |
+
if self._index is None or not self._sentences:
|
| 101 |
+
return []
|
| 102 |
+
|
| 103 |
+
query_text = claim.text if isinstance(claim, Claim) else claim
|
| 104 |
+
q_emb = self.embedder.encode(
|
| 105 |
+
[query_text], convert_to_numpy=True, normalize_embeddings=True
|
| 106 |
+
).astype(np.float32)
|
| 107 |
+
|
| 108 |
+
scores, indices = self._index.search(q_emb, top_k)
|
| 109 |
+
results: list[Evidence] = []
|
| 110 |
+
|
| 111 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 112 |
+
if idx < 0 or idx >= len(self._sentences):
|
| 113 |
+
continue
|
| 114 |
+
clamped_score = float(min(max(score, 0.0), 1.0))
|
| 115 |
+
if clamped_score < min_score:
|
| 116 |
+
continue
|
| 117 |
+
results.append(
|
| 118 |
+
Evidence(
|
| 119 |
+
sentence=self._sentences[idx],
|
| 120 |
+
score=clamped_score,
|
| 121 |
+
source_context=self._sentence_to_context.get(idx, ""),
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return results
|
| 126 |
+
|
| 127 |
+
def retrieve_for_claims(
|
| 128 |
+
self,
|
| 129 |
+
claims: list[Claim],
|
| 130 |
+
top_k: int = DEFAULT_TOP_K,
|
| 131 |
+
min_score: float = MIN_EVIDENCE_SCORE,
|
| 132 |
+
) -> list[ClaimWithEvidence]:
|
| 133 |
+
"""
|
| 134 |
+
Batch-retrieve evidence for a list of claims.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
List of ClaimWithEvidence objects.
|
| 138 |
+
"""
|
| 139 |
+
return [
|
| 140 |
+
ClaimWithEvidence(
|
| 141 |
+
claim=claim,
|
| 142 |
+
evidence=self.retrieve(claim, top_k=top_k, min_score=min_score),
|
| 143 |
+
)
|
| 144 |
+
for claim in claims
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _split_sentences(text: str) -> list[str]:
|
| 149 |
+
"""Split text into sentences on sentence-ending punctuation."""
|
| 150 |
+
raw = re.split(r"(?<=[.!?])\s+", text)
|
| 151 |
+
return [s.strip() for s in raw if s.strip() and len(s.strip()) > 3]
|
facteval/verifier.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Verifier – NLI-based factual verification of claims against evidence.
|
| 3 |
+
|
| 4 |
+
Uses DeBERTa-v3 fine-tuned on MNLI+FEVER+ANLI to classify each
|
| 5 |
+
claim/evidence pair as entailment, contradiction, or neutral.
|
| 6 |
+
Maps NLI labels to FactEval labels: supported, contradicted, unverifiable.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
from facteval import suppress_stdout
|
| 16 |
+
from facteval.config import NLI_MODEL, MIN_EVIDENCE_SCORE
|
| 17 |
+
from facteval.models import Claim, Evidence, ClaimWithEvidence
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FactLabel(str, Enum):
|
| 23 |
+
"""FactEval verdict labels."""
|
| 24 |
+
SUPPORTED = "supported"
|
| 25 |
+
CONTRADICTED = "contradicted"
|
| 26 |
+
UNVERIFIABLE = "unverifiable"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Map DeBERTa NLI labels → FactEval labels
|
| 30 |
+
_NLI_TO_FACT = {
|
| 31 |
+
"entailment": FactLabel.SUPPORTED,
|
| 32 |
+
"contradiction": FactLabel.CONTRADICTED,
|
| 33 |
+
"neutral": FactLabel.UNVERIFIABLE,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VerificationResult:
|
| 38 |
+
"""Result of verifying a single claim."""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
claim: str,
|
| 43 |
+
label: FactLabel,
|
| 44 |
+
confidence: float,
|
| 45 |
+
evidence: str | None,
|
| 46 |
+
evidence_score: float | None,
|
| 47 |
+
raw_scores: dict[str, float],
|
| 48 |
+
reason: str = "",
|
| 49 |
+
calibrated_confidence: float | None = None,
|
| 50 |
+
calibration_error: float | None = None,
|
| 51 |
+
):
|
| 52 |
+
self.claim = claim
|
| 53 |
+
self.label = label
|
| 54 |
+
self.confidence = confidence
|
| 55 |
+
self.evidence = evidence
|
| 56 |
+
self.evidence_score = evidence_score
|
| 57 |
+
self.raw_scores = raw_scores
|
| 58 |
+
self.reason = reason
|
| 59 |
+
self.calibrated_confidence = calibrated_confidence
|
| 60 |
+
self.calibration_error = calibration_error
|
| 61 |
+
|
| 62 |
+
def to_dict(self) -> dict:
|
| 63 |
+
d = {
|
| 64 |
+
"claim": self.claim,
|
| 65 |
+
"label": self.label.value,
|
| 66 |
+
"confidence": round(self.confidence, 4),
|
| 67 |
+
"reason": self.reason,
|
| 68 |
+
"evidence": self.evidence,
|
| 69 |
+
"evidence_score": round(self.evidence_score, 4) if self.evidence_score else None,
|
| 70 |
+
"raw_nli_scores": {k: round(v, 4) for k, v in self.raw_scores.items()},
|
| 71 |
+
}
|
| 72 |
+
if self.calibrated_confidence is not None:
|
| 73 |
+
d["calibrated_confidence"] = round(self.calibrated_confidence, 4)
|
| 74 |
+
if self.calibration_error is not None:
|
| 75 |
+
d["calibration_error"] = round(self.calibration_error, 4)
|
| 76 |
+
return d
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Verifier:
|
| 80 |
+
"""Verify claims against evidence using NLI."""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
model_name: str = NLI_MODEL,
|
| 85 |
+
device: str | None = None,
|
| 86 |
+
):
|
| 87 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
+
logger.info("Loading NLI model: %s on %s", model_name, self.device)
|
| 89 |
+
|
| 90 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 91 |
+
with suppress_stdout():
|
| 92 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 93 |
+
model_name
|
| 94 |
+
).to(self.device)
|
| 95 |
+
self.model.eval()
|
| 96 |
+
|
| 97 |
+
self.id2label = self.model.config.id2label
|
| 98 |
+
logger.info("Verifier ready. Labels: %s", self.id2label)
|
| 99 |
+
|
| 100 |
+
def verify(
|
| 101 |
+
self,
|
| 102 |
+
claim_with_evidence: ClaimWithEvidence,
|
| 103 |
+
min_evidence_score: float = MIN_EVIDENCE_SCORE,
|
| 104 |
+
) -> VerificationResult:
|
| 105 |
+
"""
|
| 106 |
+
Verify a single claim against its retrieved evidence.
|
| 107 |
+
|
| 108 |
+
If no evidence meets the min_score threshold, returns 'unverifiable'
|
| 109 |
+
with zero confidence.
|
| 110 |
+
"""
|
| 111 |
+
claim_text = claim_with_evidence.claim.text
|
| 112 |
+
best = claim_with_evidence.best_evidence
|
| 113 |
+
|
| 114 |
+
# Fallback: no usable evidence
|
| 115 |
+
if best is None or best.score < min_evidence_score:
|
| 116 |
+
logger.debug("No evidence for claim: %s", claim_text)
|
| 117 |
+
return VerificationResult(
|
| 118 |
+
claim=claim_text,
|
| 119 |
+
label=FactLabel.UNVERIFIABLE,
|
| 120 |
+
confidence=0.0,
|
| 121 |
+
evidence=None,
|
| 122 |
+
evidence_score=None,
|
| 123 |
+
raw_scores={},
|
| 124 |
+
reason="No relevant evidence found in the provided context.",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Run NLI: premise=evidence, hypothesis=claim
|
| 128 |
+
return self._run_nli(claim_text, best.sentence, best.score)
|
| 129 |
+
|
| 130 |
+
def verify_batch(
|
| 131 |
+
self,
|
| 132 |
+
claims_with_evidence: list[ClaimWithEvidence],
|
| 133 |
+
min_evidence_score: float = MIN_EVIDENCE_SCORE,
|
| 134 |
+
) -> list[VerificationResult]:
|
| 135 |
+
"""
|
| 136 |
+
Verify a batch of claims using batched NLI inference.
|
| 137 |
+
|
| 138 |
+
Claims without evidence are immediately marked unverifiable.
|
| 139 |
+
Remaining claims are processed in a single forward pass for speed.
|
| 140 |
+
"""
|
| 141 |
+
results: list[VerificationResult | None] = [None] * len(claims_with_evidence)
|
| 142 |
+
nli_pairs: list[tuple[int, str, str, float]] = []
|
| 143 |
+
|
| 144 |
+
for i, cwe in enumerate(claims_with_evidence):
|
| 145 |
+
claim_text = cwe.claim.text
|
| 146 |
+
best = cwe.best_evidence
|
| 147 |
+
|
| 148 |
+
if best is None or best.score < min_evidence_score:
|
| 149 |
+
results[i] = VerificationResult(
|
| 150 |
+
claim=claim_text,
|
| 151 |
+
label=FactLabel.UNVERIFIABLE,
|
| 152 |
+
confidence=0.0,
|
| 153 |
+
evidence=None,
|
| 154 |
+
evidence_score=None,
|
| 155 |
+
raw_scores={},
|
| 156 |
+
reason="No relevant evidence found in the provided context.",
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
nli_pairs.append((i, claim_text, best.sentence, best.score))
|
| 160 |
+
|
| 161 |
+
# Batch NLI inference for all claims with evidence
|
| 162 |
+
if nli_pairs:
|
| 163 |
+
indices, claims, evidences, scores = zip(*nli_pairs)
|
| 164 |
+
inputs = self.tokenizer(
|
| 165 |
+
list(evidences), list(claims),
|
| 166 |
+
return_tensors="pt",
|
| 167 |
+
padding=True,
|
| 168 |
+
truncation=True,
|
| 169 |
+
max_length=512,
|
| 170 |
+
).to(self.device)
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
logits = self.model(**inputs).logits
|
| 174 |
+
|
| 175 |
+
all_probs = torch.softmax(logits, dim=-1).cpu()
|
| 176 |
+
|
| 177 |
+
for idx, probs_t, claim, evidence, ev_score in zip(
|
| 178 |
+
indices, all_probs, claims, evidences, scores
|
| 179 |
+
):
|
| 180 |
+
probs = probs_t.tolist()
|
| 181 |
+
label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
|
| 182 |
+
predicted_nli = self.id2label[probs_t.argmax().item()]
|
| 183 |
+
fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)
|
| 184 |
+
|
| 185 |
+
results[idx] = VerificationResult(
|
| 186 |
+
claim=claim,
|
| 187 |
+
label=fact_label,
|
| 188 |
+
confidence=max(probs),
|
| 189 |
+
evidence=evidence,
|
| 190 |
+
evidence_score=ev_score,
|
| 191 |
+
raw_scores=label_probs,
|
| 192 |
+
reason=self._make_reason(fact_label, evidence),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return results
|
| 196 |
+
|
| 197 |
+
def _run_nli(
|
| 198 |
+
self, claim: str, evidence: str, evidence_score: float
|
| 199 |
+
) -> VerificationResult:
|
| 200 |
+
"""Run NLI inference on a single claim/evidence pair."""
|
| 201 |
+
inputs = self.tokenizer(
|
| 202 |
+
evidence, claim,
|
| 203 |
+
return_tensors="pt",
|
| 204 |
+
truncation=True,
|
| 205 |
+
max_length=512,
|
| 206 |
+
).to(self.device)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
logits = self.model(**inputs).logits
|
| 210 |
+
|
| 211 |
+
probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
|
| 212 |
+
label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
|
| 213 |
+
predicted_nli = self.id2label[logits.argmax().item()]
|
| 214 |
+
fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)
|
| 215 |
+
|
| 216 |
+
return VerificationResult(
|
| 217 |
+
claim=claim,
|
| 218 |
+
label=fact_label,
|
| 219 |
+
confidence=max(probs),
|
| 220 |
+
evidence=evidence,
|
| 221 |
+
evidence_score=evidence_score,
|
| 222 |
+
raw_scores=label_probs,
|
| 223 |
+
reason=self._make_reason(fact_label, evidence),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def _make_reason(label: FactLabel, evidence: str) -> str:
|
| 228 |
+
"""Generate a human-readable reason for the verdict."""
|
| 229 |
+
ev_short = evidence[:80] + "..." if len(evidence) > 80 else evidence
|
| 230 |
+
if label == FactLabel.SUPPORTED:
|
| 231 |
+
return f"Supported by evidence: \"{ev_short}\""
|
| 232 |
+
elif label == FactLabel.CONTRADICTED:
|
| 233 |
+
return f"Contradicts evidence: \"{ev_short}\""
|
| 234 |
+
else:
|
| 235 |
+
return f"Evidence is neutral — neither confirms nor denies: \"{ev_short}\""
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
transformers>=4.36
|
| 3 |
+
sentence-transformers>=2.2
|
| 4 |
+
faiss-cpu>=1.7
|
| 5 |
+
scikit-learn>=1.3
|
| 6 |
+
pydantic>=2.0
|
| 7 |
+
accelerate>=0.25
|
| 8 |
+
numpy>=1.24
|
| 9 |
+
gradio>=4.0
|