universal-model-trainer / app /utils /peft_utils.py
vectorplasticity's picture
Add PEFT utilities for LoRA/adapter support
76a4048 verified
"""
PEFT Utilities for Parameter-Efficient Fine-Tuning
Supports LoRA, AdaLoRA, IA3, Prefix Tuning, and Prompt Tuning
"""
import os
import json
import logging
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass, field
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = logging.getLogger(__name__)
# PEFT configuration classes
@dataclass
class LoRAConfig:
"""LoRA configuration"""
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
bias: str = "none"
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class AdaLoRAConfig:
"""AdaLoRA configuration"""
target_r: int = 8
init_r: int = 12
tinit: int = 200
tfinal: int = 1000
deltaT: int = 10
beta1: float = 0.85
beta2: float = 0.85
lora_alpha: int = 16
lora_dropout: float = 0.05
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class IA3Config:
"""IA3 configuration"""
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj"])
feedforward_modules: List[str] = field(default_factory=lambda: ["up_proj", "down_proj"])
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class PrefixTuningConfig:
"""Prefix Tuning configuration"""
num_virtual_tokens: int = 20
encoder_hidden_size: Optional[int] = None
prefix_projection: bool = False
projection_dim: int = 128
dropout: float = 0.0
@dataclass
class PromptTuningConfig:
"""Prompt Tuning configuration"""
num_virtual_tokens: int = 20
tokenizer_name_or_path: Optional[str] = None
num_layers: Optional[int] = None
token_dim: Optional[int] = None
PEFT_CONFIG_MAP = {
"lora": LoRAConfig,
"adalora": AdaLoRAConfig,
"ia3": IA3Config,
"prefix_tuning": PrefixTuningConfig,
"prompt_tuning": PromptTuningConfig,
}
def get_peft_config(peft_type: str, **kwargs) -> Any:
"""
Get PEFT configuration for the specified type.
Args:
peft_type: Type of PEFT method ('lora', 'adalora', 'ia3', 'prefix_tuning', 'prompt_tuning')
**kwargs: Configuration parameters
Returns:
PEFT configuration object
"""
peft_type = peft_type.lower()
if peft_type not in PEFT_CONFIG_MAP:
raise ValueError(f"Unknown PEFT type: {peft_type}. Available: {list(PEFT_CONFIG_MAP.keys())}")
config_class = PEFT_CONFIG_MAP[peft_type]
return config_class(**kwargs)
def apply_peft_to_model(
model: PreTrainedModel,
peft_type: str,
config: Optional[Union[Dict, Any]] = None,
**kwargs
) -> PreTrainedModel:
"""
Apply PEFT to a model.
Args:
model: The base model to apply PEFT to
peft_type: Type of PEFT method
config: PEFT configuration (dict or dataclass)
**kwargs: Additional configuration parameters
Returns:
Model with PEFT applied
"""
try:
from peft import (
LoraConfig, AdaLoraConfig, IA3Config,
PrefixTuningConfig, PromptTuningConfig,
get_peft_model, TaskType, prepare_model_for_kbit_training
)
except ImportError:
logger.warning("PEFT library not installed. Returning original model.")
return model
peft_type = peft_type.lower()
# Build PEFT config
if config is None:
config = {}
if isinstance(config, dict):
config_data = {**config, **kwargs}
else:
config_data = {k: v for k, v in vars(config).items() if not k.startswith('_')}
config_data.update(kwargs)
# Map to PEFT library config classes
peft_config_map = {
"lora": LoraConfig,
"adalora": AdaLoraConfig,
"ia3": IA3Config,
"prefix_tuning": PrefixTuningConfig,
"prompt_tuning": PromptTuningConfig,
}
if peft_type not in peft_config_map:
raise ValueError(f"Unknown PEFT type: {peft_type}")
peft_config_class = peft_config_map[peft_type]
# Determine task type
task_type = config_data.pop('task_type', None)
if task_type:
task_type_map = {
'causal-lm': TaskType.CAUSAL_LM,
'seq2seq': TaskType.SEQ_2_SEQ_LM,
'token-classification': TaskType.TOKEN_CLS,
'text-classification': TaskType.SEQ_CLS,
'question-answering': TaskType.QUESTION_ANS,
}
task_type = task_type_map.get(task_type)
if task_type:
config_data['task_type'] = task_type
# Create PEFT config
peft_config = peft_config_class(**config_data)
# Prepare model for k-bit training if needed
if hasattr(model, 'is_loaded_in_8bit') and model.is_loaded_in_8bit:
model = prepare_model_for_kbit_training(model)
elif hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit:
model = prepare_model_for_kbit_training(model)
# Apply PEFT
model = get_peft_model(model, peft_config)
# Log trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logger.info(f"Trainable params: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
return model
def get_target_modules_for_architecture(model_name: str) -> List[str]:
"""
Get recommended target modules based on model architecture.
Args:
model_name: Name of the model
Returns:
List of target module names
"""
model_name_lower = model_name.lower()
# LLaMA, Alpaca, Vicuna
if any(name in model_name_lower for name in ['llama', 'alpaca', 'vicuna']):
return ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# Mistral
if 'mistral' in model_name_lower:
return ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# BERT, RoBERTa, DeBERTa
if any(name in model_name_lower for name in ['bert', 'roberta', 'deberta']):
return ['query', 'key', 'value', 'dense']
# T5, Flan-T5
if 't5' in model_name_lower:
return ['q', 'k', 'v', 'o', 'wi', 'wo']
# GPT-2, GPT-Neo, GPT-J
if any(name in model_name_lower for name in ['gpt2', 'gpt-neo', 'gptj', 'gpt-j']):
return ['c_attn', 'c_proj', 'mlp.c_fc', 'mlp.c_proj']
# Bloom
if 'bloom' in model_name_lower:
return ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
# OPT
if 'opt' in model_name_lower:
return ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2']
# Falcon
if 'falcon' in model_name_lower:
return ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
# Default for transformer models
return ['q_proj', 'v_proj']
def estimate_lora_parameters(
base_model_params: int,
r: int,
target_modules: List[str],
lora_alpha: int = 16
) -> Dict[str, int]:
"""
Estimate the number of trainable parameters for LoRA.
Args:
base_model_params: Number of parameters in the base model
r: LoRA rank
target_modules: List of target module names
lora_alpha: LoRA alpha parameter
Returns:
Dictionary with parameter estimates
"""
# Rough estimate: each target module gets 2 LoRA matrices (A and B)
# Size depends on layer dimensions and rank
# For a typical transformer layer:
# - attention projections: hidden_size x hidden_size
# - LoRA adds: hidden_size x r + r x hidden_size = 2 * hidden_size * r per module
# Estimate hidden size from total params (rough approximation)
hidden_size = int((base_model_params ** 0.5) * 0.5)
# Estimate params per target module
params_per_module = 2 * hidden_size * r
# Total trainable params (rough estimate)
total_lora_params = params_per_module * len(target_modules)
return {
'estimated_trainable_params': total_lora_params,
'params_per_module': params_per_module,
'compression_ratio': base_model_params / total_lora_params if total_lora_params > 0 else 0,
'memory_reduction_percent': 100 * (1 - total_lora_params / base_model_params) if base_model_params > 0 else 0
}
def save_peft_model(
model,
output_dir: str,
tokenizer: Optional[PreTrainedTokenizer] = None,
save_merged: bool = False
) -> Dict[str, str]:
"""
Save PEFT model and associated files.
Args:
model: PEFT model to save
output_dir: Directory to save to
tokenizer: Optional tokenizer to save
save_merged: Whether to save merged model
Returns:
Dictionary with saved file paths
"""
os.makedirs(output_dir, exist_ok=True)
saved_files = []
try:
# Save PEFT adapters
model.save_pretrained(output_dir)
saved_files.append(f"{output_dir}/adapter_config.json")
saved_files.append(f"{output_dir}/adapter_model.safetensors")
# Save tokenizer if provided
if tokenizer:
tokenizer.save_pretrained(output_dir)
saved_files.append(f"{output_dir}/tokenizer.json")
# Optionally save merged model
if save_merged:
try:
merged_model = model.merge_and_unload()
merged_dir = os.path.join(output_dir, "merged")
merged_model.save_pretrained(merged_dir)
if tokenizer:
tokenizer.save_pretrained(merged_dir)
saved_files.append(f"{merged_dir}/pytorch_model.bin")
except Exception as e:
logger.warning(f"Could not merge model: {e}")
# Save training config
config = {
'peft_type': model.active_peft_config.peft_type.value if hasattr(model, 'active_peft_config') else 'unknown',
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
'total_params': sum(p.numel() for p in model.parameters()),
}
config_path = os.path.join(output_dir, "training_config.json")
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
saved_files.append(config_path)
logger.info(f"Saved PEFT model to {output_dir}")
except Exception as e:
logger.error(f"Error saving PEFT model: {e}")
raise
return {'saved_files': saved_files, 'output_dir': output_dir}
def load_peft_model(
base_model_name: str,
peft_model_path: str,
device: str = 'auto'
):
"""
Load a PEFT model.
Args:
base_model_name: Name or path of the base model
peft_model_path: Path to the saved PEFT adapters
device: Device to load to
Returns:
Loaded PEFT model
"""
try:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map=device
)
# Load PEFT model
model = PeftModel.from_pretrained(base_model, peft_model_path)
return model
except Exception as e:
logger.error(f"Error loading PEFT model: {e}")
raise
def get_peft_memory_requirements(
model_params: int,
peft_type: str = 'lora',
r: int = 8,
batch_size: int = 1,
seq_length: int = 512,
gradient_checkpointing: bool = True
) -> Dict[str, float]:
"""
Estimate memory requirements for PEFT training.
Args:
model_params: Number of model parameters
peft_type: Type of PEFT method
r: LoRA rank (if applicable)
batch_size: Training batch size
seq_length: Sequence length
gradient_checkpointing: Whether gradient checkpointing is enabled
Returns:
Dictionary with memory estimates in GB
"""
# Base model memory (FP16)
base_memory = model_params * 2 / 1e9
# Optimizer states (AdamW: 2 states per param)
# Only for trainable params with PEFT
trainable_ratio = r / 512 # Approximate ratio for LoRA
trainable_params = model_params * trainable_ratio
optimizer_memory = trainable_params * 2 * 4 / 1e9 # 2 states, FP32
# Gradients (only for trainable params)
gradient_memory = trainable_params * 2 / 1e9
# Activations (depends on batch size, seq length, and gradient checkpointing)
# Rough estimate: ~batch_size * seq_length * hidden_size * num_layers
activation_memory = batch_size * seq_length * (model_params ** 0.5) * 0.1 / 1e9
if gradient_checkpointing:
activation_memory *= 0.2 # Significant reduction
# Total
total_memory = base_memory + optimizer_memory + gradient_memory + activation_memory
return {
'base_model_gb': round(base_memory, 2),
'optimizer_states_gb': round(optimizer_memory, 2),
'gradients_gb': round(gradient_memory, 2),
'activations_gb': round(activation_memory, 2),
'total_gb': round(total_memory, 2),
'peak_gb': round(total_memory * 1.1, 2), # 10% buffer
'recommended_gpu_vram': round(total_memory * 1.2, 2) # 20% buffer
}
# Convenience function for quick LoRA setup
def quick_lora_setup(
model: PreTrainedModel,
r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
target_modules: Optional[List[str]] = None
) -> PreTrainedModel:
"""
Quick setup for LoRA fine-tuning.
Args:
model: Base model
r: LoRA rank
lora_alpha: LoRA alpha
lora_dropout: Dropout rate
target_modules: Target modules (auto-detected if None)
Returns:
Model with LoRA applied
"""
if target_modules is None:
# Try to auto-detect from model config
model_name = getattr(model.config, '_name_or_path', '')
target_modules = get_target_modules_for_architecture(model_name)
return apply_peft_to_model(
model,
'lora',
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules
)