|
|
| """
|
| GRPO training script for arithmetic countdown problems.
|
|
|
| This script trains a language model using GRPO (Group Relative Policy Optimization)
|
| to solve arithmetic problems with proper reasoning and formatting.
|
| """
|
|
|
| import argparse
|
| import logging
|
| import os
|
| from collections.abc import Callable
|
| from pathlib import Path
|
|
|
| from datasets import Dataset
|
| from peft import LoraConfig, get_peft_model
|
| from src.utils.dataset import load_csv_dataset
|
| from transformers import AutoModelForCausalLM, PreTrainedModel
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
| from src.utils.rewards import (
|
| mathematical_correctness_reward_function,
|
| )
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("grpo_training")
|
|
|
|
|
| def load_train_dataset(
|
| dataset_csv: str, max_rows: int = 2000, seed: int = 42
|
| ) -> Dataset:
|
| """
|
| Load, shuffle, and subsample the training dataset.
|
|
|
| Args:
|
| dataset_csv: Absolute path to the dataset CSV file
|
| max_rows: Maximum number of rows to select for training
|
| seed: Seed for dataset shuffling
|
|
|
| Returns:
|
| Dataset: A datasets.Dataset ready for GRPO training
|
| """
|
| raw_dataset: Dataset = load_csv_dataset(dataset_csv)
|
| raw_dataset = raw_dataset.shuffle(seed=seed)
|
| train_dataset = raw_dataset.select(range(min(max_rows, len(raw_dataset))))
|
| logger.info("Train rows: %d", len(train_dataset))
|
| return train_dataset
|
|
|
|
|
| def create_lora_model(model_id: str, device_map: str = "cuda") -> PreTrainedModel:
|
| """
|
| Create a base causal LM and wrap it with LoRA adapters.
|
|
|
| Args:
|
| model_id: Hugging Face model identifier to load as the base model
|
| device_map: Device mapping strategy for model loading
|
|
|
| Returns:
|
| PreTrainedModel: A transformers.PreTrainedModel with LoRA adapters applied
|
| """
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_id,
|
| device_map=device_map,
|
| )
|
|
|
| lora_cfg = LoraConfig(
|
| r=16,
|
| lora_alpha=32,
|
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| lora_dropout=0.05,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
| model = get_peft_model(model, lora_cfg)
|
| logger.info("Model with LoRA ready")
|
| return model
|
|
|
|
|
| def create_grpo_config(
|
| output_dir: str,
|
| learning_rate: float = 5e-6,
|
| num_train_epochs: int = 1,
|
| per_device_train_batch_size: int = 1,
|
| gradient_accumulation_steps: int = 16,
|
| max_completion_length: int = 512,
|
| num_generations: int = 16,
|
| temperature: float = 1.0,
|
| save_steps: int = 50,
|
| logging_steps: int = 1,
|
| max_prompt_length: int = 4096,
|
| ) -> GRPOConfig:
|
| """
|
| Create GRPO training configuration.
|
|
|
| Args:
|
| output_dir: Directory where checkpoints and logs will be written
|
| learning_rate: Learning rate for training
|
| num_train_epochs: Number of training epochs
|
| per_device_train_batch_size: Batch size per device
|
| gradient_accumulation_steps: Steps to accumulate gradients
|
| max_completion_length: Maximum length for completions
|
| num_generations: Number of generations per prompt
|
| temperature: Sampling temperature
|
| save_steps: Steps between model saves
|
| logging_steps: Steps between log outputs
|
| max_prompt_length: Maximum length for input prompts
|
|
|
| Returns:
|
| GRPOConfig: A configured trl.GRPOConfig instance
|
| """
|
| return GRPOConfig(
|
| output_dir=output_dir,
|
| learning_rate=learning_rate,
|
| weight_decay=0.01,
|
| warmup_ratio=0.1,
|
| lr_scheduler_type="linear",
|
| optim="adamw_8bit",
|
| remove_unused_columns=False,
|
| gradient_accumulation_steps=gradient_accumulation_steps,
|
| num_train_epochs=num_train_epochs,
|
| bf16=True,
|
| per_device_train_batch_size=per_device_train_batch_size,
|
| temperature=temperature,
|
|
|
| max_completion_length=max_completion_length,
|
| num_generations=num_generations,
|
| max_prompt_length=max_prompt_length,
|
|
|
| report_to=["tensorboard"],
|
| logging_steps=logging_steps,
|
| save_strategy="steps",
|
| save_steps=save_steps,
|
| )
|
|
|
|
|
| def create_trainer(
|
| model: PreTrainedModel,
|
| train_dataset: Dataset,
|
| args: GRPOConfig,
|
| ) -> GRPOTrainer:
|
| """
|
| Construct a GRPOTrainer with arithmetic-specific reward functions.
|
|
|
| Args:
|
| model: The LoRA-wrapped pretrained model to train
|
| train_dataset: The dataset to use for training
|
| args: The GRPO configuration
|
|
|
| Returns:
|
| GRPOTrainer: An initialized trl.GRPOTrainer instance
|
| """
|
| reward_funcs: list[Callable[..., list[float]]] = [
|
| mathematical_correctness_reward_function,
|
| ]
|
| trainer = GRPOTrainer(
|
| model=model,
|
| reward_funcs=reward_funcs,
|
| args=args,
|
| train_dataset=train_dataset,
|
| )
|
| return trainer
|
|
|
|
|
| def train_and_save(trainer: GRPOTrainer, output_dir: str) -> None:
|
| """
|
| Run training and save the final model to disk.
|
|
|
| Args:
|
| trainer: The configured GRPO trainer instance
|
| output_dir: Output directory to save the trained model
|
|
|
| Returns:
|
| None
|
| """
|
| train_result = trainer.train()
|
| logger.info("Training complete: %s", str(train_result))
|
| trainer.save_model(output_dir)
|
| logger.info("Saved to %s", output_dir)
|
|
|
|
|
| def main() -> None:
|
| """
|
| Run the full GRPO training workflow with command-line arguments.
|
|
|
| Returns:
|
| None
|
| """
|
| parser = argparse.ArgumentParser(
|
| description="Train a language model using GRPO for arithmetic countdown problems"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--dataset_csv",
|
| type=str,
|
| required=True,
|
| help="Path to the training dataset CSV file",
|
| )
|
| parser.add_argument(
|
| "--max_rows", type=int, default=2000, help="Maximum number of training samples"
|
| )
|
| parser.add_argument(
|
| "--seed", type=int, default=42, help="Random seed for dataset shuffling"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--model_id",
|
| type=str,
|
| default="Qwen/Qwen2.5-3B-Instruct",
|
| help="Hugging Face model identifier",
|
| )
|
| parser.add_argument(
|
| "--device_map", type=str, default="x", help="Device mapping strategy"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| required=True,
|
| help="Directory to save model checkpoints and logs",
|
| )
|
| parser.add_argument(
|
| "--learning_rate", type=float, default=5e-6, help="Learning rate"
|
| )
|
| parser.add_argument(
|
| "--num_train_epochs", type=int, default=1, help="Number of training epochs"
|
| )
|
| parser.add_argument(
|
| "--per_device_train_batch_size",
|
| type=int,
|
| default=1,
|
| help="Batch size per device",
|
| )
|
| parser.add_argument(
|
| "--gradient_accumulation_steps",
|
| type=int,
|
| default=16,
|
| help="Gradient accumulation steps",
|
| )
|
| parser.add_argument(
|
| "--max_completion_length",
|
| type=int,
|
| default=512,
|
| help="Maximum completion length",
|
| )
|
| parser.add_argument(
|
| "--num_generations",
|
| type=int,
|
| default=16,
|
| help="Number of generations per prompt",
|
| )
|
| parser.add_argument(
|
| "--temperature", type=float, default=1.0, help="Sampling temperature"
|
| )
|
| parser.add_argument(
|
| "--save_steps", type=int, default=50, help="Steps between model saves"
|
| )
|
| parser.add_argument(
|
| "--logging_steps", type=int, default=1, help="Steps between log outputs"
|
| )
|
| parser.add_argument(
|
| "--max_prompt_length",
|
| type=int,
|
| default=4096,
|
| help="Maximum length for input prompts",
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| if not Path(args.dataset_csv).exists():
|
| logger.error("Dataset CSV file does not exist: %s", args.dataset_csv)
|
| return
|
|
|
| if args.max_rows <= 0:
|
| logger.error("max_rows must be positive")
|
| return
|
|
|
|
|
| os.makedirs(args.output_dir, exist_ok=True)
|
| logger.info("Output dir: %s", args.output_dir)
|
|
|
|
|
| train_dataset = load_train_dataset(args.dataset_csv, args.max_rows, args.seed)
|
|
|
|
|
| model = create_lora_model(args.model_id, args.device_map)
|
|
|
|
|
| training_args = create_grpo_config(
|
| output_dir=args.output_dir,
|
| learning_rate=args.learning_rate,
|
| num_train_epochs=args.num_train_epochs,
|
| per_device_train_batch_size=args.per_device_train_batch_size,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| max_completion_length=args.max_completion_length,
|
| num_generations=args.num_generations,
|
| temperature=args.temperature,
|
| save_steps=args.save_steps,
|
| logging_steps=args.logging_steps,
|
| max_prompt_length=args.max_prompt_length,
|
| )
|
|
|
|
|
| trainer = create_trainer(
|
| model=model, train_dataset=train_dataset, args=training_args
|
| )
|
|
|
|
|
| train_and_save(trainer=trainer, output_dir=args.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|