logtriage-env / server /graders /cascade_grader.py
OGrohit's picture
Day 4: grader system complete
fbb0927
"""
Grader for Task 2 β€” Cascading Failure (Medium)
Scoring breakdown:
Correct severity (P1) β†’ +0.20
Correct root cause (user-db) β†’ +0.35
Correct remediation (kill-query/restart) β†’ +0.25
Ordering bonus (no symptom fix first) β†’ +0.10
Speed bonus (resolved ≀ 8 steps) β†’ +0.10
─────────────────────────────────────────────────
Maximum possible score β†’ 1.00
Penalties:
Identified symptom as root cause β†’ 0.00 (no credit)
Remediated symptom service first β†’ -0.10 (ordering penalty)
Never resolved β†’ -0.10
"""
from __future__ import annotations
from server.models import EpisodeState
from server.graders.base_grader import BaseGrader
class CascadeGrader(BaseGrader):
"""Official grader for Task 2 β€” Cascading Failure."""
CORRECT_SEVERITY = "P1"
CORRECT_ROOT_CAUSE = "user-db"
CORRECT_REMEDIATION_PREFIXES = {"kill-query", "restart"}
CORRECT_REMEDIATION_SERVICE = "user-db"
SYMPTOM_SERVICES = {"api-gateway", "auth-service"} # wrong answers
MAX_STEPS = 12
SPEED_THRESHOLD = 8
def score(self, state: EpisodeState) -> float:
"""
Score the completed Task 2 episode.
Penalizes agents that treat symptoms instead of root cause.
"""
total = 0.0
breakdown = {}
# ── 1. Severity classification ─────────────────────────────────────────
severity_value = self._get_first_value(state, "classify_severity")
if severity_value == self.CORRECT_SEVERITY:
total += 0.20
breakdown["severity"] = "+0.20 (correct: P1)"
elif severity_value == "P2":
total += 0.08
breakdown["severity"] = "+0.08 (partial: P2 given, P1 expected)"
elif severity_value is None:
breakdown["severity"] = "+0.00 (never classified)"
else:
breakdown["severity"] = f"+0.00 (wrong: {severity_value})"
# ── 2. Root cause identification ───────────────────────────────────────
root_cause_value = self._get_first_value(state, "identify_root_cause")
if root_cause_value == self.CORRECT_ROOT_CAUSE:
total += 0.35
breakdown["root_cause"] = "+0.35 (correct: user-db)"
elif root_cause_value in self.SYMPTOM_SERVICES:
# Identified a symptom, not root cause β€” no credit
breakdown["root_cause"] = f"+0.00 (wrong: {root_cause_value} is a symptom, not root cause)"
elif root_cause_value and "db" in root_cause_value:
total += 0.10 # right tier (database), wrong specific service
breakdown["root_cause"] = f"+0.10 (partial: {root_cause_value}, right tier)"
elif root_cause_value is None:
breakdown["root_cause"] = "+0.00 (never identified)"
else:
breakdown["root_cause"] = f"+0.00 (wrong: {root_cause_value})"
# ── 3. Remediation + Ordering ──────────────────────────────────────────
remediation_actions = self._get_actions_of_type(state, "remediate")
remediation_scored = False
symptom_remediated_first = False
for i, action in enumerate(remediation_actions):
value = action.get("value", "")
parts = value.split(":")
if len(parts) != 2:
continue
prefix, service = parts
# Check if agent remediated a symptom service before root cause
if service in self.SYMPTOM_SERVICES and not remediation_scored:
symptom_remediated_first = True
# Check for correct remediation
if (
prefix in self.CORRECT_REMEDIATION_PREFIXES
and service == self.CORRECT_REMEDIATION_SERVICE
and not remediation_scored
):
total += 0.25
breakdown["remediation"] = f"+0.25 (correct: {value})"
remediation_scored = True
if not remediation_scored:
breakdown["remediation"] = "+0.00 (no correct remediation)"
# ── 4. Ordering bonus ──────────────────────────────────────────────────
if not symptom_remediated_first and remediation_scored:
total += 0.10
breakdown["ordering"] = "+0.10 (correctly targeted root cause, not symptoms)"
elif symptom_remediated_first:
total -= 0.10
breakdown["ordering"] = "-0.10 (remediated symptom service before root cause)"
# ── 5. Speed bonus ─────────────────────────────────────────────────────
if self._episode_resolved(state):
if self._steps_used(state) <= self.SPEED_THRESHOLD:
total += 0.10
breakdown["speed"] = f"+0.10 (resolved in {self._steps_used(state)} steps)"
else:
breakdown["speed"] = f"+0.00 (resolved but used {self._steps_used(state)} steps)"
else:
total -= 0.10
breakdown["resolution"] = "-0.10 (never resolved)"
self._breakdown = breakdown
return self._clamp(total)
def get_breakdown(self) -> dict:
return getattr(self, "_breakdown", {})