File size: 6,239 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Ternary LoRA adapters for memory-efficient fine-tuning.
Freezes base ternary weights, adds small float low-rank adapters.
Only adapters receive gradients — base state stays at 1.71 GB, adapters < 100 MB.
Usage:
from finetuning.lora import TernaryLoRA, apply_lora_to_model
model = ARBModel(...)
apply_lora_to_model(model, rank=16, target_modules=['moe', 'byte_head'])
# Only LoRA params are trainable — base is frozen
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from arbitor.kernel.ternary_scale import TernaryScaleTensor
class TernaryLoRALayer(nn.Module):
"""LoRA adapter wrapping a single TernaryScaleTensor.
Base ternary weights are frozen. Two small float matrices A and B
are trained: output = base(x) + (x @ A) @ B * scaling
"""
def __init__(self, base_layer, rank=8, alpha=16.0):
super().__init__()
self.base = base_layer
self.scaling = alpha / rank
# Freeze all base parameters and buffers
for p in base_layer.parameters(True):
p.requires_grad = False
# _T_shape is [out_dim, in_dim]
out_dim, in_dim = base_layer._T_shape.tolist()
out_dim, in_dim = int(out_dim), int(in_dim)
self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * 0.02)
self.lora_B = nn.Parameter(torch.zeros(rank, out_dim))
# For modules that need it, track if input requires grad
self.forward_count = 0
def forward(self, x):
with torch.no_grad():
base_out = self.base(x)
lora_out = (x @ self.lora_A) @ self.lora_B * self.scaling
return base_out + lora_out
def extra_repr(self):
return (f"base={tuple(self.base._T_shape.tolist())}, "
f"rank={self.lora_A.shape[1]}, alpha={self.scaling * self.lora_A.shape[1]:.0f}")
class LoRAEmbedding(nn.Module):
"""LoRA adapter for ByteEmbedding / TernaryEmbeddingTable.
Wraps the embedding lookup, adds a small learned delta.
"""
def __init__(self, base_embed, rank=16):
super().__init__()
self.base = base_embed
for p in base_embed.parameters(True):
p.requires_grad = False
num_embeddings, embed_dim = base_embed._T_shape.tolist()
self.lora_A = nn.Parameter(torch.randn(num_embeddings, rank) * 0.02)
self.lora_B = nn.Parameter(torch.randn(rank, embed_dim) * 0.02)
def forward(self, x):
with torch.no_grad():
base_out = self.base(x)
delta = F.embedding(x, self.lora_A @ self.lora_B)
return base_out + delta * 0.1
def _should_lora(name, target_modules):
"""Check if a module name matches any target pattern."""
for pattern in target_modules:
if pattern in name.lower():
return True
return False
def apply_lora_to_model(model, rank=16, alpha=32.0, target_modules=None):
"""Apply LoRA adapters to targeted TernaryScaleTensor modules.
Args:
model: ARBModel instance
rank: LoRA rank (8-32 typical, higher = more adapter capacity)
alpha: LoRA scaling alpha (higher = stronger adapter influence)
target_modules: List of name patterns.
Default: ['moe', 'byte_head', 'embedding', 'graph', 'head']
Returns:
Dict mapping module names to LoRA layers (for saving/loading adapters)
"""
if target_modules is None:
target_modules = ['W_gate', 'W_transform', 'router', 'byte_head', 'head',
'output_router', 'shared_up', 'shared_expert_gate',
'shared_expert_up']
lora_layers = {}
def _apply(module, name=''):
for child_name, child in list(module.named_children()):
full_name = f"{name}.{child_name}" if name else child_name
if isinstance(child, TernaryScaleTensor) and _should_lora(full_name, target_modules):
# Skip embedding layers (they use lookup, not matmul)
if hasattr(child, '_T_shape') and child._T_shape[1].item() == child._T_shape[0].item():
# Square projection — likely a linear layer, safe for LoRA
pass
elif hasattr(child, '_T_shape') and 'embed' in full_name.lower():
continue # Skip embeddings
lora = TernaryLoRALayer(child, rank=rank, alpha=alpha)
setattr(module, child_name, lora)
lora_layers[full_name] = lora
elif hasattr(child, '_T_shape') and hasattr(child, 'T_packed') and _should_lora(full_name, target_modules):
continue # Non-TernaryScaleTensor with T_packed is an embedding
else:
_apply(child, full_name)
_apply(model)
# Move LoRA adapters to CUDA (they're created on CPU by default)
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Freeze all non-LoRA params, only LoRA A/B are trainable
for name, p in model.named_parameters():
if 'lora_' not in name:
p.requires_grad = False
return lora_layers
def count_lora_params(model):
"""Count trainable (LoRA) vs frozen (base) parameters."""
lora = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return lora, total
def save_lora(lora_layers, path):
"""Save only LoRA adapter weights (small ~5-50 MB)."""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
state = {f"lora.{k}.A": v.lora_A for k, v in lora_layers.items()}
state.update({f"lora.{k}.B": v.lora_B for k, v in lora_layers.items()})
torch.save(state, path)
return path
def load_lora(model, path):
"""Load LoRA adapter weights into model."""
state = torch.load(path, weights_only=True)
for full_name, param in state.items():
parts = full_name.split('.')
if len(parts) < 3:
continue
# Find the module
obj = model
for p in parts[1:-1]:
obj = getattr(obj, p, None)
if obj is None:
break
if obj is not None and hasattr(obj, parts[-1]):
getattr(obj, parts[-1]).data.copy_(param.data)
return model
|