Spaces:
Sleeping
Sleeping
| """ | |
| server/environment.py β Core OpenEnv environment for Cloud Incident Response. | |
| Implements the full OpenEnv interface: | |
| reset(task_id, scenario_index) -> Observation | |
| step(action) -> (Observation, Reward, done, info) | |
| state() -> EpisodeState | |
| All state is in-memory. Thread-safe via a lock. | |
| """ | |
| from __future__ import annotations | |
| import uuid | |
| import threading | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from tasks import get_task, get_scenario | |
| from graders import grade | |
| from server.models import Action, ActionParameters, Observation, Reward, EpisodeState | |
| # ββ Action type classification ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _DIAGNOSTIC = frozenset({ | |
| "query_logs", "check_metrics", "check_dependencies", | |
| "check_recent_deploys", "check_service_status", | |
| }) | |
| _REMEDIATION = frozenset({ | |
| "restart_service", "rollback_deploy", "scale_service", | |
| "disable_feature_flag", "clear_cache", "execute_runbook_step", | |
| }) | |
| _SUBMIT = frozenset({ | |
| "submit_severity", "submit_root_cause", "submit_resolution", | |
| }) | |
| # ββ Reward constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| R_QUERY_FIRST = +0.05 # First time querying a known service | |
| R_QUERY_REPEAT = +0.01 # Re-querying same service/tool | |
| R_QUERY_UNKNOWN = -0.05 # Querying an unknown service | |
| R_REM_GOOD = +0.10 # Correct remediation action | |
| R_REM_WRONG = -0.10 # Wrong remediation action | |
| R_PAST_HALF = -0.02 # Step efficiency penalty past halfway | |
| R_TIMEOUT = -0.10 # No submission before max_steps | |
| R_BAD_ACTION = -0.03 # Unrecognised action_type | |
| class IncidentEnvironment: | |
| """ | |
| OpenEnv environment for Cloud Incident Response. | |
| One instance handles one episode at a time. Thread-safe. | |
| """ | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self._s: dict = {} | |
| self._scenario: dict = {} | |
| self._task_def: dict = {} | |
| self._ready = False | |
| # ββ Public OpenEnv API βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str, scenario_index: int = 0) -> Observation: | |
| """Start a fresh episode. Returns the initial Observation.""" | |
| with self._lock: | |
| task_def = get_task(task_id) | |
| scenario = get_scenario(task_id, scenario_index) | |
| self._task_def = task_def | |
| self._scenario = scenario | |
| self._s = { | |
| "episode_id": str(uuid.uuid4()), | |
| "task_id": task_id, | |
| "scenario_id": scenario["scenario_id"], | |
| "step_count": 0, | |
| "max_steps": task_def["max_steps"], | |
| "action_history": [], | |
| "queried_data": {}, | |
| "queried_keys": set(), | |
| "submitted": False, | |
| "resolved": False, | |
| "done": False, | |
| "cumulative_reward": 0.0, | |
| "feedback": f"Episode started. {scenario['description']}", | |
| } | |
| self._ready = True | |
| return self._build_obs() | |
| def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]: | |
| """Process one agent action. Returns (Observation, Reward, done, info).""" | |
| with self._lock: | |
| if not self._ready: | |
| raise RuntimeError("Call reset() before step().") | |
| s = self._s | |
| if s["done"]: | |
| return ( | |
| self._build_obs(), | |
| Reward(value=0.0, reason="episode already done", | |
| cumulative=s["cumulative_reward"]), | |
| True, | |
| {}, | |
| ) | |
| s["step_count"] += 1 | |
| step_num = s["step_count"] | |
| at = action.action_type | |
| params = action.parameters | |
| s["action_history"].append({ | |
| "action_type": at, | |
| "parameters": params.model_dump(exclude_none=True), | |
| "step": step_num, | |
| }) | |
| r = 0.0 | |
| fb: list[str] = [] | |
| # Efficiency penalty past halfway | |
| if step_num > s["max_steps"] // 2: | |
| r += R_PAST_HALF | |
| fb.append("efficiency penalty") | |
| if at in _DIAGNOSTIC: | |
| r, fb = self._handle_diagnostic(at, params, r, fb) | |
| elif at in _REMEDIATION: | |
| r, fb = self._handle_remediation(at, params, r, fb) | |
| elif at in _SUBMIT: | |
| r, fb, terminal = self._handle_submit(at, params, r, fb) | |
| if terminal: | |
| s["done"] = True | |
| else: | |
| r += R_BAD_ACTION | |
| fb.append(f"unknown action_type '{at}'") | |
| # Timeout | |
| if step_num >= s["max_steps"] and not s["done"]: | |
| r += R_TIMEOUT | |
| fb.append("timeout β no submission made") | |
| s["done"] = True | |
| # Run grader on terminal step | |
| if s["done"]: | |
| result = grade(s["task_id"], s, self._scenario) | |
| s["cumulative_reward"] = round( | |
| s["cumulative_reward"] + r + result["total"], 4 | |
| ) | |
| fb.append(f"grader={result['feedback']}") | |
| else: | |
| s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4) | |
| s["feedback"] = " | ".join(fb) if fb else "ok" | |
| return ( | |
| self._build_obs(), | |
| Reward( | |
| value=round(r, 4), | |
| reason=s["feedback"], | |
| cumulative=s["cumulative_reward"], | |
| ), | |
| s["done"], | |
| {"step": step_num, "feedback": s["feedback"]}, | |
| ) | |
| def state(self) -> EpisodeState: | |
| """Return the full current episode state.""" | |
| with self._lock: | |
| if not self._ready: | |
| raise RuntimeError("No active episode β call reset() first.") | |
| s = self._s | |
| return EpisodeState( | |
| episode_id=s["episode_id"], | |
| task_id=s["task_id"], | |
| scenario_id=s["scenario_id"], | |
| step_count=s["step_count"], | |
| max_steps=s["max_steps"], | |
| action_history=list(s["action_history"]), | |
| queried_data=dict(s["queried_data"]), | |
| submitted=s["submitted"], | |
| resolved=s["resolved"], | |
| done=s["done"], | |
| cumulative_reward=s["cumulative_reward"], | |
| feedback=s["feedback"], | |
| ) | |
| # ββ Action handlers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_diagnostic( | |
| self, at: str, params: ActionParameters, r: float, fb: list[str] | |
| ) -> tuple[float, list[str]]: | |
| s = self._s | |
| service = (params.service or "").lower().strip() | |
| known = {sv.lower() for sv in self._scenario.get("known_services", set())} | |
| tool_data = self._scenario.get("tool_responses", {}).get(at, {}) | |
| key = (at, service) | |
| if service and service in known: | |
| if key not in s["queried_keys"]: | |
| r += R_QUERY_FIRST | |
| fb.append(f"queried {service} (+{R_QUERY_FIRST})") | |
| s["queried_keys"].add(key) | |
| else: | |
| r += R_QUERY_REPEAT | |
| fb.append(f"re-queried {service} (+{R_QUERY_REPEAT})") | |
| result = tool_data.get(service, f"No data for '{service}'.") | |
| s["queried_data"].setdefault(at, {})[service] = result | |
| elif service: | |
| r += R_QUERY_UNKNOWN | |
| fb.append(f"unknown service '{service}' ({R_QUERY_UNKNOWN})") | |
| else: | |
| fb.append(f"{at}: no service specified") | |
| return r, fb | |
| def _handle_remediation( | |
| self, at: str, params: ActionParameters, r: float, fb: list[str] | |
| ) -> tuple[float, list[str]]: | |
| s = self._s | |
| service = (params.service or "").lower().strip() | |
| flag = (params.flag or "").lower().strip() | |
| runbook = (params.runbook_action or "").lower().strip() | |
| target = (params.target or "").lower().strip() | |
| keys = {at} | |
| if service: keys.add(f"{at}:{service}") | |
| if flag: keys.add(f"{at}:{flag}") | |
| if runbook: keys.add(f"execute_runbook_step:{runbook}") | |
| if target: keys.add(f"execute_runbook_step:{target}") | |
| wrong_map = self._scenario.get("wrong_actions", {}) | |
| rem_data = self._scenario.get("remediation_data", {}) | |
| if any(k in wrong_map for k in keys): | |
| r += R_REM_WRONG | |
| reason = next( | |
| (wrong_map[k] for k in keys if k in wrong_map), "wrong action" | |
| ) | |
| fb.append(f"wrong action '{at}': {str(reason)[:80]}") | |
| else: | |
| r += R_REM_GOOD | |
| fb.append(f"executed {at}" + (f" on '{service}'" if service else "")) | |
| at_data = rem_data.get(at, {}) | |
| result = ( | |
| at_data.get(service) or at_data.get(flag) or | |
| at_data.get(runbook) or at_data.get(target) or | |
| "action executed successfully" | |
| ) | |
| s["queried_data"].setdefault(at, {})[ | |
| service or flag or runbook or target or at | |
| ] = result | |
| return r, fb | |
| def _handle_submit( | |
| self, at: str, params: ActionParameters, r: float, fb: list[str] | |
| ) -> tuple[float, list[str], bool]: | |
| s = self._s | |
| s["submitted"] = True | |
| if at == "submit_severity": | |
| fb.append(f"submitted severity: {(params.severity or '').upper()}") | |
| elif at == "submit_root_cause": | |
| fb.append( | |
| f"submitted root cause: " | |
| f"service={params.service or ''}, " | |
| f"failure_mode={params.failure_mode or ''}" | |
| ) | |
| elif at == "submit_resolution": | |
| summary = params.summary or "" | |
| inv_count = sum( | |
| 1 for a in s["action_history"] | |
| if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION | |
| ) | |
| if summary.strip() and inv_count >= 1: | |
| s["resolved"] = True | |
| fb.append("resolution submitted β incident resolved") | |
| else: | |
| fb.append("resolution submitted β insufficient investigation") | |
| return r, fb, True | |
| # ββ Build observation ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_obs(self) -> Observation: | |
| s = self._s | |
| sc = self._scenario | |
| td = self._task_def | |
| return Observation( | |
| episode_id=s["episode_id"], | |
| task_id=s["task_id"], | |
| scenario_id=s["scenario_id"], | |
| step_count=s["step_count"], | |
| max_steps=s["max_steps"], | |
| incident_summary=sc.get("incident_summary", sc.get("description", "")), | |
| alert=sc.get("alert", {}), | |
| available_actions=td.get("available_actions", []), | |
| queried_data=dict(s["queried_data"]), | |
| cumulative_reward=s["cumulative_reward"], | |
| done=s["done"], | |
| feedback=s["feedback"], | |
| ) | |