# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """End-to-end smoke tests for the meta-RL training pipeline (no GPU). Validates that dataset generation, the LLM-output parser, and all four reward functions agree with each other before any GPU training spend. The most important check is reward_variance_across_completion_kinds — that's what catches the iter-1 mode-collapse class of bug (a reward layer returns the same value for every completion in a GRPO group, contributing zero to advantage). """ import random import pytest from models import ActionType from server.rhythm_environment import sample_profile, profile_to_belief_vector from training.dataset import generate_dataset from training.reward_functions import ( action_legal, belief_accuracy, env_reward, extract_action_and_belief, format_valid, ) @pytest.fixture(scope="module") def small_dataset(): """20 episodes, ~80 samples max — enough variety for reward checks.""" return generate_dataset( num_episodes=20, strategy="mixed", max_samples=80, hint_fraction=0.1, ) def _synth_completion(seed: int, kind: str) -> str: """Synthesize a completion of a given quality.""" rng = random.Random(seed) action_str = rng.choice(list(ActionType)).value.upper() s, m, w = rng.randint(0, 9), rng.randint(0, 9), rng.randint(0, 9) if kind == "perfect": true = profile_to_belief_vector(sample_profile(seed)) s, m, w = round(true[0] * 9), round(true[1] * 9), round(true[2] * 9) return f"{s} {m} {w} {action_str}" if kind == "good": return f"{s} {m} {w} {action_str}" if kind == "action_only": return action_str if kind == "wrong_belief": true = profile_to_belief_vector(sample_profile(seed)) s = round((1 - true[0]) * 9) m = round((1 - true[1]) * 9) w = round((1 - true[2]) * 9) return f"{s} {m} {w} {action_str}" if kind == "garbage": return "I don't know what to do here" return action_str # --------------------------------------------------------------------------- # Dataset shape # --------------------------------------------------------------------------- def test_dataset_is_non_empty(small_dataset): assert len(small_dataset) > 0 def test_dataset_rows_have_required_replay_columns(small_dataset): expected = {"prompt", "seed", "step_index", "action_history", "profile_mode"} for row in small_dataset: missing = expected - row.keys() assert not missing, f"row missing columns: {missing}" def test_dataset_prompts_are_chat_messages(small_dataset): for row in small_dataset[:5]: msgs = row["prompt"] assert isinstance(msgs, list) and len(msgs) == 2 assert msgs[0]["role"] == "system" assert msgs[1]["role"] == "user" # --------------------------------------------------------------------------- # Parser # --------------------------------------------------------------------------- def test_parser_belief_first_format(): action, belief, provided = extract_action_and_belief("3 8 7 DEEP_WORK") assert action == ActionType.DEEP_WORK assert belief == pytest.approx([3 / 9, 8 / 9, 7 / 9], abs=1e-3) assert provided is True def test_parser_action_only_returns_default_belief(): action, belief, provided = extract_action_and_belief("DEEP_WORK") assert action == ActionType.DEEP_WORK assert belief == [0.5, 0.5, 0.5] assert provided is False def test_parser_garbage_returns_none(): action, _, _ = extract_action_and_belief("I don't know what to do here") assert action is None # --------------------------------------------------------------------------- # Reward layers run end-to-end on synth completions # --------------------------------------------------------------------------- @pytest.fixture(scope="module") def replay_columns(small_dataset): # Skip step_index=0 samples — belief_accuracy intentionally returns 0 # at step 0 (no info yet), which would mask gradient checks. sub = [s for s in small_dataset if s["step_index"] > 0][:30] return { "samples": sub, "seed": [s["seed"] for s in sub], "history": [s["action_history"] for s in sub], "mode": [s["profile_mode"] for s in sub], "step_index": [s["step_index"] for s in sub], } @pytest.mark.parametrize("kind", ["perfect", "good", "action_only", "wrong_belief", "garbage"]) def test_all_reward_layers_return_floats(kind, replay_columns): sub = replay_columns["samples"] completions = [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] f = format_valid(completions) l = action_legal(completions) e = env_reward( completions, seed=replay_columns["seed"], action_history=replay_columns["history"], profile_mode=replay_columns["mode"], step_index=replay_columns["step_index"], ) b = belief_accuracy( completions, seed=replay_columns["seed"], action_history=replay_columns["history"], profile_mode=replay_columns["mode"], step_index=replay_columns["step_index"], ) for layer_name, scores in [("format_valid", f), ("action_legal", l), ("env_reward", e), ("belief_accuracy", b)]: assert len(scores) == len(completions), f"{layer_name} length mismatch" for s in scores: assert isinstance(s, float), f"{layer_name} returned non-float: {type(s)}" # --------------------------------------------------------------------------- # Reward layers DISCRIMINATE between completion qualities (anti-mode-collapse) # --------------------------------------------------------------------------- def test_format_valid_discriminates_belief_vs_action_only(replay_columns): """format_valid must reward belief+action higher than action-only.""" sub = replay_columns["samples"] good = [[{"content": _synth_completion(s["seed"], "good")}] for s in sub] action_only = [[{"content": _synth_completion(s["seed"], "action_only")}] for s in sub] good_avg = sum(format_valid(good)) / len(good) action_only_avg = sum(format_valid(action_only)) / len(action_only) assert good_avg > action_only_avg, ( f"format_valid did not push toward belief output: " f"good={good_avg:.3f} action_only={action_only_avg:.3f}" ) def test_belief_accuracy_discriminates_perfect_vs_wrong(replay_columns): """The whole point of the meta-RL signal: better belief → higher reward.""" sub = replay_columns["samples"] perfect = [[{"content": _synth_completion(s["seed"], "perfect")}] for s in sub] wrong = [[{"content": _synth_completion(s["seed"], "wrong_belief")}] for s in sub] perfect_avg = sum(belief_accuracy( perfect, seed=replay_columns["seed"], action_history=replay_columns["history"], profile_mode=replay_columns["mode"], step_index=replay_columns["step_index"], )) / len(perfect) wrong_avg = sum(belief_accuracy( wrong, seed=replay_columns["seed"], action_history=replay_columns["history"], profile_mode=replay_columns["mode"], step_index=replay_columns["step_index"], )) / len(wrong) assert perfect_avg > wrong_avg, ( f"belief_accuracy gave no gradient: perfect={perfect_avg:.3f} wrong={wrong_avg:.3f}" ) def test_no_reward_layer_is_constant_across_kinds(replay_columns): """The iter-1 collapse trap: a reward layer returning the same value for every completion in a GRPO group contributes zero to advantage. At least one of {good, action_only, garbage, wrong_belief} must produce a different mean score from the others for each layer that's supposed to be a learning signal. """ sub = replay_columns["samples"] kinds = ["good", "action_only", "garbage", "wrong_belief"] completions_by_kind = { kind: [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] for kind in kinds } # format_valid and belief_accuracy MUST discriminate. f_means = [sum(format_valid(completions_by_kind[k])) / len(sub) for k in kinds] assert max(f_means) - min(f_means) > 0.1, f"format_valid is near-constant: {f_means}" b_means = [ sum(belief_accuracy( completions_by_kind[k], seed=replay_columns["seed"], action_history=replay_columns["history"], profile_mode=replay_columns["mode"], step_index=replay_columns["step_index"], )) / len(sub) for k in kinds ] assert max(b_means) - min(b_means) > 0.05, f"belief_accuracy is near-constant: {b_means}"