ar9avg commited on
Commit
e99d0aa
Β·
1 Parent(s): 2014920

Clamp all remaining score leak paths: /state, step_rewards, demo SSE

Browse files

- ep.step_rewards now stores clamped task_score instead of raw RL reward
(compute_reward can return exact 1.0 or negative values)
- /state endpoint returns clamped step_rewards and total_reward
- Removed info.rl_reward from /step response (leaked unclamped reward)
- demo.py SSE chat/benchmark endpoints now clamp task_score
- Added .dockerignore excluding backend/data/ so persisted bandit state
from prior runs can't leak unclamped values into the image

Single source of truth: _clamp_score() in sql_env.py β†’ [0.05, 0.95]

Files changed (3) hide show
  1. .dockerignore +18 -0
  2. backend/api/demo.py +3 -3
  3. backend/env/sql_env.py +31 -13
.dockerignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .git/
6
+ .gitignore
7
+ .venv/
8
+ venv/
9
+ node_modules/
10
+ .DS_Store
11
+ *.log
12
+
13
+ # Exclude persisted bandit/GEPA state so the image starts with a clean slate.
14
+ # Any historical rewards stored on disk could contain unclamped values and
15
+ # leak into validator runs.
16
+ backend/data/rl_experiences.json
17
+ backend/data/gepa_prompt.json
18
+ frontend/node_modules/
backend/api/demo.py CHANGED
@@ -42,7 +42,7 @@ _DIFFICULTY_MAP = {
42
  "hard": "complex_queries",
43
  }
44
  from env.tasks import TASKS, get_task
45
- from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql
46
  from rl.environment import get_bandit_state
47
  from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
48
  from rl.error_classifier import classify_error, extract_offending_token
@@ -275,7 +275,7 @@ async def execute_query_stream(req: ExecuteQueryRequest):
275
 
276
  # For free-form chat, success = no SQL error (not task grader)
277
  attempt_success = (error is None)
278
- task_score = 1.0 if attempt_success else 0.0
279
 
280
  current_error_class = None
281
  error_class_name = None
@@ -544,7 +544,7 @@ async def run_benchmark(req: BenchmarkRequest):
544
  attempt = 0
545
  sql = ""
546
  success = False
547
- task_score = 0.0
548
  max_attempts = env.MAX_ATTEMPTS
549
  ep = env._episode # type: ignore[union-attr]
550
 
 
42
  "hard": "complex_queries",
43
  }
44
  from env.tasks import TASKS, get_task
45
+ from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql, _clamp_score
46
  from rl.environment import get_bandit_state
47
  from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
48
  from rl.error_classifier import classify_error, extract_offending_token
 
275
 
276
  # For free-form chat, success = no SQL error (not task grader)
277
  attempt_success = (error is None)
278
+ task_score = _clamp_score(1.0 if attempt_success else 0.0)
279
 
280
  current_error_class = None
281
  error_class_name = None
 
544
  attempt = 0
545
  sql = ""
546
  success = False
547
+ task_score = _clamp_score(0.0)
548
  max_attempts = env.MAX_ATTEMPTS
549
  ep = env._episode # type: ignore[union-attr]
550
 
backend/env/sql_env.py CHANGED
@@ -72,6 +72,27 @@ _MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
72
  HF_TOKEN = os.environ.get("HF_TOKEN") # no default β€” must be set explicitly
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _make_client() -> AsyncOpenAI:
76
  return AsyncOpenAI(
77
  api_key=HF_TOKEN,
@@ -316,7 +337,8 @@ class SQLAgentEnv:
316
  )
317
  ep.steps.append(step_obj)
318
 
319
- ep.step_rewards.append(grader_out.reward)
 
320
  ep.current_sql = generated_sql
321
  ep.error_message = error
322
  ep.error_class = error_class_name
@@ -331,20 +353,14 @@ class SQLAgentEnv:
331
  ep.success = success
332
 
333
  obs = self._build_observation()
 
334
  reward_info = RewardInfo(
335
- value=task_score, # always in (eps, 1-eps) per OpenEnv spec
336
  success=success,
337
  done=done,
338
  info={
339
- "task_score": task_score,
340
- "rl_reward": grader_out.reward,
341
  "attempt": ep.attempt_number,
342
- "breakdown": {
343
- "base": grader_out.breakdown.base,
344
- "attempt_penalty": grader_out.breakdown.attempt_penalty,
345
- "severity_bonus": grader_out.breakdown.severity_bonus,
346
- "change_bonus": grader_out.breakdown.change_bonus,
347
- },
348
  "rows": rows[:5] if rows else [],
349
  "row_count": len(rows),
350
  "sql": generated_sql,
@@ -458,7 +474,7 @@ class SQLAgentEnv:
458
  ep.steps.append(step_obj)
459
  self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward)
460
 
461
- ep.step_rewards.append(grader_out.reward)
462
  ep.current_sql = generated_sql
463
  ep.error_message = error
464
  ep.error_class = error_class_name
@@ -500,6 +516,8 @@ class SQLAgentEnv:
500
  if self._episode is None:
501
  return {"active": False}
502
  ep = self._episode
 
 
503
  return {
504
  "active": True,
505
  "task_id": ep.task_id,
@@ -512,8 +530,8 @@ class SQLAgentEnv:
512
  "error_class": ep.error_class,
513
  "done": ep.done,
514
  "success": ep.success,
515
- "step_rewards": ep.step_rewards,
516
- "total_reward": compute_episode_reward(ep.step_rewards, ep.success),
517
  }
518
 
519
  # ─── Private Helpers ──────────────────────────────────────────
 
72
  HF_TOKEN = os.environ.get("HF_TOKEN") # no default β€” must be set explicitly
73
 
74
 
75
+ # ─── Score clamping (strictly in (0, 1)) ─────────────────────────
76
+
77
+ _SCORE_MIN = 0.05
78
+ _SCORE_MAX = 0.95
79
+
80
+
81
+ def _clamp_score(x) -> float:
82
+ """Coerce any value into strictly (0, 1). None/NaN/invalid β†’ _SCORE_MIN."""
83
+ try:
84
+ if x is None:
85
+ return _SCORE_MIN
86
+ if isinstance(x, bool):
87
+ return _SCORE_MAX if x else _SCORE_MIN
88
+ v = float(x)
89
+ if v != v or v == float("inf") or v == float("-inf"):
90
+ return _SCORE_MIN if v != float("inf") else _SCORE_MAX
91
+ except (TypeError, ValueError):
92
+ return _SCORE_MIN
93
+ return max(_SCORE_MIN, min(_SCORE_MAX, v))
94
+
95
+
96
  def _make_client() -> AsyncOpenAI:
97
  return AsyncOpenAI(
98
  api_key=HF_TOKEN,
 
337
  )
338
  ep.steps.append(step_obj)
339
 
340
+ # Store clamped reward so /state never returns raw RL values
341
+ ep.step_rewards.append(_clamp_score(task_score))
342
  ep.current_sql = generated_sql
343
  ep.error_message = error
344
  ep.error_class = error_class_name
 
353
  ep.success = success
354
 
355
  obs = self._build_observation()
356
+ safe_task_score = _clamp_score(task_score)
357
  reward_info = RewardInfo(
358
+ value=safe_task_score, # strictly in (0, 1) per OpenEnv spec
359
  success=success,
360
  done=done,
361
  info={
362
+ "task_score": safe_task_score,
 
363
  "attempt": ep.attempt_number,
 
 
 
 
 
 
364
  "rows": rows[:5] if rows else [],
365
  "row_count": len(rows),
366
  "sql": generated_sql,
 
474
  ep.steps.append(step_obj)
475
  self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward)
476
 
477
+ ep.step_rewards.append(_clamp_score(task_score))
478
  ep.current_sql = generated_sql
479
  ep.error_message = error
480
  ep.error_class = error_class_name
 
516
  if self._episode is None:
517
  return {"active": False}
518
  ep = self._episode
519
+ safe_rewards = [_clamp_score(r) for r in ep.step_rewards]
520
+ total = sum(safe_rewards) / max(len(safe_rewards), 1) if safe_rewards else _SCORE_MIN
521
  return {
522
  "active": True,
523
  "task_id": ep.task_id,
 
530
  "error_class": ep.error_class,
531
  "done": ep.done,
532
  "success": ep.success,
533
+ "step_rewards": safe_rewards,
534
+ "total_reward": _clamp_score(total),
535
  }
536
 
537
  # ─── Private Helpers ──────────────────────────────────────────