| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments |
|
|
| from trl import SFTTrainer |
|
|
|
|
| tqdm.pandas() |
|
|
|
|
| |
| @dataclass |
| class ScriptArguments: |
| """ |
| The name of the Casual LM model we wish to fine with SFTTrainer |
| """ |
|
|
| model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) |
| dataset_name: Optional[str] = field( |
| default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"} |
| ) |
| dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) |
| log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) |
| learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
| batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) |
| seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) |
| gradient_accumulation_steps: Optional[int] = field( |
| default=2, metadata={"help": "the number of gradient accumulation steps"} |
| ) |
| load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
| load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
| use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) |
| trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"}) |
| output_dir: Optional[str] = field(default="./", metadata={"help": "the output directory"}) |
| peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"}) |
| peft_lora_alpha: Optional[int] = field(default=2, metadata={"help": "the alpha parameter of the LoRA adapters"}) |
| logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) |
| use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) |
| num_train_epochs: Optional[int] = field(default=2, metadata={"help": "the number of training epochs"}) |
| max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) |
|
|
|
|
| parser = HfArgumentParser(ScriptArguments) |
| script_args = parser.parse_args_into_dataclasses()[0] |
|
|
| |
| if script_args.load_in_8bit and script_args.load_in_4bit: |
| raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
| elif script_args.load_in_8bit or script_args.load_in_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit |
| ) |
| |
| device_map = {"": 0} |
| torch_dtype = torch.bfloat16 |
| else: |
| device_map = None |
| quantization_config = None |
| torch_dtype = None |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| script_args.model_name, |
| quantization_config=quantization_config, |
| device_map=device_map, |
| trust_remote_code=script_args.trust_remote_code, |
| torch_dtype=torch_dtype, |
| use_auth_token=script_args.use_auth_token, |
| ) |
|
|
| |
| dataset = load_dataset(script_args.dataset_name, split="train") |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=script_args.output_dir, |
| per_device_train_batch_size=script_args.batch_size, |
| gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
| learning_rate=script_args.learning_rate, |
| logging_steps=script_args.logging_steps, |
| num_train_epochs=script_args.num_train_epochs, |
| max_steps=script_args.max_steps, |
| report_to=script_args.log_with, |
| ) |
|
|
| |
| if script_args.use_peft: |
| peft_config = LoraConfig( |
| r=script_args.peft_lora_r, |
| lora_alpha=script_args.peft_lora_alpha, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| else: |
| peft_config = None |
|
|
| |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| max_seq_length=script_args.seq_length, |
| train_dataset=dataset, |
| dataset_text_field=script_args.dataset_text_field, |
| peft_config=peft_config, |
| ) |
|
|
| trainer.train() |
|
|
| |
| trainer.save_model(script_args.output_dir) |