Pramod Basavaraj Menasi commited on
Commit
2668702
·
1 Parent(s): e7cfcc2

updated app.py

Browse files
inference.py CHANGED
@@ -21,7 +21,7 @@ BENCHMARK = os.getenv("INCIDENTOPS_BENCHMARK", "incidentops_env")
21
  MAX_STEPS = int(os.getenv("MAX_STEPS", "12"))
22
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
23
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
24
- DIFFICULTY = os.getenv("DIFFICULTY", "easy")
25
 
26
  SYSTEM_PROMPT = """
27
  You are an incident-response policy.
@@ -113,7 +113,7 @@ async def main() -> None:
113
  log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
114
 
115
  try:
116
- result = await env.reset(difficulty=DIFFICULTY)
117
  obs = result.observation
118
 
119
  for step in range(1, MAX_STEPS + 1):
 
21
  MAX_STEPS = int(os.getenv("MAX_STEPS", "12"))
22
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
23
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
24
+ TASK_ID = os.getenv("TASK_ID", "incident_easy")
25
 
26
  SYSTEM_PROMPT = """
27
  You are an incident-response policy.
 
113
  log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
114
 
115
  try:
116
+ result = await env.reset(task_id=TASK_ID)
117
  obs = result.observation
118
 
119
  for step in range(1, MAX_STEPS + 1):
models.py CHANGED
@@ -37,6 +37,6 @@ class IncidentopsObservation(Observation):
37
  metadata: Dict[str, Any] = Field(default_factory=dict, description="Extra debug metadata")
38
  reward: float = Field(default=0.0, description="Reward returned by the last step")
39
  done: bool = Field(default=False, description="Whether the episode is finished")
40
- grader_score: float = Field(default=0.0, description="Grader score 0.0-1.0, set when done=True")
41
 
42
 
 
37
  metadata: Dict[str, Any] = Field(default_factory=dict, description="Extra debug metadata")
38
  reward: float = Field(default=0.0, description="Reward returned by the last step")
39
  done: bool = Field(default=False, description="Whether the episode is finished")
40
+
41
 
42
 
openenv.yaml CHANGED
@@ -3,14 +3,14 @@ name: incidentops_env
3
  type: space
4
  runtime: fastapi
5
  app: server.app:app
6
- port: 7860
7
 
8
  tasks:
9
  - id: incident_easy
10
  name: "Single Service Outage (Easy)"
11
  description: "Diagnose and resolve a payment-service latency spike caused by a bad deployment."
12
  reset_kwargs:
13
- difficulty: easy
14
  grader:
15
  type: class
16
  module: graders
@@ -18,9 +18,8 @@ tasks:
18
 
19
  - id: incident_medium
20
  name: "Dependency Failure (Medium)"
21
- description: "Identify a DB timeout causing API gateway failures with no logs initially available."
22
  reset_kwargs:
23
- difficulty: medium
24
  grader:
25
  type: class
26
  module: graders
@@ -28,9 +27,8 @@ tasks:
28
 
29
  - id: incident_hard
30
  name: "Multi-Service Root Cause (Hard)"
31
- description: "Trace EU checkout failures across auth, payment, checkout to a DNS issue."
32
  reset_kwargs:
33
- difficulty: hard
34
  grader:
35
  type: class
36
  module: graders
 
3
  type: space
4
  runtime: fastapi
5
  app: server.app:app
6
+ port: 8000
7
 
8
  tasks:
9
  - id: incident_easy
10
  name: "Single Service Outage (Easy)"
11
  description: "Diagnose and resolve a payment-service latency spike caused by a bad deployment."
12
  reset_kwargs:
13
+ task_id: incident_easy
14
  grader:
15
  type: class
16
  module: graders
 
18
 
19
  - id: incident_medium
20
  name: "Dependency Failure (Medium)"
 
21
  reset_kwargs:
22
+ task_id: incident_medium
23
  grader:
24
  type: class
25
  module: graders
 
27
 
28
  - id: incident_hard
29
  name: "Multi-Service Root Cause (Hard)"
 
30
  reset_kwargs:
31
+ task_id: incident_hard
32
  grader:
33
  type: class
34
  module: graders
server/app.py CHANGED
@@ -65,22 +65,44 @@ GRADERS = {
65
  @app.post("/grade")
66
  async def grade_endpoint(task_id: str = None, request: Request = None):
67
  try:
68
- if task_id and task_id in GRADERS:
69
- snapshot = _shared_env._snapshot
70
- if snapshot is None:
71
- # Return a zero score instead of erroring — validator just needs grader to respond
72
- return {"score": 0.0, "success": False, "grader": task_id, "detail": "no active episode"}
73
- trajectory = [
74
- {"action": a, "observation": {"incident_resolved": snapshot.resolved}}
75
- for a in snapshot.action_history
76
- ]
77
- score = GRADERS[task_id].grade(trajectory)
78
- return {"score": score, "success": score >= 0.5, "grader": task_id}
79
-
80
- # fallback to env's own grade()
81
- return _shared_env.grade()
82
- except AssertionError:
83
- return {"score": 0.0, "success": False, "detail": "no active episode"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
@@ -94,7 +116,7 @@ async def list_tasks():
94
  ]
95
  }
96
 
97
- def main(host: str = "0.0.0.0", port: int = 7860) -> None:
98
  import uvicorn
99
  uvicorn.run(app, host=host, port=port)
100
 
 
65
  @app.post("/grade")
66
  async def grade_endpoint(task_id: str = None, request: Request = None):
67
  try:
68
+ # STRICT validation (important)
69
+ if not task_id or task_id not in GRADERS:
70
+ return {
71
+ "score": 0.0,
72
+ "success": False,
73
+ "detail": "invalid or missing task_id"
74
+ }
75
+
76
+ snapshot = _shared_env._snapshot
77
+
78
+ if snapshot is None:
79
+ return {
80
+ "score": 0.0,
81
+ "success": False,
82
+ "grader": task_id,
83
+ "detail": "no active episode"
84
+ }
85
+
86
+ # ✅ Build trajectory
87
+ trajectory = [
88
+ {
89
+ "action": a,
90
+ "observation": {
91
+ "incident_resolved": snapshot.resolved
92
+ }
93
+ }
94
+ for a in snapshot.action_history
95
+ ]
96
+
97
+ # ✅ Call correct grader
98
+ score = GRADERS[task_id].grade(trajectory)
99
+
100
+ return {
101
+ "score": score,
102
+ "success": score >= 0.5,
103
+ "grader": task_id
104
+ }
105
+
106
  except Exception as e:
107
  raise HTTPException(status_code=500, detail=str(e))
108
 
 
116
  ]
117
  }
118
 
119
+ def main(host: str = "0.0.0.0", port: int = 8000) -> None:
120
  import uvicorn
121
  uvicorn.run(app, host=host, port=port)
122
 
server/incidentops_env_environment.py CHANGED
@@ -50,10 +50,10 @@ class IncidentSnapshot:
50
 
51
 
52
  SCENARIOS: Dict[str, List[Dict[str, Any]]] = {
53
- "easy": [
54
  {
55
  "scenario_id": "easy_001",
56
- "task": "single_service_outage",
57
  "alert_text": "SEV-2: payment-service latency high after deploy.",
58
  "hidden_truth": "bad_deployment",
59
  "severity": "high",
@@ -62,15 +62,24 @@ SCENARIOS: Dict[str, List[Dict[str, Any]]] = {
62
  "log_snippet": "deploy at 14:32 UTC caused connection pool exhaustion",
63
  "likely_cause": "bad_deployment",
64
  "hf_confidence": 0.92,
65
- "available_actions": ["request_logs", "rollback_deploy", "restart_service", "resolve_incident"],
66
- "correct_action_sequence": ["rollback_deploy", "resolve_incident"],
 
 
 
 
 
 
 
 
67
  "sla_steps": 5,
68
  }
69
  ],
70
- "medium": [
 
71
  {
72
  "scenario_id": "medium_001",
73
- "task": "dependency_failure",
74
  "alert_text": "SEV-1: api-gateway 5xx errors; user-profile-service slow; no logs available.",
75
  "hidden_truth": "db_timeout",
76
  "severity": "critical",
@@ -87,14 +96,21 @@ SCENARIOS: Dict[str, List[Dict[str, Any]]] = {
87
  "restart_service",
88
  "resolve_incident",
89
  ],
90
- "correct_action_sequence": ["request_logs", "query_dependencies", "escalate_db_team", "restart_service", "resolve_incident"],
 
 
 
 
 
 
91
  "sla_steps": 8,
92
  }
93
  ],
94
- "hard": [
 
95
  {
96
  "scenario_id": "hard_001",
97
- "task": "multi_service_root_cause",
98
  "alert_text": "SEV-1: EU checkout failures. Auth and payment degraded. Logs incomplete.",
99
  "hidden_truth": "dns_issue",
100
  "severity": "critical",
@@ -245,15 +261,27 @@ class IncidentopsEnvironment(Environment):
245
  def reset(
246
  self,
247
  episode_id: str = None,
248
- difficulty: str = "easy",
249
  **kwargs
250
  ) -> IncidentopsObservation:
251
- scenario = self._pick_scenario(difficulty)
252
- self._difficulty = difficulty
253
- self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
 
 
 
 
 
 
 
 
 
254
  self._snapshot = IncidentSnapshot(**scenario)
255
  self._snapshot.action_history = []
 
 
256
  self._last_observation = self._build_observation()
 
257
  return self._last_observation
258
 
259
  def step(self, action: IncidentopsAction) -> IncidentopsObservation: # type: ignore[override]
 
50
 
51
 
52
  SCENARIOS: Dict[str, List[Dict[str, Any]]] = {
53
+ "incident_easy": [
54
  {
55
  "scenario_id": "easy_001",
56
+ "task": "incident_easy",
57
  "alert_text": "SEV-2: payment-service latency high after deploy.",
58
  "hidden_truth": "bad_deployment",
59
  "severity": "high",
 
62
  "log_snippet": "deploy at 14:32 UTC caused connection pool exhaustion",
63
  "likely_cause": "bad_deployment",
64
  "hf_confidence": 0.92,
65
+ "available_actions": [
66
+ "request_logs",
67
+ "rollback_deploy",
68
+ "restart_service",
69
+ "resolve_incident"
70
+ ],
71
+ "correct_action_sequence": [
72
+ "rollback_deploy",
73
+ "resolve_incident"
74
+ ],
75
  "sla_steps": 5,
76
  }
77
  ],
78
+
79
+ "incident_medium": [
80
  {
81
  "scenario_id": "medium_001",
82
+ "task": "incident_medium",
83
  "alert_text": "SEV-1: api-gateway 5xx errors; user-profile-service slow; no logs available.",
84
  "hidden_truth": "db_timeout",
85
  "severity": "critical",
 
96
  "restart_service",
97
  "resolve_incident",
98
  ],
99
+ "correct_action_sequence": [
100
+ "request_logs",
101
+ "query_dependencies",
102
+ "escalate_db_team",
103
+ "restart_service",
104
+ "resolve_incident"
105
+ ],
106
  "sla_steps": 8,
107
  }
108
  ],
109
+
110
+ "incident_hard": [
111
  {
112
  "scenario_id": "hard_001",
113
+ "task": "incident_hard",
114
  "alert_text": "SEV-1: EU checkout failures. Auth and payment degraded. Logs incomplete.",
115
  "hidden_truth": "dns_issue",
116
  "severity": "critical",
 
261
  def reset(
262
  self,
263
  episode_id: str = None,
264
+ task_id: str = "incident_easy",
265
  **kwargs
266
  ) -> IncidentopsObservation:
267
+
268
+ # Pick scenario based on task_id (not difficulty)
269
+ scenarios = SCENARIOS.get(task_id, SCENARIOS["incident_easy"])
270
+ scenario = scenarios[0]
271
+
272
+ # ✅ Initialize state
273
+ self._state = State(
274
+ episode_id=episode_id or str(uuid4()),
275
+ step_count=0
276
+ )
277
+
278
+ # ✅ Load scenario into snapshot
279
  self._snapshot = IncidentSnapshot(**scenario)
280
  self._snapshot.action_history = []
281
+
282
+ # ✅ Build first observation
283
  self._last_observation = self._build_observation()
284
+
285
  return self._last_observation
286
 
287
  def step(self, action: IncidentopsAction) -> IncidentopsObservation: # type: ignore[override]