|
|
| """
|
| SFT training script with Hydra for LoRA (resume supported)
|
| """
|
|
|
| import os
|
| import sys
|
| import logging
|
| from pathlib import Path
|
|
|
|
|
| FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
|
| sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
|
| import hydra
|
| from omegaconf import DictConfig, OmegaConf
|
| from datasets import Dataset
|
| from peft import (
|
| LoraConfig,
|
| get_peft_model,
|
| PeftModel,
|
| )
|
| from transformers import (
|
| AutoTokenizer,
|
| AutoModelForCausalLM,
|
| )
|
| from trl import SFTTrainer, SFTConfig
|
| from huggingface_hub import login
|
|
|
|
|
| from src.dataset.sft import (
|
| load_csv_dataset_sft,
|
| map_problem_description_to_conversation_sft,
|
| )
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| logger = logging.getLogger("sft_training")
|
|
|
|
|
|
|
|
|
|
|
| def load_train_dataset(cfg: DictConfig) -> Dataset:
|
| raw_dataset = load_csv_dataset_sft(
|
| cfg.file_path, map_problem_description_to_conversation_sft
|
| )
|
| raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
|
| train_dataset = raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
|
| logger.info("Train rows: %d", len(train_dataset))
|
| return train_dataset
|
|
|
|
|
|
|
|
|
|
|
| def create_lora_model(cfg, resume_path=None):
|
| """
|
| RULE:
|
| - If resume_path provided: load base model then load LoRA adapter
|
| - Else: load base model then attach new LoRA
|
| """
|
|
|
| base_model_id = cfg.model_id
|
|
|
| logger.info(f"Loading base model: {base_model_id}")
|
| base_model = AutoModelForCausalLM.from_pretrained(
|
| base_model_id,
|
| device_map=cfg.device_map,
|
| )
|
|
|
| if resume_path:
|
| logger.info(f"Resume from LoRA adapter: {resume_path}")
|
| model = PeftModel.from_pretrained(base_model, resume_path)
|
| return model
|
|
|
|
|
| lora_cfg = LoraConfig(
|
| r=cfg.lora.r,
|
| lora_alpha=cfg.lora.lora_alpha,
|
| target_modules=OmegaConf.to_container(cfg.lora.target_modules),
|
| lora_dropout=cfg.lora.lora_dropout,
|
| bias=cfg.lora.bias,
|
| task_type=cfg.lora.task_type,
|
| )
|
|
|
| model = get_peft_model(base_model, lora_cfg)
|
| logger.info("New LoRA model created")
|
|
|
| return model
|
|
|
|
|
|
|
|
|
|
|
| def build_sft_config(cfg, output_dir):
|
| return SFTConfig(
|
| output_dir=output_dir,
|
| learning_rate=cfg.learning_rate,
|
| weight_decay=cfg.weight_decay,
|
| warmup_ratio=cfg.warmup_ratio,
|
| gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| per_device_train_batch_size=cfg.per_device_train_batch_size,
|
| num_train_epochs=cfg.num_train_epochs,
|
| max_length=cfg.max_length,
|
| bf16=cfg.bf16,
|
| fp16=cfg.fp16,
|
| logging_steps=cfg.logging_steps,
|
| save_strategy=cfg.save_strategy,
|
| save_steps=cfg.save_steps,
|
| report_to=cfg.report_to,
|
| lr_scheduler_type=cfg.lr_scheduler_type,
|
| optim=cfg.optim,
|
| remove_unused_columns=cfg.remove_unused_columns,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def create_trainer(model, tokenizer, train_dataset, training_args):
|
| return SFTTrainer(
|
| model=model,
|
| train_dataset=train_dataset,
|
| args=training_args,
|
| tokenizer=tokenizer,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def train_and_save(trainer, output_dir, tokenizer, hf_repo_id=None):
|
| logger.info("Start training...")
|
| trainer.train()
|
| logger.info("Training finished")
|
|
|
|
|
| trainer.model.save_pretrained(output_dir)
|
| tokenizer.save_pretrained(output_dir)
|
|
|
| logger.info(f"Saved LoRA adapter to: {output_dir}")
|
|
|
| if hf_repo_id:
|
| logger.info(f"Pushing adapter to HF Hub: {hf_repo_id}")
|
| trainer.model.push_to_hub(hf_repo_id)
|
| tokenizer.push_to_hub(hf_repo_id)
|
|
|
|
|
|
|
|
|
|
|
| @hydra.main(version_base=None, config_path="../../config/sft", config_name="config")
|
| def main(cfg: DictConfig):
|
| print("Loaded config:")
|
| print(OmegaConf.to_yaml(cfg))
|
|
|
|
|
| if cfg.get("hf_token", None):
|
| login(cfg.hf_token)
|
| logger.info("Logged into HF")
|
|
|
|
|
| if not Path(cfg.dataset.file_path).exists():
|
| logger.error(f"Dataset not found: {cfg.dataset.file_path}")
|
| return
|
|
|
| os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
|
|
| train_dataset = load_train_dataset(cfg.dataset)
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_id)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| resume = cfg.get("resume_from", None)
|
| model = create_lora_model(cfg.model, resume)
|
|
|
|
|
| training_args = build_sft_config(cfg.training, cfg.output_dir)
|
|
|
|
|
| trainer = create_trainer(model, tokenizer, train_dataset, training_args)
|
|
|
|
|
| train_and_save(
|
| trainer=trainer,
|
| output_dir=cfg.output_dir,
|
| tokenizer=tokenizer,
|
| hf_repo_id=cfg.get("hf_repo_id", None),
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|