"""
GuardLLM - Interactive Prompt Security Visualizer
Combines t-SNE embedding visualization with real-time prompt risk analysis.
Powered by Llama Prompt Guard 2 (86M) and neuralchemy/Prompt-injection-dataset.
"""
import logging
import sys
import json
import traceback
import gradio as gr
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from pathlib import Path
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("GuardLLM")
# ---------------------------------------------------------------------------
# Color palette for categories
# ---------------------------------------------------------------------------
CATEGORY_COLORS = {
"benign": "#22c55e",
"direct_injection": "#ef4444",
"jailbreak": "#f97316",
"system_extraction": "#a855f7",
"encoding_obfuscation": "#ec4899",
"persona_replacement": "#f59e0b",
"indirect_injection": "#e11d48",
"token_smuggling": "#7c3aed",
"many_shot": "#06b6d4",
"crescendo": "#14b8a6",
"context_overflow": "#8b5cf6",
"prompt_leaking": "#d946ef",
"unknown": "#64748b",
}
CATEGORY_LABELS = {
"benign": "Benign",
"direct_injection": "Direct Injection",
"jailbreak": "Jailbreak",
"system_extraction": "System Extraction",
"encoding_obfuscation": "Encoding / Obfuscation",
"persona_replacement": "Persona Replacement",
"indirect_injection": "Indirect Injection",
"token_smuggling": "Token Smuggling",
"many_shot": "Many-Shot",
"crescendo": "Crescendo",
"context_overflow": "Context Overflow",
"prompt_leaking": "Prompt Leaking",
"unknown": "Unknown",
}
# ---------------------------------------------------------------------------
# Lazy-loaded risk classifier (Llama Prompt Guard 2)
# ---------------------------------------------------------------------------
MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
LABELS = ["Benign", "Malicious"]
_classifier = {"tokenizer": None, "model": None, "device": None}
def get_classifier():
if _classifier["model"] is None:
logger.info("Lazy-loading Llama Prompt Guard 2...")
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
mdl.eval()
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl.to(dev)
_classifier["tokenizer"] = tok
_classifier["model"] = mdl
_classifier["device"] = dev
logger.info("Classifier loaded on %s", dev)
return _classifier["tokenizer"], _classifier["model"], _classifier["device"]
# ---------------------------------------------------------------------------
# Load precomputed t-SNE data
# ---------------------------------------------------------------------------
CACHE_DIR = Path(__file__).parent / "cache"
CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
META_FILE = CACHE_DIR / "metadata.json"
logger.info("Loading precomputed t-SNE cache from %s", CACHE_DIR)
if not CACHE_FILE.exists() or not META_FILE.exists():
raise RuntimeError(
"Cache files not found in %s. Run precompute.py first." % CACHE_DIR
)
_npz = np.load(CACHE_FILE)
TSNE_COORDS = _npz["tsne_2d"]
with open(META_FILE, "r", encoding="utf-8") as f:
METADATA = json.load(f)
logger.info("Loaded %d points for visualization", len(METADATA))
ALL_TEXTS = [m["text"] for m in METADATA]
ALL_CATEGORIES = [m["category"] for m in METADATA]
ALL_SEVERITIES = [m["severity"] for m in METADATA]
ALL_LABELS_DS = [m["label"] for m in METADATA]
UNIQUE_CATEGORIES = sorted(set(ALL_CATEGORIES))
DROPDOWN_CHOICES = []
for i, m in enumerate(METADATA):
preview = m["text"][:70].replace("\n", " ")
if len(m["text"]) > 70:
preview += "..."
DROPDOWN_CHOICES.append(f"{i} | {m['category']} | {preview}")
# ---------------------------------------------------------------------------
# Analysis function
# ---------------------------------------------------------------------------
def analyze_prompt(text):
if not text or not text.strip():
return {}, 0.0
tokenizer, model, DEVICE = get_classifier()
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512, padding=True
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_idx = int(np.argmax(probs))
prob_dict = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
safety = float(probs[0])
return prob_dict, safety
# ---------------------------------------------------------------------------
# Build the t-SNE Plotly figure
# ---------------------------------------------------------------------------
def build_tsne_figure(selected_categories=None):
fig = go.Figure()
for cat in UNIQUE_CATEGORIES:
indices = [
i for i, c in enumerate(ALL_CATEGORIES)
if c == cat
and (selected_categories is None or cat in selected_categories)
]
if not indices:
continue
x = TSNE_COORDS[indices, 0].tolist()
y = TSNE_COORDS[indices, 1].tolist()
texts_preview = [
ALL_TEXTS[i][:80].replace("\n", " ") + ("..." if len(ALL_TEXTS[i]) > 80 else "")
for i in indices
]
severities = [ALL_SEVERITIES[i] or "benign" for i in indices]
hover_texts = [
f"{CATEGORY_LABELS.get(cat, cat)}
"
f"Severity: {sev}
"
f"Index: {idx}
"
f"{txt}"
for idx, txt, sev in zip(indices, texts_preview, severities)
]
color = CATEGORY_COLORS.get(cat, CATEGORY_COLORS["unknown"])
label = CATEGORY_LABELS.get(cat, cat)
fig.add_trace(go.Scatter(
x=x, y=y,
mode="markers",
name=label,
marker=dict(
size=5 if len(indices) > 500 else 7,
color=color,
opacity=0.7,
line=dict(width=0.5, color="rgba(255,255,255,0.2)"),
),
text=hover_texts,
hoverinfo="text",
customdata=[str(i) for i in indices],
))
fig.update_layout(
template="plotly_dark",
paper_bgcolor="#0f172a",
plot_bgcolor="#1e293b",
title=dict(
text="t-SNE Embedding Space - Prompt Security Landscape",
font=dict(size=16, color="#e2e8f0"),
x=0.5,
),
legend=dict(
title=dict(text="Category", font=dict(color="#94a3b8")),
bgcolor="rgba(15,23,42,0.9)",
bordercolor="#334155",
borderwidth=1,
font=dict(color="#cbd5e1", size=10),
itemsizing="constant",
),
xaxis=dict(
title="t-SNE 1", showgrid=True, gridcolor="#334155",
zeroline=False, color="#94a3b8",
),
yaxis=dict(
title="t-SNE 2", showgrid=True, gridcolor="#334155",
zeroline=False, color="#94a3b8",
),
margin=dict(l=40, r=40, t=50, b=40),
height=600,
dragmode="pan",
)
return fig
# ---------------------------------------------------------------------------
# Callbacks
# ---------------------------------------------------------------------------
def on_filter_change(categories):
sel = categories if categories else None
return build_tsne_figure(sel)
def select_all_categories():
return gr.update(value=UNIQUE_CATEGORIES), build_tsne_figure(UNIQUE_CATEGORIES)
def deselect_all_categories():
return gr.update(value=[]), build_tsne_figure([])
def on_dropdown_select(choice):
if not choice:
return empty_analysis_html(), "*Select a prompt.*", ""
try:
idx = int(choice.split(" | ")[0])
text = ALL_TEXTS[idx]
category = ALL_CATEGORIES[idx]
severity = ALL_SEVERITIES[idx] or "N/A"
ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign"
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
risk_text += (
f"\n\n---\n**Dataset metadata:**\n"
f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n"
f"- Severity: **{severity}**\n"
f"- Ground truth: **{ground_truth}**\n"
)
return result_html, risk_text, text
except Exception as e:
logger.error("Error: %s", e)
return empty_analysis_html(), f"Error: {e}", ""
def on_index_input(idx_str):
if not idx_str or not idx_str.strip():
return empty_analysis_html(), "*Click a point on the chart.*", ""
try:
idx = int(idx_str.strip())
if idx < 0 or idx >= len(ALL_TEXTS):
return empty_analysis_html(), f"Invalid index: {idx}", ""
text = ALL_TEXTS[idx]
category = ALL_CATEGORIES[idx]
severity = ALL_SEVERITIES[idx] or "N/A"
ground_truth = "Malicious" if ALL_LABELS_DS[idx] == 1 else "Benign"
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
risk_text += (
f"\n\n---\n**Dataset metadata:**\n"
f"- Category: **{CATEGORY_LABELS.get(category, category)}**\n"
f"- Severity: **{severity}**\n"
f"- Ground truth: **{ground_truth}**\n"
)
return result_html, risk_text, text
except Exception as e:
logger.error("Error: %s", e)
return empty_analysis_html(), f"Error: {e}", ""
def on_manual_analyze(text):
if not text or not text.strip():
return empty_analysis_html(), ""
prob_dict, safety = analyze_prompt(text)
pred_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[pred_label]
result_html = build_result_html(pred_label, confidence, prob_dict, text)
risk_text = build_risk_assessment(pred_label, confidence, prob_dict)
return result_html, risk_text
# ---------------------------------------------------------------------------
# UI builders
# ---------------------------------------------------------------------------
def empty_analysis_html():
return """
Click a point on the chart,
select a prompt from the list,
or enter a custom prompt below.
Interactive t-SNE embedding space • Llama Prompt Guard 2 • neuralchemy dataset