File size: 8,980 Bytes
99ef18d
 
 
 
 
 
edf8eb5
 
99ef18d
 
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
edf8eb5
 
 
 
 
 
 
99ef18d
 
 
edf8eb5
 
99ef18d
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
 
edf8eb5
 
 
 
99ef18d
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
edf8eb5
 
 
 
 
99ef18d
 
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
edf8eb5
99ef18d
 
 
edf8eb5
99ef18d
 
 
 
 
 
 
 
 
edf8eb5
 
 
99ef18d
edf8eb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Deterministic grading for SRE incident response tasks."""

from __future__ import annotations

from typing import Any, Dict, List

OPTIMAL_STEPS = {"easy": 5, "medium": 10, "hard": 15}


def grade_task(task_id: str, cluster_snapshot: Dict[str, Any], action_history: List[Dict]) -> Dict[str, Any]:
    """Grade a completed episode. Returns {"reward": float, "metadata": dict}."""
    grader_map = {
        "easy": _grade_easy,
        "medium": _grade_medium,
        "hard": _grade_hard,
    }
    grader = grader_map.get(task_id)
    if not grader:
        return {"reward": 0.5, "metadata": {"error": f"Unknown task: {task_id}"}}
    return grader(cluster_snapshot, action_history)


def _efficiency_score(task_id: str, steps_taken: int) -> float:
    """Score based on how efficiently the agent solved the task.
    Returns value in (0, 1) β€” mathematically cannot be 0.0 or 1.0."""
    optimal = OPTIMAL_STEPS.get(task_id, 10)
    return optimal / (steps_taken + optimal)


def _weighted_score(results: List[Dict]) -> float:
    total_weight = sum(r["weight"] for r in results)
    if total_weight == 0:
        return 0.5
    return sum(r["score"] * r["weight"] for r in results) / total_weight


# ═══════════════════════════════════════════════════════════════════
# EASY β€” Memory Leak in API Gateway
# ═══════════════════════════════════════════════════════════════════

def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
    services = snapshot.get("services", {})
    steps = snapshot.get("step_count", len(history))
    results = []

    # 1. Did agent investigate the root cause service?
    investigated = any(
        h["command"] in ("check_logs", "get_metrics", "check_processes")
        and h["target"] == "api-gateway"
        for h in history
    )
    results.append({"name": "Investigated api-gateway", "score": 1.0 if investigated else 0.0, "weight": 0.2})

    # 2. Was api-gateway restarted?
    restarted = any(
        h["command"] == "restart_service" and h["target"] == "api-gateway"
        for h in history
    )
    results.append({"name": "Restarted api-gateway", "score": 1.0 if restarted else 0.0, "weight": 0.3})

    # 3. Is api-gateway healthy now?
    gw = services.get("api-gateway", {})
    healthy = gw.get("status") == "healthy"
    results.append({"name": "api-gateway is healthy", "score": 1.0 if healthy else 0.0, "weight": 0.2})

    # 4. Didn't restart healthy services unnecessarily
    unnecessary_restarts = sum(
        1 for h in history
        if h["command"] == "restart_service" and h["target"] not in ("api-gateway", "frontend")
    )
    no_waste = unnecessary_restarts == 0
    results.append({"name": "No unnecessary restarts", "score": 1.0 if no_waste else 0.0, "weight": 0.1})

    # 5. Efficiency β€” always in (0, 1), prevents total from hitting 0.0 or 1.0
    eff = _efficiency_score("easy", steps)
    results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})

    return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}


# ═══════════════════════════════════════════════════════════════════
# MEDIUM β€” Cascading Database Failure
# ═══════════════════════════════════════════════════════════════════

def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
    services = snapshot.get("services", {})
    steps = snapshot.get("step_count", len(history))
    results = []

    # 1. Traced dependencies
    traced = any(
        h["command"] == "check_dependencies"
        for h in history
    )
    results.append({"name": "Traced dependency graph", "score": 1.0 if traced else 0.0, "weight": 0.1})

    # 2. Identified postgres as root cause (investigated it)
    db_investigated = any(
        h["command"] in ("check_logs", "get_metrics")
        and h["target"] == "postgres-primary"
        for h in history
    )
    results.append({"name": "Investigated postgres-primary", "score": 1.0 if db_investigated else 0.0, "weight": 0.15})

    # 3. Fixed postgres (restarted or updated config)
    db_fixed = any(
        h["command"] in ("restart_service", "update_config")
        and h["target"] == "postgres-primary"
        for h in history
    )
    db_svc = services.get("postgres-primary", {})
    db_healthy = db_svc.get("status") == "healthy"
    results.append({"name": "Fixed postgres-primary", "score": 1.0 if (db_fixed and db_healthy) else 0.0, "weight": 0.2})

    # 4. Downstream services recovered
    downstream_names = ["user-service", "order-service", "payment-service"]
    recovered = sum(1 for n in downstream_names if services.get(n, {}).get("status") == "healthy")
    downstream_score = recovered / len(downstream_names)
    results.append({"name": "Downstream services recovered", "score": round(downstream_score, 4), "weight": 0.2})

    # 5. Didn't act on red herring alerts (didn't restart cache-service)
    acted_on_noise = any(
        h["command"] in ("restart_service", "scale_service", "rollback_service")
        and h["target"] == "cache-service"
        for h in history
    )
    results.append({"name": "Ignored red herring alerts", "score": 0.0 if acted_on_noise else 1.0, "weight": 0.15})

    # 6. Efficiency
    eff = _efficiency_score("medium", steps)
    results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})

    return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}


# ═══════════════════════════════════════════════════════════════════
# HARD β€” Crypto-Mining Attack + Disk Full
# ═══════════════════════════════════════════════════════════════════

def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
    services = snapshot.get("services", {})
    steps = snapshot.get("step_count", len(history))
    results = []

    # 1. Investigated worker-service (the compromised service)
    investigated = any(
        h["command"] in ("check_logs", "get_metrics", "check_processes", "check_network")
        and h["target"] == "worker-service"
        for h in history
    )
    results.append({"name": "Investigated worker-service", "score": 1.0 if investigated else 0.0, "weight": 0.1})

    # 2. Killed the crypto miner process
    worker = services.get("worker-service", {})
    miner_killed = any(
        p.get("name") in ("xmrig", "kworker/u8:2") and p.get("status") == "killed"
        for p in worker.get("processes", [])
    )
    results.append({"name": "Killed crypto miner", "score": 1.0 if miner_killed else 0.0, "weight": 0.1})

    # 3. Rolled back compromised deployment
    rolled_back = worker.get("was_rolled_back", False)
    results.append({"name": "Rolled back worker-service", "score": 1.0 if rolled_back else 0.0, "weight": 0.1})

    # 4. Rotated credentials
    creds_rotated = worker.get("credentials_rotated", False)
    results.append({"name": "Rotated credentials", "score": 1.0 if creds_rotated else 0.0, "weight": 0.1})

    # 5. Fixed disk issue on log-aggregator
    log_agg = services.get("log-aggregator", {})
    disk_ok = log_agg.get("disk_usage_percent", 100) < 80
    results.append({"name": "Cleared log-aggregator disk", "score": 1.0 if disk_ok else 0.0, "weight": 0.1})

    # 6. All services healthy
    mostly_healthy = sum(1 for s in services.values() if s.get("status") == "healthy") / max(len(services), 1)
    results.append({"name": "Cluster health restored", "score": round(mostly_healthy, 4), "weight": 0.15})

    # 7. Submitted correct diagnosis
    diagnosis = snapshot.get("diagnosis_submitted")
    correct_diagnosis = False
    if diagnosis:
        rc = diagnosis.get("root_cause", "").lower()
        correct_diagnosis = any(kw in rc for kw in ["crypto", "mining", "xmrig", "malicious", "compromised", "unauthorized"])
    results.append({"name": "Correct diagnosis submitted", "score": 1.0 if correct_diagnosis else 0.0, "weight": 0.15})

    # 8. Efficiency
    eff = _efficiency_score("hard", steps)
    results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})

    return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}