rhythm_env / training /train.py
InosLihka's picture
Algorithm Distillation: grader v2 with belief_accuracy + SFT pipeline
ece0bbe
"""
RhythmEnv GRPO Training Script (Meta-RL version).
Trains an LLM agent to BOTH (a) balance life meters AND (b) infer the hidden
personality of the person it's helping. Four-layer reward stack:
format_valid — output parseable as ACTION + 3 belief digits
action_legal — action is one of 10 valid types
env_reward — actual env reward for the chosen action (seed replay)
belief_accuracy — how close the belief vector is to the hidden profile
Usage (Colab T4):
!pip install unsloth transformers trl datasets
!python training/train.py --max_steps 1500
Setup-check (no GPU): run the smoke tests instead of starting a real run:
python -m pytest tests/test_pipeline_smoke.py -q
"""
import argparse
import json
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
def main():
parser = argparse.ArgumentParser(description="Train RhythmEnv agent with GRPO (meta-RL)")
parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--max_steps", type=int, default=1500,
help="Number of GRPO training steps (1500 recommended for meta-RL)")
parser.add_argument("--num_episodes", type=int, default=300,
help="Number of episodes for dataset generation (more diversity = better meta-RL)")
parser.add_argument("--max_samples", type=int, default=3000,
help="Maximum training samples")
parser.add_argument("--num_generations", type=int, default=8,
help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)")
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--beta", type=float, default=0.04,
help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
parser.add_argument("--lora_rank", type=int, default=8,
help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
parser.add_argument("--hint_fraction", type=float, default=0.0,
help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
"to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
"show hints during eval.")
parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained")
parser.add_argument("--report_to", type=str, default="none")
args = parser.parse_args()
# ---------------------------------------------------------------
# 1. Generate dataset
# ---------------------------------------------------------------
print("=" * 60)
print("Step 1: Generating training dataset (continuous profiles)")
print("=" * 60)
from dataset import generate_dataset
from datasets import Dataset
raw_samples = generate_dataset(
num_episodes=args.num_episodes,
strategy="mixed",
max_samples=args.max_samples,
hint_fraction=args.hint_fraction,
)
# Replay metadata so env_reward + belief_accuracy can reconstruct state
dataset = Dataset.from_list([
{
"prompt": sample["prompt"],
"seed": sample["seed"],
"step_index": sample["step_index"],
"action_history": sample["action_history"],
"profile_mode": sample["profile_mode"],
}
for sample in raw_samples
])
print(f"Dataset size: {len(dataset)}")
# ---------------------------------------------------------------
# 2. Load model with Unsloth
# ---------------------------------------------------------------
print("\n" + "=" * 60)
print(f"Step 2: Loading model {args.model_name}")
print("=" * 60)
from unsloth import FastLanguageModel
max_seq_length = 1024 # bumped from 768 to fit longer prompts with history
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
load_in_4bit=True,
max_seq_length=max_seq_length,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_rank,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=args.lora_rank * 2,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_rank * 2}")
# ---------------------------------------------------------------
# 3. Reward functions (4-layer stack including belief_accuracy)
# ---------------------------------------------------------------
print("\n" + "=" * 60)
print("Step 3: Setting up reward functions")
print("=" * 60)
from reward_functions import format_valid, action_legal, env_reward, belief_accuracy
reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy]
print("Using: format_valid + action_legal + env_reward + belief_accuracy")
# ---------------------------------------------------------------
# 4. GRPO trainer config
# ---------------------------------------------------------------
print("\n" + "=" * 60)
print("Step 4: Configuring GRPO trainer")
print("=" * 60)
from trl import GRPOConfig, GRPOTrainer
max_prompt_length = 600 # history + hint room
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
# reward_weights: suppress the format/action_legal layers (small, low-variance
# signals — too constant across a GRPO group to contribute meaningful advantage)
# and amplify the variable signals env_reward and belief_accuracy. belief_accuracy
# at 3.0 is the dominant learning signal.
# Order MUST match reward_funcs above: format_valid, action_legal, env_reward, belief_accuracy
reward_weights = [0.05, 0.05, 1.5, 3.0]
training_args_kwargs = dict(
temperature=1.5, # bumped from 1.0 to force diverse rollouts and break mode collapse
learning_rate=args.learning_rate,
beta=args.beta,
max_grad_norm=0.5,
weight_decay=0.001,
warmup_ratio=0.1,
lr_scheduler_type="linear",
optim="adamw_8bit",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=args.num_generations,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
max_steps=args.max_steps,
save_steps=250,
report_to=args.report_to,
output_dir=args.output_dir,
)
# reward_weights was added in TRL 0.13+; pass only if supported
try:
training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights)
print(f"Using GRPOConfig with reward_weights={reward_weights}")
except TypeError:
training_args = GRPOConfig(**training_args_kwargs)
print("WARN: TRL version does not support reward_weights; using uniform weighting")
print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
f"lr={args.learning_rate}, beta={args.beta}")
print(f"max_prompt_length={max_prompt_length}, max_completion_length={max_completion_length}")
print(f"hint_fraction={args.hint_fraction} (curriculum warmup)")
# ---------------------------------------------------------------
# 5. Train
# ---------------------------------------------------------------
print("\n" + "=" * 60)
print("Step 5: Starting GRPO training")
print("=" * 60)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset,
)
trainer.train()
# ---------------------------------------------------------------
# 6. Save merged model
# ---------------------------------------------------------------
print("\n" + "=" * 60)
print("Step 6: Saving model")
print("=" * 60)
model.save_pretrained_merged(
args.output_dir,
tokenizer,
save_method="merged_16bit",
)
config_path = os.path.join(args.output_dir, "training_config.json")
with open(config_path, "w") as f:
json.dump(vars(args), f, indent=2)
# Save log_history for offline plotting (job runs don't have a notebook to inspect trainer.state)
log_path = os.path.join(args.output_dir, "log_history.json")
with open(log_path, "w") as f:
json.dump(trainer.state.log_history, f, indent=2)
print(f"Model saved to: {args.output_dir}")
print(f"Training config saved to: {config_path}")
print(f"Log history saved to: {log_path}")
print("\nNext: run inference_eval.py to compare baseline vs trained")
print(" python training/inference_eval.py --model_path " + args.output_dir)
if __name__ == "__main__":
main()