ar9avg commited on
Commit
2014920
Β·
1 Parent(s): 98b87b7

Bulletproof _safe_score for all bad inputs (None, NaN, strings, bool)

Browse files

log_step and log_end now route every reward/score through _safe_score
which handles: None, NaN, inf, empty string, non-numeric string, bool,
negative, >1. All map to the closed range [0.05, 0.95] β€” strictly in (0, 1).

Also added catch-alls in run_episode and main() so if the env or LLM
client crashes at any point, every task still emits a valid
[START]/[STEP]/[END] block with score in (0, 1).

Files changed (1) hide show
  1. inference.py +80 -39
inference.py CHANGED
@@ -82,25 +82,58 @@ SYSTEM_PROMPT = textwrap.dedent("""
82
 
83
  # ── Logging ───────────────────────────────────────────────────────────────────
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def log_start(task: str, model: str) -> None:
86
  print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
87
 
88
 
89
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
90
- error_val = error.replace("\n", " ").strip() if error else "null"
91
- done_val = str(done).lower()
 
 
 
92
  print(
93
- f"[STEP] step={step} action={action} reward={reward:.2f} "
94
  f"done={done_val} error={error_val}",
95
  flush=True,
96
  )
97
 
98
 
99
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
100
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
 
 
 
 
101
  print(
102
- f"[END] success={str(success).lower()} steps={steps} "
103
- f"score={score:.3f} rewards={rewards_str}",
104
  flush=True,
105
  )
106
 
@@ -152,16 +185,6 @@ def pick_action(
152
 
153
  # ── Single-episode runner ─────────────────────────────────────────────────────
154
 
155
- _SCORE_EPS = 0.05 # strict (0, 1) with generous margin for :.2f/:.3f rounding
156
-
157
-
158
- def _clamp_score(x: float) -> float:
159
- """Clamp to strictly (0, 1). Uses 0.05 margin so :.2f/:.3f formatting stays safe."""
160
- if x != x: # NaN
161
- return 0.5
162
- return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, x))
163
-
164
-
165
  async def run_episode(
166
  env: SQLAgentEnv,
167
  client: OpenAI,
@@ -172,7 +195,7 @@ async def run_episode(
172
 
173
  rewards: List[float] = []
174
  steps_taken = 0
175
- score = _SCORE_EPS
176
  success = False
177
  last_error: Optional[str] = None
178
 
@@ -180,28 +203,30 @@ async def run_episode(
180
  try:
181
  obs = env.reset(task_id)
182
  except Exception as exc:
183
- log_step(step=1, action="reset", reward=_SCORE_EPS, done=True, error=str(exc))
184
- rewards.append(_SCORE_EPS)
185
  steps_taken = 1
186
  return
187
 
188
  for step in range(1, MAX_STEPS + 1):
189
- action_name = pick_action(client, obs, step)
 
 
 
190
  action = Action(repair_action=action_name)
191
 
192
  try:
193
  obs, reward_info = await env.step(action)
194
  except Exception as exc:
195
- log_step(step=step, action=action_name, reward=_SCORE_EPS, done=True, error=str(exc))
196
- rewards.append(_SCORE_EPS)
197
  steps_taken = step
198
  break
199
 
200
- raw_reward = reward_info.value if reward_info.value is not None else _SCORE_EPS
201
- reward = _clamp_score(raw_reward)
202
- done = reward_info.done
203
- last_error = obs.error_message
204
- success = reward_info.success
205
 
206
  rewards.append(reward)
207
  steps_taken = step
@@ -218,16 +243,19 @@ async def run_episode(
218
  break
219
 
220
  denom = max(len(rewards), 1)
221
- avg = sum(rewards) / denom if rewards else _SCORE_EPS
222
- score = _clamp_score(avg)
223
 
 
 
 
 
 
 
224
  finally:
225
- # Final safety net: score and every reward must be strictly in (0, 1)
226
- score = _clamp_score(score)
227
- rewards = [_clamp_score(r) for r in rewards]
228
  log_end(
229
  success=success,
230
- steps=steps_taken,
231
  score=score,
232
  rewards=rewards,
233
  )
@@ -236,12 +264,25 @@ async def run_episode(
236
  # ── Main ──────────────────────────────────────────────────────────────────────
237
 
238
  async def main() -> None:
239
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
240
- env = SQLAgentEnv()
 
 
 
 
 
 
 
 
 
241
 
242
  for task_id in TASKS:
243
- await run_episode(env, client, task_id)
244
- # Small gap between tasks for readability
 
 
 
 
245
  print("", flush=True)
246
 
247
 
 
82
 
83
  # ── Logging ───────────────────────────────────────────────────────────────────
84
 
85
+ # Hard bounds: every score/reward we ever emit is clamped to this closed range.
86
+ # 0.05 margin guarantees that :.2f and :.3f formatting never produces
87
+ # "0.00", "0.000", "1.00", or "1.000" (all of which parse as exactly 0.0 / 1.0).
88
+ _MIN_SCORE = 0.05
89
+ _MAX_SCORE = 0.95
90
+
91
+
92
+ def _safe_score(x) -> float:
93
+ """Coerce anything (None, NaN, str, bool, int, float) to a float strictly in (0, 1)."""
94
+ try:
95
+ if x is None:
96
+ return _MIN_SCORE
97
+ if isinstance(x, bool):
98
+ return _MAX_SCORE if x else _MIN_SCORE
99
+ v = float(x)
100
+ if v != v: # NaN
101
+ return _MIN_SCORE
102
+ if v == float("inf"):
103
+ return _MAX_SCORE
104
+ if v == float("-inf"):
105
+ return _MIN_SCORE
106
+ except (TypeError, ValueError):
107
+ return _MIN_SCORE
108
+ return max(_MIN_SCORE, min(_MAX_SCORE, v))
109
+
110
+
111
  def log_start(task: str, model: str) -> None:
112
  print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
113
 
114
 
115
+ def log_step(step: int, action: str, reward, done: bool, error: Optional[str]) -> None:
116
+ r = _safe_score(reward)
117
+ error_val = (error or "null")
118
+ if hasattr(error_val, "replace"):
119
+ error_val = error_val.replace("\n", " ").strip() or "null"
120
+ done_val = str(bool(done)).lower()
121
  print(
122
+ f"[STEP] step={int(step)} action={action or 'noop'} reward={r:.2f} "
123
  f"done={done_val} error={error_val}",
124
  flush=True,
125
  )
126
 
127
 
128
+ def log_end(success: bool, steps: int, score, rewards: List) -> None:
129
+ s = _safe_score(score)
130
+ safe_rewards = [_safe_score(r) for r in (rewards or [])]
131
+ if not safe_rewards:
132
+ safe_rewards = [_MIN_SCORE]
133
+ rewards_str = ",".join(f"{r:.2f}" for r in safe_rewards)
134
  print(
135
+ f"[END] success={str(bool(success)).lower()} steps={int(steps)} "
136
+ f"score={s:.3f} rewards={rewards_str}",
137
  flush=True,
138
  )
139
 
 
185
 
186
  # ── Single-episode runner ─────────────────────────────────────────────────────
187
 
 
 
 
 
 
 
 
 
 
 
188
  async def run_episode(
189
  env: SQLAgentEnv,
190
  client: OpenAI,
 
195
 
196
  rewards: List[float] = []
197
  steps_taken = 0
198
+ score = _MIN_SCORE
199
  success = False
200
  last_error: Optional[str] = None
201
 
 
203
  try:
204
  obs = env.reset(task_id)
205
  except Exception as exc:
206
+ log_step(step=1, action="reset", reward=_MIN_SCORE, done=True, error=str(exc))
207
+ rewards.append(_MIN_SCORE)
208
  steps_taken = 1
209
  return
210
 
211
  for step in range(1, MAX_STEPS + 1):
212
+ try:
213
+ action_name = pick_action(client, obs, step)
214
+ except Exception:
215
+ action_name = "generate"
216
  action = Action(repair_action=action_name)
217
 
218
  try:
219
  obs, reward_info = await env.step(action)
220
  except Exception as exc:
221
+ log_step(step=step, action=action_name, reward=_MIN_SCORE, done=True, error=str(exc))
222
+ rewards.append(_MIN_SCORE)
223
  steps_taken = step
224
  break
225
 
226
+ reward = _safe_score(getattr(reward_info, "value", None))
227
+ done = bool(getattr(reward_info, "done", False))
228
+ last_error = getattr(obs, "error_message", None)
229
+ success = bool(getattr(reward_info, "success", False))
 
230
 
231
  rewards.append(reward)
232
  steps_taken = step
 
243
  break
244
 
245
  denom = max(len(rewards), 1)
246
+ avg = sum(rewards) / denom if rewards else _MIN_SCORE
247
+ score = _safe_score(avg)
248
 
249
+ except Exception as exc:
250
+ # Catch-all so we always emit a valid [END] line
251
+ log_step(step=steps_taken or 1, action="error", reward=_MIN_SCORE, done=True, error=str(exc))
252
+ if not rewards:
253
+ rewards.append(_MIN_SCORE)
254
+ score = _MIN_SCORE
255
  finally:
 
 
 
256
  log_end(
257
  success=success,
258
+ steps=max(int(steps_taken), 1),
259
  score=score,
260
  rewards=rewards,
261
  )
 
264
  # ── Main ──────────────────────────────────────────────────────────────────────
265
 
266
  async def main() -> None:
267
+ try:
268
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
269
+ env = SQLAgentEnv()
270
+ except Exception as exc:
271
+ # Environment couldn't init β€” still emit a valid [START]/[STEP]/[END] per task
272
+ for task_id in TASKS:
273
+ log_start(task=task_id, model=MODEL_NAME)
274
+ log_step(step=1, action="init_error", reward=_MIN_SCORE, done=True, error=str(exc))
275
+ log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE])
276
+ print("", flush=True)
277
+ return
278
 
279
  for task_id in TASKS:
280
+ try:
281
+ await run_episode(env, client, task_id)
282
+ except Exception as exc:
283
+ # run_episode already has its own catch-all, but guard against anything leaking
284
+ log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE])
285
+ print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True)
286
  print("", flush=True)
287
 
288