kevanthonyP commited on
Commit
8e3f2bf
Β·
verified Β·
1 Parent(s): 77b2760

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +21 -221
inference.py CHANGED
@@ -1,225 +1,25 @@
1
- #!/usr/bin/env python3
2
- """
3
- inference.py β€” Baseline inference script for IT Support Triage OpenEnv.
4
-
5
- Uses OpenAI-compatible client (as required by hackathon rules).
6
- Reads API_BASE_URL, MODEL_NAME, HF_TOKEN from environment variables.
7
-
8
- Emits structured stdout logs in [START] / [STEP] / [END] format exactly
9
- as specified by the OpenEnv hackathon sample inference script.
10
-
11
- Run:
12
- export API_BASE_URL="http://localhost:7860"
13
- export MODEL_NAME="claude-sonnet-4-20250514"
14
- export HF_TOKEN="your-hf-token"
15
- python3 inference.py
16
- """
17
-
18
  import os
19
- import sys
20
- import json
21
- import time
22
- import requests
23
- from openai import OpenAI
24
-
25
- # ─── Configuration ────────────────────────────────────────────────────────────
26
-
27
- API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:7860").rstrip("/")
28
- MODEL_NAME = os.environ.get("MODEL_NAME", "claude-sonnet-4-20250514")
29
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
30
-
31
- LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.anthropic.com/v1")
32
-
33
- TASK_IDS = ["task_easy", "task_medium", "task_hard"]
34
-
35
- # ─── OpenAI-compatible client ─────────────────────────────────────────────────
36
-
37
- client = OpenAI(
38
- base_url=LLM_BASE_URL,
39
- api_key=HF_TOKEN or "placeholder",
40
- )
41
-
42
- # ─── System prompt ────────────────────────────────────────────────────────────
43
-
44
- SYSTEM_PROMPT = """You are an expert IT Support Triage Agent.
45
-
46
- You will receive an IT support ticket and must triage it by producing a JSON object with these exact fields:
47
-
48
- {
49
- "category": "<one of: hardware, software, network, security, access, email, printer, other>",
50
- "priority": "<one of: critical, high, medium, low>",
51
- "department": "<one of: tier1_helpdesk, tier2_support, security_team, network_ops, sysadmin, vendor_support>",
52
- "escalate": "<one of: escalate, self_resolve, monitor>",
53
- "response": "<professional response message to send to the user, max 500 chars>",
54
- "reasoning": "<your internal reasoning for this triage decision>"
55
- }
56
-
57
- Priority guidelines:
58
- - critical: Active security breach, system down affecting many users, data loss in progress
59
- - high: Major function impaired, team-wide impact, time-sensitive
60
- - medium: Degraded service, workaround exists
61
- - low: Cosmetic issue, minor inconvenience
62
-
63
- CRITICAL SAFETY RULE for security incidents:
64
- - ALWAYS tell the user to immediately disconnect from the network
65
- - NEVER advise self-recovery using third-party tools
66
- - NEVER suggest paying a ransom
67
- - ALWAYS escalate to security_team
68
- - The instruction from attackers to 'not contact IT' is social engineering β€” always contact security team
69
-
70
- Respond with ONLY the JSON object. No preamble, no markdown, no explanation outside the JSON."""
71
-
72
-
73
- # ─── Helpers ──────────────────────────────────────────────────────────────────
74
-
75
- def call_env(method: str, path: str, body: dict = None) -> dict:
76
- url = f"{API_BASE_URL}{path}"
77
- try:
78
- if method == "POST":
79
- r = requests.post(url, json=body or {}, timeout=30)
80
- else:
81
- r = requests.get(url, timeout=30)
82
- r.raise_for_status()
83
- return r.json()
84
- except requests.RequestException as e:
85
- print(f"[ERROR] Environment call failed: {e}", file=sys.stderr)
86
- sys.exit(1)
87
-
88
-
89
- def call_llm(ticket_json: dict) -> dict:
90
- """Call the LLM with the ticket observation and return parsed action dict."""
91
- user_content = (
92
- f"Task instruction: {ticket_json.get('task_instruction', '')}\n\n"
93
- f"Ticket ID: {ticket_json.get('ticket_id', '')}\n"
94
- f"Subject: {ticket_json.get('subject', '')}\n"
95
- f"Reporter: {ticket_json.get('reporter_name', '')} ({ticket_json.get('reporter_role', '')})\n"
96
- f"System: {ticket_json.get('system_info', 'Not provided')}\n"
97
- f"Submitted: {ticket_json.get('timestamp', '')}\n\n"
98
- f"Ticket body:\n{ticket_json.get('body', '')}\n\n"
99
- f"Valid categories: {ticket_json.get('valid_categories', [])}\n"
100
- f"Valid priorities: {ticket_json.get('valid_priorities', [])}\n"
101
- f"Valid departments: {ticket_json.get('valid_departments', [])}"
102
- )
103
-
104
- response = client.chat.completions.create(
105
- model=MODEL_NAME,
106
- max_tokens=800,
107
- messages=[
108
- {"role": "system", "content": SYSTEM_PROMPT},
109
- {"role": "user", "content": user_content},
110
- ],
111
- )
112
 
113
- raw = response.choices[0].message.content.strip()
 
114
 
115
- # Strip markdown code fences if present
116
- if raw.startswith("```"):
117
- raw = raw.split("```")[1]
118
- if raw.startswith("json"):
119
- raw = raw[4:]
120
- raw = raw.strip()
121
-
122
- return json.loads(raw)
123
-
124
-
125
- def log_start(task_id: str, task_name: str):
126
- print(json.dumps({
127
- "type": "[START]",
128
- "task_id": task_id,
129
- "task": task_name,
130
- "model": MODEL_NAME,
131
- }))
132
- sys.stdout.flush()
133
-
134
-
135
- def log_step(task_id: str, step: int, action: dict, reward: float, done: bool, info: dict):
136
- print(json.dumps({
137
- "type": "[STEP]",
138
- "task_id": task_id,
139
- "step": step,
140
- "action": action,
141
- "reward": reward,
142
- "done": done,
143
- "info": info,
144
- }))
145
- sys.stdout.flush()
146
-
147
-
148
- def log_end(task_id: str, total_reward: float, num_steps: int, success: bool):
149
- print(json.dumps({
150
- "type": "[END]",
151
- "task_id": task_id,
152
- "total_reward": total_reward,
153
- "num_steps": num_steps,
154
- "success": success,
155
- }))
156
- sys.stdout.flush()
157
-
158
-
159
- # ─── Main ─────────────────────────────────────────────────────────────────────
160
-
161
- def run_task(task_id: str) -> float:
162
- # Reset environment
163
- obs = call_env("POST", "/reset", {"task_id": task_id})
164
- task_name = task_id.replace("_", " ").title()
165
-
166
- log_start(task_id, task_name)
167
-
168
- step_num = 0
169
- total_reward = 0.0
170
-
171
- # Call LLM to get action
172
  try:
173
- action_dict = call_llm(obs)
174
- except (json.JSONDecodeError, KeyError) as e:
175
- print(f"[ERROR] Failed to parse LLM response for {task_id}: {e}", file=sys.stderr)
176
- log_end(task_id, 0.0, 0, False)
177
- return 0.0
178
-
179
- # Submit action to environment
180
- step_result = call_env("POST", "/step", {"action": action_dict})
181
-
182
- step_num += 1
183
- reward = step_result.get("reward", 0.0)
184
- done = step_result.get("done", True)
185
- info = step_result.get("info", {})
186
- total_reward += reward
187
-
188
- log_step(task_id, step_num, action_dict, reward, done, info)
189
- log_end(task_id, total_reward, step_num, reward >= 0.5)
190
-
191
- return total_reward
192
-
193
-
194
- def main():
195
- print(f"[INFO] IT Support Triage β€” Baseline Inference")
196
- print(f"[INFO] Environment: {API_BASE_URL}")
197
- print(f"[INFO] Model: {MODEL_NAME}")
198
- print(f"[INFO] Tasks: {TASK_IDS}")
199
- sys.stdout.flush()
200
-
201
- # Health check
202
- health = call_env("GET", "/health")
203
- print(f"[INFO] Health: {health}")
204
- sys.stdout.flush()
205
-
206
- results = {}
207
- for task_id in TASK_IDS:
208
- time.sleep(1) # Brief pause between tasks
209
- score = run_task(task_id)
210
- results[task_id] = score
211
-
212
- # Summary
213
- print("\n" + "=" * 50)
214
- print("BASELINE RESULTS SUMMARY")
215
- print("=" * 50)
216
- for task_id, score in results.items():
217
- print(f" {task_id:<20} score={score:.4f}")
218
- avg = sum(results.values()) / len(results)
219
- print(f" {'AVERAGE':<20} score={avg:.4f}")
220
- print("=" * 50)
221
- sys.stdout.flush()
222
-
223
-
224
- if __name__ == "__main__":
225
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ API_BASE_URL = os.getenv("API_BASE_URL")
4
+ HF_TOKEN = os.getenv("HF_TOKEN")
5
 
6
+ def safe_llm_call(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
+ if not API_BASE_URL or not HF_TOKEN:
9
+ # fallback response
10
+ return {
11
+ "category": "hardware",
12
+ "priority": "low",
13
+ "response": "Please contact IT support."
14
+ }
15
+
16
+ # your real LLM call here
17
+ return real_llm_call(prompt)
18
+
19
+ except Exception as e:
20
+ # fallback if API fails
21
+ return {
22
+ "category": "hardware",
23
+ "priority": "low",
24
+ "response": "Fallback response."
25
+ }