""" RhythmEnv GRPO Training Script (Meta-RL version). Trains an LLM agent to BOTH (a) balance life meters AND (b) infer the hidden personality of the person it's helping. Four-layer reward stack: format_valid — output parseable as ACTION + 3 belief digits action_legal — action is one of 10 valid types env_reward — actual env reward for the chosen action (seed replay) belief_accuracy — how close the belief vector is to the hidden profile Usage (Colab T4): !pip install unsloth transformers trl datasets !python training/train.py --max_steps 1500 Setup-check (no GPU): run the smoke tests instead of starting a real run: python -m pytest tests/test_pipeline_smoke.py -q """ import argparse import json import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) def main(): parser = argparse.ArgumentParser(description="Train RhythmEnv agent with GRPO (meta-RL)") parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct") parser.add_argument("--max_steps", type=int, default=1500, help="Number of GRPO training steps (1500 recommended for meta-RL)") parser.add_argument("--num_episodes", type=int, default=300, help="Number of episodes for dataset generation (more diversity = better meta-RL)") parser.add_argument("--max_samples", type=int, default=3000, help="Maximum training samples") parser.add_argument("--num_generations", type=int, default=8, help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)") parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--beta", type=float, default=0.04, help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)") parser.add_argument("--lora_rank", type=int, default=8, help="LoRA rank (8 = more capacity than original 4 for meta-RL)") parser.add_argument("--hint_fraction", type=float, default=0.0, help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) " "to eliminate train-eval distribution mismatch. Set >0 only if you ALSO " "show hints during eval.") parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained") parser.add_argument("--report_to", type=str, default="none") args = parser.parse_args() # --------------------------------------------------------------- # 1. Generate dataset # --------------------------------------------------------------- print("=" * 60) print("Step 1: Generating training dataset (continuous profiles)") print("=" * 60) from dataset import generate_dataset from datasets import Dataset raw_samples = generate_dataset( num_episodes=args.num_episodes, strategy="mixed", max_samples=args.max_samples, hint_fraction=args.hint_fraction, ) # Replay metadata so env_reward + belief_accuracy can reconstruct state dataset = Dataset.from_list([ { "prompt": sample["prompt"], "seed": sample["seed"], "step_index": sample["step_index"], "action_history": sample["action_history"], "profile_mode": sample["profile_mode"], } for sample in raw_samples ]) print(f"Dataset size: {len(dataset)}") # --------------------------------------------------------------- # 2. Load model with Unsloth # --------------------------------------------------------------- print("\n" + "=" * 60) print(f"Step 2: Loading model {args.model_name}") print("=" * 60) from unsloth import FastLanguageModel max_seq_length = 1024 # bumped from 768 to fit longer prompts with history model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model_name, load_in_4bit=True, max_seq_length=max_seq_length, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=args.lora_rank * 2, use_gradient_checkpointing="unsloth", random_state=3407, ) print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_rank * 2}") # --------------------------------------------------------------- # 3. Reward functions (4-layer stack including belief_accuracy) # --------------------------------------------------------------- print("\n" + "=" * 60) print("Step 3: Setting up reward functions") print("=" * 60) from reward_functions import format_valid, action_legal, env_reward, belief_accuracy reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy] print("Using: format_valid + action_legal + env_reward + belief_accuracy") # --------------------------------------------------------------- # 4. GRPO trainer config # --------------------------------------------------------------- print("\n" + "=" * 60) print("Step 4: Configuring GRPO trainer") print("=" * 60) from trl import GRPOConfig, GRPOTrainer max_prompt_length = 600 # history + hint room max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits # reward_weights: suppress the format/action_legal layers (small, low-variance # signals — too constant across a GRPO group to contribute meaningful advantage) # and amplify the variable signals env_reward and belief_accuracy. belief_accuracy # at 3.0 is the dominant learning signal. # Order MUST match reward_funcs above: format_valid, action_legal, env_reward, belief_accuracy reward_weights = [0.05, 0.05, 1.5, 3.0] training_args_kwargs = dict( temperature=1.5, # bumped from 1.0 to force diverse rollouts and break mode collapse learning_rate=args.learning_rate, beta=args.beta, max_grad_norm=0.5, weight_decay=0.001, warmup_ratio=0.1, lr_scheduler_type="linear", optim="adamw_8bit", logging_steps=1, per_device_train_batch_size=1, gradient_accumulation_steps=4, num_generations=args.num_generations, max_prompt_length=max_prompt_length, max_completion_length=max_completion_length, max_steps=args.max_steps, save_steps=250, report_to=args.report_to, output_dir=args.output_dir, ) # reward_weights was added in TRL 0.13+; pass only if supported try: training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights) print(f"Using GRPOConfig with reward_weights={reward_weights}") except TypeError: training_args = GRPOConfig(**training_args_kwargs) print("WARN: TRL version does not support reward_weights; using uniform weighting") print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, " f"lr={args.learning_rate}, beta={args.beta}") print(f"max_prompt_length={max_prompt_length}, max_completion_length={max_completion_length}") print(f"hint_fraction={args.hint_fraction} (curriculum warmup)") # --------------------------------------------------------------- # 5. Train # --------------------------------------------------------------- print("\n" + "=" * 60) print("Step 5: Starting GRPO training") print("=" * 60) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=reward_funcs, args=training_args, train_dataset=dataset, ) trainer.train() # --------------------------------------------------------------- # 6. Save merged model # --------------------------------------------------------------- print("\n" + "=" * 60) print("Step 6: Saving model") print("=" * 60) model.save_pretrained_merged( args.output_dir, tokenizer, save_method="merged_16bit", ) config_path = os.path.join(args.output_dir, "training_config.json") with open(config_path, "w") as f: json.dump(vars(args), f, indent=2) # Save log_history for offline plotting (job runs don't have a notebook to inspect trainer.state) log_path = os.path.join(args.output_dir, "log_history.json") with open(log_path, "w") as f: json.dump(trainer.state.log_history, f, indent=2) print(f"Model saved to: {args.output_dir}") print(f"Training config saved to: {config_path}") print(f"Log history saved to: {log_path}") print("\nNext: run inference_eval.py to compare baseline vs trained") print(" python training/inference_eval.py --model_path " + args.output_dir) if __name__ == "__main__": main()