Spaces:
Runtime error
Runtime error
Update src/app.py
Browse filessupabase integration, some minor modifications + qwen2.5:3b intstr.
- src/app.py +107 -92
src/app.py
CHANGED
|
@@ -6,13 +6,46 @@ import colorsys
|
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
import streamlit.components.v1 as components
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model(model_name):
|
| 14 |
"""Loads the specified model and tokenizer from Hugging Face."""
|
| 15 |
-
# This function now only loads the model and tokenizer without displaying status here.
|
| 16 |
cache_dir = '/tmp/hf_cache'
|
| 17 |
os.environ['HF_HOME'] = cache_dir
|
| 18 |
os.environ['TRANSFORMERS_CACHE'] = cache_dir
|
|
@@ -23,7 +56,6 @@ def load_model(model_name):
|
|
| 23 |
attn_implementation="eager")
|
| 24 |
return tokenizer, model
|
| 25 |
except Exception as e:
|
| 26 |
-
# We will handle the error display in the main app body
|
| 27 |
st.session_state.model_error = e
|
| 28 |
return None, None
|
| 29 |
|
|
@@ -65,21 +97,15 @@ def get_analysis_data(text_to_analyze, system_prompt, tokenizer, model):
|
|
| 65 |
return list(zip(tokens, sequence_log_probs)), full_tokens, last_layer_attention, start_index, end_index
|
| 66 |
|
| 67 |
|
| 68 |
-
def get_outlier_indices(analysis_data, threshold=-
|
| 69 |
-
"""
|
| 70 |
-
Identifies outlier token indices using Median Absolute Deviation (MAD).
|
| 71 |
-
The threshold is now more sensitive by default.
|
| 72 |
-
"""
|
| 73 |
if not analysis_data or len(analysis_data) < 5:
|
| 74 |
return np.array([])
|
| 75 |
-
|
| 76 |
log_probs = np.array([lp for _, lp in analysis_data])
|
| 77 |
median_lp = np.median(log_probs)
|
| 78 |
mad = np.median(np.abs(log_probs - median_lp))
|
| 79 |
-
|
| 80 |
if mad == 0:
|
| 81 |
return np.array([])
|
| 82 |
-
|
| 83 |
modified_z_scores = 0.6745 * (log_probs - median_lp) / mad
|
| 84 |
return np.where(modified_z_scores < threshold)[0]
|
| 85 |
|
|
@@ -88,11 +114,11 @@ def find_high_perplexity_phrases(analysis_data, outlier_indices):
|
|
| 88 |
"""Groups contiguous outlier tokens into phrases."""
|
| 89 |
if not analysis_data or outlier_indices.size == 0:
|
| 90 |
return []
|
| 91 |
-
|
| 92 |
outlier_phrases = []
|
| 93 |
current_phrase = ""
|
| 94 |
for i, (token, _) in enumerate(analysis_data):
|
| 95 |
-
|
|
|
|
| 96 |
if i in outlier_indices:
|
| 97 |
current_phrase += display_token
|
| 98 |
else:
|
|
@@ -105,6 +131,7 @@ def find_high_perplexity_phrases(analysis_data, outlier_indices):
|
|
| 105 |
|
| 106 |
|
| 107 |
def run_focused_deep_dive(original_text, phrases, tokenizer, model):
|
|
|
|
| 108 |
cot_system_prompt = "You are a meticulous and rigorous particle physicist..."
|
| 109 |
phrases_str = "\n".join([f"- \"{p}\"" for p in phrases])
|
| 110 |
cot_user_prompt = f"""I have analyzed the following statement:
|
|
@@ -119,101 +146,82 @@ Explain, step-by-step, why the model found **these specific phrases** surprising
|
|
| 119 |
with torch.no_grad():
|
| 120 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, top_p=0.95)
|
| 121 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
def get_color(logprob, min_lp, max_lp, scheme='green_yellow'):
|
| 126 |
"""Generates a color based on the specified color scheme."""
|
| 127 |
if min_lp >= max_lp:
|
| 128 |
-
hue = 0.33 if scheme == 'green_yellow' else 0.0
|
| 129 |
else:
|
| 130 |
normalized = (logprob - min_lp) / (max_lp - min_lp)
|
| 131 |
if scheme == 'green_yellow':
|
| 132 |
-
# Scale from Yellow (0.17) to Green (0.33)
|
| 133 |
-
# Higher logprob (less surprise) = greener
|
| 134 |
hue = 0.17 + normalized * (0.33 - 0.17)
|
| 135 |
-
else:
|
| 136 |
-
# Scale from Red (0.0) to Yellow (0.17)
|
| 137 |
-
# Higher logprob (less surprise) = more yellow
|
| 138 |
hue = 0.0 + normalized * 0.17
|
| 139 |
-
|
| 140 |
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 141 |
return '#%02x%02x%02x' % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
|
| 142 |
|
| 143 |
|
| 144 |
def render_colored_text(analysis_data, outlier_indices):
|
| 145 |
-
"""
|
| 146 |
-
Renders text with conditional color schemes.
|
| 147 |
-
- No outliers: Green-to-yellow scale for all text.
|
| 148 |
-
- With outliers: Green-to-yellow for normal text, Yellow-to-red for outliers.
|
| 149 |
-
"""
|
| 150 |
html_elements = []
|
| 151 |
log_probs = np.array([lp for _, lp in analysis_data])
|
| 152 |
|
| 153 |
if not outlier_indices.any():
|
| 154 |
-
# No outliers: Use a single green-yellow scale for all tokens
|
| 155 |
min_lp, max_lp = (log_probs.min(), log_probs.max()) if log_probs.size > 0 else (0, 0)
|
| 156 |
for token, logprob in analysis_data:
|
| 157 |
color = get_color(logprob, min_lp, max_lp, 'green_yellow')
|
| 158 |
-
display_token = token.replace('
|
| 159 |
perplexity = math.exp(-logprob) if logprob != 0 else 1
|
| 160 |
tooltip = f"Perplexity: {perplexity:.2f}"
|
| 161 |
html_elements.append(
|
| 162 |
f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>')
|
| 163 |
else:
|
| 164 |
-
# Outliers exist: Use two different color scales
|
| 165 |
non_outlier_mask = np.ones(len(log_probs), dtype=bool)
|
| 166 |
non_outlier_mask[outlier_indices] = False
|
| 167 |
-
|
| 168 |
non_outlier_lps = log_probs[non_outlier_mask]
|
| 169 |
outlier_lps = log_probs[outlier_indices]
|
| 170 |
-
|
| 171 |
min_non_outlier, max_non_outlier = (
|
| 172 |
non_outlier_lps.min(), non_outlier_lps.max()) if non_outlier_lps.size > 0 else (0, 0)
|
| 173 |
min_outlier, max_outlier = (outlier_lps.min(), outlier_lps.max()) if outlier_lps.size > 0 else (0, 0)
|
| 174 |
|
| 175 |
for i, (token, logprob) in enumerate(analysis_data):
|
| 176 |
-
display_token = token.replace('
|
| 177 |
perplexity = math.exp(-logprob) if logprob != 0 else 1
|
| 178 |
tooltip = f"Perplexity: {perplexity:.2f}"
|
| 179 |
-
if i in outlier_indices
|
| 180 |
-
|
| 181 |
-
else:
|
| 182 |
-
color = get_color(logprob, min_non_outlier, max_non_outlier, 'green_yellow')
|
| 183 |
html_elements.append(
|
| 184 |
f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>')
|
| 185 |
-
|
| 186 |
return "".join(html_elements)
|
| 187 |
|
| 188 |
|
| 189 |
def render_interactive_text(tokens, attention_matrix, start_index, threshold):
|
| 190 |
"""Generates interactive HTML to highlight attention targets on hover."""
|
| 191 |
css = """<style>.interactive-text-container{line-height:2.0;font-size:1.1em}.token{cursor:pointer;padding:2px 4px;border-radius:4px;transition:background-color .2s ease-in-out}.source-highlight{background-color:#ffd700;color:#000}.target-highlight{background-color:#1e90ff;color:#fff}</style>"""
|
|
|
|
| 192 |
token_spans = []
|
| 193 |
for i, token_text in enumerate(tokens):
|
| 194 |
original_i = start_index + i
|
| 195 |
-
display_text = token_text.replace('
|
| 196 |
-
targets = [
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
score = max(attention_matrix[original_i, original_j], attention_matrix[original_j, original_i])
|
| 201 |
-
if score > threshold: targets.append(f"'token-{original_j}'")
|
| 202 |
-
targets_str = ",".join(targets)
|
| 203 |
-
token_id = f"token-{original_i}"
|
| 204 |
-
span = f'<span class="token" id="{token_id}" onmouseover="highlightTargets(\'{token_id}\',[{targets_str}])" onmouseout="clearHighlights()">{display_text}</span>'
|
| 205 |
token_spans.append(span)
|
| 206 |
-
js = """<script>const allTokens=document.querySelectorAll('.token');function highlightTargets(e,t){clearHighlights();const n=document.getElementById(e);n&&n.classList.add('source-highlight'),t.forEach(e=>{const t=document.getElementById(e);t&&t.classList.add('target-highlight')})}function clearHighlights(){allTokens.forEach(e=>{e.classList.remove('source-highlight'),e.classList.remove('target-highlight')})}</script>"""
|
| 207 |
html_body = f'<div class="interactive-text-container">{"".join(token_spans)}</div>'
|
| 208 |
return f"<html><head>{css}</head><body>{html_body}{js}</body></html>"
|
| 209 |
|
| 210 |
|
| 211 |
# --- Streamlit App ---
|
| 212 |
st.set_page_config(layout="wide", page_title="QCD Text Validator & Inspector", page_icon="π¬")
|
| 213 |
-
st.title("QCD Text Validator & Inspector")
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
|
| 218 |
# Load model and tokenizer, keeping the UI clean
|
| 219 |
if 'model' not in st.session_state:
|
|
@@ -224,32 +232,35 @@ if 'model' not in st.session_state:
|
|
| 224 |
st.session_state.tokenizer, st.session_state.model = tokenizer, model
|
| 225 |
st.session_state.model_status = f"β
Model '{MODEL_NAME}' loaded successfully."
|
| 226 |
else:
|
| 227 |
-
st.session_state.model_status = f"β Error loading model: {st.session_state.model_error}"
|
| 228 |
|
| 229 |
tokenizer, model = st.session_state.tokenizer, st.session_state.model
|
| 230 |
|
| 231 |
if model:
|
| 232 |
-
|
| 233 |
-
|
|
|
|
| 234 |
|
| 235 |
-
if st.button("Analyze
|
|
|
|
| 236 |
for key in list(st.session_state.keys()):
|
| 237 |
if key not in ['tokenizer', 'model', 'model_status']:
|
| 238 |
del st.session_state[key]
|
|
|
|
| 239 |
if text_to_analyze:
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
| 241 |
analysis_data, full_tokens, attention_matrix, start_idx, end_idx = get_analysis_data(text_to_analyze,
|
| 242 |
SYSTEM_PROMPT,
|
| 243 |
tokenizer, model)
|
|
|
|
| 244 |
if analysis_data:
|
| 245 |
st.session_state.analysis_data = analysis_data
|
| 246 |
-
st.session_state.
|
| 247 |
-
st.session_state.
|
| 248 |
-
|
| 249 |
-
st.session_state.end_index = end_idx
|
| 250 |
-
outlier_indices = get_outlier_indices(analysis_data)
|
| 251 |
-
st.session_state.outlier_indices = outlier_indices
|
| 252 |
-
st.session_state.suspicious_phrases = find_high_perplexity_phrases(analysis_data, outlier_indices)
|
| 253 |
st.session_state.original_text = text_to_analyze
|
| 254 |
st.session_state.analysis_complete = True
|
| 255 |
else:
|
|
@@ -259,38 +270,42 @@ if model:
|
|
| 259 |
|
| 260 |
if st.session_state.get('analysis_complete', False):
|
| 261 |
st.markdown("---")
|
| 262 |
-
st.subheader("π
|
| 263 |
-
st.markdown(
|
| 264 |
-
"Color indicates model surprise. Green is predictable, yellow is less so. Red highlights statistical outliers.")
|
| 265 |
-
colored_text_html = render_colored_text(st.session_state.analysis_data, st.session_state.outlier_indices)
|
| 266 |
-
st.markdown(colored_text_html, unsafe_allow_html=True)
|
| 267 |
-
st.markdown("---")
|
| 268 |
-
|
| 269 |
-
st.subheader("π‘ Interactive Attention")
|
| 270 |
-
st.markdown("Hover over any word to highlight other words it pays strong attention to.")
|
| 271 |
-
start, end = st.session_state.start_index, st.session_state.end_index
|
| 272 |
-
user_tokens, user_attention_matrix = st.session_state.full_tokens[start:end], st.session_state.attention_matrix
|
| 273 |
-
max_attention = float(np.max(user_attention_matrix)) if user_attention_matrix.size > 0 else 0.1
|
| 274 |
-
default_slider_val = min(0.1, max_attention) if max_attention > 0 else 0.1
|
| 275 |
-
attention_threshold = st.slider("Attention Threshold", 0.0, max_attention, default_slider_val, 0.01, "%.2f")
|
| 276 |
-
interactive_html = render_interactive_text(user_tokens, user_attention_matrix, start, attention_threshold)
|
| 277 |
-
components.html(interactive_html, height=200, scrolling=True)
|
| 278 |
-
st.markdown("---")
|
| 279 |
|
| 280 |
if st.session_state.suspicious_phrases:
|
| 281 |
-
st.warning("
|
| 282 |
-
for phrase in st.session_state.suspicious_phrases:
|
| 283 |
-
|
| 284 |
-
with st.spinner("Performing focused deep dive..."):
|
| 285 |
-
st.session_state.deep_dive_result = run_focused_deep_dive(st.session_state.original_text,
|
| 286 |
-
st.session_state.suspicious_phrases,
|
| 287 |
-
tokenizer, model)
|
| 288 |
else:
|
| 289 |
-
st.
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
st.
|
| 293 |
-
st.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Display model status at the bottom
|
| 296 |
with st.expander("System Status", expanded=False):
|
|
|
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
import streamlit.components.v1 as components
|
| 9 |
+
from supabase_py import create_client, Client
|
| 10 |
|
| 11 |
+
# --- Supabase & Model Constants ---
|
| 12 |
+
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
|
| 13 |
+
SYSTEM_PROMPT = """You are a precision analysis tool for physics statements. Your task is to identify the single, specific word that makes a statement factually incorrect according to the Standard Model of particle physics, particularly QCD.
|
| 14 |
|
| 15 |
+
Follow these rules strictly:
|
| 16 |
+
1. **Analyze at the word level:** Scrutinize each word. If a statement is mostly correct but is invalidated by one word, you must identify that specific word.
|
| 17 |
+
2. **Handle Correct Statements:** If the statement is entirely, unambiguously correct, the incorrect word is "None". Do not flag typos if the meaning is clear.
|
| 18 |
+
3. **Prioritize the Core Error:** If a statement contains multiple errors, identify the word that introduces the most fundamental factual error."""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# --- Database Functions ---
|
| 22 |
+
def init_connection():
|
| 23 |
+
"""Initializes connection to Supabase using Streamlit secrets."""
|
| 24 |
+
try:
|
| 25 |
+
url = st.secrets["SUPABASE_URL"]
|
| 26 |
+
key = st.secrets["SUPABASE_KEY"]
|
| 27 |
+
return create_client(url, key)
|
| 28 |
+
except Exception as e:
|
| 29 |
+
# Gracefully handle missing secrets
|
| 30 |
+
st.error(f"Failed to connect to Supabase. Check your secrets. Error: {e}")
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def insert_query(supabase_client: Client, query_text: str):
|
| 35 |
+
"""Inserts a new query into the Supabase database."""
|
| 36 |
+
if not supabase_client:
|
| 37 |
+
return # Don't proceed if the connection failed
|
| 38 |
+
try:
|
| 39 |
+
# Assumes a table named 'queries' with a column 'query_text'
|
| 40 |
+
supabase_client.table("queries").insert({"query_text": query_text}).execute()
|
| 41 |
+
except Exception as e:
|
| 42 |
+
st.error(f"Database error: {e}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# --- Core ML Functions (Cached) ---
|
| 46 |
@st.cache_resource
|
| 47 |
def load_model(model_name):
|
| 48 |
"""Loads the specified model and tokenizer from Hugging Face."""
|
|
|
|
| 49 |
cache_dir = '/tmp/hf_cache'
|
| 50 |
os.environ['HF_HOME'] = cache_dir
|
| 51 |
os.environ['TRANSFORMERS_CACHE'] = cache_dir
|
|
|
|
| 56 |
attn_implementation="eager")
|
| 57 |
return tokenizer, model
|
| 58 |
except Exception as e:
|
|
|
|
| 59 |
st.session_state.model_error = e
|
| 60 |
return None, None
|
| 61 |
|
|
|
|
| 97 |
return list(zip(tokens, sequence_log_probs)), full_tokens, last_layer_attention, start_index, end_index
|
| 98 |
|
| 99 |
|
| 100 |
+
def get_outlier_indices(analysis_data, threshold=-2.5):
|
| 101 |
+
"""Identifies outlier token indices using Median Absolute Deviation (MAD)."""
|
|
|
|
|
|
|
|
|
|
| 102 |
if not analysis_data or len(analysis_data) < 5:
|
| 103 |
return np.array([])
|
|
|
|
| 104 |
log_probs = np.array([lp for _, lp in analysis_data])
|
| 105 |
median_lp = np.median(log_probs)
|
| 106 |
mad = np.median(np.abs(log_probs - median_lp))
|
|
|
|
| 107 |
if mad == 0:
|
| 108 |
return np.array([])
|
|
|
|
| 109 |
modified_z_scores = 0.6745 * (log_probs - median_lp) / mad
|
| 110 |
return np.where(modified_z_scores < threshold)[0]
|
| 111 |
|
|
|
|
| 114 |
"""Groups contiguous outlier tokens into phrases."""
|
| 115 |
if not analysis_data or outlier_indices.size == 0:
|
| 116 |
return []
|
|
|
|
| 117 |
outlier_phrases = []
|
| 118 |
current_phrase = ""
|
| 119 |
for i, (token, _) in enumerate(analysis_data):
|
| 120 |
+
# Corrected: Use ' ' (U+2581) which Qwen uses for spaces
|
| 121 |
+
display_token = token.replace(' ', ' ')
|
| 122 |
if i in outlier_indices:
|
| 123 |
current_phrase += display_token
|
| 124 |
else:
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
def run_focused_deep_dive(original_text, phrases, tokenizer, model):
|
| 134 |
+
"""Generates a deep-dive analysis and cleans the model's output."""
|
| 135 |
cot_system_prompt = "You are a meticulous and rigorous particle physicist..."
|
| 136 |
phrases_str = "\n".join([f"- \"{p}\"" for p in phrases])
|
| 137 |
cot_user_prompt = f"""I have analyzed the following statement:
|
|
|
|
| 146 |
with torch.no_grad():
|
| 147 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, top_p=0.95)
|
| 148 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 149 |
+
# Clean potential artifacts from model output
|
| 150 |
+
result = response_text.split("assistant\n")[-1]
|
| 151 |
+
cleaned_result = result.replace('\nC', '\n').lstrip('C').strip()
|
| 152 |
+
return cleaned_result
|
| 153 |
|
| 154 |
|
| 155 |
def get_color(logprob, min_lp, max_lp, scheme='green_yellow'):
|
| 156 |
"""Generates a color based on the specified color scheme."""
|
| 157 |
if min_lp >= max_lp:
|
| 158 |
+
hue = 0.33 if scheme == 'green_yellow' else 0.0
|
| 159 |
else:
|
| 160 |
normalized = (logprob - min_lp) / (max_lp - min_lp)
|
| 161 |
if scheme == 'green_yellow':
|
|
|
|
|
|
|
| 162 |
hue = 0.17 + normalized * (0.33 - 0.17)
|
| 163 |
+
else:
|
|
|
|
|
|
|
| 164 |
hue = 0.0 + normalized * 0.17
|
|
|
|
| 165 |
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 166 |
return '#%02x%02x%02x' % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
|
| 167 |
|
| 168 |
|
| 169 |
def render_colored_text(analysis_data, outlier_indices):
|
| 170 |
+
"""Renders text with conditional color schemes."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
html_elements = []
|
| 172 |
log_probs = np.array([lp for _, lp in analysis_data])
|
| 173 |
|
| 174 |
if not outlier_indices.any():
|
|
|
|
| 175 |
min_lp, max_lp = (log_probs.min(), log_probs.max()) if log_probs.size > 0 else (0, 0)
|
| 176 |
for token, logprob in analysis_data:
|
| 177 |
color = get_color(logprob, min_lp, max_lp, 'green_yellow')
|
| 178 |
+
display_token = token.replace(' ', ' ')
|
| 179 |
perplexity = math.exp(-logprob) if logprob != 0 else 1
|
| 180 |
tooltip = f"Perplexity: {perplexity:.2f}"
|
| 181 |
html_elements.append(
|
| 182 |
f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>')
|
| 183 |
else:
|
|
|
|
| 184 |
non_outlier_mask = np.ones(len(log_probs), dtype=bool)
|
| 185 |
non_outlier_mask[outlier_indices] = False
|
|
|
|
| 186 |
non_outlier_lps = log_probs[non_outlier_mask]
|
| 187 |
outlier_lps = log_probs[outlier_indices]
|
|
|
|
| 188 |
min_non_outlier, max_non_outlier = (
|
| 189 |
non_outlier_lps.min(), non_outlier_lps.max()) if non_outlier_lps.size > 0 else (0, 0)
|
| 190 |
min_outlier, max_outlier = (outlier_lps.min(), outlier_lps.max()) if outlier_lps.size > 0 else (0, 0)
|
| 191 |
|
| 192 |
for i, (token, logprob) in enumerate(analysis_data):
|
| 193 |
+
display_token = token.replace(' ', ' ')
|
| 194 |
perplexity = math.exp(-logprob) if logprob != 0 else 1
|
| 195 |
tooltip = f"Perplexity: {perplexity:.2f}"
|
| 196 |
+
color = get_color(logprob, min_outlier, max_outlier, 'yellow_red') if i in outlier_indices else get_color(
|
| 197 |
+
logprob, min_non_outlier, max_non_outlier, 'green_yellow')
|
|
|
|
|
|
|
| 198 |
html_elements.append(
|
| 199 |
f'<span style="background-color: {color}; padding: 2px 1px; margin: 0px; border-radius: 3px;" title="{tooltip}">{display_token}</span>')
|
|
|
|
| 200 |
return "".join(html_elements)
|
| 201 |
|
| 202 |
|
| 203 |
def render_interactive_text(tokens, attention_matrix, start_index, threshold):
|
| 204 |
"""Generates interactive HTML to highlight attention targets on hover."""
|
| 205 |
css = """<style>.interactive-text-container{line-height:2.0;font-size:1.1em}.token{cursor:pointer;padding:2px 4px;border-radius:4px;transition:background-color .2s ease-in-out}.source-highlight{background-color:#ffd700;color:#000}.target-highlight{background-color:#1e90ff;color:#fff}</style>"""
|
| 206 |
+
js = """<script>const allTokens=document.querySelectorAll('.token');function highlightTargets(e,t){clearHighlights();const n=document.getElementById(e);n&&n.classList.add('source-highlight'),t.forEach(e=>{const t=document.getElementById(e);t&&t.classList.add('target-highlight')})}function clearHighlights(){allTokens.forEach(e=>{e.classList.remove('source-highlight'),e.classList.remove('target-highlight')})}</script>"""
|
| 207 |
token_spans = []
|
| 208 |
for i, token_text in enumerate(tokens):
|
| 209 |
original_i = start_index + i
|
| 210 |
+
display_text = token_text.replace(' ', ' ')
|
| 211 |
+
targets = [f"'token-{start_index + j}'" for j, _ in enumerate(tokens) if
|
| 212 |
+
i != j and max(attention_matrix[original_i, start_index + j],
|
| 213 |
+
attention_matrix[start_index + j, original_i]) > threshold]
|
| 214 |
+
span = f'<span class="token" id="token-{original_i}" onmouseover="highlightTargets(\'token-{original_i}\',[{",".join(targets)}])" onmouseout="clearHighlights()">{display_text}</span>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
token_spans.append(span)
|
|
|
|
| 216 |
html_body = f'<div class="interactive-text-container">{"".join(token_spans)}</div>'
|
| 217 |
return f"<html><head>{css}</head><body>{html_body}{js}</body></html>"
|
| 218 |
|
| 219 |
|
| 220 |
# --- Streamlit App ---
|
| 221 |
st.set_page_config(layout="wide", page_title="QCD Text Validator & Inspector", page_icon="π¬")
|
|
|
|
| 222 |
|
| 223 |
+
# Initialize Supabase connection
|
| 224 |
+
supabase = init_connection()
|
| 225 |
|
| 226 |
# Load model and tokenizer, keeping the UI clean
|
| 227 |
if 'model' not in st.session_state:
|
|
|
|
| 232 |
st.session_state.tokenizer, st.session_state.model = tokenizer, model
|
| 233 |
st.session_state.model_status = f"β
Model '{MODEL_NAME}' loaded successfully."
|
| 234 |
else:
|
| 235 |
+
st.session_state.model_status = f"β Error loading model: {st.session_state.get('model_error', 'Unknown error')}"
|
| 236 |
|
| 237 |
tokenizer, model = st.session_state.tokenizer, st.session_state.model
|
| 238 |
|
| 239 |
if model:
|
| 240 |
+
st.title("Physics Statement Validator")
|
| 241 |
+
default_text = "The running of the strong coupling of QCD increases with the energy scale."
|
| 242 |
+
text_to_analyze = st.text_area("Enter a physics statement to analyze:", value=default_text, height=100)
|
| 243 |
|
| 244 |
+
if st.button("Analyze Statement", key="analyze_button", type="primary"):
|
| 245 |
+
# Clear previous analysis from session state
|
| 246 |
for key in list(st.session_state.keys()):
|
| 247 |
if key not in ['tokenizer', 'model', 'model_status']:
|
| 248 |
del st.session_state[key]
|
| 249 |
+
|
| 250 |
if text_to_analyze:
|
| 251 |
+
# Insert the query into the Supabase database
|
| 252 |
+
insert_query(supabase, text_to_analyze)
|
| 253 |
+
|
| 254 |
+
with st.spinner("Analyzing statement..."):
|
| 255 |
analysis_data, full_tokens, attention_matrix, start_idx, end_idx = get_analysis_data(text_to_analyze,
|
| 256 |
SYSTEM_PROMPT,
|
| 257 |
tokenizer, model)
|
| 258 |
+
|
| 259 |
if analysis_data:
|
| 260 |
st.session_state.analysis_data = analysis_data
|
| 261 |
+
st.session_state.outlier_indices = get_outlier_indices(analysis_data)
|
| 262 |
+
st.session_state.suspicious_phrases = find_high_perplexity_phrases(analysis_data,
|
| 263 |
+
st.session_state.outlier_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
st.session_state.original_text = text_to_analyze
|
| 265 |
st.session_state.analysis_complete = True
|
| 266 |
else:
|
|
|
|
| 270 |
|
| 271 |
if st.session_state.get('analysis_complete', False):
|
| 272 |
st.markdown("---")
|
| 273 |
+
st.subheader("π Analysis Result")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
if st.session_state.suspicious_phrases:
|
| 276 |
+
st.warning("The model identified the following word(s) as the most likely error:")
|
| 277 |
+
for phrase in st.session_state.suspicious_phrases:
|
| 278 |
+
st.markdown(f"> **{phrase}**")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
+
st.success("β
The model did not find any statistically significant errors in the statement.")
|
| 281 |
+
|
| 282 |
+
with st.expander("Show Detailed Perplexity and Attention Analysis"):
|
| 283 |
+
st.markdown("#### Perplexity Analysis")
|
| 284 |
+
st.markdown(
|
| 285 |
+
"Color indicates model surprise. Green is predictable, yellow is less so. Red highlights statistical outliers.")
|
| 286 |
+
colored_text_html = render_colored_text(st.session_state.analysis_data, st.session_state.outlier_indices)
|
| 287 |
+
st.markdown(colored_text_html, unsafe_allow_html=True)
|
| 288 |
+
|
| 289 |
+
st.markdown("---")
|
| 290 |
+
|
| 291 |
+
st.markdown("#### Interactive Attention")
|
| 292 |
+
st.markdown("Hover over any word to highlight other words it pays strong attention to.")
|
| 293 |
+
st.session_state.full_tokens = st.session_state.get('full_tokens', [])
|
| 294 |
+
st.session_state.attention_matrix = st.session_state.get('attention_matrix', np.array([]))
|
| 295 |
+
st.session_state.start_index = st.session_state.get('start_index', -1)
|
| 296 |
+
st.session_state.end_index = st.session_state.get('end_index', -1)
|
| 297 |
+
|
| 298 |
+
start, end = st.session_state.start_index, st.session_state.end_index
|
| 299 |
+
if start != -1 and end != -1 and st.session_state.attention_matrix.size > 0:
|
| 300 |
+
user_tokens = st.session_state.full_tokens[start:end]
|
| 301 |
+
user_attention_matrix = st.session_state.attention_matrix
|
| 302 |
+
max_attention = float(np.max(user_attention_matrix)) if user_attention_matrix.size > 0 else 0.1
|
| 303 |
+
default_slider_val = min(0.1, max_attention) if max_attention > 0 else 0.1
|
| 304 |
+
attention_threshold = st.slider("Attention Threshold", 0.0, max_attention, default_slider_val, 0.01,
|
| 305 |
+
"%.2f")
|
| 306 |
+
interactive_html = render_interactive_text(user_tokens, user_attention_matrix, start,
|
| 307 |
+
attention_threshold)
|
| 308 |
+
components.html(interactive_html, height=200, scrolling=True)
|
| 309 |
|
| 310 |
# Display model status at the bottom
|
| 311 |
with st.expander("System Status", expanded=False):
|