File size: 16,508 Bytes
c452421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""End-of-episode graders for each task.

Each grader evaluates the full trajectory and produces a score in [0.0, 1.0]
with a detailed breakdown.  Grading is deterministic given the same
trajectory.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

from src.models import (
    EnvironmentState,
    GraderResult,
    IncidentSeverity,
)
from src.scenarios import Scenario


def _keyword_match(text: str, keywords: List[str]) -> bool:
    text_lower = text.lower()
    return any(kw.lower() in text_lower for kw in keywords)


def _severity_distance(a: Optional[IncidentSeverity], b: IncidentSeverity) -> int:
    if a is None:
        return 4  # worst case
    order = list(IncidentSeverity)
    return abs(order.index(a) - order.index(b))


# --------------------------------------------------------------------------
# Task 1 – Easy: Severity Classification
# --------------------------------------------------------------------------

def grade_severity_classification(state: EnvironmentState, scenario: Scenario) -> GraderResult:
    breakdown: Dict[str, float] = {}

    # 1. Correct severity (0.50)
    sev_dist = _severity_distance(state.severity_classified, scenario.correct_severity)
    if sev_dist == 0:
        breakdown["severity_accuracy"] = 0.50
    elif sev_dist == 1:
        breakdown["severity_accuracy"] = 0.25
    else:
        breakdown["severity_accuracy"] = 0.0

    # 2. Investigated before classifying (0.25)
    inv_before_classify = len(state.investigated_services) > 0
    if inv_before_classify:
        relevant_inv = len(
            set(state.investigated_services) & set(scenario.relevant_services)
        )
        if relevant_inv > 0:
            breakdown["investigation_quality"] = 0.25
        else:
            breakdown["investigation_quality"] = 0.10
    else:
        breakdown["investigation_quality"] = 0.0

    # 3. Efficiency (0.25) – fewer steps is better; 0 steps = no credit
    max_s = scenario.max_steps
    used = state.total_steps_taken
    if used == 0:
        breakdown["efficiency"] = 0.0
    elif used <= 3:
        breakdown["efficiency"] = 0.25
    elif used <= 5:
        breakdown["efficiency"] = 0.20
    elif used <= max_s // 2:
        breakdown["efficiency"] = 0.10
    else:
        breakdown["efficiency"] = 0.0

    score = sum(breakdown.values())
    return GraderResult(
        task_id=scenario.task_id,
        score=round(max(0.01, min(0.99, score)), 4),
        breakdown={k: round(v, 4) for k, v in breakdown.items()},
        feedback=_severity_feedback(breakdown),
    )


def _severity_feedback(bd: Dict[str, float]) -> str:
    parts = []
    sa = bd.get("severity_accuracy", 0)
    iq = bd.get("investigation_quality", 0)
    ef = bd.get("efficiency", 0)
    if sa >= 0.50:
        parts.append("βœ“ Severity classification correct. The connection pool saturation and partial error rate (~12%) indicate a degraded-but-not-down P2 incident.")
    elif sa > 0:
        parts.append("~ Severity off by one level. Review the alert signals: 98% connection pool utilisation and 12% error rate indicate degraded service (P2), not a full outage (P1) or minor issue (P3).")
    else:
        parts.append("βœ— Severity classification missing or wrong. Examine alert severity levels and error rates before classifying. A P2 is correct: significant service degradation but not a full outage.")
    if iq >= 0.25:
        parts.append("βœ“ Good investigation β€” examined relevant services before classifying. Always investigate before classifying.")
    elif iq > 0:
        parts.append("~ Investigated services, but not the most relevant ones. postgres-primary (connection pool alert) and user-service (latency alert) are the critical paths.")
    else:
        parts.append("βœ— No investigation performed before classification. Investigate postgres-primary and user-service first to confirm the root cause.")
    if ef >= 0.25:
        parts.append("βœ“ Efficient resolution β€” completed in 3 steps or fewer.")
    elif ef > 0:
        parts.append("~ Resolved but used more steps than optimal. Target: investigate 2 services β†’ classify (3 steps total).")
    else:
        parts.append("βœ— Too many steps or no actions taken. Optimal path: INVESTIGATE postgres-primary β†’ INVESTIGATE user-service β†’ CLASSIFY P2.")
    return " ".join(parts)


# --------------------------------------------------------------------------
# Task 2 – Medium: Root Cause Analysis
# --------------------------------------------------------------------------

def grade_root_cause_analysis(state: EnvironmentState, scenario: Scenario) -> GraderResult:
    breakdown: Dict[str, float] = {}

    # 1. Correct severity (0.15)
    sev_dist = _severity_distance(state.severity_classified, scenario.correct_severity)
    breakdown["severity_accuracy"] = 0.15 if sev_dist == 0 else (0.08 if sev_dist == 1 else 0.0)

    # 2. Investigated root-cause service (0.15)
    if scenario.correct_root_cause_service in state.investigated_services:
        breakdown["investigated_root_cause_service"] = 0.15
    else:
        breakdown["investigated_root_cause_service"] = 0.0

    # 3. Correct diagnosis (0.30)
    if state.diagnosis and _keyword_match(state.diagnosis, scenario.correct_root_cause_keywords):
        breakdown["diagnosis_accuracy"] = 0.30
    elif state.diagnosis:
        breakdown["diagnosis_accuracy"] = 0.05
    else:
        breakdown["diagnosis_accuracy"] = 0.0

    # 4. Correct remediation (0.20)
    valid_keys = {
        f"{va['action']}:{va['service']}" for va in scenario.valid_remediation_actions
    }
    applied_valid = len(set(state.remediations_applied) & valid_keys)
    if applied_valid > 0:
        breakdown["remediation_quality"] = 0.20
    elif len(state.remediations_applied) > 0:
        breakdown["remediation_quality"] = 0.05
    else:
        breakdown["remediation_quality"] = 0.0

    # 5. Efficiency (0.20); 0 steps = no credit
    max_s = scenario.max_steps
    used = state.total_steps_taken
    ratio = used / max_s if used > 0 else 1.0
    if used == 0:
        breakdown["efficiency"] = 0.0
    elif ratio <= 0.4:
        breakdown["efficiency"] = 0.20
    elif ratio <= 0.6:
        breakdown["efficiency"] = 0.15
    elif ratio <= 0.8:
        breakdown["efficiency"] = 0.08
    else:
        breakdown["efficiency"] = 0.0

    score = sum(breakdown.values())
    return GraderResult(
        task_id=scenario.task_id,
        score=round(max(0.01, min(0.99, score)), 4),
        breakdown={k: round(v, 4) for k, v in breakdown.items()},
        feedback=_rca_feedback(breakdown),
    )


def _rca_feedback(bd: Dict[str, float]) -> str:
    parts = []
    da = bd.get("diagnosis_accuracy", 0)
    rq = bd.get("remediation_quality", 0)
    ir = bd.get("investigated_root_cause_service", 0)
    sa = bd.get("severity_accuracy", 0)
    ef = bd.get("efficiency", 0)
    if da >= 0.30:
        parts.append("βœ“ Root cause correctly identified: Redis session store hit maxmemory, causing active payment session tokens to be evicted before payment completion.")
    elif da > 0:
        parts.append("~ Diagnosis attempted but inaccurate. The root cause is Redis memory exhaustion (maxmemory reached β†’ allkeys-lru eviction β†’ payment tokens evicted mid-transaction). Not a payment-gateway issue.")
    else:
        parts.append("βœ— No diagnosis provided. Root cause: redis-session hit its 4 GB maxmemory limit. Eviction policy evicted active payment session tokens faster than transactions completed.")
    if ir >= 0.15:
        parts.append("βœ“ Correctly investigated the root-cause service (redis-session).")
    else:
        parts.append("βœ— Did not investigate redis-session. The eviction spike alert (ALT-012) is the key signal β€” always investigate the service generating eviction alerts.")
    if rq >= 0.20:
        parts.append("βœ“ Correct remediation: scaling redis-session memory clears the eviction pressure.")
    elif rq > 0:
        parts.append("~ Remediation attempted but suboptimal. Valid fixes: scale redis-session capacity, config_change to increase maxmemory, or restart to temporarily clear the eviction queue.")
    else:
        parts.append("βœ— No remediation applied. Fix: REMEDIATE redis-session with action=scale or action=config_change to increase maxmemory.")
    if sa < 0.15:
        parts.append("βœ— Severity wrong. Payment processing at 45% success rate is a P1 β€” revenue-impacting, customer-facing, requires immediate escalation.")
    if ef < 0.15:
        parts.append("~ Efficiency: optimal path is 5 steps (investigate Γ—2 β†’ classify β†’ diagnose β†’ remediate).")
    return " ".join(parts)


# --------------------------------------------------------------------------
# Task 3 – Hard: Full Incident Management
# --------------------------------------------------------------------------

def grade_full_incident_management(state: EnvironmentState, scenario: Scenario) -> GraderResult:
    breakdown: Dict[str, float] = {}

    # Pre-compute whether agent actually investigated root-cause service
    investigated_root_cause = scenario.correct_root_cause_service in state.investigated_services

    # 1. Correct severity (0.12)
    sev_dist = _severity_distance(state.severity_classified, scenario.correct_severity)
    breakdown["severity_accuracy"] = 0.12 if sev_dist == 0 else (0.06 if sev_dist == 1 else 0.0)

    # 2. Root cause identified (0.20)
    # Diagnosis without investigation gets at most 0.06 (prompted guess, not evidence-based)
    if state.diagnosis and _keyword_match(state.diagnosis, scenario.correct_root_cause_keywords):
        if investigated_root_cause:
            breakdown["diagnosis_accuracy"] = 0.20
        else:
            breakdown["diagnosis_accuracy"] = 0.06   # guessed correctly but didn't look
    elif state.diagnosis:
        breakdown["diagnosis_accuracy"] = 0.03
    else:
        breakdown["diagnosis_accuracy"] = 0.0

    # 3. Remediation (0.18)
    valid_keys = {
        f"{va['action']}:{va['service']}" for va in scenario.valid_remediation_actions
    }
    applied_valid = set(state.remediations_applied) & valid_keys
    if len(applied_valid) >= 2:
        breakdown["remediation_quality"] = 0.18
    elif len(applied_valid) == 1:
        breakdown["remediation_quality"] = 0.12
    elif len(state.remediations_applied) > 0:
        breakdown["remediation_quality"] = 0.03
    else:
        breakdown["remediation_quality"] = 0.0

    # 4. Escalation (0.15) β€” raised from 0.08/0.14 cap
    expected_lower = {t.lower() for t in scenario.expected_escalation_teams}
    escalated_lower = {e.lower() for e in state.escalations_made}
    matched = len(escalated_lower & expected_lower)
    if matched >= 2:
        breakdown["escalation_quality"] = 0.15
    elif matched == 1:
        breakdown["escalation_quality"] = 0.09
    else:
        breakdown["escalation_quality"] = 0.0

    # 5. Communication (0.10)
    if len(state.communications_sent) >= 2:
        breakdown["communication"] = 0.10
    elif len(state.communications_sent) == 1:
        breakdown["communication"] = 0.06
    else:
        breakdown["communication"] = 0.0

    # 6. Investigation thoroughness (0.12)
    relevant_investigated = len(
        set(state.investigated_services) & set(scenario.relevant_services)
    )
    total_relevant = len(scenario.relevant_services)
    if total_relevant > 0:
        inv_ratio = relevant_investigated / total_relevant
    else:
        inv_ratio = 0.0
    breakdown["investigation_thoroughness"] = round(0.12 * inv_ratio, 4)

    # 7. Investigation precision (0.03) β€” penalise unfocused investigation
    irrelevant_investigated = len(
        set(state.investigated_services) - set(scenario.relevant_services)
    )
    if irrelevant_investigated == 0:
        breakdown["investigation_precision"] = 0.03
    elif irrelevant_investigated <= 1:
        breakdown["investigation_precision"] = 0.01
    else:
        breakdown["investigation_precision"] = 0.0

    # 8. Efficiency (0.10); 0 steps = no credit
    max_s = scenario.max_steps
    used = state.total_steps_taken
    if used == 0:
        breakdown["efficiency"] = 0.0
    elif used <= int(max_s * 0.5):
        breakdown["efficiency"] = 0.10
    elif used <= int(max_s * 0.7):
        breakdown["efficiency"] = 0.07
    elif used <= int(max_s * 0.85):
        breakdown["efficiency"] = 0.04
    else:
        breakdown["efficiency"] = 0.0

    score = sum(breakdown.values())
    return GraderResult(
        task_id=scenario.task_id,
        score=round(max(0.01, min(0.99, score)), 4),
        breakdown={k: round(v, 4) for k, v in breakdown.items()},
        feedback=_full_feedback(breakdown),
    )


def _full_feedback(bd: Dict[str, float]) -> str:
    parts = []
    da = bd.get("diagnosis_accuracy", 0)
    rq = bd.get("remediation_quality", 0)
    eq = bd.get("escalation_quality", 0)
    comm = bd.get("communication", 0)
    it = bd.get("investigation_thoroughness", 0)
    ip = bd.get("investigation_precision", 0)
    sa = bd.get("severity_accuracy", 0)
    ef = bd.get("efficiency", 0)
    if da >= 0.20:
        parts.append("βœ“ Root cause correctly identified: auth-service v3.1.0 introduced an unbounded in-memory token cache causing OOMKill and cascading failures across all auth-dependent services.")
    elif da >= 0.06:
        parts.append("~ Root cause guessed but not confirmed via investigation. Investigate auth-service first β€” the deployment timestamp (v3.1.0 at 13:47) and memory climb logs are the definitive evidence.")
    else:
        parts.append("βœ— Root cause not identified. Key signals: auth-service memory 45%β†’97% after v3.1.0 deploy at 13:47, changelog note 'Refactored token cache to in-memory store', 3 OOMKills in 5 min.")
    if rq >= 0.18:
        parts.append("βœ“ Comprehensive remediation: rolled back auth-service AND scaled order-service to drain the 15k+ message backlog.")
    elif rq >= 0.12:
        parts.append("~ Partial remediation. Also remediate order-service (scale) to drain the queue that built up during the auth outage.")
    elif rq > 0:
        parts.append("~ Some remediation applied but not optimal. Correct actions: rollback auth-service to v3.0.9 (primary fix) + scale order-service (queue drain).")
    else:
        parts.append("βœ— No remediation applied. Critical: REMEDIATE auth-service action=rollback (roll back v3.1.0). Then REMEDIATE order-service action=scale to clear the queue backlog.")
    if eq >= 0.15:
        parts.append("βœ“ Correct teams escalated (platform-team + auth-team).")
    elif eq > 0:
        parts.append("~ Escalation partial. Escalate to both platform-team (infrastructure impact) and auth-team (owns the service with the bug).")
    else:
        parts.append("βœ— No escalation. This is a P1 cascading outage β€” escalate to platform-team (urgent) and auth-team (owns the buggy deployment).")
    if comm > 0:
        parts.append("βœ“ Status communication sent.")
    else:
        parts.append("βœ— No status communication. Send a COMMUNICATE action to status_page with root cause, impact, and ETA.")
    if it < 0.08:
        parts.append("~ Investigation incomplete. Key services to investigate: auth-service (root cause), api-gateway (circuit breaker), redis-auth-cache (cache bypass evidence), order-service (queue depth).")
    if ip < 0.03:
        parts.append("~ Investigation spread too wide. cdn-static and postgres-primary are red herrings β€” normal metrics, no alerts. Focus on auth-dependent services.")
    if sa < 0.12:
        parts.append("βœ— Severity wrong. Multi-service cascading outage affecting auth, API gateway, orders, users = P1.")
    if ef < 0.04:
        parts.append("~ Efficiency: optimal path completes in 11 steps. Avoid re-investigating services or applying wrong remediations.")
    return " ".join(parts)


# --------------------------------------------------------------------------
# Dispatcher
# --------------------------------------------------------------------------

_GRADERS = {
    "severity_classification": grade_severity_classification,
    "root_cause_analysis": grade_root_cause_analysis,
    "full_incident_management": grade_full_incident_management,
}


def grade(task_id: str, state: EnvironmentState, scenario: Scenario) -> GraderResult:
    grader_fn = _GRADERS.get(task_id)
    if grader_fn is None:
        raise ValueError(f"No grader for task_id '{task_id}'")
    return grader_fn(state, scenario)