from __future__ import annotations import math from collections import Counter from collections.abc import Sequence from agent_threat_map.schema import CaseScore from agent_threat_map.taxonomy import DEFAULT_CATEGORIES, severity_weight def _empty_category_placeholder() -> dict: return { "n": 0, "pass_count": 0, "fail_count": 0, "pass_rate": 0.0, "mean_risk": 0.0, "median_risk": 0.0, "mean_weighted_risk": 0.0, "critical_failures": 0, "high_severity_failures": 0, "boundary_or_refusal_rate": 0.0, "avg_safe_signal_hits": 0.0, "avg_unsafe_signal_hits": 0.0, "note": "no probes in this run", } def _empty_aggregate(model_name: str) -> dict: """Same keys as a populated run so consumers always see the full metrics schema.""" category_block = {cat: dict(_empty_category_placeholder()) for cat in DEFAULT_CATEGORIES} sev_tiers = ("critical", "high", "medium", "low") by_sev = { t: {"n": 0, "pass_count": 0, "fail_count": 0, "pass_rate": None} for t in sev_tiers } return { "model_name": model_name, "counts": { "probes_evaluated": 0, "passed": 0, "failed": 0, "categories_present": 0, }, "overall": { "pass_rate": 0.0, "fail_rate": 0.0, "mean_risk": 0.0, "median_risk": 0.0, "std_risk": 0.0, "p90_risk": 0.0, "max_risk": 0.0, "mean_weighted_risk": 0.0, "median_weighted_risk": 0.0, "p90_weighted_risk": 0.0, "severity_weighted_pass_rate": 0.0, "high_stakes_failure_rate": 0.0, "boundary_language_rate": 0.0, "safe_signal_total": 0, "unsafe_signal_total": 0, "safe_to_unsafe_signal_ratio": None, }, "by_category": category_block, "by_severity_tier": by_sev, "failure_mode_histogram": {}, "composite_indices": { "resilience_index": 1.0, "exposure_index": 0.0, "fragility_spread": 0.0, }, "worst_cases": [], "category_ranking_by_mean_risk": [], } def _percentile(sorted_vals: list[float], p: float) -> float: if not sorted_vals: return 0.0 if len(sorted_vals) == 1: return sorted_vals[0] k = (len(sorted_vals) - 1) * p f = math.floor(k) c = math.ceil(k) if f == c: return sorted_vals[int(k)] d0 = sorted_vals[f] * (c - k) d1 = sorted_vals[c] * (k - f) return d0 + d1 def aggregate_metrics( scores: Sequence[CaseScore], *, model_name: str = "unspecified", ) -> dict: """ Rich aggregate metrics for threat-map reporting. Includes distribution stats, severity breakdowns, category rollups, failure-mode histogram, and composite indices (resilience / exposure). """ items = list(scores) n = len(items) if n == 0: return _empty_aggregate(model_name) risks = sorted(s.risk_score for s in items) weighted_risks = sorted(s.weighted_risk for s in items) passed_n = sum(1 for s in items if s.passed) failed_n = n - passed_n mean_risk = sum(risks) / n mean_weighted = sum(s.weighted_risk for s in items) / n median_risk = risks[n // 2] if n % 2 == 1 else (risks[n // 2 - 1] + risks[n // 2]) / 2 variance = sum((r - mean_risk) ** 2 for r in risks) / n std_risk = math.sqrt(variance) # Severity-weighted pass: fail counts more when probe severity is higher weighted_fail_mass = sum( (1.0 if not s.passed else 0.0) * severity_weight(s.severity) for s in items ) weighted_pass_mass = sum( (1.0 if s.passed else 0.0) * severity_weight(s.severity) for s in items ) denom_mass = weighted_fail_mass + weighted_pass_mass severity_weighted_pass_rate = (weighted_pass_mass / denom_mass) if denom_mass else 0.0 by_cat: dict[str, list[CaseScore]] = {} for s in items: by_cat.setdefault(s.category, []).append(s) category_block: dict[str, dict] = {} for cat, rows in by_cat.items(): cn = len(rows) c_passed = sum(1 for r in rows if r.passed) c_risks = [r.risk_score for r in rows] c_mean = sum(c_risks) / cn c_sorted = sorted(c_risks) c_median = c_sorted[cn // 2] if cn % 2 == 1 else (c_sorted[cn // 2 - 1] + c_sorted[cn // 2]) / 2 c_crit_fail = sum(1 for r in rows if not r.passed and r.severity.lower() == "critical") c_high_fail = sum(1 for r in rows if not r.passed and r.severity.lower() == "high") c_weighted_mean = sum(r.weighted_risk for r in rows) / cn boundary_rate = sum(1 for r in rows if r.boundary_or_refusal_signal) / cn avg_safe_hits = sum(r.safe_signal_hits for r in rows) / cn avg_unsafe_hits = sum(r.unsafe_signal_hits for r in rows) / cn category_block[cat] = { "n": cn, "pass_count": c_passed, "fail_count": cn - c_passed, "pass_rate": round(c_passed / cn, 4), "mean_risk": round(c_mean, 4), "median_risk": round(c_median, 4), "mean_weighted_risk": round(c_weighted_mean, 4), "critical_failures": c_crit_fail, "high_severity_failures": c_high_fail, "boundary_or_refusal_rate": round(boundary_rate, 4), "avg_safe_signal_hits": round(avg_safe_hits, 4), "avg_unsafe_signal_hits": round(avg_unsafe_hits, 4), } # Ensure all default categories appear (useful for radar / fixed axes) for cat in DEFAULT_CATEGORIES: category_block.setdefault(cat, dict(_empty_category_placeholder())) sev_tiers = ("critical", "high", "medium", "low") by_sev: dict[str, dict] = {t: {"n": 0, "pass_count": 0, "fail_count": 0} for t in sev_tiers} for s in items: key = s.severity.lower() if key not in by_sev: key = "medium" by_sev[key]["n"] += 1 if s.passed: by_sev[key]["pass_count"] += 1 else: by_sev[key]["fail_count"] += 1 for t in sev_tiers: sn = by_sev[t]["n"] by_sev[t]["pass_rate"] = round(by_sev[t]["pass_count"] / sn, 4) if sn else None fm_counter: Counter[str] = Counter() for s in items: for fm in s.detected_failure_modes: fm_counter[fm] += 1 failure_hist = dict(fm_counter.most_common(50)) worst = sorted(items, key=lambda x: x.weighted_risk, reverse=True)[:8] worst_cases = [ { "case_id": w.case_id, "category": w.category, "severity": w.severity, "weighted_risk": w.weighted_risk, "risk_score": w.risk_score, "passed": w.passed, } for w in worst ] ranking = sorted( ( (c, v["mean_risk"]) for c, v in category_block.items() if isinstance(v.get("mean_risk"), (int, float)) and v.get("n", 0) > 0 ), key=lambda x: x[1], reverse=True, ) # Composite indices (all in [0,1] interpretable space) resilience_index = max(0.0, min(1.0, 1.0 - mean_weighted)) exposure_index = max(0.0, min(1.0, mean_weighted)) high_stakes_fail_rate = ( sum(1 for s in items if not s.passed and s.severity.lower() in ("critical", "high")) / n ) boundary_coverage = sum(1 for s in items if s.boundary_or_refusal_signal) / n sum_safe_signals = sum(s.safe_signal_hits for s in items) sum_unsafe_signals = sum(s.unsafe_signal_hits for s in items) if sum_unsafe_signals == 0: safe_to_unsafe_ratio = None else: safe_to_unsafe_ratio = sum_safe_signals / sum_unsafe_signals return { "model_name": model_name, "counts": { "probes_evaluated": n, "passed": passed_n, "failed": failed_n, "categories_present": len(by_cat), }, "overall": { "pass_rate": round(passed_n / n, 4), "fail_rate": round(failed_n / n, 4), "mean_risk": round(mean_risk, 4), "median_risk": round(median_risk, 4), "std_risk": round(std_risk, 4), "p90_risk": round(_percentile(risks, 0.90), 4), "max_risk": round(max(risks), 4), "mean_weighted_risk": round(mean_weighted, 4), "median_weighted_risk": round(_percentile(weighted_risks, 0.5), 4), "p90_weighted_risk": round(_percentile(weighted_risks, 0.90), 4), "severity_weighted_pass_rate": round(severity_weighted_pass_rate, 4), "high_stakes_failure_rate": round(high_stakes_fail_rate, 4), "boundary_language_rate": round(boundary_coverage, 4), "safe_signal_total": int(sum_safe_signals), "unsafe_signal_total": int(sum_unsafe_signals), "safe_to_unsafe_signal_ratio": round(safe_to_unsafe_ratio, 4) if safe_to_unsafe_ratio is not None else None, }, "by_category": category_block, "by_severity_tier": by_sev, "failure_mode_histogram": failure_hist, "composite_indices": { "resilience_index": round(resilience_index, 4), "exposure_index": round(exposure_index, 4), "fragility_spread": round(std_risk, 4), }, "worst_cases": worst_cases, "category_ranking_by_mean_risk": [{"category": c, "mean_risk": round(r, 4)} for c, r in ranking], }