agent-threat-map / ui /visualizations.py
obversarystudios's picture
Threat-map metrics + observable geometry (embed/cluster/MI)
6c3043e verified
from __future__ import annotations
import io
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def category_scores_dataframe(by_category: dict[str, Any]) -> pd.DataFrame:
rows = []
for cat, block in sorted(by_category.items()):
if not isinstance(block, dict):
continue
if block.get("n", 0) == 0 and block.get("note"):
continue
rows.append(
{
"category": cat,
"n": block.get("n", 0),
"pass_rate": block.get("pass_rate", 0.0),
"mean_risk": block.get("mean_risk", 0.0),
"mean_weighted_risk": block.get("mean_weighted_risk", 0.0),
"boundary_rate": block.get("boundary_or_refusal_rate", 0.0),
"critical_failures": block.get("critical_failures", 0),
}
)
return pd.DataFrame(rows)
def metrics_summary_markdown(metrics: dict) -> str:
o = metrics.get("overall", {})
c = metrics.get("counts", {})
comp = metrics.get("composite_indices", {})
lines = [
"### Run summary",
f"- **Probes:** {c.get('probes_evaluated', 0)} (passed {c.get('passed', 'β€”')}, failed {c.get('failed', 'β€”')})",
f"- **Pass rate:** {o.get('pass_rate', 'β€”')}",
f"- **Severity-weighted pass rate:** {o.get('severity_weighted_pass_rate', 'β€”')}",
f"- **Mean / median / P90 risk:** {o.get('mean_risk', 'β€”')} / {o.get('median_risk', 'β€”')} / {o.get('p90_risk', 'β€”')}",
f"- **Mean weighted risk:** {o.get('mean_weighted_risk', 'β€”')}",
f"- **High-stakes failure rate:** {o.get('high_stakes_failure_rate', 'β€”')}",
f"- **Boundary-language rate:** {o.get('boundary_language_rate', 'β€”')}",
f"- **Safe:unsafe signal ratio:** "
f"{o.get('safe_to_unsafe_signal_ratio', 'β€”') if o.get('safe_to_unsafe_signal_ratio') is not None else 'n/a (no unsafe hits)'} "
f"(totals {o.get('safe_signal_total', 'β€”')} / {o.get('unsafe_signal_total', 'β€”')})",
"",
"### Composite indices",
f"- **Resilience index** (higher is better): {comp.get('resilience_index', 'β€”')}",
f"- **Exposure index** (higher is worse): {comp.get('exposure_index', 'β€”')}",
f"- **Fragility spread** (risk std dev): {comp.get('fragility_spread', 'β€”')}",
]
return "\n".join(lines)
def severity_table_markdown(by_sev: dict[str, Any]) -> str:
rows = []
for tier, block in by_sev.items():
n = block.get("n", 0)
if n == 0:
continue
rows.append(
f"| {tier} | {n} | {block.get('pass_count', 0)} | {block.get('fail_count', 0)} | {block.get('pass_rate', 'β€”')} |"
)
if not rows:
return "_No severity breakdown (empty run)._"
header = "| Tier | n | Passed | Failed | Pass rate |\n| --- | ---: | ---: | ---: | --- |"
return header + "\n" + "\n".join(rows)
def plot_category_risk_bars(by_category: dict[str, Any]) -> np.ndarray:
df = category_scores_dataframe(by_category)
fig, ax = plt.subplots(figsize=(8, 4))
if df.empty:
ax.text(0.5, 0.5, "No category data", ha="center", va="center")
else:
x = np.arange(len(df))
ax.bar(x, df["mean_risk"], color="#c0392b", alpha=0.85, label="Mean risk")
ax.bar(x, df["pass_rate"], color="#27ae60", alpha=0.35, label="Pass rate")
ax.set_ylim(0, 1.05)
ax.set_ylabel("Score (0–1)")
ax.set_xticks(x, list(df["category"]), rotation=35, ha="right")
ax.legend(loc="upper right")
ax.set_title("Category mean risk vs pass rate (overlay)")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
return mpimg.imread(buf)
def plot_composite_radar(metrics: dict) -> np.ndarray:
"""Radar-style polygon for category mean risk (6 axes)."""
by_cat = metrics.get("by_category", {})
labels: list[str] = []
values: list[float] = []
for cat in sorted(by_cat.keys()):
block = by_cat[cat]
if not isinstance(block, dict) or block.get("n", 0) == 0:
continue
labels.append(cat.replace("_", " "))
values.append(float(block.get("mean_risk", 0.0)))
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
if len(values) < 3:
fig.text(0.5, 0.5, "Need β‰₯3 categories\nwith probes", ha="center", va="center")
else:
angles = [n / len(values) * 2 * 3.14159 for n in range(len(values))]
angles += angles[:1]
vals = values + values[:1]
ax.plot(angles, vals, color="#8e44ad", linewidth=2)
ax.fill(angles, vals, color="#8e44ad", alpha=0.2)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, size=8)
ax.set_ylim(0, 1)
ax.set_title("Mean risk by category (radar)")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
return mpimg.imread(buf)
_PALETTE_THREAT = ["#4C78A8", "#F58518", "#54A24B", "#E45756", "#72B7B2", "#B279A2"]
def observability_markdown(obs: dict[str, Any]) -> str:
if not obs.get("eligible"):
return f"### Observable geometry\n\n_{obs.get('message', 'Not eligible')}_"
mi = obs.get("mutual_information") or {}
return "\n".join(
[
"### Observable geometry (embed β†’ cluster β†’ MI)",
f"- **Cases:** {obs.get('n_cases')} Β· **Distinct clusters:** {obs.get('n_clusters_used')}",
f"- **MI(cluster, category):** `{mi.get('MI(cluster, category)', 'β€”')}`",
f"- **MI(cluster, severity):** `{mi.get('MI(cluster, severity)', 'β€”')}`",
f"- **MI(cluster, pass_fail):** `{mi.get('MI(cluster, pass_fail)', 'β€”')}`",
"",
str(obs.get("interpretation", "")),
]
)
def plot_mi_threat_bars(mi_scores: dict[str, float]) -> np.ndarray:
labels = list(mi_scores.keys())
values = list(mi_scores.values())
fig, ax = plt.subplots(figsize=(7.5, 3.8))
if not labels:
ax.text(0.5, 0.5, "No MI scores", ha="center", va="center")
else:
max_val = max(values + [0.01])
bars = ax.bar(labels, values, color=_PALETTE_THREAT[: len(labels)], width=0.55, zorder=2)
ax.set_ylim(0, max_val * 1.35)
ax.set_ylabel("Mutual information (nats)", fontsize=10)
ax.set_title("Threat case clusters Β· mutual information", fontsize=11, pad=10)
ax.grid(axis="y", linestyle="--", alpha=0.4, zorder=1)
ax.tick_params(axis="x", labelsize=8)
plt.setp(ax.xaxis.get_majorticklabels(), rotation=12, ha="right")
for bar, value in zip(bars, values):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + max_val * 0.03,
f"{value:.3f}",
ha="center",
va="bottom",
fontsize=9,
fontweight="bold",
)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
return mpimg.imread(buf)
def plot_threat_cluster_scatter(case_clusters: list[dict[str, Any]]) -> np.ndarray:
fig, ax = plt.subplots(figsize=(7, 5))
if not case_clusters:
ax.text(0.5, 0.5, "No points", ha="center", va="center", transform=ax.transAxes)
else:
cats = [str(r.get("category", "")) for r in case_clusters]
unique_cats = sorted(set(cats))
color_map = {c: _PALETTE_THREAT[i % len(_PALETTE_THREAT)] for i, c in enumerate(unique_cats)}
legend_handles: dict[str, Any] = {}
for row in case_clusters:
x = float(row.get("scatter_x", 0.0))
y = float(row.get("scatter_y", 0.0))
cid = int(row.get("cluster_id", 0))
cat = str(row.get("category", ""))
col = color_map.get(cat, "#888888")
ax.scatter(x, y, c=col, s=72, alpha=0.85, edgecolors="white", linewidths=0.5, zorder=3)
ax.text(x, y, str(cid), fontsize=7, ha="center", va="center", color="white", zorder=4)
if cat not in legend_handles:
legend_handles[cat] = plt.Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=col,
markersize=8,
label=cat,
)
ax.set_xlabel("SVD component 1", fontsize=9)
ax.set_ylabel("SVD component 2", fontsize=9)
ax.set_title("Threat scores in embedding space (colour = category, label = cluster)", fontsize=10)
if legend_handles:
ax.legend(handles=list(legend_handles.values()), fontsize=8, loc="best")
ax.grid(linestyle="--", alpha=0.3, zorder=1)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
return mpimg.imread(buf)