rhythm_env / tests /test_pipeline_smoke.py
InosLihka's picture
Algorithm Distillation: grader v2 with belief_accuracy + SFT pipeline
ece0bbe
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""End-to-end smoke tests for the meta-RL training pipeline (no GPU).
Validates that dataset generation, the LLM-output parser, and all four
reward functions agree with each other before any GPU training spend.
The most important check is reward_variance_across_completion_kinds —
that's what catches the iter-1 mode-collapse class of bug (a reward
layer returns the same value for every completion in a GRPO group,
contributing zero to advantage).
"""
import random
import pytest
from models import ActionType
from server.rhythm_environment import sample_profile, profile_to_belief_vector
from training.dataset import generate_dataset
from training.reward_functions import (
action_legal,
belief_accuracy,
env_reward,
extract_action_and_belief,
format_valid,
)
@pytest.fixture(scope="module")
def small_dataset():
"""20 episodes, ~80 samples max — enough variety for reward checks."""
return generate_dataset(
num_episodes=20,
strategy="mixed",
max_samples=80,
hint_fraction=0.1,
)
def _synth_completion(seed: int, kind: str) -> str:
"""Synthesize a completion of a given quality."""
rng = random.Random(seed)
action_str = rng.choice(list(ActionType)).value.upper()
s, m, w = rng.randint(0, 9), rng.randint(0, 9), rng.randint(0, 9)
if kind == "perfect":
true = profile_to_belief_vector(sample_profile(seed))
s, m, w = round(true[0] * 9), round(true[1] * 9), round(true[2] * 9)
return f"{s} {m} {w} {action_str}"
if kind == "good":
return f"{s} {m} {w} {action_str}"
if kind == "action_only":
return action_str
if kind == "wrong_belief":
true = profile_to_belief_vector(sample_profile(seed))
s = round((1 - true[0]) * 9)
m = round((1 - true[1]) * 9)
w = round((1 - true[2]) * 9)
return f"{s} {m} {w} {action_str}"
if kind == "garbage":
return "I don't know what to do here"
return action_str
# ---------------------------------------------------------------------------
# Dataset shape
# ---------------------------------------------------------------------------
def test_dataset_is_non_empty(small_dataset):
assert len(small_dataset) > 0
def test_dataset_rows_have_required_replay_columns(small_dataset):
expected = {"prompt", "seed", "step_index", "action_history", "profile_mode"}
for row in small_dataset:
missing = expected - row.keys()
assert not missing, f"row missing columns: {missing}"
def test_dataset_prompts_are_chat_messages(small_dataset):
for row in small_dataset[:5]:
msgs = row["prompt"]
assert isinstance(msgs, list) and len(msgs) == 2
assert msgs[0]["role"] == "system"
assert msgs[1]["role"] == "user"
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
def test_parser_belief_first_format():
action, belief, provided = extract_action_and_belief("3 8 7 DEEP_WORK")
assert action == ActionType.DEEP_WORK
assert belief == pytest.approx([3 / 9, 8 / 9, 7 / 9], abs=1e-3)
assert provided is True
def test_parser_action_only_returns_default_belief():
action, belief, provided = extract_action_and_belief("DEEP_WORK")
assert action == ActionType.DEEP_WORK
assert belief == [0.5, 0.5, 0.5]
assert provided is False
def test_parser_garbage_returns_none():
action, _, _ = extract_action_and_belief("I don't know what to do here")
assert action is None
# ---------------------------------------------------------------------------
# Reward layers run end-to-end on synth completions
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def replay_columns(small_dataset):
# Skip step_index=0 samples — belief_accuracy intentionally returns 0
# at step 0 (no info yet), which would mask gradient checks.
sub = [s for s in small_dataset if s["step_index"] > 0][:30]
return {
"samples": sub,
"seed": [s["seed"] for s in sub],
"history": [s["action_history"] for s in sub],
"mode": [s["profile_mode"] for s in sub],
"step_index": [s["step_index"] for s in sub],
}
@pytest.mark.parametrize("kind", ["perfect", "good", "action_only", "wrong_belief", "garbage"])
def test_all_reward_layers_return_floats(kind, replay_columns):
sub = replay_columns["samples"]
completions = [[{"content": _synth_completion(s["seed"], kind)}] for s in sub]
f = format_valid(completions)
l = action_legal(completions)
e = env_reward(
completions,
seed=replay_columns["seed"],
action_history=replay_columns["history"],
profile_mode=replay_columns["mode"],
step_index=replay_columns["step_index"],
)
b = belief_accuracy(
completions,
seed=replay_columns["seed"],
action_history=replay_columns["history"],
profile_mode=replay_columns["mode"],
step_index=replay_columns["step_index"],
)
for layer_name, scores in [("format_valid", f), ("action_legal", l), ("env_reward", e), ("belief_accuracy", b)]:
assert len(scores) == len(completions), f"{layer_name} length mismatch"
for s in scores:
assert isinstance(s, float), f"{layer_name} returned non-float: {type(s)}"
# ---------------------------------------------------------------------------
# Reward layers DISCRIMINATE between completion qualities (anti-mode-collapse)
# ---------------------------------------------------------------------------
def test_format_valid_discriminates_belief_vs_action_only(replay_columns):
"""format_valid must reward belief+action higher than action-only."""
sub = replay_columns["samples"]
good = [[{"content": _synth_completion(s["seed"], "good")}] for s in sub]
action_only = [[{"content": _synth_completion(s["seed"], "action_only")}] for s in sub]
good_avg = sum(format_valid(good)) / len(good)
action_only_avg = sum(format_valid(action_only)) / len(action_only)
assert good_avg > action_only_avg, (
f"format_valid did not push toward belief output: "
f"good={good_avg:.3f} action_only={action_only_avg:.3f}"
)
def test_belief_accuracy_discriminates_perfect_vs_wrong(replay_columns):
"""The whole point of the meta-RL signal: better belief → higher reward."""
sub = replay_columns["samples"]
perfect = [[{"content": _synth_completion(s["seed"], "perfect")}] for s in sub]
wrong = [[{"content": _synth_completion(s["seed"], "wrong_belief")}] for s in sub]
perfect_avg = sum(belief_accuracy(
perfect,
seed=replay_columns["seed"],
action_history=replay_columns["history"],
profile_mode=replay_columns["mode"],
step_index=replay_columns["step_index"],
)) / len(perfect)
wrong_avg = sum(belief_accuracy(
wrong,
seed=replay_columns["seed"],
action_history=replay_columns["history"],
profile_mode=replay_columns["mode"],
step_index=replay_columns["step_index"],
)) / len(wrong)
assert perfect_avg > wrong_avg, (
f"belief_accuracy gave no gradient: perfect={perfect_avg:.3f} wrong={wrong_avg:.3f}"
)
def test_no_reward_layer_is_constant_across_kinds(replay_columns):
"""The iter-1 collapse trap: a reward layer returning the same value for every
completion in a GRPO group contributes zero to advantage. At least one of
{good, action_only, garbage, wrong_belief} must produce a different mean
score from the others for each layer that's supposed to be a learning signal.
"""
sub = replay_columns["samples"]
kinds = ["good", "action_only", "garbage", "wrong_belief"]
completions_by_kind = {
kind: [[{"content": _synth_completion(s["seed"], kind)}] for s in sub] for kind in kinds
}
# format_valid and belief_accuracy MUST discriminate.
f_means = [sum(format_valid(completions_by_kind[k])) / len(sub) for k in kinds]
assert max(f_means) - min(f_means) > 0.1, f"format_valid is near-constant: {f_means}"
b_means = [
sum(belief_accuracy(
completions_by_kind[k],
seed=replay_columns["seed"],
action_history=replay_columns["history"],
profile_mode=replay_columns["mode"],
step_index=replay_columns["step_index"],
)) / len(sub)
for k in kinds
]
assert max(b_means) - min(b_means) > 0.05, f"belief_accuracy is near-constant: {b_means}"