yadnyeshkolte commited on
Commit
92d9fa2
Β·
1 Parent(s): caf7c32

Add mandatory inference.py to HF Space root

Browse files
Files changed (1) hide show
  1. inference.py +233 -0
inference.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script for API Integration Debugging Environment
3
+ ===================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
10
+
11
+ - Defaults are set only for API_BASE_URL and MODEL_NAME:
12
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
13
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
+
15
+ - The inference script must be named `inference.py` and placed in the root directory of the project
16
+ - Participants must use OpenAI Client for all LLM calls using above variables
17
+
18
+ STDOUT FORMAT
19
+ - The script must emit exactly three line types to stdout, in this order:
20
+
21
+ [START] task=<task_name> env=<benchmark> model=<model_name>
22
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
23
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
24
+ """
25
+
26
+ import asyncio
27
+ import json
28
+ import os
29
+ import textwrap
30
+ from typing import Dict, List, Optional
31
+
32
+ from openai import OpenAI
33
+
34
+ from models import ApiDebugAction, ApiDebugObservation
35
+ from server.api_debug_env_environment import ApiDebugEnvironment
36
+ from scenarios import get_all_task_ids
37
+
38
+ # ─── Environment Variables ─────────────────────────────────────────────────────
39
+
40
+ IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
41
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
42
+
43
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
44
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
45
+ BENCHMARK = "api_debug_env"
46
+ MAX_STEPS = 40 # max across all tasks (hard has 40)
47
+ TEMPERATURE = 0.3
48
+ MAX_TOKENS = 800
49
+ SUCCESS_SCORE_THRESHOLD = 0.1
50
+
51
+ SYSTEM_PROMPT = textwrap.dedent("""
52
+ You are an expert API debugging agent. You are tasked with diagnosing and fixing
53
+ broken API integrations. You interact with a simulated multi-service environment.
54
+
55
+ Available actions (respond with JSON):
56
+ {
57
+ "action_type": "inspect_logs" | "inspect_config" | "inspect_endpoint" | "submit_fix",
58
+ "target": "<service_name>",
59
+ "fix_payload": { ... } // required only for submit_fix
60
+ }
61
+
62
+ Strategy:
63
+ 1. First inspect_logs on each service to identify error patterns
64
+ 2. Then inspect_config to understand current (broken) settings
65
+ 3. Use inspect_endpoint to see actual error responses
66
+ 4. Submit fixes with corrected configuration values
67
+
68
+ IMPORTANT: When submitting a fix, include ALL the corrected key-value pairs in fix_payload.
69
+ For nested keys like "headers.Authorization", use the nested format:
70
+ {"headers.Authorization": "Bearer <token>"}
71
+
72
+ Respond with ONLY valid JSON. No explanation text.
73
+ """).strip()
74
+
75
+
76
+ # ─── Logging Functions ──────────────────────────────────────────────────────────
77
+
78
+ def log_start(task: str, env: str, model: str) -> None:
79
+ print(f"[START] task={task} env={env} model={model}", flush=True)
80
+
81
+
82
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
83
+ error_val = error if error else "null"
84
+ done_val = str(done).lower()
85
+ print(
86
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
87
+ flush=True,
88
+ )
89
+
90
+
91
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
92
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
93
+ print(
94
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
95
+ flush=True,
96
+ )
97
+
98
+
99
+ # ─── LLM Interaction ────────────────────────────────────────────────────────────
100
+
101
+ def build_user_prompt(obs: ApiDebugObservation, step: int) -> str:
102
+ """Build a prompt from the current observation."""
103
+ parts = [
104
+ f"Step: {step}",
105
+ f"Task: {obs.task_description}",
106
+ f"Remaining steps: {obs.remaining_steps}",
107
+ f"Issues found: {obs.issues_found}/{obs.issues_total}",
108
+ f"Issues fixed: {obs.issues_fixed}/{obs.issues_total}",
109
+ f"Last action result: {obs.action_result}",
110
+ f"Available targets: {obs.available_targets}",
111
+ ]
112
+
113
+ if obs.logs:
114
+ parts.append("Logs:\n" + "\n".join(obs.logs))
115
+ if obs.config_snapshot:
116
+ parts.append(f"Config: {json.dumps(obs.config_snapshot, indent=2)}")
117
+ if obs.api_response:
118
+ parts.append(f"API Response: {json.dumps(obs.api_response, indent=2)}")
119
+ if obs.hints:
120
+ parts.append(f"Hints: {'; '.join(obs.hints)}")
121
+
122
+ return "\n".join(parts)
123
+
124
+
125
+ def get_model_action(
126
+ client: OpenAI,
127
+ obs: ApiDebugObservation,
128
+ step: int,
129
+ messages: List[Dict],
130
+ ) -> ApiDebugAction:
131
+ """Get next action from the LLM."""
132
+ user_prompt = build_user_prompt(obs, step)
133
+ messages.append({"role": "user", "content": user_prompt})
134
+
135
+ try:
136
+ completion = client.chat.completions.create(
137
+ model=MODEL_NAME,
138
+ messages=messages,
139
+ temperature=TEMPERATURE,
140
+ max_tokens=MAX_TOKENS,
141
+ stream=False,
142
+ )
143
+ text = (completion.choices[0].message.content or "").strip()
144
+
145
+ # Try to extract JSON from the response
146
+ # Handle cases where model wraps JSON in markdown code blocks
147
+ if "```" in text:
148
+ json_start = text.find("{")
149
+ json_end = text.rfind("}") + 1
150
+ if json_start >= 0 and json_end > json_start:
151
+ text = text[json_start:json_end]
152
+
153
+ action_json = json.loads(text)
154
+ messages.append({"role": "assistant", "content": json.dumps(action_json)})
155
+
156
+ return ApiDebugAction(
157
+ action_type=action_json.get("action_type", "inspect_logs"),
158
+ target=action_json.get("target", obs.available_targets[0] if obs.available_targets else ""),
159
+ fix_payload=action_json.get("fix_payload"),
160
+ )
161
+ except Exception as exc:
162
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
163
+ # Fallback: inspect logs of first available target
164
+ fallback_target = obs.available_targets[0] if obs.available_targets else ""
165
+ return ApiDebugAction(
166
+ action_type="inspect_logs",
167
+ target=fallback_target,
168
+ )
169
+
170
+
171
+ # ─── Main Execution ─────────────────────────────────────────────────────────────
172
+
173
+ async def run_task(task_id: str, client: OpenAI) -> tuple:
174
+ """Run a single task and return (score, rewards, steps)."""
175
+ env = ApiDebugEnvironment(task_id=task_id)
176
+
177
+ rewards: List[float] = []
178
+ steps_taken = 0
179
+ score = 0.0
180
+ success = False
181
+
182
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
183
+
184
+ try:
185
+ obs = env.reset()
186
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
187
+
188
+ for step in range(1, MAX_STEPS + 1):
189
+ if obs.done:
190
+ break
191
+
192
+ action = get_model_action(client, obs, step, messages)
193
+ action_str = f"{action.action_type}(target={action.target})"
194
+ if action.fix_payload:
195
+ action_str = f"{action.action_type}(target={action.target}, fix={json.dumps(action.fix_payload)})"
196
+
197
+ obs = env.step(action)
198
+
199
+ reward = obs.reward if obs.reward is not None else 0.0
200
+ done = obs.done
201
+ error = None
202
+
203
+ rewards.append(reward)
204
+ steps_taken = step
205
+
206
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
207
+
208
+ if done:
209
+ break
210
+
211
+ score = env.grade()
212
+ score = min(max(score, 0.0), 1.0)
213
+ success = score >= SUCCESS_SCORE_THRESHOLD
214
+
215
+ except Exception as e:
216
+ print(f"[DEBUG] Error during task {task_id}: {e}", flush=True)
217
+ finally:
218
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
219
+
220
+ return score, rewards, steps_taken
221
+
222
+
223
+ async def main() -> None:
224
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
225
+
226
+ task_ids = get_all_task_ids() # ["easy", "medium", "hard"]
227
+
228
+ for task_id in task_ids:
229
+ await run_task(task_id, client)
230
+
231
+
232
+ if __name__ == "__main__":
233
+ asyncio.run(main())