OGrohit commited on
Commit
3358379
Β·
1 Parent(s): 6fb943b

Day 5: baseline.py, /baseline endpoint, openenv.yaml updated

Browse files
Files changed (3) hide show
  1. baseline.py +395 -0
  2. openenv.yaml +1 -0
  3. server/app.py +54 -2
baseline.py CHANGED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline inference script for LogTriageEnv.
3
+ Uses an LLM agent to play all 3 tasks and produce reproducible scores.
4
+
5
+ Usage:
6
+ # Set API key as environment variable (never hardcode)
7
+ export GROQ_API_KEY=your_key_here # Linux/Mac
8
+ set GROQ_API_KEY=your_key_here # Windows CMD
9
+ $env:GROQ_API_KEY="your_key_here" # Windows PowerShell
10
+
11
+ python baseline.py
12
+
13
+ Environment variables:
14
+ GROQ_API_KEY - Groq API key (primary)
15
+ NVIDIA_API_KEY - NVIDIA NIM API key (fallback)
16
+ OPENROUTER_API_KEY - OpenRouter API key (fallback)
17
+ OPENAI_API_KEY - OpenAI API key (fallback)
18
+ ENV_URL - Base URL of deployed environment (default: http://localhost:7860)
19
+ """
20
+ from __future__ import annotations
21
+ import os
22
+ import json
23
+ import time
24
+ import requests
25
+ from openai import OpenAI
26
+
27
+ # ─── PROVIDER CONFIG β€” change PROVIDER to switch. Nothing else changes. ───────
28
+
29
+ PROVIDER = "groq" # options: "groq", "nvidia", "openrouter", "openai"
30
+
31
+ PROVIDERS = {
32
+ "groq": {
33
+ "base_url": "https://api.groq.com/openai/v1",
34
+ "api_key_env": "GROQ_API_KEY",
35
+ "model": "llama-3.3-70b-versatile",
36
+ },
37
+ "nvidia": {
38
+ "base_url": "https://integrate.api.nvidia.com/v1",
39
+ "api_key_env": "NVIDIA_API_KEY",
40
+ "model": "openai/gpt-oss-20b",
41
+ },
42
+ "openrouter": {
43
+ "base_url": "https://openrouter.ai/api/v1",
44
+ "api_key_env": "OPENROUTER_API_KEY",
45
+ "model": "meta-llama/llama-3.1-8b-instruct:free",
46
+ },
47
+ "openai": {
48
+ "base_url": None,
49
+ "api_key_env": "OPENAI_API_KEY",
50
+ "model": "gpt-4o-mini",
51
+ },
52
+ }
53
+
54
+ # ─── ENVIRONMENT CONFIG ───────────────────────────────────────────────────────
55
+
56
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
57
+ TASKS = ["single_crash", "cascading_failure", "silent_degradation"]
58
+ MAX_STEPS_PER_TASK = {"single_crash": 8, "cascading_failure": 12, "silent_degradation": 15}
59
+ SEED = 42 # fixed seed for reproducibility
60
+
61
+ # ─── SYSTEM PROMPT ─────────────────────────────────────────────────────────────
62
+
63
+ SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) performing incident triage.
64
+ You will receive log lines from a microservice cluster and must diagnose and resolve the incident.
65
+
66
+ Available services: api-gateway, auth-service, user-db, payment-service, payment-db, notification-service, email-queue
67
+ Available teams: sre-team, backend-team, dba-team, security-team
68
+
69
+ You must respond with ONLY a valid JSON object in this exact format:
70
+ {
71
+ "action_type": "<one of: classify_severity, identify_root_cause, escalate, remediate, request_more_logs, resolve, ignore>",
72
+ "value": "<depends on action_type>",
73
+ "confidence": <float 0.0-1.0>,
74
+ "reasoning": "<brief explanation>"
75
+ }
76
+
77
+ Value rules by action_type:
78
+ - classify_severity: value must be "P1", "P2", or "P3"
79
+ - identify_root_cause: value must be a service name from the list above
80
+ - escalate: value must be a team name from the list above
81
+ - remediate: value must be "restart:<service>", "rollback:<service>", "scale:<service>", "flush-cache:<service>", or "kill-query:<service>"
82
+ - request_more_logs: value must be a service name or "all"
83
+ - resolve: value must be "resolved"
84
+ - ignore: value must be "noise"
85
+
86
+ Strategy:
87
+ 1. Read all log lines carefully
88
+ 2. Look at system_state for service health (error_rate, latency_p99_ms, status)
89
+ 3. Identify which service is the ROOT CAUSE (not just a symptom)
90
+ 4. Classify severity based on actual impact:
91
+ - P1: service down or error rate > 5% (customer impact)
92
+ - P2: degraded performance, trending toward P1 (no outage yet)
93
+ - P3: warning, no immediate impact
94
+ 5. Apply the correct fix to the ROOT CAUSE service, not symptom services
95
+ 6. Once you have classified, identified root cause, and remediated β€” resolve the incident
96
+
97
+ IMPORTANT: Respond with ONLY the JSON object. No explanation, no markdown, no backticks."""
98
+
99
+
100
+ def _build_user_prompt(obs: dict) -> str:
101
+ """Convert observation dict to a prompt string for the LLM."""
102
+ lines = []
103
+
104
+ # System state summary
105
+ lines.append("=== SYSTEM STATE ===")
106
+ for svc, status in obs.get("system_state", {}).items():
107
+ if isinstance(status, dict):
108
+ s = status.get("status", "unknown")
109
+ er = status.get("error_rate", 0)
110
+ lat = status.get("latency_p99_ms", 0)
111
+ if s != "up" or er > 0.01 or lat > 200:
112
+ lines.append(f" {svc}: {s} | error_rate={er:.1%} | latency_p99={lat}ms")
113
+ lines.append("")
114
+
115
+ # Active alerts
116
+ alerts = obs.get("active_alerts", [])
117
+ if alerts:
118
+ lines.append("=== ACTIVE ALERTS ===")
119
+ for alert in alerts:
120
+ lines.append(f" {alert}")
121
+ lines.append("")
122
+
123
+ # Log lines
124
+ lines.append("=== LOG LINES ===")
125
+ for log in obs.get("logs", []):
126
+ if isinstance(log, dict):
127
+ ts = log.get("timestamp", "")[-8:] # just time part
128
+ level = log.get("level", "INFO")
129
+ svc = log.get("service", "unknown")
130
+ msg = log.get("message", "")
131
+ lines.append(f" [{ts}] {level:<5} {svc:<25} {msg}")
132
+ lines.append("")
133
+
134
+ # Episode context
135
+ lines.append(f"Step: {obs.get('step_count', 0)} | "
136
+ f"Task: {obs.get('task_id', '')} | "
137
+ f"Time elapsed: {obs.get('time_elapsed_seconds', 0)}s")
138
+
139
+ # Feedback from last action
140
+ feedback = obs.get("last_action_feedback", "")
141
+ if feedback and feedback != "Incident detected. Analyze the logs and take action.":
142
+ lines.append(f"Last action feedback: {feedback}")
143
+
144
+ lines.append("")
145
+ lines.append("Based on the above, what is your next triage action? Respond with JSON only.")
146
+ return "\n".join(lines)
147
+
148
+
149
+ def _parse_action(response_text: str) -> dict | None:
150
+ """Parse LLM response into action dict. Returns None if parsing fails."""
151
+ text = response_text.strip()
152
+
153
+ # Strip markdown code blocks if present
154
+ if text.startswith("```"):
155
+ lines = text.split("\n")
156
+ text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
157
+
158
+ try:
159
+ action = json.loads(text)
160
+ # Validate required fields
161
+ if "action_type" not in action or "value" not in action:
162
+ return None
163
+ # Ensure confidence and reasoning exist
164
+ action.setdefault("confidence", 0.8)
165
+ action.setdefault("reasoning", "")
166
+ return action
167
+ except json.JSONDecodeError:
168
+ # Try to extract JSON from text
169
+ import re
170
+ match = re.search(r'\{[^{}]+\}', text, re.DOTALL)
171
+ if match:
172
+ try:
173
+ return json.loads(match.group())
174
+ except json.JSONDecodeError:
175
+ return None
176
+ return None
177
+
178
+
179
+ def _get_fallback_action(obs: dict, step: int) -> dict:
180
+ """
181
+ Fallback action when LLM fails to produce valid JSON.
182
+ Uses simple heuristics to make a reasonable action.
183
+ """
184
+ system_state = obs.get("system_state", {})
185
+ task_id = obs.get("task_id", "")
186
+
187
+ # Find the most degraded service
188
+ worst_service = None
189
+ worst_error_rate = 0
190
+ for svc, status in system_state.items():
191
+ if isinstance(status, dict):
192
+ er = status.get("error_rate", 0)
193
+ if er > worst_error_rate:
194
+ worst_error_rate = er
195
+ worst_service = svc
196
+
197
+ if step == 0:
198
+ return {"action_type": "classify_severity", "value": "P1", "confidence": 0.5, "reasoning": "fallback"}
199
+ elif step == 1 and worst_service:
200
+ return {"action_type": "identify_root_cause", "value": worst_service, "confidence": 0.5, "reasoning": "fallback"}
201
+ elif step == 2 and worst_service:
202
+ return {"action_type": "remediate", "value": f"restart:{worst_service}", "confidence": 0.5, "reasoning": "fallback"}
203
+ else:
204
+ return {"action_type": "resolve", "value": "resolved", "confidence": 0.5, "reasoning": "fallback"}
205
+
206
+
207
+ def run_task(client: OpenAI, model: str, task_id: str, seed: int = 42) -> dict:
208
+ """
209
+ Run one complete episode for a given task.
210
+ Returns dict with score, steps, and breakdown.
211
+ """
212
+ print(f"\n Running task: {task_id}...")
213
+
214
+ # Reset environment
215
+ try:
216
+ resp = requests.post(
217
+ f"{ENV_URL}/reset",
218
+ params={"task": task_id, "seed": seed},
219
+ timeout=30
220
+ )
221
+ resp.raise_for_status()
222
+ obs = resp.json()
223
+ except Exception as e:
224
+ print(f" ERROR: Failed to reset environment: {e}")
225
+ return {"score": 0.0, "error": str(e), "task_id": task_id}
226
+
227
+ max_steps = MAX_STEPS_PER_TASK.get(task_id, 10)
228
+ conversation_history = []
229
+ steps_taken = 0
230
+ done = obs.get("done", False)
231
+
232
+ while not done and steps_taken < max_steps:
233
+ # Build prompt from observation
234
+ user_prompt = _build_user_prompt(obs)
235
+
236
+ # Add to conversation history (keep last 4 exchanges for context)
237
+ conversation_history.append({"role": "user", "content": user_prompt})
238
+ if len(conversation_history) > 8:
239
+ conversation_history = conversation_history[-8:]
240
+
241
+ # Call LLM
242
+ try:
243
+ response = client.chat.completions.create(
244
+ model=model,
245
+ messages=[
246
+ {"role": "system", "content": SYSTEM_PROMPT},
247
+ ] + conversation_history,
248
+ max_tokens=200,
249
+ temperature=0, # deterministic
250
+ )
251
+ response_text = response.choices[0].message.content
252
+ conversation_history.append({"role": "assistant", "content": response_text})
253
+
254
+ # Parse action
255
+ action = _parse_action(response_text)
256
+ if action is None:
257
+ print(f" Step {steps_taken}: LLM parse failed, using fallback")
258
+ action = _get_fallback_action(obs, steps_taken)
259
+
260
+ except Exception as e:
261
+ print(f" Step {steps_taken}: LLM call failed ({e}), using fallback")
262
+ action = _get_fallback_action(obs, steps_taken)
263
+
264
+ # Take action in environment
265
+ try:
266
+ step_resp = requests.post(
267
+ f"{ENV_URL}/step",
268
+ json=action,
269
+ timeout=30
270
+ )
271
+ step_resp.raise_for_status()
272
+ obs = step_resp.json()
273
+ done = obs.get("done", False)
274
+ reward = obs.get("reward", 0.0)
275
+ feedback = obs.get("last_action_feedback", "")
276
+
277
+ print(f" Step {steps_taken}: {action['action_type']}({action['value']}) "
278
+ f"-> reward={reward:+.2f} | {feedback[:60]}")
279
+
280
+ except Exception as e:
281
+ print(f" Step {steps_taken}: Environment step failed: {e}")
282
+ break
283
+
284
+ steps_taken += 1
285
+ time.sleep(0.1) # small delay to avoid rate limits
286
+
287
+ # Get official grader score
288
+ try:
289
+ grader_resp = requests.post(f"{ENV_URL}/grader", timeout=30)
290
+ grader_resp.raise_for_status()
291
+ grader_result = grader_resp.json()
292
+ score = grader_result.get("score", 0.0)
293
+ breakdown = grader_result.get("breakdown", {})
294
+ except Exception as e:
295
+ print(f" ERROR: Grader call failed: {e}")
296
+ score = obs.get("cumulative_score", 0.0)
297
+ breakdown = {}
298
+
299
+ print(f" Final score: {score:.4f} ({steps_taken} steps)")
300
+ return {
301
+ "task_id": task_id,
302
+ "score": score,
303
+ "steps_taken": steps_taken,
304
+ "breakdown": breakdown,
305
+ }
306
+
307
+
308
+ def main():
309
+ """Run baseline agent against all 3 tasks and report scores."""
310
+
311
+ # ── Setup provider ─────────────────────────────────────────────────────────
312
+ provider_config = PROVIDERS[PROVIDER]
313
+ api_key = os.environ.get(provider_config["api_key_env"])
314
+ model = provider_config["model"]
315
+ base_url = provider_config["base_url"]
316
+
317
+ if not api_key:
318
+ raise ValueError(
319
+ f"API key not found. Set environment variable: {provider_config['api_key_env']}\n"
320
+ f" Windows PowerShell: $env:{provider_config['api_key_env']}='your_key'\n"
321
+ f" Windows CMD: set {provider_config['api_key_env']}=your_key"
322
+ )
323
+
324
+ # Build OpenAI-compatible client
325
+ client_kwargs = {"api_key": api_key}
326
+ if base_url:
327
+ client_kwargs["base_url"] = base_url
328
+ client = OpenAI(**client_kwargs)
329
+
330
+ print("=" * 60)
331
+ print("LogTriageEnv β€” Baseline Inference Script")
332
+ print("=" * 60)
333
+ print(f"Provider: {PROVIDER}")
334
+ print(f"Model: {model}")
335
+ print(f"Environment: {ENV_URL}")
336
+ print(f"Seed: {SEED}")
337
+ print(f"Tasks: {', '.join(TASKS)}")
338
+ print("=" * 60)
339
+
340
+ # ── Verify environment is running ──────────────────────────────────────────
341
+ try:
342
+ health = requests.get(f"{ENV_URL}/health", timeout=10)
343
+ health.raise_for_status()
344
+ print(f"Environment health: OK")
345
+ except Exception as e:
346
+ raise RuntimeError(
347
+ f"Environment not responding at {ENV_URL}\n"
348
+ f"Start it with: python -m uvicorn server.app:app --port 7860\n"
349
+ f"Error: {e}"
350
+ )
351
+
352
+ # ── Run all tasks ──────────────────────────────────────────────────────────
353
+ results = []
354
+ for task_id in TASKS:
355
+ result = run_task(client, model, task_id, seed=SEED)
356
+ results.append(result)
357
+
358
+ # ── Print final report ─────────────────────────────────────────────────────
359
+ print("\n" + "=" * 60)
360
+ print("BASELINE RESULTS")
361
+ print("=" * 60)
362
+
363
+ total_score = 0.0
364
+ for result in results:
365
+ task = result["task_id"]
366
+ score = result["score"]
367
+ steps = result["steps_taken"]
368
+ total_score += score
369
+ bar = "#" * int(score * 20) + "-" * (20 - int(score * 20))
370
+ print(f"{task:<25} {score:.4f} [{bar}] ({steps} steps)")
371
+ if result.get("breakdown"):
372
+ for k, v in result["breakdown"].items():
373
+ print(f" {k:<20} {v}")
374
+
375
+ avg_score = total_score / len(TASKS)
376
+ print("-" * 60)
377
+ print(f"{'AVERAGE':<25} {avg_score:.4f}")
378
+ print("=" * 60)
379
+
380
+ # ── Machine-readable output ────────────────────────────────────────────────
381
+ output = {
382
+ "provider": PROVIDER,
383
+ "model": model,
384
+ "seed": SEED,
385
+ "results": results,
386
+ "average_score": round(avg_score, 4),
387
+ }
388
+ print("\nJSON Output (for /baseline endpoint):")
389
+ print(json.dumps(output, indent=2))
390
+
391
+ return output
392
+
393
+
394
+ if __name__ == "__main__":
395
+ main()
openenv.yaml CHANGED
@@ -6,6 +6,7 @@ description: >
6
  and must diagnose, prioritize, and resolve incidents across 3 tasks
7
  of increasing difficulty.
8
  author: Rohit Patil
 
9
  tags:
10
  - openenv
11
  - sre
 
6
  and must diagnose, prioritize, and resolve incidents across 3 tasks
7
  of increasing difficulty.
8
  author: Rohit Patil
9
+ space_url: https://ogrohit-logtriage-env.hf.space
10
  tags:
11
  - openenv
12
  - sre
server/app.py CHANGED
@@ -1,6 +1,7 @@
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import JSONResponse
3
  import uvicorn
 
4
 
5
  from server.models import TriageAction
6
  from server.environment import LogTriageEnvironment
@@ -114,8 +115,59 @@ def grader():
114
 
115
  @app.post("/baseline")
116
  def baseline():
117
- # TODO Day 5: wire to baseline.py
118
- return {"message": "baseline endpoint β€” to be wired on Day 5"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import JSONResponse
3
  import uvicorn
4
+ import os
5
 
6
  from server.models import TriageAction
7
  from server.environment import LogTriageEnvironment
 
115
 
116
  @app.post("/baseline")
117
  def baseline():
118
+ """
119
+ Run the baseline inference script against all 3 tasks.
120
+ Returns scores for each task produced by the LLM agent.
121
+ Note: Requires GROQ_API_KEY (or other provider key) to be set.
122
+ """
123
+ import subprocess
124
+ import sys
125
+ import json as json_lib
126
+
127
+ try:
128
+ # Pass through all current env vars, plus GROQ_API_KEY if set
129
+ env = os.environ.copy()
130
+ groq_key = os.environ.get("GROQ_API_KEY", "")
131
+ if not groq_key:
132
+ # Try to read from process that started the server
133
+ pass
134
+
135
+ result = subprocess.run(
136
+ [sys.executable, "baseline.py"],
137
+ capture_output=True,
138
+ text=True,
139
+ timeout=300, # 5 minute timeout
140
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
141
+ env=env,
142
+ )
143
+
144
+ if result.returncode != 0:
145
+ return JSONResponse(
146
+ status_code=500,
147
+ content={
148
+ "error": "Baseline script failed",
149
+ "stderr": result.stderr[-500:] if result.stderr else "",
150
+ }
151
+ )
152
+
153
+ # Extract JSON from output
154
+ output_lines = result.stdout.strip().split("\n")
155
+ json_start = None
156
+ for i, line in enumerate(output_lines):
157
+ if line.strip() == "JSON Output (for /baseline endpoint):":
158
+ json_start = i + 1
159
+ break
160
+
161
+ if json_start and json_start < len(output_lines):
162
+ json_str = "\n".join(output_lines[json_start:])
163
+ return json_lib.loads(json_str)
164
+ else:
165
+ return {"message": "Baseline completed", "output": result.stdout[-1000:]}
166
+
167
+ except subprocess.TimeoutExpired:
168
+ return JSONResponse(status_code=504, content={"error": "Baseline timed out after 5 minutes"})
169
+ except Exception as e:
170
+ return JSONResponse(status_code=500, content={"error": str(e)})
171
 
172
 
173
  if __name__ == "__main__":