cloud-incident-response / server /environment.py
Elliot89's picture
Prepare project for push: update files
8151d99
"""
server/environment.py — Core OpenEnv environment for Cloud Incident Response.
Difficulty comes from SCENARIO DESIGN, not mechanics:
EASY: 3 services, clear metrics, obvious severity
MEDIUM: 8 services, root cause NOT in alert, must follow log breadcrumbs
HARD: 8 services + 5-7 remediation steps + quality summary + penalties
"""
from __future__ import annotations
import os
import sys
import threading
import uuid
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from graders import _svc_match, grade
from server.models import Action, ActionParameters, EpisodeState, Observation, Reward
from tasks import get_scenario, get_task
_DIAGNOSTIC = frozenset({
"query_logs", "check_metrics", "check_dependencies",
"check_recent_deploys", "check_service_status",
})
_REMEDIATION = frozenset({
"restart_service", "rollback_deploy", "scale_service",
"disable_feature_flag", "clear_cache", "execute_runbook_step",
})
_SUBMIT = frozenset({
"submit_severity", "submit_root_cause", "submit_resolution",
})
_TASK_SUBMIT = {
"alert_classification": "submit_severity",
"root_cause_analysis": "submit_root_cause",
"remediation_planning": "submit_resolution",
}
_REWARD_TABLE = {
"easy": {
"query_new_svc": +0.04, "query_new_action": +0.02,
"query_repeat": -0.03, "query_unknown_svc": -0.06,
"query_no_service": -0.04, "rem_good": +0.00,
"rem_wrong": -0.08, "rem_no_target": -0.05,
"submit_correct": +0.02, "submit_wrong": -0.08,
"past_half": -0.04, "timeout": -0.15,
"bad_action": -0.05, "exact_repeat": -0.04,
},
"medium": {
"query_new_svc": +0.04, "query_new_action": +0.02,
"query_repeat": -0.04, "query_unknown_svc": -0.06,
"query_no_service": -0.04, "rem_good": +0.06,
"rem_wrong": -0.10, "rem_no_target": -0.06,
"submit_correct": +0.02, "submit_wrong": -0.10,
"past_half": -0.02, "timeout": -0.15,
"bad_action": -0.05, "exact_repeat": -0.05,
},
"hard": {
"query_new_svc": +0.03, "query_new_action": +0.01,
"query_repeat": -0.03, "query_unknown_svc": -0.05,
"query_no_service": -0.03, "rem_good": +0.06,
"rem_wrong": -0.15, "rem_no_target": -0.05,
"submit_correct": +0.02, "submit_wrong": -0.12,
"past_half": -0.02, "timeout": -0.20,
"bad_action": -0.05, "exact_repeat": -0.04,
},
}
_TASK_DIFFICULTY = {
"alert_classification": "easy",
"root_cause_analysis": "medium",
"remediation_planning": "hard",
}
class IncidentEnvironment:
def __init__(self) -> None:
self._lock = threading.Lock()
self._s: dict = {}
self._scenario: dict = {}
self._task_def: dict = {}
self._ready = False
def reset(self, task_id: str = "alert_classification",
scenario_index: int = 0) -> Observation:
with self._lock:
task_def = get_task(task_id)
scenario = get_scenario(task_id, scenario_index)
self._task_def = task_def
self._scenario = scenario
self._s = {
"episode_id": str(uuid.uuid4()),
"task_id": task_id,
"scenario_id": scenario["scenario_id"],
"step_count": 0,
"max_steps": task_def["max_steps"],
"action_history": [],
"queried_data": {},
"queried_keys": set(),
"services_queried": set(),
"exact_hashes": set(),
"submitted": False,
"resolved": False,
"done": False,
"cumulative_reward": 0.0,
"feedback": f"Episode started. {scenario['description']}",
"last_action_error": None,
}
self._ready = True
return self._build_obs()
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
with self._lock:
if not self._ready:
raise RuntimeError("Call reset() before step().")
s = self._s
s["last_action_error"] = None
if s["done"]:
return (self._build_obs(),
Reward(score=0.0, reason="episode already done",
cumulative=s["cumulative_reward"]),
True, {})
s["step_count"] += 1
step_num = s["step_count"]
at = action.action_type
params = action.parameters
task_id = s["task_id"]
diff = _TASK_DIFFICULTY.get(task_id, "medium")
rt = _REWARD_TABLE[diff]
s["action_history"].append({
"action_type": at,
"parameters": params.model_dump(exclude_none=True),
"step": step_num,
})
r = 0.0
fb: list[str] = []
h = f"{at}|{params.model_dump_json(exclude_none=True)}"
if h in s["exact_hashes"]:
r += rt["exact_repeat"]
fb.append(f"exact repeat ({rt['exact_repeat']:+.2f})")
s["exact_hashes"].add(h)
half = max(1, s["max_steps"] // 2)
if step_num > half and at not in _SUBMIT:
r += rt["past_half"]
fb.append(f"past halfway ({rt['past_half']:+.3f})")
if at in _DIAGNOSTIC:
r, fb = self._handle_diagnostic(at, params, r, fb, rt)
elif at in _REMEDIATION:
r, fb = self._handle_remediation(at, params, r, fb, rt, task_id)
elif at in _SUBMIT:
r, fb, terminal = self._handle_submit(at, params, r, fb, rt, task_id)
if terminal:
s["done"] = True
else:
r += rt["bad_action"]
fb.append(f"unknown action '{at}' ({rt['bad_action']:+.2f})")
s["last_action_error"] = f"Unknown action type: {at}"
if step_num >= s["max_steps"] and not s["done"]:
r += rt["timeout"]
fb.append(f"timeout ({rt['timeout']:+.2f})")
s["done"] = True
if s["done"]:
result = grade(s["task_id"], s, self._scenario)
grader_score = result["total"]
s["cumulative_reward"] = round(
s["cumulative_reward"] + r + grader_score, 4)
fb.append(f"grader={grader_score:.3f} ({result['feedback']})")
else:
s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4)
s["feedback"] = " | ".join(fb) if fb else "ok"
return (self._build_obs(),
Reward(score=round(r, 4), reason=s["feedback"],
cumulative=s["cumulative_reward"]),
s["done"],
{"step": step_num, "feedback": s["feedback"]})
def state(self) -> EpisodeState:
with self._lock:
if not self._ready:
raise RuntimeError("No active episode — call reset() first.")
s = self._s
return EpisodeState(
episode_id=s["episode_id"], task_id=s["task_id"],
scenario_id=s["scenario_id"], step_count=s["step_count"],
max_steps=s["max_steps"],
action_history=list(s["action_history"]),
queried_data=dict(s["queried_data"]),
submitted=s["submitted"], resolved=s["resolved"],
done=s["done"], cumulative_reward=s["cumulative_reward"],
feedback=s["feedback"])
def _handle_diagnostic(self, at, params, r, fb, rt):
s = self._s
svc = (params.service or "").lower().strip()
known = {v.lower() for v in self._scenario.get("known_services", set())}
tool = self._scenario.get("tool_responses", {}).get(at, {})
key = (at, svc)
if not svc:
r += rt["query_no_service"]
fb.append(f"{at}: no service ({rt['query_no_service']:+.2f})")
s["last_action_error"] = f"{at} requires a service parameter"
return r, fb
if svc not in known:
r += rt["query_unknown_svc"]
fb.append(f"unknown service '{svc}' ({rt['query_unknown_svc']:+.2f})")
s["last_action_error"] = f"Unknown service: {svc}"
return r, fb
if key in s["queried_keys"]:
r += rt["query_repeat"]
fb.append(f"repeat [{at}][{svc}] ({rt['query_repeat']:+.2f})")
elif svc in s["services_queried"]:
r += rt["query_new_action"]
fb.append(f"new action on {svc} ({rt['query_new_action']:+.2f})")
s["queried_keys"].add(key)
else:
r += rt["query_new_svc"]
fb.append(f"new service {svc} ({rt['query_new_svc']:+.2f})")
s["queried_keys"].add(key)
s["services_queried"].add(svc)
result = tool.get(svc, f"No data available for '{svc}'.")
s["queried_data"].setdefault(at, {})[svc] = result
return r, fb
def _handle_remediation(self, at, params, r, fb, rt, task_id):
s = self._s
if task_id == "alert_classification":
r += rt["rem_wrong"]
fb.append(f"remediation in easy task ({rt['rem_wrong']:+.2f})")
s["last_action_error"] = "Remediation not available in alert_classification"
return r, fb
svc = (params.service or "").lower().strip()
flag = (params.flag or "").lower().strip()
runbook = (params.runbook_action or "").lower().strip()
target = (params.target or params.target_version or "").lower().strip()
if not (svc or flag or runbook or target):
r += rt["rem_no_target"]
fb.append(f"{at}: no target ({rt['rem_no_target']:+.2f})")
s["last_action_error"] = f"{at} requires a target"
return r, fb
keys = {at}
if svc: keys.add(f"{at}:{svc}")
if flag: keys.add(f"{at}:{flag}")
if runbook: keys.add(f"execute_runbook_step:{runbook}")
if target: keys.add(f"execute_runbook_step:{target}")
wrong_map = self._scenario.get("wrong_actions", {})
rem_data = self._scenario.get("remediation_data", {})
is_wrong = any(k in wrong_map for k in keys)
if not is_wrong and svc:
for wk in wrong_map:
if ":" in wk:
w_at, w_svc = wk.split(":", 1)
if w_at == at and _svc_match(svc, w_svc):
is_wrong = True
break
if is_wrong:
r += rt["rem_wrong"]
reason = next((wrong_map[k] for k in keys if k in wrong_map), "wrong")
fb.append(f"wrong: {at}{str(reason)[:60]} ({rt['rem_wrong']:+.2f})")
else:
r += rt["rem_good"]
tgt = svc or flag or runbook or target
fb.append(f"executed {at}:{tgt} ({rt['rem_good']:+.2f})")
at_data = rem_data.get(at, {})
result = (at_data.get(svc) or at_data.get(flag) or at_data.get(runbook)
or at_data.get(target) or "action executed successfully")
s["queried_data"].setdefault(at, {})[tgt] = result
return r, fb
def _handle_submit(self, at, params, r, fb, rt, task_id):
s = self._s
correct = _TASK_SUBMIT.get(task_id, "")
if at != correct:
r += rt["submit_wrong"]
fb.append(f"wrong submit '{at}' (need '{correct}') ({rt['submit_wrong']:+.2f})")
s["last_action_error"] = f"Wrong submission type: use {correct}"
return r, fb, False
s["submitted"] = True
r += rt["submit_correct"]
fb.append(f"submitted ({rt['submit_correct']:+.2f})")
if at == "submit_severity":
fb.append(f"severity={(params.severity or '').upper().strip()}")
elif at == "submit_root_cause":
fb.append(f"svc={params.service or ''}, mode={params.failure_mode or ''}")
elif at == "submit_resolution":
summary = params.summary or ""
inv = sum(1 for a in s["action_history"]
if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION)
if summary.strip() and inv >= 1:
s["resolved"] = True
fb.append("resolved")
else:
fb.append("insufficient investigation")
return r, fb, True
def _build_obs(self):
s = self._s
sc = self._scenario
td = self._task_def
return Observation(
episode_id=s["episode_id"], task_id=s["task_id"],
scenario_id=s["scenario_id"], step_count=s["step_count"],
max_steps=s["max_steps"],
incident_summary=sc.get("incident_summary", sc.get("description", "")),
alert=sc.get("alert", {}),
available_actions=td.get("available_actions", []),
queried_data=dict(s["queried_data"]),
cumulative_reward=s["cumulative_reward"],
done=s["done"], feedback=s["feedback"],
known_services=sorted(sc.get("known_services", set())),
last_action_error=s.get("last_action_error"))