abdullah-113 commited on
Commit
f697d16
·
verified ·
1 Parent(s): 7bb55ea

Update api/detector.py

Browse files
Files changed (1) hide show
  1. api/detector.py +62 -60
api/detector.py CHANGED
@@ -1,126 +1,131 @@
1
  import torch
2
  import re
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
 
5
- # ── Configuration Constants ──
6
- TEMPERATURE = 1.5 # Logit smoothing factor (higher = softer distribution)
7
- CONFIDENCE_THRESHOLD = 0.60 # Minimum raw probability to trust a classification
8
- CHUNK_SIZE = 400 # Words per chunk
9
- CHUNK_OVERLAP = 50 # Overlapping words between chunks
10
 
11
 
12
  def sliding_window_chunker(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]:
13
- """Splits a large text into overlapping chunks of a specific word count."""
14
  words = text.split()
15
  chunks = []
16
-
17
  if not words:
18
  return chunks
19
-
20
  step = chunk_size - overlap
21
  if step <= 0:
22
  step = 1
23
-
24
  for i in range(0, len(words), step):
25
  chunk_words = words[i:i + chunk_size]
26
  chunks.append(" ".join(chunk_words))
27
-
28
  if i + chunk_size >= len(words):
29
  break
30
-
31
  return chunks
32
 
 
33
  def split_into_claims(text: str) -> list[str]:
34
- """Splits the LLM output into individual sentences/claims to prevent conversational filler from ruining factual scores."""
 
35
  raw_sentences = re.split(r'(?<=[.!?])\s+', text.strip())
36
-
37
  valid_claims = []
38
  for s in raw_sentences:
39
  clean = s.strip()
40
- # Only keep substantial claims to avoid evaluating numbering fragments (like "1.")
41
  if len(clean.split()) >= 3:
42
  valid_claims.append(clean)
43
-
44
  if not valid_claims and text.strip():
45
  valid_claims = [text.strip()]
46
-
47
  return valid_claims
48
 
 
49
  def normalize_scores(contradiction: float, entailment: float, neutral: float) -> tuple[float, float, float]:
50
- """Ensures the three scores sum to exactly 100.0%."""
51
  total = contradiction + entailment + neutral
52
  if total == 0:
53
  return (0.0, 0.0, 100.0)
54
-
55
  c = round((contradiction / total) * 100.0, 2)
56
  e = round((entailment / total) * 100.0, 2)
57
- n = round(100.0 - c - e, 2) # Assign remainder to neutral to guarantee sum = 100
58
  return (c, e, n)
59
 
60
 
61
  class HallucinationDetector:
62
  def __init__(self):
63
- """Initializes the model and tokenizer only once when the class is created."""
64
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  self.model_name = "cross-encoder/nli-deberta-v3-base"
66
-
67
  print(f"Initializing Detector on {self.device.type.upper()}...")
68
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
69
  self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
70
  print("Detector Ready!")
71
 
 
 
 
72
  def _infer_chunk(self, chunk: str, claim: str) -> dict:
73
- """Runs NLI inference on a single chunk against a single claim."""
74
  inputs = self.tokenizer(
75
- chunk, claim,
76
  return_tensors="pt", truncation=True, max_length=512
77
  ).to(self.device)
78
-
79
  with torch.no_grad():
80
  outputs = self.model(**inputs)
81
- # Temperature Scaling
82
  scaled_logits = outputs.logits / TEMPERATURE
83
  probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
84
-
85
  c_raw = probs[0][0].item()
86
  e_raw = probs[0][1].item()
87
  n_raw = probs[0][2].item()
88
-
89
- # Confidence Thresholding
90
  max_score = max(c_raw, e_raw, n_raw)
91
  if max_score < CONFIDENCE_THRESHOLD:
92
- c_raw, e_raw, n_raw = 0.0, 0.0, 1.0 # Default to Neutral
93
-
94
  return {
95
  "contradiction": c_raw,
96
  "entailment": e_raw,
97
  "neutral": n_raw,
98
- "spans": [] # Placeholder for Captum
99
  }
100
 
101
  def analyze(self, context: str, llm_response: str) -> dict:
 
 
 
 
102
  """
103
- Hyper-Accurate Claim-by-Claim Analysis:
104
- Splits LLM output into sentences, evaluates each sentence against context chunks,
105
- and aggregates the results logically.
106
- """
107
- chunks = sliding_window_chunker(context)
108
- if not chunks:
109
- chunks = [""]
110
-
111
  claims = split_into_claims(llm_response)
112
  sentence_scores = []
113
- best_attribution_spans = []
114
-
115
  for claim in claims:
116
- # Score this claim against all context chunks
117
- chunk_results = [self._infer_chunk(chunk, claim) for chunk in chunks]
118
-
119
  s_max_e = max(r["entailment"] for r in chunk_results)
120
  s_max_c = max(r["contradiction"] for r in chunk_results)
121
  s_max_n = max(r["neutral"] for r in chunk_results)
122
-
123
- # Priority Resolution ("Truth Wins") for THIS specific claim
124
  if s_max_e >= CONFIDENCE_THRESHOLD and s_max_e >= s_max_c:
125
  final_s_e = s_max_e
126
  final_s_c = s_max_c * 0.25
@@ -136,38 +141,35 @@ class HallucinationDetector:
136
  final_s_e = s_max_e
137
  final_s_n = s_max_n
138
  winning_spans = []
139
-
140
  sentence_scores.append({
141
  "c": final_s_c,
142
  "e": final_s_e,
143
  "n": final_s_n,
144
  "spans": winning_spans
145
  })
146
-
147
- # ── Document-level Aggregation ──
148
- # 1. Contradiction runs on a "One Strike" rule: If ANY claim contradicts, the output is flawed.
149
  doc_c = max(s["c"] for s in sentence_scores)
150
-
151
- # 2. Entailment and Neutral run on an Average: Reflects the ratio of "Facts" vs "Neutral conversational filler".
152
  doc_e = sum(s["e"] for s in sentence_scores) / len(sentence_scores)
153
  doc_n = sum(s["n"] for s in sentence_scores) / len(sentence_scores)
154
-
155
- # Clamp negatives and purely normalize
156
  doc_c = max(doc_c, 0.0)
157
  doc_e = max(doc_e, 0.0)
158
  doc_n = max(doc_n, 0.0)
159
-
160
  c_pct, e_pct, n_pct = normalize_scores(doc_c, doc_e, doc_n)
161
-
162
- # Grab spans from the claim that scored the highest severity
163
  if doc_c > doc_e:
164
  best_spans = max(sentence_scores, key=lambda x: x["c"])["spans"]
165
  else:
166
  best_spans = max(sentence_scores, key=lambda x: x["e"])["spans"]
167
-
168
- # True Hallucination criteria
169
  is_hallucination = (c_pct > e_pct) and (doc_c >= CONFIDENCE_THRESHOLD)
170
-
171
  return {
172
  "contradiction_score": c_pct,
173
  "entailment_score": e_pct,
 
1
  import torch
2
  import re
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from api.retriever import ChunkRetriever
5
 
6
+ TEMPERATURE = 1.5
7
+ CONFIDENCE_THRESHOLD = 0.60
8
+ CHUNK_SIZE = 400
9
+ CHUNK_OVERLAP = 50
 
10
 
11
 
12
  def sliding_window_chunker(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]:
13
+ """Splits a large text into overlapping word-level chunks."""
14
  words = text.split()
15
  chunks = []
16
+
17
  if not words:
18
  return chunks
19
+
20
  step = chunk_size - overlap
21
  if step <= 0:
22
  step = 1
23
+
24
  for i in range(0, len(words), step):
25
  chunk_words = words[i:i + chunk_size]
26
  chunks.append(" ".join(chunk_words))
27
+
28
  if i + chunk_size >= len(words):
29
  break
30
+
31
  return chunks
32
 
33
+
34
  def split_into_claims(text: str) -> list[str]:
35
+ """Breaks LLM output into individual sentences so each factual
36
+ claim gets scored independently (avoids filler diluting scores)."""
37
  raw_sentences = re.split(r'(?<=[.!?])\s+', text.strip())
38
+
39
  valid_claims = []
40
  for s in raw_sentences:
41
  clean = s.strip()
 
42
  if len(clean.split()) >= 3:
43
  valid_claims.append(clean)
44
+
45
  if not valid_claims and text.strip():
46
  valid_claims = [text.strip()]
47
+
48
  return valid_claims
49
 
50
+
51
  def normalize_scores(contradiction: float, entailment: float, neutral: float) -> tuple[float, float, float]:
52
+ """Makes sure the three scores always add up to exactly 100%."""
53
  total = contradiction + entailment + neutral
54
  if total == 0:
55
  return (0.0, 0.0, 100.0)
56
+
57
  c = round((contradiction / total) * 100.0, 2)
58
  e = round((entailment / total) * 100.0, 2)
59
+ n = round(100.0 - c - e, 2)
60
  return (c, e, n)
61
 
62
 
63
  class HallucinationDetector:
64
  def __init__(self):
 
65
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  self.model_name = "cross-encoder/nli-deberta-v3-base"
67
+
68
  print(f"Initializing Detector on {self.device.type.upper()}...")
69
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
70
  self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
71
  print("Detector Ready!")
72
 
73
+ # Stage 1 retriever — lightweight bi-encoder for pre-filtering chunks
74
+ self.retriever = ChunkRetriever()
75
+
76
  def _infer_chunk(self, chunk: str, claim: str) -> dict:
77
+ """Stage 2: runs the heavy cross-encoder on a single (chunk, claim) pair."""
78
  inputs = self.tokenizer(
79
+ chunk, claim,
80
  return_tensors="pt", truncation=True, max_length=512
81
  ).to(self.device)
82
+
83
  with torch.no_grad():
84
  outputs = self.model(**inputs)
 
85
  scaled_logits = outputs.logits / TEMPERATURE
86
  probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
87
+
88
  c_raw = probs[0][0].item()
89
  e_raw = probs[0][1].item()
90
  n_raw = probs[0][2].item()
91
+
92
+ # if the model isn't confident about anything, default to neutral
93
  max_score = max(c_raw, e_raw, n_raw)
94
  if max_score < CONFIDENCE_THRESHOLD:
95
+ c_raw, e_raw, n_raw = 0.0, 0.0, 1.0
96
+
97
  return {
98
  "contradiction": c_raw,
99
  "entailment": e_raw,
100
  "neutral": n_raw,
101
+ "spans": [] # placeholder for Captum attributions
102
  }
103
 
104
  def analyze(self, context: str, llm_response: str) -> dict:
105
+ """Two-stage pipeline:
106
+ 1) Chunk the document → retrieve top-5 relevant chunks (bi-encoder)
107
+ 2) Score each claim against those top chunks (cross-encoder)
108
+ 3) Aggregate with priority resolution
109
  """
110
+ all_chunks = sliding_window_chunker(context)
111
+ if not all_chunks:
112
+ all_chunks = [""]
113
+
114
+ # Stage 1: narrow down to the most relevant chunks
115
+ relevant_chunks = self.retriever.get_top_chunks(llm_response, all_chunks)
116
+
 
117
  claims = split_into_claims(llm_response)
118
  sentence_scores = []
119
+
 
120
  for claim in claims:
121
+ # Stage 2: cross-encoder only runs on the pre-filtered chunks
122
+ chunk_results = [self._infer_chunk(chunk, claim) for chunk in relevant_chunks]
123
+
124
  s_max_e = max(r["entailment"] for r in chunk_results)
125
  s_max_c = max(r["contradiction"] for r in chunk_results)
126
  s_max_n = max(r["neutral"] for r in chunk_results)
127
+
128
+ # priority resolution if the fact exists somewhere, entailment wins
129
  if s_max_e >= CONFIDENCE_THRESHOLD and s_max_e >= s_max_c:
130
  final_s_e = s_max_e
131
  final_s_c = s_max_c * 0.25
 
141
  final_s_e = s_max_e
142
  final_s_n = s_max_n
143
  winning_spans = []
144
+
145
  sentence_scores.append({
146
  "c": final_s_c,
147
  "e": final_s_e,
148
  "n": final_s_n,
149
  "spans": winning_spans
150
  })
151
+
152
+ # document-level aggregation
153
+ # contradiction uses max (one-strike rule)
154
  doc_c = max(s["c"] for s in sentence_scores)
155
+ # entailment and neutral use average across claims
 
156
  doc_e = sum(s["e"] for s in sentence_scores) / len(sentence_scores)
157
  doc_n = sum(s["n"] for s in sentence_scores) / len(sentence_scores)
158
+
 
159
  doc_c = max(doc_c, 0.0)
160
  doc_e = max(doc_e, 0.0)
161
  doc_n = max(doc_n, 0.0)
162
+
163
  c_pct, e_pct, n_pct = normalize_scores(doc_c, doc_e, doc_n)
164
+
165
+ # grab attribution spans from the highest-severity claim
166
  if doc_c > doc_e:
167
  best_spans = max(sentence_scores, key=lambda x: x["c"])["spans"]
168
  else:
169
  best_spans = max(sentence_scores, key=lambda x: x["e"])["spans"]
170
+
 
171
  is_hallucination = (c_pct > e_pct) and (doc_c >= CONFIDENCE_THRESHOLD)
172
+
173
  return {
174
  "contradiction_score": c_pct,
175
  "entailment_score": e_pct,