| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from dataclasses import dataclass |
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
| import datasets |
| import torch |
| import torch.distributed |
| import transformers |
| from accelerate.logging import get_logger |
| from transformers import AutoTokenizer |
| from trl import SFTTrainer |
|
|
| import modelopt.torch.opt as mto |
| from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss |
|
|
| logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| teacher_name_or_path: str | None = None |
| student_name_or_path: str | None = None |
|
|
|
|
| @dataclass |
| class TrainingArguments(transformers.TrainingArguments): |
| do_train: bool = True |
| do_eval: bool = True |
| save_strategy: str = "no" |
| max_length: int = 1024 |
| optim: str = "adamw_torch" |
| learning_rate: float = 1e-5 |
| lr_scheduler_type: str = "cosine" |
| dataloader_drop_last: bool = True |
| dataset_num_proc: int = 8 |
| bf16: bool = True |
| |
|
|
|
|
| def _format_smoltalk_chat_template(sample, tokenizer): |
| |
| |
| messages = [ |
| {"role": "user", "content": sample["query"]}, |
| {"role": "assistant", "content": sample["answer"]}, |
| ] |
| return tokenizer.apply_chat_template(messages, tokenize=False) |
|
|
|
|
| class KDSFTTrainer(KDTrainer, SFTTrainer): |
| pass |
|
|
|
|
| def train(): |
| parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) |
| model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| |
| mto.enable_huggingface_checkpointing() |
|
|
| |
| total_batch_size = 64 |
| num_accum_steps = total_batch_size / ( |
| training_args.per_device_train_batch_size * torch.distributed.get_world_size() |
| ) |
| if not num_accum_steps.is_integer(): |
| raise ValueError( |
| f"`per_device_train_batch_size` * `world_size` must be a factor of {total_batch_size}" |
| ) |
| training_args.gradient_accumulation_steps = int(num_accum_steps) |
| logger.info( |
| f"Using {int(num_accum_steps)} grad accumulation steps for effective batchsize of {total_batch_size}." |
| ) |
|
|
| |
| logger.info("Loading dataset...") |
| dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train") |
| dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420) |
| dset_train, dset_eval = dset_splits["train"], dset_splits["test"] |
| logger.info("Dataset loaded.") |
|
|
| |
| logger.info("Loading tokenizer...") |
| model_path = model_args.teacher_name_or_path or model_args.student_name_or_path |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
| logger.info("Tokenizer loaded.") |
|
|
| |
| logger.info("Loading student model...") |
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None |
| ) |
| logger.info("Student loaded.") |
| logger.info("Loading teacher model...") |
| teacher_model = transformers.AutoModelForCausalLM.from_pretrained( |
| model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None |
| ) |
|
|
| |
| kd_config = { |
| "teacher_model": teacher_model, |
| "criterion": LMLogitsLoss(), |
| } |
|
|
| |
| model.generation_config.temperature = None |
| model.generation_config.top_p = None |
|
|
| |
| trainer = KDSFTTrainer( |
| model, |
| training_args, |
| distill_config=kd_config, |
| train_dataset=dset_train, |
| eval_dataset=dset_eval, |
| formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer), |
| processing_class=tokenizer, |
| ) |
|
|
| |
| if training_args.do_train: |
| logger.info("Beginning training...") |
| trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) |
| logger.info("Training done.") |
|
|
| |
| if training_args.do_eval: |
| logger.info("Evaluating...") |
| eval_results = trainer.evaluate() |
| logger.info(eval_results) |
| logger.info("Evaluation complete.") |
|
|
| |
| logger.info("Saving checkpoint...") |
| trainer.save_state() |
| trainer.save_model(trainer.args.output_dir) |
| logger.info("Checkpoint saved.") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|