| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from dataclasses import dataclass, field |
| import logging |
| import pathlib |
| from typing import Dict, Optional, Sequence, List |
| import torch |
| import transformers |
| import sys |
|
|
| from salmonn_trainer import SALMONNTrainer, get_state |
| from dataset import make_supervised_data_module, DataArguments |
| from model import SALMONN |
| from utils import print_trainable_parameters |
|
|
| import wandb |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| ckpt_path: Optional[str] = field(default='./salmonn_7b_v0.pth') |
| whisper_path: Optional[str] = field(default='./whisper-large-v2') |
| beats_path: Optional[str] = field(default='./BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') |
| vicuna_path: Optional[str] = field(default='./vicuna-7b-v1.5') |
| version: Optional[str] = field(default="v0") |
| device: Optional[str] = field(default='cuda') |
|
|
|
|
| @dataclass |
| class TrainingArguments(transformers.TrainingArguments): |
| output_dir: Optional[str] = field(default='./checkpoints/') |
| optim: str = field(default="adamw_torch") |
| bf16: bool = True |
| fp16: bool = False |
| lora_alpha: int = 32 |
| model_max_length: int = 2048 |
| use_cache: bool = False |
| gradient_checkpointing: bool = False |
|
|
|
|
| def train(): |
| parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
| wandb.init(project='SALMONN', name=training_args.run_name) |
|
|
| model = SALMONN( |
| model_args.ckpt_path, model_args.whisper_path, model_args.beats_path, model_args.vicuna_path, |
| lora_alpha=training_args.lora_alpha, compute_dtype=compute_dtype |
| ).cuda() |
| print_trainable_parameters(model, vb=0) |
|
|
| data_module = make_supervised_data_module(tokenizer=model.tokenizer, data_args=data_args) |
| trainer = SALMONNTrainer(model=model, tokenizer=model.tokenizer, args=training_args, **data_module) |
|
|
| trainer.train() |
|
|
| |
| weight_to_save = get_state(model.named_parameters()) |
| torch.save(weight_to_save, os.path.join(training_args.output_dir, f'salomnn_7b.bin')) |
|
|
| trainer.save_state() |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|