Spaces:
Sleeping
Sleeping
| """ | |
| server/environment.py — Core OpenEnv environment for Cloud Incident Response. | |
| Difficulty comes from SCENARIO DESIGN, not mechanics: | |
| EASY: 3 services, clear metrics, obvious severity | |
| MEDIUM: 8 services, root cause NOT in alert, must follow log breadcrumbs | |
| HARD: 8 services + 5-7 remediation steps + quality summary + penalties | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import threading | |
| import uuid | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from graders import _svc_match, grade | |
| from server.models import Action, ActionParameters, EpisodeState, Observation, Reward | |
| from tasks import get_scenario, get_task | |
| _DIAGNOSTIC = frozenset({ | |
| "query_logs", "check_metrics", "check_dependencies", | |
| "check_recent_deploys", "check_service_status", | |
| }) | |
| _REMEDIATION = frozenset({ | |
| "restart_service", "rollback_deploy", "scale_service", | |
| "disable_feature_flag", "clear_cache", "execute_runbook_step", | |
| }) | |
| _SUBMIT = frozenset({ | |
| "submit_severity", "submit_root_cause", "submit_resolution", | |
| }) | |
| _TASK_SUBMIT = { | |
| "alert_classification": "submit_severity", | |
| "root_cause_analysis": "submit_root_cause", | |
| "remediation_planning": "submit_resolution", | |
| } | |
| _REWARD_TABLE = { | |
| "easy": { | |
| "query_new_svc": +0.04, "query_new_action": +0.02, | |
| "query_repeat": -0.03, "query_unknown_svc": -0.06, | |
| "query_no_service": -0.04, "rem_good": +0.00, | |
| "rem_wrong": -0.08, "rem_no_target": -0.05, | |
| "submit_correct": +0.02, "submit_wrong": -0.08, | |
| "past_half": -0.04, "timeout": -0.15, | |
| "bad_action": -0.05, "exact_repeat": -0.04, | |
| }, | |
| "medium": { | |
| "query_new_svc": +0.04, "query_new_action": +0.02, | |
| "query_repeat": -0.04, "query_unknown_svc": -0.06, | |
| "query_no_service": -0.04, "rem_good": +0.06, | |
| "rem_wrong": -0.10, "rem_no_target": -0.06, | |
| "submit_correct": +0.02, "submit_wrong": -0.10, | |
| "past_half": -0.02, "timeout": -0.15, | |
| "bad_action": -0.05, "exact_repeat": -0.05, | |
| }, | |
| "hard": { | |
| "query_new_svc": +0.03, "query_new_action": +0.01, | |
| "query_repeat": -0.03, "query_unknown_svc": -0.05, | |
| "query_no_service": -0.03, "rem_good": +0.06, | |
| "rem_wrong": -0.15, "rem_no_target": -0.05, | |
| "submit_correct": +0.02, "submit_wrong": -0.12, | |
| "past_half": -0.02, "timeout": -0.20, | |
| "bad_action": -0.05, "exact_repeat": -0.04, | |
| }, | |
| } | |
| _TASK_DIFFICULTY = { | |
| "alert_classification": "easy", | |
| "root_cause_analysis": "medium", | |
| "remediation_planning": "hard", | |
| } | |
| class IncidentEnvironment: | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| self._s: dict = {} | |
| self._scenario: dict = {} | |
| self._task_def: dict = {} | |
| self._ready = False | |
| def reset(self, task_id: str = "alert_classification", | |
| scenario_index: int = 0) -> Observation: | |
| with self._lock: | |
| task_def = get_task(task_id) | |
| scenario = get_scenario(task_id, scenario_index) | |
| self._task_def = task_def | |
| self._scenario = scenario | |
| self._s = { | |
| "episode_id": str(uuid.uuid4()), | |
| "task_id": task_id, | |
| "scenario_id": scenario["scenario_id"], | |
| "step_count": 0, | |
| "max_steps": task_def["max_steps"], | |
| "action_history": [], | |
| "queried_data": {}, | |
| "queried_keys": set(), | |
| "services_queried": set(), | |
| "exact_hashes": set(), | |
| "submitted": False, | |
| "resolved": False, | |
| "done": False, | |
| "cumulative_reward": 0.0, | |
| "feedback": f"Episode started. {scenario['description']}", | |
| "last_action_error": None, | |
| } | |
| self._ready = True | |
| return self._build_obs() | |
| def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]: | |
| with self._lock: | |
| if not self._ready: | |
| raise RuntimeError("Call reset() before step().") | |
| s = self._s | |
| s["last_action_error"] = None | |
| if s["done"]: | |
| return (self._build_obs(), | |
| Reward(score=0.0, reason="episode already done", | |
| cumulative=s["cumulative_reward"]), | |
| True, {}) | |
| s["step_count"] += 1 | |
| step_num = s["step_count"] | |
| at = action.action_type | |
| params = action.parameters | |
| task_id = s["task_id"] | |
| diff = _TASK_DIFFICULTY.get(task_id, "medium") | |
| rt = _REWARD_TABLE[diff] | |
| s["action_history"].append({ | |
| "action_type": at, | |
| "parameters": params.model_dump(exclude_none=True), | |
| "step": step_num, | |
| }) | |
| r = 0.0 | |
| fb: list[str] = [] | |
| h = f"{at}|{params.model_dump_json(exclude_none=True)}" | |
| if h in s["exact_hashes"]: | |
| r += rt["exact_repeat"] | |
| fb.append(f"exact repeat ({rt['exact_repeat']:+.2f})") | |
| s["exact_hashes"].add(h) | |
| half = max(1, s["max_steps"] // 2) | |
| if step_num > half and at not in _SUBMIT: | |
| r += rt["past_half"] | |
| fb.append(f"past halfway ({rt['past_half']:+.3f})") | |
| if at in _DIAGNOSTIC: | |
| r, fb = self._handle_diagnostic(at, params, r, fb, rt) | |
| elif at in _REMEDIATION: | |
| r, fb = self._handle_remediation(at, params, r, fb, rt, task_id) | |
| elif at in _SUBMIT: | |
| r, fb, terminal = self._handle_submit(at, params, r, fb, rt, task_id) | |
| if terminal: | |
| s["done"] = True | |
| else: | |
| r += rt["bad_action"] | |
| fb.append(f"unknown action '{at}' ({rt['bad_action']:+.2f})") | |
| s["last_action_error"] = f"Unknown action type: {at}" | |
| if step_num >= s["max_steps"] and not s["done"]: | |
| r += rt["timeout"] | |
| fb.append(f"timeout ({rt['timeout']:+.2f})") | |
| s["done"] = True | |
| if s["done"]: | |
| result = grade(s["task_id"], s, self._scenario) | |
| grader_score = result["total"] | |
| s["cumulative_reward"] = round( | |
| s["cumulative_reward"] + r + grader_score, 4) | |
| fb.append(f"grader={grader_score:.3f} ({result['feedback']})") | |
| else: | |
| s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4) | |
| s["feedback"] = " | ".join(fb) if fb else "ok" | |
| return (self._build_obs(), | |
| Reward(score=round(r, 4), reason=s["feedback"], | |
| cumulative=s["cumulative_reward"]), | |
| s["done"], | |
| {"step": step_num, "feedback": s["feedback"]}) | |
| def state(self) -> EpisodeState: | |
| with self._lock: | |
| if not self._ready: | |
| raise RuntimeError("No active episode — call reset() first.") | |
| s = self._s | |
| return EpisodeState( | |
| episode_id=s["episode_id"], task_id=s["task_id"], | |
| scenario_id=s["scenario_id"], step_count=s["step_count"], | |
| max_steps=s["max_steps"], | |
| action_history=list(s["action_history"]), | |
| queried_data=dict(s["queried_data"]), | |
| submitted=s["submitted"], resolved=s["resolved"], | |
| done=s["done"], cumulative_reward=s["cumulative_reward"], | |
| feedback=s["feedback"]) | |
| def _handle_diagnostic(self, at, params, r, fb, rt): | |
| s = self._s | |
| svc = (params.service or "").lower().strip() | |
| known = {v.lower() for v in self._scenario.get("known_services", set())} | |
| tool = self._scenario.get("tool_responses", {}).get(at, {}) | |
| key = (at, svc) | |
| if not svc: | |
| r += rt["query_no_service"] | |
| fb.append(f"{at}: no service ({rt['query_no_service']:+.2f})") | |
| s["last_action_error"] = f"{at} requires a service parameter" | |
| return r, fb | |
| if svc not in known: | |
| r += rt["query_unknown_svc"] | |
| fb.append(f"unknown service '{svc}' ({rt['query_unknown_svc']:+.2f})") | |
| s["last_action_error"] = f"Unknown service: {svc}" | |
| return r, fb | |
| if key in s["queried_keys"]: | |
| r += rt["query_repeat"] | |
| fb.append(f"repeat [{at}][{svc}] ({rt['query_repeat']:+.2f})") | |
| elif svc in s["services_queried"]: | |
| r += rt["query_new_action"] | |
| fb.append(f"new action on {svc} ({rt['query_new_action']:+.2f})") | |
| s["queried_keys"].add(key) | |
| else: | |
| r += rt["query_new_svc"] | |
| fb.append(f"new service {svc} ({rt['query_new_svc']:+.2f})") | |
| s["queried_keys"].add(key) | |
| s["services_queried"].add(svc) | |
| result = tool.get(svc, f"No data available for '{svc}'.") | |
| s["queried_data"].setdefault(at, {})[svc] = result | |
| return r, fb | |
| def _handle_remediation(self, at, params, r, fb, rt, task_id): | |
| s = self._s | |
| if task_id == "alert_classification": | |
| r += rt["rem_wrong"] | |
| fb.append(f"remediation in easy task ({rt['rem_wrong']:+.2f})") | |
| s["last_action_error"] = "Remediation not available in alert_classification" | |
| return r, fb | |
| svc = (params.service or "").lower().strip() | |
| flag = (params.flag or "").lower().strip() | |
| runbook = (params.runbook_action or "").lower().strip() | |
| target = (params.target or params.target_version or "").lower().strip() | |
| if not (svc or flag or runbook or target): | |
| r += rt["rem_no_target"] | |
| fb.append(f"{at}: no target ({rt['rem_no_target']:+.2f})") | |
| s["last_action_error"] = f"{at} requires a target" | |
| return r, fb | |
| keys = {at} | |
| if svc: keys.add(f"{at}:{svc}") | |
| if flag: keys.add(f"{at}:{flag}") | |
| if runbook: keys.add(f"execute_runbook_step:{runbook}") | |
| if target: keys.add(f"execute_runbook_step:{target}") | |
| wrong_map = self._scenario.get("wrong_actions", {}) | |
| rem_data = self._scenario.get("remediation_data", {}) | |
| is_wrong = any(k in wrong_map for k in keys) | |
| if not is_wrong and svc: | |
| for wk in wrong_map: | |
| if ":" in wk: | |
| w_at, w_svc = wk.split(":", 1) | |
| if w_at == at and _svc_match(svc, w_svc): | |
| is_wrong = True | |
| break | |
| if is_wrong: | |
| r += rt["rem_wrong"] | |
| reason = next((wrong_map[k] for k in keys if k in wrong_map), "wrong") | |
| fb.append(f"wrong: {at} — {str(reason)[:60]} ({rt['rem_wrong']:+.2f})") | |
| else: | |
| r += rt["rem_good"] | |
| tgt = svc or flag or runbook or target | |
| fb.append(f"executed {at}:{tgt} ({rt['rem_good']:+.2f})") | |
| at_data = rem_data.get(at, {}) | |
| result = (at_data.get(svc) or at_data.get(flag) or at_data.get(runbook) | |
| or at_data.get(target) or "action executed successfully") | |
| s["queried_data"].setdefault(at, {})[tgt] = result | |
| return r, fb | |
| def _handle_submit(self, at, params, r, fb, rt, task_id): | |
| s = self._s | |
| correct = _TASK_SUBMIT.get(task_id, "") | |
| if at != correct: | |
| r += rt["submit_wrong"] | |
| fb.append(f"wrong submit '{at}' (need '{correct}') ({rt['submit_wrong']:+.2f})") | |
| s["last_action_error"] = f"Wrong submission type: use {correct}" | |
| return r, fb, False | |
| s["submitted"] = True | |
| r += rt["submit_correct"] | |
| fb.append(f"submitted ({rt['submit_correct']:+.2f})") | |
| if at == "submit_severity": | |
| fb.append(f"severity={(params.severity or '').upper().strip()}") | |
| elif at == "submit_root_cause": | |
| fb.append(f"svc={params.service or ''}, mode={params.failure_mode or ''}") | |
| elif at == "submit_resolution": | |
| summary = params.summary or "" | |
| inv = sum(1 for a in s["action_history"] | |
| if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION) | |
| if summary.strip() and inv >= 1: | |
| s["resolved"] = True | |
| fb.append("resolved") | |
| else: | |
| fb.append("insufficient investigation") | |
| return r, fb, True | |
| def _build_obs(self): | |
| s = self._s | |
| sc = self._scenario | |
| td = self._task_def | |
| return Observation( | |
| episode_id=s["episode_id"], task_id=s["task_id"], | |
| scenario_id=s["scenario_id"], step_count=s["step_count"], | |
| max_steps=s["max_steps"], | |
| incident_summary=sc.get("incident_summary", sc.get("description", "")), | |
| alert=sc.get("alert", {}), | |
| available_actions=td.get("available_actions", []), | |
| queried_data=dict(s["queried_data"]), | |
| cumulative_reward=s["cumulative_reward"], | |
| done=s["done"], feedback=s["feedback"], | |
| known_services=sorted(sc.get("known_services", set())), | |
| last_action_error=s.get("last_action_error")) |