| """
|
| G²LU Gate Grafting: Surgically upgrade pretrained SwiGLU models to G²LU.
|
|
|
| Takes any HuggingFace model with SwiGLU (gate_proj + up_proj), freezes everything
|
| except gate weights, adds W4 for nested gating, and trains with alignment + LM loss.
|
|
|
| This is grafting applied to the gate mechanism — the same methodology validated for
|
| full layer replacement, now targeting the minimum surgical unit.
|
|
|
| Usage:
|
| python -m circuits.train --arch graft_g2lu --pretrained meta-llama/Llama-3.2-1B \
|
| --align-weight 1.0 --graft-warmup 500 --data hf:Bingsu/openwebtext_20p ...
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from pathlib import Path
|
|
|
|
|
| class G2LU_MLP(nn.Module):
|
| """Per-layer MLP wrapper that upgrades SwiGLU to G²LU.
|
|
|
| Holds references to the original gate_proj (W3, frozen), up_proj (W1, frozen),
|
| down_proj (W2, frozen), plus a new w4 (zero-initialized, trainable).
|
|
|
| Gate ordering: silu(W4@x * silu(W3@x)) — the pretrained gate (W3) acts as
|
| structural prior, constraining W4 to operate within the feature subspace the
|
| pretrained model already deems relevant. W4's gradients are scaled by silu(W3@x),
|
| inheriting the pretrained model's feature selection hierarchy.
|
| """
|
|
|
| def __init__(self, original_mlp: nn.Module):
|
| super().__init__()
|
|
|
| self.gate_proj = original_mlp.gate_proj
|
| self.up_proj = original_mlp.up_proj
|
| self.down_proj = original_mlp.down_proj
|
|
|
|
|
| self.w4 = nn.Linear(
|
| self.gate_proj.in_features,
|
| self.gate_proj.out_features,
|
| bias=self.gate_proj.bias is not None,
|
| dtype=self.gate_proj.weight.dtype,
|
| device=self.gate_proj.weight.device,
|
| )
|
| nn.init.zeros_(self.w4.weight)
|
| if self.w4.bias is not None:
|
| nn.init.zeros_(self.w4.bias)
|
|
|
|
|
| self._alpha = 0.0
|
|
|
| self._align_loss = None
|
|
|
| def forward(self, x):
|
|
|
| w3_gate = F.silu(self.gate_proj(x))
|
|
|
|
|
|
|
| g2lu_gate = F.silu(self.w4(x) * w3_gate)
|
|
|
|
|
| if self._alpha < 1.0:
|
| gate = (1.0 - self._alpha) * w3_gate + self._alpha * g2lu_gate
|
| else:
|
| gate = g2lu_gate
|
|
|
|
|
| self._align_loss = F.mse_loss(gate, w3_gate.detach())
|
|
|
| return self.down_proj(gate * self.up_proj(x))
|
|
|
|
|
| class G2LU_GraftedModel(nn.Module):
|
| """Full model wrapper that upgrades a pretrained HF model's MLPs to G²LU.
|
|
|
| Interface matches CircuitTransformer: forward(input_ids, labels=labels) returns
|
| {"loss", "logits", "align_loss"}.
|
| """
|
|
|
| def __init__(
|
| self,
|
| pretrained_name: str,
|
| align_weight: float = 1.0,
|
| warmup_steps: int = 500,
|
| device: str = "cuda",
|
| dtype=torch.bfloat16,
|
| ):
|
| super().__init__()
|
| self.pretrained_name = pretrained_name
|
| self.align_weight = align_weight
|
| self.warmup_steps = warmup_steps
|
| self._current_step = 0
|
|
|
|
|
| from transformers import AutoModelForCausalLM
|
| self.model = AutoModelForCausalLM.from_pretrained(
|
| pretrained_name,
|
| dtype=dtype,
|
| trust_remote_code=True,
|
| )
|
|
|
|
|
| self.g2lu_mlps = []
|
| self._replace_mlps()
|
|
|
|
|
| for param in self.model.parameters():
|
| param.requires_grad = False
|
|
|
| for g2lu in self.g2lu_mlps:
|
| for param in g2lu.w4.parameters():
|
| param.requires_grad = True
|
|
|
| self.model.to(device)
|
|
|
|
|
| total_params = sum(p.numel() for p in self.model.parameters())
|
| trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| print(f"G²LU Graft: {pretrained_name}")
|
| print(f" Layers upgraded: {len(self.g2lu_mlps)}")
|
| print(f" Total params: {total_params:,} ({total_params/1e6:.1f}M)")
|
| print(f" Trainable params: {trainable:,} ({trainable/1e6:.1f}M, {100*trainable/total_params:.1f}%)")
|
| print(f" Align weight: {align_weight}, Warmup: {warmup_steps} steps")
|
|
|
| def _replace_mlps(self):
|
| """Walk the model tree and replace SwiGLU MLPs with G²LU wrappers."""
|
|
|
| layers = None
|
| for attr_path in ["model.layers", "gpt_neox.layers", "transformer.h"]:
|
| obj = self.model
|
| try:
|
| for attr in attr_path.split("."):
|
| obj = getattr(obj, attr)
|
| layers = obj
|
| break
|
| except AttributeError:
|
| continue
|
|
|
| if layers is None:
|
| raise ValueError(
|
| f"Could not find decoder layers in {type(self.model).__name__}. "
|
| f"Tried: model.layers, gpt_neox.layers, transformer.h"
|
| )
|
|
|
| for i, layer in enumerate(layers):
|
|
|
| mlp = None
|
| mlp_attr = None
|
| for attr in ["mlp", "feed_forward"]:
|
| if hasattr(layer, attr):
|
| mlp = getattr(layer, attr)
|
| mlp_attr = attr
|
| break
|
|
|
| if mlp is None:
|
| continue
|
|
|
|
|
| if hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj"):
|
| g2lu = G2LU_MLP(mlp)
|
| setattr(layer, mlp_attr, g2lu)
|
| self.g2lu_mlps.append(g2lu)
|
|
|
| if not self.g2lu_mlps:
|
| raise ValueError(
|
| "No SwiGLU MLPs found (need gate_proj + up_proj attributes). "
|
| "This model may not use gated linear units."
|
| )
|
|
|
| def set_step(self, step: int):
|
| """Update blend alpha across all G²LU MLPs."""
|
| self._current_step = step
|
| alpha = min(step / max(self.warmup_steps, 1), 1.0)
|
| for g2lu in self.g2lu_mlps:
|
| g2lu._alpha = alpha
|
|
|
| def trainable_parameters(self):
|
| """Yield only unfrozen parameters (for optimizer and grad clipping)."""
|
| for param in self.model.parameters():
|
| if param.requires_grad:
|
| yield param
|
|
|
| def collect_align_loss(self):
|
| """Average per-layer alignment losses."""
|
| losses = [g2lu._align_loss for g2lu in self.g2lu_mlps if g2lu._align_loss is not None]
|
| if not losses:
|
| return torch.tensor(0.0)
|
| return torch.stack(losses).mean()
|
|
|
| def forward(self, input_ids, labels=None, **kwargs):
|
| outputs = self.model(input_ids=input_ids, labels=labels, **kwargs)
|
|
|
| result = {"logits": outputs.logits}
|
|
|
| align_loss = self.collect_align_loss()
|
| result["align_loss"] = align_loss
|
|
|
| if labels is not None:
|
|
|
| result["loss"] = outputs.loss + self.align_weight * align_loss
|
| result["lm_loss"] = outputs.loss
|
| else:
|
| result["loss"] = align_loss
|
|
|
| return result
|
|
|
| def generate(self, input_ids, **kwargs):
|
| """Delegate to HF model's .generate()."""
|
| return self.model.generate(input_ids=input_ids, **kwargs)
|
|
|
|
|
| def save_g2lu_checkpoint(
|
| model: G2LU_GraftedModel,
|
| optimizer: torch.optim.Optimizer,
|
| step: int,
|
| epoch: int,
|
| loss: float,
|
| path: str,
|
| epoch_step: int = 0,
|
| best_val_loss: float | None = None,
|
| scaler=None,
|
| tokenizer_name: str = None,
|
| ):
|
| """Delta save: only trainable params + metadata."""
|
|
|
| raw = model.model if not hasattr(model, '_orig_mod') else model._orig_mod.model
|
|
|
| if hasattr(model, '_orig_mod'):
|
| g2lu_model = model._orig_mod
|
| else:
|
| g2lu_model = model
|
|
|
| delta_sd = {}
|
| full_sd = g2lu_model.model.state_dict()
|
| for name, param in g2lu_model.model.named_parameters():
|
| if param.requires_grad:
|
|
|
| clean_name = name.removeprefix("_orig_mod.")
|
| delta_sd[clean_name] = full_sd.get(name, param.data).clone()
|
|
|
|
|
| for name, val in full_sd.items():
|
| clean_name = name.removeprefix("_orig_mod.")
|
| if ".w4." in clean_name and clean_name not in delta_sd:
|
| delta_sd[clean_name] = val.clone()
|
|
|
| checkpoint = {
|
| "model": delta_sd,
|
| "optimizer": optimizer.state_dict(),
|
| "step": step,
|
| "epoch": epoch,
|
| "epoch_step": epoch_step,
|
| "loss": loss,
|
| "model_type": "graft_g2lu",
|
| "pretrained_name": g2lu_model.pretrained_name,
|
| "align_weight": g2lu_model.align_weight,
|
| "warmup_steps": g2lu_model.warmup_steps,
|
| "tokenizer_name": tokenizer_name or g2lu_model.pretrained_name,
|
| }
|
| if best_val_loss is not None:
|
| checkpoint["best_val_loss"] = best_val_loss
|
| if scaler is not None:
|
| checkpoint["scaler"] = scaler.state_dict()
|
| torch.save(checkpoint, path)
|
|
|
|
|
| def load_g2lu_model(checkpoint_path: str, device: str = "cuda", dtype=torch.bfloat16):
|
| """Delta load: recreate model from pretrained + apply delta weights."""
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
| pretrained_name = checkpoint["pretrained_name"]
|
| align_weight = checkpoint.get("align_weight", 1.0)
|
| warmup_steps = checkpoint.get("warmup_steps", 500)
|
|
|
| model = G2LU_GraftedModel(
|
| pretrained_name=pretrained_name,
|
| align_weight=align_weight,
|
| warmup_steps=warmup_steps,
|
| device=device,
|
| dtype=dtype,
|
| )
|
|
|
|
|
| delta_sd = checkpoint["model"]
|
|
|
| delta_sd = {k.removeprefix("_orig_mod."): v for k, v in delta_sd.items()}
|
|
|
|
|
| missing, unexpected = model.model.load_state_dict(delta_sd, strict=False)
|
| if unexpected:
|
| print(f" Warning: unexpected keys in delta checkpoint: {unexpected[:5]}...")
|
|
|
|
|
| model.set_step(warmup_steps + 1)
|
|
|
| return model
|
|
|