Spaces:
Sleeping
Sleeping
| """Baseline agent using GPT-4o-mini with Chain-of-Thought reasoning. | |
| Supports a MOCK_OPENAI mode for running without a real API key. | |
| Set OPENAI_API_KEY=mock to activate mock mode. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import logging | |
| from typing import Any | |
| import httpx | |
| logger = logging.getLogger(__name__) | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "mock") | |
| MOCK_MODE = OPENAI_API_KEY == "mock" | |
| BASE_URL = os.getenv("OPENENV_BASE_URL", "http://127.0.0.1:7860") | |
| # --------------------------------------------------------------------------- | |
| # Mock OpenAI responses (Chain-of-Thought strategies per task) | |
| # --------------------------------------------------------------------------- | |
| MOCK_STRATEGIES: dict[str, list[dict[str, Any]]] = { | |
| "task_auth_restart": [ | |
| {"action_type": "check_health", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "gateway"}, | |
| {"action_type": "check_logs", "target_service": "auth"}, | |
| {"action_type": "restart_service", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "gateway"}, | |
| ], | |
| "task_db_log_analysis": [ | |
| {"action_type": "check_health", "target_service": "database"}, | |
| {"action_type": "check_logs", "target_service": "database"}, | |
| {"action_type": "analyze_logs", "target_service": "database"}, | |
| {"action_type": "update_config", "target_service": "database", "parameters": {"max_connections": 300}}, | |
| {"action_type": "restart_service", "target_service": "database"}, | |
| {"action_type": "check_health", "target_service": "database"}, | |
| {"action_type": "check_health", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "gateway"}, | |
| ], | |
| "task_redis_config": [ | |
| {"action_type": "check_health", "target_service": "payment"}, | |
| {"action_type": "run_diagnostics", "target_service": "payment"}, | |
| {"action_type": "check_logs", "target_service": "payment"}, | |
| {"action_type": "run_diagnostics", "target_service": "gateway"}, | |
| { | |
| "action_type": "update_config", | |
| "target_service": "payment", | |
| "parameters": {"maxmemory": "134217728", "maxmemory_policy": "allkeys-lru"}, | |
| }, | |
| {"action_type": "check_health", "target_service": "payment"}, | |
| ], | |
| "task_cascading_failure": [ | |
| {"action_type": "check_health", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "database"}, | |
| {"action_type": "analyze_logs", "target_service": "database"}, | |
| {"action_type": "check_logs", "target_service": "database"}, | |
| {"action_type": "update_config", "target_service": "database", "parameters": {"connection_timeout": 5}}, | |
| {"action_type": "restart_service", "target_service": "database"}, | |
| {"action_type": "restart_service", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "auth"}, | |
| {"action_type": "check_health", "target_service": "gateway"}, | |
| ], | |
| "task_memory_leak": [ | |
| {"action_type": "check_health", "target_service": "payment"}, | |
| {"action_type": "run_diagnostics", "target_service": "payment"}, | |
| {"action_type": "check_logs", "target_service": "payment"}, | |
| {"action_type": "analyze_logs", "target_service": "payment"}, | |
| {"action_type": "update_config", "target_service": "payment", "parameters": {"key_pattern": "session:leak:*", "ttl": 60}}, | |
| {"action_type": "update_config", "target_service": "payment", "parameters": {"maxmemory-policy": "allkeys-lru"}}, | |
| {"action_type": "check_health", "target_service": "payment"}, | |
| ], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # OpenAI client (real or mock) | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert SRE agent. You must diagnose and resolve infrastructure incidents. | |
| Available actions: | |
| - check_health: Check a service's health status | |
| - check_logs: View recent logs for a service | |
| - restart_service: Restart a service | |
| - analyze_logs: Deep analysis of logs across services | |
| - update_config: Update service configuration (pass parameters dict) | |
| - run_diagnostics: Run diagnostics on a service | |
| - scale_service: Scale a service | |
| - nop: Do nothing | |
| Available services: gateway, auth, payment, database | |
| Think step by step: | |
| 1. First diagnose — check health and logs | |
| 2. Identify the root cause | |
| 3. Apply the minimal fix | |
| 4. Verify the fix worked | |
| Respond with JSON: {"action_type": "...", "target_service": "...", "parameters": {...}} | |
| """ | |
| def call_openai(messages: list[dict[str, str]]) -> dict[str, Any]: | |
| """Call OpenAI API or return mock response.""" | |
| if MOCK_MODE: | |
| return {"role": "assistant", "content": "Using mock mode — action from strategy."} | |
| resp = httpx.post( | |
| "https://api.openai.com/v1/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {OPENAI_API_KEY}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": "gpt-4o-mini", | |
| "messages": messages, | |
| "temperature": 0.1, | |
| "max_tokens": 500, | |
| }, | |
| timeout=30.0, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json()["choices"][0]["message"] | |
| def parse_action(content: str) -> dict[str, Any] | None: | |
| """Extract JSON action from LLM response.""" | |
| try: | |
| # Try to find JSON in the response | |
| start = content.find("{") | |
| end = content.rfind("}") + 1 | |
| if start >= 0 and end > start: | |
| return json.loads(content[start:end]) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Baseline runner | |
| # --------------------------------------------------------------------------- | |
| def run_baseline_task(task_id: str, base_url: str = BASE_URL) -> dict[str, Any]: | |
| """Run the baseline agent on a single task.""" | |
| with httpx.Client(base_url=base_url, timeout=30.0) as client: | |
| # Reset | |
| reset_resp = client.post("/reset", json={"task_id": task_id}) | |
| reset_resp.raise_for_status() | |
| env_state = reset_resp.json() | |
| actions_taken = [] | |
| mock_actions = MOCK_STRATEGIES.get(task_id, []) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Task: {task_id}\nCurrent state: {json.dumps(env_state, default=str)}"}, | |
| ] | |
| for step_idx in range(env_state.get("max_steps", 50)): | |
| if env_state.get("done"): | |
| break | |
| # Get action | |
| if MOCK_MODE and step_idx < len(mock_actions): | |
| action_data = mock_actions[step_idx] | |
| else: | |
| llm_resp = call_openai(messages) | |
| action_data = parse_action(llm_resp.get("content", "")) | |
| if action_data is None: | |
| action_data = {"action_type": "nop"} | |
| # Ensure required fields | |
| if "parameters" not in action_data: | |
| action_data["parameters"] = {} | |
| # Execute | |
| step_resp = client.post("/step", json={"action": action_data}) | |
| if step_resp.status_code != 200: | |
| logger.warning("Step failed: %s", step_resp.text) | |
| break | |
| observation = step_resp.json() | |
| actions_taken.append(action_data) | |
| # Update conversation for LLM | |
| messages.append({"role": "assistant", "content": json.dumps(action_data)}) | |
| messages.append({"role": "user", "content": f"Observation: {json.dumps(observation, default=str)}"}) | |
| # Get updated state | |
| state_resp = client.get("/state") | |
| env_state = state_resp.json() | |
| # Grade | |
| grade_resp = client.post("/grader", json={"task_id": task_id}) | |
| grade_result = grade_resp.json() | |
| return { | |
| "task_id": task_id, | |
| "actions_taken": actions_taken, | |
| "final_reward": env_state.get("total_reward", 0.0), | |
| "passed": grade_result.get("passed", False), | |
| "steps_used": len(actions_taken), | |
| } | |
| def run_baseline_all_tasks(base_url: str = BASE_URL) -> list[dict[str, Any]]: | |
| """Run baseline on all tasks and return results.""" | |
| task_ids = ["task_auth_restart", "task_db_log_analysis", "task_redis_config", "task_cascading_failure", "task_memory_leak"] | |
| results = [] | |
| for tid in task_ids: | |
| try: | |
| result = run_baseline_task(tid, base_url) | |
| results.append(result) | |
| logger.info(f"Task {tid}: passed={result['passed']}, reward={result['final_reward']:.2f}") | |
| except Exception as e: | |
| logger.error(f"Task {tid} failed: {e}") | |
| results.append({ | |
| "task_id": tid, | |
| "actions_taken": [], | |
| "final_reward": 0.0, | |
| "passed": False, | |
| "steps_used": 0, | |
| }) | |
| return results | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| results = run_baseline_all_tasks() | |
| for r in results: | |
| status = "PASS" if r["passed"] else "FAIL" | |
| print(f"[{status}] {r['task_id']}: reward={r['final_reward']:.2f}, steps={r['steps_used']}") | |