File size: 8,985 Bytes
98a5a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
check_leak.py — Validates BudgetRouterGRPOEnv before GRPO training.

Checks:
  1. Tool methods return strings (not crash).
  2. Episode ends gracefully via ValueError (TRL-idiomatic done signal).
  3. Reward is a float in [0, 1] — not a dict, not NaN.
  4. History uses actual_degradation_start (jittered) — NOT the config constant.
     This proves grade_episode() will compute correct adaptation windows.
  5. 10-step reward trajectory printed: verify no explosion/vanishing.
  6. Provider status IS present in tool responses (intentional — text interface needs it).

Run:
    uv run python check_leak.py
"""

import sys


def main() -> None:
    try:
        from train.grpo_env import BudgetRouterGRPOEnv
        from budget_router.reward import grade_episode
        from budget_router.tasks import HARD_MULTI
    except ImportError as e:
        print(f"[FAIL] Import error: {e}")
        sys.exit(1)

    print("=" * 60)
    print("BudgetRouterGRPOEnv — Pre-training Validation")
    print("=" * 60)

    # ── Check 0: transformers version (soft warning — required for environment_factory) ──
    print("\n[CHECK 0] transformers version (required for environment_factory)...")
    try:
        import transformers
        ver_str = transformers.__version__
        # TRL's environment_factory requires transformers >= 4.47.0 (confirmed shipping in
        # stable builds as of Apr 2026). Exact minimum threshold is version-specific to TRL.
        # If not installed, training will fail at import time — caught here early.
        print(f"  ✅ transformers=={ver_str} installed.")
        # Soft check: warn if below 4.47 (minimum known to ship environment_factory support)
        major, minor = int(ver_str.split(".")[0]), int(ver_str.split(".")[1])
        if major < 4 or (major == 4 and minor < 47):
            print(
                f"  ⚠️  WARNING: transformers {ver_str} may be too old for environment_factory.\n"
                f"     Recommended: pip install 'transformers>=4.47.0' or install from main."
            )
    except ImportError:
        print(
            "  ⚠️  WARNING: transformers is NOT installed in this venv.\n"
            "     Install before GRPO training: pip install 'transformers>=4.47.0' trl accelerate peft"
        )

    # ── Check 1: reset() returns a non-empty string ─────────────────────
    print("\n[CHECK 1] reset() returns rich text observation...")

    env = BudgetRouterGRPOEnv()
    obs_text = env.reset(scenario="hard_multi", seed=42)
    assert isinstance(obs_text, str) and len(obs_text) > 10, \
        f"reset() should return non-empty string, got: {obs_text!r}"
    assert "Budget" in obs_text, "reset() should mention Budget"
    assert "Provider" in obs_text, "reset() should include provider status (text interface, not sanitized)"
    print(f"  ✅ reset() returned {len(obs_text)} chars. Provider status PRESENT (correct for text interface).")
    print(f"  Preview: {obs_text[:120].replace(chr(10), ' ')}...")

    # ── Check 2: Tool methods return strings step-by-step ───────────────
    print("\n[CHECK 2] Tool methods return strings and accumulate history...")
    env2 = BudgetRouterGRPOEnv()
    env2.reset(scenario="hard_multi", seed=42)

    step_results = []
    episode_done = False
    for step in range(25):  # more than max_steps to test guard
        action_fn = [env2.route_to_a, env2.route_to_b, env2.shed_load, env2.route_to_b][step % 4]
        try:
            result = action_fn()
            assert isinstance(result, str), f"Tool method should return str, got {type(result)}"
            step_results.append(result)
            print(f"  Step {step + 1:02d}: ✅ {result[:80].replace(chr(10), ' ')}...")
        except ValueError as e:
            episode_done = True
            print(f"  Step {step + 1:02d}: ✅ Episode ended via ValueError (TRL-idiomatic): {str(e)[:80]}...")
            break

    assert episode_done, "Episode should end with ValueError before step 25"
    assert len(step_results) > 0, "At least one tool step should complete"
    print(f"  ✅ Episode ended correctly after {len(step_results)} tool calls.")

    # ── Check 3: Reward is float in [0, 1] ──────────────────────────────
    print("\n[CHECK 3] Reward is float in [0, 1]...")
    assert isinstance(env2.reward, float), \
        f"env.reward should be float, got {type(env2.reward)}: {env2.reward!r}"
    assert 0.0 <= env2.reward <= 1.0, \
        f"env.reward should be in [0, 1], got {env2.reward}"
    import math
    assert not math.isnan(env2.reward), "env.reward is NaN — grade_episode bug"
    print(f"  ✅ env.reward = {env2.reward:.4f} (float, in [0,1], not NaN)")

    # ── Check 4: History uses actual jittered degradation_start_step ────
    print("\n[CHECK 4] History contains jittered actual_degradation_start (not config constant)...")
    history = env2._env._internal.history
    assert len(history) > 0, "History should not be empty after episode"

    # Read degradation_start_step from step_info (written by environment.py)
    step_info_degrade_start = history[0].get("degradation_start_step")
    # Read the actual jittered value from internal state
    actual_jittered_start = env2._env._internal.actual_degradation_start
    # Config constant for hard_multi
    config_constant = HARD_MULTI.degradation_start_step  # = 0

    print(f"  Config constant (degradation_start_step): {config_constant}")
    print(f"  step_info[degradation_start_step]: {step_info_degrade_start}")
    print(f"  internal.actual_degradation_start: {actual_jittered_start}")

    assert step_info_degrade_start is not None, \
        "step_info missing degradation_start_step — grade_episode() will break"
    assert step_info_degrade_start == actual_jittered_start, \
        (f"step_info uses wrong degradation onset! "
         f"Got {step_info_degrade_start}, expected {actual_jittered_start}. "
         f"This would corrupt adaptation scores in grade_episode().")
    print(f"  ✅ Jittered onset correctly propagated through step_info.")

    # ── Check 5: grade_episode() on history returns consistent score ─────
    print("\n[CHECK 5] grade_episode(history) matches env.reward...")
    grader_result = grade_episode(history)
    assert isinstance(grader_result, dict), "grade_episode should return dict"
    grader_score = float(grader_result["overall_score"])
    assert abs(grader_score - env2.reward) < 1e-6, \
        f"env.reward ({env2.reward}) != grade_episode score ({grader_score}). Mismatch."
    print(f"  ✅ grade_episode overall_score = {grader_score:.4f}, env.reward = {env2.reward:.4f}. Match confirmed.")

    # ── Check 6: 10-episode reward trajectory ────────────────────────────
    print("\n[CHECK 6] 10-episode reward trajectory (hard_multi, varying seeds)...")
    print("  Episode | Seed | Steps | Score | Reward-in-range")
    rewards = []
    for ep, seed in enumerate(range(10)):
        env3 = BudgetRouterGRPOEnv()
        env3.reset(scenario="hard_multi", seed=seed)
        done = False
        steps = 0
        while not done and steps < 30:
            # Alternate actions: A, B, A, B... (simple test policy)
            action_fn = env3.route_to_a if steps % 2 == 0 else env3.route_to_b
            try:
                action_fn()
                steps += 1
            except ValueError:
                done = True
        reward = env3.reward
        rewards.append(reward)
        in_range = "✅" if 0.0 <= reward <= 1.0 else "❌"
        print(f"  Ep {ep+1:02d}     | {seed:4d} | {steps:5d} | {reward:.4f} | {in_range}")

    import statistics
    if len(rewards) > 1:
        std = statistics.stdev(rewards)
        mean = statistics.mean(rewards)
        print(f"\n  Mean reward: {mean:.4f} | Std: {std:.4f}")
        if std < 0.03:
            print(
                f"  ⚠️  WARNING: Low reward variance (std={std:.4f}). GRPO may get weak gradient signal.\n"
                f"     Mitigation: Use num_generations=8, hard_multi scenario, and a small LLM\n"
                f"     at initialization that makes diverse routing decisions."
            )
        else:
            print(f"  ✅ Reward variance is sufficient for GRPO learning (std={std:.4f} > 0.03).")

    print("\n" + "=" * 60)
    print("✅ ALL CHECKS PASSED — BudgetRouterGRPOEnv is ready for GRPO training.")
    print("=" * 60)
    print("\nRecommended training config (Mac MPS / Colab):")
    print("  scenario: hard_multi")
    print("  num_generations: 8")
    print("  model: Qwen2.5-1.5B (Mac 16GB) / Qwen2.5-7B (Colab T4)")
    print("  Mac: TRL + PyTorch MPS (set PYTORCH_ENABLE_MPS_FALLBACK=1)")
    print("  Colab: Unsloth + vLLM on NVIDIA T4/A100")


if __name__ == "__main__":
    main()