File size: 5,791 Bytes
fbb0927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Grader for Task 2 β€” Cascading Failure (Medium)

Scoring breakdown:
  Correct severity (P1)                    β†’ +0.20
  Correct root cause (user-db)             β†’ +0.35
  Correct remediation (kill-query/restart) β†’ +0.25
  Ordering bonus (no symptom fix first)    β†’ +0.10
  Speed bonus (resolved ≀ 8 steps)         β†’ +0.10
  ─────────────────────────────────────────────────
  Maximum possible score                   β†’  1.00

Penalties:
  Identified symptom as root cause         β†’  0.00 (no credit)
  Remediated symptom service first         β†’ -0.10 (ordering penalty)
  Never resolved                           β†’ -0.10
"""
from __future__ import annotations
from server.models import EpisodeState
from server.graders.base_grader import BaseGrader


class CascadeGrader(BaseGrader):
    """Official grader for Task 2 β€” Cascading Failure."""

    CORRECT_SEVERITY = "P1"
    CORRECT_ROOT_CAUSE = "user-db"
    CORRECT_REMEDIATION_PREFIXES = {"kill-query", "restart"}
    CORRECT_REMEDIATION_SERVICE = "user-db"
    SYMPTOM_SERVICES = {"api-gateway", "auth-service"}  # wrong answers
    MAX_STEPS = 12
    SPEED_THRESHOLD = 8

    def score(self, state: EpisodeState) -> float:
        """
        Score the completed Task 2 episode.
        Penalizes agents that treat symptoms instead of root cause.
        """
        total = 0.0
        breakdown = {}

        # ── 1. Severity classification ─────────────────────────────────────────
        severity_value = self._get_first_value(state, "classify_severity")
        if severity_value == self.CORRECT_SEVERITY:
            total += 0.20
            breakdown["severity"] = "+0.20 (correct: P1)"
        elif severity_value == "P2":
            total += 0.08
            breakdown["severity"] = "+0.08 (partial: P2 given, P1 expected)"
        elif severity_value is None:
            breakdown["severity"] = "+0.00 (never classified)"
        else:
            breakdown["severity"] = f"+0.00 (wrong: {severity_value})"

        # ── 2. Root cause identification ───────────────────────────────────────
        root_cause_value = self._get_first_value(state, "identify_root_cause")
        if root_cause_value == self.CORRECT_ROOT_CAUSE:
            total += 0.35
            breakdown["root_cause"] = "+0.35 (correct: user-db)"
        elif root_cause_value in self.SYMPTOM_SERVICES:
            # Identified a symptom, not root cause β€” no credit
            breakdown["root_cause"] = f"+0.00 (wrong: {root_cause_value} is a symptom, not root cause)"
        elif root_cause_value and "db" in root_cause_value:
            total += 0.10  # right tier (database), wrong specific service
            breakdown["root_cause"] = f"+0.10 (partial: {root_cause_value}, right tier)"
        elif root_cause_value is None:
            breakdown["root_cause"] = "+0.00 (never identified)"
        else:
            breakdown["root_cause"] = f"+0.00 (wrong: {root_cause_value})"

        # ── 3. Remediation + Ordering ──────────────────────────────────────────
        remediation_actions = self._get_actions_of_type(state, "remediate")
        remediation_scored = False
        symptom_remediated_first = False

        for i, action in enumerate(remediation_actions):
            value = action.get("value", "")
            parts = value.split(":")
            if len(parts) != 2:
                continue
            prefix, service = parts

            # Check if agent remediated a symptom service before root cause
            if service in self.SYMPTOM_SERVICES and not remediation_scored:
                symptom_remediated_first = True

            # Check for correct remediation
            if (
                prefix in self.CORRECT_REMEDIATION_PREFIXES
                and service == self.CORRECT_REMEDIATION_SERVICE
                and not remediation_scored
            ):
                total += 0.25
                breakdown["remediation"] = f"+0.25 (correct: {value})"
                remediation_scored = True

        if not remediation_scored:
            breakdown["remediation"] = "+0.00 (no correct remediation)"

        # ── 4. Ordering bonus ──────────────────────────────────────────────────
        if not symptom_remediated_first and remediation_scored:
            total += 0.10
            breakdown["ordering"] = "+0.10 (correctly targeted root cause, not symptoms)"
        elif symptom_remediated_first:
            total -= 0.10
            breakdown["ordering"] = "-0.10 (remediated symptom service before root cause)"

        # ── 5. Speed bonus ─────────────────────────────────────────────────────
        if self._episode_resolved(state):
            if self._steps_used(state) <= self.SPEED_THRESHOLD:
                total += 0.10
                breakdown["speed"] = f"+0.10 (resolved in {self._steps_used(state)} steps)"
            else:
                breakdown["speed"] = f"+0.00 (resolved but used {self._steps_used(state)} steps)"
        else:
            total -= 0.10
            breakdown["resolution"] = "-0.10 (never resolved)"

        self._breakdown = breakdown
        return self._clamp(total)

    def get_breakdown(self) -> dict:
        return getattr(self, "_breakdown", {})