asulc commited on
Commit
a4ed7cf
·
1 Parent(s): 6712e9d

mad outlier detection for problematic phases

Browse files
Files changed (1) hide show
  1. src/app.py +92 -31
src/app.py CHANGED
@@ -19,7 +19,8 @@ def load_model(model_name):
19
  st.info(f"Loading model '{model_name}'... This may take a moment on the first run.")
20
  try:
21
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True)
22
- model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True, attn_implementation="eager")
 
23
  st.success(f"Model '{model_name}' loaded and ready!")
24
  return tokenizer, model
25
  except Exception as e:
@@ -64,25 +65,54 @@ def get_analysis_data(text_to_analyze, system_prompt, tokenizer, model):
64
  return list(zip(tokens, sequence_log_probs)), full_tokens, last_layer_attention, start_index, end_index
65
 
66
 
67
- def find_high_perplexity_phrases(analysis_data, std_dev_threshold=1.5):
68
- if not analysis_data: return []
69
- log_probs = [lp for _, lp in analysis_data]
70
- mean_lp = np.mean(log_probs)
71
- std_lp = np.std(log_probs)
72
- threshold = mean_lp - std_dev_threshold * std_lp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  outlier_phrases = []
75
  current_phrase = ""
76
- for token, log_prob in analysis_data:
 
 
77
  display_token = token.replace('Ġ', ' ')
78
- if log_prob < threshold:
79
  current_phrase += display_token
 
80
  else:
81
- if current_phrase:
82
  outlier_phrases.append(current_phrase.strip())
83
  current_phrase = ""
84
- if current_phrase:
 
 
85
  outlier_phrases.append(current_phrase.strip())
 
86
  return outlier_phrases
87
 
88
 
@@ -105,23 +135,39 @@ Explain, step-by-step, why the model found **these specific phrases** surprising
105
 
106
 
107
  def get_color_for_logprob(logprob, min_logprob, max_logprob):
108
- if min_logprob >= max_logprob: return "#FFFFFF"
109
  normalized = (logprob - min_logprob) / (max_logprob - min_logprob)
110
- hue = normalized * 0.4
111
  rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
112
  return '#%02x%02x%02x' % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
113
 
114
 
115
- def render_colored_text(analysis_data, min_logprob, max_logprob):
 
 
 
 
116
  html_elements = []
117
- for token, logprob in analysis_data:
 
 
 
 
 
 
118
  perplexity = math.exp(-logprob) if logprob != 0 else 1
119
  display_token = token.replace('Ġ', ' ')
120
- color = get_color_for_logprob(logprob, min_logprob, max_logprob)
121
  tooltip = f"Perplexity: {perplexity:.2f}"
122
- html_elements.append(
123
- f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>'
124
- )
 
 
 
 
 
 
 
125
  return "".join(html_elements)
126
 
127
 
@@ -208,8 +254,7 @@ def render_interactive_text(tokens, attention_matrix, start_index, threshold):
208
  st.set_page_config(layout="wide", page_title="QCD Text Validator & Inspector", page_icon="🔬")
209
  st.title("QCD Text Validator & Inspector")
210
 
211
- MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
212
- # SYSTEM_PROMPT = "You are a particle physicist specializing in Quantum Chromodynamics (QCD)..."
213
  SYSTEM_PROMPT = """
214
  You are an expert peer reviewer for a top-tier physics journal, specializing in the theory of the strong interaction (Quantum Chromodynamics). Your task is to rigorously evaluate statements for their strict adherence to the established principles of the Standard Model.
215
 
@@ -220,11 +265,15 @@ Do not tolerate simplifications, analogies, or pop-science descriptions that are
220
 
221
  try:
222
  tokenizer, model = load_model(MODEL_NAME)
223
- default_text = "In QCD, asymptotic freedom incorrectly states that the strong force between quarks grows stronger at high energies."
224
  text_to_analyze = st.text_area("Enter Text to Analyze:", value=default_text, height=150)
225
 
226
  if st.button("Analyze Text", key="analyze_button", type="primary"):
227
- for key in list(st.session_state.keys()): del st.session_state[key]
 
 
 
 
228
  if text_to_analyze:
229
  with st.spinner("Performing analysis and calculating attention..."):
230
  analysis_data, full_tokens, attention_matrix, start_idx, end_idx = get_analysis_data(text_to_analyze,
@@ -236,11 +285,19 @@ try:
236
  st.session_state.attention_matrix = attention_matrix
237
  st.session_state.start_index = start_idx
238
  st.session_state.end_index = end_idx
239
- st.session_state.suspicious_phrases = find_high_perplexity_phrases(analysis_data)
 
 
 
 
 
 
 
 
240
  st.session_state.original_text = text_to_analyze
241
  st.session_state.analysis_complete = True
242
  else:
243
- st.warning("Analysis could not be completed.")
244
  else:
245
  st.warning("Please enter some text to analyze.")
246
 
@@ -249,10 +306,11 @@ try:
249
 
250
  # Perplexity Analysis Section
251
  st.subheader("📝 Perplexity Analysis")
252
- st.markdown("Color indicates model surprise (Red = High Surprise, Green = Low Surprise).")
253
- min_lp = min([lp for _, lp in st.session_state.analysis_data], default=0)
254
- max_lp = max([lp for _, lp in st.session_state.analysis_data], default=0)
255
- colored_text_html = render_colored_text(st.session_state.analysis_data, min_lp, max_lp)
 
256
  st.markdown(colored_text_html, unsafe_allow_html=True)
257
  st.markdown("---")
258
 
@@ -266,8 +324,10 @@ try:
266
  user_attention_matrix = st.session_state.attention_matrix
267
 
268
  max_attention = float(np.max(user_attention_matrix)) if user_attention_matrix.size > 0 else 0.1
 
 
269
  attention_threshold = st.slider("Attention Threshold for Highlighting", min_value=0.0, max_value=max_attention,
270
- value=min(0.1, max_attention), step=0.01, format="%.2f")
271
 
272
  # Render the new interactive text component
273
  interactive_html = render_interactive_text(user_tokens, user_attention_matrix, start, attention_threshold)
@@ -285,7 +345,7 @@ try:
285
  st.session_state.suspicious_phrases,
286
  tokenizer, model)
287
  else:
288
- st.info("No specific high-perplexity phrases were found.")
289
 
290
  if 'deep_dive_result' in st.session_state:
291
  st.subheader("🧠 Focused Deep Dive Analysis")
@@ -293,3 +353,4 @@ try:
293
 
294
  except Exception as e:
295
  st.error(f"A critical error occurred: {e}")
 
 
19
  st.info(f"Loading model '{model_name}'... This may take a moment on the first run.")
20
  try:
21
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True,
23
+ attn_implementation="eager")
24
  st.success(f"Model '{model_name}' loaded and ready!")
25
  return tokenizer, model
26
  except Exception as e:
 
65
  return list(zip(tokens, sequence_log_probs)), full_tokens, last_layer_attention, start_index, end_index
66
 
67
 
68
+ def get_outlier_indices(analysis_data, threshold=-3.0):
69
+ """
70
+ Identifies the indices of outlier tokens using the Median Absolute Deviation (MAD).
71
+ A lower log-probability is more surprising, so we look for large negative scores.
72
+ """
73
+ if not analysis_data or len(analysis_data) < 5:
74
+ return np.array([])
75
+
76
+ log_probs = np.array([lp for _, lp in analysis_data])
77
+ median_lp = np.median(log_probs)
78
+ # Calculate Median Absolute Deviation (MAD)
79
+ mad = np.median(np.abs(log_probs - median_lp))
80
+
81
+ if mad == 0: # Avoid division by zero if many values are the same
82
+ return np.array([])
83
+
84
+ # Calculate the modified Z-scores (robust against outliers)
85
+ modified_z_scores = 0.6745 * (log_probs - median_lp) / mad
86
+
87
+ # Return indices where the score is below the threshold
88
+ return np.where(modified_z_scores < threshold)[0]
89
+
90
+
91
+ def find_high_perplexity_phrases(analysis_data, outlier_indices):
92
+ """
93
+ Groups contiguous outlier tokens into phrases.
94
+ """
95
+ if not analysis_data or outlier_indices.size == 0:
96
+ return []
97
 
98
  outlier_phrases = []
99
  current_phrase = ""
100
+ on_outlier_streak = False
101
+
102
+ for i, (token, _) in enumerate(analysis_data):
103
  display_token = token.replace('Ġ', ' ')
104
+ if i in outlier_indices:
105
  current_phrase += display_token
106
+ on_outlier_streak = True
107
  else:
108
+ if on_outlier_streak:
109
  outlier_phrases.append(current_phrase.strip())
110
  current_phrase = ""
111
+ on_outlier_streak = False
112
+
113
+ if current_phrase: # Catch a phrase that ends the sentence
114
  outlier_phrases.append(current_phrase.strip())
115
+
116
  return outlier_phrases
117
 
118
 
 
135
 
136
 
137
  def get_color_for_logprob(logprob, min_logprob, max_logprob):
138
+ if min_logprob >= max_logprob: return "#FFB3B3" # A default soft red
139
  normalized = (logprob - min_logprob) / (max_logprob - min_logprob)
140
+ hue = normalized * 0.4 # Scale from Red (0.0) to Greenish-Yellow (0.4)
141
  rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
142
  return '#%02x%02x%02x' % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
143
 
144
 
145
+ def render_colored_text(analysis_data, outlier_indices):
146
+ """
147
+ Renders text, highlighting only the outlier tokens.
148
+ The color intensity of outliers is still scaled relative to each other.
149
+ """
150
  html_elements = []
151
+
152
+ # For scaling the color of the outliers themselves
153
+ outlier_log_probs = [analysis_data[i][1] for i in outlier_indices]
154
+ min_lp = min(outlier_log_probs) if outlier_log_probs else 0
155
+ max_lp = max(outlier_log_probs) if outlier_log_probs else 0
156
+
157
+ for i, (token, logprob) in enumerate(analysis_data):
158
  perplexity = math.exp(-logprob) if logprob != 0 else 1
159
  display_token = token.replace('Ġ', ' ')
 
160
  tooltip = f"Perplexity: {perplexity:.2f}"
161
+
162
+ if i in outlier_indices:
163
+ color = get_color_for_logprob(logprob, min_lp, max_lp)
164
+ html_elements.append(
165
+ f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>'
166
+ )
167
+ else:
168
+ # Not an outlier, render with no background color
169
+ html_elements.append(f'<span title="{tooltip}">{display_token}</span>')
170
+
171
  return "".join(html_elements)
172
 
173
 
 
254
  st.set_page_config(layout="wide", page_title="QCD Text Validator & Inspector", page_icon="🔬")
255
  st.title("QCD Text Validator & Inspector")
256
 
257
+ MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
 
258
  SYSTEM_PROMPT = """
259
  You are an expert peer reviewer for a top-tier physics journal, specializing in the theory of the strong interaction (Quantum Chromodynamics). Your task is to rigorously evaluate statements for their strict adherence to the established principles of the Standard Model.
260
 
 
265
 
266
  try:
267
  tokenizer, model = load_model(MODEL_NAME)
268
+ default_text = "In QCD, asymptotic freedom incorrectly states that the strong force between quarks grows stronger at high energies, while in reality it gets weaker."
269
  text_to_analyze = st.text_area("Enter Text to Analyze:", value=default_text, height=150)
270
 
271
  if st.button("Analyze Text", key="analyze_button", type="primary"):
272
+ # Clear previous analysis from session state
273
+ for key in list(st.session_state.keys()):
274
+ if key not in ['tokenizer', 'model']: # Don't clear the loaded model
275
+ del st.session_state[key]
276
+
277
  if text_to_analyze:
278
  with st.spinner("Performing analysis and calculating attention..."):
279
  analysis_data, full_tokens, attention_matrix, start_idx, end_idx = get_analysis_data(text_to_analyze,
 
285
  st.session_state.attention_matrix = attention_matrix
286
  st.session_state.start_index = start_idx
287
  st.session_state.end_index = end_idx
288
+
289
+ # --- MODIFIED LOGIC ---
290
+ # 1. Get outlier indices first
291
+ outlier_indices = get_outlier_indices(analysis_data)
292
+ st.session_state.outlier_indices = outlier_indices
293
+
294
+ # 2. Find phrases based on these indices
295
+ st.session_state.suspicious_phrases = find_high_perplexity_phrases(analysis_data, outlier_indices)
296
+
297
  st.session_state.original_text = text_to_analyze
298
  st.session_state.analysis_complete = True
299
  else:
300
+ st.warning("Analysis could not be completed. The input text might be too short or unusual.")
301
  else:
302
  st.warning("Please enter some text to analyze.")
303
 
 
306
 
307
  # Perplexity Analysis Section
308
  st.subheader("📝 Perplexity Analysis")
309
+ st.markdown("Color indicates model surprise (**outliers only**). A lack of color means the text is plausible.")
310
+
311
+ # --- MODIFIED LOGIC ---
312
+ # 3. Pass indices to the rendering function
313
+ colored_text_html = render_colored_text(st.session_state.analysis_data, st.session_state.outlier_indices)
314
  st.markdown(colored_text_html, unsafe_allow_html=True)
315
  st.markdown("---")
316
 
 
324
  user_attention_matrix = st.session_state.attention_matrix
325
 
326
  max_attention = float(np.max(user_attention_matrix)) if user_attention_matrix.size > 0 else 0.1
327
+ # Set a sensible default value for the slider
328
+ default_slider_val = min(0.1, max_attention) if max_attention > 0 else 0.1
329
  attention_threshold = st.slider("Attention Threshold for Highlighting", min_value=0.0, max_value=max_attention,
330
+ value=default_slider_val, step=0.01, format="%.2f")
331
 
332
  # Render the new interactive text component
333
  interactive_html = render_interactive_text(user_tokens, user_attention_matrix, start, attention_threshold)
 
345
  st.session_state.suspicious_phrases,
346
  tokenizer, model)
347
  else:
348
+ st.info("No statistically significant high-perplexity phrases were found.")
349
 
350
  if 'deep_dive_result' in st.session_state:
351
  st.subheader("🧠 Focused Deep Dive Analysis")
 
353
 
354
  except Exception as e:
355
  st.error(f"A critical error occurred: {e}")
356
+ st.exception(e) # Provides a full traceback in the terminal for debugging