Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from models import IncidentopsAction, IncidentopsObservation | |
| except ImportError: | |
| from models import IncidentopsAction, IncidentopsObservation | |
| class IncidentSnapshot: | |
| scenario_id: str | |
| task: str | |
| alert_text: str | |
| hidden_truth: str | |
| severity: str | |
| affected_services: List[str] | |
| logs_available: bool | |
| log_snippet: str | |
| likely_cause: str | |
| hf_confidence: float | |
| available_actions: List[str] | |
| correct_action_sequence: List[str] | |
| sla_steps: int | |
| step_count: int = 0 | |
| resolved: bool = False | |
| wrong_escalations: int = 0 | |
| action_history: List[str] = field(default_factory=list) | |
| evidence_collected: bool = False | |
| team_engaged: Optional[str] = None | |
| # Scenarios | |
| SCENARIOS: Dict[str, List[Dict[str, Any]]] = { | |
| "incident_easy": [ | |
| { | |
| "scenario_id": "easy_001", | |
| "task": "incident_easy", | |
| "alert_text": "SEV-2: payment-service latency high after deploy.", | |
| "hidden_truth": "bad_deployment", | |
| "severity": "high", | |
| "affected_services": ["payment-service"], | |
| "logs_available": True, | |
| "log_snippet": "deploy at 14:32 UTC caused connection pool exhaustion", | |
| "likely_cause": "bad_deployment", | |
| "hf_confidence": 0.92, | |
| "available_actions": [ | |
| "request_logs", | |
| "rollback_deploy", | |
| "restart_service", | |
| "resolve_incident", | |
| ], | |
| "correct_action_sequence": [ | |
| "rollback_deploy", | |
| "resolve_incident", | |
| ], | |
| "sla_steps": 5, | |
| } | |
| ], | |
| "incident_medium": [ | |
| { | |
| "scenario_id": "medium_001", | |
| "task": "incident_medium", | |
| "alert_text": "SEV-1: api-gateway 5xx errors; user-profile-service slow; no logs available.", | |
| "hidden_truth": "db_timeout", | |
| "severity": "critical", | |
| "affected_services": ["api-gateway", "user-profile-service"], | |
| "logs_available": False, | |
| "log_snippet": "DB timeout errors from checkout reads", | |
| "likely_cause": "dependency_issue", | |
| "hf_confidence": 0.72, | |
| "available_actions": [ | |
| "request_logs", | |
| "query_dependencies", | |
| "escalate_db_team", | |
| "escalate_network_team", | |
| "restart_service", | |
| "resolve_incident", | |
| ], | |
| "correct_action_sequence": [ | |
| "request_logs", | |
| "query_dependencies", | |
| "escalate_db_team", | |
| "restart_service", | |
| "resolve_incident", | |
| ], | |
| "sla_steps": 8, | |
| } | |
| ], | |
| "incident_hard": [ | |
| { | |
| "scenario_id": "hard_001", | |
| "task": "incident_hard", | |
| "alert_text": "SEV-1: EU checkout failures. Auth and payment degraded. Logs incomplete.", | |
| "hidden_truth": "dns_issue", | |
| "severity": "critical", | |
| "affected_services": ["auth-service", "payment-service", "checkout-service"], | |
| "logs_available": False, | |
| "log_snippet": "DNS query failures in EU region resolver", | |
| "likely_cause": "ambiguous", | |
| "hf_confidence": 0.55, | |
| "available_actions": [ | |
| "request_logs", | |
| "query_dns_status", | |
| "query_region_health", | |
| "rollback_deploy", | |
| "restart_service", | |
| "escalate_network_team", | |
| "escalate_db_team", | |
| "broadcast_status_page", | |
| "resolve_incident", | |
| ], | |
| "correct_action_sequence": [ | |
| "query_region_health", | |
| "query_dns_status", | |
| "escalate_network_team", | |
| "broadcast_status_page", | |
| "resolve_incident", | |
| ], | |
| "sla_steps": 12, | |
| } | |
| ], | |
| } | |
| # Main Environment Class | |
| class IncidentopsEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def init(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._snapshot: Optional[IncidentSnapshot] = None | |
| self._difficulty = "incident_easy" | |
| self._last_observation: Optional[IncidentopsObservation] = None | |
| def _build_observation(self) -> IncidentopsObservation: | |
| assert self._snapshot is not None | |
| remaining = max(self._snapshot.sla_steps - self._snapshot.step_count, 0) | |
| return IncidentopsObservation( | |
| alert_summary=self._snapshot.alert_text, | |
| severity=self._snapshot.severity, | |
| likely_cause=self._snapshot.likely_cause, | |
| hf_confidence=self._snapshot.hf_confidence, | |
| services_affected=self._snapshot.affected_services, | |
| logs_available=self._snapshot.logs_available, | |
| log_snippet=self._snapshot.log_snippet if self._snapshot.logs_available else "", | |
| service_healthy=self._snapshot.resolved, | |
| elapsed_steps=self._snapshot.step_count, | |
| sla_steps_remaining=remaining, | |
| action_history=list(self._snapshot.action_history), | |
| available_actions=self._snapshot.available_actions, | |
| incident_resolved=self._snapshot.resolved, | |
| wrong_escalations=self._snapshot.wrong_escalations, | |
| metadata={ | |
| "scenario_id": self._snapshot.scenario_id, | |
| "task": self._snapshot.task, | |
| "hidden_truth": self._snapshot.hidden_truth, | |
| "team_engaged": self._snapshot.team_engaged, | |
| "evidence_collected": self._snapshot.evidence_collected, | |
| }, | |
| reward=0.0, | |
| done=self._snapshot.resolved, | |
| ) | |
| # Reward | |
| def _calc_reward(self, action: str) -> float: | |
| assert self._snapshot is not None | |
| s = self._snapshot | |
| reward = -0.05 | |
| if s.action_history.count(action) > 1: | |
| reward -= 0.2 | |
| if action == "request_logs" and not s.logs_available: | |
| reward += 0.3 | |
| s.logs_available = True | |
| s.evidence_collected = True | |
| if action == "query_dependencies" and s.hidden_truth == "db_timeout": | |
| reward += 0.5 | |
| s.likely_cause = "db_timeout" | |
| s.hf_confidence = min(0.95, s.hf_confidence + 0.15) | |
| s.evidence_collected = True | |
| if action == "query_dns_status" and s.hidden_truth == "dns_issue": | |
| reward += 0.5 | |
| s.likely_cause = "dns_issue" | |
| s.hf_confidence = min(0.95, s.hf_confidence + 0.20) | |
| s.evidence_collected = True | |
| if action == "query_region_health" and s.hidden_truth == "dns_issue": | |
| reward += 0.4 | |
| s.hf_confidence = min(0.95, s.hf_confidence + 0.10) | |
| if action == "rollback_deploy" and s.hidden_truth == "bad_deployment": | |
| reward += 1.0 | |
| s.resolved = True | |
| elif action == "rollback_deploy": | |
| reward -= 0.8 | |
| if action == "escalate_db_team" and s.hidden_truth == "db_timeout": | |
| reward += 0.7 | |
| s.team_engaged = "db_team" | |
| elif action == "escalate_db_team": | |
| reward -= 0.5 | |
| s.wrong_escalations += 1 | |
| if action == "escalate_network_team" and s.hidden_truth == "dns_issue": | |
| reward += 0.7 | |
| s.team_engaged = "network_team" | |
| elif action == "escalate_network_team": | |
| reward -= 0.5 | |
| s.wrong_escalations += 1 | |
| if action == "broadcast_status_page": | |
| reward += 0.2 if s.step_count <= 2 else 0.05 | |
| if action == "restart_service" and s.hidden_truth in {"bad_deployment", "db_timeout"}: | |
| reward += 0.8 | |
| elif action == "restart_service": | |
| reward -= 0.2 | |
| if action == "resolve_incident": | |
| if s.resolved or s.hidden_truth in {"bad_deployment", "db_timeout", "dns_issue"}: | |
| if s.step_count <= s.sla_steps and ( | |
| s.evidence_collected or s.team_engaged is not None or s.hidden_truth == "bad_deployment" | |
| ): | |
| reward += 1.5 | |
| s.resolved = True | |
| else: | |
| reward -= 2.0 | |
| else: | |
| reward -= 1.0 | |
| if s.step_count > s.sla_steps: | |
| reward -= 0.5 | |
| return reward | |
| # ReSet | |
| def reset(self, episode_id=None, task_id="incident_easy", **kwargs): | |
| print(f"[ENV] reset called: task_id={task_id}", flush=True) | |
| scenarios = SCENARIOS.get(task_id, SCENARIOS["incident_easy"]) | |
| scenario = scenarios[0] | |
| self._state = State(episode_id=episode_id or str(uuid4()), step_count=0) | |
| self._snapshot = IncidentSnapshot(**scenario) | |
| self._snapshot.action_history = [] | |
| self._last_observation = self._build_observation() | |
| return self._last_observation | |
| # Step | |
| def step(self, action) -> IncidentopsObservation: | |
| """Handle step - accept both IncidentopsAction objects and dicts.""" | |
| print(f"[ENV] step called: action={action}, type={type(action)}", flush=True) | |
| if isinstance(action, IncidentopsAction): | |
| action_name = action.action | |
| elif isinstance(action, dict): | |
| action_name = action.get("action", "resolve_incident") | |
| elif isinstance(action, str): | |
| action_name = action | |
| else: | |
| action_name = str(action) | |
| print(f"[ENV] action_name={action_name}", flush=True) | |
| if self._snapshot is None: | |
| print("[ENV] ERROR: No snapshot! Calling reset first.", flush=True) | |
| self.reset() | |
| assert self._snapshot is not None | |
| self._snapshot.step_count += 1 | |
| self._state.step_count = self._snapshot.step_count | |
| self._snapshot.action_history.append(action_name) | |
| reward = self._calc_reward(action_name) | |
| done = self._snapshot.resolved or self._snapshot.step_count >= self._snapshot.sla_steps | |
| obs = self._build_observation() | |
| obs.reward = reward | |
| obs.done = done | |
| obs.metadata = { | |
| **(obs.metadata or {}), | |
| "last_action": action_name, | |
| "last_reward": reward, | |
| } | |
| if done: | |
| grade_result = self.grade() | |
| obs.metadata["grader_score"] = grade_result["score"] | |
| self._last_observation = obs | |
| print(f"[ENV] step done: reward={reward:.2f}, done={done}", flush=True) | |
| return obs | |
| # Grade | |
| def grade(self) -> dict: | |
| assert self._snapshot is not None | |
| s = self._snapshot | |
| total_steps = max(s.step_count, 1) | |
| sla_ok = s.step_count <= s.sla_steps | |
| correct_actions = sum( | |
| 1 for a in s.action_history if a in s.correct_action_sequence | |
| ) | |
| correctness_ratio = correct_actions / max(len(s.correct_action_sequence), 1) | |
| efficiency_bonus = max(0.0, (s.sla_steps - total_steps) / s.sla_steps) | |
| if s.resolved and sla_ok: | |
| score = min(1.0, 0.5 + 0.3 * correctness_ratio + 0.2 * efficiency_bonus) | |
| elif s.resolved: | |
| score = min(0.6, 0.3 + 0.3 * correctness_ratio) | |
| else: | |
| score = max(0.0, 0.1 * correctness_ratio) | |
| return { | |
| "score": round(score, 4), | |
| "success": s.resolved and sla_ok, | |
| "incident_resolved": s.resolved, | |
| "steps_taken": s.step_count, | |
| "sla_met": sla_ok, | |
| "efficiency_bonus": round(efficiency_bonus, 4), | |
| "wrong_escalations": s.wrong_escalations, | |
| "evidence_collected": s.evidence_collected, | |
| } | |
| def state(self) -> State: | |
| return self._state |