Elliot89 commited on
Commit
e4ab55f
Β·
verified Β·
1 Parent(s): 2b5d42a

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -204
agent.py DELETED
@@ -1,204 +0,0 @@
1
- """
2
- inference.py β€” OpenEnv Hackathon baseline inference script.
3
-
4
- Required env vars (set in HF Space secrets or .env):
5
- API_BASE_URL OpenAI-compatible LLM endpoint
6
- MODEL_NAME Model identifier
7
- HF_TOKEN API key for the LLM endpoint
8
-
9
- Runs the agent against all 3 tasks Γ— 2 scenarios each.
10
- Final stdout line is valid JSON β€” required by the hackathon validator.
11
-
12
- Usage:
13
- export API_BASE_URL="https://api.groq.com/openai/v1"
14
- export MODEL_NAME="llama-3.1-8b-instant"
15
- export HF_TOKEN="gsk_your_key_here"
16
- python inference.py
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import json
22
- import os
23
- import sys
24
-
25
- import requests
26
- from openai import OpenAI
27
- from dotenv import load_dotenv
28
-
29
- load_dotenv()
30
-
31
- # ── Config from env vars (hackathon required names) ──────────────────────────
32
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
33
- MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.1-8b-instant")
34
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
35
- ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
36
-
37
- if not HF_TOKEN:
38
- print("[WARN] HF_TOKEN is not set β€” LLM calls will fail.", file=sys.stderr)
39
-
40
- client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
41
-
42
- # ── System prompt ─────────────────────────────────────────────────────────────
43
- SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) responding to a live production incident.
44
-
45
- You receive an incident observation as JSON. Respond with ONLY a single valid JSON action object β€” no markdown, no explanation.
46
-
47
- Available action_types and their parameters:
48
- Diagnostic (gather info):
49
- {"action_type": "query_logs", "parameters": {"service": "<name>"}}
50
- {"action_type": "check_metrics", "parameters": {"service": "<name>"}}
51
- {"action_type": "check_dependencies", "parameters": {"service": "<name>"}}
52
- {"action_type": "check_recent_deploys", "parameters": {"service": "<name>"}}
53
- {"action_type": "check_service_status", "parameters": {"service": "<name>"}}
54
-
55
- Remediation (fix the issue):
56
- {"action_type": "restart_service", "parameters": {"service": "<name>"}}
57
- {"action_type": "rollback_deploy", "parameters": {"service": "<name>", "target_version": "previous"}}
58
- {"action_type": "scale_service", "parameters": {"service": "<name>", "replicas": 5}}
59
- {"action_type": "disable_feature_flag", "parameters": {"flag": "<flag_name>"}}
60
- {"action_type": "clear_cache", "parameters": {"service": "<name>"}}
61
- {"action_type": "execute_runbook_step", "parameters": {"runbook_action": "<action>", "target": "<name>"}}
62
-
63
- Submission (end the episode β€” choose ONE based on task):
64
- {"action_type": "submit_severity", "parameters": {"severity": "P1|P2|P3|P4", "service": "<root_cause_service>"}}
65
- {"action_type": "submit_root_cause", "parameters": {"service": "<root_cause>", "failure_mode": "<what_went_wrong>"}}
66
- {"action_type": "submit_resolution", "parameters": {"summary": "<full description of what happened and what you did>"}}
67
-
68
- Strategy by task:
69
- alert_classification (max 3 steps): Query 1-2 services for evidence, then submit_severity.
70
- root_cause_analysis (max 10 steps): Query logs/metrics/deps for multiple services, trace the failure chain, then submit_root_cause.
71
- remediation_planning (max 15 steps): Investigate, execute fix actions, then submit_resolution with a detailed summary.
72
-
73
- Output ONLY the JSON object. Nothing else."""
74
-
75
-
76
- def _format_obs(obs: dict) -> str:
77
- parts = [
78
- f"TASK: {obs.get('task_id')} | Step {obs.get('step_count')}/{obs.get('max_steps')}",
79
- f"INCIDENT: {obs.get('incident_summary', '')}",
80
- ]
81
- alert = obs.get("alert", {})
82
- if alert:
83
- parts.append("ALERT:\n" + json.dumps(alert, indent=2))
84
- if obs.get("available_actions"):
85
- parts.append(f"AVAILABLE ACTIONS: {obs['available_actions']}")
86
- if obs.get("queried_data"):
87
- parts.append("DATA GATHERED:\n" + json.dumps(obs["queried_data"], indent=2))
88
- parts.append(f"LAST REWARD: {obs.get('cumulative_reward', 0.0)}")
89
- parts.append(f"FEEDBACK: {obs.get('feedback', '')}")
90
- return "\n\n".join(parts)
91
-
92
-
93
- def _parse_action(text: str) -> dict:
94
- text = text.strip()
95
- # Strip markdown code fences if present
96
- if text.startswith("```"):
97
- lines = [l for l in text.splitlines() if not l.startswith("```")]
98
- text = "\n".join(lines).strip()
99
- try:
100
- return json.loads(text)
101
- except json.JSONDecodeError:
102
- start, end = text.find("{"), text.rfind("}") + 1
103
- if start != -1 and end > start:
104
- return json.loads(text[start:end])
105
- raise
106
-
107
-
108
- def _run_episode(task_id: str, scenario_index: int) -> float:
109
- r = requests.post(
110
- f"{ENV_BASE_URL}/reset",
111
- params={"task_id": task_id, "scenario_index": scenario_index},
112
- timeout=30,
113
- )
114
- r.raise_for_status()
115
- obs = r.json()
116
-
117
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
118
-
119
- for _step in range(obs.get("max_steps", 10)):
120
- messages.append({"role": "user", "content": _format_obs(obs)})
121
-
122
- response = client.chat.completions.create(
123
- model=MODEL_NAME,
124
- messages=messages,
125
- temperature=0.0,
126
- max_tokens=256,
127
- )
128
- raw = response.choices[0].message.content
129
- messages.append({"role": "assistant", "content": raw})
130
-
131
- try:
132
- action = _parse_action(raw)
133
- except Exception as e:
134
- print(f" [WARN] parse failed at step {_step+1}: {e}", file=sys.stderr)
135
- # Graceful fallback per task
136
- if task_id == "alert_classification":
137
- action = {"action_type": "submit_severity",
138
- "parameters": {"severity": "P2", "service": "unknown"}}
139
- elif task_id == "root_cause_analysis":
140
- action = {"action_type": "submit_root_cause",
141
- "parameters": {"service": "unknown", "failure_mode": "unknown"}}
142
- else:
143
- action = {"action_type": "submit_resolution",
144
- "parameters": {"summary": "Unable to determine root cause."}}
145
-
146
- step_r = requests.post(
147
- f"{ENV_BASE_URL}/step",
148
- json=action,
149
- headers={"Content-Type": "application/json"},
150
- timeout=30,
151
- )
152
- step_r.raise_for_status()
153
- result = step_r.json()
154
- obs = result["observation"]
155
-
156
- if result.get("done"):
157
- break
158
-
159
- # Get final grader score
160
- g = requests.get(f"{ENV_BASE_URL}/grader", timeout=30)
161
- g.raise_for_status()
162
- return g.json().get("total", 0.0)
163
-
164
-
165
- def main():
166
- runs = [
167
- ("alert_classification", 0),
168
- ("alert_classification", 1),
169
- ("root_cause_analysis", 0),
170
- ("root_cause_analysis", 1),
171
- ("remediation_planning", 0),
172
- ("remediation_planning", 1),
173
- ]
174
-
175
- results: dict[str, list[float]] = {}
176
-
177
- print(f"{'Task':<30} {'Scenario':>8} {'Score':>8}")
178
- print("-" * 52)
179
-
180
- for task_id, scenario_index in runs:
181
- try:
182
- score = _run_episode(task_id, scenario_index)
183
- except Exception as e:
184
- print(f" [ERROR] {task_id} s{scenario_index}: {e}", file=sys.stderr)
185
- score = 0.0
186
-
187
- label = f"{task_id} [s{scenario_index}]"
188
- print(f"{label:<30} {scenario_index:>8} {score:>8.4f}")
189
- results.setdefault(task_id, []).append(score)
190
-
191
- print("-" * 52)
192
- summary = {task: round(sum(v) / len(v), 4) for task, v in results.items()}
193
- summary["overall"] = round(sum(summary.values()) / len(summary), 4)
194
-
195
- print("\nBaseline Summary:")
196
- for k, v in summary.items():
197
- print(f" {k:<30}: {v:.4f}")
198
-
199
- # Final line must be valid JSON β€” parsed by /baseline endpoint
200
- print(json.dumps(summary))
201
-
202
-
203
- if __name__ == "__main__":
204
- main()