Spaces:
Running
Running
| """ | |
| Core LogTriageEnvironment class. | |
| Implements OpenEnv interface: reset(), step(), state property. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from datetime import datetime | |
| from uuid import uuid4 | |
| from server.models import ( | |
| TriageAction, | |
| TriageObservation, | |
| EpisodeState, | |
| LogLine, | |
| ServiceStatus, | |
| ) | |
| from server.scenarios import single_crash | |
| from server.scenarios import cascading | |
| from server.scenarios import silent_degrade | |
| from server.log_generator import generate_healthy_system_state, _make_timestamp | |
| # βββ TASK REGISTRY βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TASK_MAX_STEPS = { | |
| "single_crash": 8, | |
| "cascading_failure": 12, | |
| "silent_degradation": 15, | |
| } | |
| # βββ REWARD CONSTANTS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| R_CORRECT_SEVERITY = 0.30 | |
| R_CORRECT_ROOT_CAUSE = 0.35 | |
| R_CORRECT_REMEDIATION = 0.25 | |
| R_CORRECT_ESCALATION = 0.10 | |
| R_SPEED_BONUS = 0.10 | |
| R_PARTIAL_SERVICE_FAM = 0.10 | |
| R_PARTIAL_SEVERITY_ADJ = 0.10 | |
| P_WRONG_ESCALATION = -0.10 | |
| P_IGNORE_P1 = -0.50 | |
| P_REDUNDANT_ACTION = -0.05 | |
| P_EXCEEDED_BUDGET = -0.20 | |
| P_OVERESCALATE_P3_P1 = -0.15 | |
| class LogTriageEnvironment: | |
| """ | |
| OpenEnv-compatible environment for SRE incident triage. | |
| Usage: | |
| env = LogTriageEnvironment() | |
| obs = env.reset(task_id="single_crash", seed=42) | |
| while not obs.done: | |
| action = agent.act(obs) | |
| obs = env.step(action) | |
| score = env.get_grader_score() | |
| """ | |
| def __init__(self): | |
| self._state: EpisodeState | None = None | |
| self._rng: random.Random = random.Random() | |
| self._base_time: datetime = datetime.utcnow() | |
| self._task_id: str = "single_crash" | |
| self._ground_truth: dict = {} | |
| self._current_obs: TriageObservation | None = None | |
| # βββ OPENENV INTERFACE βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str = "single_crash", seed: int | None = None) -> TriageObservation: | |
| """Start a fresh episode. Returns initial observation.""" | |
| if task_id not in TASK_MAX_STEPS: | |
| raise ValueError(f"Unknown task_id '{task_id}'. Valid: {list(TASK_MAX_STEPS.keys())}") | |
| self._task_id = task_id | |
| self._rng = random.Random(seed) | |
| self._base_time = datetime.utcnow() | |
| # Load ground truth for this task | |
| if task_id == "single_crash": | |
| self._ground_truth = single_crash.GROUND_TRUTH | |
| elif task_id == "cascading_failure": | |
| self._ground_truth = cascading.GROUND_TRUTH | |
| elif task_id == "silent_degradation": | |
| self._ground_truth = silent_degrade.GROUND_TRUTH | |
| # Initialize episode state | |
| self._state = EpisodeState( | |
| episode_id=str(uuid4()), | |
| task_id=task_id, | |
| step_count=0, | |
| max_steps=TASK_MAX_STEPS[task_id], | |
| done=False, | |
| cumulative_score=0.0, | |
| actions_taken=[], | |
| correct_severity=None, | |
| correct_root_cause=None, | |
| correct_remediation=False, | |
| ) | |
| # Get initial observation (step 0) | |
| logs, system_state = self._get_step_data(0) | |
| alerts = self._get_alerts(0) | |
| obs = TriageObservation( | |
| logs=logs, | |
| system_state=system_state, | |
| incident_id=self._state.episode_id, | |
| task_id=task_id, | |
| step_count=0, | |
| time_elapsed_seconds=0, | |
| active_alerts=alerts, | |
| reward=0.0, | |
| cumulative_score=0.0, | |
| done=False, | |
| last_action_feedback="Incident detected. Analyze the logs and take action.", | |
| invalid_action_error=None, | |
| ) | |
| self._current_obs = obs | |
| return obs | |
| def step(self, action: TriageAction) -> TriageObservation: | |
| """Take one action. Returns next observation + reward.""" | |
| if self._state is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self._state.done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| # Validate action | |
| valid, err = action.is_valid() | |
| if not valid: | |
| return self._make_obs( | |
| reward=0.0, | |
| feedback=f"Invalid action: {err}", | |
| invalid_action_error=err, | |
| advance_step=False, | |
| ) | |
| # Calculate reward for this action | |
| reward, feedback = self._evaluate_action(action) | |
| # Update state | |
| self._state.cumulative_score = round( | |
| self._state.cumulative_score + reward, 4 | |
| ) | |
| self._state.actions_taken.append(action.action_type) | |
| self._state.action_history.append(action.model_dump()) | |
| self._state.step_count += 1 | |
| # Check if episode should end | |
| done = self._check_done(action) | |
| self._state.done = done | |
| # If done due to budget exceeded, apply penalty | |
| if self._state.step_count >= self._state.max_steps and not done: | |
| self._state.cumulative_score = round( | |
| self._state.cumulative_score + P_EXCEEDED_BUDGET, 4 | |
| ) | |
| self._state.done = True | |
| feedback += f" Step budget exceeded ({self._state.max_steps} steps). Penalty applied." | |
| return self._make_obs(reward=reward, feedback=feedback, advance_step=True) | |
| def state(self) -> EpisodeState: | |
| """Return current episode state.""" | |
| if self._state is None: | |
| raise RuntimeError("Call reset() first.") | |
| return self._state | |
| def get_grader_score(self) -> float: | |
| """ | |
| Return final grader score for the completed episode. | |
| Score is normalized to [0.0, 1.0]. | |
| """ | |
| if self._state is None: | |
| return 0.0 | |
| # Clamp score to [0.0, 1.0] | |
| raw = self._state.cumulative_score | |
| return round(max(0.0, min(1.0, raw)), 4) | |
| # βββ INTERNAL HELPERS ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _evaluate_action(self, action: TriageAction) -> tuple[float, str]: | |
| """ | |
| Evaluate the action against ground truth. | |
| Returns (reward: float, feedback: str). | |
| """ | |
| gt = self._ground_truth | |
| reward = 0.0 | |
| feedback_parts = [] | |
| # Penalize redundant actions | |
| if action.action_type in self._state.actions_taken: | |
| reward += P_REDUNDANT_ACTION | |
| feedback_parts.append("Redundant action β you've already done this.") | |
| # ββ classify_severity ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if action.action_type == "classify_severity": | |
| correct_sev = gt.get("severity", "") | |
| if action.value == correct_sev: | |
| if self._state.correct_severity is None: # only reward first time | |
| reward += R_CORRECT_SEVERITY | |
| feedback_parts.append(f"Correct severity: {action.value}. +{R_CORRECT_SEVERITY}") | |
| self._state.correct_severity = action.value | |
| else: | |
| # Partial credit: P1 vs P2 is close, P1 vs P3 is not | |
| if correct_sev == "P1" and action.value == "P3": | |
| reward += P_OVERESCALATE_P3_P1 # wrong direction | |
| feedback_parts.append(f"Incorrect severity: {action.value}. P1 expected. This is a customer-impacting incident.") | |
| elif correct_sev == "P1" and action.value == "P2": | |
| reward += R_PARTIAL_SEVERITY_ADJ | |
| feedback_parts.append(f"Close β {action.value} given, P1 expected. Partial credit.") | |
| else: | |
| feedback_parts.append(f"Incorrect severity: {action.value}. Reassess.") | |
| # ββ identify_root_cause ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "identify_root_cause": | |
| correct_rc = gt.get("root_cause", "") | |
| if action.value == correct_rc: | |
| if self._state.correct_root_cause is None: | |
| reward += R_CORRECT_ROOT_CAUSE | |
| feedback_parts.append(f"Correct root cause: {action.value}. +{R_CORRECT_ROOT_CAUSE}") | |
| self._state.correct_root_cause = action.value | |
| else: | |
| # Partial credit: same tier (e.g. payment-db instead of payment-service) | |
| if correct_rc.split("-")[0] == action.value.split("-")[0]: | |
| reward += R_PARTIAL_SERVICE_FAM | |
| feedback_parts.append(f"Close β {action.value} is in the right service family. Check more carefully.") | |
| else: | |
| feedback_parts.append(f"Incorrect root cause: {action.value}. Look at which service is actually failing.") | |
| # ββ escalate ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "escalate": | |
| correct_teams = gt.get("correct_teams", set()) | |
| if action.value in correct_teams: | |
| reward += R_CORRECT_ESCALATION | |
| feedback_parts.append(f"Correct escalation to {action.value}. +{R_CORRECT_ESCALATION}") | |
| else: | |
| reward += P_WRONG_ESCALATION | |
| feedback_parts.append(f"Wrong team escalated: {action.value}. Penalty applied.") | |
| # ββ remediate βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "remediate": | |
| prefix = action.value.split(":")[0] | |
| service = action.value.split(":")[1] if ":" in action.value else "" | |
| correct_prefixes = gt.get("remediation_prefixes", set()) | |
| correct_service = gt.get("remediation_service", "") | |
| if prefix in correct_prefixes and service == correct_service: | |
| if not self._state.correct_remediation: | |
| reward += R_CORRECT_REMEDIATION | |
| feedback_parts.append(f"Correct remediation: {action.value}. +{R_CORRECT_REMEDIATION}") | |
| self._state.correct_remediation = True | |
| elif service == correct_service and prefix not in correct_prefixes: | |
| reward += 0.05 # right service, wrong action | |
| feedback_parts.append(f"Right service, but '{prefix}' may not fix this. Try another remediation type.") | |
| else: | |
| feedback_parts.append(f"Incorrect remediation: {action.value}. Reconsider which service needs fixing.") | |
| # ββ ignore ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "ignore": | |
| correct_sev = gt.get("severity", "") | |
| if correct_sev == "P1": | |
| reward += P_IGNORE_P1 | |
| feedback_parts.append(f"CRITICAL ERROR: Ignored a P1 incident! Major penalty applied.") | |
| else: | |
| feedback_parts.append("Marked as noise.") | |
| # ββ request_more_logs βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "request_more_logs": | |
| feedback_parts.append(f"Fetching more logs for {action.value}...") | |
| # ββ resolve βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action.action_type == "resolve": | |
| # Speed bonus if resolved within 60% of step budget | |
| step_budget = self._state.max_steps | |
| if self._state.step_count <= int(step_budget * 0.6): | |
| reward += R_SPEED_BONUS | |
| feedback_parts.append(f"Incident resolved efficiently. Speed bonus: +{R_SPEED_BONUS}") | |
| else: | |
| feedback_parts.append("Incident resolved.") | |
| return round(reward, 4), " | ".join(feedback_parts) or "Action processed." | |
| def _check_done(self, action: TriageAction) -> bool: | |
| """Episode ends on resolve, ignore (with P1), or step budget exhausted.""" | |
| if action.action_type == "resolve": | |
| return True | |
| if action.action_type == "ignore" and self._ground_truth.get("severity") == "P1": | |
| return True # Catastrophic β episode ends immediately | |
| if self._state.step_count >= self._state.max_steps: | |
| return True | |
| return False | |
| def _get_step_data(self, step: int): | |
| """Get logs and system state for the current step.""" | |
| if self._task_id == "single_crash": | |
| return single_crash.get_step_data(step, self._base_time, self._rng) | |
| elif self._task_id == "cascading_failure": | |
| return cascading.get_step_data(step, self._base_time, self._rng) | |
| elif self._task_id == "silent_degradation": | |
| return silent_degrade.get_step_data(step, self._base_time, self._rng) | |
| return [], generate_healthy_system_state(self._base_time) | |
| def _get_alerts(self, step: int) -> list[str]: | |
| """Get active alerts for the current step.""" | |
| if self._task_id == "single_crash": | |
| return single_crash.get_active_alerts(step) | |
| elif self._task_id == "cascading_failure": | |
| return cascading.get_active_alerts(step) | |
| elif self._task_id == "silent_degradation": | |
| return silent_degrade.get_active_alerts(step) | |
| return [] | |
| def _make_obs( | |
| self, | |
| reward: float, | |
| feedback: str, | |
| invalid_action_error: str | None = None, | |
| advance_step: bool = True, | |
| ) -> TriageObservation: | |
| """Build a TriageObservation for the current state.""" | |
| step = self._state.step_count | |
| logs, system_state = self._get_step_data(step) | |
| alerts = self._get_alerts(step) | |
| return TriageObservation( | |
| logs=logs, | |
| system_state=system_state, | |
| incident_id=self._state.episode_id, | |
| task_id=self._state.task_id, | |
| step_count=step, | |
| time_elapsed_seconds=step * 30, | |
| active_alerts=alerts, | |
| reward=reward, | |
| cumulative_score=self._state.cumulative_score, | |
| done=self._state.done, | |
| last_action_feedback=feedback, | |
| invalid_action_error=invalid_action_error, | |
| ) | |