Prisma / graft_g2lu.py
y3i12's picture
Initial commit
56e82ec
"""
G²LU Gate Grafting: Surgically upgrade pretrained SwiGLU models to G²LU.
Takes any HuggingFace model with SwiGLU (gate_proj + up_proj), freezes everything
except gate weights, adds W4 for nested gating, and trains with alignment + LM loss.
This is grafting applied to the gate mechanism — the same methodology validated for
full layer replacement, now targeting the minimum surgical unit.
Usage:
python -m circuits.train --arch graft_g2lu --pretrained meta-llama/Llama-3.2-1B \
--align-weight 1.0 --graft-warmup 500 --data hf:Bingsu/openwebtext_20p ...
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
class G2LU_MLP(nn.Module):
"""Per-layer MLP wrapper that upgrades SwiGLU to G²LU.
Holds references to the original gate_proj (W3, frozen), up_proj (W1, frozen),
down_proj (W2, frozen), plus a new w4 (zero-initialized, trainable).
Gate ordering: silu(W4@x * silu(W3@x)) — the pretrained gate (W3) acts as
structural prior, constraining W4 to operate within the feature subspace the
pretrained model already deems relevant. W4's gradients are scaled by silu(W3@x),
inheriting the pretrained model's feature selection hierarchy.
"""
def __init__(self, original_mlp: nn.Module):
super().__init__()
# References to original weights (all frozen)
self.gate_proj = original_mlp.gate_proj # W3 — frozen
self.up_proj = original_mlp.up_proj # W1 — frozen
self.down_proj = original_mlp.down_proj # W2 — frozen
# New W4: same shape as gate_proj, zero-initialized, matched dtype
self.w4 = nn.Linear(
self.gate_proj.in_features,
self.gate_proj.out_features,
bias=self.gate_proj.bias is not None,
dtype=self.gate_proj.weight.dtype,
device=self.gate_proj.weight.device,
)
nn.init.zeros_(self.w4.weight)
if self.w4.bias is not None:
nn.init.zeros_(self.w4.bias)
# Blend alpha: 0 = pure SwiGLU, 1 = full G²LU
self._alpha = 0.0
# Per-layer alignment loss (collected by parent)
self._align_loss = None
def forward(self, x):
# Pretrained gate (frozen W3) — structural prior
w3_gate = F.silu(self.gate_proj(x))
# G²LU gate: silu(W4@x * silu(W3@x))
# W4 modulated BY pretrained knowledge, not the reverse
g2lu_gate = F.silu(self.w4(x) * w3_gate)
# Blend warmup: smooth transition from SwiGLU → G²LU
if self._alpha < 1.0:
gate = (1.0 - self._alpha) * w3_gate + self._alpha * g2lu_gate
else:
gate = g2lu_gate
# Per-layer alignment loss (compare against original SwiGLU gate)
self._align_loss = F.mse_loss(gate, w3_gate.detach())
return self.down_proj(gate * self.up_proj(x))
class G2LU_GraftedModel(nn.Module):
"""Full model wrapper that upgrades a pretrained HF model's MLPs to G²LU.
Interface matches CircuitTransformer: forward(input_ids, labels=labels) returns
{"loss", "logits", "align_loss"}.
"""
def __init__(
self,
pretrained_name: str,
align_weight: float = 1.0,
warmup_steps: int = 500,
device: str = "cuda",
dtype=torch.bfloat16,
):
super().__init__()
self.pretrained_name = pretrained_name
self.align_weight = align_weight
self.warmup_steps = warmup_steps
self._current_step = 0
# Load pretrained HF model
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_name,
dtype=dtype,
trust_remote_code=True,
)
# Discover and replace MLPs
self.g2lu_mlps = []
self._replace_mlps()
# Freeze everything, then selectively unfreeze W4 only
for param in self.model.parameters():
param.requires_grad = False
for g2lu in self.g2lu_mlps:
for param in g2lu.w4.parameters():
param.requires_grad = True
self.model.to(device)
# Print summary
total_params = sum(p.numel() for p in self.model.parameters())
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"G²LU Graft: {pretrained_name}")
print(f" Layers upgraded: {len(self.g2lu_mlps)}")
print(f" Total params: {total_params:,} ({total_params/1e6:.1f}M)")
print(f" Trainable params: {trainable:,} ({trainable/1e6:.1f}M, {100*trainable/total_params:.1f}%)")
print(f" Align weight: {align_weight}, Warmup: {warmup_steps} steps")
def _replace_mlps(self):
"""Walk the model tree and replace SwiGLU MLPs with G²LU wrappers."""
# Try common decoder layer paths
layers = None
for attr_path in ["model.layers", "gpt_neox.layers", "transformer.h"]:
obj = self.model
try:
for attr in attr_path.split("."):
obj = getattr(obj, attr)
layers = obj
break
except AttributeError:
continue
if layers is None:
raise ValueError(
f"Could not find decoder layers in {type(self.model).__name__}. "
f"Tried: model.layers, gpt_neox.layers, transformer.h"
)
for i, layer in enumerate(layers):
# Try common MLP attribute names
mlp = None
mlp_attr = None
for attr in ["mlp", "feed_forward"]:
if hasattr(layer, attr):
mlp = getattr(layer, attr)
mlp_attr = attr
break
if mlp is None:
continue
# Check for SwiGLU signature (gate_proj + up_proj)
if hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj"):
g2lu = G2LU_MLP(mlp)
setattr(layer, mlp_attr, g2lu)
self.g2lu_mlps.append(g2lu)
if not self.g2lu_mlps:
raise ValueError(
"No SwiGLU MLPs found (need gate_proj + up_proj attributes). "
"This model may not use gated linear units."
)
def set_step(self, step: int):
"""Update blend alpha across all G²LU MLPs."""
self._current_step = step
alpha = min(step / max(self.warmup_steps, 1), 1.0)
for g2lu in self.g2lu_mlps:
g2lu._alpha = alpha
def trainable_parameters(self):
"""Yield only unfrozen parameters (for optimizer and grad clipping)."""
for param in self.model.parameters():
if param.requires_grad:
yield param
def collect_align_loss(self):
"""Average per-layer alignment losses."""
losses = [g2lu._align_loss for g2lu in self.g2lu_mlps if g2lu._align_loss is not None]
if not losses:
return torch.tensor(0.0)
return torch.stack(losses).mean()
def forward(self, input_ids, labels=None, **kwargs):
outputs = self.model(input_ids=input_ids, labels=labels, **kwargs)
result = {"logits": outputs.logits}
align_loss = self.collect_align_loss()
result["align_loss"] = align_loss
if labels is not None:
# Combine LM loss + alignment loss
result["loss"] = outputs.loss + self.align_weight * align_loss
result["lm_loss"] = outputs.loss
else:
result["loss"] = align_loss
return result
def generate(self, input_ids, **kwargs):
"""Delegate to HF model's .generate()."""
return self.model.generate(input_ids=input_ids, **kwargs)
def save_g2lu_checkpoint(
model: G2LU_GraftedModel,
optimizer: torch.optim.Optimizer,
step: int,
epoch: int,
loss: float,
path: str,
epoch_step: int = 0,
best_val_loss: float | None = None,
scaler=None,
tokenizer_name: str = None,
):
"""Delta save: only trainable params + metadata."""
# Extract only requires_grad params
raw = model.model if not hasattr(model, '_orig_mod') else model._orig_mod.model
# Handle torch.compile wrapper
if hasattr(model, '_orig_mod'):
g2lu_model = model._orig_mod
else:
g2lu_model = model
delta_sd = {}
full_sd = g2lu_model.model.state_dict()
for name, param in g2lu_model.model.named_parameters():
if param.requires_grad:
# Strip _orig_mod. prefix if present
clean_name = name.removeprefix("_orig_mod.")
delta_sd[clean_name] = full_sd.get(name, param.data).clone()
# Also save the w4 weights explicitly (they're part of the replaced modules)
for name, val in full_sd.items():
clean_name = name.removeprefix("_orig_mod.")
if ".w4." in clean_name and clean_name not in delta_sd:
delta_sd[clean_name] = val.clone()
checkpoint = {
"model": delta_sd,
"optimizer": optimizer.state_dict(),
"step": step,
"epoch": epoch,
"epoch_step": epoch_step,
"loss": loss,
"model_type": "graft_g2lu",
"pretrained_name": g2lu_model.pretrained_name,
"align_weight": g2lu_model.align_weight,
"warmup_steps": g2lu_model.warmup_steps,
"tokenizer_name": tokenizer_name or g2lu_model.pretrained_name,
}
if best_val_loss is not None:
checkpoint["best_val_loss"] = best_val_loss
if scaler is not None:
checkpoint["scaler"] = scaler.state_dict()
torch.save(checkpoint, path)
def load_g2lu_model(checkpoint_path: str, device: str = "cuda", dtype=torch.bfloat16):
"""Delta load: recreate model from pretrained + apply delta weights."""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
pretrained_name = checkpoint["pretrained_name"]
align_weight = checkpoint.get("align_weight", 1.0)
warmup_steps = checkpoint.get("warmup_steps", 500)
model = G2LU_GraftedModel(
pretrained_name=pretrained_name,
align_weight=align_weight,
warmup_steps=warmup_steps,
device=device,
dtype=dtype,
)
# Load delta weights
delta_sd = checkpoint["model"]
# Strip _orig_mod. prefix if present
delta_sd = {k.removeprefix("_orig_mod."): v for k, v in delta_sd.items()}
# Apply delta weights to the model
missing, unexpected = model.model.load_state_dict(delta_sd, strict=False)
if unexpected:
print(f" Warning: unexpected keys in delta checkpoint: {unexpected[:5]}...")
# Set alpha to 1.0 for inference (full G²LU)
model.set_step(warmup_steps + 1)
return model