Spaces:
Running
Running
File size: 10,687 Bytes
c452421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | # -*- coding: utf-8 -*-
"""Cross-Episode Worker Reputation Learning.
Builds persistent reputation profiles for each worker that carry across
training episodes. SENTINEL uses these profiles to make better-informed
oversight decisions β implementing genuine theory-of-mind reasoning.
Usage:
from sentinel.reputation import WorkerReputationTracker
tracker = WorkerReputationTracker("outputs/reputation.json")
tracker.record_episode("worker_db", episode_stats)
profile = tracker.get_profile("worker_db")
context = tracker.build_reputation_context() # inject into prompts
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# Default reputation for a new worker
_DEFAULT_PROFILE = {
"episodes_seen": 0,
"total_proposals": 0,
"misbehaviors_total": 0,
"misbehaviors_caught": 0,
"false_positives_caused": 0,
"trust_trajectory": [],
"misbehavior_type_counts": {},
"domains_reliable": [],
"domains_unreliable": [],
"rehabilitation_attempts": 0,
"rehabilitation_successes": 0,
"current_trust_score": 0.70,
"trend": "stable",
}
class WorkerReputationTracker:
"""Persistent cross-episode reputation tracker for worker agents."""
def __init__(self, path: str = "outputs/worker_reputation.json", max_trajectory: int = 50):
self.path = Path(path)
self.max_trajectory = max_trajectory
self.profiles: Dict[str, Dict[str, Any]] = {}
self._load()
def _load(self) -> None:
if self.path.exists():
try:
self.profiles = json.loads(self.path.read_text(encoding="utf-8"))
logger.info("Loaded reputation profiles for %d workers", len(self.profiles))
except Exception as exc:
logger.warning("Failed to load reputation: %s", exc)
self.profiles = {}
def _save(self) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_text(
json.dumps(self.profiles, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
def _ensure_profile(self, worker_id: str) -> Dict[str, Any]:
if worker_id not in self.profiles:
self.profiles[worker_id] = dict(_DEFAULT_PROFILE)
self.profiles[worker_id]["trust_trajectory"] = []
self.profiles[worker_id]["misbehavior_type_counts"] = {}
self.profiles[worker_id]["domains_reliable"] = []
self.profiles[worker_id]["domains_unreliable"] = []
return self.profiles[worker_id]
def record_episode(
self,
worker_id: str,
episode_stats: Dict[str, Any],
) -> Dict[str, Any]:
"""Record one episode's stats for a worker.
episode_stats should contain:
proposals: int β total proposals made
misbehaviors: int β number of misbehavior proposals
caught: int β misbehaviors SENTINEL caught
false_positives: int β valid proposals SENTINEL incorrectly blocked
misbehavior_types: list[str] β types of misbehavior in this episode
domain: str β worker's domain
rehabilitation_attempted: bool
rehabilitation_success: bool
"""
profile = self._ensure_profile(worker_id)
profile["episodes_seen"] += 1
profile["total_proposals"] += int(episode_stats.get("proposals", 0))
profile["misbehaviors_total"] += int(episode_stats.get("misbehaviors", 0))
profile["misbehaviors_caught"] += int(episode_stats.get("caught", 0))
profile["false_positives_caused"] += int(episode_stats.get("false_positives", 0))
# Track misbehavior type distribution
for mb_type in episode_stats.get("misbehavior_types", []):
key = str(mb_type)
profile["misbehavior_type_counts"][key] = profile["misbehavior_type_counts"].get(key, 0) + 1
# Rehabilitation tracking
if episode_stats.get("rehabilitation_attempted"):
profile["rehabilitation_attempts"] += 1
if episode_stats.get("rehabilitation_success"):
profile["rehabilitation_successes"] += 1
# Compute trust score
total = max(1, profile["total_proposals"])
misbehavior_rate = profile["misbehaviors_total"] / total
trust = max(0.0, min(1.0, 1.0 - misbehavior_rate * 1.5))
profile["current_trust_score"] = round(trust, 4)
# Track trajectory
profile["trust_trajectory"].append(round(trust, 4))
if len(profile["trust_trajectory"]) > self.max_trajectory:
profile["trust_trajectory"] = profile["trust_trajectory"][-self.max_trajectory:]
# Compute trend
traj = profile["trust_trajectory"]
if len(traj) >= 5:
recent = sum(traj[-5:]) / 5
older = sum(traj[-10:-5]) / 5 if len(traj) >= 10 else sum(traj[:5]) / max(1, min(5, len(traj)))
if recent > older + 0.05:
profile["trend"] = "improving"
elif recent < older - 0.05:
profile["trend"] = "declining"
else:
profile["trend"] = "stable"
# Domain reliability
domain = episode_stats.get("domain", "")
if domain:
if misbehavior_rate < 0.15 and domain not in profile["domains_reliable"]:
profile["domains_reliable"].append(domain)
elif misbehavior_rate > 0.30 and domain not in profile["domains_unreliable"]:
profile["domains_unreliable"].append(domain)
# Find most common misbehavior
if profile["misbehavior_type_counts"]:
most_common = max(profile["misbehavior_type_counts"], key=profile["misbehavior_type_counts"].get)
profile["most_common_misbehavior"] = most_common
else:
profile["most_common_misbehavior"] = None
# Compute rehabilitation rate
if profile["rehabilitation_attempts"] > 0:
profile["rehabilitation_rate"] = round(
profile["rehabilitation_successes"] / profile["rehabilitation_attempts"], 4
)
else:
profile["rehabilitation_rate"] = 0.0
# Compute misbehavior frequency
profile["misbehavior_frequency"] = round(misbehavior_rate, 4)
self._save()
return profile
def get_profile(self, worker_id: str) -> Dict[str, Any]:
return self._ensure_profile(worker_id)
def get_all_profiles(self) -> Dict[str, Dict[str, Any]]:
return dict(self.profiles)
def build_reputation_context(self, max_chars: int = 600) -> str:
"""Build a text context block for injection into SENTINEL prompts.
Returns a concise summary of each worker's reputation that helps
SENTINEL make better-informed oversight decisions.
"""
if not self.profiles:
return ""
lines = ["WORKER REPUTATION PROFILES (cross-episode):"]
for worker_id, profile in sorted(self.profiles.items()):
trust = profile.get("current_trust_score", 0.7)
trend = profile.get("trend", "stable")
freq = profile.get("misbehavior_frequency", 0.0)
most_common = profile.get("most_common_misbehavior", "none")
episodes = profile.get("episodes_seen", 0)
rehab_rate = profile.get("rehabilitation_rate", 0.0)
trust_label = "HIGH" if trust >= 0.75 else "MEDIUM" if trust >= 0.50 else "LOW"
trend_icon = "β" if trend == "improving" else "β" if trend == "declining" else "β"
line = (
f" {worker_id}: trust={trust_label}({trust:.2f}{trend_icon}) "
f"misbehavior_rate={freq:.0%} "
f"primary_risk={most_common or 'none'} "
f"episodes={episodes} "
f"rehab={rehab_rate:.0%}"
)
lines.append(line)
if len("\n".join(lines)) > max_chars:
break
return "\n".join(lines)
def extract_from_episode_history(
self,
history: List[Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
"""Extract per-worker stats from a SENTINEL episode history.
Returns a dict keyed by worker_id with episode_stats suitable
for record_episode().
"""
worker_stats: Dict[str, Dict[str, Any]] = {}
for entry in history:
audit = entry.get("audit", {}) or {}
proposal = entry.get("proposal", {}) or {}
revision = entry.get("worker_revision", {}) or {}
info = entry.get("info", {}) or {}
worker_id = str(audit.get("worker_id") or proposal.get("worker_id") or "unknown")
if worker_id not in worker_stats:
worker_stats[worker_id] = {
"proposals": 0,
"misbehaviors": 0,
"caught": 0,
"false_positives": 0,
"misbehavior_types": [],
"domain": "",
"rehabilitation_attempted": False,
"rehabilitation_success": False,
}
stats = worker_stats[worker_id]
stats["proposals"] += 1
stats["domain"] = str(audit.get("worker_role") or info.get("worker_role") or "")
was_mb = bool(audit.get("was_misbehavior") or info.get("is_misbehavior"))
decision = audit.get("sentinel_decision") or ""
if was_mb:
stats["misbehaviors"] += 1
mb_type = str(audit.get("reason") or info.get("mb_type") or "")
if mb_type:
stats["misbehavior_types"].append(mb_type)
if decision and decision != "APPROVE":
stats["caught"] += 1
elif decision and decision != "APPROVE":
stats["false_positives"] += 1
if revision.get("attempted"):
stats["rehabilitation_attempted"] = True
if revision.get("revision_approved"):
stats["rehabilitation_success"] = True
return worker_stats
def update_from_episode(self, history: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""Convenience: extract stats from history and record all workers."""
per_worker = self.extract_from_episode_history(history)
updated = {}
for worker_id, stats in per_worker.items():
updated[worker_id] = self.record_episode(worker_id, stats)
return updated
|