cloud-incident / inference.py
Elliot89's picture
feat: cloud incident response OpenEnv v0.1.0
37204eb
"""
inference.py β€” Cloud Incident Response OpenEnv baseline inference script.
Required env vars:
API_BASE_URL OpenAI-compatible LLM endpoint
MODEL_NAME Model identifier
HF_TOKEN API key (Hugging Face token or any OpenAI-compatible key)
Also accepts OPENAI_API_KEY as fallback for HF_TOKEN.
Runs the agent against all 3 tasks x 2 scenarios = 6 episodes.
Final stdout line is valid JSON β€” required by hackathon validator.
Usage:
export API_BASE_URL="https://api-inference.huggingface.co/v1"
export MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
export HF_TOKEN="hf_your_token_here"
python inference.py
"""
from __future__ import annotations
import json
import os
import sys
import requests
from openai import OpenAI
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
# ── Config β€” accepts both HF_TOKEN and OPENAI_API_KEY ────────────────────────
API_BASE_URL = os.environ.get(
"API_BASE_URL", "https://api-inference.huggingface.co/v1"
)
MODEL_NAME = os.environ.get(
"MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct"
)
HF_TOKEN = (
os.environ.get("HF_TOKEN")
or os.environ.get("OPENAI_API_KEY")
or ""
)
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
if not HF_TOKEN:
print(
"[WARN] Neither HF_TOKEN nor OPENAI_API_KEY is set β€” "
"LLM calls will fail.",
file=sys.stderr,
)
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) \
responding to a live production cloud incident.
You receive the incident observation as JSON. \
Respond with ONLY a single valid JSON action β€” no markdown, no explanation.
Available action_types:
Diagnostic (gather evidence):
{"action_type": "query_logs", "parameters": {"service": "<name>"}}
{"action_type": "check_metrics", "parameters": {"service": "<name>"}}
{"action_type": "check_dependencies", "parameters": {"service": "<name>"}}
{"action_type": "check_recent_deploys", "parameters": {"service": "<name>"}}
{"action_type": "check_service_status", "parameters": {"service": "<name>"}}
Remediation (fix the incident):
{"action_type": "restart_service", "parameters": {"service": "<name>"}}
{"action_type": "rollback_deploy", "parameters": {"service": "<name>", "target_version": "previous"}}
{"action_type": "scale_service", "parameters": {"service": "<name>", "replicas": 5}}
{"action_type": "disable_feature_flag", "parameters": {"flag": "<flag_name>"}}
{"action_type": "execute_runbook_step", "parameters": {"runbook_action": "<action>"}}
Submission (ends the episode β€” pick ONE matching the task):
{"action_type": "submit_severity", "parameters": {"severity": "P1|P2|P3|P4", "service": "<root_cause>"}}
{"action_type": "submit_root_cause", "parameters": {"service": "<root_cause>", "failure_mode": "<description>"}}
{"action_type": "submit_resolution", "parameters": {"summary": "<detailed description of what happened and what you did>"}}
Strategy:
- alert_classification (3 steps max): Query 1-2 key services, then submit_severity.
- root_cause_analysis (10 steps max): Trace the failure chain, identify root service, submit_root_cause.
- remediation_planning (15 steps max): Diagnose, fix the root cause with remediation actions, submit_resolution.
Output ONLY the JSON object. Nothing else."""
def _fmt(obs: dict) -> str:
parts = [
f"TASK: {obs.get('task_id')} | "
f"Step {obs.get('step_count')}/{obs.get('max_steps')}",
f"INCIDENT: {obs.get('incident_summary', '')}",
]
if obs.get("alert"):
parts.append("ALERT:\n" + json.dumps(obs["alert"], indent=2))
if obs.get("available_actions"):
parts.append(f"AVAILABLE ACTIONS: {obs['available_actions']}")
if obs.get("queried_data"):
parts.append("DATA GATHERED:\n" + json.dumps(obs["queried_data"], indent=2))
parts.append(f"CUMULATIVE REWARD: {obs.get('cumulative_reward', 0.0)}")
parts.append(f"FEEDBACK: {obs.get('feedback', '')}")
return "\n\n".join(parts)
def _parse(text: str) -> dict:
text = text.strip()
if text.startswith("```"):
text = "\n".join(
l for l in text.splitlines() if not l.startswith("```")
).strip()
try:
return json.loads(text)
except json.JSONDecodeError:
s, e = text.find("{"), text.rfind("}") + 1
if s != -1 and e > s:
return json.loads(text[s:e])
raise
def _fallback(task_id: str) -> dict:
"""Safe fallback action when LLM output can't be parsed."""
if task_id == "alert_classification":
return {"action_type": "submit_severity",
"parameters": {"severity": "P2", "service": "unknown"}}
if task_id == "root_cause_analysis":
return {"action_type": "submit_root_cause",
"parameters": {"service": "unknown", "failure_mode": "unknown"}}
return {"action_type": "submit_resolution",
"parameters": {"summary": "Unable to determine root cause."}}
def _run_episode(task_id: str, scenario_index: int) -> float:
r = requests.post(
f"{ENV_BASE_URL}/reset",
params={"task_id": task_id, "scenario_index": scenario_index},
timeout=30,
)
r.raise_for_status()
obs = r.json()
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for _ in range(obs.get("max_steps", 10)):
messages.append({"role": "user", "content": _fmt(obs)})
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.0,
max_tokens=256,
)
raw = resp.choices[0].message.content
except Exception as e:
print(f" [WARN] LLM call failed: {e}", file=sys.stderr)
raw = json.dumps(_fallback(task_id))
messages.append({"role": "assistant", "content": raw})
try:
action = _parse(raw)
except Exception as e:
print(f" [WARN] parse failed: {e}", file=sys.stderr)
action = _fallback(task_id)
sr = requests.post(
f"{ENV_BASE_URL}/step",
json=action,
headers={"Content-Type": "application/json"},
timeout=30,
)
sr.raise_for_status()
result = sr.json()
obs = result["observation"]
if result.get("done"):
break
g = requests.get(f"{ENV_BASE_URL}/grader", timeout=30)
g.raise_for_status()
return g.json().get("total", 0.0)
def main():
runs = [
("alert_classification", 0),
("alert_classification", 1),
("root_cause_analysis", 0),
("root_cause_analysis", 1),
("remediation_planning", 0),
("remediation_planning", 1),
]
results: dict[str, list[float]] = {}
print(f"{'Task':<32} {'S':>2} {'Score':>7}")
print("-" * 46)
for task_id, scenario_index in runs:
try:
score = _run_episode(task_id, scenario_index)
except Exception as e:
print(f" [ERROR] {task_id} s{scenario_index}: {e}",
file=sys.stderr)
score = 0.0
label = f"{task_id} [s{scenario_index}]"
print(f"{label:<32} {scenario_index:>2} {score:>7.4f}")
results.setdefault(task_id, []).append(score)
print("-" * 46)
summary = {
t: round(sum(v) / len(v), 4)
for t, v in results.items()
}
summary["overall"] = round(
sum(summary.values()) / len(summary), 4
)
print("\nBaseline Summary:")
for k, v in summary.items():
print(f" {k:<32}: {v:.4f}")
# Final line MUST be valid JSON β€” parsed by /baseline endpoint
print(json.dumps(summary))
if __name__ == "__main__":
main()