ARBS / training /finetuning /lora.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Ternary LoRA adapters for memory-efficient fine-tuning.
Freezes base ternary weights, adds small float low-rank adapters.
Only adapters receive gradients — base state stays at 1.71 GB, adapters < 100 MB.
Usage:
from finetuning.lora import TernaryLoRA, apply_lora_to_model
model = ARBModel(...)
apply_lora_to_model(model, rank=16, target_modules=['moe', 'byte_head'])
# Only LoRA params are trainable — base is frozen
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from arbitor.kernel.ternary_scale import TernaryScaleTensor
class TernaryLoRALayer(nn.Module):
"""LoRA adapter wrapping a single TernaryScaleTensor.
Base ternary weights are frozen. Two small float matrices A and B
are trained: output = base(x) + (x @ A) @ B * scaling
"""
def __init__(self, base_layer, rank=8, alpha=16.0):
super().__init__()
self.base = base_layer
self.scaling = alpha / rank
# Freeze all base parameters and buffers
for p in base_layer.parameters(True):
p.requires_grad = False
# _T_shape is [out_dim, in_dim]
out_dim, in_dim = base_layer._T_shape.tolist()
out_dim, in_dim = int(out_dim), int(in_dim)
self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * 0.02)
self.lora_B = nn.Parameter(torch.zeros(rank, out_dim))
# For modules that need it, track if input requires grad
self.forward_count = 0
def forward(self, x):
with torch.no_grad():
base_out = self.base(x)
lora_out = (x @ self.lora_A) @ self.lora_B * self.scaling
return base_out + lora_out
def extra_repr(self):
return (f"base={tuple(self.base._T_shape.tolist())}, "
f"rank={self.lora_A.shape[1]}, alpha={self.scaling * self.lora_A.shape[1]:.0f}")
class LoRAEmbedding(nn.Module):
"""LoRA adapter for ByteEmbedding / TernaryEmbeddingTable.
Wraps the embedding lookup, adds a small learned delta.
"""
def __init__(self, base_embed, rank=16):
super().__init__()
self.base = base_embed
for p in base_embed.parameters(True):
p.requires_grad = False
num_embeddings, embed_dim = base_embed._T_shape.tolist()
self.lora_A = nn.Parameter(torch.randn(num_embeddings, rank) * 0.02)
self.lora_B = nn.Parameter(torch.randn(rank, embed_dim) * 0.02)
def forward(self, x):
with torch.no_grad():
base_out = self.base(x)
delta = F.embedding(x, self.lora_A @ self.lora_B)
return base_out + delta * 0.1
def _should_lora(name, target_modules):
"""Check if a module name matches any target pattern."""
for pattern in target_modules:
if pattern in name.lower():
return True
return False
def apply_lora_to_model(model, rank=16, alpha=32.0, target_modules=None):
"""Apply LoRA adapters to targeted TernaryScaleTensor modules.
Args:
model: ARBModel instance
rank: LoRA rank (8-32 typical, higher = more adapter capacity)
alpha: LoRA scaling alpha (higher = stronger adapter influence)
target_modules: List of name patterns.
Default: ['moe', 'byte_head', 'embedding', 'graph', 'head']
Returns:
Dict mapping module names to LoRA layers (for saving/loading adapters)
"""
if target_modules is None:
target_modules = ['W_gate', 'W_transform', 'router', 'byte_head', 'head',
'output_router', 'shared_up', 'shared_expert_gate',
'shared_expert_up']
lora_layers = {}
def _apply(module, name=''):
for child_name, child in list(module.named_children()):
full_name = f"{name}.{child_name}" if name else child_name
if isinstance(child, TernaryScaleTensor) and _should_lora(full_name, target_modules):
# Skip embedding layers (they use lookup, not matmul)
if hasattr(child, '_T_shape') and child._T_shape[1].item() == child._T_shape[0].item():
# Square projection — likely a linear layer, safe for LoRA
pass
elif hasattr(child, '_T_shape') and 'embed' in full_name.lower():
continue # Skip embeddings
lora = TernaryLoRALayer(child, rank=rank, alpha=alpha)
setattr(module, child_name, lora)
lora_layers[full_name] = lora
elif hasattr(child, '_T_shape') and hasattr(child, 'T_packed') and _should_lora(full_name, target_modules):
continue # Non-TernaryScaleTensor with T_packed is an embedding
else:
_apply(child, full_name)
_apply(model)
# Move LoRA adapters to CUDA (they're created on CPU by default)
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Freeze all non-LoRA params, only LoRA A/B are trainable
for name, p in model.named_parameters():
if 'lora_' not in name:
p.requires_grad = False
return lora_layers
def count_lora_params(model):
"""Count trainable (LoRA) vs frozen (base) parameters."""
lora = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return lora, total
def save_lora(lora_layers, path):
"""Save only LoRA adapter weights (small ~5-50 MB)."""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
state = {f"lora.{k}.A": v.lora_A for k, v in lora_layers.items()}
state.update({f"lora.{k}.B": v.lora_B for k, v in lora_layers.items()})
torch.save(state, path)
return path
def load_lora(model, path):
"""Load LoRA adapter weights into model."""
state = torch.load(path, weights_only=True)
for full_name, param in state.items():
parts = full_name.split('.')
if len(parts) < 3:
continue
# Find the module
obj = model
for p in parts[1:-1]:
obj = getattr(obj, p, None)
if obj is None:
break
if obj is not None and hasattr(obj, parts[-1]):
getattr(obj, parts[-1]).data.copy_(param.data)
return model