GPT Bilinear-MLP Family Relation β€” Reversal Curse Experiments

GPT causal language models with bilinear MLP trained on the family_relation dataset to study the reversal curse β€” the phenomenon where a model trained on "A is the parent of B" fails to infer "B is the child of A".

Bilinear MLP

Replaces the standard ReLU MLP with a bilinear (no-gate) variant:

MLP(x) = p( w_l(x) βŠ™ w_r(x) )

where w_l, w_r are two linear projections from d_model to 4*d_model, βŠ™ is element-wise product, and p projects back to d_model. With gate=none there is no nonlinearity β€” the entire MLP is a low-rank bilinear form, making the network mathematically tractable for interpretability analyses.

Key Finding

Weight decay drives reversal generalization in bilinear-MLP models too. With sufficient weight decay, the bilinear-MLP architecture solves the reversal curse without any data augmentation, matching ReLU-MLP results.

Results (L=12, nhead=8, d_model=768, 20 epochs)

wd train acc test acc (reversal)
0.0 98.82% 25.57%
1.0 100.00% 92.45%
2.0 99.93% 92.65%
3.0 99.97% 93.87%
4.0 99.98% 97.77%
5.0 100.00% 97.67%
6.0 100.00% 99.33%
  • Train Acc: accuracy on bidirectional eval split (same direction as training data)
  • Test Acc (reversal): accuracy on unidirectional eval split (reversed direction, not seen in training)

Model Architecture

Parameters ~135M
Layers 12
Hidden dim (d_model) 768
Attention heads 8 (head_dim=96)
MLP Bilinear (gate=none), hidden = 4 Γ— d_model
Max seq len 1024
Vocab size 50,257 (GPT-2 BPE)
Normalization RMSNorm (unparameterized, pre-norm)
Positional encoding RoPE
QK norm RMSNorm (unparameterized)
Logit softcap 15.0
Embedding tying No (untied)
Bias None

Training Details

Optimizer AdamW (betas=0.9, 0.95)
Learning rate 3e-4
Schedule Cosine decay with 1% warmup
Batch size 64
Dropout 0.1 (embedding, attention softmax, FFN activation, residual)
Epochs 20
Precision FP32 weights, bf16 autocast forward
Gradient clipping None
Weight decay Applied to all parameters except RMSNorm
Data packing Simple concatenation, fixed-size chunks

Usage

from huggingface_hub import hf_hub_download
import os, sys, torch

# Download model
model_path = hf_hub_download("kdkyum/gpt-biMLP-family-relation", "L12/h8_wd6.0/best_model.pt")
model_py_path = hf_hub_download("kdkyum/gpt-biMLP-family-relation", "model.py")

# Load model
sys.path.insert(0, os.path.dirname(model_py_path))
from model import GPT, GPTConfig, load_model, load_tokenizer

config = GPTConfig(nhead=8, dropout=0.1, bilinear=True, gate=None)
model = load_model(model_path, config=config, device="cuda")  # or "cpu"

# Load tokenizer (GPT-2 BPE)
tokenizer = load_tokenizer()
bos_id = tokenizer.bos_token_id
period_id = tokenizer.encode(".", add_special_tokens=False)[0]

# Generate
prompt = " Ryan Earl Garza mother"  # reversed query (child β†’ parent)
ids = torch.tensor([[bos_id] + tokenizer.encode(prompt, add_special_tokens=False)],
                    dtype=torch.long, device="cuda")
out = model.generate(ids, max_new_tokens=10)
new_ids = out[0, ids.shape[1]:].tolist()
result = []
for t in new_ids:
    if t == period_id or t == bos_id:
        break
    result.append(t)
print(tokenizer.decode(result))

Files

model.py                  # Self-contained GPT model with bilinear MLP support
L12/
  h8_wd0.0/               # Baseline (no weight decay) β€” fails reversal
    best_model.pt
    latest_model.pt
  h8_wd6.0/               # Best reversal accuracy (99.3%)
    best_model.pt
    latest_model.pt

Inspecting the bilinear weights

Because there is no nonlinearity, the bilinear MLP exposes its two factors directly:

block = model.transformer.h[0]
W_l = block.mlp.w.w_l   # (4*d_model, d_model)
W_r = block.mlp.w.w_r   # (4*d_model, d_model)
W_p = block.mlp.p.weight  # (d_model, 4*d_model)

# The MLP output is: W_p @ ((W_l @ x) * (W_r @ x))
# i.e. a low-rank quadratic form in x.

Citation

@misc{gpt-biMLP-family-relation,
  author = {kdkyum},
  title = {Bilinear-MLP GPT for the Reversal Curse},
  year = {2026},
  url = {https://huggingface.co/kdkyum/gpt-biMLP-family-relation}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train kdkyum/gpt-biMLP-family-relation