File size: 8,370 Bytes
8fb73f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ec89b
8fb73f8
80ec89b
8fb73f8
80ec89b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Verifier – NLI-based factual verification of claims against evidence.

Uses DeBERTa-v3 fine-tuned on MNLI+FEVER+ANLI to classify each
claim/evidence pair as entailment, contradiction, or neutral.
Maps NLI labels to FactEval labels: supported, contradicted, unverifiable.
"""

import logging
from enum import Enum

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from facteval import suppress_stdout
from facteval.config import NLI_MODEL, MIN_EVIDENCE_SCORE
from facteval.models import Claim, Evidence, ClaimWithEvidence

logger = logging.getLogger(__name__)


class FactLabel(str, Enum):
    """FactEval verdict labels."""
    SUPPORTED = "supported"
    CONTRADICTED = "contradicted"
    UNVERIFIABLE = "unverifiable"


# Map DeBERTa NLI labels → FactEval labels
_NLI_TO_FACT = {
    "entailment": FactLabel.SUPPORTED,
    "contradiction": FactLabel.CONTRADICTED,
    "neutral": FactLabel.UNVERIFIABLE,
}


class VerificationResult:
    """Result of verifying a single claim."""

    def __init__(
        self,
        claim: str,
        label: FactLabel,
        confidence: float,
        evidence: str | None,
        evidence_score: float | None,
        raw_scores: dict[str, float],
        reason: str = "",
        calibrated_confidence: float | None = None,
        calibration_error: float | None = None,
    ):
        self.claim = claim
        self.label = label
        self.confidence = confidence
        self.evidence = evidence
        self.evidence_score = evidence_score
        self.raw_scores = raw_scores
        self.reason = reason
        self.calibrated_confidence = calibrated_confidence
        self.calibration_error = calibration_error

    def to_dict(self) -> dict:
        d = {
            "claim": self.claim,
            "label": self.label.value,
            "confidence": round(self.confidence, 4),
            "reason": self.reason,
            "evidence": self.evidence,
            "evidence_score": round(self.evidence_score, 4) if self.evidence_score else None,
            "raw_nli_scores": {k: round(v, 4) for k, v in self.raw_scores.items()},
        }
        if self.calibrated_confidence is not None:
            d["calibrated_confidence"] = round(self.calibrated_confidence, 4)
        if self.calibration_error is not None:
            d["calibration_error"] = round(self.calibration_error, 4)
        return d


class Verifier:
    """Verify claims against evidence using NLI."""

    def __init__(
        self,
        model_name: str = NLI_MODEL,
        device: str | None = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        logger.info("Loading NLI model: %s on %s", model_name, self.device)

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        with suppress_stdout():
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name
            ).to(self.device)
        self.model.eval()

        self.id2label = self.model.config.id2label
        logger.info("Verifier ready. Labels: %s", self.id2label)

    def verify(
        self,
        claim_with_evidence: ClaimWithEvidence,
        min_evidence_score: float = MIN_EVIDENCE_SCORE,
    ) -> VerificationResult:
        """
        Verify a single claim against its retrieved evidence.

        If no evidence meets the min_score threshold, returns 'unverifiable'
        with zero confidence.
        """
        claim_text = claim_with_evidence.claim.text
        best = claim_with_evidence.best_evidence

        # Fallback: no usable evidence
        if best is None or best.score < min_evidence_score:
            logger.debug("No evidence for claim: %s", claim_text)
            return VerificationResult(
                claim=claim_text,
                label=FactLabel.UNVERIFIABLE,
                confidence=0.0,
                evidence=None,
                evidence_score=None,
                raw_scores={},
                reason="No relevant evidence found in the provided context.",
            )

        # Run NLI: premise=evidence, hypothesis=claim
        return self._run_nli(claim_text, best.sentence, best.score)

    def verify_batch(
        self,
        claims_with_evidence: list[ClaimWithEvidence],
        min_evidence_score: float = MIN_EVIDENCE_SCORE,
    ) -> list[VerificationResult]:
        """
        Verify a batch of claims using batched NLI inference.

        Claims without evidence are immediately marked unverifiable.
        Remaining claims are processed in a single forward pass for speed.
        """
        results: list[VerificationResult | None] = [None] * len(claims_with_evidence)
        nli_pairs: list[tuple[int, str, str, float]] = []

        for i, cwe in enumerate(claims_with_evidence):
            claim_text = cwe.claim.text
            best = cwe.best_evidence

            if best is None or best.score < min_evidence_score:
                results[i] = VerificationResult(
                    claim=claim_text,
                    label=FactLabel.UNVERIFIABLE,
                    confidence=0.0,
                    evidence=None,
                    evidence_score=None,
                    raw_scores={},
                    reason="No relevant evidence found in the provided context.",
                )
            else:
                nli_pairs.append((i, claim_text, best.sentence, best.score))

        # Batch NLI inference for all claims with evidence
        if nli_pairs:
            indices, claims, evidences, scores = zip(*nli_pairs)
            inputs = self.tokenizer(
                list(evidences), list(claims),
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            ).to(self.device)

            with torch.no_grad():
                logits = self.model(**inputs).logits

            all_probs = torch.softmax(logits, dim=-1).cpu()

            for idx, probs_t, claim, evidence, ev_score in zip(
                indices, all_probs, claims, evidences, scores
            ):
                probs = probs_t.tolist()
                label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
                predicted_nli = self.id2label[probs_t.argmax().item()]
                fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)

                results[idx] = VerificationResult(
                    claim=claim,
                    label=fact_label,
                    confidence=max(probs),
                    evidence=evidence,
                    evidence_score=ev_score,
                    raw_scores=label_probs,
                    reason=self._make_reason(fact_label, evidence),
                )

        return results

    def _run_nli(
        self, claim: str, evidence: str, evidence_score: float
    ) -> VerificationResult:
        """Run NLI inference on a single claim/evidence pair."""
        inputs = self.tokenizer(
            evidence, claim,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        ).to(self.device)

        with torch.no_grad():
            logits = self.model(**inputs).logits

        probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
        label_probs = {self.id2label[i]: float(p) for i, p in enumerate(probs)}
        predicted_nli = self.id2label[logits.argmax().item()]
        fact_label = _NLI_TO_FACT.get(predicted_nli, FactLabel.UNVERIFIABLE)

        return VerificationResult(
            claim=claim,
            label=fact_label,
            confidence=max(probs),
            evidence=evidence,
            evidence_score=evidence_score,
            raw_scores=label_probs,
            reason=self._make_reason(fact_label, evidence),
        )

    @staticmethod
    def _make_reason(label: FactLabel, evidence: str) -> str:
        """Generate a human-readable reason for the verdict."""
        ev_short = evidence[:80] + "..." if len(evidence) > 80 else evidence
        if label == FactLabel.SUPPORTED:
            return f"Matched evidence: \"{ev_short}\""
        elif label == FactLabel.CONTRADICTED:
            return f"Contradicted by: \"{ev_short}\""
        else:
            return f"No strong match — evidence is neutral: \"{ev_short}\""