sre-incident-responder / baseline_script.py
kaori02's picture
feat: add expert task, gradient rewards, and test suite
697d9ce
"""Baseline agent using GPT-4o-mini with Chain-of-Thought reasoning.
Supports a MOCK_OPENAI mode for running without a real API key.
Set OPENAI_API_KEY=mock to activate mock mode.
"""
from __future__ import annotations
import json
import os
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "mock")
MOCK_MODE = OPENAI_API_KEY == "mock"
BASE_URL = os.getenv("OPENENV_BASE_URL", "http://127.0.0.1:7860")
# ---------------------------------------------------------------------------
# Mock OpenAI responses (Chain-of-Thought strategies per task)
# ---------------------------------------------------------------------------
MOCK_STRATEGIES: dict[str, list[dict[str, Any]]] = {
"task_auth_restart": [
{"action_type": "check_health", "target_service": "auth"},
{"action_type": "check_health", "target_service": "gateway"},
{"action_type": "check_logs", "target_service": "auth"},
{"action_type": "restart_service", "target_service": "auth"},
{"action_type": "check_health", "target_service": "auth"},
{"action_type": "check_health", "target_service": "gateway"},
],
"task_db_log_analysis": [
{"action_type": "check_health", "target_service": "database"},
{"action_type": "check_logs", "target_service": "database"},
{"action_type": "analyze_logs", "target_service": "database"},
{"action_type": "update_config", "target_service": "database", "parameters": {"max_connections": 300}},
{"action_type": "restart_service", "target_service": "database"},
{"action_type": "check_health", "target_service": "database"},
{"action_type": "check_health", "target_service": "auth"},
{"action_type": "check_health", "target_service": "gateway"},
],
"task_redis_config": [
{"action_type": "check_health", "target_service": "payment"},
{"action_type": "run_diagnostics", "target_service": "payment"},
{"action_type": "check_logs", "target_service": "payment"},
{"action_type": "run_diagnostics", "target_service": "gateway"},
{
"action_type": "update_config",
"target_service": "payment",
"parameters": {"maxmemory": "134217728", "maxmemory_policy": "allkeys-lru"},
},
{"action_type": "check_health", "target_service": "payment"},
],
"task_cascading_failure": [
{"action_type": "check_health", "target_service": "auth"},
{"action_type": "check_health", "target_service": "database"},
{"action_type": "analyze_logs", "target_service": "database"},
{"action_type": "check_logs", "target_service": "database"},
{"action_type": "update_config", "target_service": "database", "parameters": {"connection_timeout": 5}},
{"action_type": "restart_service", "target_service": "database"},
{"action_type": "restart_service", "target_service": "auth"},
{"action_type": "check_health", "target_service": "auth"},
{"action_type": "check_health", "target_service": "gateway"},
],
"task_memory_leak": [
{"action_type": "check_health", "target_service": "payment"},
{"action_type": "run_diagnostics", "target_service": "payment"},
{"action_type": "check_logs", "target_service": "payment"},
{"action_type": "analyze_logs", "target_service": "payment"},
{"action_type": "update_config", "target_service": "payment", "parameters": {"key_pattern": "session:leak:*", "ttl": 60}},
{"action_type": "update_config", "target_service": "payment", "parameters": {"maxmemory-policy": "allkeys-lru"}},
{"action_type": "check_health", "target_service": "payment"},
],
}
# ---------------------------------------------------------------------------
# OpenAI client (real or mock)
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are an expert SRE agent. You must diagnose and resolve infrastructure incidents.
Available actions:
- check_health: Check a service's health status
- check_logs: View recent logs for a service
- restart_service: Restart a service
- analyze_logs: Deep analysis of logs across services
- update_config: Update service configuration (pass parameters dict)
- run_diagnostics: Run diagnostics on a service
- scale_service: Scale a service
- nop: Do nothing
Available services: gateway, auth, payment, database
Think step by step:
1. First diagnose — check health and logs
2. Identify the root cause
3. Apply the minimal fix
4. Verify the fix worked
Respond with JSON: {"action_type": "...", "target_service": "...", "parameters": {...}}
"""
def call_openai(messages: list[dict[str, str]]) -> dict[str, Any]:
"""Call OpenAI API or return mock response."""
if MOCK_MODE:
return {"role": "assistant", "content": "Using mock mode — action from strategy."}
resp = httpx.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": "gpt-4o-mini",
"messages": messages,
"temperature": 0.1,
"max_tokens": 500,
},
timeout=30.0,
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]
def parse_action(content: str) -> dict[str, Any] | None:
"""Extract JSON action from LLM response."""
try:
# Try to find JSON in the response
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except (json.JSONDecodeError, ValueError):
pass
return None
# ---------------------------------------------------------------------------
# Baseline runner
# ---------------------------------------------------------------------------
def run_baseline_task(task_id: str, base_url: str = BASE_URL) -> dict[str, Any]:
"""Run the baseline agent on a single task."""
with httpx.Client(base_url=base_url, timeout=30.0) as client:
# Reset
reset_resp = client.post("/reset", json={"task_id": task_id})
reset_resp.raise_for_status()
env_state = reset_resp.json()
actions_taken = []
mock_actions = MOCK_STRATEGIES.get(task_id, [])
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Task: {task_id}\nCurrent state: {json.dumps(env_state, default=str)}"},
]
for step_idx in range(env_state.get("max_steps", 50)):
if env_state.get("done"):
break
# Get action
if MOCK_MODE and step_idx < len(mock_actions):
action_data = mock_actions[step_idx]
else:
llm_resp = call_openai(messages)
action_data = parse_action(llm_resp.get("content", ""))
if action_data is None:
action_data = {"action_type": "nop"}
# Ensure required fields
if "parameters" not in action_data:
action_data["parameters"] = {}
# Execute
step_resp = client.post("/step", json={"action": action_data})
if step_resp.status_code != 200:
logger.warning("Step failed: %s", step_resp.text)
break
observation = step_resp.json()
actions_taken.append(action_data)
# Update conversation for LLM
messages.append({"role": "assistant", "content": json.dumps(action_data)})
messages.append({"role": "user", "content": f"Observation: {json.dumps(observation, default=str)}"})
# Get updated state
state_resp = client.get("/state")
env_state = state_resp.json()
# Grade
grade_resp = client.post("/grader", json={"task_id": task_id})
grade_result = grade_resp.json()
return {
"task_id": task_id,
"actions_taken": actions_taken,
"final_reward": env_state.get("total_reward", 0.0),
"passed": grade_result.get("passed", False),
"steps_used": len(actions_taken),
}
def run_baseline_all_tasks(base_url: str = BASE_URL) -> list[dict[str, Any]]:
"""Run baseline on all tasks and return results."""
task_ids = ["task_auth_restart", "task_db_log_analysis", "task_redis_config", "task_cascading_failure", "task_memory_leak"]
results = []
for tid in task_ids:
try:
result = run_baseline_task(tid, base_url)
results.append(result)
logger.info(f"Task {tid}: passed={result['passed']}, reward={result['final_reward']:.2f}")
except Exception as e:
logger.error(f"Task {tid} failed: {e}")
results.append({
"task_id": tid,
"actions_taken": [],
"final_reward": 0.0,
"passed": False,
"steps_used": 0,
})
return results
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
results = run_baseline_all_tasks()
for r in results:
status = "PASS" if r["passed"] else "FAIL"
print(f"[{status}] {r['task_id']}: reward={r['final_reward']:.2f}, steps={r['steps_used']}")