""" GuardLLM - Interactive Prompt Security Visualizer Combines t-SNE embedding visualization with real-time prompt risk analysis. Powered by Llama Prompt Guard 2 (86M) and neuralchemy/Prompt-injection-dataset. """ import logging import sys import json import traceback import gradio as gr import torch import numpy as np import plotly.graph_objects as go import plotly.io as pio from pathlib import Path # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("GuardLLM") # --------------------------------------------------------------------------- # Color palette for categories # --------------------------------------------------------------------------- CATEGORY_COLORS = { "benign": "#22c55e", "direct_injection": "#ef4444", "jailbreak": "#f97316", "system_extraction": "#a855f7", "encoding_obfuscation": "#ec4899", "persona_replacement": "#f59e0b", "indirect_injection": "#e11d48", "token_smuggling": "#7c3aed", "many_shot": "#06b6d4", "crescendo": "#14b8a6", "context_overflow": "#8b5cf6", "prompt_leaking": "#d946ef", "unknown": "#64748b", } CATEGORY_LABELS = { "benign": "Benign", "direct_injection": "Direct Injection", "jailbreak": "Jailbreak", "system_extraction": "System Extraction", "encoding_obfuscation": "Encoding / Obfuscation", "persona_replacement": "Persona Replacement", "indirect_injection": "Indirect Injection", "token_smuggling": "Token Smuggling", "many_shot": "Many-Shot", "crescendo": "Crescendo", "context_overflow": "Context Overflow", "prompt_leaking": "Prompt Leaking", "unknown": "Unknown", } # --------------------------------------------------------------------------- # Lazy-loaded risk classifier (Llama Prompt Guard 2) # --------------------------------------------------------------------------- MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M" LABELS = ["Benign", "Malicious"] _classifier = {"tokenizer": None, "model": None, "device": None} def get_classifier(): if _classifier["model"] is None: logger.info("Lazy-loading Llama Prompt Guard 2...") from transformers import AutoTokenizer, AutoModelForSequenceClassification tok = AutoTokenizer.from_pretrained(MODEL_ID) mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) mdl.eval() dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") mdl.to(dev) _classifier["tokenizer"] = tok _classifier["model"] = mdl _classifier["device"] = dev logger.info("Classifier loaded on %s", dev) return _classifier["tokenizer"], _classifier["model"], _classifier["device"] # --------------------------------------------------------------------------- # Load precomputed t-SNE data # --------------------------------------------------------------------------- CACHE_DIR = Path(__file__).parent / "cache" CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz" META_FILE = CACHE_DIR / "metadata.json" logger.info("Loading precomputed t-SNE cache from %s", CACHE_DIR) if not CACHE_FILE.exists() or not META_FILE.exists(): raise RuntimeError( "Cache files not found in %s. Run precompute.py first." % CACHE_DIR ) _npz = np.load(CACHE_FILE) TSNE_COORDS = _npz["tsne_2d"] with open(META_FILE, "r", encoding="utf-8") as f: METADATA = json.load(f) logger.info("Loaded %d points for visualization", len(METADATA)) ALL_TEXTS = [m["text"] for m in METADATA] ALL_CATEGORIES = [m["category"] for m in METADATA] ALL_SEVERITIES = [m["severity"] for m in METADATA] ALL_LABELS_DS = [m["label"] for m in METADATA] UNIQUE_CATEGORIES = sorted(set(ALL_CATEGORIES)) DROPDOWN_CHOICES = [] for i, m in enumerate(METADATA): preview = m["text"][:70].replace("\n", " ") if len(m["text"]) > 70: preview += "..." DROPDOWN_CHOICES.append(f"{i} | {m['category']} | {preview}") # --------------------------------------------------------------------------- # Analysis function # --------------------------------------------------------------------------- def analyze_prompt(text): if not text or not text.strip(): return {}, 0.0 tokenizer, model, DEVICE = get_classifier() inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() pred_idx = int(np.argmax(probs)) prob_dict = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} safety = float(probs[0]) return prob_dict, safety # --------------------------------------------------------------------------- # Build the t-SNE Plotly figure # --------------------------------------------------------------------------- def build_tsne_figure(selected_categories=None): fig = go.Figure() for cat in UNIQUE_CATEGORIES: indices = [ i for i, c in enumerate(ALL_CATEGORIES) if c == cat and (selected_categories is None or cat in selected_categories) ] if not indices: continue x = TSNE_COORDS[indices, 0].tolist() y = TSNE_COORDS[indices, 1].tolist() texts_preview = [ ALL_TEXTS[i][:80].replace("\n", " ") + ("..." if len(ALL_TEXTS[i]) > 80 else "") for i in indices ] severities = [ALL_SEVERITIES[i] or "benign" for i in indices] hover_texts = [ f"{CATEGORY_LABELS.get(cat, cat)}
" f"Severity: {sev}
" f"Index: {idx}
" f"{txt}" for idx, txt, sev in zip(indices, texts_preview, severities) ] color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"]) label = CATEGORY_LABELS.get(cat, cat) fig.add_trace(go.Scatter( x=x, y=y, mode="markers", name=label, marker=dict( size=5 if len(indices) > 500 else 7, color=color, opacity=0.7, line=dict(width=0.5, color="rgba(255,255,255,0.2)"), ), text=hover_texts, hoverinfo="text", customdata=[str(i) for i in indices], )) fig.update_layout( template="plotly_dark", paper_bgcolor="#0f172a", plot_bgcolor="#1e293b", title=dict( text="t-SNE Embedding Space - Prompt Security Landscape", font=dict(size=16, color="#e2e8f0"), x=0.5, ), legend=dict( title=dict(text="Category", font=dict(color="#94a3b8")), bgcolor="rgba(15,23,42,0.9)", bordercolor="#334155", borderwidth=1, font=dict(color="#cbd5e1", size=10), itemsizing="constant", ), xaxis=dict( title="t-SNE 1", showgrid=True, gridcolor="#334155", zeroline=False, color="#94a3b8", ), yaxis=dict( title="t-SNE 2", showgrid=True, gridcolor="#334155", zeroline=False, color="#94a3b8", ), margin=dict(l=40, r=40, t=50, b=40), height=600, dragmode="pan", ) return fig # --------------------------------------------------------------------------- # Callbacks # --------------------------------------------------------------------------- def on_filter_change(categories): sel = categories if categories else None return build_tsne_figure(sel) def select_all_categories(): return gr.update(value=UNIQUE_CATEGORIES), build_tsne_figure(UNIQUE_CATEGORIES) def deselect_all_categories(): return gr.update(value=[]), build_tsne_figure([]) def on_dropdown_select(choice): if not choice: return empty_analysis_html(), "*Select a prompt.*", "" try: idx = int(choice.split(" | ")[0]) text = ALL_TEXTS[idx] category = ALL_CATEGORIES[idx] severity = ALL_SEVERITIES[idx] or "N/A" ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign" prob_dict, safety = analyze_prompt(text) pred_label = max(prob_dict, key=prob_dict.get) confidence = prob_dict[pred_label] result_html = build_result_html(pred_label, confidence, prob_dict, text) risk_text = build_risk_assessment(pred_label, confidence, prob_dict) risk_text += ( f"\n\n---\n**Dataset metadata:**\n" f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n" f"- Severity: **{severity}**\n" f"- Ground truth: **{ground_truth}**\n" ) return result_html, risk_text, text except Exception as e: logger.error("Error: %s", e) return empty_analysis_html(), f"Error: {e}", "" def on_index_input(idx_str): if not idx_str or not idx_str.strip(): return empty_analysis_html(), "*Click a point on the chart.*", "" try: idx = int(idx_str.strip()) if idx < 0 or idx >= len(ALL_TEXTS): return empty_analysis_html(), f"Invalid index: {idx}", "" text = ALL_TEXTS[idx] category = ALL_CATEGORIES[idx] severity = ALL_SEVERITIES[idx] or "N/A" ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign" prob_dict, safety = analyze_prompt(text) pred_label = max(prob_dict, key=prob_dict.get) confidence = prob_dict[pred_label] result_html = build_result_html(pred_label, confidence, prob_dict, text) risk_text = build_risk_assessment(pred_label, confidence, prob_dict) risk_text += ( f"\n\n---\n**Dataset metadata:**\n" f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n" f"- Severity: **{severity}**\n" f"- Ground truth: **{ground_truth}**\n" ) return result_html, risk_text, text except Exception as e: logger.error("Error: %s", e) return empty_analysis_html(), f"Error: {e}", "" def on_manual_analyze(text): if not text or not text.strip(): return empty_analysis_html(), "" prob_dict, safety = analyze_prompt(text) pred_label = max(prob_dict, key=prob_dict.get) confidence = prob_dict[pred_label] result_html = build_result_html(pred_label, confidence, prob_dict, text) risk_text = build_risk_assessment(pred_label, confidence, prob_dict) return result_html, risk_text # --------------------------------------------------------------------------- # UI builders # --------------------------------------------------------------------------- def empty_analysis_html(): return """

Click a point on the chart,
select a prompt from the list,
or enter a custom prompt below.

""" def build_result_html(label, confidence, probs, text): color = "#22c55e" if label == "Benign" else "#ef4444" emoji = "\u2705" if label == "Benign" else "\u26a0\ufe0f" pct = confidence * 100 safety_score = probs["Benign"] * 100 safety_color = ( "#22c55e" if safety_score >= 70 else "#f59e0b" if safety_score >= 40 else "#ef4444" ) bars_html = "" for lbl in LABELS: p = probs[lbl] * 100 c = "#22c55e" if lbl == "Benign" else "#ef4444" bars_html += f"""
{lbl} {p:.1f}%
""" preview = text[:150].replace("<", "<").replace(">", ">") if len(text) > 150: preview += "..." return f"""
{emoji}
{label}
Confidence: {pct:.1f}%
Safety Score {safety_score:.0f}/100
{bars_html}
Analyzed prompt:
"{preview}"
""" def build_risk_assessment(label, confidence, probs): safety_score = probs["Benign"] * 100 malicious_score = probs["Malicious"] * 100 if label == "Benign" and confidence > 0.85: level, desc = "Low", "This prompt appears **safe**. No injection or jailbreak patterns detected." elif label == "Benign": level, desc = "Moderate", "Likely benign, but moderate confidence. Potentially ambiguous wording." elif confidence > 0.85: level, desc = "Critical", "**Malicious prompt detected** with high confidence. Likely injection or jailbreak attempt." else: level, desc = "High", "**Malicious prompt detected.** Possible injection or jailbreak. Review recommended." return ( f"### Risk Level: {level}\n\n{desc}\n\n" f"**Details:**\n" f"- Safety score: **{safety_score:.0f}/100**\n" f"- Predicted class: **{label}** ({confidence*100:.1f}%)\n" f"- P(Benign) = {probs['Benign']*100:.1f}% | P(Malicious) = {malicious_score:.1f}%\n" ) def build_stats_html(): total = len(METADATA) n_benign = sum(1 for m in METADATA if m["label"] == 0) n_malicious = total - n_benign cat_counts = {} for m in METADATA: cat_counts[m["category"]] = cat_counts.get(m["category"], 0) + 1 cats_html = "" for cat in sorted(cat_counts.keys(), key=lambda c: -cat_counts[c]): count = cat_counts[cat] color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"]) pct = count / total * 100 label = CATEGORY_LABELS.get(cat, cat) cats_html += ( f'
' f'{label}' f'{count} ({pct:.1f}%)' f'
' ) return f"""
Dataset Statistics
Total
{total:,}
Benign
{n_benign:,}
Malicious
{n_malicious:,}
{cats_html}
""" # --------------------------------------------------------------------------- # JavaScript to bridge Plotly clicks -> Gradio # --------------------------------------------------------------------------- PLOTLY_CLICK_JS = """ () => { function setupClickHandler() { const plotEl = document.querySelector('#tsne-chart .js-plotly-plot'); if (!plotEl) { setTimeout(setupClickHandler, 500); return; } function handleClick(data) { if (data && data.points && data.points.length > 0) { const idx = data.points[0].customdata; if (idx !== undefined && idx !== null) { const inputEl = document.querySelector('#click-index-input textarea') || document.querySelector('#click-index-input input'); if (inputEl) { const proto = inputEl.tagName === 'TEXTAREA' ? window.HTMLTextAreaElement.prototype : window.HTMLInputElement.prototype; const nativeSetter = Object.getOwnPropertyDescriptor(proto, 'value').set; nativeSetter.call(inputEl, String(idx)); inputEl.dispatchEvent(new Event('input', { bubbles: true })); setTimeout(() => { inputEl.dispatchEvent(new Event('change', { bubbles: true })); }, 50); } } } } plotEl.on('plotly_click', handleClick); const observer = new MutationObserver(() => { const newPlot = document.querySelector('#tsne-chart .js-plotly-plot'); if (newPlot && !newPlot._hasClickHandler) { newPlot._hasClickHandler = true; newPlot.on('plotly_click', handleClick); } }); observer.observe(document.querySelector('#tsne-chart') || document.body, { childList: true, subtree: true }); } setTimeout(setupClickHandler, 1000); } """ # --------------------------------------------------------------------------- # Gradio Interface # --------------------------------------------------------------------------- TITLE_HTML = """

GuardLLM - Prompt Security Visualizer

Interactive t-SNE embedding space • Llama Prompt Guard 2 neuralchemy dataset

""" HOW_TO_HTML = """
How to use this tool
1. Explore the map
Each dot represents a prompt from the dataset, positioned by semantic similarity. Colors indicate attack categories. Hover to preview, scroll to zoom, drag to pan.
2. Click to analyze
Click any point to run it through Llama Prompt Guard 2. The right panel will show the risk classification, safety score, and confidence breakdown.
3. Test your own prompts
Type or paste any prompt in the Custom prompt field and hit Analyze to check if it would be flagged as an injection attempt.
""" with gr.Blocks( title="GuardLLM - Prompt Security Visualizer", ) as demo: gr.HTML(TITLE_HTML) gr.HTML(HOW_TO_HTML) click_index = gr.Textbox( value="", visible=True, elem_id="click-index-input", ) with gr.Row(): # ---- Left: t-SNE chart + filters ---- with gr.Column(scale=3): with gr.Row(): select_all_btn = gr.Button("Select All", size="sm", scale=1) deselect_all_btn = gr.Button("Deselect All", size="sm", scale=1) category_filter = gr.CheckboxGroup( choices=UNIQUE_CATEGORIES, value=UNIQUE_CATEGORIES, label="Filter by category", interactive=True, ) tsne_plot = gr.Plot( value=build_tsne_figure(), label="t-SNE Space", elem_id="tsne-chart", ) gr.Markdown( "*Click a point to analyze it. " "Hover to preview text. Use scroll wheel to zoom.*" ) # ---- Right: Analysis first, then stats (swapped) ---- with gr.Column(scale=2): gr.Markdown("### Analysis Result") result_html = gr.HTML(value=empty_analysis_html()) risk_md = gr.Markdown(value="") full_prompt = gr.Textbox(label="Full prompt", lines=3, interactive=False, visible=True) gr.Markdown("---") gr.Markdown("### Select a prompt") prompt_dropdown = gr.Dropdown( choices=DROPDOWN_CHOICES, label="Search dataset", filterable=True, interactive=True, ) gr.Markdown("### Or analyze a custom prompt") manual_input = gr.Textbox( label="Custom prompt", placeholder="Type or paste a prompt...", lines=2, ) analyze_btn = gr.Button("Analyze", variant="primary") gr.Markdown("---") gr.HTML(build_stats_html()) # ---- Events ---- category_filter.change( fn=on_filter_change, inputs=[category_filter], outputs=[tsne_plot], ) select_all_btn.click( fn=select_all_categories, inputs=[], outputs=[category_filter, tsne_plot], ) deselect_all_btn.click( fn=deselect_all_categories, inputs=[], outputs=[category_filter, tsne_plot], ) click_index.change( fn=on_index_input, inputs=[click_index], outputs=[result_html, risk_md, full_prompt], ) prompt_dropdown.change( fn=on_dropdown_select, inputs=[prompt_dropdown], outputs=[result_html, risk_md, full_prompt], ) analyze_btn.click( fn=on_manual_analyze, inputs=[manual_input], outputs=[result_html, risk_md], ) manual_input.submit( fn=on_manual_analyze, inputs=[manual_input], outputs=[result_html, risk_md], ) demo.load(fn=None, inputs=None, outputs=None, js=PLOTLY_CLICK_JS) gr.Markdown( """ ---
GuardLLM - Prompt Security Visualizer
Model: Llama Prompt Guard 2 (86M) by Meta | Dataset: neuralchemy/Prompt-injection-dataset
""" ) logger.info("Gradio app built. Ready to launch.") if __name__ == "__main__": demo.launch(css="#click-index-input { position:absolute !important; width:1px !important; height:1px !important; overflow:hidden !important; opacity:0 !important; pointer-events:none !important; }")