| """Ternary LoRA adapters for memory-efficient fine-tuning. |
| |
| Freezes base ternary weights, adds small float low-rank adapters. |
| Only adapters receive gradients — base state stays at 1.71 GB, adapters < 100 MB. |
| |
| Usage: |
| from finetuning.lora import TernaryLoRA, apply_lora_to_model |
| |
| model = ARBModel(...) |
| apply_lora_to_model(model, rank=16, target_modules=['moe', 'byte_head']) |
| # Only LoRA params are trainable — base is frozen |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os |
| from arbitor.kernel.ternary_scale import TernaryScaleTensor |
|
|
|
|
| class TernaryLoRALayer(nn.Module): |
| """LoRA adapter wrapping a single TernaryScaleTensor. |
| |
| Base ternary weights are frozen. Two small float matrices A and B |
| are trained: output = base(x) + (x @ A) @ B * scaling |
| """ |
| def __init__(self, base_layer, rank=8, alpha=16.0): |
| super().__init__() |
| self.base = base_layer |
| self.scaling = alpha / rank |
|
|
| |
| for p in base_layer.parameters(True): |
| p.requires_grad = False |
|
|
| |
| out_dim, in_dim = base_layer._T_shape.tolist() |
| out_dim, in_dim = int(out_dim), int(in_dim) |
|
|
| self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * 0.02) |
| self.lora_B = nn.Parameter(torch.zeros(rank, out_dim)) |
|
|
| |
| self.forward_count = 0 |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| base_out = self.base(x) |
| lora_out = (x @ self.lora_A) @ self.lora_B * self.scaling |
| return base_out + lora_out |
|
|
| def extra_repr(self): |
| return (f"base={tuple(self.base._T_shape.tolist())}, " |
| f"rank={self.lora_A.shape[1]}, alpha={self.scaling * self.lora_A.shape[1]:.0f}") |
|
|
|
|
| class LoRAEmbedding(nn.Module): |
| """LoRA adapter for ByteEmbedding / TernaryEmbeddingTable. |
| |
| Wraps the embedding lookup, adds a small learned delta. |
| """ |
| def __init__(self, base_embed, rank=16): |
| super().__init__() |
| self.base = base_embed |
| for p in base_embed.parameters(True): |
| p.requires_grad = False |
|
|
| num_embeddings, embed_dim = base_embed._T_shape.tolist() |
| self.lora_A = nn.Parameter(torch.randn(num_embeddings, rank) * 0.02) |
| self.lora_B = nn.Parameter(torch.randn(rank, embed_dim) * 0.02) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| base_out = self.base(x) |
| delta = F.embedding(x, self.lora_A @ self.lora_B) |
| return base_out + delta * 0.1 |
|
|
|
|
| def _should_lora(name, target_modules): |
| """Check if a module name matches any target pattern.""" |
| for pattern in target_modules: |
| if pattern in name.lower(): |
| return True |
| return False |
|
|
|
|
| def apply_lora_to_model(model, rank=16, alpha=32.0, target_modules=None): |
| """Apply LoRA adapters to targeted TernaryScaleTensor modules. |
| |
| Args: |
| model: ARBModel instance |
| rank: LoRA rank (8-32 typical, higher = more adapter capacity) |
| alpha: LoRA scaling alpha (higher = stronger adapter influence) |
| target_modules: List of name patterns. |
| Default: ['moe', 'byte_head', 'embedding', 'graph', 'head'] |
| Returns: |
| Dict mapping module names to LoRA layers (for saving/loading adapters) |
| """ |
| if target_modules is None: |
| target_modules = ['W_gate', 'W_transform', 'router', 'byte_head', 'head', |
| 'output_router', 'shared_up', 'shared_expert_gate', |
| 'shared_expert_up'] |
|
|
| lora_layers = {} |
|
|
| def _apply(module, name=''): |
| for child_name, child in list(module.named_children()): |
| full_name = f"{name}.{child_name}" if name else child_name |
|
|
| if isinstance(child, TernaryScaleTensor) and _should_lora(full_name, target_modules): |
| |
| if hasattr(child, '_T_shape') and child._T_shape[1].item() == child._T_shape[0].item(): |
| |
| pass |
| elif hasattr(child, '_T_shape') and 'embed' in full_name.lower(): |
| continue |
| lora = TernaryLoRALayer(child, rank=rank, alpha=alpha) |
| setattr(module, child_name, lora) |
| lora_layers[full_name] = lora |
|
|
| elif hasattr(child, '_T_shape') and hasattr(child, 'T_packed') and _should_lora(full_name, target_modules): |
| continue |
|
|
| else: |
| _apply(child, full_name) |
|
|
| _apply(model) |
|
|
| |
| model.to('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| for name, p in model.named_parameters(): |
| if 'lora_' not in name: |
| p.requires_grad = False |
|
|
| return lora_layers |
|
|
|
|
| def count_lora_params(model): |
| """Count trainable (LoRA) vs frozen (base) parameters.""" |
| lora = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in model.parameters()) |
| return lora, total |
|
|
|
|
| def save_lora(lora_layers, path): |
| """Save only LoRA adapter weights (small ~5-50 MB).""" |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
| state = {f"lora.{k}.A": v.lora_A for k, v in lora_layers.items()} |
| state.update({f"lora.{k}.B": v.lora_B for k, v in lora_layers.items()}) |
| torch.save(state, path) |
| return path |
|
|
|
|
| def load_lora(model, path): |
| """Load LoRA adapter weights into model.""" |
| state = torch.load(path, weights_only=True) |
| for full_name, param in state.items(): |
| parts = full_name.split('.') |
| if len(parts) < 3: |
| continue |
| |
| obj = model |
| for p in parts[1:-1]: |
| obj = getattr(obj, p, None) |
| if obj is None: |
| break |
| if obj is not None and hasattr(obj, parts[-1]): |
| getattr(obj, parts[-1]).data.copy_(param.data) |
| return model |
|
|