incidentops-env / inference.py
Pramod Basavaraj Menasi
fixed small error
61e0797
import json
import os
import sys
import traceback
import httpx
from openai import OpenAI
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.environ.get("API_KEY", "") or os.environ.get("HF_TOKEN", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = "incidentops_env"
TASK_IDS = ["incident_easy", "incident_medium", "incident_hard"]
ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
MAX_STEPS = 12
TEMPERATURE = 0.2
SYSTEM_PROMPT = """You are an expert incident-response engineer.
You are given an incident observation with alert details, severity, affected services, and available actions.
Analyze the situation and choose the BEST single action from the available_actions list.
Rules:
- If logs are not available, request_logs first
- Investigate before escalating
- Escalate to the correct team based on evidence
- Resolve only when the incident is actually fixed
- Minimize steps to stay within SLA
Return ONLY the action string, nothing else. No explanation, no quotes."""
def log_start(task, env, model):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step, action, reward, done, error):
err = error if error else "null"
d = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={d} error={err}", flush=True)
def log_end(success, steps, score, rewards):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
def choose_action_llm(client, obs):
"""Always call the LLM first, fall back to deterministic only on error."""
available = obs.get("available_actions", [])
if not available:
return "resolve_incident"
obs_for_llm = {
"alert_summary": obs.get("alert_summary", ""),
"severity": obs.get("severity", ""),
"likely_cause": obs.get("likely_cause", ""),
"hf_confidence": obs.get("hf_confidence", 0.0),
"logs_available": obs.get("logs_available", False),
"log_snippet": obs.get("log_snippet", ""),
"services_affected": obs.get("services_affected", []),
"elapsed_steps": obs.get("elapsed_steps", 0),
"sla_steps_remaining": obs.get("sla_steps_remaining", 0),
"action_history": obs.get("action_history", []),
"available_actions": available,
"incident_resolved": obs.get("incident_resolved", False),
"wrong_escalations": obs.get("wrong_escalations", 0),
}
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(obs_for_llm)},
],
temperature=TEMPERATURE,
max_tokens=20,
)
text = (response.choices[0].message.content or "").strip()
text = text.splitlines()[0].strip().strip("'\"` ")
if text in available:
return text
for action in available:
if action in text or text in action:
return action
except Exception as e:
print(f"[DEBUG] LLM call error: {e}", flush=True)
return choose_action_deterministic(obs)
def choose_action_deterministic(obs):
"""Fallback deterministic policy."""
available = obs.get("available_actions", [])
logs_available = obs.get("logs_available", False)
likely_cause = obs.get("likely_cause", "unknown")
if not available:
return "resolve_incident"
if not logs_available and "request_logs" in available:
return "request_logs"
if likely_cause == "bad_deployment" and "rollback_deploy" in available:
return "rollback_deploy"
if likely_cause == "dependency_issue" and "query_dependencies" in available:
return "query_dependencies"
if likely_cause == "ambiguous" and "query_region_health" in available:
return "query_region_health"
if likely_cause == "dns_issue" and "query_dns_status" in available:
return "query_dns_status"
if likely_cause == "db_timeout" and "escalate_db_team" in available:
return "escalate_db_team"
if likely_cause == "dns_issue" and "escalate_network_team" in available:
return "escalate_network_team"
if likely_cause == "dns_issue" and "broadcast_status_page" in available:
return "broadcast_status_page"
if "restart_service" in available and likely_cause in ("db_timeout", "bad_deployment"):
return "restart_service"
if "resolve_incident" in available:
return "resolve_incident"
return available[0] if available else "resolve_incident"
def extract_obs(data):
if "observation" in data:
obs = data["observation"]
else:
obs = data
if isinstance(obs, str):
obs = json.loads(obs)
return obs
def run_task(client, http, task_id):
rewards = []
steps_taken = 0
success = False
score = 0.0
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
r = http.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30.0)
r.raise_for_status()
obs = extract_obs(r.json())
finished = obs.get("done", False) or obs.get("incident_resolved", False)
for step in range(1, MAX_STEPS + 1):
if finished:
break
action_name = choose_action_llm(client, obs)
r = http.post(
f"{ENV_URL}/step",
json={"action": {"action": action_name}},
timeout=30.0,
)
r.raise_for_status()
step_data = r.json()
obs = extract_obs(step_data)
reward = float(step_data.get("reward", obs.get("reward", 0.0)))
finished = bool(
step_data.get("done", obs.get("done", False))
or obs.get("incident_resolved", False)
)
rewards.append(reward)
steps_taken = step
log_step(step, action_name, reward, finished, None)
try:
r = http.get(f"{ENV_URL}/grade", params={"task_id": task_id}, timeout=30.0)
r.raise_for_status()
grade = r.json()
score = float(grade.get("score", 0.0))
success = bool(grade.get("success", False))
except Exception as e:
print(f"[DEBUG] Grade error: {e}", flush=True)
success = obs.get("incident_resolved", False)
score = max(0.0, min(1.0, sum(rewards) / 5.0))
except Exception as e:
print(f"[DEBUG] Error in task {task_id}: {e}", flush=True)
traceback.print_exc()
finally:
log_end(success, steps_taken, score, rewards)
def main():
if not API_KEY:
print("[ERROR] No API_KEY or HF_TOKEN set!", flush=True)
sys.exit(1)
client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY,
)
http = httpx.Client()
try:
r = http.get(f"{ENV_URL}/tasks", timeout=10.0)
r.raise_for_status()
except Exception as e:
print(f"[ERROR] Server not reachable: {e}", flush=True)
for tid in TASK_IDS:
log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)
log_end(False, 0, 0.0, [])
return
for task_id in TASK_IDS:
run_task(client, http, task_id)
http.close()
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"[FATAL] {e}", flush=True)
traceback.print_exc()