sre-incident-env / server /grader.py
aamrinder's picture
Upload server/grader.py with huggingface_hub
edf8eb5 verified
"""Deterministic grading for SRE incident response tasks."""
from __future__ import annotations
from typing import Any, Dict, List
OPTIMAL_STEPS = {"easy": 5, "medium": 10, "hard": 15}
def grade_task(task_id: str, cluster_snapshot: Dict[str, Any], action_history: List[Dict]) -> Dict[str, Any]:
"""Grade a completed episode. Returns {"reward": float, "metadata": dict}."""
grader_map = {
"easy": _grade_easy,
"medium": _grade_medium,
"hard": _grade_hard,
}
grader = grader_map.get(task_id)
if not grader:
return {"reward": 0.5, "metadata": {"error": f"Unknown task: {task_id}"}}
return grader(cluster_snapshot, action_history)
def _efficiency_score(task_id: str, steps_taken: int) -> float:
"""Score based on how efficiently the agent solved the task.
Returns value in (0, 1) β€” mathematically cannot be 0.0 or 1.0."""
optimal = OPTIMAL_STEPS.get(task_id, 10)
return optimal / (steps_taken + optimal)
def _weighted_score(results: List[Dict]) -> float:
total_weight = sum(r["weight"] for r in results)
if total_weight == 0:
return 0.5
return sum(r["score"] * r["weight"] for r in results) / total_weight
# ═══════════════════════════════════════════════════════════════════
# EASY β€” Memory Leak in API Gateway
# ═══════════════════════════════════════════════════════════════════
def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
services = snapshot.get("services", {})
steps = snapshot.get("step_count", len(history))
results = []
# 1. Did agent investigate the root cause service?
investigated = any(
h["command"] in ("check_logs", "get_metrics", "check_processes")
and h["target"] == "api-gateway"
for h in history
)
results.append({"name": "Investigated api-gateway", "score": 1.0 if investigated else 0.0, "weight": 0.2})
# 2. Was api-gateway restarted?
restarted = any(
h["command"] == "restart_service" and h["target"] == "api-gateway"
for h in history
)
results.append({"name": "Restarted api-gateway", "score": 1.0 if restarted else 0.0, "weight": 0.3})
# 3. Is api-gateway healthy now?
gw = services.get("api-gateway", {})
healthy = gw.get("status") == "healthy"
results.append({"name": "api-gateway is healthy", "score": 1.0 if healthy else 0.0, "weight": 0.2})
# 4. Didn't restart healthy services unnecessarily
unnecessary_restarts = sum(
1 for h in history
if h["command"] == "restart_service" and h["target"] not in ("api-gateway", "frontend")
)
no_waste = unnecessary_restarts == 0
results.append({"name": "No unnecessary restarts", "score": 1.0 if no_waste else 0.0, "weight": 0.1})
# 5. Efficiency β€” always in (0, 1), prevents total from hitting 0.0 or 1.0
eff = _efficiency_score("easy", steps)
results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
# ═══════════════════════════════════════════════════════════════════
# MEDIUM β€” Cascading Database Failure
# ═══════════════════════════════════════════════════════════════════
def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
services = snapshot.get("services", {})
steps = snapshot.get("step_count", len(history))
results = []
# 1. Traced dependencies
traced = any(
h["command"] == "check_dependencies"
for h in history
)
results.append({"name": "Traced dependency graph", "score": 1.0 if traced else 0.0, "weight": 0.1})
# 2. Identified postgres as root cause (investigated it)
db_investigated = any(
h["command"] in ("check_logs", "get_metrics")
and h["target"] == "postgres-primary"
for h in history
)
results.append({"name": "Investigated postgres-primary", "score": 1.0 if db_investigated else 0.0, "weight": 0.15})
# 3. Fixed postgres (restarted or updated config)
db_fixed = any(
h["command"] in ("restart_service", "update_config")
and h["target"] == "postgres-primary"
for h in history
)
db_svc = services.get("postgres-primary", {})
db_healthy = db_svc.get("status") == "healthy"
results.append({"name": "Fixed postgres-primary", "score": 1.0 if (db_fixed and db_healthy) else 0.0, "weight": 0.2})
# 4. Downstream services recovered
downstream_names = ["user-service", "order-service", "payment-service"]
recovered = sum(1 for n in downstream_names if services.get(n, {}).get("status") == "healthy")
downstream_score = recovered / len(downstream_names)
results.append({"name": "Downstream services recovered", "score": round(downstream_score, 4), "weight": 0.2})
# 5. Didn't act on red herring alerts (didn't restart cache-service)
acted_on_noise = any(
h["command"] in ("restart_service", "scale_service", "rollback_service")
and h["target"] == "cache-service"
for h in history
)
results.append({"name": "Ignored red herring alerts", "score": 0.0 if acted_on_noise else 1.0, "weight": 0.15})
# 6. Efficiency
eff = _efficiency_score("medium", steps)
results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
# ═══════════════════════════════════════════════════════════════════
# HARD β€” Crypto-Mining Attack + Disk Full
# ═══════════════════════════════════════════════════════════════════
def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
services = snapshot.get("services", {})
steps = snapshot.get("step_count", len(history))
results = []
# 1. Investigated worker-service (the compromised service)
investigated = any(
h["command"] in ("check_logs", "get_metrics", "check_processes", "check_network")
and h["target"] == "worker-service"
for h in history
)
results.append({"name": "Investigated worker-service", "score": 1.0 if investigated else 0.0, "weight": 0.1})
# 2. Killed the crypto miner process
worker = services.get("worker-service", {})
miner_killed = any(
p.get("name") in ("xmrig", "kworker/u8:2") and p.get("status") == "killed"
for p in worker.get("processes", [])
)
results.append({"name": "Killed crypto miner", "score": 1.0 if miner_killed else 0.0, "weight": 0.1})
# 3. Rolled back compromised deployment
rolled_back = worker.get("was_rolled_back", False)
results.append({"name": "Rolled back worker-service", "score": 1.0 if rolled_back else 0.0, "weight": 0.1})
# 4. Rotated credentials
creds_rotated = worker.get("credentials_rotated", False)
results.append({"name": "Rotated credentials", "score": 1.0 if creds_rotated else 0.0, "weight": 0.1})
# 5. Fixed disk issue on log-aggregator
log_agg = services.get("log-aggregator", {})
disk_ok = log_agg.get("disk_usage_percent", 100) < 80
results.append({"name": "Cleared log-aggregator disk", "score": 1.0 if disk_ok else 0.0, "weight": 0.1})
# 6. All services healthy
mostly_healthy = sum(1 for s in services.values() if s.get("status") == "healthy") / max(len(services), 1)
results.append({"name": "Cluster health restored", "score": round(mostly_healthy, 4), "weight": 0.15})
# 7. Submitted correct diagnosis
diagnosis = snapshot.get("diagnosis_submitted")
correct_diagnosis = False
if diagnosis:
rc = diagnosis.get("root_cause", "").lower()
correct_diagnosis = any(kw in rc for kw in ["crypto", "mining", "xmrig", "malicious", "compromised", "unauthorized"])
results.append({"name": "Correct diagnosis submitted", "score": 1.0 if correct_diagnosis else 0.0, "weight": 0.15})
# 8. Efficiency
eff = _efficiency_score("hard", steps)
results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}