aamrinder commited on
Commit
edf8eb5
Β·
verified Β·
1 Parent(s): f3ea120

Upload server/grader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/grader.py +38 -17
server/grader.py CHANGED
@@ -4,6 +4,8 @@ from __future__ import annotations
4
 
5
  from typing import Any, Dict, List
6
 
 
 
7
 
8
  def grade_task(task_id: str, cluster_snapshot: Dict[str, Any], action_history: List[Dict]) -> Dict[str, Any]:
9
  """Grade a completed episode. Returns {"reward": float, "metadata": dict}."""
@@ -14,17 +16,22 @@ def grade_task(task_id: str, cluster_snapshot: Dict[str, Any], action_history: L
14
  }
15
  grader = grader_map.get(task_id)
16
  if not grader:
17
- return {"reward": 0.0, "metadata": {"error": f"Unknown task: {task_id}"}}
18
  return grader(cluster_snapshot, action_history)
19
 
20
 
 
 
 
 
 
 
 
21
  def _weighted_score(results: List[Dict]) -> float:
22
  total_weight = sum(r["weight"] for r in results)
23
  if total_weight == 0:
24
- return 0.001
25
- raw = sum(r["score"] * r["weight"] for r in results) / total_weight
26
- # Clamp to strict (0, 1) β€” validator rejects exactly 0.0 or 1.0
27
- return min(max(raw, 0.001), 0.999)
28
 
29
 
30
  # ═══════════════════════════════════════════════════════════════════
@@ -33,6 +40,7 @@ def _weighted_score(results: List[Dict]) -> float:
33
 
34
  def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
35
  services = snapshot.get("services", {})
 
36
  results = []
37
 
38
  # 1. Did agent investigate the root cause service?
@@ -48,12 +56,12 @@ def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
48
  h["command"] == "restart_service" and h["target"] == "api-gateway"
49
  for h in history
50
  )
51
- results.append({"name": "Restarted api-gateway", "score": 1.0 if restarted else 0.0, "weight": 0.4})
52
 
53
  # 3. Is api-gateway healthy now?
54
  gw = services.get("api-gateway", {})
55
  healthy = gw.get("status") == "healthy"
56
- results.append({"name": "api-gateway is healthy", "score": 1.0 if healthy else 0.0, "weight": 0.3})
57
 
58
  # 4. Didn't restart healthy services unnecessarily
59
  unnecessary_restarts = sum(
@@ -63,6 +71,10 @@ def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
63
  no_waste = unnecessary_restarts == 0
64
  results.append({"name": "No unnecessary restarts", "score": 1.0 if no_waste else 0.0, "weight": 0.1})
65
 
 
 
 
 
66
  return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
67
 
68
 
@@ -72,6 +84,7 @@ def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
72
 
73
  def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
74
  services = snapshot.get("services", {})
 
75
  results = []
76
 
77
  # 1. Traced dependencies
@@ -79,7 +92,7 @@ def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
79
  h["command"] == "check_dependencies"
80
  for h in history
81
  )
82
- results.append({"name": "Traced dependency graph", "score": 1.0 if traced else 0.0, "weight": 0.15})
83
 
84
  # 2. Identified postgres as root cause (investigated it)
85
  db_investigated = any(
@@ -87,7 +100,7 @@ def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
87
  and h["target"] == "postgres-primary"
88
  for h in history
89
  )
90
- results.append({"name": "Investigated postgres-primary", "score": 1.0 if db_investigated else 0.0, "weight": 0.2})
91
 
92
  # 3. Fixed postgres (restarted or updated config)
93
  db_fixed = any(
@@ -97,13 +110,13 @@ def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
97
  )
98
  db_svc = services.get("postgres-primary", {})
99
  db_healthy = db_svc.get("status") == "healthy"
100
- results.append({"name": "Fixed postgres-primary", "score": 1.0 if (db_fixed and db_healthy) else 0.0, "weight": 0.25})
101
 
102
  # 4. Downstream services recovered
103
  downstream_names = ["user-service", "order-service", "payment-service"]
104
  recovered = sum(1 for n in downstream_names if services.get(n, {}).get("status") == "healthy")
105
  downstream_score = recovered / len(downstream_names)
106
- results.append({"name": "Downstream services recovered", "score": round(downstream_score, 2), "weight": 0.2})
107
 
108
  # 5. Didn't act on red herring alerts (didn't restart cache-service)
109
  acted_on_noise = any(
@@ -111,7 +124,11 @@ def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
111
  and h["target"] == "cache-service"
112
  for h in history
113
  )
114
- results.append({"name": "Ignored red herring alerts", "score": 0.0 if acted_on_noise else 1.0, "weight": 0.2})
 
 
 
 
115
 
116
  return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
117
 
@@ -122,6 +139,7 @@ def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
122
 
123
  def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
124
  services = snapshot.get("services", {})
 
125
  results = []
126
 
127
  # 1. Investigated worker-service (the compromised service)
@@ -138,11 +156,11 @@ def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
138
  p.get("name") in ("xmrig", "kworker/u8:2") and p.get("status") == "killed"
139
  for p in worker.get("processes", [])
140
  )
141
- results.append({"name": "Killed crypto miner", "score": 1.0 if miner_killed else 0.0, "weight": 0.15})
142
 
143
  # 3. Rolled back compromised deployment
144
  rolled_back = worker.get("was_rolled_back", False)
145
- results.append({"name": "Rolled back worker-service", "score": 1.0 if rolled_back else 0.0, "weight": 0.15})
146
 
147
  # 4. Rotated credentials
148
  creds_rotated = worker.get("credentials_rotated", False)
@@ -151,11 +169,11 @@ def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
151
  # 5. Fixed disk issue on log-aggregator
152
  log_agg = services.get("log-aggregator", {})
153
  disk_ok = log_agg.get("disk_usage_percent", 100) < 80
154
- results.append({"name": "Cleared log-aggregator disk", "score": 1.0 if disk_ok else 0.0, "weight": 0.15})
155
 
156
  # 6. All services healthy
157
  mostly_healthy = sum(1 for s in services.values() if s.get("status") == "healthy") / max(len(services), 1)
158
- results.append({"name": "Cluster health restored", "score": round(mostly_healthy, 2), "weight": 0.2})
159
 
160
  # 7. Submitted correct diagnosis
161
  diagnosis = snapshot.get("diagnosis_submitted")
@@ -165,5 +183,8 @@ def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
165
  correct_diagnosis = any(kw in rc for kw in ["crypto", "mining", "xmrig", "malicious", "compromised", "unauthorized"])
166
  results.append({"name": "Correct diagnosis submitted", "score": 1.0 if correct_diagnosis else 0.0, "weight": 0.15})
167
 
168
- return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
 
 
169
 
 
 
4
 
5
  from typing import Any, Dict, List
6
 
7
+ OPTIMAL_STEPS = {"easy": 5, "medium": 10, "hard": 15}
8
+
9
 
10
  def grade_task(task_id: str, cluster_snapshot: Dict[str, Any], action_history: List[Dict]) -> Dict[str, Any]:
11
  """Grade a completed episode. Returns {"reward": float, "metadata": dict}."""
 
16
  }
17
  grader = grader_map.get(task_id)
18
  if not grader:
19
+ return {"reward": 0.5, "metadata": {"error": f"Unknown task: {task_id}"}}
20
  return grader(cluster_snapshot, action_history)
21
 
22
 
23
+ def _efficiency_score(task_id: str, steps_taken: int) -> float:
24
+ """Score based on how efficiently the agent solved the task.
25
+ Returns value in (0, 1) β€” mathematically cannot be 0.0 or 1.0."""
26
+ optimal = OPTIMAL_STEPS.get(task_id, 10)
27
+ return optimal / (steps_taken + optimal)
28
+
29
+
30
  def _weighted_score(results: List[Dict]) -> float:
31
  total_weight = sum(r["weight"] for r in results)
32
  if total_weight == 0:
33
+ return 0.5
34
+ return sum(r["score"] * r["weight"] for r in results) / total_weight
 
 
35
 
36
 
37
  # ═══════════════════════════════════════════════════════════════════
 
40
 
41
  def _grade_easy(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
42
  services = snapshot.get("services", {})
43
+ steps = snapshot.get("step_count", len(history))
44
  results = []
45
 
46
  # 1. Did agent investigate the root cause service?
 
56
  h["command"] == "restart_service" and h["target"] == "api-gateway"
57
  for h in history
58
  )
59
+ results.append({"name": "Restarted api-gateway", "score": 1.0 if restarted else 0.0, "weight": 0.3})
60
 
61
  # 3. Is api-gateway healthy now?
62
  gw = services.get("api-gateway", {})
63
  healthy = gw.get("status") == "healthy"
64
+ results.append({"name": "api-gateway is healthy", "score": 1.0 if healthy else 0.0, "weight": 0.2})
65
 
66
  # 4. Didn't restart healthy services unnecessarily
67
  unnecessary_restarts = sum(
 
71
  no_waste = unnecessary_restarts == 0
72
  results.append({"name": "No unnecessary restarts", "score": 1.0 if no_waste else 0.0, "weight": 0.1})
73
 
74
+ # 5. Efficiency β€” always in (0, 1), prevents total from hitting 0.0 or 1.0
75
+ eff = _efficiency_score("easy", steps)
76
+ results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
77
+
78
  return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
79
 
80
 
 
84
 
85
  def _grade_medium(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
86
  services = snapshot.get("services", {})
87
+ steps = snapshot.get("step_count", len(history))
88
  results = []
89
 
90
  # 1. Traced dependencies
 
92
  h["command"] == "check_dependencies"
93
  for h in history
94
  )
95
+ results.append({"name": "Traced dependency graph", "score": 1.0 if traced else 0.0, "weight": 0.1})
96
 
97
  # 2. Identified postgres as root cause (investigated it)
98
  db_investigated = any(
 
100
  and h["target"] == "postgres-primary"
101
  for h in history
102
  )
103
+ results.append({"name": "Investigated postgres-primary", "score": 1.0 if db_investigated else 0.0, "weight": 0.15})
104
 
105
  # 3. Fixed postgres (restarted or updated config)
106
  db_fixed = any(
 
110
  )
111
  db_svc = services.get("postgres-primary", {})
112
  db_healthy = db_svc.get("status") == "healthy"
113
+ results.append({"name": "Fixed postgres-primary", "score": 1.0 if (db_fixed and db_healthy) else 0.0, "weight": 0.2})
114
 
115
  # 4. Downstream services recovered
116
  downstream_names = ["user-service", "order-service", "payment-service"]
117
  recovered = sum(1 for n in downstream_names if services.get(n, {}).get("status") == "healthy")
118
  downstream_score = recovered / len(downstream_names)
119
+ results.append({"name": "Downstream services recovered", "score": round(downstream_score, 4), "weight": 0.2})
120
 
121
  # 5. Didn't act on red herring alerts (didn't restart cache-service)
122
  acted_on_noise = any(
 
124
  and h["target"] == "cache-service"
125
  for h in history
126
  )
127
+ results.append({"name": "Ignored red herring alerts", "score": 0.0 if acted_on_noise else 1.0, "weight": 0.15})
128
+
129
+ # 6. Efficiency
130
+ eff = _efficiency_score("medium", steps)
131
+ results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
132
 
133
  return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}
134
 
 
139
 
140
  def _grade_hard(snapshot: Dict, history: List[Dict]) -> Dict[str, Any]:
141
  services = snapshot.get("services", {})
142
+ steps = snapshot.get("step_count", len(history))
143
  results = []
144
 
145
  # 1. Investigated worker-service (the compromised service)
 
156
  p.get("name") in ("xmrig", "kworker/u8:2") and p.get("status") == "killed"
157
  for p in worker.get("processes", [])
158
  )
159
+ results.append({"name": "Killed crypto miner", "score": 1.0 if miner_killed else 0.0, "weight": 0.1})
160
 
161
  # 3. Rolled back compromised deployment
162
  rolled_back = worker.get("was_rolled_back", False)
163
+ results.append({"name": "Rolled back worker-service", "score": 1.0 if rolled_back else 0.0, "weight": 0.1})
164
 
165
  # 4. Rotated credentials
166
  creds_rotated = worker.get("credentials_rotated", False)
 
169
  # 5. Fixed disk issue on log-aggregator
170
  log_agg = services.get("log-aggregator", {})
171
  disk_ok = log_agg.get("disk_usage_percent", 100) < 80
172
+ results.append({"name": "Cleared log-aggregator disk", "score": 1.0 if disk_ok else 0.0, "weight": 0.1})
173
 
174
  # 6. All services healthy
175
  mostly_healthy = sum(1 for s in services.values() if s.get("status") == "healthy") / max(len(services), 1)
176
+ results.append({"name": "Cluster health restored", "score": round(mostly_healthy, 4), "weight": 0.15})
177
 
178
  # 7. Submitted correct diagnosis
179
  diagnosis = snapshot.get("diagnosis_submitted")
 
183
  correct_diagnosis = any(kw in rc for kw in ["crypto", "mining", "xmrig", "malicious", "compromised", "unauthorized"])
184
  results.append({"name": "Correct diagnosis submitted", "score": 1.0 if correct_diagnosis else 0.0, "weight": 0.15})
185
 
186
+ # 8. Efficiency
187
+ eff = _efficiency_score("hard", steps)
188
+ results.append({"name": "Resolution efficiency", "score": round(eff, 4), "weight": 0.2})
189
 
190
+ return {"reward": round(_weighted_score(results), 4), "metadata": {"evaluations": results}}