Spaces:
Sleeping
Sleeping
Commit Β·
c314a65
1
Parent(s): 3948a09
Enforce strict (0.001, 0.999) bounds on ALL rewards and scores
Browse files- rewards.py: fix early returns for invalid/timeout that skipped clamp
- env_core.py: align _REWARD_MIN/MAX to 0.001/0.999
- graders.py: align _SCORE_MIN/MAX to 0.001/0.999
- inference.py: align _clamp01 to 0.001/0.999
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
backend/src/polypharmacy_env/env_core.py
CHANGED
|
@@ -25,6 +25,15 @@ from .models import (
|
|
| 25 |
PolypharmacyState,
|
| 26 |
)
|
| 27 |
from .rewards import compute_regimen_risk, compute_shaped_reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from .tasks import get_task_config, sample_episode
|
| 29 |
|
| 30 |
|
|
@@ -123,10 +132,10 @@ class PolypharmacyEnv(
|
|
| 123 |
# Validate basic action structure
|
| 124 |
valid, err = self._validate_action(action)
|
| 125 |
if not valid:
|
| 126 |
-
reward = compute_shaped_reward(
|
| 127 |
self._current_risk, self._current_risk,
|
| 128 |
action.action_type, is_invalid=True,
|
| 129 |
-
)
|
| 130 |
info["error"] = err
|
| 131 |
self._step_count += 1
|
| 132 |
return self._check_timeout_and_build_obs(reward, info)
|
|
@@ -140,7 +149,7 @@ class PolypharmacyEnv(
|
|
| 140 |
elif action.action_type == "finish_review":
|
| 141 |
self._done = True
|
| 142 |
score = self._run_grader()
|
| 143 |
-
reward = score # terminal bonus
|
| 144 |
info["grader_score"] = score
|
| 145 |
|
| 146 |
self._step_count += 1
|
|
@@ -187,10 +196,10 @@ class PolypharmacyEnv(
|
|
| 187 |
assert action.drug_id_1 and action.drug_id_2
|
| 188 |
|
| 189 |
if self._remaining_query_budget <= 0:
|
| 190 |
-
reward = compute_shaped_reward(
|
| 191 |
self._current_risk, self._current_risk,
|
| 192 |
"query_ddi", is_invalid=True,
|
| 193 |
-
)
|
| 194 |
info["error"] = "Query budget exhausted"
|
| 195 |
return reward, info
|
| 196 |
|
|
@@ -210,12 +219,12 @@ class PolypharmacyEnv(
|
|
| 210 |
if discovered_severe:
|
| 211 |
self._severe_moderate_discovered += 1
|
| 212 |
|
| 213 |
-
reward = compute_shaped_reward(
|
| 214 |
self._current_risk, self._current_risk,
|
| 215 |
"query_ddi",
|
| 216 |
discovered_severe=(result.severity == "severe"),
|
| 217 |
discovered_moderate=(result.severity == "moderate"),
|
| 218 |
-
)
|
| 219 |
info["ddi_result"] = {
|
| 220 |
"severity": result.severity,
|
| 221 |
"recommendation": result.recommendation,
|
|
@@ -229,10 +238,10 @@ class PolypharmacyEnv(
|
|
| 229 |
assert action.intervention_type and action.intervention_type != "none"
|
| 230 |
|
| 231 |
if self._remaining_intervention_budget <= 0:
|
| 232 |
-
reward = compute_shaped_reward(
|
| 233 |
self._current_risk, self._current_risk,
|
| 234 |
"propose_intervention", is_invalid=True,
|
| 235 |
-
)
|
| 236 |
info["error"] = "Intervention budget exhausted"
|
| 237 |
return reward, info
|
| 238 |
|
|
@@ -244,10 +253,10 @@ class PolypharmacyEnv(
|
|
| 244 |
break
|
| 245 |
|
| 246 |
if target_idx is None:
|
| 247 |
-
reward = compute_shaped_reward(
|
| 248 |
self._current_risk, self._current_risk,
|
| 249 |
"propose_intervention", is_invalid=True,
|
| 250 |
-
)
|
| 251 |
info["error"] = f"Drug {action.target_drug_id} not in current medications"
|
| 252 |
return reward, info
|
| 253 |
|
|
@@ -321,7 +330,7 @@ class PolypharmacyEnv(
|
|
| 321 |
step_index=self._step_count,
|
| 322 |
))
|
| 323 |
|
| 324 |
-
reward = compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention")
|
| 325 |
info["risk_delta"] = risk_delta
|
| 326 |
return reward, info
|
| 327 |
|
|
@@ -354,7 +363,7 @@ class PolypharmacyEnv(
|
|
| 354 |
self._total_drug_changes,
|
| 355 |
self._critical_stopped_without_sub,
|
| 356 |
)
|
| 357 |
-
return 0.001 # strict (0,
|
| 358 |
|
| 359 |
def _get_severe_pairs(self) -> List[Tuple[str, str]]:
|
| 360 |
"""Return all severe DDI pairs present in the *initial* medication list."""
|
|
@@ -377,8 +386,8 @@ class PolypharmacyEnv(
|
|
| 377 |
if not self._done and self._step_count >= self._task_cfg.max_steps:
|
| 378 |
self._done = True
|
| 379 |
score = self._run_grader()
|
| 380 |
-
# Terminal reward = grader score only (strictly in
|
| 381 |
-
reward = score
|
| 382 |
info["timeout"] = True
|
| 383 |
info["grader_score"] = score
|
| 384 |
|
|
|
|
| 25 |
PolypharmacyState,
|
| 26 |
)
|
| 27 |
from .rewards import compute_regimen_risk, compute_shaped_reward
|
| 28 |
+
|
| 29 |
+
# ββ Reward clamping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
_REWARD_MIN = 0.001
|
| 31 |
+
_REWARD_MAX = 0.999
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _clamp_reward(v: float) -> float:
|
| 35 |
+
"""Clamp any reward to strict (0.001, 0.999) bounds."""
|
| 36 |
+
return max(_REWARD_MIN, min(_REWARD_MAX, v))
|
| 37 |
from .tasks import get_task_config, sample_episode
|
| 38 |
|
| 39 |
|
|
|
|
| 132 |
# Validate basic action structure
|
| 133 |
valid, err = self._validate_action(action)
|
| 134 |
if not valid:
|
| 135 |
+
reward = _clamp_reward(compute_shaped_reward(
|
| 136 |
self._current_risk, self._current_risk,
|
| 137 |
action.action_type, is_invalid=True,
|
| 138 |
+
))
|
| 139 |
info["error"] = err
|
| 140 |
self._step_count += 1
|
| 141 |
return self._check_timeout_and_build_obs(reward, info)
|
|
|
|
| 149 |
elif action.action_type == "finish_review":
|
| 150 |
self._done = True
|
| 151 |
score = self._run_grader()
|
| 152 |
+
reward = _clamp_reward(score) # terminal bonus
|
| 153 |
info["grader_score"] = score
|
| 154 |
|
| 155 |
self._step_count += 1
|
|
|
|
| 196 |
assert action.drug_id_1 and action.drug_id_2
|
| 197 |
|
| 198 |
if self._remaining_query_budget <= 0:
|
| 199 |
+
reward = _clamp_reward(compute_shaped_reward(
|
| 200 |
self._current_risk, self._current_risk,
|
| 201 |
"query_ddi", is_invalid=True,
|
| 202 |
+
))
|
| 203 |
info["error"] = "Query budget exhausted"
|
| 204 |
return reward, info
|
| 205 |
|
|
|
|
| 219 |
if discovered_severe:
|
| 220 |
self._severe_moderate_discovered += 1
|
| 221 |
|
| 222 |
+
reward = _clamp_reward(compute_shaped_reward(
|
| 223 |
self._current_risk, self._current_risk,
|
| 224 |
"query_ddi",
|
| 225 |
discovered_severe=(result.severity == "severe"),
|
| 226 |
discovered_moderate=(result.severity == "moderate"),
|
| 227 |
+
))
|
| 228 |
info["ddi_result"] = {
|
| 229 |
"severity": result.severity,
|
| 230 |
"recommendation": result.recommendation,
|
|
|
|
| 238 |
assert action.intervention_type and action.intervention_type != "none"
|
| 239 |
|
| 240 |
if self._remaining_intervention_budget <= 0:
|
| 241 |
+
reward = _clamp_reward(compute_shaped_reward(
|
| 242 |
self._current_risk, self._current_risk,
|
| 243 |
"propose_intervention", is_invalid=True,
|
| 244 |
+
))
|
| 245 |
info["error"] = "Intervention budget exhausted"
|
| 246 |
return reward, info
|
| 247 |
|
|
|
|
| 253 |
break
|
| 254 |
|
| 255 |
if target_idx is None:
|
| 256 |
+
reward = _clamp_reward(compute_shaped_reward(
|
| 257 |
self._current_risk, self._current_risk,
|
| 258 |
"propose_intervention", is_invalid=True,
|
| 259 |
+
))
|
| 260 |
info["error"] = f"Drug {action.target_drug_id} not in current medications"
|
| 261 |
return reward, info
|
| 262 |
|
|
|
|
| 330 |
step_index=self._step_count,
|
| 331 |
))
|
| 332 |
|
| 333 |
+
reward = _clamp_reward(compute_shaped_reward(previous_risk, self._current_risk, "propose_intervention"))
|
| 334 |
info["risk_delta"] = risk_delta
|
| 335 |
return reward, info
|
| 336 |
|
|
|
|
| 363 |
self._total_drug_changes,
|
| 364 |
self._critical_stopped_without_sub,
|
| 365 |
)
|
| 366 |
+
return 0.001 # strict (0.001, 0.999) range required
|
| 367 |
|
| 368 |
def _get_severe_pairs(self) -> List[Tuple[str, str]]:
|
| 369 |
"""Return all severe DDI pairs present in the *initial* medication list."""
|
|
|
|
| 386 |
if not self._done and self._step_count >= self._task_cfg.max_steps:
|
| 387 |
self._done = True
|
| 388 |
score = self._run_grader()
|
| 389 |
+
# Terminal reward = grader score only (strictly in [0.01, 0.99])
|
| 390 |
+
reward = _clamp_reward(score)
|
| 391 |
info["timeout"] = True
|
| 392 |
info["grader_score"] = score
|
| 393 |
|
backend/src/polypharmacy_env/graders.py
CHANGED
|
@@ -12,7 +12,7 @@ from .models import InterventionRecord
|
|
| 12 |
|
| 13 |
_EPS = 1e-8
|
| 14 |
|
| 15 |
-
# Scores must be strictly in (0,
|
| 16 |
_SCORE_MIN = 0.001
|
| 17 |
_SCORE_MAX = 0.999
|
| 18 |
|
|
|
|
| 12 |
|
| 13 |
_EPS = 1e-8
|
| 14 |
|
| 15 |
+
# Scores must be strictly in (0.001, 0.999) β never outside this range
|
| 16 |
_SCORE_MIN = 0.001
|
| 17 |
_SCORE_MAX = 0.999
|
| 18 |
|
backend/src/polypharmacy_env/rewards.py
CHANGED
|
@@ -75,22 +75,20 @@ def compute_shaped_reward(
|
|
| 75 |
reward = 0.0
|
| 76 |
|
| 77 |
if is_invalid:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if action_type == "query_ddi":
|
| 84 |
reward -= QUERY_COST
|
| 85 |
if discovered_severe:
|
| 86 |
reward += SEVERE_DDI_DISCOVERY_BONUS
|
| 87 |
elif discovered_moderate:
|
| 88 |
reward += MODERATE_DDI_DISCOVERY_BONUS
|
| 89 |
-
|
| 90 |
elif action_type == "propose_intervention":
|
| 91 |
reward += (previous_risk - new_risk)
|
| 92 |
reward -= INTERVENTION_COST
|
| 93 |
|
| 94 |
# finish_review terminal bonus is added by the caller after grading
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
| 75 |
reward = 0.0
|
| 76 |
|
| 77 |
if is_invalid:
|
| 78 |
+
reward = -INVALID_ACTION_PENALTY
|
| 79 |
+
elif is_timeout:
|
| 80 |
+
reward = -TIMEOUT_PENALTY
|
| 81 |
+
elif action_type == "query_ddi":
|
|
|
|
|
|
|
| 82 |
reward -= QUERY_COST
|
| 83 |
if discovered_severe:
|
| 84 |
reward += SEVERE_DDI_DISCOVERY_BONUS
|
| 85 |
elif discovered_moderate:
|
| 86 |
reward += MODERATE_DDI_DISCOVERY_BONUS
|
|
|
|
| 87 |
elif action_type == "propose_intervention":
|
| 88 |
reward += (previous_risk - new_risk)
|
| 89 |
reward -= INTERVENTION_COST
|
| 90 |
|
| 91 |
# finish_review terminal bonus is added by the caller after grading
|
| 92 |
|
| 93 |
+
# Clamp all rewards to strict (0.001, 0.999) range
|
| 94 |
+
return max(0.001, min(0.999, reward))
|
inference.py
CHANGED
|
@@ -68,7 +68,7 @@ def _fmt_reward(v: float) -> str:
|
|
| 68 |
|
| 69 |
|
| 70 |
def _clamp01(v: float) -> float:
|
| 71 |
-
"""Clamp score to strict (0,
|
| 72 |
return max(0.001, min(0.999, float(v)))
|
| 73 |
|
| 74 |
|
|
@@ -179,8 +179,8 @@ def _reset(task_id: str) -> Dict[str, Any]:
|
|
| 179 |
def _step(action: Dict[str, Any]) -> Dict[str, Any]:
|
| 180 |
r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45)
|
| 181 |
if r.status_code == 422:
|
| 182 |
-
# Invalid action β return a penalty and let the agent continue
|
| 183 |
-
return {"observation": {}, "reward":
|
| 184 |
r.raise_for_status()
|
| 185 |
return r.json()
|
| 186 |
|
|
@@ -189,7 +189,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
|
|
| 189 |
rewards: List[float] = []
|
| 190 |
steps = 0
|
| 191 |
success = False
|
| 192 |
-
score = 0.001 # strict (0,
|
| 193 |
log_start(task_id)
|
| 194 |
try:
|
| 195 |
reset_payload = _reset(task_id)
|
|
@@ -203,7 +203,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
|
|
| 203 |
action_str = json.dumps(action, separators=(",", ":"))
|
| 204 |
step_payload = _step(action)
|
| 205 |
obs = step_payload.get("observation", {})
|
| 206 |
-
reward = float(step_payload.get("reward") or 0.0)
|
| 207 |
done = bool(step_payload.get("done", False))
|
| 208 |
metadata = (obs or {}).get("metadata", {}) or {}
|
| 209 |
last_error = metadata.get("error")
|
|
@@ -217,7 +217,7 @@ def run_task(client: OpenAI, task_id: str) -> None:
|
|
| 217 |
score = _clamp01(float(raw_score))
|
| 218 |
else:
|
| 219 |
score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1))
|
| 220 |
-
success = score > 0.
|
| 221 |
break
|
| 222 |
except Exception:
|
| 223 |
success = False
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def _clamp01(v: float) -> float:
|
| 71 |
+
"""Clamp score to strict (0.001, 0.999) β never outside this range."""
|
| 72 |
return max(0.001, min(0.999, float(v)))
|
| 73 |
|
| 74 |
|
|
|
|
| 179 |
def _step(action: Dict[str, Any]) -> Dict[str, Any]:
|
| 180 |
r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=45)
|
| 181 |
if r.status_code == 422:
|
| 182 |
+
# Invalid action β return a clamped penalty and let the agent continue
|
| 183 |
+
return {"observation": {}, "reward": 0.01, "done": False, "info": {"error": r.text[:200]}}
|
| 184 |
r.raise_for_status()
|
| 185 |
return r.json()
|
| 186 |
|
|
|
|
| 189 |
rewards: List[float] = []
|
| 190 |
steps = 0
|
| 191 |
success = False
|
| 192 |
+
score = 0.001 # strict (0.001, 0.999) β never outside this range
|
| 193 |
log_start(task_id)
|
| 194 |
try:
|
| 195 |
reset_payload = _reset(task_id)
|
|
|
|
| 203 |
action_str = json.dumps(action, separators=(",", ":"))
|
| 204 |
step_payload = _step(action)
|
| 205 |
obs = step_payload.get("observation", {})
|
| 206 |
+
reward = _clamp01(float(step_payload.get("reward") or 0.0))
|
| 207 |
done = bool(step_payload.get("done", False))
|
| 208 |
metadata = (obs or {}).get("metadata", {}) or {}
|
| 209 |
last_error = metadata.get("error")
|
|
|
|
| 217 |
score = _clamp01(float(raw_score))
|
| 218 |
else:
|
| 219 |
score = _clamp01(sum(max(0.0, r) for r in rewards) / max(len(rewards), 1))
|
| 220 |
+
success = score > 0.001
|
| 221 |
break
|
| 222 |
except Exception:
|
| 223 |
success = False
|