| |
| from typing import Optional |
|
|
| import torch |
| from peft import IA3Config, PeftModel, get_peft_model |
|
|
| from swift.llm import MODEL_ARCH_MAPPING, ModelKeys |
| from swift.utils import find_all_linears |
|
|
|
|
| class Tuner: |
|
|
| @staticmethod |
| def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
| """Prepare a new model with a tuner |
| |
| Args: |
| args: The training arguments |
| model: The model instance |
| |
| Returns: |
| The wrapped model |
| """ |
| raise NotImplementedError |
|
|
| @staticmethod |
| def save_pretrained( |
| model: torch.nn.Module, |
| save_directory: str, |
| state_dict: Optional[dict] = None, |
| safe_serialization: bool = True, |
| **kwargs, |
| ) -> None: |
| """Save when save_steps reaches |
| |
| Args: |
| model: The wrapped model by `prepare_model` |
| save_directory: The directory to save |
| safe_serialization: Use safetensors or not |
| """ |
| raise NotImplementedError |
|
|
| @staticmethod |
| def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: |
| """Load the ckpt_dir |
| |
| Args: |
| model: The original model instance. |
| model_id: The model id or ckpt_dir to load |
| Returns: |
| The wrapped model instance |
| """ |
| raise NotImplementedError |
|
|
|
|
| class PeftTuner(Tuner): |
|
|
| @staticmethod |
| def save_pretrained( |
| model: torch.nn.Module, |
| save_directory: str, |
| state_dict: Optional[dict] = None, |
| safe_serialization: bool = True, |
| **kwargs, |
| ) -> None: |
| model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs) |
|
|
| @staticmethod |
| def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: |
| return PeftModel.from_pretrained(model, model_id, **kwargs) |
|
|
|
|
| |
| class IA3(PeftTuner): |
|
|
| @staticmethod |
| def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
| model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch] |
| ia3_config = IA3Config( |
| target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*') |
| return get_peft_model(model, ia3_config) |
|
|
|
|
| class DummyTuner(PeftTuner): |
|
|
| @staticmethod |
| def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: |
| return model |
|
|
|
|
| |
| extra_tuners = {'ia3': IA3, 'dummy': DummyTuner} |
|
|