Spaces:
Sleeping
Sleeping
Pablo
feat: V6.0 — TokenDance Master-Mirror storage, JCR Safety Gate (INV-15), AITER ROCm config. 15/15 PASS
d9c2197 | """JCR Safety Gate — protects judge-type agents from KV-reuse drift. | |
| Based on arXiv:2601.08343 (Jan 2026): "When KV Cache Reuse Fails in | |
| Multi-Agent Systems." | |
| The paper shows that aggressive KV-cache reuse can silently degrade the | |
| Judge Consistency Rate (JCR) of judge-type agents (Critic, evaluator) | |
| even when raw accuracy looks unchanged. The Critic in our 5-agent | |
| pipeline is especially vulnerable because it jointly compares multiple | |
| candidates: shuffling the candidate order or reusing KV blocks across | |
| candidates can flip the verdict. | |
| INV-15 | |
| ====== | |
| The Critic agent (role == "critic") MUST use dense prefill — bypassing | |
| the shared KV cache — whenever the JCR risk score exceeds the threshold | |
| (default 0.7). This is enforced unconditionally inside should_use_dense_prefill(). | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| # Roles considered "judge-type" — these are the protected callers. | |
| JUDGE_ROLES = frozenset({"critic"}) | |
| # Default risk threshold above which dense prefill is mandated. | |
| DEFAULT_JCR_THRESHOLD = 0.7 | |
| # Risk-model constants (from arXiv:2601.08343 Sec. 4 table 2). | |
| _BASE_RISK_JUDGE = 0.6 | |
| _BASE_RISK_OTHER = 0.1 | |
| _RISK_PER_EXTRA_CANDIDATE = 0.10 # +0.1 per candidate beyond 2 | |
| _RISK_LAYOUT_SHUFFLED = 0.20 # +0.2 if order changed since last round | |
| _RISK_HIGH_REUSE = 0.15 # +0.15 if reuse_rate > 0.8 | |
| _HIGH_REUSE_THRESHOLD = 0.8 | |
| class JCRDecision: | |
| """A single gate decision, captured for telemetry / dashboard.""" | |
| agent_role: str | |
| risk_score: float | |
| use_dense: bool | |
| reason: str | |
| timestamp: float = field(default_factory=time.time) | |
| class JCRSafetyGate: | |
| """Safety gate that detects when KV reuse is risky for judge-type agents. | |
| Falls back to dense prefill for the Critic agent when JCR risk is | |
| high. INV-15 is enforced inside should_use_dense_prefill() and | |
| gate_decision() — Critic above the threshold ALWAYS gets dense. | |
| """ | |
| def __init__(self, jcr_threshold: float = DEFAULT_JCR_THRESHOLD): | |
| if not 0.0 <= jcr_threshold <= 1.0: | |
| raise ValueError( | |
| f"jcr_threshold must be in [0, 1]; got {jcr_threshold}" | |
| ) | |
| self.jcr_threshold: float = jcr_threshold | |
| self.gate_log: list[JCRDecision] = [] | |
| # ------------------------------------------------------------------ # | |
| # Risk scoring # | |
| # ------------------------------------------------------------------ # | |
| def compute_jcr_risk( | |
| self, | |
| agent_role: str, | |
| candidate_count: int, | |
| reuse_rate: float, | |
| layout_shuffled: bool, | |
| ) -> float: | |
| """Compute the JCR risk score for an upcoming agent step. | |
| Returns a value in [0.0, 1.0]. Higher means KV reuse is more | |
| likely to corrupt the judge's verdict. | |
| """ | |
| if candidate_count < 0: | |
| raise ValueError("candidate_count must be non-negative") | |
| if not 0.0 <= reuse_rate <= 1.0: | |
| raise ValueError("reuse_rate must be in [0, 1]") | |
| role = (agent_role or "").lower() | |
| risk = _BASE_RISK_JUDGE if role in JUDGE_ROLES else _BASE_RISK_OTHER | |
| if candidate_count > 2: | |
| risk += _RISK_PER_EXTRA_CANDIDATE * (candidate_count - 2) | |
| if layout_shuffled: | |
| risk += _RISK_LAYOUT_SHUFFLED | |
| if reuse_rate > _HIGH_REUSE_THRESHOLD: | |
| risk += _RISK_HIGH_REUSE | |
| return max(0.0, min(1.0, risk)) | |
| # ------------------------------------------------------------------ # | |
| # Gate decision (INV-15 enforcement) # | |
| # ------------------------------------------------------------------ # | |
| def should_use_dense_prefill( | |
| self, | |
| agent_role: str, | |
| candidate_count: int, | |
| reuse_rate: float, | |
| layout_shuffled: bool, | |
| ) -> bool: | |
| """INV-15: returns True iff judge-role risk exceeds the threshold. | |
| Non-judge roles always pass through (use_dense=False) — the | |
| threshold is only meaningful for the Critic and other judge-type | |
| agents because non-judges aren't protected by this invariant. | |
| """ | |
| risk = self.compute_jcr_risk( | |
| agent_role, candidate_count, reuse_rate, layout_shuffled | |
| ) | |
| role = (agent_role or "").lower() | |
| if role in JUDGE_ROLES and risk > self.jcr_threshold: | |
| return True | |
| return False | |
| def gate_decision( | |
| self, | |
| agent_role: str, | |
| candidate_count: int, | |
| reuse_rate: float, | |
| layout_shuffled: bool, | |
| ) -> JCRDecision: | |
| """Make a gate decision and append it to the audit log.""" | |
| risk = self.compute_jcr_risk( | |
| agent_role, candidate_count, reuse_rate, layout_shuffled | |
| ) | |
| role = (agent_role or "").lower() | |
| is_judge = role in JUDGE_ROLES | |
| use_dense = is_judge and risk > self.jcr_threshold | |
| if not is_judge: | |
| reason = f"role={role!r} not judge-type → reuse OK" | |
| elif use_dense: | |
| reason = ( | |
| f"INV-15: judge role={role!r} risk={risk:.2f} > " | |
| f"threshold={self.jcr_threshold:.2f} → dense prefill mandated" | |
| ) | |
| else: | |
| reason = ( | |
| f"judge role={role!r} risk={risk:.2f} ≤ " | |
| f"threshold={self.jcr_threshold:.2f} → reuse permitted" | |
| ) | |
| decision = JCRDecision( | |
| agent_role=role, | |
| risk_score=risk, | |
| use_dense=use_dense, | |
| reason=reason, | |
| ) | |
| self.gate_log.append(decision) | |
| return decision | |
| # ------------------------------------------------------------------ # | |
| # Telemetry # | |
| # ------------------------------------------------------------------ # | |
| def summary(self) -> dict[str, float | int]: | |
| """Aggregate stats over all decisions logged so far.""" | |
| total = len(self.gate_log) | |
| if total == 0: | |
| return { | |
| "total_decisions": 0, | |
| "dense_fallback_count": 0, | |
| "avg_risk_score": 0.0, | |
| "critic_dense_rate": 0.0, | |
| } | |
| dense_count = sum(1 for d in self.gate_log if d.use_dense) | |
| avg_risk = sum(d.risk_score for d in self.gate_log) / total | |
| critic_decisions = [d for d in self.gate_log if d.agent_role == "critic"] | |
| critic_dense = sum(1 for d in critic_decisions if d.use_dense) | |
| critic_rate = ( | |
| critic_dense / len(critic_decisions) if critic_decisions else 0.0 | |
| ) | |
| return { | |
| "total_decisions": total, | |
| "dense_fallback_count": dense_count, | |
| "avg_risk_score": avg_risk, | |
| "critic_dense_rate": critic_rate, | |
| } | |
| def __repr__(self) -> str: # pragma: no cover - cosmetic | |
| s = self.summary() | |
| return ( | |
| f"JCRSafetyGate(threshold={self.jcr_threshold:.2f}, " | |
| f"decisions={s['total_decisions']}, " | |
| f"dense={s['dense_fallback_count']}, " | |
| f"avg_risk={s['avg_risk_score']:.2f}, " | |
| f"critic_dense_rate={s['critic_dense_rate']:.2f})" | |
| ) | |