| from dataclasses import dataclass, field, fields, asdict |
| from typing import Optional, List, Literal, Dict, Any, Union |
| from transformers import TrainingArguments, Trainer |
| from omegaconf import OmegaConf |
| import sys |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| model_name: str = "" |
| dropout: float = 0.0 |
| model_max_seq_length: int = field(default=512) |
| data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"}) |
| lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
| adapter_path: Optional[str] = field(default=None) |
|
|
| merge_adapter_path: Optional[str] = field(default=None) |
| merge_output_path: Optional[str] = field(default=None) |
|
|
| @dataclass |
| class RotationConfig: |
| r: int = field(default=4) |
| num_rotations: int = field(default=4) |
| task_type: str = "CAUSAL_LM" |
| target_modules: List[str] = field(default_factory=lambda: ["q_proj",]) |
|
|
| @dataclass |
| class DataConfig: |
| dataset_name: str = 'math' |
| split_ratio: float = field(default=0.01) |
| path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json" |
| dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"}) |
| adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) |
| dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."}) |
|
|
|
|
| @dataclass |
| class TrainingOverride: |
| optim: str=field(default="adamw_torch") |
| eval_strategy: str=field(default='no') |
| per_device_train_batch_size: int=field(default=8) |
| per_device_eval_batch_size: int=field(default=8) |
|
|
| learning_rate: float = field(default=1e-05) |
| lr_scheduler_type: str = field(default='cosine') |
| |
| warmup_steps: int = field(default=0) |
| |
| gradient_checkpointing: bool = field(default=False) |
| gradient_accumulation_steps: int=field(default=1) |
| output_dir: str = field(default="runs") |
| save_steps: float = field(default=0) |
| save_strategy: str =field(default='no') |
| |
| bf16: bool=field(default=False) |
| bf16_full_eval: bool=field(default=False) |
| save_safetensors: bool=field(default=False) |
|
|
| report_to: Union[None, str, list[str]]=field(default="none") |
| logging_steps: int=field(default=25) |
| |
| eval_steps: Union[None,int]=field(default=None) |
|
|
| dataloader_num_workers: int = field(default=1) |
| dataloader_pin_memory: bool = field(default=True) |
| dataloader_persistent_workers: bool=field(default=True) |
| dataloader_prefetch_factor: int = field(default=1) |
|
|
| num_train_epochs: float = field(default=1.0) |
| max_steps: int=field(default=-1) |
| load_best_model_at_end: bool = field(default=True) |
|
|
| @dataclass |
| class GlueConfig: |
| task_name: str = field(default='mnli') |
| pad_to_max_length: bool = field(default=True) |
|
|
|
|
| @dataclass |
| class MainConfig: |
| model: ModelConfig = field(default_factory=ModelConfig) |
| rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig) |
| data: DataConfig = field(default_factory=DataConfig) |
| trainer_args: TrainingOverride = field(default_factory=TrainingOverride) |
|
|
| glue: GlueConfig = field(default_factory=GlueConfig) |
| project_name: str = "llm_rotation" |
| seed: int = 42 |
| run_text: str=field(default='def') |
| |
|
|
| @dataclass |
| class HFTrainingArguments(TrainingArguments): |
| extension: Optional[Dict[str, Any]] = field( |
| default=None, |
| metadata={"help": "Serialized MainConfig excluding training args"} |
| ) |
|
|
| def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments: |
| """ |
| Maps MainConfig to MyTrainingArguments. |
| Logic: |
| 1. Extract 'training' fields -> Pass to TrainingArguments constructor. |
| 2. Pack 'model', 'data', etc. -> Put into 'extension'. |
| """ |
| KEY = "trainer_args" |
| |
| |
| full_dict = asdict(main_cfg) |
| |
| |
| |
| train_args_dict = full_dict.pop(KEY) |
| |
| |
| extension_payload = full_dict |
| |
| |
| |
| try: |
| args = HFTrainingArguments(**train_args_dict) |
| except TypeError as e: |
| print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}") |
| sys.exit(1) |
| |
| |
| args.extension = extension_payload |
| |
| return args |
|
|
|
|
|
|
|
|
| @dataclass |
| class Training: |
| model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b") |
| adapter_name_or_path: Optional[str] = field(default=None) |
| data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
| dataset_split: str = field( |
| default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"} |
| ) |
| dataset_field: List[str] = field( |
| default=None, metadata={"help": "Fields of dataset input and output."} |
| ) |
| optim: str = field(default="adamw_torch") |
| model_max_length: int = field(default=512, metadata={ |
| "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) |
| hrft_r: int = field(default=8, metadata={ |
| "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."}) |
| init_a: float = field(default=1e-4, metadata={"help": "The initial weights"}) |
| eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."}) |
| lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"}) |
| add_orth: str = field(default='none', metadata={"help": ""}) |
| init_weights: Literal[True, "pissa"] = field( |
| default=True, |
| metadata={ |
| "help": ( |
| "Passing True (default) results in the LoRA initialization." |
| "Passing `pissa` results in PiSSA initialization." |
| ), |
| }, |
| ) |
| extension: Optional[Dict[str, Any]] = field( |
| default=None, |
| metadata={"help": "Serialized MainConfig excluding training args"} |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |