Mihir1107 Claude Sonnet 4.6 commited on
Commit
4d21b0a
Β·
1 Parent(s): 1e7104f

Fix inference.py: correct log format, robust OpenAI client, always exit 0

Browse files

- Use exact [START]/[STEP]/[END] format required by validator spec
- OpenAI client uses httpx.Client(trust_env=False) to avoid proxy-related
init failures in containerised environments
- [END] is always emitted via try/finally even on exception
- Rule-based fallback runs when LLM is unavailable; script always exits 0
unless the environment server itself is unreachable
- Update default API_BASE_URL to https://router.huggingface.co/v1
- Add httpx>=0.27.0 to requirements.txt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. inference.py +183 -141
  2. requirements.txt +1 -0
inference.py CHANGED
@@ -5,15 +5,19 @@ Connects to the environment via WebSocket (/ws) β€” the required transport
5
  on HF Spaces where HTTP /reset and /step are not accessible.
6
 
7
  Usage:
8
- export HF_TOKEN=hf_... # or OPENAI_API_KEY=sk-...
 
 
9
  export ENV_HOST=https://your-space.hf.space # or http://localhost:7860
10
- export API_BASE_URL=https://api-inference.huggingface.co/v1 # optional
11
- export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct # optional
12
  python inference.py [--host URL]
13
 
14
  Runs all 3 tasks sequentially using one WebSocket connection per task,
15
  calls POST /grader after each episode, prints scores and final summary.
16
- Designed to complete in under 20 minutes on 2 vCPU / 8 GB RAM.
 
 
 
 
17
  """
18
 
19
  import argparse
@@ -21,17 +25,21 @@ import asyncio
21
  import json
22
  import os
23
  import sys
 
24
 
 
25
  import requests
26
  import websockets
 
27
 
28
  # ---------------------------------------------------------------------------
29
  # Config β€” all overridable via environment variables
30
  # ---------------------------------------------------------------------------
31
 
32
  DEFAULT_HOST = os.environ.get("ENV_HOST", "http://localhost:7860")
33
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
34
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
 
35
  SEED = 42
36
  TASKS = ["easy", "medium", "hard"]
37
 
@@ -71,21 +79,46 @@ Strategy rules:
71
 
72
 
73
  # ---------------------------------------------------------------------------
74
- # Rule-based fallback (used when LLM is unavailable or errors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # ---------------------------------------------------------------------------
76
 
77
  def rule_based_action(obs: dict) -> dict:
78
- """Produce a sensible action from the observation without an LLM."""
79
- noise = obs.get("noise_estimate", 0.1)
80
- diversity = obs.get("diversity_score", 1.0)
81
- budget = obs.get("remaining_budget", 100)
82
- perf = obs.get("current_performance", 0.5)
83
- available = obs.get("samples_available", 100)
84
-
85
- # Batch size: shrink near budget exhaustion
86
  batch_size = 5 if budget < 30 else 10
87
 
88
- # Weights: penalize uncertainty when noise is high
89
  if noise > 0.4:
90
  u, d, r = 0.05, 0.80, 0.15
91
  elif noise > 0.2:
@@ -95,8 +128,7 @@ def rule_based_action(obs: dict) -> dict:
95
  else:
96
  u, d, r = 0.40, 0.40, 0.20
97
 
98
- # Early stop if doing well and nearly out of budget
99
- if perf > 0.65 and budget < 20 and available > 0:
100
  return {"action_type": "stop", "batch_size": 0,
101
  "strategy_weights": {"uncertainty": u, "diversity": d, "random": r}}
102
 
@@ -108,49 +140,48 @@ def rule_based_action(obs: dict) -> dict:
108
 
109
 
110
  # ---------------------------------------------------------------------------
111
- # LLM helper β€” uses requests directly (no openai SDK dependency)
112
  # ---------------------------------------------------------------------------
113
 
114
- def query_llm(api_key: str | None, obs: dict) -> dict:
115
  """
116
- Call the LLM via plain HTTP (OpenAI-compatible chat/completions endpoint).
117
- Returns a parsed action dict. Raises on any error so the caller can
118
- fall back to rule_based_action.
119
  """
120
- if not api_key:
121
- raise ValueError("No API key available")
 
 
 
 
122
 
123
- base_url = (API_BASE_URL or "https://api.openai.com/v1").rstrip("/")
124
- url = f"{base_url}/chat/completions"
125
 
 
 
 
 
 
 
126
  user_msg = (
127
  f"Current observation:\n{json.dumps(obs, indent=2)}\n\n"
128
  "What action do you take?"
129
  )
130
- payload = {
131
- "model": MODEL_NAME,
132
- "messages": [
133
  {"role": "system", "content": SYSTEM_PROMPT},
134
  {"role": "user", "content": user_msg},
135
  ],
136
- "temperature": 0.0,
137
- "max_tokens": 200,
138
- }
139
- headers = {
140
- "Authorization": f"Bearer {api_key}",
141
- "Content-Type": "application/json",
142
- }
143
-
144
- resp = requests.post(url, json=payload, headers=headers, timeout=30)
145
- resp.raise_for_status()
146
- raw = resp.json()["choices"][0]["message"]["content"].strip()
147
-
148
  # Strip markdown fences if model wraps JSON
149
  if raw.startswith("```"):
150
  raw = raw.split("```")[1]
151
  if raw.startswith("json"):
152
  raw = raw[4:]
153
-
154
  action = json.loads(raw.strip())
155
  assert "action_type" in action
156
  assert "batch_size" in action
@@ -175,98 +206,103 @@ def ws_url(host: str) -> str:
175
  return base + "/ws"
176
 
177
 
178
- async def run_task_ws(host: str, api_key: str | None, task_id: str) -> dict:
179
- """Run one full episode for task_id over a WebSocket. Returns grader result."""
180
- print(f"\n{'='*52}")
181
- print(f" Task: {task_id.upper()}")
182
- print(f"{'='*52}")
183
-
184
  url = ws_url(host)
185
- print(f" Connecting to {url} ...")
186
 
187
- async with websockets.connect(url, open_timeout=30, ping_interval=20) as ws:
 
 
 
 
 
188
 
189
- # ── reset ────────────────────────────────────────────────────────
190
- await ws.send(json.dumps({
191
- "type": "reset",
192
- "data": {"task_id": task_id, "seed": SEED},
193
- }))
194
- resp = json.loads(await ws.recv())
195
- if resp["type"] == "error":
196
- raise RuntimeError(f"reset error: {resp['data']['message']}")
197
 
198
- episode_id = resp["data"]["episode_id"]
199
- obs = resp["data"]["observation"]
200
- print(f" Episode ID: {episode_id}")
201
- print(f" Initial obs: {obs}")
202
-
203
- step = 0
204
- total_reward = 0.0
205
- done = False
206
-
207
- # ── step loop ────────────────────────────────────────────────────
208
- while not done:
209
- step += 1
210
-
211
- # Try LLM; fall back to rule-based on any failure
212
- try:
213
- action = query_llm(api_key, obs)
214
- except Exception as e:
215
- print(f" Step {step}: LLM unavailable ({type(e).__name__}), using rule-based")
216
- action = rule_based_action(obs)
217
 
218
- await ws.send(json.dumps({"type": "step", "data": action}))
 
 
 
 
219
  resp = json.loads(await ws.recv())
220
-
221
  if resp["type"] == "error":
222
- print(f" Step {step}: server error: {resp['data']['message']}")
223
- break
224
-
225
- data = resp["data"]
226
- obs = data["observation"]
227
- raw_reward = data["reward"]
228
- reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
229
- done = data["done"]
230
- total_reward += reward
231
-
232
- print(
233
- f" Step {step:2d} | perf={obs['current_performance']:.4f} "
234
- f"budget={obs['remaining_budget']:3d} "
235
- f"reward={reward:+.4f} "
236
- f"noise_est={obs['noise_estimate']:.3f}"
237
- )
238
-
239
- # ── close WebSocket cleanly ───────────────────────────────────────
240
- await ws.send(json.dumps({"type": "close", "data": {}}))
241
- try:
242
- await asyncio.wait_for(ws.recv(), timeout=2.0)
243
- except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
244
- pass
245
-
246
- print(f"\n Episode done after {step} steps | total_reward={total_reward:.4f}")
247
- print(f" Final performance: {obs['current_performance']:.4f}")
248
-
249
- # ── grade via HTTP ────────────────────────────────────────────────────
250
- r = requests.post(
251
- f"{http_base(host)}/grader",
252
- json={"episode_id": episode_id, "task_id": task_id},
253
- timeout=15,
254
- )
255
- r.raise_for_status()
256
- grade = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- print(f" Score: {grade['score']:.4f}")
259
- print(f" Passed: {grade['passed']}")
260
- print(f" Details: {grade['breakdown']}")
261
 
262
  return {
263
  "task_id": task_id,
264
- "score": grade["score"],
265
- "passed": grade["passed"],
266
- "breakdown": grade["breakdown"],
267
- "steps": step,
268
- "total_reward": round(total_reward, 4),
269
- "final_performance": obs["current_performance"],
270
  }
271
 
272
 
@@ -274,25 +310,26 @@ async def run_task_ws(host: str, api_key: str | None, task_id: str) -> dict:
274
  # Main
275
  # ---------------------------------------------------------------------------
276
 
277
- async def amain(host: str, api_key: str | None) -> None:
278
  results = {}
279
  for task_id in TASKS:
280
- results[task_id] = await run_task_ws(host, api_key, task_id)
281
 
282
- print(f"\n{'='*52}")
283
- print(" INFERENCE RESULTS SUMMARY")
284
- print(f"{'='*52}")
285
- print(f"{'Task':<10} {'Score':<8} {'Passed':<8} {'Final Perf':<12} {'Steps'}")
286
- print("-" * 52)
287
  for task_id, r in results.items():
288
  print(
289
  f"{task_id:<10} {r['score']:<8.4f} {str(r['passed']):<8} "
290
- f"{r['final_performance']:<12.4f} {r['steps']}"
 
291
  )
292
 
293
  overall = sum(r["score"] for r in results.values()) / len(results)
294
- print(f"\nOverall mean score: {overall:.4f}")
295
- print(json.dumps({"results": results, "mean_score": round(overall, 4)}, indent=2))
296
 
297
 
298
  def main() -> None:
@@ -301,23 +338,28 @@ def main() -> None:
301
  help="Environment server base URL (http or https)")
302
  args = parser.parse_args()
303
 
304
- # API key is optional β€” rule-based fallback runs without one
305
  api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
306
- if api_key:
307
- print(f"LLM API key found ({len(api_key)} chars); will attempt LLM-guided actions.")
 
308
  else:
309
- print("No API key (HF_TOKEN / OPENAI_API_KEY); running rule-based fallback.")
 
 
 
 
310
 
311
  # Health check β€” environment must be reachable
312
  try:
313
  r = requests.get(f"{http_base(args.host)}/health", timeout=15)
314
  r.raise_for_status()
315
- print(f"Connected to {args.host} β€” {r.json()}")
316
  except Exception as e:
317
- print(f"ERROR: Could not reach environment at {args.host}: {e}")
318
  sys.exit(1)
319
 
320
- asyncio.run(amain(args.host, api_key))
321
 
322
 
323
  if __name__ == "__main__":
 
5
  on HF Spaces where HTTP /reset and /step are not accessible.
6
 
7
  Usage:
8
+ export HF_TOKEN=hf_...
9
+ export API_BASE_URL=https://router.huggingface.co/v1
10
+ export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
11
  export ENV_HOST=https://your-space.hf.space # or http://localhost:7860
 
 
12
  python inference.py [--host URL]
13
 
14
  Runs all 3 tasks sequentially using one WebSocket connection per task,
15
  calls POST /grader after each episode, prints scores and final summary.
16
+
17
+ STDOUT FORMAT (required by validator):
18
+ [START] task=<task_name> env=DataSelectEnv model=<model_name>
19
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
20
+ [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...,rn>
21
  """
22
 
23
  import argparse
 
25
  import json
26
  import os
27
  import sys
28
+ from typing import List, Optional
29
 
30
+ import httpx
31
  import requests
32
  import websockets
33
+ from openai import OpenAI
34
 
35
  # ---------------------------------------------------------------------------
36
  # Config β€” all overridable via environment variables
37
  # ---------------------------------------------------------------------------
38
 
39
  DEFAULT_HOST = os.environ.get("ENV_HOST", "http://localhost:7860")
40
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
41
+ MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
42
+ BENCHMARK = "DataSelectEnv"
43
  SEED = 42
44
  TASKS = ["easy", "medium", "hard"]
45
 
 
79
 
80
 
81
  # ---------------------------------------------------------------------------
82
+ # Structured log helpers (validator-required format)
83
+ # ---------------------------------------------------------------------------
84
+
85
+ def log_start(task: str, model: str) -> None:
86
+ print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
87
+
88
+
89
+ def log_step(step: int, action: dict, reward: float, done: bool,
90
+ error: Optional[str] = None) -> None:
91
+ error_val = error if error else "null"
92
+ done_val = str(done).lower()
93
+ print(
94
+ f"[STEP] step={step} action={json.dumps(action)} "
95
+ f"reward={reward:.2f} done={done_val} error={error_val}",
96
+ flush=True,
97
+ )
98
+
99
+
100
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
101
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
102
+ print(
103
+ f"[END] success={str(success).lower()} steps={steps} "
104
+ f"score={score:.2f} rewards={rewards_str}",
105
+ flush=True,
106
+ )
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Rule-based fallback (used when LLM call fails)
111
  # ---------------------------------------------------------------------------
112
 
113
  def rule_based_action(obs: dict) -> dict:
114
+ """Adaptive rule-based action derived from observation."""
115
+ noise = obs.get("noise_estimate", 0.1)
116
+ diversity = obs.get("diversity_score", 1.0)
117
+ budget = obs.get("remaining_budget", 100)
118
+ perf = obs.get("current_performance", 0.5)
119
+
 
 
120
  batch_size = 5 if budget < 30 else 10
121
 
 
122
  if noise > 0.4:
123
  u, d, r = 0.05, 0.80, 0.15
124
  elif noise > 0.2:
 
128
  else:
129
  u, d, r = 0.40, 0.40, 0.20
130
 
131
+ if perf > 0.65 and budget < 20:
 
132
  return {"action_type": "stop", "batch_size": 0,
133
  "strategy_weights": {"uncertainty": u, "diversity": d, "random": r}}
134
 
 
140
 
141
 
142
  # ---------------------------------------------------------------------------
143
+ # OpenAI client factory β€” robust against proxy/env issues in containers
144
  # ---------------------------------------------------------------------------
145
 
146
+ def make_openai_client(api_key: str) -> OpenAI:
147
  """
148
+ Create the required OpenAI client.
149
+ Uses an explicit httpx.Client with trust_env=False to bypass proxy
150
+ auto-detection that commonly breaks SDK init in containerised environments.
151
  """
152
+ base_url = (API_BASE_URL or "https://router.huggingface.co/v1").strip().rstrip("/")
153
+ http_client = httpx.Client(trust_env=False)
154
+ try:
155
+ return OpenAI(api_key=api_key, base_url=base_url, http_client=http_client)
156
+ except Exception:
157
+ return OpenAI(api_key=api_key, http_client=http_client)
158
 
 
 
159
 
160
+ # ---------------------------------------------------------------------------
161
+ # LLM helper β€” uses the required OpenAI client
162
+ # ---------------------------------------------------------------------------
163
+
164
+ def query_llm(client: OpenAI, obs: dict) -> dict:
165
+ """Ask the LLM to produce an action given the current observation."""
166
  user_msg = (
167
  f"Current observation:\n{json.dumps(obs, indent=2)}\n\n"
168
  "What action do you take?"
169
  )
170
+ response = client.chat.completions.create(
171
+ model=MODEL_NAME,
172
+ messages=[
173
  {"role": "system", "content": SYSTEM_PROMPT},
174
  {"role": "user", "content": user_msg},
175
  ],
176
+ temperature=0.0,
177
+ max_tokens=200,
178
+ )
179
+ raw = response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
180
  # Strip markdown fences if model wraps JSON
181
  if raw.startswith("```"):
182
  raw = raw.split("```")[1]
183
  if raw.startswith("json"):
184
  raw = raw[4:]
 
185
  action = json.loads(raw.strip())
186
  assert "action_type" in action
187
  assert "batch_size" in action
 
206
  return base + "/ws"
207
 
208
 
209
+ async def run_task_ws(host: str, client: Optional[OpenAI], task_id: str) -> dict:
210
+ """Run one full episode for task_id over WebSocket. Returns grader result."""
 
 
 
 
211
  url = ws_url(host)
 
212
 
213
+ rewards: List[float] = []
214
+ steps_taken = 0
215
+ score = 0.0
216
+ success = False
217
+ obs = {}
218
+ episode_id = "unknown"
219
 
220
+ log_start(task=task_id, model=MODEL_NAME)
 
 
 
 
 
 
 
221
 
222
+ try:
223
+ async with websockets.connect(url, open_timeout=30, ping_interval=20) as ws:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ # ── reset ────────────────────────────────────────────────────
226
+ await ws.send(json.dumps({
227
+ "type": "reset",
228
+ "data": {"task_id": task_id, "seed": SEED},
229
+ }))
230
  resp = json.loads(await ws.recv())
 
231
  if resp["type"] == "error":
232
+ raise RuntimeError(f"reset error: {resp['data']['message']}")
233
+
234
+ episode_id = resp["data"]["episode_id"]
235
+ obs = resp["data"]["observation"]
236
+ done = False
237
+
238
+ # ── step loop ────────────────────────────────────────────────
239
+ while not done:
240
+ step_num = len(rewards) + 1
241
+ last_error: Optional[str] = None
242
+
243
+ # Try LLM; fall back to rule-based on any failure
244
+ try:
245
+ if client is None:
246
+ raise ValueError("no LLM client")
247
+ action = query_llm(client, obs)
248
+ except Exception as e:
249
+ last_error = f"{type(e).__name__}: {e}"
250
+ action = rule_based_action(obs)
251
+
252
+ await ws.send(json.dumps({"type": "step", "data": action}))
253
+ resp = json.loads(await ws.recv())
254
+
255
+ if resp["type"] == "error":
256
+ err_msg = resp["data"]["message"]
257
+ log_step(step_num, action, 0.0, True, error=err_msg)
258
+ rewards.append(0.0)
259
+ steps_taken = step_num
260
+ break
261
+
262
+ data = resp["data"]
263
+ obs = data["observation"]
264
+ raw_reward = data["reward"]
265
+ reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
266
+ done = data["done"]
267
+
268
+ rewards.append(reward)
269
+ steps_taken = step_num
270
+
271
+ log_step(step_num, action, reward, done, error=last_error)
272
+
273
+ # ── close WebSocket cleanly ───────────────────────────────────
274
+ await ws.send(json.dumps({"type": "close", "data": {}}))
275
+ try:
276
+ await asyncio.wait_for(ws.recv(), timeout=2.0)
277
+ except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
278
+ pass
279
+
280
+ # ── grade via HTTP ────────────────────────────────────────────────
281
+ r = requests.post(
282
+ f"{http_base(host)}/grader",
283
+ json={"episode_id": episode_id, "task_id": task_id},
284
+ timeout=15,
285
+ )
286
+ r.raise_for_status()
287
+ grade = r.json()
288
+ score = float(grade["score"])
289
+ success = bool(grade["passed"])
290
+
291
+ except Exception as exc:
292
+ print(f"[DEBUG] Episode error for {task_id}: {exc}", flush=True)
293
+ score = 0.0
294
+ success = False
295
 
296
+ finally:
297
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
 
298
 
299
  return {
300
  "task_id": task_id,
301
+ "score": score,
302
+ "passed": success,
303
+ "steps": steps_taken,
304
+ "total_reward": round(sum(rewards), 4),
305
+ "final_performance": obs.get("current_performance", 0.0),
 
306
  }
307
 
308
 
 
310
  # Main
311
  # ---------------------------------------------------------------------------
312
 
313
+ async def amain(host: str, client: Optional[OpenAI]) -> None:
314
  results = {}
315
  for task_id in TASKS:
316
+ results[task_id] = await run_task_ws(host, client, task_id)
317
 
318
+ print(f"\n{'='*52}", flush=True)
319
+ print(" INFERENCE RESULTS SUMMARY", flush=True)
320
+ print(f"{'='*52}", flush=True)
321
+ print(f"{'Task':<10} {'Score':<8} {'Passed':<8} {'Final Perf':<12} {'Steps'}", flush=True)
322
+ print("-" * 52, flush=True)
323
  for task_id, r in results.items():
324
  print(
325
  f"{task_id:<10} {r['score']:<8.4f} {str(r['passed']):<8} "
326
+ f"{r['final_performance']:<12.4f} {r['steps']}",
327
+ flush=True,
328
  )
329
 
330
  overall = sum(r["score"] for r in results.values()) / len(results)
331
+ print(f"\nOverall mean score: {overall:.4f}", flush=True)
332
+ print(json.dumps({"results": results, "mean_score": round(overall, 4)}, indent=2), flush=True)
333
 
334
 
335
  def main() -> None:
 
338
  help="Environment server base URL (http or https)")
339
  args = parser.parse_args()
340
 
341
+ # Build OpenAI client (required by spec); warn and fall back if unavailable
342
  api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
343
+ client: Optional[OpenAI] = None
344
+ if not api_key:
345
+ print("WARNING: No HF_TOKEN / OPENAI_API_KEY found β€” using rule-based fallback.", flush=True)
346
  else:
347
+ try:
348
+ client = make_openai_client(api_key)
349
+ print(f"OpenAI client ready | base_url={API_BASE_URL} | model={MODEL_NAME}", flush=True)
350
+ except Exception as e:
351
+ print(f"WARNING: Could not init OpenAI client ({e}); using rule-based fallback.", flush=True)
352
 
353
  # Health check β€” environment must be reachable
354
  try:
355
  r = requests.get(f"{http_base(args.host)}/health", timeout=15)
356
  r.raise_for_status()
357
+ print(f"Environment health: {r.json()}", flush=True)
358
  except Exception as e:
359
+ print(f"ERROR: Could not reach environment at {args.host}: {e}", flush=True)
360
  sys.exit(1)
361
 
362
+ asyncio.run(amain(args.host, client))
363
 
364
 
365
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -4,5 +4,6 @@ pydantic==2.7.4
4
  numpy==1.26.4
5
  scikit-learn==1.5.1
6
  openai==1.40.0
 
7
  requests==2.32.3
8
  websockets>=12.0
 
4
  numpy==1.26.4
5
  scikit-learn==1.5.1
6
  openai==1.40.0
7
+ httpx>=0.27.0
8
  requests==2.32.3
9
  websockets>=12.0