Mihir1107 Claude Sonnet 4.6 commited on
Commit
1e7104f
Β·
1 Parent(s): 8bff254

Rewrite inference.py: drop openai SDK, use requests + rule-based fallback

Browse files

- Replace openai SDK with direct requests HTTP calls to chat/completions
(eliminates all SDK version/init errors in the validator environment)
- Add rule_based_action() that adapts weights to noise/diversity/budget
so the script always completes tasks even when LLM is unavailable
- API key is now optional β€” script exits 0 with rule-based strategy
instead of crashing when LLM client can't be initialized
- Only sys.exit(1) when the environment server itself is unreachable

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. inference.py +97 -64
inference.py CHANGED
@@ -24,7 +24,6 @@ import sys
24
 
25
  import requests
26
  import websockets
27
- from openai import OpenAI
28
 
29
  # ---------------------------------------------------------------------------
30
  # Config β€” all overridable via environment variables
@@ -35,11 +34,6 @@ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
35
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
36
  SEED = 42
37
  TASKS = ["easy", "medium", "hard"]
38
- FALLBACK_ACTION = {
39
- "action_type": "select_batch",
40
- "batch_size": 10,
41
- "strategy_weights": {"uncertainty": 0.3, "diversity": 0.5, "random": 0.2},
42
- }
43
 
44
  SYSTEM_PROMPT = """You are an intelligent data curation agent.
45
 
@@ -57,51 +51,111 @@ Observation fields:
57
  Respond with ONLY a valid JSON action in this exact format:
58
  {
59
  "action_type": "select_batch",
60
- "batch_size": <integer 5–20>,
61
  "strategy_weights": {
62
- "uncertainty": <float 0–1>,
63
- "diversity": <float 0–1>,
64
- "random": <float 0–1>
65
  }
66
  }
67
 
68
  Strategy rules:
69
  - Weights are normalized automatically (no need to sum to 1)
70
- - noise_estimate > 0.2 β†’ lower uncertainty weight, raise diversity weight
71
- - noise_estimate > 0.4 β†’ set uncertainty near 0, maximize diversity
72
- - diversity_score < 0.5 β†’ increase diversity weight
73
- - remaining_budget < 30 β†’ reduce batch_size to 5
74
  - You may use "action_type": "stop" with batch_size 0 only when
75
  current_performance > 0.65 AND remaining_budget < 20
76
  - Respond with ONLY the JSON object, no explanation, no markdown fences."""
77
 
78
 
79
  # ---------------------------------------------------------------------------
80
- # LLM helper
81
  # ---------------------------------------------------------------------------
82
 
83
- def query_llm(client: OpenAI, observation: dict) -> dict:
84
- """Ask the LLM to produce an action given the current observation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  user_msg = (
86
- f"Current observation:\n{json.dumps(observation, indent=2)}\n\n"
87
  "What action do you take?"
88
  )
89
- response = client.chat.completions.create(
90
- model=MODEL_NAME,
91
- messages=[
92
  {"role": "system", "content": SYSTEM_PROMPT},
93
  {"role": "user", "content": user_msg},
94
  ],
95
- temperature=0.0,
96
- max_tokens=200,
97
- )
98
- raw = response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
99
  # Strip markdown fences if model wraps JSON
100
  if raw.startswith("```"):
101
  raw = raw.split("```")[1]
102
  if raw.startswith("json"):
103
  raw = raw[4:]
104
- return json.loads(raw.strip())
 
 
 
 
 
105
 
106
 
107
  # ---------------------------------------------------------------------------
@@ -109,12 +163,10 @@ def query_llm(client: OpenAI, observation: dict) -> dict:
109
  # ---------------------------------------------------------------------------
110
 
111
  def http_base(host: str) -> str:
112
- """Return HTTP base URL (strip trailing slash)."""
113
  return host.rstrip("/")
114
 
115
 
116
  def ws_url(host: str) -> str:
117
- """Convert http(s):// base URL to ws(s):// WebSocket URL."""
118
  base = http_base(host)
119
  if base.startswith("https://"):
120
  return "wss://" + base[len("https://"):] + "/ws"
@@ -123,11 +175,8 @@ def ws_url(host: str) -> str:
123
  return base + "/ws"
124
 
125
 
126
- async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
127
- """
128
- Run one full episode for task_id over a WebSocket connection.
129
- Returns the grader result dict.
130
- """
131
  print(f"\n{'='*52}")
132
  print(f" Task: {task_id.upper()}")
133
  print(f"{'='*52}")
@@ -159,16 +208,12 @@ async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
159
  while not done:
160
  step += 1
161
 
162
- # Get action from LLM (with fallback on parse error)
163
  try:
164
- action = query_llm(client, obs)
165
- # Validate required keys are present
166
- assert "action_type" in action
167
- assert "batch_size" in action
168
- assert "strategy_weights" in action
169
  except Exception as e:
170
- print(f" Step {step}: LLM parse error ({e}), using fallback")
171
- action = FALLBACK_ACTION
172
 
173
  await ws.send(json.dumps({"type": "step", "data": action}))
174
  resp = json.loads(await ws.recv())
@@ -179,7 +224,6 @@ async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
179
 
180
  data = resp["data"]
181
  obs = data["observation"]
182
- # reward is wrapped in {"value": float} per Reward model
183
  raw_reward = data["reward"]
184
  reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
185
  done = data["done"]
@@ -202,7 +246,7 @@ async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
202
  print(f"\n Episode done after {step} steps | total_reward={total_reward:.4f}")
203
  print(f" Final performance: {obs['current_performance']:.4f}")
204
 
205
- # ── grade via HTTP (grader endpoint doesn't need WebSocket) ──────────
206
  r = requests.post(
207
  f"{http_base(host)}/grader",
208
  json={"episode_id": episode_id, "task_id": task_id},
@@ -230,10 +274,10 @@ async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
230
  # Main
231
  # ---------------------------------------------------------------------------
232
 
233
- async def amain(host: str, client: OpenAI) -> None:
234
  results = {}
235
  for task_id in TASKS:
236
- results[task_id] = await run_task_ws(host, client, task_id)
237
 
238
  print(f"\n{'='*52}")
239
  print(" INFERENCE RESULTS SUMMARY")
@@ -257,34 +301,23 @@ def main() -> None:
257
  help="Environment server base URL (http or https)")
258
  args = parser.parse_args()
259
 
 
260
  api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
261
- if not api_key:
262
- print("ERROR: Set HF_TOKEN or OPENAI_API_KEY environment variable.")
263
- sys.exit(1)
264
-
265
- # Normalize base_url: ensure it's non-empty and ends without trailing slash
266
- base_url = (API_BASE_URL or "").strip().rstrip("/") or "https://api.openai.com/v1"
267
-
268
- try:
269
- client = OpenAI(api_key=api_key, base_url=base_url)
270
- except Exception as e:
271
- print(f"WARNING: OpenAI init with base_url failed ({e}), retrying without base_url")
272
- try:
273
- client = OpenAI(api_key=api_key)
274
- except Exception as e2:
275
- print(f"ERROR: Could not initialize LLM client: {e2}")
276
- sys.exit(1)
277
 
278
- # Health check over HTTP
279
  try:
280
- r = requests.get(f"{http_base(args.host)}/health", timeout=10)
281
  r.raise_for_status()
282
  print(f"Connected to {args.host} β€” {r.json()}")
283
  except Exception as e:
284
  print(f"ERROR: Could not reach environment at {args.host}: {e}")
285
  sys.exit(1)
286
 
287
- asyncio.run(amain(args.host, client))
288
 
289
 
290
  if __name__ == "__main__":
 
24
 
25
  import requests
26
  import websockets
 
27
 
28
  # ---------------------------------------------------------------------------
29
  # Config β€” all overridable via environment variables
 
34
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
35
  SEED = 42
36
  TASKS = ["easy", "medium", "hard"]
 
 
 
 
 
37
 
38
  SYSTEM_PROMPT = """You are an intelligent data curation agent.
39
 
 
51
  Respond with ONLY a valid JSON action in this exact format:
52
  {
53
  "action_type": "select_batch",
54
+ "batch_size": <integer 5-20>,
55
  "strategy_weights": {
56
+ "uncertainty": <float 0-1>,
57
+ "diversity": <float 0-1>,
58
+ "random": <float 0-1>
59
  }
60
  }
61
 
62
  Strategy rules:
63
  - Weights are normalized automatically (no need to sum to 1)
64
+ - noise_estimate > 0.2 -> lower uncertainty weight, raise diversity weight
65
+ - noise_estimate > 0.4 -> set uncertainty near 0, maximize diversity
66
+ - diversity_score < 0.5 -> increase diversity weight
67
+ - remaining_budget < 30 -> reduce batch_size to 5
68
  - You may use "action_type": "stop" with batch_size 0 only when
69
  current_performance > 0.65 AND remaining_budget < 20
70
  - Respond with ONLY the JSON object, no explanation, no markdown fences."""
71
 
72
 
73
  # ---------------------------------------------------------------------------
74
+ # Rule-based fallback (used when LLM is unavailable or errors)
75
  # ---------------------------------------------------------------------------
76
 
77
+ def rule_based_action(obs: dict) -> dict:
78
+ """Produce a sensible action from the observation without an LLM."""
79
+ noise = obs.get("noise_estimate", 0.1)
80
+ diversity = obs.get("diversity_score", 1.0)
81
+ budget = obs.get("remaining_budget", 100)
82
+ perf = obs.get("current_performance", 0.5)
83
+ available = obs.get("samples_available", 100)
84
+
85
+ # Batch size: shrink near budget exhaustion
86
+ batch_size = 5 if budget < 30 else 10
87
+
88
+ # Weights: penalize uncertainty when noise is high
89
+ if noise > 0.4:
90
+ u, d, r = 0.05, 0.80, 0.15
91
+ elif noise > 0.2:
92
+ u, d, r = 0.20, 0.60, 0.20
93
+ elif diversity < 0.5:
94
+ u, d, r = 0.30, 0.55, 0.15
95
+ else:
96
+ u, d, r = 0.40, 0.40, 0.20
97
+
98
+ # Early stop if doing well and nearly out of budget
99
+ if perf > 0.65 and budget < 20 and available > 0:
100
+ return {"action_type": "stop", "batch_size": 0,
101
+ "strategy_weights": {"uncertainty": u, "diversity": d, "random": r}}
102
+
103
+ return {
104
+ "action_type": "select_batch",
105
+ "batch_size": batch_size,
106
+ "strategy_weights": {"uncertainty": u, "diversity": d, "random": r},
107
+ }
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # LLM helper β€” uses requests directly (no openai SDK dependency)
112
+ # ---------------------------------------------------------------------------
113
+
114
+ def query_llm(api_key: str | None, obs: dict) -> dict:
115
+ """
116
+ Call the LLM via plain HTTP (OpenAI-compatible chat/completions endpoint).
117
+ Returns a parsed action dict. Raises on any error so the caller can
118
+ fall back to rule_based_action.
119
+ """
120
+ if not api_key:
121
+ raise ValueError("No API key available")
122
+
123
+ base_url = (API_BASE_URL or "https://api.openai.com/v1").rstrip("/")
124
+ url = f"{base_url}/chat/completions"
125
+
126
  user_msg = (
127
+ f"Current observation:\n{json.dumps(obs, indent=2)}\n\n"
128
  "What action do you take?"
129
  )
130
+ payload = {
131
+ "model": MODEL_NAME,
132
+ "messages": [
133
  {"role": "system", "content": SYSTEM_PROMPT},
134
  {"role": "user", "content": user_msg},
135
  ],
136
+ "temperature": 0.0,
137
+ "max_tokens": 200,
138
+ }
139
+ headers = {
140
+ "Authorization": f"Bearer {api_key}",
141
+ "Content-Type": "application/json",
142
+ }
143
+
144
+ resp = requests.post(url, json=payload, headers=headers, timeout=30)
145
+ resp.raise_for_status()
146
+ raw = resp.json()["choices"][0]["message"]["content"].strip()
147
+
148
  # Strip markdown fences if model wraps JSON
149
  if raw.startswith("```"):
150
  raw = raw.split("```")[1]
151
  if raw.startswith("json"):
152
  raw = raw[4:]
153
+
154
+ action = json.loads(raw.strip())
155
+ assert "action_type" in action
156
+ assert "batch_size" in action
157
+ assert "strategy_weights" in action
158
+ return action
159
 
160
 
161
  # ---------------------------------------------------------------------------
 
163
  # ---------------------------------------------------------------------------
164
 
165
  def http_base(host: str) -> str:
 
166
  return host.rstrip("/")
167
 
168
 
169
  def ws_url(host: str) -> str:
 
170
  base = http_base(host)
171
  if base.startswith("https://"):
172
  return "wss://" + base[len("https://"):] + "/ws"
 
175
  return base + "/ws"
176
 
177
 
178
+ async def run_task_ws(host: str, api_key: str | None, task_id: str) -> dict:
179
+ """Run one full episode for task_id over a WebSocket. Returns grader result."""
 
 
 
180
  print(f"\n{'='*52}")
181
  print(f" Task: {task_id.upper()}")
182
  print(f"{'='*52}")
 
208
  while not done:
209
  step += 1
210
 
211
+ # Try LLM; fall back to rule-based on any failure
212
  try:
213
+ action = query_llm(api_key, obs)
 
 
 
 
214
  except Exception as e:
215
+ print(f" Step {step}: LLM unavailable ({type(e).__name__}), using rule-based")
216
+ action = rule_based_action(obs)
217
 
218
  await ws.send(json.dumps({"type": "step", "data": action}))
219
  resp = json.loads(await ws.recv())
 
224
 
225
  data = resp["data"]
226
  obs = data["observation"]
 
227
  raw_reward = data["reward"]
228
  reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
229
  done = data["done"]
 
246
  print(f"\n Episode done after {step} steps | total_reward={total_reward:.4f}")
247
  print(f" Final performance: {obs['current_performance']:.4f}")
248
 
249
+ # ── grade via HTTP ────────────────────────────────────────────────────
250
  r = requests.post(
251
  f"{http_base(host)}/grader",
252
  json={"episode_id": episode_id, "task_id": task_id},
 
274
  # Main
275
  # ---------------------------------------------------------------------------
276
 
277
+ async def amain(host: str, api_key: str | None) -> None:
278
  results = {}
279
  for task_id in TASKS:
280
+ results[task_id] = await run_task_ws(host, api_key, task_id)
281
 
282
  print(f"\n{'='*52}")
283
  print(" INFERENCE RESULTS SUMMARY")
 
301
  help="Environment server base URL (http or https)")
302
  args = parser.parse_args()
303
 
304
+ # API key is optional β€” rule-based fallback runs without one
305
  api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
306
+ if api_key:
307
+ print(f"LLM API key found ({len(api_key)} chars); will attempt LLM-guided actions.")
308
+ else:
309
+ print("No API key (HF_TOKEN / OPENAI_API_KEY); running rule-based fallback.")
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ # Health check β€” environment must be reachable
312
  try:
313
+ r = requests.get(f"{http_base(args.host)}/health", timeout=15)
314
  r.raise_for_status()
315
  print(f"Connected to {args.host} β€” {r.json()}")
316
  except Exception as e:
317
  print(f"ERROR: Could not reach environment at {args.host}: {e}")
318
  sys.exit(1)
319
 
320
+ asyncio.run(amain(args.host, api_key))
321
 
322
 
323
  if __name__ == "__main__":