File size: 15,400 Bytes
5884d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb0927
 
5884d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb0927
 
 
 
5884d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb0927
5884d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb0927
 
 
 
5884d9c
 
 
 
 
 
fbb0927
 
 
 
5884d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
Core LogTriageEnvironment class.
Implements OpenEnv interface: reset(), step(), state property.
"""
from __future__ import annotations
import random
from datetime import datetime
from uuid import uuid4

from server.models import (
    TriageAction,
    TriageObservation,
    EpisodeState,
    LogLine,
    ServiceStatus,
)
from server.scenarios import single_crash
from server.scenarios import cascading
from server.scenarios import silent_degrade
from server.log_generator import generate_healthy_system_state, _make_timestamp

# ─── TASK REGISTRY ─────────────────────────────────────────────────────────────

TASK_MAX_STEPS = {
    "single_crash":      8,
    "cascading_failure": 12,
    "silent_degradation": 15,
}

# ─── REWARD CONSTANTS ──────────────────────────────────────────────────────────

R_CORRECT_SEVERITY     =  0.30
R_CORRECT_ROOT_CAUSE   =  0.35
R_CORRECT_REMEDIATION  =  0.25
R_CORRECT_ESCALATION   =  0.10
R_SPEED_BONUS          =  0.10
R_PARTIAL_SERVICE_FAM  =  0.10
R_PARTIAL_SEVERITY_ADJ =  0.10

P_WRONG_ESCALATION     = -0.10
P_IGNORE_P1            = -0.50
P_REDUNDANT_ACTION     = -0.05
P_EXCEEDED_BUDGET      = -0.20
P_OVERESCALATE_P3_P1   = -0.15


class LogTriageEnvironment:
    """
    OpenEnv-compatible environment for SRE incident triage.

    Usage:
        env = LogTriageEnvironment()
        obs = env.reset(task_id="single_crash", seed=42)
        while not obs.done:
            action = agent.act(obs)
            obs = env.step(action)
        score = env.get_grader_score()
    """

    def __init__(self):
        self._state: EpisodeState | None = None
        self._rng: random.Random = random.Random()
        self._base_time: datetime = datetime.utcnow()
        self._task_id: str = "single_crash"
        self._ground_truth: dict = {}
        self._current_obs: TriageObservation | None = None

    # ─── OPENENV INTERFACE ─────────────────────────────────────────────────────

    def reset(self, task_id: str = "single_crash", seed: int | None = None) -> TriageObservation:
        """Start a fresh episode. Returns initial observation."""
        if task_id not in TASK_MAX_STEPS:
            raise ValueError(f"Unknown task_id '{task_id}'. Valid: {list(TASK_MAX_STEPS.keys())}")

        self._task_id = task_id
        self._rng = random.Random(seed)
        self._base_time = datetime.utcnow()

        # Load ground truth for this task
        if task_id == "single_crash":
            self._ground_truth = single_crash.GROUND_TRUTH
        elif task_id == "cascading_failure":
            self._ground_truth = cascading.GROUND_TRUTH
        elif task_id == "silent_degradation":
            self._ground_truth = silent_degrade.GROUND_TRUTH

        # Initialize episode state
        self._state = EpisodeState(
            episode_id=str(uuid4()),
            task_id=task_id,
            step_count=0,
            max_steps=TASK_MAX_STEPS[task_id],
            done=False,
            cumulative_score=0.0,
            actions_taken=[],
            correct_severity=None,
            correct_root_cause=None,
            correct_remediation=False,
        )

        # Get initial observation (step 0)
        logs, system_state = self._get_step_data(0)
        alerts = self._get_alerts(0)

        obs = TriageObservation(
            logs=logs,
            system_state=system_state,
            incident_id=self._state.episode_id,
            task_id=task_id,
            step_count=0,
            time_elapsed_seconds=0,
            active_alerts=alerts,
            reward=0.0,
            cumulative_score=0.0,
            done=False,
            last_action_feedback="Incident detected. Analyze the logs and take action.",
            invalid_action_error=None,
        )
        self._current_obs = obs
        return obs

    def step(self, action: TriageAction) -> TriageObservation:
        """Take one action. Returns next observation + reward."""
        if self._state is None:
            raise RuntimeError("Call reset() before step()")
        if self._state.done:
            raise RuntimeError("Episode is done. Call reset() to start a new episode.")

        # Validate action
        valid, err = action.is_valid()
        if not valid:
            return self._make_obs(
                reward=0.0,
                feedback=f"Invalid action: {err}",
                invalid_action_error=err,
                advance_step=False,
            )

        # Calculate reward for this action
        reward, feedback = self._evaluate_action(action)

        # Update state
        self._state.cumulative_score = round(
            self._state.cumulative_score + reward, 4
        )
        self._state.actions_taken.append(action.action_type)
        self._state.action_history.append(action.model_dump())
        self._state.step_count += 1

        # Check if episode should end
        done = self._check_done(action)
        self._state.done = done

        # If done due to budget exceeded, apply penalty
        if self._state.step_count >= self._state.max_steps and not done:
            self._state.cumulative_score = round(
                self._state.cumulative_score + P_EXCEEDED_BUDGET, 4
            )
            self._state.done = True
            feedback += f" Step budget exceeded ({self._state.max_steps} steps). Penalty applied."

        return self._make_obs(reward=reward, feedback=feedback, advance_step=True)

    @property
    def state(self) -> EpisodeState:
        """Return current episode state."""
        if self._state is None:
            raise RuntimeError("Call reset() first.")
        return self._state

    def get_grader_score(self) -> float:
        """
        Return final grader score for the completed episode.
        Score is normalized to [0.0, 1.0].
        """
        if self._state is None:
            return 0.0
        # Clamp score to [0.0, 1.0]
        raw = self._state.cumulative_score
        return round(max(0.0, min(1.0, raw)), 4)

    # ─── INTERNAL HELPERS ──────────────────────────────────────────────────────

    def _evaluate_action(self, action: TriageAction) -> tuple[float, str]:
        """
        Evaluate the action against ground truth.
        Returns (reward: float, feedback: str).
        """
        gt = self._ground_truth
        reward = 0.0
        feedback_parts = []

        # Penalize redundant actions
        if action.action_type in self._state.actions_taken:
            reward += P_REDUNDANT_ACTION
            feedback_parts.append("Redundant action β€” you've already done this.")

        # ── classify_severity ──────────────────────────────────────────────────
        if action.action_type == "classify_severity":
            correct_sev = gt.get("severity", "")
            if action.value == correct_sev:
                if self._state.correct_severity is None:  # only reward first time
                    reward += R_CORRECT_SEVERITY
                    feedback_parts.append(f"Correct severity: {action.value}. +{R_CORRECT_SEVERITY}")
                    self._state.correct_severity = action.value
            else:
                # Partial credit: P1 vs P2 is close, P1 vs P3 is not
                if correct_sev == "P1" and action.value == "P3":
                    reward += P_OVERESCALATE_P3_P1  # wrong direction
                    feedback_parts.append(f"Incorrect severity: {action.value}. P1 expected. This is a customer-impacting incident.")
                elif correct_sev == "P1" and action.value == "P2":
                    reward += R_PARTIAL_SEVERITY_ADJ
                    feedback_parts.append(f"Close β€” {action.value} given, P1 expected. Partial credit.")
                else:
                    feedback_parts.append(f"Incorrect severity: {action.value}. Reassess.")

        # ── identify_root_cause ────────────────────────────────────────────────
        elif action.action_type == "identify_root_cause":
            correct_rc = gt.get("root_cause", "")
            if action.value == correct_rc:
                if self._state.correct_root_cause is None:
                    reward += R_CORRECT_ROOT_CAUSE
                    feedback_parts.append(f"Correct root cause: {action.value}. +{R_CORRECT_ROOT_CAUSE}")
                    self._state.correct_root_cause = action.value
            else:
                # Partial credit: same tier (e.g. payment-db instead of payment-service)
                if correct_rc.split("-")[0] == action.value.split("-")[0]:
                    reward += R_PARTIAL_SERVICE_FAM
                    feedback_parts.append(f"Close β€” {action.value} is in the right service family. Check more carefully.")
                else:
                    feedback_parts.append(f"Incorrect root cause: {action.value}. Look at which service is actually failing.")

        # ── escalate ──────────────────────────────────────────────────────────
        elif action.action_type == "escalate":
            correct_teams = gt.get("correct_teams", set())
            if action.value in correct_teams:
                reward += R_CORRECT_ESCALATION
                feedback_parts.append(f"Correct escalation to {action.value}. +{R_CORRECT_ESCALATION}")
            else:
                reward += P_WRONG_ESCALATION
                feedback_parts.append(f"Wrong team escalated: {action.value}. Penalty applied.")

        # ── remediate ─────────────────────────────────────────────────────────
        elif action.action_type == "remediate":
            prefix = action.value.split(":")[0]
            service = action.value.split(":")[1] if ":" in action.value else ""
            correct_prefixes = gt.get("remediation_prefixes", set())
            correct_service = gt.get("remediation_service", "")

            if prefix in correct_prefixes and service == correct_service:
                if not self._state.correct_remediation:
                    reward += R_CORRECT_REMEDIATION
                    feedback_parts.append(f"Correct remediation: {action.value}. +{R_CORRECT_REMEDIATION}")
                    self._state.correct_remediation = True
            elif service == correct_service and prefix not in correct_prefixes:
                reward += 0.05  # right service, wrong action
                feedback_parts.append(f"Right service, but '{prefix}' may not fix this. Try another remediation type.")
            else:
                feedback_parts.append(f"Incorrect remediation: {action.value}. Reconsider which service needs fixing.")

        # ── ignore ────────────────────────────────────────────────────────────
        elif action.action_type == "ignore":
            correct_sev = gt.get("severity", "")
            if correct_sev == "P1":
                reward += P_IGNORE_P1
                feedback_parts.append(f"CRITICAL ERROR: Ignored a P1 incident! Major penalty applied.")
            else:
                feedback_parts.append("Marked as noise.")

        # ── request_more_logs ─────────────────────────────────────────────────
        elif action.action_type == "request_more_logs":
            feedback_parts.append(f"Fetching more logs for {action.value}...")

        # ── resolve ───────────────────────────────────────────────────────────
        elif action.action_type == "resolve":
            # Speed bonus if resolved within 60% of step budget
            step_budget = self._state.max_steps
            if self._state.step_count <= int(step_budget * 0.6):
                reward += R_SPEED_BONUS
                feedback_parts.append(f"Incident resolved efficiently. Speed bonus: +{R_SPEED_BONUS}")
            else:
                feedback_parts.append("Incident resolved.")

        return round(reward, 4), " | ".join(feedback_parts) or "Action processed."

    def _check_done(self, action: TriageAction) -> bool:
        """Episode ends on resolve, ignore (with P1), or step budget exhausted."""
        if action.action_type == "resolve":
            return True
        if action.action_type == "ignore" and self._ground_truth.get("severity") == "P1":
            return True  # Catastrophic β€” episode ends immediately
        if self._state.step_count >= self._state.max_steps:
            return True
        return False

    def _get_step_data(self, step: int):
        """Get logs and system state for the current step."""
        if self._task_id == "single_crash":
            return single_crash.get_step_data(step, self._base_time, self._rng)
        elif self._task_id == "cascading_failure":
            return cascading.get_step_data(step, self._base_time, self._rng)
        elif self._task_id == "silent_degradation":
            return silent_degrade.get_step_data(step, self._base_time, self._rng)
        return [], generate_healthy_system_state(self._base_time)

    def _get_alerts(self, step: int) -> list[str]:
        """Get active alerts for the current step."""
        if self._task_id == "single_crash":
            return single_crash.get_active_alerts(step)
        elif self._task_id == "cascading_failure":
            return cascading.get_active_alerts(step)
        elif self._task_id == "silent_degradation":
            return silent_degrade.get_active_alerts(step)
        return []

    def _make_obs(
        self,
        reward: float,
        feedback: str,
        invalid_action_error: str | None = None,
        advance_step: bool = True,
    ) -> TriageObservation:
        """Build a TriageObservation for the current state."""
        step = self._state.step_count
        logs, system_state = self._get_step_data(step)
        alerts = self._get_alerts(step)

        return TriageObservation(
            logs=logs,
            system_state=system_state,
            incident_id=self._state.episode_id,
            task_id=self._state.task_id,
            step_count=step,
            time_elapsed_seconds=step * 30,
            active_alerts=alerts,
            reward=reward,
            cumulative_score=self._state.cumulative_score,
            done=self._state.done,
            last_action_feedback=feedback,
            invalid_action_error=invalid_action_error,
        )