GPT Family Relation (w/o MLP middle layer) β Reversal Curse Experiments
GPT causal language models trained on the family_relation dataset to study the reversal curse β the phenomenon where a model trained on "A is the parent of B" fails to infer "B is the child of A".
This repo contains three model variants:
- Standard MLP (ReLU activation)
- Bilinear MLP (no gate/activation:
MLP(x) = p( w_l(x) β w_r(x) )) - Attention-only (no MLP blocks, nhead=8, dp=0.0, norm_scale=True)
Key Finding
Weight decay is the key driver for solving the reversal curse. With sufficient weight decay, all three variants achieve near-perfect reversal accuracy without any data augmentation β including the attention-only model with no MLP at all.
Results β Standard MLP, dp=0.1 (15 epochs)
cosine lr=3e-4, final_lr_frac=0.01, warmup=100, dp=0.1, attn_dp=0.1, gc=1.0
| wd | train acc | test acc (reversal) |
|---|---|---|
| 0.0 | 91.88% | 14.43% |
| 1.0 | 99.82% | 22.22% |
| 2.0 | 99.95% | 63.63% |
| 3.0 | 100.00% | 99.47% |
| 4.0 | 99.97% | 93.03% |
| 5.0 | 99.72% | 95.15% |
| 6.0 | 99.80% | 95.45% |
Results β Standard MLP, dp=0.3 (15 epochs)
cosine lr=3e-4, final_lr_frac=0.01, warmup=100, dp=0.3, attn_dp=0.3, gc=1.0
| wd | train acc | test acc (reversal) |
|---|---|---|
| 0.0 | 100.00% | 23.60% |
| 1.0 | 100.00% | 48.75% |
| 2.0 | 100.00% | 71.07% |
| 3.0 | 100.00% | 99.18% |
| 4.0 | 100.00% | 97.83% |
| 5.0 | 99.68% | 97.20% |
| 6.0 | 99.95% | 96.80% |
Results β Bilinear MLP (15 epochs)
cosine lr=3e-4, final_lr_frac=0.01, warmup=100, dp=0.1, attn_dp=0.1, gc=1.0
| wd | train acc | test acc (reversal) |
|---|---|---|
| 0.0 | 99.37% | 21.68% |
| 1.0 | 100.00% | 81.77% |
| 2.0 | 100.00% | 99.70% |
| 3.0 | 100.00% | 99.98% |
| 4.0 | 100.00% | 99.98% |
| 5.0 | 100.00% | 99.87% |
| 6.0 | 100.00% | 97.50% |
Results β Attention-only (25 epochs, nhead=8, norm_scale)
cosine lr=3e-4, final_lr_frac=0.01, warmup=100, dp=0.0, attn_dp=0.0, norm_scale=True, gc=1.0
| wd | train acc | test acc (reversal) |
|---|---|---|
| 0.0 | 84.20% | 16.83% |
| 1.0 | 97.78% | 86.37% |
| 2.0 | 99.12% | 87.32% |
| 3.0 | 99.70% | 98.87% |
| 4.0 | 98.30% | 86.83% |
| 5.0 | 99.90% | 96.98% |
| 6.0 | 99.92% | 99.48% |
- Train Acc: accuracy on bidirectional eval split (same direction as training data)
- Test Acc (reversal): accuracy on unidirectional eval split (reversed direction, not seen in training)
Model Architecture
| Parameters | ~115M (MLP) / ~135M (Bilinear MLP) / ~77M (Attn-only) |
| Layers | 12 |
| Hidden dim (d_model) | 768 |
| Attention heads | 8 (head_dim=96) for all variants |
| FFN hidden | 3072 (4 Γ d_model) |
| Max seq len | 1024 |
| Vocab size | 50,257 (GPT-2 BPE) |
| Activation | ReLU (MLP) / None (Bilinear MLP) / N/A (Attn-only) |
| Normalization | RMSNorm (pre-norm; unparameterized for MLP variants, learnable scale for Attn-only) |
| Positional encoding | RoPE |
| QK norm | RMSNorm (same as above) |
| Embedding tying | No (untied) |
| Bias | None |
Training Details
| Optimizer | AdamW (betas=0.9, 0.95) |
| Learning rate | 3e-4 |
| Schedule | Cosine decay (final_lr_frac=0.01), warmup=100 steps |
| Batch size | 64 |
| Dropout | 0.1 (MLP variants) / 0.0 (Attn-only) |
| Epochs | 15 (MLP variants) / 25 (Attn-only) |
| Precision | FP32 |
| Gradient clipping | 1.0 |
| Weight decay | Swept from 0.0 to 6.0 |
Usage
from huggingface_hub import hf_hub_download
import os, sys, torch
# Download model (e.g. bilinear MLP with wd=3.0)
model_path = hf_hub_download(
"kdkyum/gpt-family-relation_wo_mid",
"bimlp_cos_3e4_w100_ep15_wd3.0/best_model.pt"
)
model_py_path = hf_hub_download("kdkyum/gpt-family-relation_wo_mid", "model.py")
# Load model
sys.path.insert(0, os.path.dirname(model_py_path))
from model import GPT, GPTConfig, load_model, load_tokenizer
# For bilinear MLP:
config = GPTConfig(nhead=8, dropout=0.0, bilinear=True, gate=None)
# For standard MLP:
# config = GPTConfig(nhead=8, dropout=0.0)
# For attention-only (norm_scale):
# config = GPTConfig(nhead=8, dropout=0.0, no_mlp=True, norm_scale=True)
model = load_model(model_path, config=config, device="cuda") # or "cpu"
# Load tokenizer (GPT-2 BPE)
tokenizer = load_tokenizer()
bos_id = tokenizer.bos_token_id
period_id = tokenizer.encode(".", add_special_tokens=False)[0]
# Generate β reversed query (child β parent)
prompt = " Ryan Earl Garza mother"
ids = torch.tensor([[bos_id] + tokenizer.encode(prompt, add_special_tokens=False)],
dtype=torch.long, device="cuda")
out = model.generate(ids, max_new_tokens=10)
new_ids = out[0, ids.shape[1]:].tolist()
result = []
for t in new_ids:
if t == period_id or t == bos_id:
break
result.append(t)
print(tokenizer.decode(result))
Files
model.py # Self-contained GPT model (needs torch + transformers)
mlp_cos_3e4_w100_ep15_wd{X}/
best_model.pt # Best checkpoint (by reversal test accuracy)
latest_model.pt # Final checkpoint
bimlp_cos_3e4_w100_ep15_wd{X}/
best_model.pt
latest_model.pt
attn25_cos_3e4_w100_ep25_nh8_dp0_ns_wd{X}/ # Attention-only (nhead=8, dp=0.0, norm_scale)
best_model.pt
latest_model.pt
mlp_cos_3e4_w100_ep15_dp03_wd{X}/ # Standard MLP (dp=0.3)
best_model.pt
latest_model.pt
Dataset
Trained on kdkyum/family_relation (lvl3_N1000 split) β synthetically generated family relation statements with 1000 families and 3 levels of depth.
Citation
@misc{gpt-family-relation-wo-mid,
author = {kdkyum},
title = {GPT Family Relation (w/o MLP middle layer): Solving the Reversal Curse with Weight Decay},
year = {2026},
url = {https://huggingface.co/kdkyum/gpt-family-relation_wo_mid}
}