cloud-incident / server /environment.py
Elliot89's picture
feat: cloud incident response OpenEnv v0.1.0
37204eb
"""
server/environment.py β€” Core OpenEnv environment for Cloud Incident Response.
Implements the full OpenEnv interface:
reset(task_id, scenario_index) -> Observation
step(action) -> (Observation, Reward, done, info)
state() -> EpisodeState
All state is in-memory. Thread-safe via a lock.
"""
from __future__ import annotations
import uuid
import threading
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from tasks import get_task, get_scenario
from graders import grade
from server.models import Action, ActionParameters, Observation, Reward, EpisodeState
# ── Action type classification ────────────────────────────────────────────────
_DIAGNOSTIC = frozenset({
"query_logs", "check_metrics", "check_dependencies",
"check_recent_deploys", "check_service_status",
})
_REMEDIATION = frozenset({
"restart_service", "rollback_deploy", "scale_service",
"disable_feature_flag", "clear_cache", "execute_runbook_step",
})
_SUBMIT = frozenset({
"submit_severity", "submit_root_cause", "submit_resolution",
})
# ── Reward constants ──────────────────────────────────────────────────────────
R_QUERY_FIRST = +0.05 # First time querying a known service
R_QUERY_REPEAT = +0.01 # Re-querying same service/tool
R_QUERY_UNKNOWN = -0.05 # Querying an unknown service
R_REM_GOOD = +0.10 # Correct remediation action
R_REM_WRONG = -0.10 # Wrong remediation action
R_PAST_HALF = -0.02 # Step efficiency penalty past halfway
R_TIMEOUT = -0.10 # No submission before max_steps
R_BAD_ACTION = -0.03 # Unrecognised action_type
class IncidentEnvironment:
"""
OpenEnv environment for Cloud Incident Response.
One instance handles one episode at a time. Thread-safe.
"""
def __init__(self):
self._lock = threading.Lock()
self._s: dict = {}
self._scenario: dict = {}
self._task_def: dict = {}
self._ready = False
# ── Public OpenEnv API ───────────────────────────────────────────────────
def reset(self, task_id: str, scenario_index: int = 0) -> Observation:
"""Start a fresh episode. Returns the initial Observation."""
with self._lock:
task_def = get_task(task_id)
scenario = get_scenario(task_id, scenario_index)
self._task_def = task_def
self._scenario = scenario
self._s = {
"episode_id": str(uuid.uuid4()),
"task_id": task_id,
"scenario_id": scenario["scenario_id"],
"step_count": 0,
"max_steps": task_def["max_steps"],
"action_history": [],
"queried_data": {},
"queried_keys": set(),
"submitted": False,
"resolved": False,
"done": False,
"cumulative_reward": 0.0,
"feedback": f"Episode started. {scenario['description']}",
}
self._ready = True
return self._build_obs()
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
"""Process one agent action. Returns (Observation, Reward, done, info)."""
with self._lock:
if not self._ready:
raise RuntimeError("Call reset() before step().")
s = self._s
if s["done"]:
return (
self._build_obs(),
Reward(value=0.0, reason="episode already done",
cumulative=s["cumulative_reward"]),
True,
{},
)
s["step_count"] += 1
step_num = s["step_count"]
at = action.action_type
params = action.parameters
s["action_history"].append({
"action_type": at,
"parameters": params.model_dump(exclude_none=True),
"step": step_num,
})
r = 0.0
fb: list[str] = []
# Efficiency penalty past halfway
if step_num > s["max_steps"] // 2:
r += R_PAST_HALF
fb.append("efficiency penalty")
if at in _DIAGNOSTIC:
r, fb = self._handle_diagnostic(at, params, r, fb)
elif at in _REMEDIATION:
r, fb = self._handle_remediation(at, params, r, fb)
elif at in _SUBMIT:
r, fb, terminal = self._handle_submit(at, params, r, fb)
if terminal:
s["done"] = True
else:
r += R_BAD_ACTION
fb.append(f"unknown action_type '{at}'")
# Timeout
if step_num >= s["max_steps"] and not s["done"]:
r += R_TIMEOUT
fb.append("timeout β€” no submission made")
s["done"] = True
# Run grader on terminal step
if s["done"]:
result = grade(s["task_id"], s, self._scenario)
s["cumulative_reward"] = round(
s["cumulative_reward"] + r + result["total"], 4
)
fb.append(f"grader={result['feedback']}")
else:
s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4)
s["feedback"] = " | ".join(fb) if fb else "ok"
return (
self._build_obs(),
Reward(
value=round(r, 4),
reason=s["feedback"],
cumulative=s["cumulative_reward"],
),
s["done"],
{"step": step_num, "feedback": s["feedback"]},
)
def state(self) -> EpisodeState:
"""Return the full current episode state."""
with self._lock:
if not self._ready:
raise RuntimeError("No active episode β€” call reset() first.")
s = self._s
return EpisodeState(
episode_id=s["episode_id"],
task_id=s["task_id"],
scenario_id=s["scenario_id"],
step_count=s["step_count"],
max_steps=s["max_steps"],
action_history=list(s["action_history"]),
queried_data=dict(s["queried_data"]),
submitted=s["submitted"],
resolved=s["resolved"],
done=s["done"],
cumulative_reward=s["cumulative_reward"],
feedback=s["feedback"],
)
# ── Action handlers ──────────────────────────────────────────────────────
def _handle_diagnostic(
self, at: str, params: ActionParameters, r: float, fb: list[str]
) -> tuple[float, list[str]]:
s = self._s
service = (params.service or "").lower().strip()
known = {sv.lower() for sv in self._scenario.get("known_services", set())}
tool_data = self._scenario.get("tool_responses", {}).get(at, {})
key = (at, service)
if service and service in known:
if key not in s["queried_keys"]:
r += R_QUERY_FIRST
fb.append(f"queried {service} (+{R_QUERY_FIRST})")
s["queried_keys"].add(key)
else:
r += R_QUERY_REPEAT
fb.append(f"re-queried {service} (+{R_QUERY_REPEAT})")
result = tool_data.get(service, f"No data for '{service}'.")
s["queried_data"].setdefault(at, {})[service] = result
elif service:
r += R_QUERY_UNKNOWN
fb.append(f"unknown service '{service}' ({R_QUERY_UNKNOWN})")
else:
fb.append(f"{at}: no service specified")
return r, fb
def _handle_remediation(
self, at: str, params: ActionParameters, r: float, fb: list[str]
) -> tuple[float, list[str]]:
s = self._s
service = (params.service or "").lower().strip()
flag = (params.flag or "").lower().strip()
runbook = (params.runbook_action or "").lower().strip()
target = (params.target or "").lower().strip()
keys = {at}
if service: keys.add(f"{at}:{service}")
if flag: keys.add(f"{at}:{flag}")
if runbook: keys.add(f"execute_runbook_step:{runbook}")
if target: keys.add(f"execute_runbook_step:{target}")
wrong_map = self._scenario.get("wrong_actions", {})
rem_data = self._scenario.get("remediation_data", {})
if any(k in wrong_map for k in keys):
r += R_REM_WRONG
reason = next(
(wrong_map[k] for k in keys if k in wrong_map), "wrong action"
)
fb.append(f"wrong action '{at}': {str(reason)[:80]}")
else:
r += R_REM_GOOD
fb.append(f"executed {at}" + (f" on '{service}'" if service else ""))
at_data = rem_data.get(at, {})
result = (
at_data.get(service) or at_data.get(flag) or
at_data.get(runbook) or at_data.get(target) or
"action executed successfully"
)
s["queried_data"].setdefault(at, {})[
service or flag or runbook or target or at
] = result
return r, fb
def _handle_submit(
self, at: str, params: ActionParameters, r: float, fb: list[str]
) -> tuple[float, list[str], bool]:
s = self._s
s["submitted"] = True
if at == "submit_severity":
fb.append(f"submitted severity: {(params.severity or '').upper()}")
elif at == "submit_root_cause":
fb.append(
f"submitted root cause: "
f"service={params.service or ''}, "
f"failure_mode={params.failure_mode or ''}"
)
elif at == "submit_resolution":
summary = params.summary or ""
inv_count = sum(
1 for a in s["action_history"]
if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION
)
if summary.strip() and inv_count >= 1:
s["resolved"] = True
fb.append("resolution submitted β€” incident resolved")
else:
fb.append("resolution submitted β€” insufficient investigation")
return r, fb, True
# ── Build observation ────────────────────────────────────────────────────
def _build_obs(self) -> Observation:
s = self._s
sc = self._scenario
td = self._task_def
return Observation(
episode_id=s["episode_id"],
task_id=s["task_id"],
scenario_id=s["scenario_id"],
step_count=s["step_count"],
max_steps=s["max_steps"],
incident_summary=sc.get("incident_summary", sc.get("description", "")),
alert=sc.get("alert", {}),
available_actions=td.get("available_actions", []),
queried_data=dict(s["queried_data"]),
cumulative_reward=s["cumulative_reward"],
done=s["done"],
feedback=s["feedback"],
)