Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| import traceback | |
| import httpx | |
| from openai import OpenAI | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| API_KEY = os.environ.get("API_KEY", "") or os.environ.get("HF_TOKEN", "") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| BENCHMARK = "incidentops_env" | |
| TASK_IDS = ["incident_easy", "incident_medium", "incident_hard"] | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000") | |
| MAX_STEPS = 12 | |
| TEMPERATURE = 0.2 | |
| SYSTEM_PROMPT = """You are an expert incident-response engineer. | |
| You are given an incident observation with alert details, severity, affected services, and available actions. | |
| Analyze the situation and choose the BEST single action from the available_actions list. | |
| Rules: | |
| - If logs are not available, request_logs first | |
| - Investigate before escalating | |
| - Escalate to the correct team based on evidence | |
| - Resolve only when the incident is actually fixed | |
| - Minimize steps to stay within SLA | |
| Return ONLY the action string, nothing else. No explanation, no quotes.""" | |
| def log_start(task, env, model): | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step, action, reward, done, error): | |
| err = error if error else "null" | |
| d = str(done).lower() | |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={d} error={err}", flush=True) | |
| def log_end(success, steps, score, rewards): | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True) | |
| def choose_action_llm(client, obs): | |
| """Always call the LLM first, fall back to deterministic only on error.""" | |
| available = obs.get("available_actions", []) | |
| if not available: | |
| return "resolve_incident" | |
| obs_for_llm = { | |
| "alert_summary": obs.get("alert_summary", ""), | |
| "severity": obs.get("severity", ""), | |
| "likely_cause": obs.get("likely_cause", ""), | |
| "hf_confidence": obs.get("hf_confidence", 0.0), | |
| "logs_available": obs.get("logs_available", False), | |
| "log_snippet": obs.get("log_snippet", ""), | |
| "services_affected": obs.get("services_affected", []), | |
| "elapsed_steps": obs.get("elapsed_steps", 0), | |
| "sla_steps_remaining": obs.get("sla_steps_remaining", 0), | |
| "action_history": obs.get("action_history", []), | |
| "available_actions": available, | |
| "incident_resolved": obs.get("incident_resolved", False), | |
| "wrong_escalations": obs.get("wrong_escalations", 0), | |
| } | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps(obs_for_llm)}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=20, | |
| ) | |
| text = (response.choices[0].message.content or "").strip() | |
| text = text.splitlines()[0].strip().strip("'\"` ") | |
| if text in available: | |
| return text | |
| for action in available: | |
| if action in text or text in action: | |
| return action | |
| except Exception as e: | |
| print(f"[DEBUG] LLM call error: {e}", flush=True) | |
| return choose_action_deterministic(obs) | |
| def choose_action_deterministic(obs): | |
| """Fallback deterministic policy.""" | |
| available = obs.get("available_actions", []) | |
| logs_available = obs.get("logs_available", False) | |
| likely_cause = obs.get("likely_cause", "unknown") | |
| if not available: | |
| return "resolve_incident" | |
| if not logs_available and "request_logs" in available: | |
| return "request_logs" | |
| if likely_cause == "bad_deployment" and "rollback_deploy" in available: | |
| return "rollback_deploy" | |
| if likely_cause == "dependency_issue" and "query_dependencies" in available: | |
| return "query_dependencies" | |
| if likely_cause == "ambiguous" and "query_region_health" in available: | |
| return "query_region_health" | |
| if likely_cause == "dns_issue" and "query_dns_status" in available: | |
| return "query_dns_status" | |
| if likely_cause == "db_timeout" and "escalate_db_team" in available: | |
| return "escalate_db_team" | |
| if likely_cause == "dns_issue" and "escalate_network_team" in available: | |
| return "escalate_network_team" | |
| if likely_cause == "dns_issue" and "broadcast_status_page" in available: | |
| return "broadcast_status_page" | |
| if "restart_service" in available and likely_cause in ("db_timeout", "bad_deployment"): | |
| return "restart_service" | |
| if "resolve_incident" in available: | |
| return "resolve_incident" | |
| return available[0] if available else "resolve_incident" | |
| def extract_obs(data): | |
| if "observation" in data: | |
| obs = data["observation"] | |
| else: | |
| obs = data | |
| if isinstance(obs, str): | |
| obs = json.loads(obs) | |
| return obs | |
| def run_task(client, http, task_id): | |
| rewards = [] | |
| steps_taken = 0 | |
| success = False | |
| score = 0.0 | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| r = http.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30.0) | |
| r.raise_for_status() | |
| obs = extract_obs(r.json()) | |
| finished = obs.get("done", False) or obs.get("incident_resolved", False) | |
| for step in range(1, MAX_STEPS + 1): | |
| if finished: | |
| break | |
| action_name = choose_action_llm(client, obs) | |
| r = http.post( | |
| f"{ENV_URL}/step", | |
| json={"action": {"action": action_name}}, | |
| timeout=30.0, | |
| ) | |
| r.raise_for_status() | |
| step_data = r.json() | |
| obs = extract_obs(step_data) | |
| reward = float(step_data.get("reward", obs.get("reward", 0.0))) | |
| finished = bool( | |
| step_data.get("done", obs.get("done", False)) | |
| or obs.get("incident_resolved", False) | |
| ) | |
| rewards.append(reward) | |
| steps_taken = step | |
| log_step(step, action_name, reward, finished, None) | |
| try: | |
| r = http.get(f"{ENV_URL}/grade", params={"task_id": task_id}, timeout=30.0) | |
| r.raise_for_status() | |
| grade = r.json() | |
| score = float(grade.get("score", 0.0)) | |
| success = bool(grade.get("success", False)) | |
| except Exception as e: | |
| print(f"[DEBUG] Grade error: {e}", flush=True) | |
| success = obs.get("incident_resolved", False) | |
| score = max(0.0, min(1.0, sum(rewards) / 5.0)) | |
| except Exception as e: | |
| print(f"[DEBUG] Error in task {task_id}: {e}", flush=True) | |
| traceback.print_exc() | |
| finally: | |
| log_end(success, steps_taken, score, rewards) | |
| def main(): | |
| if not API_KEY: | |
| print("[ERROR] No API_KEY or HF_TOKEN set!", flush=True) | |
| sys.exit(1) | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=API_KEY, | |
| ) | |
| http = httpx.Client() | |
| try: | |
| r = http.get(f"{ENV_URL}/tasks", timeout=10.0) | |
| r.raise_for_status() | |
| except Exception as e: | |
| print(f"[ERROR] Server not reachable: {e}", flush=True) | |
| for tid in TASK_IDS: | |
| log_start(task=tid, env=BENCHMARK, model=MODEL_NAME) | |
| log_end(False, 0, 0.0, []) | |
| return | |
| for task_id in TASK_IDS: | |
| run_task(client, http, task_id) | |
| http.close() | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as e: | |
| print(f"[FATAL] {e}", flush=True) | |
| traceback.print_exc() |