|
|
| """
|
| SFT training script for arithmetic countdown problems.
|
|
|
| This script trains a language model using SFT (Supervised Fine-Tuning)
|
| to solve arithmetic problems with proper reasoning and formatting.
|
| """
|
|
|
| import argparse
|
| import logging
|
| import os
|
| from pathlib import Path
|
|
|
| from datasets import Dataset
|
| from peft import LoraConfig, get_peft_model
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
| from trl import SFTConfig, SFTTrainer
|
|
|
| from src.dataset.sft import (
|
| load_csv_dataset_sft,
|
| map_problem_description_to_conversation_sft,
|
| )
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("sft_training")
|
|
|
|
|
| def load_train_dataset(
|
| dataset_csv: str, max_rows: int = 2000, seed: int = 42
|
| ) -> Dataset:
|
| """
|
| Load, shuffle, and subsample the training dataset.
|
|
|
| Args:
|
| dataset_csv: Absolute path to the dataset CSV file
|
| max_rows: Maximum number of rows to select for training
|
| seed: Seed for dataset shuffling
|
|
|
| Returns:
|
| Dataset: A datasets.Dataset ready for SFT training
|
| """
|
| raw_dataset: Dataset = load_csv_dataset_sft(
|
| dataset_csv, "train", map_problem_description_to_conversation_sft
|
| )
|
| raw_dataset = raw_dataset.shuffle(seed=seed)
|
| train_dataset = raw_dataset.select(range(min(max_rows, len(raw_dataset))))
|
| logger.info("Train rows: %d", len(train_dataset))
|
| return train_dataset
|
|
|
|
|
| def create_lora_model(model_id: str, device_map: str = "auto") -> PreTrainedModel:
|
| """
|
| Create a base causal LM and wrap it with LoRA adapters.
|
|
|
| Args:
|
| model_id: Hugging Face model identifier to load as the base model
|
| device_map: Device mapping strategy for model loading
|
|
|
| Returns:
|
| PreTrainedModel: A transformers.PreTrainedModel with LoRA adapters applied
|
| """
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_id,
|
| device_map=device_map,
|
| )
|
|
|
| lora_cfg = LoraConfig(
|
| r=16,
|
| lora_alpha=32,
|
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| lora_dropout=0.05,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
| model = get_peft_model(model, lora_cfg)
|
| logger.info("Model with LoRA ready")
|
| return model
|
|
|
|
|
| def create_sft_config(
|
| output_dir: str,
|
| learning_rate: float = 2e-4,
|
| num_train_epochs: int = 1,
|
| per_device_train_batch_size: int = 4,
|
| gradient_accumulation_steps: int = 4,
|
| max_length: int = 2048,
|
| save_steps: int = 50,
|
| logging_steps: int = 1,
|
| ) -> SFTConfig:
|
| """
|
| Create SFT training configuration.
|
|
|
| Args:
|
| output_dir: Directory where checkpoints and logs will be written
|
| learning_rate: Learning rate for training
|
| num_train_epochs: Number of training epochs
|
| per_device_train_batch_size: Batch size per device
|
| gradient_accumulation_steps: Steps to accumulate gradients
|
| max_length: Maximum sequence length
|
| save_steps: Steps between model saves
|
| logging_steps: Steps between log outputs
|
|
|
| Returns:
|
| SFTConfig: A configured trl.SFTConfig instance
|
| """
|
| return SFTConfig(
|
| output_dir=output_dir,
|
| learning_rate=learning_rate,
|
| weight_decay=0.001,
|
| warmup_ratio=0.03,
|
| lr_scheduler_type="cosine",
|
| optim="paged_adamw_32bit",
|
| remove_unused_columns=False,
|
| gradient_accumulation_steps=gradient_accumulation_steps,
|
| num_train_epochs=num_train_epochs,
|
| bf16=True,
|
| per_device_train_batch_size=per_device_train_batch_size,
|
|
|
| max_length=max_length,
|
| packing=False,
|
|
|
| report_to=["tensorboard"],
|
| logging_steps=logging_steps,
|
| save_strategy="steps",
|
| save_steps=save_steps,
|
| eval_strategy="no",
|
| )
|
|
|
|
|
| def create_trainer(
|
| model: PreTrainedModel,
|
| tokenizer: AutoTokenizer,
|
| train_dataset: Dataset,
|
| args: SFTConfig,
|
| ) -> SFTTrainer:
|
| """
|
| Construct an SFTTrainer for supervised fine-tuning.
|
|
|
| Args:
|
| model: The LoRA-wrapped pretrained model to train
|
| tokenizer: The tokenizer for the model
|
| train_dataset: The dataset to use for training
|
| args: The SFT configuration
|
|
|
| Returns:
|
| SFTTrainer: An initialized trl.SFTTrainer instance
|
| """
|
| trainer = SFTTrainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| args=args,
|
| train_dataset=train_dataset,
|
| )
|
| return trainer
|
|
|
|
|
| def train_and_save(trainer: SFTTrainer, output_dir: str) -> None:
|
| """
|
| Run training and save the final model to disk.
|
|
|
| Args:
|
| trainer: The configured SFT trainer instance
|
| output_dir: Output directory to save the trained model
|
|
|
| Returns:
|
| None
|
| """
|
| train_result = trainer.train()
|
| logger.info("Training complete: %s", str(train_result))
|
| trainer.save_model(output_dir)
|
| logger.info("Saved to %s", output_dir)
|
|
|
|
|
| def main() -> None:
|
| """
|
| Run the full SFT training workflow with command-line arguments.
|
|
|
| Returns:
|
| None
|
| """
|
| parser = argparse.ArgumentParser(
|
| description="Train a language model using SFT for arithmetic countdown problems"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--dataset_csv",
|
| type=str,
|
| required=True,
|
| help="Path to the training dataset CSV file",
|
| )
|
| parser.add_argument(
|
| "--max_rows", type=int, default=2000, help="Maximum number of training samples"
|
| )
|
| parser.add_argument(
|
| "--seed", type=int, default=42, help="Random seed for dataset shuffling"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--model_id",
|
| type=str,
|
| default="Qwen/Qwen2.5-3B-Instruct",
|
| help="Hugging Face model identifier",
|
| )
|
| parser.add_argument(
|
| "--device_map", type=str, default="auto", help="Device mapping strategy"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| required=True,
|
| help="Directory to save model checkpoints and logs",
|
| )
|
| parser.add_argument(
|
| "--learning_rate", type=float, default=2e-4, help="Learning rate"
|
| )
|
| parser.add_argument(
|
| "--num_train_epochs", type=int, default=1, help="Number of training epochs"
|
| )
|
| parser.add_argument(
|
| "--per_device_train_batch_size",
|
| type=int,
|
| default=4,
|
| help="Batch size per device",
|
| )
|
| parser.add_argument(
|
| "--gradient_accumulation_steps",
|
| type=int,
|
| default=4,
|
| help="Gradient accumulation steps",
|
| )
|
| parser.add_argument(
|
| "--max_length",
|
| type=int,
|
| default=2048,
|
| help="Maximum sequence length",
|
| )
|
| parser.add_argument(
|
| "--save_steps", type=int, default=50, help="Steps between model saves"
|
| )
|
| parser.add_argument(
|
| "--logging_steps", type=int, default=1, help="Steps between log outputs"
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| if not Path(args.dataset_csv).exists():
|
| logger.error("Dataset CSV file does not exist: %s", args.dataset_csv)
|
| return
|
|
|
| if args.max_rows <= 0:
|
| logger.error("max_rows must be positive")
|
| return
|
|
|
|
|
| os.makedirs(args.output_dir, exist_ok=True)
|
| logger.info("Output dir: %s", args.output_dir)
|
|
|
|
|
| train_dataset = load_train_dataset(args.dataset_csv, args.max_rows, args.seed)
|
|
|
|
|
| model = create_lora_model(args.model_id, args.device_map)
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
|
|
|
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| training_args = create_sft_config(
|
| output_dir=args.output_dir,
|
| learning_rate=args.learning_rate,
|
| num_train_epochs=args.num_train_epochs,
|
| per_device_train_batch_size=args.per_device_train_batch_size,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| max_length=args.max_length,
|
| save_steps=args.save_steps,
|
| logging_steps=args.logging_steps,
|
| )
|
|
|
|
|
| trainer = create_trainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| train_dataset=train_dataset,
|
| args=training_args,
|
| )
|
|
|
|
|
| train_and_save(trainer=trainer, output_dir=args.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|