Spaces:
Sleeping
Sleeping
| """ | |
| RhythmEnv GRPO Training Script (Meta-RL version). | |
| Trains an LLM agent to BOTH (a) balance life meters AND (b) infer the hidden | |
| personality of the person it's helping. Four-layer reward stack: | |
| format_valid — output parseable as ACTION + 3 belief digits | |
| action_legal — action is one of 10 valid types | |
| env_reward — actual env reward for the chosen action (seed replay) | |
| belief_accuracy — how close the belief vector is to the hidden profile | |
| Usage (Colab T4): | |
| !pip install unsloth transformers trl datasets | |
| !python training/train.py --max_steps 1500 | |
| Setup-check (no GPU): run the smoke tests instead of starting a real run: | |
| python -m pytest tests/test_pipeline_smoke.py -q | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train RhythmEnv agent with GRPO (meta-RL)") | |
| parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct") | |
| parser.add_argument("--max_steps", type=int, default=1500, | |
| help="Number of GRPO training steps (1500 recommended for meta-RL)") | |
| parser.add_argument("--num_episodes", type=int, default=300, | |
| help="Number of episodes for dataset generation (more diversity = better meta-RL)") | |
| parser.add_argument("--max_samples", type=int, default=3000, | |
| help="Maximum training samples") | |
| parser.add_argument("--num_generations", type=int, default=8, | |
| help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)") | |
| parser.add_argument("--learning_rate", type=float, default=5e-5) | |
| parser.add_argument("--beta", type=float, default=0.04, | |
| help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)") | |
| parser.add_argument("--lora_rank", type=int, default=8, | |
| help="LoRA rank (8 = more capacity than original 4 for meta-RL)") | |
| parser.add_argument("--hint_fraction", type=float, default=0.0, | |
| help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) " | |
| "to eliminate train-eval distribution mismatch. Set >0 only if you ALSO " | |
| "show hints during eval.") | |
| parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained") | |
| parser.add_argument("--report_to", type=str, default="none") | |
| args = parser.parse_args() | |
| # --------------------------------------------------------------- | |
| # 1. Generate dataset | |
| # --------------------------------------------------------------- | |
| print("=" * 60) | |
| print("Step 1: Generating training dataset (continuous profiles)") | |
| print("=" * 60) | |
| from dataset import generate_dataset | |
| from datasets import Dataset | |
| raw_samples = generate_dataset( | |
| num_episodes=args.num_episodes, | |
| strategy="mixed", | |
| max_samples=args.max_samples, | |
| hint_fraction=args.hint_fraction, | |
| ) | |
| # Replay metadata so env_reward + belief_accuracy can reconstruct state | |
| dataset = Dataset.from_list([ | |
| { | |
| "prompt": sample["prompt"], | |
| "seed": sample["seed"], | |
| "step_index": sample["step_index"], | |
| "action_history": sample["action_history"], | |
| "profile_mode": sample["profile_mode"], | |
| } | |
| for sample in raw_samples | |
| ]) | |
| print(f"Dataset size: {len(dataset)}") | |
| # --------------------------------------------------------------- | |
| # 2. Load model with Unsloth | |
| # --------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print(f"Step 2: Loading model {args.model_name}") | |
| print("=" * 60) | |
| from unsloth import FastLanguageModel | |
| max_seq_length = 1024 # bumped from 768 to fit longer prompts with history | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model_name, | |
| load_in_4bit=True, | |
| max_seq_length=max_seq_length, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=args.lora_rank, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=args.lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_rank * 2}") | |
| # --------------------------------------------------------------- | |
| # 3. Reward functions (4-layer stack including belief_accuracy) | |
| # --------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("Step 3: Setting up reward functions") | |
| print("=" * 60) | |
| from reward_functions import format_valid, action_legal, env_reward, belief_accuracy | |
| reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy] | |
| print("Using: format_valid + action_legal + env_reward + belief_accuracy") | |
| # --------------------------------------------------------------- | |
| # 4. GRPO trainer config | |
| # --------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("Step 4: Configuring GRPO trainer") | |
| print("=" * 60) | |
| from trl import GRPOConfig, GRPOTrainer | |
| max_prompt_length = 600 # history + hint room | |
| max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits | |
| # reward_weights: suppress the format/action_legal layers (small, low-variance | |
| # signals — too constant across a GRPO group to contribute meaningful advantage) | |
| # and amplify the variable signals env_reward and belief_accuracy. belief_accuracy | |
| # at 3.0 is the dominant learning signal. | |
| # Order MUST match reward_funcs above: format_valid, action_legal, env_reward, belief_accuracy | |
| reward_weights = [0.05, 0.05, 1.5, 3.0] | |
| training_args_kwargs = dict( | |
| temperature=1.5, # bumped from 1.0 to force diverse rollouts and break mode collapse | |
| learning_rate=args.learning_rate, | |
| beta=args.beta, | |
| max_grad_norm=0.5, | |
| weight_decay=0.001, | |
| warmup_ratio=0.1, | |
| lr_scheduler_type="linear", | |
| optim="adamw_8bit", | |
| logging_steps=1, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| num_generations=args.num_generations, | |
| max_prompt_length=max_prompt_length, | |
| max_completion_length=max_completion_length, | |
| max_steps=args.max_steps, | |
| save_steps=250, | |
| report_to=args.report_to, | |
| output_dir=args.output_dir, | |
| ) | |
| # reward_weights was added in TRL 0.13+; pass only if supported | |
| try: | |
| training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights) | |
| print(f"Using GRPOConfig with reward_weights={reward_weights}") | |
| except TypeError: | |
| training_args = GRPOConfig(**training_args_kwargs) | |
| print("WARN: TRL version does not support reward_weights; using uniform weighting") | |
| print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, " | |
| f"lr={args.learning_rate}, beta={args.beta}") | |
| print(f"max_prompt_length={max_prompt_length}, max_completion_length={max_completion_length}") | |
| print(f"hint_fraction={args.hint_fraction} (curriculum warmup)") | |
| # --------------------------------------------------------------- | |
| # 5. Train | |
| # --------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("Step 5: Starting GRPO training") | |
| print("=" * 60) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=reward_funcs, | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |
| # --------------------------------------------------------------- | |
| # 6. Save merged model | |
| # --------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("Step 6: Saving model") | |
| print("=" * 60) | |
| model.save_pretrained_merged( | |
| args.output_dir, | |
| tokenizer, | |
| save_method="merged_16bit", | |
| ) | |
| config_path = os.path.join(args.output_dir, "training_config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(vars(args), f, indent=2) | |
| # Save log_history for offline plotting (job runs don't have a notebook to inspect trainer.state) | |
| log_path = os.path.join(args.output_dir, "log_history.json") | |
| with open(log_path, "w") as f: | |
| json.dump(trainer.state.log_history, f, indent=2) | |
| print(f"Model saved to: {args.output_dir}") | |
| print(f"Training config saved to: {config_path}") | |
| print(f"Log history saved to: {log_path}") | |
| print("\nNext: run inference_eval.py to compare baseline vs trained") | |
| print(" python training/inference_eval.py --model_path " + args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |