""" G²LU Gate Grafting: Surgically upgrade pretrained SwiGLU models to G²LU. Takes any HuggingFace model with SwiGLU (gate_proj + up_proj), freezes everything except gate weights, adds W4 for nested gating, and trains with alignment + LM loss. This is grafting applied to the gate mechanism — the same methodology validated for full layer replacement, now targeting the minimum surgical unit. Usage: python -m circuits.train --arch graft_g2lu --pretrained meta-llama/Llama-3.2-1B \ --align-weight 1.0 --graft-warmup 500 --data hf:Bingsu/openwebtext_20p ... """ import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path class G2LU_MLP(nn.Module): """Per-layer MLP wrapper that upgrades SwiGLU to G²LU. Holds references to the original gate_proj (W3, frozen), up_proj (W1, frozen), down_proj (W2, frozen), plus a new w4 (zero-initialized, trainable). Gate ordering: silu(W4@x * silu(W3@x)) — the pretrained gate (W3) acts as structural prior, constraining W4 to operate within the feature subspace the pretrained model already deems relevant. W4's gradients are scaled by silu(W3@x), inheriting the pretrained model's feature selection hierarchy. """ def __init__(self, original_mlp: nn.Module): super().__init__() # References to original weights (all frozen) self.gate_proj = original_mlp.gate_proj # W3 — frozen self.up_proj = original_mlp.up_proj # W1 — frozen self.down_proj = original_mlp.down_proj # W2 — frozen # New W4: same shape as gate_proj, zero-initialized, matched dtype self.w4 = nn.Linear( self.gate_proj.in_features, self.gate_proj.out_features, bias=self.gate_proj.bias is not None, dtype=self.gate_proj.weight.dtype, device=self.gate_proj.weight.device, ) nn.init.zeros_(self.w4.weight) if self.w4.bias is not None: nn.init.zeros_(self.w4.bias) # Blend alpha: 0 = pure SwiGLU, 1 = full G²LU self._alpha = 0.0 # Per-layer alignment loss (collected by parent) self._align_loss = None def forward(self, x): # Pretrained gate (frozen W3) — structural prior w3_gate = F.silu(self.gate_proj(x)) # G²LU gate: silu(W4@x * silu(W3@x)) # W4 modulated BY pretrained knowledge, not the reverse g2lu_gate = F.silu(self.w4(x) * w3_gate) # Blend warmup: smooth transition from SwiGLU → G²LU if self._alpha < 1.0: gate = (1.0 - self._alpha) * w3_gate + self._alpha * g2lu_gate else: gate = g2lu_gate # Per-layer alignment loss (compare against original SwiGLU gate) self._align_loss = F.mse_loss(gate, w3_gate.detach()) return self.down_proj(gate * self.up_proj(x)) class G2LU_GraftedModel(nn.Module): """Full model wrapper that upgrades a pretrained HF model's MLPs to G²LU. Interface matches CircuitTransformer: forward(input_ids, labels=labels) returns {"loss", "logits", "align_loss"}. """ def __init__( self, pretrained_name: str, align_weight: float = 1.0, warmup_steps: int = 500, device: str = "cuda", dtype=torch.bfloat16, ): super().__init__() self.pretrained_name = pretrained_name self.align_weight = align_weight self.warmup_steps = warmup_steps self._current_step = 0 # Load pretrained HF model from transformers import AutoModelForCausalLM self.model = AutoModelForCausalLM.from_pretrained( pretrained_name, dtype=dtype, trust_remote_code=True, ) # Discover and replace MLPs self.g2lu_mlps = [] self._replace_mlps() # Freeze everything, then selectively unfreeze W4 only for param in self.model.parameters(): param.requires_grad = False for g2lu in self.g2lu_mlps: for param in g2lu.w4.parameters(): param.requires_grad = True self.model.to(device) # Print summary total_params = sum(p.numel() for p in self.model.parameters()) trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) print(f"G²LU Graft: {pretrained_name}") print(f" Layers upgraded: {len(self.g2lu_mlps)}") print(f" Total params: {total_params:,} ({total_params/1e6:.1f}M)") print(f" Trainable params: {trainable:,} ({trainable/1e6:.1f}M, {100*trainable/total_params:.1f}%)") print(f" Align weight: {align_weight}, Warmup: {warmup_steps} steps") def _replace_mlps(self): """Walk the model tree and replace SwiGLU MLPs with G²LU wrappers.""" # Try common decoder layer paths layers = None for attr_path in ["model.layers", "gpt_neox.layers", "transformer.h"]: obj = self.model try: for attr in attr_path.split("."): obj = getattr(obj, attr) layers = obj break except AttributeError: continue if layers is None: raise ValueError( f"Could not find decoder layers in {type(self.model).__name__}. " f"Tried: model.layers, gpt_neox.layers, transformer.h" ) for i, layer in enumerate(layers): # Try common MLP attribute names mlp = None mlp_attr = None for attr in ["mlp", "feed_forward"]: if hasattr(layer, attr): mlp = getattr(layer, attr) mlp_attr = attr break if mlp is None: continue # Check for SwiGLU signature (gate_proj + up_proj) if hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj"): g2lu = G2LU_MLP(mlp) setattr(layer, mlp_attr, g2lu) self.g2lu_mlps.append(g2lu) if not self.g2lu_mlps: raise ValueError( "No SwiGLU MLPs found (need gate_proj + up_proj attributes). " "This model may not use gated linear units." ) def set_step(self, step: int): """Update blend alpha across all G²LU MLPs.""" self._current_step = step alpha = min(step / max(self.warmup_steps, 1), 1.0) for g2lu in self.g2lu_mlps: g2lu._alpha = alpha def trainable_parameters(self): """Yield only unfrozen parameters (for optimizer and grad clipping).""" for param in self.model.parameters(): if param.requires_grad: yield param def collect_align_loss(self): """Average per-layer alignment losses.""" losses = [g2lu._align_loss for g2lu in self.g2lu_mlps if g2lu._align_loss is not None] if not losses: return torch.tensor(0.0) return torch.stack(losses).mean() def forward(self, input_ids, labels=None, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, **kwargs) result = {"logits": outputs.logits} align_loss = self.collect_align_loss() result["align_loss"] = align_loss if labels is not None: # Combine LM loss + alignment loss result["loss"] = outputs.loss + self.align_weight * align_loss result["lm_loss"] = outputs.loss else: result["loss"] = align_loss return result def generate(self, input_ids, **kwargs): """Delegate to HF model's .generate().""" return self.model.generate(input_ids=input_ids, **kwargs) def save_g2lu_checkpoint( model: G2LU_GraftedModel, optimizer: torch.optim.Optimizer, step: int, epoch: int, loss: float, path: str, epoch_step: int = 0, best_val_loss: float | None = None, scaler=None, tokenizer_name: str = None, ): """Delta save: only trainable params + metadata.""" # Extract only requires_grad params raw = model.model if not hasattr(model, '_orig_mod') else model._orig_mod.model # Handle torch.compile wrapper if hasattr(model, '_orig_mod'): g2lu_model = model._orig_mod else: g2lu_model = model delta_sd = {} full_sd = g2lu_model.model.state_dict() for name, param in g2lu_model.model.named_parameters(): if param.requires_grad: # Strip _orig_mod. prefix if present clean_name = name.removeprefix("_orig_mod.") delta_sd[clean_name] = full_sd.get(name, param.data).clone() # Also save the w4 weights explicitly (they're part of the replaced modules) for name, val in full_sd.items(): clean_name = name.removeprefix("_orig_mod.") if ".w4." in clean_name and clean_name not in delta_sd: delta_sd[clean_name] = val.clone() checkpoint = { "model": delta_sd, "optimizer": optimizer.state_dict(), "step": step, "epoch": epoch, "epoch_step": epoch_step, "loss": loss, "model_type": "graft_g2lu", "pretrained_name": g2lu_model.pretrained_name, "align_weight": g2lu_model.align_weight, "warmup_steps": g2lu_model.warmup_steps, "tokenizer_name": tokenizer_name or g2lu_model.pretrained_name, } if best_val_loss is not None: checkpoint["best_val_loss"] = best_val_loss if scaler is not None: checkpoint["scaler"] = scaler.state_dict() torch.save(checkpoint, path) def load_g2lu_model(checkpoint_path: str, device: str = "cuda", dtype=torch.bfloat16): """Delta load: recreate model from pretrained + apply delta weights.""" checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) pretrained_name = checkpoint["pretrained_name"] align_weight = checkpoint.get("align_weight", 1.0) warmup_steps = checkpoint.get("warmup_steps", 500) model = G2LU_GraftedModel( pretrained_name=pretrained_name, align_weight=align_weight, warmup_steps=warmup_steps, device=device, dtype=dtype, ) # Load delta weights delta_sd = checkpoint["model"] # Strip _orig_mod. prefix if present delta_sd = {k.removeprefix("_orig_mod."): v for k, v in delta_sd.items()} # Apply delta weights to the model missing, unexpected = model.model.load_state_dict(delta_sd, strict=False) if unexpected: print(f" Warning: unexpected keys in delta checkpoint: {unexpected[:5]}...") # Set alpha to 1.0 for inference (full G²LU) model.set_step(warmup_steps + 1) return model