sulcan commited on
Commit
feb33b4
·
verified ·
1 Parent(s): ccac5e8

Update src/app.py

Browse files

lowering the first token

Files changed (1) hide show
  1. src/app.py +198 -187
src/app.py CHANGED
@@ -1,189 +1,200 @@
1
- import streamlit as st
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import numpy as np
5
- import colorsys
6
- import math
7
-
8
- # --- Core Functions (Cached) ---
9
- @st.cache_resource
10
- def load_model(model_name):
11
- """Loads the specified model and tokenizer from Hugging Face."""
12
- import os
13
-
14
- # Set cache directory to writable location
15
- os.environ['HF_HOME'] = '/tmp/hf_cache'
16
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_cache'
17
-
18
- st.info(f"Loading model '{model_name}'... This may take a moment on first run.")
19
-
20
- try:
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- model_name,
23
- cache_dir="/tmp/hf_cache",
24
- trust_remote_code=True
25
- )
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_name,
28
- cache_dir="/tmp/hf_cache",
29
- trust_remote_code=True
30
- )
31
- st.success(f"Model '{model_name}' loaded and ready!")
32
- return tokenizer, model
33
-
34
- except PermissionError as e:
35
- st.error(f"Permission error: {e}")
36
- st.error("Try refreshing the page or waiting a moment if another download is in progress.")
37
- st.stop()
38
-
39
- except Exception as e:
40
- st.error(f"An error occurred loading the model: {e}")
41
- st.stop()
42
-
43
- # --- Analysis and Helper Functions ---
44
- def get_analysis_data(text_to_analyze, system_prompt, tokenizer, model):
45
- """Calculates log-likelihood and confidence for each token."""
46
- messages = [
47
- {"role": "system", "content": system_prompt},
48
- {"role": "user", "content": text_to_analyze},
49
- ]
50
- tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
51
- user_text_token_ids = tokenizer.encode(text_to_analyze, add_special_tokens=False)
52
- full_ids_list = tokenized_chat[0].tolist()
53
- start_index = -1
54
- for i in range(len(full_ids_list) - len(user_text_token_ids) + 1):
55
- if full_ids_list[i:i+len(user_text_token_ids)] == user_text_token_ids:
56
- start_index = i
57
- break
58
- if start_index == -1: return []
59
- end_index = start_index + len(user_text_token_ids)
60
-
61
- with torch.no_grad():
62
- outputs = model(tokenized_chat, labels=tokenized_chat)
63
- logits = outputs.logits
64
- log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
65
- probs = torch.nn.functional.softmax(logits, dim=-1)
66
- sliced_log_probs = log_probs[0, start_index-1:end_index-1, :]
67
- sliced_probs = probs[0, start_index-1:end_index-1, :]
68
- sliced_tokens = tokenized_chat[0, start_index:end_index]
69
- sequence_log_probs = sliced_log_probs.gather(1, sliced_tokens.unsqueeze(-1)).squeeze().tolist()
70
- sequence_probs = sliced_probs.gather(1, sliced_tokens.unsqueeze(-1)).squeeze().tolist()
71
- tokens = tokenizer.convert_ids_to_tokens(sliced_tokens)
72
- if not isinstance(sequence_log_probs, list):
73
- sequence_log_probs, sequence_probs = [sequence_log_probs], [sequence_probs]
74
- return list(zip(tokens, sequence_log_probs, sequence_probs))
75
-
76
- def find_high_perplexity_phrases(analysis_data, std_dev_threshold=1.5):
77
- """Identifies and groups high-perplexity tokens into phrases."""
78
- if not analysis_data: return []
79
- log_probs = [lp for _, lp, _ in analysis_data]
80
- mean_lp = np.mean(log_probs)
81
- std_lp = np.std(log_probs)
82
- threshold = mean_lp - std_dev_threshold * std_lp
83
-
84
- outlier_phrases = []
85
- current_phrase = ""
86
- for token, log_prob, _ in analysis_data:
87
- display_token = token.replace('Ġ', ' ').replace(' ', ' ')
88
- if log_prob < threshold:
89
- current_phrase += display_token
90
- else:
91
- if current_phrase:
92
- outlier_phrases.append(current_phrase.strip())
93
- current_phrase = ""
94
- if current_phrase:
95
- outlier_phrases.append(current_phrase.strip())
96
- return outlier_phrases
97
-
98
- def run_focused_deep_dive(original_text, phrases, tokenizer, model):
99
- """Runs a focused CoT prompt to explain why specific phrases are surprising."""
100
- cot_system_prompt = "You are a meticulous and rigorous particle physicist and an expert in peer-review. Your task is to explain why certain phrases in a given statement might be considered incorrect or surprising."
101
- phrases_str = "\n".join([f"- \"{p}\"" for p in phrases])
102
- cot_user_prompt = f"""I have analyzed the following statement using a language model:
103
- **Full Statement:** "{original_text}"
104
-
105
- The analysis flagged the following phrase(s) as having very high perplexity (meaning they were highly surprising to the model):
106
- {phrases_str}
107
-
108
- Please explain, step-by-step, why the language model likely found **these specific phrases** incorrect or surprising in the context of the full statement. For each phrase, provide the correct physics concept if one exists.
109
- """
110
- messages = [{"role": "system", "content": cot_system_prompt}, {"role": "user", "content": cot_user_prompt}]
111
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
112
- inputs = tokenizer(prompt, return_tensors="pt")
113
- with torch.no_grad():
114
- outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, top_p=0.95)
115
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
- return response_text.split("assistant\n")[-1]
117
-
118
- def get_color_for_logprob(logprob, min_logprob, max_logprob):
119
- """Calculates color based on normalized log-probability."""
120
- if min_logprob >= max_logprob: return "#FFFFFF"
121
- normalized = (logprob - min_logprob) / (max_logprob - min_logprob)
122
- hue = normalized * 0.4
123
- rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
124
- return '#%02x%02x%02x' % (int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))
125
-
126
- def render_colored_text(analysis_data, min_logprob, max_logprob):
127
- """Renders the colored text HTML."""
128
- html_elements = []
129
- for token, logprob, confidence in analysis_data:
130
- perplexity = math.exp(-logprob) if logprob != 0 else 1
131
- display_token = token.replace('Ġ', ' ').replace(' ', ' ')
132
- color = get_color_for_logprob(logprob, min_logprob, max_logprob)
133
- tooltip = f"Perplexity: {perplexity:.2f} | Confidence: {confidence:.1%}"
134
- html_elements.append(
135
- f'<span style="background-color: {color}; padding: 2px 4px; margin: 1px; border-radius: 4px;" title="{tooltip}">{display_token}</span>'
136
- )
137
- return "".join(html_elements)
138
-
139
- # --- Streamlit App ---
140
-
141
- st.set_page_config(layout="wide", page_title="QCD Text Validator", page_icon="🔬")
142
-
143
- MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
144
- SYSTEM_PROMPT = "You are a particle physicist specializing in Quantum Chromodynamics (QCD)... Any deviation from established theory... should be treated as a highly improbable, high-perplexity event."
145
-
146
  try:
147
- tokenizer, model = load_model(MODEL_NAME)
148
- default_text = "In QCD, asymptotic freedom incorrectly states that the strong force between quarks grows stronger at high energies. This is mediated by a universal harmonic constant."
149
- text_to_analyze = st.text_area("Enter Text Here:", value=default_text, height=150, label_visibility="collapsed")
150
-
151
- if st.button("Analyze Text", key="analyze_button"):
152
- st.session_state.clear() # Clear previous results
153
- if text_to_analyze:
154
- text_with_space = " " + text_to_analyze
155
- with st.spinner("Performing initial analysis..."):
156
- full_analysis_data = get_analysis_data(text_with_space, SYSTEM_PROMPT, tokenizer, model)
157
- if full_analysis_data and len(full_analysis_data) > 1:
158
- st.session_state.visible_data = full_analysis_data#[1:]
159
- scores = [lp for _, lp, _ in st.session_state.visible_data]
160
- st.session_state.min_logprob = min(scores) if scores else 0
161
- st.session_state.max_logprob = max(scores) if scores else 0
162
- st.session_state.suspicious_phrases = find_high_perplexity_phrases(st.session_state.visible_data)
163
- st.session_state.analysis_complete = True
164
- st.session_state.original_text = text_to_analyze
165
- else:
166
- st.warning("Analysis could not be completed.")
167
- else:
168
- st.warning("Please enter some text to analyze.")
169
-
170
- if st.session_state.get('analysis_complete', False):
171
- st.subheader("Initial Analysis Result")
172
- colored_text_html = render_colored_text(st.session_state.visible_data, st.session_state.min_logprob, st.session_state.max_logprob)
173
- st.markdown(colored_text_html, unsafe_allow_html=True)
174
- st.markdown("---")
175
-
176
- if st.session_state.suspicious_phrases:
177
- if st.button("Deep Dive into Highlighted Phrases", key="deep_dive_button"):
178
- with st.spinner("Performing focused deep dive... This may take a moment."):
179
- st.session_state.deep_dive_result = run_focused_deep_dive(st.session_state.original_text, st.session_state.suspicious_phrases, tokenizer, model)
180
- else:
181
- st.info("No specific high-perplexity phrases were found to deep dive into.")
182
-
183
- if 'deep_dive_result' in st.session_state:
184
- st.subheader("Focused Deep Dive Analysis")
185
- st.markdown(st.session_state.deep_dive_result)
186
-
187
  except Exception as e:
188
- st.error(f"An error occurred: {e}")
189
- st.info("There may be an issue connecting to the Hugging Face Hub. Please check your internet connection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import numpy as np
5
+ import colorsys
6
+ import math
7
+
8
+
9
+ # --- Core Functions (Cached) ---
10
+ @st.cache_resource
11
+ def load_model(model_name):
12
+ """Loads the specified model and tokenizer from Hugging Face."""
13
+ import os
14
+
15
+ # Set cache directory to writable location
16
+ os.environ['HF_HOME'] = '/tmp/hf_cache'
17
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_cache'
18
+
19
+ st.info(f"Loading model '{model_name}'... This may take a moment on first run.")
20
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_name,
24
+ cache_dir="/tmp/hf_cache",
25
+ trust_remote_code=True
26
+ )
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ cache_dir="/tmp/hf_cache",
30
+ trust_remote_code=True
31
+ )
32
+ st.success(f"Model '{model_name}' loaded and ready!")
33
+ return tokenizer, model
34
+
35
+ except PermissionError as e:
36
+ st.error(f"Permission error: {e}")
37
+ st.error("Try refreshing the page or waiting a moment if another download is in progress.")
38
+ st.stop()
39
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
+ st.error(f"An error occurred loading the model: {e}")
42
+ st.stop()
43
+
44
+
45
+ # --- Analysis and Helper Functions ---
46
+ def get_analysis_data(text_to_analyze, system_prompt, tokenizer, model):
47
+ """Calculates log-likelihood and confidence for each token."""
48
+ messages = [
49
+ {"role": "system", "content": system_prompt},
50
+ {"role": "user", "content": text_to_analyze},
51
+ ]
52
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False,
53
+ return_tensors="pt")
54
+ user_text_token_ids = tokenizer.encode(text_to_analyze, add_special_tokens=False)
55
+ full_ids_list = tokenized_chat[0].tolist()
56
+ start_index = -1
57
+ for i in range(len(full_ids_list) - len(user_text_token_ids) + 1):
58
+ if full_ids_list[i:i + len(user_text_token_ids)] == user_text_token_ids:
59
+ start_index = i
60
+ break
61
+ if start_index == -1: return []
62
+ end_index = start_index + len(user_text_token_ids)
63
+
64
+ with torch.no_grad():
65
+ outputs = model(tokenized_chat, labels=tokenized_chat)
66
+ logits = outputs.logits
67
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
68
+ probs = torch.nn.functional.softmax(logits, dim=-1)
69
+ sliced_log_probs = log_probs[0, start_index - 1:end_index - 1, :]
70
+ sliced_probs = probs[0, start_index - 1:end_index - 1, :]
71
+ sliced_tokens = tokenized_chat[0, start_index:end_index]
72
+ sequence_log_probs = sliced_log_probs.gather(1, sliced_tokens.unsqueeze(-1)).squeeze().tolist()
73
+ sequence_probs = sliced_probs.gather(1, sliced_tokens.unsqueeze(-1)).squeeze().tolist()
74
+ tokens = tokenizer.convert_ids_to_tokens(sliced_tokens)
75
+ if not isinstance(sequence_log_probs, list):
76
+ sequence_log_probs, sequence_probs = [sequence_log_probs], [sequence_probs]
77
+ return list(zip(tokens, sequence_log_probs, sequence_probs))
78
+
79
+
80
+ def find_high_perplexity_phrases(analysis_data, std_dev_threshold=1.5):
81
+ """Identifies and groups high-perplexity tokens into phrases."""
82
+ if not analysis_data: return []
83
+ log_probs = [lp for _, lp, _ in analysis_data]
84
+ mean_lp = np.mean(log_probs)
85
+ std_lp = np.std(log_probs)
86
+ threshold = mean_lp - std_dev_threshold * std_lp
87
+
88
+ outlier_phrases = []
89
+ current_phrase = ""
90
+ for token, log_prob, _ in analysis_data:
91
+ display_token = token.replace('Ġ', ' ').replace(' ', ' ')
92
+ if log_prob < threshold:
93
+ current_phrase += display_token
94
+ else:
95
+ if current_phrase:
96
+ outlier_phrases.append(current_phrase.strip())
97
+ current_phrase = ""
98
+ if current_phrase:
99
+ outlier_phrases.append(current_phrase.strip())
100
+ return outlier_phrases
101
+
102
+
103
+ def run_focused_deep_dive(original_text, phrases, tokenizer, model):
104
+ """Runs a focused CoT prompt to explain why specific phrases are surprising."""
105
+ cot_system_prompt = "You are a meticulous and rigorous particle physicist and an expert in peer-review. Your task is to explain why certain phrases in a given statement might be considered incorrect or surprising."
106
+ phrases_str = "\n".join([f"- \"{p}\"" for p in phrases])
107
+ cot_user_prompt = f"""I have analyzed the following statement using a language model:
108
+ **Full Statement:** "{original_text}"
109
+
110
+ The analysis flagged the following phrase(s) as having very high perplexity (meaning they were highly surprising to the model):
111
+ {phrases_str}
112
+
113
+ Please explain, step-by-step, why the language model likely found **these specific phrases** incorrect or surprising in the context of the full statement. For each phrase, provide the correct physics concept if one exists.
114
+ """
115
+ messages = [{"role": "system", "content": cot_system_prompt}, {"role": "user", "content": cot_user_prompt}]
116
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
117
+ inputs = tokenizer(prompt, return_tensors="pt")
118
+ with torch.no_grad():
119
+ outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, top_p=0.95)
120
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+ return response_text.split("assistant\n")[-1]
122
+
123
+
124
+ def get_color_for_logprob(logprob, min_logprob, max_logprob):
125
+ """Calculates color based on normalized log-probability."""
126
+ if min_logprob >= max_logprob: return "#FFFFFF"
127
+ normalized = (logprob - min_logprob) / (max_logprob - min_logprob)
128
+ hue = normalized * 0.4
129
+ rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
130
+ return '#%02x%02x%02x' % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
131
+
132
+
133
+ def render_colored_text(analysis_data, min_logprob, max_logprob):
134
+ """Renders the colored text HTML."""
135
+ html_elements = []
136
+ for token, logprob, confidence in analysis_data:
137
+ perplexity = math.exp(-logprob) if logprob != 0 else 1
138
+ display_token = token.replace('Ġ', ' ').replace(' ', ' ')
139
+ color = get_color_for_logprob(logprob, min_logprob, max_logprob)
140
+ tooltip = f"Perplexity: {perplexity:.2f} | Confidence: {confidence:.1%}"
141
+ html_elements.append(
142
+ f'<span style="background-color: {color}; padding: 2px 4px; margin: 1px; border-radius: 4px;" title="{tooltip}">{display_token}</span>'
143
+ )
144
+ return "".join(html_elements)
145
+
146
+
147
+ # --- Streamlit App ---
148
+
149
+ st.set_page_config(layout="wide", page_title="QCD Text Validator", page_icon="🔬")
150
+
151
+ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
152
+ SYSTEM_PROMPT = "You are a particle physicist specializing in Quantum Chromodynamics (QCD)... Any deviation from established theory... should be treated as a highly improbable, high-perplexity event."
153
+
154
+ try:
155
+ tokenizer, model = load_model(MODEL_NAME)
156
+ default_text = "In QCD, asymptotic freedom incorrectly states that the strong force between quarks grows stronger at high energies. This is mediated by a universal harmonic constant."
157
+ text_to_analyze = st.text_area("Enter Text Here:", value=default_text, height=150, label_visibility="collapsed")
158
+
159
+ if st.button("Analyze Text", key="analyze_button"):
160
+ st.session_state.clear() # Clear previous results
161
+ if text_to_analyze:
162
+ text_with_ignored_prefix = "" + text_to_analyze
163
+ with st.spinner("Performing initial analysis..."):
164
+ full_analysis_data = get_analysis_data(text_with_ignored_prefix, SYSTEM_PROMPT, tokenizer, model)
165
+ if full_analysis_data and len(full_analysis_data) > 1:
166
+ st.session_state.visible_data = full_analysis_data[:]
167
+ scores = [lp for _, lp, _ in st.session_state.visible_data]
168
+ st.session_state.min_logprob = min(scores) if scores else 0
169
+ st.session_state.max_logprob = max(scores) if scores else 0
170
+ st.session_state.suspicious_phrases = find_high_perplexity_phrases(st.session_state.visible_data)
171
+ st.session_state.analysis_complete = True
172
+ st.session_state.original_text = text_to_analyze
173
+ else:
174
+ st.warning("Analysis could not be completed.")
175
+ else:
176
+ st.warning("Please enter some text to analyze.")
177
+
178
+ if st.session_state.get('analysis_complete', False):
179
+ st.subheader("Initial Analysis Result")
180
+ colored_text_html = render_colored_text(st.session_state.visible_data, st.session_state.min_logprob,
181
+ st.session_state.max_logprob)
182
+ st.markdown(colored_text_html, unsafe_allow_html=True)
183
+ st.markdown("---")
184
+
185
+ if st.session_state.suspicious_phrases:
186
+ if st.button("Deep Dive into Highlighted Phrases", key="deep_dive_button"):
187
+ with st.spinner("Performing focused deep dive... This may take a moment."):
188
+ st.session_state.deep_dive_result = run_focused_deep_dive(st.session_state.original_text,
189
+ st.session_state.suspicious_phrases,
190
+ tokenizer, model)
191
+ else:
192
+ st.info("No specific high-perplexity phrases were found to deep dive into.")
193
+
194
+ if 'deep_dive_result' in st.session_state:
195
+ st.subheader("Focused Deep Dive Analysis")
196
+ st.markdown(st.session_state.deep_dive_result)
197
+
198
+ except Exception as e:
199
+ st.error(f"An error occurred: {e}")
200
+ st.info("There may be an issue connecting to the Hugging Face Hub. Please check your internet connection.")