aamrinder commited on
Commit
3fbd071
·
verified ·
1 Parent(s): e1fcd2c

Upload server/sre_environment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/sre_environment.py +118 -0
server/sre_environment.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SRE Incident Response Environment — OpenEnv Environment subclass."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from typing import Optional
7
+
8
+ from openenv.core.env_server import Environment
9
+
10
+ from sre_incident_env.models import SREAction, SREObservation, SREState
11
+ from server.cluster import Cluster, AVAILABLE_COMMANDS
12
+ from server.scenarios import get_scenario
13
+ from server.grader import grade_task
14
+
15
+
16
+ class SREEnvironment(Environment):
17
+ SUPPORTS_CONCURRENT_SESSIONS = False
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.cluster = Cluster()
22
+ self._state = SREState()
23
+ self._task_id = "easy"
24
+ self._max_steps = 15
25
+ self._done = False
26
+ self._scenario = {}
27
+ self._prev_grader_score = 0.0
28
+
29
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> SREObservation:
30
+ self._task_id = kwargs.get("task_id", "easy")
31
+ self._scenario = get_scenario(self._task_id)
32
+ self._max_steps = self._scenario["max_steps"]
33
+ self._done = False
34
+ self._prev_grader_score = 0.0
35
+
36
+ self.cluster.reset(self._scenario, seed=seed)
37
+
38
+ self._state = SREState(
39
+ episode_id=episode_id or str(uuid.uuid4()),
40
+ step_count=0,
41
+ task_id=self._task_id,
42
+ task_difficulty=self._scenario["difficulty"],
43
+ cluster_snapshot=self.cluster.get_snapshot(),
44
+ )
45
+
46
+ return SREObservation(
47
+ done=False,
48
+ reward=None,
49
+ output=self._scenario["description"],
50
+ alerts=self.cluster.get_active_alerts(),
51
+ system_health=self.cluster.get_system_health(),
52
+ services_status=self.cluster.get_services_status(),
53
+ step_count=0,
54
+ max_steps=self._max_steps,
55
+ available_commands=AVAILABLE_COMMANDS,
56
+ )
57
+
58
+ def step(self, action: SREAction, timeout_s: Optional[float] = None, **kwargs) -> SREObservation:
59
+ if self._done:
60
+ return SREObservation(
61
+ done=True,
62
+ reward=0.0,
63
+ output="Episode already finished. Call reset() to start a new episode.",
64
+ alerts=[],
65
+ system_health=self.cluster.get_system_health(),
66
+ services_status=self.cluster.get_services_status(),
67
+ step_count=self._state.step_count,
68
+ max_steps=self._max_steps,
69
+ available_commands=AVAILABLE_COMMANDS,
70
+ )
71
+
72
+ # Execute command
73
+ output = self.cluster.execute_command(action.command, action.target, action.parameters)
74
+
75
+ # Advance simulation (degradation + cascading)
76
+ self.cluster.tick()
77
+
78
+ # Update state
79
+ self._state.step_count += 1
80
+ snapshot = self.cluster.get_snapshot()
81
+ self._state.cluster_snapshot = snapshot
82
+
83
+ # Run grader to get current score, reward = delta from previous
84
+ # Delta can be NEGATIVE if degradation worsened things or agent
85
+ # took a destructive action (e.g. restarted healthy service).
86
+ # This penalizes clearly undesirable behavior.
87
+ current_grade = grade_task(
88
+ self._task_id,
89
+ snapshot,
90
+ self.cluster.investigation_history,
91
+ )
92
+ current_score = current_grade["reward"]
93
+ step_reward = current_score - self._prev_grader_score
94
+ self._prev_grader_score = current_score
95
+
96
+ # Check if episode is done
97
+ at_max_steps = self._state.step_count >= self._max_steps
98
+ all_healthy = all(
99
+ s.get("status") == "healthy" for s in snapshot.get("services", {}).values()
100
+ )
101
+ diagnosis_done = snapshot.get("diagnosis_submitted") is not None
102
+ self._done = at_max_steps or (all_healthy and diagnosis_done)
103
+
104
+ return SREObservation(
105
+ done=self._done,
106
+ reward=round(step_reward, 4),
107
+ output=output,
108
+ alerts=self.cluster.get_active_alerts(),
109
+ system_health=self.cluster.get_system_health(),
110
+ services_status=self.cluster.get_services_status(),
111
+ step_count=self._state.step_count,
112
+ max_steps=self._max_steps,
113
+ available_commands=AVAILABLE_COMMANDS,
114
+ )
115
+
116
+ @property
117
+ def state(self) -> SREState:
118
+ return self._state