GELU d=12 Chinchilla (261M)

A 261M-parameter nanochat-architecture GPT with GELU activation in the MLP. Trained on English C4 to Chinchilla-optimal token budget (20Γ— params β‰ˆ 5.22 B tokens) on a single TPU v6e-8.

This repository hosts the seed 0 checkpoint (model + optimizer state, resumable). Training was also run for seeds 1 and 2 under identical compute to quantify seed variance β€” summary below.

Final result β€” seed 0 (this checkpoint)

Metric Value
Final loss 3.0745
Smooth final loss 3.1106
Tokens 5.22 B
Steps 19,920
Wall time 2.06 h
Throughput 703 K tok/s

3-seed variance (same arch, same data, same compute)

Seed Final loss Smooth final Wall time Throughput
0 (this repo) 3.0745 3.1106 2.06 h 703 K tok/s
1 3.1355 3.1097 2.02 h 717 K tok/s
2 3.0765 3.1261 2.02 h 717 K tok/s
mean Β± std 3.0955 Β± 0.028 3.1155 Β± 0.008 2.03 Β± 0.02 h 712 Β± 7 K tok/s

Architecture

Full nanochat stack with GELU MLP β€” everything else identical to the default nanochat GPT:

  • RoPE (base 100K), no learned positional embeddings
  • MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
  • QK norm with 1.2Γ— scaling
  • Sliding window attention with "SSSL" pattern
  • Tied embeddings (wte ↔ lm_head)
  • Parameterless RMSNorm post-embedding and per block
  • Value embeddings (ResFormer-style, alternating layers)
  • Per-layer learnable residual scalars
  • Smear (learnable bigram gate on first 24 dims of token embedding)
  • Backout (subtract mid-layer residual)
  • Logit soft-capping via tanh(x/15)Β·15
  • No biases in any Linear (attention Q/K/V/proj, MLP fc/proj)
  • MLP FFN: Linear(n β†’ 4n) β†’ jax.nn.gelu β†’ Linear(4n β†’ n)
    • Applied via the --mlp gelu flag in scripts/train_d12_chinchilla.py which monkey-patches MLP.__call__ in flaxchat.gpt

Training

Architecture Nanochat-style GPT with GELU MLP
Parameters 261,096,338
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL window
Data allenai/c4 (English split, streamed)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, Ξ²=(0.9, 0.999), wd=0.01, global-norm grad clip 1.0
LR schedule warmup-cosine-decay, peak 0.01, warmup 500, end 5% of peak
Batch 32/device Γ— 8 devices = 256 global (262 K tokens/step)
Seq length 1024
Tokens Chinchilla 20Γ— params β‰ˆ 5.22 B tokens (19,920 steps)
Hardware TPU v6e-8 (TRC), europe-west4-a
Seed 0 (seeds 1, 2 trained under identical config for variance table above)
Wandb irf-sic/flaxchat β€” gelu-d12-chinchilla-lr0.01-seed0

The optimizer state is preserved β†’ checkpoint is resumable for further training.

Contents

This repository contains the full Orbax checkpoint (model weights + AdamW optimizer state), the frozen config, and code snapshots sufficient to load the model:

.
β”œβ”€β”€ 19920/                          # final Orbax checkpoint (~2.86 GB total)
β”‚   β”œβ”€β”€ _CHECKPOINT_METADATA
β”‚   β”œβ”€β”€ metadata/                   # JSON: {"final_loss": 3.07, "smooth": 3.11}
β”‚   β”œβ”€β”€ model/                      # nnx.Param state β€” architecture weights
β”‚   └── optimizer/                  # optax AdamW state (Adam m/v moments + step + LR state)
β”œβ”€β”€ config.json
β”œβ”€β”€ README.md
└── code/                           # snapshots of flaxchat code (source of truth: github)
    β”œβ”€β”€ gpt.py
    β”œβ”€β”€ checkpoint.py
    β”œβ”€β”€ config.py
    β”œβ”€β”€ train_d12_chinchilla.py     # the exact training script (use --mlp gelu)
    └── load_model.py

Loading

git clone https://github.com/mlnomadpy/flaxchat && cd flaxchat
pixi install   # or pip install with the deps in code/load_model.py
huggingface-cli download mlnomad/gelu-d12-chinchilla-261M --local-dir ./hf_model
# Inside your flaxchat clone β€” adapt code/load_model.py for GELU:
from flax import nnx
import jax, jax.numpy as jnp
from flaxchat.gpt import GPT, GPTConfig, MLP
from flaxchat.checkpoint import restore_model_from_checkpoint

# Patch MLP to use GELU instead of ReLUΒ² (matches training)
_orig = MLP.__call__
def _gelu_call(self, x):
    x = self.c_fc(x)
    x = jax.nn.gelu(x)
    x = self.c_proj(x)
    return x
MLP.__call__ = _gelu_call

config = GPTConfig(
    sequence_len=1024, vocab_size=32768,
    n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
    window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
meta = restore_model_from_checkpoint(model, "./hf_model/19920")
print(f"Loaded β€” final loss: {meta['final_loss']:.4f}")  # expect 3.0745

Weights and optimizer state

Both the model parameters and the full AdamW optimizer state (Adam m/v moments, step counter, LR schedule state) are stored at:

The checkpoint format is Orbax (OCDBT PyTree). flaxchat.checkpoint.restore_model_from_checkpoint(model, ckpt_path, optimizer=optimizer) restores both in one call.

License

Apache 2.0.

Downloads last month
48
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/gelu-d12-chinchilla-261M

Finetunes
1 model

Dataset used to train mlnomad/gelu-d12-chinchilla-261M