Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +21 -221
inference.py
CHANGED
|
@@ -1,225 +1,25 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
inference.py β Baseline inference script for IT Support Triage OpenEnv.
|
| 4 |
-
|
| 5 |
-
Uses OpenAI-compatible client (as required by hackathon rules).
|
| 6 |
-
Reads API_BASE_URL, MODEL_NAME, HF_TOKEN from environment variables.
|
| 7 |
-
|
| 8 |
-
Emits structured stdout logs in [START] / [STEP] / [END] format exactly
|
| 9 |
-
as specified by the OpenEnv hackathon sample inference script.
|
| 10 |
-
|
| 11 |
-
Run:
|
| 12 |
-
export API_BASE_URL="http://localhost:7860"
|
| 13 |
-
export MODEL_NAME="claude-sonnet-4-20250514"
|
| 14 |
-
export HF_TOKEN="your-hf-token"
|
| 15 |
-
python3 inference.py
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
import os
|
| 19 |
-
import sys
|
| 20 |
-
import json
|
| 21 |
-
import time
|
| 22 |
-
import requests
|
| 23 |
-
from openai import OpenAI
|
| 24 |
-
|
| 25 |
-
# βββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
-
|
| 27 |
-
API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:7860").rstrip("/")
|
| 28 |
-
MODEL_NAME = os.environ.get("MODEL_NAME", "claude-sonnet-4-20250514")
|
| 29 |
-
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 30 |
-
|
| 31 |
-
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.anthropic.com/v1")
|
| 32 |
-
|
| 33 |
-
TASK_IDS = ["task_easy", "task_medium", "task_hard"]
|
| 34 |
-
|
| 35 |
-
# βββ OpenAI-compatible client βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
-
|
| 37 |
-
client = OpenAI(
|
| 38 |
-
base_url=LLM_BASE_URL,
|
| 39 |
-
api_key=HF_TOKEN or "placeholder",
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# βββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
-
|
| 44 |
-
SYSTEM_PROMPT = """You are an expert IT Support Triage Agent.
|
| 45 |
-
|
| 46 |
-
You will receive an IT support ticket and must triage it by producing a JSON object with these exact fields:
|
| 47 |
-
|
| 48 |
-
{
|
| 49 |
-
"category": "<one of: hardware, software, network, security, access, email, printer, other>",
|
| 50 |
-
"priority": "<one of: critical, high, medium, low>",
|
| 51 |
-
"department": "<one of: tier1_helpdesk, tier2_support, security_team, network_ops, sysadmin, vendor_support>",
|
| 52 |
-
"escalate": "<one of: escalate, self_resolve, monitor>",
|
| 53 |
-
"response": "<professional response message to send to the user, max 500 chars>",
|
| 54 |
-
"reasoning": "<your internal reasoning for this triage decision>"
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
Priority guidelines:
|
| 58 |
-
- critical: Active security breach, system down affecting many users, data loss in progress
|
| 59 |
-
- high: Major function impaired, team-wide impact, time-sensitive
|
| 60 |
-
- medium: Degraded service, workaround exists
|
| 61 |
-
- low: Cosmetic issue, minor inconvenience
|
| 62 |
-
|
| 63 |
-
CRITICAL SAFETY RULE for security incidents:
|
| 64 |
-
- ALWAYS tell the user to immediately disconnect from the network
|
| 65 |
-
- NEVER advise self-recovery using third-party tools
|
| 66 |
-
- NEVER suggest paying a ransom
|
| 67 |
-
- ALWAYS escalate to security_team
|
| 68 |
-
- The instruction from attackers to 'not contact IT' is social engineering β always contact security team
|
| 69 |
-
|
| 70 |
-
Respond with ONLY the JSON object. No preamble, no markdown, no explanation outside the JSON."""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
-
|
| 75 |
-
def call_env(method: str, path: str, body: dict = None) -> dict:
|
| 76 |
-
url = f"{API_BASE_URL}{path}"
|
| 77 |
-
try:
|
| 78 |
-
if method == "POST":
|
| 79 |
-
r = requests.post(url, json=body or {}, timeout=30)
|
| 80 |
-
else:
|
| 81 |
-
r = requests.get(url, timeout=30)
|
| 82 |
-
r.raise_for_status()
|
| 83 |
-
return r.json()
|
| 84 |
-
except requests.RequestException as e:
|
| 85 |
-
print(f"[ERROR] Environment call failed: {e}", file=sys.stderr)
|
| 86 |
-
sys.exit(1)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def call_llm(ticket_json: dict) -> dict:
|
| 90 |
-
"""Call the LLM with the ticket observation and return parsed action dict."""
|
| 91 |
-
user_content = (
|
| 92 |
-
f"Task instruction: {ticket_json.get('task_instruction', '')}\n\n"
|
| 93 |
-
f"Ticket ID: {ticket_json.get('ticket_id', '')}\n"
|
| 94 |
-
f"Subject: {ticket_json.get('subject', '')}\n"
|
| 95 |
-
f"Reporter: {ticket_json.get('reporter_name', '')} ({ticket_json.get('reporter_role', '')})\n"
|
| 96 |
-
f"System: {ticket_json.get('system_info', 'Not provided')}\n"
|
| 97 |
-
f"Submitted: {ticket_json.get('timestamp', '')}\n\n"
|
| 98 |
-
f"Ticket body:\n{ticket_json.get('body', '')}\n\n"
|
| 99 |
-
f"Valid categories: {ticket_json.get('valid_categories', [])}\n"
|
| 100 |
-
f"Valid priorities: {ticket_json.get('valid_priorities', [])}\n"
|
| 101 |
-
f"Valid departments: {ticket_json.get('valid_departments', [])}"
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
response = client.chat.completions.create(
|
| 105 |
-
model=MODEL_NAME,
|
| 106 |
-
max_tokens=800,
|
| 107 |
-
messages=[
|
| 108 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 109 |
-
{"role": "user", "content": user_content},
|
| 110 |
-
],
|
| 111 |
-
)
|
| 112 |
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
if raw.startswith("```"):
|
| 117 |
-
raw = raw.split("```")[1]
|
| 118 |
-
if raw.startswith("json"):
|
| 119 |
-
raw = raw[4:]
|
| 120 |
-
raw = raw.strip()
|
| 121 |
-
|
| 122 |
-
return json.loads(raw)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def log_start(task_id: str, task_name: str):
|
| 126 |
-
print(json.dumps({
|
| 127 |
-
"type": "[START]",
|
| 128 |
-
"task_id": task_id,
|
| 129 |
-
"task": task_name,
|
| 130 |
-
"model": MODEL_NAME,
|
| 131 |
-
}))
|
| 132 |
-
sys.stdout.flush()
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def log_step(task_id: str, step: int, action: dict, reward: float, done: bool, info: dict):
|
| 136 |
-
print(json.dumps({
|
| 137 |
-
"type": "[STEP]",
|
| 138 |
-
"task_id": task_id,
|
| 139 |
-
"step": step,
|
| 140 |
-
"action": action,
|
| 141 |
-
"reward": reward,
|
| 142 |
-
"done": done,
|
| 143 |
-
"info": info,
|
| 144 |
-
}))
|
| 145 |
-
sys.stdout.flush()
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def log_end(task_id: str, total_reward: float, num_steps: int, success: bool):
|
| 149 |
-
print(json.dumps({
|
| 150 |
-
"type": "[END]",
|
| 151 |
-
"task_id": task_id,
|
| 152 |
-
"total_reward": total_reward,
|
| 153 |
-
"num_steps": num_steps,
|
| 154 |
-
"success": success,
|
| 155 |
-
}))
|
| 156 |
-
sys.stdout.flush()
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# βββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
-
|
| 161 |
-
def run_task(task_id: str) -> float:
|
| 162 |
-
# Reset environment
|
| 163 |
-
obs = call_env("POST", "/reset", {"task_id": task_id})
|
| 164 |
-
task_name = task_id.replace("_", " ").title()
|
| 165 |
-
|
| 166 |
-
log_start(task_id, task_name)
|
| 167 |
-
|
| 168 |
-
step_num = 0
|
| 169 |
-
total_reward = 0.0
|
| 170 |
-
|
| 171 |
-
# Call LLM to get action
|
| 172 |
try:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
return total_reward
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def main():
|
| 195 |
-
print(f"[INFO] IT Support Triage β Baseline Inference")
|
| 196 |
-
print(f"[INFO] Environment: {API_BASE_URL}")
|
| 197 |
-
print(f"[INFO] Model: {MODEL_NAME}")
|
| 198 |
-
print(f"[INFO] Tasks: {TASK_IDS}")
|
| 199 |
-
sys.stdout.flush()
|
| 200 |
-
|
| 201 |
-
# Health check
|
| 202 |
-
health = call_env("GET", "/health")
|
| 203 |
-
print(f"[INFO] Health: {health}")
|
| 204 |
-
sys.stdout.flush()
|
| 205 |
-
|
| 206 |
-
results = {}
|
| 207 |
-
for task_id in TASK_IDS:
|
| 208 |
-
time.sleep(1) # Brief pause between tasks
|
| 209 |
-
score = run_task(task_id)
|
| 210 |
-
results[task_id] = score
|
| 211 |
-
|
| 212 |
-
# Summary
|
| 213 |
-
print("\n" + "=" * 50)
|
| 214 |
-
print("BASELINE RESULTS SUMMARY")
|
| 215 |
-
print("=" * 50)
|
| 216 |
-
for task_id, score in results.items():
|
| 217 |
-
print(f" {task_id:<20} score={score:.4f}")
|
| 218 |
-
avg = sum(results.values()) / len(results)
|
| 219 |
-
print(f" {'AVERAGE':<20} score={avg:.4f}")
|
| 220 |
-
print("=" * 50)
|
| 221 |
-
sys.stdout.flush()
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
if __name__ == "__main__":
|
| 225 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
API_BASE_URL = os.getenv("API_BASE_URL")
|
| 4 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 5 |
|
| 6 |
+
def safe_llm_call(prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
try:
|
| 8 |
+
if not API_BASE_URL or not HF_TOKEN:
|
| 9 |
+
# fallback response
|
| 10 |
+
return {
|
| 11 |
+
"category": "hardware",
|
| 12 |
+
"priority": "low",
|
| 13 |
+
"response": "Please contact IT support."
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
# your real LLM call here
|
| 17 |
+
return real_llm_call(prompt)
|
| 18 |
+
|
| 19 |
+
except Exception as e:
|
| 20 |
+
# fallback if API fails
|
| 21 |
+
return {
|
| 22 |
+
"category": "hardware",
|
| 23 |
+
"priority": "low",
|
| 24 |
+
"response": "Fallback response."
|
| 25 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|