Spaces:
Sleeping
Sleeping
File size: 8,985 Bytes
98a5a8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
check_leak.py — Validates BudgetRouterGRPOEnv before GRPO training.
Checks:
1. Tool methods return strings (not crash).
2. Episode ends gracefully via ValueError (TRL-idiomatic done signal).
3. Reward is a float in [0, 1] — not a dict, not NaN.
4. History uses actual_degradation_start (jittered) — NOT the config constant.
This proves grade_episode() will compute correct adaptation windows.
5. 10-step reward trajectory printed: verify no explosion/vanishing.
6. Provider status IS present in tool responses (intentional — text interface needs it).
Run:
uv run python check_leak.py
"""
import sys
def main() -> None:
try:
from train.grpo_env import BudgetRouterGRPOEnv
from budget_router.reward import grade_episode
from budget_router.tasks import HARD_MULTI
except ImportError as e:
print(f"[FAIL] Import error: {e}")
sys.exit(1)
print("=" * 60)
print("BudgetRouterGRPOEnv — Pre-training Validation")
print("=" * 60)
# ── Check 0: transformers version (soft warning — required for environment_factory) ──
print("\n[CHECK 0] transformers version (required for environment_factory)...")
try:
import transformers
ver_str = transformers.__version__
# TRL's environment_factory requires transformers >= 4.47.0 (confirmed shipping in
# stable builds as of Apr 2026). Exact minimum threshold is version-specific to TRL.
# If not installed, training will fail at import time — caught here early.
print(f" ✅ transformers=={ver_str} installed.")
# Soft check: warn if below 4.47 (minimum known to ship environment_factory support)
major, minor = int(ver_str.split(".")[0]), int(ver_str.split(".")[1])
if major < 4 or (major == 4 and minor < 47):
print(
f" ⚠️ WARNING: transformers {ver_str} may be too old for environment_factory.\n"
f" Recommended: pip install 'transformers>=4.47.0' or install from main."
)
except ImportError:
print(
" ⚠️ WARNING: transformers is NOT installed in this venv.\n"
" Install before GRPO training: pip install 'transformers>=4.47.0' trl accelerate peft"
)
# ── Check 1: reset() returns a non-empty string ─────────────────────
print("\n[CHECK 1] reset() returns rich text observation...")
env = BudgetRouterGRPOEnv()
obs_text = env.reset(scenario="hard_multi", seed=42)
assert isinstance(obs_text, str) and len(obs_text) > 10, \
f"reset() should return non-empty string, got: {obs_text!r}"
assert "Budget" in obs_text, "reset() should mention Budget"
assert "Provider" in obs_text, "reset() should include provider status (text interface, not sanitized)"
print(f" ✅ reset() returned {len(obs_text)} chars. Provider status PRESENT (correct for text interface).")
print(f" Preview: {obs_text[:120].replace(chr(10), ' ')}...")
# ── Check 2: Tool methods return strings step-by-step ───────────────
print("\n[CHECK 2] Tool methods return strings and accumulate history...")
env2 = BudgetRouterGRPOEnv()
env2.reset(scenario="hard_multi", seed=42)
step_results = []
episode_done = False
for step in range(25): # more than max_steps to test guard
action_fn = [env2.route_to_a, env2.route_to_b, env2.shed_load, env2.route_to_b][step % 4]
try:
result = action_fn()
assert isinstance(result, str), f"Tool method should return str, got {type(result)}"
step_results.append(result)
print(f" Step {step + 1:02d}: ✅ {result[:80].replace(chr(10), ' ')}...")
except ValueError as e:
episode_done = True
print(f" Step {step + 1:02d}: ✅ Episode ended via ValueError (TRL-idiomatic): {str(e)[:80]}...")
break
assert episode_done, "Episode should end with ValueError before step 25"
assert len(step_results) > 0, "At least one tool step should complete"
print(f" ✅ Episode ended correctly after {len(step_results)} tool calls.")
# ── Check 3: Reward is float in [0, 1] ──────────────────────────────
print("\n[CHECK 3] Reward is float in [0, 1]...")
assert isinstance(env2.reward, float), \
f"env.reward should be float, got {type(env2.reward)}: {env2.reward!r}"
assert 0.0 <= env2.reward <= 1.0, \
f"env.reward should be in [0, 1], got {env2.reward}"
import math
assert not math.isnan(env2.reward), "env.reward is NaN — grade_episode bug"
print(f" ✅ env.reward = {env2.reward:.4f} (float, in [0,1], not NaN)")
# ── Check 4: History uses actual jittered degradation_start_step ────
print("\n[CHECK 4] History contains jittered actual_degradation_start (not config constant)...")
history = env2._env._internal.history
assert len(history) > 0, "History should not be empty after episode"
# Read degradation_start_step from step_info (written by environment.py)
step_info_degrade_start = history[0].get("degradation_start_step")
# Read the actual jittered value from internal state
actual_jittered_start = env2._env._internal.actual_degradation_start
# Config constant for hard_multi
config_constant = HARD_MULTI.degradation_start_step # = 0
print(f" Config constant (degradation_start_step): {config_constant}")
print(f" step_info[degradation_start_step]: {step_info_degrade_start}")
print(f" internal.actual_degradation_start: {actual_jittered_start}")
assert step_info_degrade_start is not None, \
"step_info missing degradation_start_step — grade_episode() will break"
assert step_info_degrade_start == actual_jittered_start, \
(f"step_info uses wrong degradation onset! "
f"Got {step_info_degrade_start}, expected {actual_jittered_start}. "
f"This would corrupt adaptation scores in grade_episode().")
print(f" ✅ Jittered onset correctly propagated through step_info.")
# ── Check 5: grade_episode() on history returns consistent score ─────
print("\n[CHECK 5] grade_episode(history) matches env.reward...")
grader_result = grade_episode(history)
assert isinstance(grader_result, dict), "grade_episode should return dict"
grader_score = float(grader_result["overall_score"])
assert abs(grader_score - env2.reward) < 1e-6, \
f"env.reward ({env2.reward}) != grade_episode score ({grader_score}). Mismatch."
print(f" ✅ grade_episode overall_score = {grader_score:.4f}, env.reward = {env2.reward:.4f}. Match confirmed.")
# ── Check 6: 10-episode reward trajectory ────────────────────────────
print("\n[CHECK 6] 10-episode reward trajectory (hard_multi, varying seeds)...")
print(" Episode | Seed | Steps | Score | Reward-in-range")
rewards = []
for ep, seed in enumerate(range(10)):
env3 = BudgetRouterGRPOEnv()
env3.reset(scenario="hard_multi", seed=seed)
done = False
steps = 0
while not done and steps < 30:
# Alternate actions: A, B, A, B... (simple test policy)
action_fn = env3.route_to_a if steps % 2 == 0 else env3.route_to_b
try:
action_fn()
steps += 1
except ValueError:
done = True
reward = env3.reward
rewards.append(reward)
in_range = "✅" if 0.0 <= reward <= 1.0 else "❌"
print(f" Ep {ep+1:02d} | {seed:4d} | {steps:5d} | {reward:.4f} | {in_range}")
import statistics
if len(rewards) > 1:
std = statistics.stdev(rewards)
mean = statistics.mean(rewards)
print(f"\n Mean reward: {mean:.4f} | Std: {std:.4f}")
if std < 0.03:
print(
f" ⚠️ WARNING: Low reward variance (std={std:.4f}). GRPO may get weak gradient signal.\n"
f" Mitigation: Use num_generations=8, hard_multi scenario, and a small LLM\n"
f" at initialization that makes diverse routing decisions."
)
else:
print(f" ✅ Reward variance is sufficient for GRPO learning (std={std:.4f} > 0.03).")
print("\n" + "=" * 60)
print("✅ ALL CHECKS PASSED — BudgetRouterGRPOEnv is ready for GRPO training.")
print("=" * 60)
print("\nRecommended training config (Mac MPS / Colab):")
print(" scenario: hard_multi")
print(" num_generations: 8")
print(" model: Qwen2.5-1.5B (Mac 16GB) / Qwen2.5-7B (Colab T4)")
print(" Mac: TRL + PyTorch MPS (set PYTORCH_ENABLE_MPS_FALLBACK=1)")
print(" Colab: Unsloth + vLLM on NVIDIA T4/A100")
if __name__ == "__main__":
main()
|