TheJackBright Claude Opus 4.6 commited on
Commit
c314a65
Β·
1 Parent(s): 3948a09

Enforce strict (0.001, 0.999) bounds on ALL rewards and scores

Browse files

- rewards.py: fix early returns for invalid/timeout that skipped clamp
- env_core.py: align _REWARD_MIN/MAX to 0.001/0.999
- graders.py: align _SCORE_MIN/MAX to 0.001/0.999
- inference.py: align _clamp01 to 0.001/0.999

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

backend/src/polypharmacy_env/env_core.py CHANGED
@@ -25,6 +25,15 @@ from .models import (
25
  PolypharmacyState,
26
  )
27
  from .rewards import compute_regimen_risk, compute_shaped_reward
 
 
 
 
 
 
 
 
 
28
  from .tasks import get_task_config, sample_episode
29
 
30
 
@@ -123,10 +132,10 @@ class PolypharmacyEnv(
123
  # Validate basic action structure
124
  valid, err = self._validate_action(action)
125
  if not valid:
126
- reward = compute_shaped_reward(
127
  self._current_risk, self._current_risk,
128
  action.action_type, is_invalid=True,
129
- )
130
  info["error"] = err
131
  self._step_count += 1
132
  return self._check_timeout_and_build_obs(reward, info)
@@ -140,7 +149,7 @@ class PolypharmacyEnv(
140
  elif action.action_type == "finish_review":
141
  self._done = True
142
  score = self._run_grader()
143
- reward = score # terminal bonus
144
  info["grader_score"] = score
145
 
146
  self._step_count += 1
@@ -187,10 +196,10 @@ class PolypharmacyEnv(
187
  assert action.drug_id_1 and action.drug_id_2
188
 
189
  if self._remaining_query_budget <= 0:
190
- reward = compute_shaped_reward(
191
  self._current_risk, self._current_risk,
192
  "query_ddi", is_invalid=True,
193
- )
194
  info["error"] = "Query budget exhausted"
195
  return reward, info
196
 
@@ -210,12 +219,12 @@ class PolypharmacyEnv(
210
  if discovered_severe:
211
  self._severe_moderate_discovered += 1
212
 
213
- reward = compute_shaped_reward(
214
  self._current_risk, self._current_risk,
215
  "query_ddi",
216
  discovered_severe=(result.severity == "severe"),
217
  discovered_moderate=(result.severity == "moderate"),
218
- )
219
  info["ddi_result"] = {
220
  "severity": result.severity,
221
  "recommendation": result.recommendation,
@@ -229,10 +238,10 @@ class PolypharmacyEnv(
229
  assert action.intervention_type and action.intervention_type != "none"
230
 
231
  if self._remaining_intervention_budget <= 0:
232
- reward = compute_shaped_reward(
233
  self._current_risk, self._current_risk,
234
  "propose_intervention", is_invalid=True,
235
- )
236
  info["error"] = "Intervention budget exhausted"
237
  return reward, info
238
 
@@ -244,10 +253,10 @@ class PolypharmacyEnv(
244
  break
245
 
246
  if target_idx is None:
247
- reward = compute_shaped_reward(
248
  self._current_risk, self._current_risk,
249
  "propose_intervention", is_invalid=True,
250
- )
251
  info["error"] = f"Drug {action.target_drug_id} not in current medications"
252
  return reward, info
253
 
@@ -321,7 +330,7 @@ class PolypharmacyEnv(
321
  step_index=self._step_count,
322
  ))
323
 
324
- reward = compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention")
325
  info["risk_delta"] = risk_delta
326
  return reward, info
327
 
@@ -354,7 +363,7 @@ class PolypharmacyEnv(
354
  self._total_drug_changes,
355
  self._critical_stopped_without_sub,
356
  )
357
- return 0.001 # strict (0, 1) range required
358
 
359
  def _get_severe_pairs(self) -> List[Tuple[str, str]]:
360
  """Return all severe DDI pairs present in the *initial* medication list."""
@@ -377,8 +386,8 @@ class PolypharmacyEnv(
377
  if not self._done and self._step_count >= self._task_cfg.max_steps:
378
  self._done = True
379
  score = self._run_grader()
380
- # Terminal reward = grader score only (strictly in (0, 1))
381
- reward = score
382
  info["timeout"] = True
383
  info["grader_score"] = score
384
 
 
25
  PolypharmacyState,
26
  )
27
  from .rewards import compute_regimen_risk, compute_shaped_reward
28
+
29
+ # ── Reward clamping ──────────────────────────────────────────────────────────
30
+ _REWARD_MIN = 0.001
31
+ _REWARD_MAX = 0.999
32
+
33
+
34
+ def _clamp_reward(v: float) -> float:
35
+ """Clamp any reward to strict (0.001, 0.999) bounds."""
36
+ return max(_REWARD_MIN, min(_REWARD_MAX, v))
37
  from .tasks import get_task_config, sample_episode
38
 
39
 
 
132
  # Validate basic action structure
133
  valid, err = self._validate_action(action)
134
  if not valid:
135
+ reward = _clamp_reward(compute_shaped_reward(
136
  self._current_risk, self._current_risk,
137
  action.action_type, is_invalid=True,
138
+ ))
139
  info["error"] = err
140
  self._step_count += 1
141
  return self._check_timeout_and_build_obs(reward, info)
 
149
  elif action.action_type == "finish_review":
150
  self._done = True
151
  score = self._run_grader()
152
+ reward = _clamp_reward(score) # terminal bonus
153
  info["grader_score"] = score
154
 
155
  self._step_count += 1
 
196
  assert action.drug_id_1 and action.drug_id_2
197
 
198
  if self._remaining_query_budget <= 0:
199
+ reward = _clamp_reward(compute_shaped_reward(
200
  self._current_risk, self._current_risk,
201
  "query_ddi", is_invalid=True,
202
+ ))
203
  info["error"] = "Query budget exhausted"
204
  return reward, info
205
 
 
219
  if discovered_severe:
220
  self._severe_moderate_discovered += 1
221
 
222
+ reward = _clamp_reward(compute_shaped_reward(
223
  self._current_risk, self._current_risk,
224
  "query_ddi",
225
  discovered_severe=(result.severity == "severe"),
226
  discovered_moderate=(result.severity == "moderate"),
227
+ ))
228
  info["ddi_result"] = {
229
  "severity": result.severity,
230
  "recommendation": result.recommendation,
 
238
  assert action.intervention_type and action.intervention_type != "none"
239
 
240
  if self._remaining_intervention_budget <= 0:
241
+ reward = _clamp_reward(compute_shaped_reward(
242
  self._current_risk, self._current_risk,
243
  "propose_intervention", is_invalid=True,
244
+ ))
245
  info["error"] = "Intervention budget exhausted"
246
  return reward, info
247
 
 
253
  break
254
 
255
  if target_idx is None:
256
+ reward = _clamp_reward(compute_shaped_reward(
257
  self._current_risk, self._current_risk,
258
  "propose_intervention", is_invalid=True,
259
+ ))
260
  info["error"] = f"Drug {action.target_drug_id} not in current medications"
261
  return reward, info
262
 
 
330
  step_index=self._step_count,
331
  ))
332
 
333
+ reward = _clamp_reward(compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention"))
334
  info["risk_delta"] = risk_delta
335
  return reward, info
336
 
 
363
  self._total_drug_changes,
364
  self._critical_stopped_without_sub,
365
  )
366
+ return 0.001 # strict (0.001, 0.999) range required
367
 
368
  def _get_severe_pairs(self) -> List[Tuple[str, str]]:
369
  """Return all severe DDI pairs present in the *initial* medication list."""
 
386
  if not self._done and self._step_count >= self._task_cfg.max_steps:
387
  self._done = True
388
  score = self._run_grader()
389
+ # Terminal reward = grader score only (strictly in [0.01, 0.99])
390
+ reward = _clamp_reward(score)
391
  info["timeout"] = True
392
  info["grader_score"] = score
393
 
backend/src/polypharmacy_env/graders.py CHANGED
@@ -12,7 +12,7 @@ from .models import InterventionRecord
12
 
13
  _EPS = 1e-8
14
 
15
- # Scores must be strictly in (0, 1) β€” never exactly 0.0 or 1.0
16
  _SCORE_MIN = 0.001
17
  _SCORE_MAX = 0.999
18
 
 
12
 
13
  _EPS = 1e-8
14
 
15
+ # Scores must be strictly in (0.001, 0.999) β€” never outside this range
16
  _SCORE_MIN = 0.001
17
  _SCORE_MAX = 0.999
18
 
backend/src/polypharmacy_env/rewards.py CHANGED
@@ -75,22 +75,20 @@ def compute_shaped_reward(
75
  reward = 0.0
76
 
77
  if is_invalid:
78
- return -INVALID_ACTION_PENALTY
79
-
80
- if is_timeout:
81
- return -TIMEOUT_PENALTY
82
-
83
- if action_type == "query_ddi":
84
  reward -= QUERY_COST
85
  if discovered_severe:
86
  reward += SEVERE_DDI_DISCOVERY_BONUS
87
  elif discovered_moderate:
88
  reward += MODERATE_DDI_DISCOVERY_BONUS
89
-
90
  elif action_type == "propose_intervention":
91
  reward += (previous_risk - new_risk)
92
  reward -= INTERVENTION_COST
93
 
94
  # finish_review terminal bonus is added by the caller after grading
95
 
96
- return max(0.001, reward)
 
 
75
  reward = 0.0
76
 
77
  if is_invalid:
78
+ reward = -INVALID_ACTION_PENALTY
79
+ elif is_timeout:
80
+ reward = -TIMEOUT_PENALTY
81
+ elif action_type == "query_ddi":
 
 
82
  reward -= QUERY_COST
83
  if discovered_severe:
84
  reward += SEVERE_DDI_DISCOVERY_BONUS
85
  elif discovered_moderate:
86
  reward += MODERATE_DDI_DISCOVERY_BONUS
 
87
  elif action_type == "propose_intervention":
88
  reward += (previous_risk - new_risk)
89
  reward -= INTERVENTION_COST
90
 
91
  # finish_review terminal bonus is added by the caller after grading
92
 
93
+ # Clamp all rewards to strict (0.001, 0.999) range
94
+ return max(0.001, min(0.999, reward))
inference.py CHANGED
@@ -68,7 +68,7 @@ def _fmt_reward(v: float) -> str:
68
 
69
 
70
  def _clamp01(v: float) -> float:
71
- """Clamp score to strict (0, 1) β€” never exactly 0.0 or 1.0."""
72
  return max(0.001, min(0.999, float(v)))
73
 
74
 
@@ -179,8 +179,8 @@ def _reset(task_id: str) -> Dict[str, Any]:
179
  def _step(action: Dict[str, Any]) -> Dict[str, Any]:
180
  r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45)
181
  if r.status_code == 422:
182
- # Invalid action β€” return a penalty and let the agent continue
183
- return {"observation": {}, "reward": -0.1, "done": False, "info": {"error": r.text[:200]}}
184
  r.raise_for_status()
185
  return r.json()
186
 
@@ -189,7 +189,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
189
  rewards: List[float] = []
190
  steps = 0
191
  success = False
192
- score = 0.001 # strict (0, 1) β€” never exactly 0.0
193
  log_start(task_id)
194
  try:
195
  reset_payload = _reset(task_id)
@@ -203,7 +203,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
203
  action_str = json.dumps(action, separators=(",", ":"))
204
  step_payload = _step(action)
205
  obs = step_payload.get("observation", {})
206
- reward = float(step_payload.get("reward") or 0.0)
207
  done = bool(step_payload.get("done", False))
208
  metadata = (obs or {}).get("metadata", {}) or {}
209
  last_error = metadata.get("error")
@@ -217,7 +217,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
217
  score = _clamp01(float(raw_score))
218
  else:
219
  score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1))
220
- success = score > 0.0
221
  break
222
  except Exception:
223
  success = False
 
68
 
69
 
70
  def _clamp01(v: float) -> float:
71
+ """Clamp score to strict (0.001, 0.999) β€” never outside this range."""
72
  return max(0.001, min(0.999, float(v)))
73
 
74
 
 
179
  def _step(action: Dict[str, Any]) -> Dict[str, Any]:
180
  r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45)
181
  if r.status_code == 422:
182
+ # Invalid action β€” return a clamped penalty and let the agent continue
183
+ return {"observation": {}, "reward": 0.01, "done": False, "info": {"error": r.text[:200]}}
184
  r.raise_for_status()
185
  return r.json()
186
 
 
189
  rewards: List[float] = []
190
  steps = 0
191
  success = False
192
+ score = 0.001 # strict (0.001, 0.999) β€” never outside this range
193
  log_start(task_id)
194
  try:
195
  reset_payload = _reset(task_id)
 
203
  action_str = json.dumps(action, separators=(",", ":"))
204
  step_payload = _step(action)
205
  obs = step_payload.get("observation", {})
206
+ reward = _clamp01(float(step_payload.get("reward") or 0.0))
207
  done = bool(step_payload.get("done", False))
208
  metadata = (obs or {}).get("metadata", {}) or {}
209
  last_error = metadata.get("error")
 
217
  score = _clamp01(float(raw_score))
218
  else:
219
  score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1))
220
+ success = score > 0.001
221
  break
222
  except Exception:
223
  success = False