|
|
| """
|
| GRPO training script for arithmetic countdown problems using Hydra configuration.
|
| After training, the model is automatically pushed to HuggingFace Hub.
|
| """
|
|
|
| import logging
|
| import os
|
| from collections.abc import Callable
|
| from pathlib import Path
|
| import sys
|
|
|
|
|
| FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
|
| sys.path.insert(0, PROJECT_ROOT)
|
|
|
| import hydra
|
| from datasets import Dataset
|
| from omegaconf import DictConfig, OmegaConf
|
| from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer
|
| from huggingface_hub import HfApi, login, create_repo
|
| from peft import LoraConfig, PeftModel, get_peft_model
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
| from src.dataset import load_csv_dataset_grpo
|
| from src.dataset.grpo import map_problem_description_to_conversation_grpo
|
| from src.utils.rewards import mathematical_correctness_reward_function
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("grpo_training")
|
|
|
|
|
|
|
|
|
|
|
| def load_train_dataset(cfg: DictConfig) -> Dataset:
|
| raw_dataset: Dataset = load_csv_dataset_grpo(
|
| cfg.file_path, cfg.split, map_problem_description_to_conversation_grpo
|
| )
|
| raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
|
| return raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
|
|
|
|
|
|
|
|
|
|
|
| def create_lora_model(cfg: DictConfig, resume_from_checkpoint: str | None = None) -> PreTrainedModel:
|
| model = AutoModelForCausalLM.from_pretrained(cfg.model_id, device_map=cfg.device_map)
|
|
|
| if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
|
| logger.info("Loading existing LoRA adapter and merging: %s", resume_from_checkpoint)
|
| model = PeftModel.from_pretrained(model, resume_from_checkpoint)
|
| model = model.merge_and_unload()
|
|
|
| lora_cfg = LoraConfig(
|
| r=cfg.lora.r,
|
| lora_alpha=cfg.lora.lora_alpha,
|
| target_modules=OmegaConf.to_container(cfg.lora.target_modules),
|
| lora_dropout=cfg.lora.lora_dropout,
|
| bias=cfg.lora.bias,
|
| task_type=cfg.lora.task_type,
|
| )
|
| return get_peft_model(model, lora_cfg)
|
|
|
|
|
|
|
|
|
|
|
| def create_grpo_config(cfg: DictConfig, output_dir: str) -> GRPOConfig:
|
| return GRPOConfig(
|
| output_dir=output_dir,
|
| learning_rate=cfg.learning_rate,
|
| warmup_ratio=cfg.warmup_ratio,
|
| weight_decay=cfg.weight_decay,
|
| lr_scheduler_type=cfg.lr_scheduler_type,
|
| optim=cfg.optim,
|
| remove_unused_columns=cfg.remove_unused_columns,
|
| gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| num_train_epochs=cfg.num_train_epochs,
|
| bf16=cfg.bf16,
|
| per_device_train_batch_size=cfg.per_device_train_batch_size,
|
| temperature=cfg.temperature,
|
| max_completion_length=cfg.max_completion_length,
|
| num_generations=cfg.num_generations,
|
| max_prompt_length=cfg.max_prompt_length,
|
| report_to=cfg.report_to,
|
| logging_steps=cfg.logging_steps,
|
| save_strategy=cfg.save_strategy,
|
| save_steps=cfg.save_steps,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def create_trainer(model, train_dataset, args):
|
| reward_funcs = [mathematical_correctness_reward_function]
|
| return GRPOTrainer(
|
| model=model,
|
| reward_funcs=reward_funcs,
|
| args=args,
|
| train_dataset=train_dataset,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def train_and_save(trainer, output_dir, resume_from_checkpoint=None, save_before_training=True):
|
| if save_before_training:
|
| trainer.save_model(output_dir)
|
|
|
| trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| trainer.save_model(output_dir)
|
|
|
| logger.info("Training completed.")
|
| logger.info("Saved final model to: %s", output_dir)
|
|
|
|
|
|
|
|
|
|
|
| def push_to_huggingface(output_dir: str, repo_id: str, model_id: str):
|
| logger.info("Pushing model to HuggingFace Hub...")
|
|
|
|
|
| api = HfApi()
|
|
|
|
|
| try:
|
| api.create_repo(repo_id, exist_ok=True)
|
| except:
|
| pass
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
| api.upload_folder(
|
| folder_path=output_dir,
|
| repo_id=repo_id,
|
| commit_message="Upload GRPO fine-tuned model",
|
| )
|
|
|
| tokenizer.push_to_hub(repo_id)
|
|
|
| logger.info("Upload complete! HF repo: https://huggingface.co/%s", repo_id)
|
|
|
|
|
|
|
|
|
|
|
| @hydra.main(version_base=None, config_path="../../config/grpo", config_name="config")
|
| def main(cfg: DictConfig):
|
| logger.info("Configuration:\n%s", OmegaConf.to_yaml(cfg))
|
|
|
| if not Path(cfg.dataset.file_path).exists():
|
| logger.error("Dataset CSV file does not exist: %s", cfg.dataset.file_path)
|
| return
|
|
|
| os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
|
|
| train_dataset = load_train_dataset(cfg.dataset)
|
|
|
|
|
| resume_sft = cfg.get("resume_from_checkpoint_sft", None)
|
| model = create_lora_model(cfg.model, resume_sft)
|
|
|
|
|
| training_args = create_grpo_config(cfg.training, cfg.output_dir)
|
| trainer = create_trainer(model, train_dataset, training_args)
|
|
|
|
|
| train_and_save(
|
| trainer,
|
| cfg.output_dir,
|
| resume_from_checkpoint=cfg.resume_from_checkpoint_grpo,
|
| save_before_training=cfg.save_before_training,
|
| )
|
|
|
|
|
| if cfg.get("push_to_hub", False):
|
| push_to_huggingface(
|
| output_dir=cfg.output_dir,
|
| repo_id=cfg.hf_repo_id,
|
| model_id=cfg.model.model_id,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|