Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """End-to-end smoke tests for the meta-RL training pipeline (no GPU). | |
| Validates that dataset generation, the LLM-output parser, and all four | |
| reward functions agree with each other before any GPU training spend. | |
| The most important check is reward_variance_across_completion_kinds — | |
| that's what catches the iter-1 mode-collapse class of bug (a reward | |
| layer returns the same value for every completion in a GRPO group, | |
| contributing zero to advantage). | |
| """ | |
| import random | |
| import pytest | |
| from models import ActionType | |
| from server.rhythm_environment import sample_profile, profile_to_belief_vector | |
| from training.dataset import generate_dataset | |
| from training.reward_functions import ( | |
| action_legal, | |
| belief_accuracy, | |
| env_reward, | |
| extract_action_and_belief, | |
| format_valid, | |
| ) | |
| def small_dataset(): | |
| """20 episodes, ~80 samples max — enough variety for reward checks.""" | |
| return generate_dataset( | |
| num_episodes=20, | |
| strategy="mixed", | |
| max_samples=80, | |
| hint_fraction=0.1, | |
| ) | |
| def _synth_completion(seed: int, kind: str) -> str: | |
| """Synthesize a completion of a given quality.""" | |
| rng = random.Random(seed) | |
| action_str = rng.choice(list(ActionType)).value.upper() | |
| s, m, w = rng.randint(0, 9), rng.randint(0, 9), rng.randint(0, 9) | |
| if kind == "perfect": | |
| true = profile_to_belief_vector(sample_profile(seed)) | |
| s, m, w = round(true[0] * 9), round(true[1] * 9), round(true[2] * 9) | |
| return f"{s} {m} {w} {action_str}" | |
| if kind == "good": | |
| return f"{s} {m} {w} {action_str}" | |
| if kind == "action_only": | |
| return action_str | |
| if kind == "wrong_belief": | |
| true = profile_to_belief_vector(sample_profile(seed)) | |
| s = round((1 - true[0]) * 9) | |
| m = round((1 - true[1]) * 9) | |
| w = round((1 - true[2]) * 9) | |
| return f"{s} {m} {w} {action_str}" | |
| if kind == "garbage": | |
| return "I don't know what to do here" | |
| return action_str | |
| # --------------------------------------------------------------------------- | |
| # Dataset shape | |
| # --------------------------------------------------------------------------- | |
| def test_dataset_is_non_empty(small_dataset): | |
| assert len(small_dataset) > 0 | |
| def test_dataset_rows_have_required_replay_columns(small_dataset): | |
| expected = {"prompt", "seed", "step_index", "action_history", "profile_mode"} | |
| for row in small_dataset: | |
| missing = expected - row.keys() | |
| assert not missing, f"row missing columns: {missing}" | |
| def test_dataset_prompts_are_chat_messages(small_dataset): | |
| for row in small_dataset[:5]: | |
| msgs = row["prompt"] | |
| assert isinstance(msgs, list) and len(msgs) == 2 | |
| assert msgs[0]["role"] == "system" | |
| assert msgs[1]["role"] == "user" | |
| # --------------------------------------------------------------------------- | |
| # Parser | |
| # --------------------------------------------------------------------------- | |
| def test_parser_belief_first_format(): | |
| action, belief, provided = extract_action_and_belief("3 8 7 DEEP_WORK") | |
| assert action == ActionType.DEEP_WORK | |
| assert belief == pytest.approx([3 / 9, 8 / 9, 7 / 9], abs=1e-3) | |
| assert provided is True | |
| def test_parser_action_only_returns_default_belief(): | |
| action, belief, provided = extract_action_and_belief("DEEP_WORK") | |
| assert action == ActionType.DEEP_WORK | |
| assert belief == [0.5, 0.5, 0.5] | |
| assert provided is False | |
| def test_parser_garbage_returns_none(): | |
| action, _, _ = extract_action_and_belief("I don't know what to do here") | |
| assert action is None | |
| # --------------------------------------------------------------------------- | |
| # Reward layers run end-to-end on synth completions | |
| # --------------------------------------------------------------------------- | |
| def replay_columns(small_dataset): | |
| # Skip step_index=0 samples — belief_accuracy intentionally returns 0 | |
| # at step 0 (no info yet), which would mask gradient checks. | |
| sub = [s for s in small_dataset if s["step_index"] > 0][:30] | |
| return { | |
| "samples": sub, | |
| "seed": [s["seed"] for s in sub], | |
| "history": [s["action_history"] for s in sub], | |
| "mode": [s["profile_mode"] for s in sub], | |
| "step_index": [s["step_index"] for s in sub], | |
| } | |
| def test_all_reward_layers_return_floats(kind, replay_columns): | |
| sub = replay_columns["samples"] | |
| completions = [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] | |
| f = format_valid(completions) | |
| l = action_legal(completions) | |
| e = env_reward( | |
| completions, | |
| seed=replay_columns["seed"], | |
| action_history=replay_columns["history"], | |
| profile_mode=replay_columns["mode"], | |
| step_index=replay_columns["step_index"], | |
| ) | |
| b = belief_accuracy( | |
| completions, | |
| seed=replay_columns["seed"], | |
| action_history=replay_columns["history"], | |
| profile_mode=replay_columns["mode"], | |
| step_index=replay_columns["step_index"], | |
| ) | |
| for layer_name, scores in [("format_valid", f), ("action_legal", l), ("env_reward", e), ("belief_accuracy", b)]: | |
| assert len(scores) == len(completions), f"{layer_name} length mismatch" | |
| for s in scores: | |
| assert isinstance(s, float), f"{layer_name} returned non-float: {type(s)}" | |
| # --------------------------------------------------------------------------- | |
| # Reward layers DISCRIMINATE between completion qualities (anti-mode-collapse) | |
| # --------------------------------------------------------------------------- | |
| def test_format_valid_discriminates_belief_vs_action_only(replay_columns): | |
| """format_valid must reward belief+action higher than action-only.""" | |
| sub = replay_columns["samples"] | |
| good = [[{"content": _synth_completion(s["seed"], "good")}] for s in sub] | |
| action_only = [[{"content": _synth_completion(s["seed"], "action_only")}] for s in sub] | |
| good_avg = sum(format_valid(good)) / len(good) | |
| action_only_avg = sum(format_valid(action_only)) / len(action_only) | |
| assert good_avg > action_only_avg, ( | |
| f"format_valid did not push toward belief output: " | |
| f"good={good_avg:.3f} action_only={action_only_avg:.3f}" | |
| ) | |
| def test_belief_accuracy_discriminates_perfect_vs_wrong(replay_columns): | |
| """The whole point of the meta-RL signal: better belief → higher reward.""" | |
| sub = replay_columns["samples"] | |
| perfect = [[{"content": _synth_completion(s["seed"], "perfect")}] for s in sub] | |
| wrong = [[{"content": _synth_completion(s["seed"], "wrong_belief")}] for s in sub] | |
| perfect_avg = sum(belief_accuracy( | |
| perfect, | |
| seed=replay_columns["seed"], | |
| action_history=replay_columns["history"], | |
| profile_mode=replay_columns["mode"], | |
| step_index=replay_columns["step_index"], | |
| )) / len(perfect) | |
| wrong_avg = sum(belief_accuracy( | |
| wrong, | |
| seed=replay_columns["seed"], | |
| action_history=replay_columns["history"], | |
| profile_mode=replay_columns["mode"], | |
| step_index=replay_columns["step_index"], | |
| )) / len(wrong) | |
| assert perfect_avg > wrong_avg, ( | |
| f"belief_accuracy gave no gradient: perfect={perfect_avg:.3f} wrong={wrong_avg:.3f}" | |
| ) | |
| def test_no_reward_layer_is_constant_across_kinds(replay_columns): | |
| """The iter-1 collapse trap: a reward layer returning the same value for every | |
| completion in a GRPO group contributes zero to advantage. At least one of | |
| {good, action_only, garbage, wrong_belief} must produce a different mean | |
| score from the others for each layer that's supposed to be a learning signal. | |
| """ | |
| sub = replay_columns["samples"] | |
| kinds = ["good", "action_only", "garbage", "wrong_belief"] | |
| completions_by_kind = { | |
| kind: [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] for kind in kinds | |
| } | |
| # format_valid and belief_accuracy MUST discriminate. | |
| f_means = [sum(format_valid(completions_by_kind[k])) / len(sub) for k in kinds] | |
| assert max(f_means) - min(f_means) > 0.1, f"format_valid is near-constant: {f_means}" | |
| b_means = [ | |
| sum(belief_accuracy( | |
| completions_by_kind[k], | |
| seed=replay_columns["seed"], | |
| action_history=replay_columns["history"], | |
| profile_mode=replay_columns["mode"], | |
| step_index=replay_columns["step_index"], | |
| )) / len(sub) | |
| for k in kinds | |
| ] | |
| assert max(b_means) - min(b_means) > 0.05, f"belief_accuracy is near-constant: {b_means}" | |