File size: 11,875 Bytes
d5fc8a7
37204eb
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
 
d5fc8a7
 
37204eb
d5fc8a7
37204eb
d5fc8a7
 
37204eb
d5fc8a7
37204eb
d5fc8a7
 
37204eb
d5fc8a7
37204eb
d5fc8a7
37204eb
d5fc8a7
37204eb
d5fc8a7
37204eb
 
 
 
 
 
 
 
d5fc8a7
 
 
 
37204eb
 
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
 
 
 
 
 
 
 
 
 
 
d5fc8a7
37204eb
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
37204eb
 
 
 
 
 
 
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
 
d5fc8a7
 
 
 
 
 
 
 
 
 
 
37204eb
d5fc8a7
 
 
37204eb
d5fc8a7
 
 
 
 
 
 
 
37204eb
d5fc8a7
37204eb
 
 
d5fc8a7
 
 
37204eb
 
 
 
 
 
 
 
 
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
d5fc8a7
 
37204eb
 
 
 
d5fc8a7
37204eb
 
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37204eb
d5fc8a7
 
37204eb
 
 
 
 
d5fc8a7
 
37204eb
d5fc8a7
37204eb
 
 
 
 
 
d5fc8a7
37204eb
 
d5fc8a7
 
37204eb
 
 
d5fc8a7
37204eb
 
 
d5fc8a7
 
 
 
 
 
 
 
 
 
37204eb
d5fc8a7
 
37204eb
 
 
 
 
d5fc8a7
 
 
37204eb
d5fc8a7
 
 
37204eb
d5fc8a7
 
 
37204eb
d5fc8a7
37204eb
d5fc8a7
 
 
 
37204eb
d5fc8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
server/environment.py β€” Core OpenEnv environment for Cloud Incident Response.

Implements the full OpenEnv interface:
  reset(task_id, scenario_index) -> Observation
  step(action)                   -> (Observation, Reward, done, info)
  state()                        -> EpisodeState

All state is in-memory. Thread-safe via a lock.
"""

from __future__ import annotations

import uuid
import threading
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from tasks import get_task, get_scenario
from graders import grade
from server.models import Action, ActionParameters, Observation, Reward, EpisodeState

# ── Action type classification ────────────────────────────────────────────────

_DIAGNOSTIC = frozenset({
    "query_logs", "check_metrics", "check_dependencies",
    "check_recent_deploys", "check_service_status",
})

_REMEDIATION = frozenset({
    "restart_service", "rollback_deploy", "scale_service",
    "disable_feature_flag", "clear_cache", "execute_runbook_step",
})

_SUBMIT = frozenset({
    "submit_severity", "submit_root_cause", "submit_resolution",
})

# ── Reward constants ──────────────────────────────────────────────────────────

R_QUERY_FIRST   = +0.05   # First time querying a known service
R_QUERY_REPEAT  = +0.01   # Re-querying same service/tool
R_QUERY_UNKNOWN = -0.05   # Querying an unknown service
R_REM_GOOD      = +0.10   # Correct remediation action
R_REM_WRONG     = -0.10   # Wrong remediation action
R_PAST_HALF     = -0.02   # Step efficiency penalty past halfway
R_TIMEOUT       = -0.10   # No submission before max_steps
R_BAD_ACTION    = -0.03   # Unrecognised action_type


class IncidentEnvironment:
    """
    OpenEnv environment for Cloud Incident Response.
    One instance handles one episode at a time. Thread-safe.
    """

    def __init__(self):
        self._lock = threading.Lock()
        self._s: dict = {}
        self._scenario: dict = {}
        self._task_def: dict = {}
        self._ready = False

    # ── Public OpenEnv API ───────────────────────────────────────────────────

    def reset(self, task_id: str, scenario_index: int = 0) -> Observation:
        """Start a fresh episode. Returns the initial Observation."""
        with self._lock:
            task_def = get_task(task_id)
            scenario = get_scenario(task_id, scenario_index)

            self._task_def = task_def
            self._scenario = scenario
            self._s = {
                "episode_id":        str(uuid.uuid4()),
                "task_id":           task_id,
                "scenario_id":       scenario["scenario_id"],
                "step_count":        0,
                "max_steps":         task_def["max_steps"],
                "action_history":    [],
                "queried_data":      {},
                "queried_keys":      set(),
                "submitted":         False,
                "resolved":          False,
                "done":              False,
                "cumulative_reward": 0.0,
                "feedback":          f"Episode started. {scenario['description']}",
            }
            self._ready = True
            return self._build_obs()

    def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
        """Process one agent action. Returns (Observation, Reward, done, info)."""
        with self._lock:
            if not self._ready:
                raise RuntimeError("Call reset() before step().")

            s = self._s
            if s["done"]:
                return (
                    self._build_obs(),
                    Reward(value=0.0, reason="episode already done",
                           cumulative=s["cumulative_reward"]),
                    True,
                    {},
                )

            s["step_count"] += 1
            step_num = s["step_count"]
            at = action.action_type
            params = action.parameters

            s["action_history"].append({
                "action_type": at,
                "parameters":  params.model_dump(exclude_none=True),
                "step":        step_num,
            })

            r = 0.0
            fb: list[str] = []

            # Efficiency penalty past halfway
            if step_num > s["max_steps"] // 2:
                r += R_PAST_HALF
                fb.append("efficiency penalty")

            if at in _DIAGNOSTIC:
                r, fb = self._handle_diagnostic(at, params, r, fb)
            elif at in _REMEDIATION:
                r, fb = self._handle_remediation(at, params, r, fb)
            elif at in _SUBMIT:
                r, fb, terminal = self._handle_submit(at, params, r, fb)
                if terminal:
                    s["done"] = True
            else:
                r += R_BAD_ACTION
                fb.append(f"unknown action_type '{at}'")

            # Timeout
            if step_num >= s["max_steps"] and not s["done"]:
                r += R_TIMEOUT
                fb.append("timeout β€” no submission made")
                s["done"] = True

            # Run grader on terminal step
            if s["done"]:
                result = grade(s["task_id"], s, self._scenario)
                s["cumulative_reward"] = round(
                    s["cumulative_reward"] + r + result["total"], 4
                )
                fb.append(f"grader={result['feedback']}")
            else:
                s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4)

            s["feedback"] = " | ".join(fb) if fb else "ok"

            return (
                self._build_obs(),
                Reward(
                    value=round(r, 4),
                    reason=s["feedback"],
                    cumulative=s["cumulative_reward"],
                ),
                s["done"],
                {"step": step_num, "feedback": s["feedback"]},
            )

    def state(self) -> EpisodeState:
        """Return the full current episode state."""
        with self._lock:
            if not self._ready:
                raise RuntimeError("No active episode β€” call reset() first.")
            s = self._s
            return EpisodeState(
                episode_id=s["episode_id"],
                task_id=s["task_id"],
                scenario_id=s["scenario_id"],
                step_count=s["step_count"],
                max_steps=s["max_steps"],
                action_history=list(s["action_history"]),
                queried_data=dict(s["queried_data"]),
                submitted=s["submitted"],
                resolved=s["resolved"],
                done=s["done"],
                cumulative_reward=s["cumulative_reward"],
                feedback=s["feedback"],
            )

    # ── Action handlers ──────────────────────────────────────────────────────

    def _handle_diagnostic(
        self, at: str, params: ActionParameters, r: float, fb: list[str]
    ) -> tuple[float, list[str]]:
        s = self._s
        service = (params.service or "").lower().strip()
        known = {sv.lower() for sv in self._scenario.get("known_services", set())}
        tool_data = self._scenario.get("tool_responses", {}).get(at, {})
        key = (at, service)

        if service and service in known:
            if key not in s["queried_keys"]:
                r += R_QUERY_FIRST
                fb.append(f"queried {service} (+{R_QUERY_FIRST})")
                s["queried_keys"].add(key)
            else:
                r += R_QUERY_REPEAT
                fb.append(f"re-queried {service} (+{R_QUERY_REPEAT})")
            result = tool_data.get(service, f"No data for '{service}'.")
            s["queried_data"].setdefault(at, {})[service] = result

        elif service:
            r += R_QUERY_UNKNOWN
            fb.append(f"unknown service '{service}' ({R_QUERY_UNKNOWN})")
        else:
            fb.append(f"{at}: no service specified")

        return r, fb

    def _handle_remediation(
        self, at: str, params: ActionParameters, r: float, fb: list[str]
    ) -> tuple[float, list[str]]:
        s = self._s
        service = (params.service or "").lower().strip()
        flag = (params.flag or "").lower().strip()
        runbook = (params.runbook_action or "").lower().strip()
        target = (params.target or "").lower().strip()

        keys = {at}
        if service: keys.add(f"{at}:{service}")
        if flag:    keys.add(f"{at}:{flag}")
        if runbook: keys.add(f"execute_runbook_step:{runbook}")
        if target:  keys.add(f"execute_runbook_step:{target}")

        wrong_map = self._scenario.get("wrong_actions", {})
        rem_data  = self._scenario.get("remediation_data", {})

        if any(k in wrong_map for k in keys):
            r += R_REM_WRONG
            reason = next(
                (wrong_map[k] for k in keys if k in wrong_map), "wrong action"
            )
            fb.append(f"wrong action '{at}': {str(reason)[:80]}")
        else:
            r += R_REM_GOOD
            fb.append(f"executed {at}" + (f" on '{service}'" if service else ""))
            at_data = rem_data.get(at, {})
            result = (
                at_data.get(service) or at_data.get(flag) or
                at_data.get(runbook) or at_data.get(target) or
                "action executed successfully"
            )
            s["queried_data"].setdefault(at, {})[
                service or flag or runbook or target or at
            ] = result

        return r, fb

    def _handle_submit(
        self, at: str, params: ActionParameters, r: float, fb: list[str]
    ) -> tuple[float, list[str], bool]:
        s = self._s
        s["submitted"] = True

        if at == "submit_severity":
            fb.append(f"submitted severity: {(params.severity or '').upper()}")

        elif at == "submit_root_cause":
            fb.append(
                f"submitted root cause: "
                f"service={params.service or ''}, "
                f"failure_mode={params.failure_mode or ''}"
            )

        elif at == "submit_resolution":
            summary = params.summary or ""
            inv_count = sum(
                1 for a in s["action_history"]
                if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION
            )
            if summary.strip() and inv_count >= 1:
                s["resolved"] = True
                fb.append("resolution submitted β€” incident resolved")
            else:
                fb.append("resolution submitted β€” insufficient investigation")

        return r, fb, True

    # ── Build observation ────────────────────────────────────────────────────

    def _build_obs(self) -> Observation:
        s  = self._s
        sc = self._scenario
        td = self._task_def
        return Observation(
            episode_id=s["episode_id"],
            task_id=s["task_id"],
            scenario_id=s["scenario_id"],
            step_count=s["step_count"],
            max_steps=s["max_steps"],
            incident_summary=sc.get("incident_summary", sc.get("description", "")),
            alert=sc.get("alert", {}),
            available_actions=td.get("available_actions", []),
            queried_data=dict(s["queried_data"]),
            cumulative_reward=s["cumulative_reward"],
            done=s["done"],
            feedback=s["feedback"],
        )