"""Ternary LoRA adapters for memory-efficient fine-tuning. Freezes base ternary weights, adds small float low-rank adapters. Only adapters receive gradients — base state stays at 1.71 GB, adapters < 100 MB. Usage: from finetuning.lora import TernaryLoRA, apply_lora_to_model model = ARBModel(...) apply_lora_to_model(model, rank=16, target_modules=['moe', 'byte_head']) # Only LoRA params are trainable — base is frozen """ import torch import torch.nn as nn import torch.nn.functional as F import os from arbitor.kernel.ternary_scale import TernaryScaleTensor class TernaryLoRALayer(nn.Module): """LoRA adapter wrapping a single TernaryScaleTensor. Base ternary weights are frozen. Two small float matrices A and B are trained: output = base(x) + (x @ A) @ B * scaling """ def __init__(self, base_layer, rank=8, alpha=16.0): super().__init__() self.base = base_layer self.scaling = alpha / rank # Freeze all base parameters and buffers for p in base_layer.parameters(True): p.requires_grad = False # _T_shape is [out_dim, in_dim] out_dim, in_dim = base_layer._T_shape.tolist() out_dim, in_dim = int(out_dim), int(in_dim) self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * 0.02) self.lora_B = nn.Parameter(torch.zeros(rank, out_dim)) # For modules that need it, track if input requires grad self.forward_count = 0 def forward(self, x): with torch.no_grad(): base_out = self.base(x) lora_out = (x @ self.lora_A) @ self.lora_B * self.scaling return base_out + lora_out def extra_repr(self): return (f"base={tuple(self.base._T_shape.tolist())}, " f"rank={self.lora_A.shape[1]}, alpha={self.scaling * self.lora_A.shape[1]:.0f}") class LoRAEmbedding(nn.Module): """LoRA adapter for ByteEmbedding / TernaryEmbeddingTable. Wraps the embedding lookup, adds a small learned delta. """ def __init__(self, base_embed, rank=16): super().__init__() self.base = base_embed for p in base_embed.parameters(True): p.requires_grad = False num_embeddings, embed_dim = base_embed._T_shape.tolist() self.lora_A = nn.Parameter(torch.randn(num_embeddings, rank) * 0.02) self.lora_B = nn.Parameter(torch.randn(rank, embed_dim) * 0.02) def forward(self, x): with torch.no_grad(): base_out = self.base(x) delta = F.embedding(x, self.lora_A @ self.lora_B) return base_out + delta * 0.1 def _should_lora(name, target_modules): """Check if a module name matches any target pattern.""" for pattern in target_modules: if pattern in name.lower(): return True return False def apply_lora_to_model(model, rank=16, alpha=32.0, target_modules=None): """Apply LoRA adapters to targeted TernaryScaleTensor modules. Args: model: ARBModel instance rank: LoRA rank (8-32 typical, higher = more adapter capacity) alpha: LoRA scaling alpha (higher = stronger adapter influence) target_modules: List of name patterns. Default: ['moe', 'byte_head', 'embedding', 'graph', 'head'] Returns: Dict mapping module names to LoRA layers (for saving/loading adapters) """ if target_modules is None: target_modules = ['W_gate', 'W_transform', 'router', 'byte_head', 'head', 'output_router', 'shared_up', 'shared_expert_gate', 'shared_expert_up'] lora_layers = {} def _apply(module, name=''): for child_name, child in list(module.named_children()): full_name = f"{name}.{child_name}" if name else child_name if isinstance(child, TernaryScaleTensor) and _should_lora(full_name, target_modules): # Skip embedding layers (they use lookup, not matmul) if hasattr(child, '_T_shape') and child._T_shape[1].item() == child._T_shape[0].item(): # Square projection — likely a linear layer, safe for LoRA pass elif hasattr(child, '_T_shape') and 'embed' in full_name.lower(): continue # Skip embeddings lora = TernaryLoRALayer(child, rank=rank, alpha=alpha) setattr(module, child_name, lora) lora_layers[full_name] = lora elif hasattr(child, '_T_shape') and hasattr(child, 'T_packed') and _should_lora(full_name, target_modules): continue # Non-TernaryScaleTensor with T_packed is an embedding else: _apply(child, full_name) _apply(model) # Move LoRA adapters to CUDA (they're created on CPU by default) model.to('cuda' if torch.cuda.is_available() else 'cpu') # Freeze all non-LoRA params, only LoRA A/B are trainable for name, p in model.named_parameters(): if 'lora_' not in name: p.requires_grad = False return lora_layers def count_lora_params(model): """Count trainable (LoRA) vs frozen (base) parameters.""" lora = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) return lora, total def save_lora(lora_layers, path): """Save only LoRA adapter weights (small ~5-50 MB).""" os.makedirs(os.path.dirname(path) or ".", exist_ok=True) state = {f"lora.{k}.A": v.lora_A for k, v in lora_layers.items()} state.update({f"lora.{k}.B": v.lora_B for k, v in lora_layers.items()}) torch.save(state, path) return path def load_lora(model, path): """Load LoRA adapter weights into model.""" state = torch.load(path, weights_only=True) for full_name, param in state.items(): parts = full_name.split('.') if len(parts) < 3: continue # Find the module obj = model for p in parts[1:-1]: obj = getattr(obj, p, None) if obj is None: break if obj is not None and hasattr(obj, parts[-1]): getattr(obj, parts[-1]).data.copy_(param.data) return model