Pramod Basavaraj Menasi commited on
Commit
18f9f38
·
1 Parent(s): 66ae73a

fixed errors

Browse files
Files changed (1) hide show
  1. inference.py +106 -41
inference.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import sys
4
  import traceback
5
 
6
- print("[DEBUG] line 6", flush=True)
 
7
 
8
  try:
9
  from dotenv import load_dotenv
@@ -11,26 +12,29 @@ try:
11
  except ImportError:
12
  pass
13
 
14
- print("[DEBUG] line 14", flush=True)
15
-
16
- import httpx
17
-
18
- print("[DEBUG] line 18", flush=True)
19
-
20
- from openai import OpenAI
21
-
22
- print("[DEBUG] line 22", flush=True)
23
 
24
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
25
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
26
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
27
  BENCHMARK = "incidentops_env"
28
  TASK_IDS = ["incident_easy", "incident_medium", "incident_hard"]
29
- ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
30
  MAX_STEPS = 12
31
  TEMPERATURE = 0.2
32
 
33
- print("[DEBUG] line 33", flush=True)
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def log_start(task, env, model):
@@ -48,7 +52,59 @@ def log_end(success, steps, score, rewards):
48
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
49
 
50
 
51
- def choose_action(obs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  available = obs.get("available_actions", [])
53
  logs_available = obs.get("logs_available", False)
54
  likely_cause = obs.get("likely_cause", "unknown")
@@ -88,8 +144,7 @@ def extract_obs(data):
88
  return obs
89
 
90
 
91
- def run_task(http, task_id):
92
- print(f"[DEBUG] Starting task: {task_id}", flush=True)
93
  rewards = []
94
  steps_taken = 0
95
  success = False
@@ -98,10 +153,10 @@ def run_task(http, task_id):
98
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
99
 
100
  try:
 
101
  r = http.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30.0)
102
  r.raise_for_status()
103
  obs = extract_obs(r.json())
104
- print(f"[DEBUG] Reset OK: cause={obs.get('likely_cause')}", flush=True)
105
 
106
  finished = obs.get("done", False) or obs.get("incident_resolved", False)
107
 
@@ -109,9 +164,10 @@ def run_task(http, task_id):
109
  if finished:
110
  break
111
 
112
- action_name = choose_action(obs)
113
- print(f"[DEBUG] Step {step}: {action_name}", flush=True)
114
 
 
115
  r = http.post(
116
  f"{ENV_URL}/step",
117
  json={"action": {"action": action_name}},
@@ -131,49 +187,58 @@ def run_task(http, task_id):
131
  steps_taken = step
132
  log_step(step, action_name, reward, finished, None)
133
 
134
- r = http.get(f"{ENV_URL}/grade", params={"task_id": task_id}, timeout=30.0)
135
- r.raise_for_status()
136
- grade = r.json()
137
- score = float(grade.get("score", 0.0))
138
- success = bool(grade.get("success", False))
139
- print(f"[DEBUG] Grade: {grade}", flush=True)
 
 
 
 
 
140
 
141
  except Exception as e:
142
- print(f"[DEBUG] Error: {e}", flush=True)
143
  traceback.print_exc()
144
 
145
  finally:
146
  log_end(success, steps_taken, score, rewards)
147
 
148
 
149
- print("[DEBUG] line 137 - about to define main", flush=True)
150
-
151
-
152
  def main():
153
- print(f"[DEBUG] main() called", flush=True)
154
- print(f"[DEBUG] ENV_URL={ENV_URL}", flush=True)
 
 
 
 
 
 
 
155
 
156
  http = httpx.Client()
157
 
 
158
  try:
159
  r = http.get(f"{ENV_URL}/tasks", timeout=10.0)
160
- print(f"[DEBUG] Server OK: {r.status_code}", flush=True)
161
  except Exception as e:
162
- print(f"[ERROR] Server not running: {e}", flush=True)
 
 
 
163
  return
164
 
 
165
  for task_id in TASK_IDS:
166
- run_task(http, task_id)
167
 
168
  http.close()
169
- print("[DEBUG] Done!", flush=True)
170
-
171
 
172
- print("[DEBUG] line 160 - about to check name", flush=True)
173
- print(f"[DEBUG] name = {__name__}", flush=True)
174
 
175
  if __name__ == "__main__":
176
- print("[DEBUG] entering main()", flush=True)
177
  try:
178
  main()
179
  except Exception as e:
 
3
  import sys
4
  import traceback
5
 
6
+ import httpx
7
+ from openai import OpenAI
8
 
9
  try:
10
  from dotenv import load_dotenv
 
12
  except ImportError:
13
  pass
14
 
15
+ # MANDATORY: Use the injected environment variables
16
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
17
+ API_KEY = os.environ.get("API_KEY", "") or os.environ.get("HF_TOKEN", "")
18
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
 
 
 
 
 
19
 
 
 
 
20
  BENCHMARK = "incidentops_env"
21
  TASK_IDS = ["incident_easy", "incident_medium", "incident_hard"]
22
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
23
  MAX_STEPS = 12
24
  TEMPERATURE = 0.2
25
 
26
+ SYSTEM_PROMPT = """You are an expert incident-response engineer.
27
+ You are given an incident observation with alert details, severity, affected services, and available actions.
28
+ Analyze the situation and choose the BEST single action from the available_actions list.
29
+
30
+ Rules:
31
+ - If logs are not available, request_logs first
32
+ - Investigate before escalating
33
+ - Escalate to the correct team based on evidence
34
+ - Resolve only when the incident is actually fixed
35
+ - Minimize steps to stay within SLA
36
+
37
+ Return ONLY the action string, nothing else. No explanation, no quotes."""
38
 
39
 
40
  def log_start(task, env, model):
 
52
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
53
 
54
 
55
+ def choose_action_llm(client, obs):
56
+ """Always call the LLM first, fall back to deterministic only on error."""
57
+ available = obs.get("available_actions", [])
58
+ if not available:
59
+ return "resolve_incident"
60
+
61
+ obs_for_llm = {
62
+ "alert_summary": obs.get("alert_summary", ""),
63
+ "severity": obs.get("severity", ""),
64
+ "likely_cause": obs.get("likely_cause", ""),
65
+ "hf_confidence": obs.get("hf_confidence", 0.0),
66
+ "logs_available": obs.get("logs_available", False),
67
+ "log_snippet": obs.get("log_snippet", ""),
68
+ "services_affected": obs.get("services_affected", []),
69
+ "elapsed_steps": obs.get("elapsed_steps", 0),
70
+ "sla_steps_remaining": obs.get("sla_steps_remaining", 0),
71
+ "action_history": obs.get("action_history", []),
72
+ "available_actions": available,
73
+ "incident_resolved": obs.get("incident_resolved", False),
74
+ "wrong_escalations": obs.get("wrong_escalations", 0),
75
+ }
76
+
77
+ try:
78
+ response = client.chat.completions.create(
79
+ model=MODEL_NAME,
80
+ messages=[
81
+ {"role": "system", "content": SYSTEM_PROMPT},
82
+ {"role": "user", "content": json.dumps(obs_for_llm)},
83
+ ],
84
+ temperature=TEMPERATURE,
85
+ max_tokens=20,
86
+ )
87
+ text = (response.choices[0].message.content or "").strip()
88
+ # Clean up response - take first line, remove quotes
89
+ text = text.splitlines()[0].strip().strip("'\"` ")
90
+
91
+ if text in available:
92
+ return text
93
+
94
+ # Try partial match
95
+ for action in available:
96
+ if action in text or text in action:
97
+ return action
98
+
99
+ except Exception as e:
100
+ print(f"[DEBUG] LLM call error: {e}", flush=True)
101
+
102
+ # Deterministic fallback only if LLM fails
103
+ return choose_action_deterministic(obs)
104
+
105
+
106
+ def choose_action_deterministic(obs):
107
+ """Fallback deterministic policy."""
108
  available = obs.get("available_actions", [])
109
  logs_available = obs.get("logs_available", False)
110
  likely_cause = obs.get("likely_cause", "unknown")
 
144
  return obs
145
 
146
 
147
+ def run_task(client, http, task_id):
 
148
  rewards = []
149
  steps_taken = 0
150
  success = False
 
153
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
154
 
155
  try:
156
+ # RESET
157
  r = http.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30.0)
158
  r.raise_for_status()
159
  obs = extract_obs(r.json())
 
160
 
161
  finished = obs.get("done", False) or obs.get("incident_resolved", False)
162
 
 
164
  if finished:
165
  break
166
 
167
+ # ALWAYS call LLM (required by validator)
168
+ action_name = choose_action_llm(client, obs)
169
 
170
+ # STEP
171
  r = http.post(
172
  f"{ENV_URL}/step",
173
  json={"action": {"action": action_name}},
 
187
  steps_taken = step
188
  log_step(step, action_name, reward, finished, None)
189
 
190
+ # GRADE
191
+ try:
192
+ r = http.get(f"{ENV_URL}/grade", params={"task_id": task_id}, timeout=30.0)
193
+ r.raise_for_status()
194
+ grade = r.json()
195
+ score = float(grade.get("score", 0.0))
196
+ success = bool(grade.get("success", False))
197
+ except Exception as e:
198
+ print(f"[DEBUG] Grade error: {e}", flush=True)
199
+ success = obs.get("incident_resolved", False)
200
+ score = max(0.0, min(1.0, sum(rewards) / 5.0))
201
 
202
  except Exception as e:
203
+ print(f"[DEBUG] Error in task {task_id}: {e}", flush=True)
204
  traceback.print_exc()
205
 
206
  finally:
207
  log_end(success, steps_taken, score, rewards)
208
 
209
 
 
 
 
210
  def main():
211
+ if not API_KEY:
212
+ print("[ERROR] No API_KEY or HF_TOKEN set!", flush=True)
213
+ sys.exit(1)
214
+
215
+ # Initialize OpenAI client with injected credentials
216
+ client = OpenAI(
217
+ base_url=API_BASE_URL,
218
+ api_key=API_KEY,
219
+ )
220
 
221
  http = httpx.Client()
222
 
223
+ # Health check
224
  try:
225
  r = http.get(f"{ENV_URL}/tasks", timeout=10.0)
226
+ r.raise_for_status()
227
  except Exception as e:
228
+ print(f"[ERROR] Server not reachable: {e}", flush=True)
229
+ for tid in TASK_IDS:
230
+ log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)
231
+ log_end(False, 0, 0.0, [])
232
  return
233
 
234
+ # Run all 3 tasks
235
  for task_id in TASK_IDS:
236
+ run_task(client, http, task_id)
237
 
238
  http.close()
 
 
239
 
 
 
240
 
241
  if __name__ == "__main__":
 
242
  try:
243
  main()
244
  except Exception as e: