test-rl-hackathon-budget / check_leak.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
check_leak.py — Validates BudgetRouterGRPOEnv before GRPO training.
Checks:
1. Tool methods return strings (not crash).
2. Episode ends gracefully via ValueError (TRL-idiomatic done signal).
3. Reward is a float in [0, 1] — not a dict, not NaN.
4. History uses actual_degradation_start (jittered) — NOT the config constant.
This proves grade_episode() will compute correct adaptation windows.
5. 10-step reward trajectory printed: verify no explosion/vanishing.
6. Provider status IS present in tool responses (intentional — text interface needs it).
Run:
uv run python check_leak.py
"""
import sys
def main() -> None:
try:
from train.grpo_env import BudgetRouterGRPOEnv
from budget_router.reward import grade_episode
from budget_router.tasks import HARD_MULTI
except ImportError as e:
print(f"[FAIL] Import error: {e}")
sys.exit(1)
print("=" * 60)
print("BudgetRouterGRPOEnv — Pre-training Validation")
print("=" * 60)
# ── Check 0: transformers version (soft warning — required for environment_factory) ──
print("\n[CHECK 0] transformers version (required for environment_factory)...")
try:
import transformers
ver_str = transformers.__version__
# TRL's environment_factory requires transformers >= 4.47.0 (confirmed shipping in
# stable builds as of Apr 2026). Exact minimum threshold is version-specific to TRL.
# If not installed, training will fail at import time — caught here early.
print(f" ✅ transformers=={ver_str} installed.")
# Soft check: warn if below 4.47 (minimum known to ship environment_factory support)
major, minor = int(ver_str.split(".")[0]), int(ver_str.split(".")[1])
if major < 4 or (major == 4 and minor < 47):
print(
f" ⚠️ WARNING: transformers {ver_str} may be too old for environment_factory.\n"
f" Recommended: pip install 'transformers>=4.47.0' or install from main."
)
except ImportError:
print(
" ⚠️ WARNING: transformers is NOT installed in this venv.\n"
" Install before GRPO training: pip install 'transformers>=4.47.0' trl accelerate peft"
)
# ── Check 1: reset() returns a non-empty string ─────────────────────
print("\n[CHECK 1] reset() returns rich text observation...")
env = BudgetRouterGRPOEnv()
obs_text = env.reset(scenario="hard_multi", seed=42)
assert isinstance(obs_text, str) and len(obs_text) > 10, \
f"reset() should return non-empty string, got: {obs_text!r}"
assert "Budget" in obs_text, "reset() should mention Budget"
assert "Provider" in obs_text, "reset() should include provider status (text interface, not sanitized)"
print(f" ✅ reset() returned {len(obs_text)} chars. Provider status PRESENT (correct for text interface).")
print(f" Preview: {obs_text[:120].replace(chr(10), ' ')}...")
# ── Check 2: Tool methods return strings step-by-step ───────────────
print("\n[CHECK 2] Tool methods return strings and accumulate history...")
env2 = BudgetRouterGRPOEnv()
env2.reset(scenario="hard_multi", seed=42)
step_results = []
episode_done = False
for step in range(25): # more than max_steps to test guard
action_fn = [env2.route_to_a, env2.route_to_b, env2.shed_load, env2.route_to_b][step % 4]
try:
result = action_fn()
assert isinstance(result, str), f"Tool method should return str, got {type(result)}"
step_results.append(result)
print(f" Step {step + 1:02d}: ✅ {result[:80].replace(chr(10), ' ')}...")
except ValueError as e:
episode_done = True
print(f" Step {step + 1:02d}: ✅ Episode ended via ValueError (TRL-idiomatic): {str(e)[:80]}...")
break
assert episode_done, "Episode should end with ValueError before step 25"
assert len(step_results) > 0, "At least one tool step should complete"
print(f" ✅ Episode ended correctly after {len(step_results)} tool calls.")
# ── Check 3: Reward is float in [0, 1] ──────────────────────────────
print("\n[CHECK 3] Reward is float in [0, 1]...")
assert isinstance(env2.reward, float), \
f"env.reward should be float, got {type(env2.reward)}: {env2.reward!r}"
assert 0.0 <= env2.reward <= 1.0, \
f"env.reward should be in [0, 1], got {env2.reward}"
import math
assert not math.isnan(env2.reward), "env.reward is NaN — grade_episode bug"
print(f" ✅ env.reward = {env2.reward:.4f} (float, in [0,1], not NaN)")
# ── Check 4: History uses actual jittered degradation_start_step ────
print("\n[CHECK 4] History contains jittered actual_degradation_start (not config constant)...")
history = env2._env._internal.history
assert len(history) > 0, "History should not be empty after episode"
# Read degradation_start_step from step_info (written by environment.py)
step_info_degrade_start = history[0].get("degradation_start_step")
# Read the actual jittered value from internal state
actual_jittered_start = env2._env._internal.actual_degradation_start
# Config constant for hard_multi
config_constant = HARD_MULTI.degradation_start_step # = 0
print(f" Config constant (degradation_start_step): {config_constant}")
print(f" step_info[degradation_start_step]: {step_info_degrade_start}")
print(f" internal.actual_degradation_start: {actual_jittered_start}")
assert step_info_degrade_start is not None, \
"step_info missing degradation_start_step — grade_episode() will break"
assert step_info_degrade_start == actual_jittered_start, \
(f"step_info uses wrong degradation onset! "
f"Got {step_info_degrade_start}, expected {actual_jittered_start}. "
f"This would corrupt adaptation scores in grade_episode().")
print(f" ✅ Jittered onset correctly propagated through step_info.")
# ── Check 5: grade_episode() on history returns consistent score ─────
print("\n[CHECK 5] grade_episode(history) matches env.reward...")
grader_result = grade_episode(history)
assert isinstance(grader_result, dict), "grade_episode should return dict"
grader_score = float(grader_result["overall_score"])
assert abs(grader_score - env2.reward) < 1e-6, \
f"env.reward ({env2.reward}) != grade_episode score ({grader_score}). Mismatch."
print(f" ✅ grade_episode overall_score = {grader_score:.4f}, env.reward = {env2.reward:.4f}. Match confirmed.")
# ── Check 6: 10-episode reward trajectory ────────────────────────────
print("\n[CHECK 6] 10-episode reward trajectory (hard_multi, varying seeds)...")
print(" Episode | Seed | Steps | Score | Reward-in-range")
rewards = []
for ep, seed in enumerate(range(10)):
env3 = BudgetRouterGRPOEnv()
env3.reset(scenario="hard_multi", seed=seed)
done = False
steps = 0
while not done and steps < 30:
# Alternate actions: A, B, A, B... (simple test policy)
action_fn = env3.route_to_a if steps % 2 == 0 else env3.route_to_b
try:
action_fn()
steps += 1
except ValueError:
done = True
reward = env3.reward
rewards.append(reward)
in_range = "✅" if 0.0 <= reward <= 1.0 else "❌"
print(f" Ep {ep+1:02d} | {seed:4d} | {steps:5d} | {reward:.4f} | {in_range}")
import statistics
if len(rewards) > 1:
std = statistics.stdev(rewards)
mean = statistics.mean(rewards)
print(f"\n Mean reward: {mean:.4f} | Std: {std:.4f}")
if std < 0.03:
print(
f" ⚠️ WARNING: Low reward variance (std={std:.4f}). GRPO may get weak gradient signal.\n"
f" Mitigation: Use num_generations=8, hard_multi scenario, and a small LLM\n"
f" at initialization that makes diverse routing decisions."
)
else:
print(f" ✅ Reward variance is sufficient for GRPO learning (std={std:.4f} > 0.03).")
print("\n" + "=" * 60)
print("✅ ALL CHECKS PASSED — BudgetRouterGRPOEnv is ready for GRPO training.")
print("=" * 60)
print("\nRecommended training config (Mac MPS / Colab):")
print(" scenario: hard_multi")
print(" num_generations: 8")
print(" model: Qwen2.5-1.5B (Mac 16GB) / Qwen2.5-7B (Colab T4)")
print(" Mac: TRL + PyTorch MPS (set PYTORCH_ENABLE_MPS_FALLBACK=1)")
print(" Colab: Unsloth + vLLM on NVIDIA T4/A100")
if __name__ == "__main__":
main()