GELU d=12 Chinchilla (261M)
A 261M-parameter nanochat-architecture GPT with GELU activation in the MLP. Trained on English C4 to Chinchilla-optimal token budget (20Γ params β 5.22 B tokens) on a single TPU v6e-8.
This repository hosts the seed 0 checkpoint (model + optimizer state, resumable). Training was also run for seeds 1 and 2 under identical compute to quantify seed variance β summary below.
Final result β seed 0 (this checkpoint)
| Metric | Value |
|---|---|
| Final loss | 3.0745 |
| Smooth final loss | 3.1106 |
| Tokens | 5.22 B |
| Steps | 19,920 |
| Wall time | 2.06 h |
| Throughput | 703 K tok/s |
3-seed variance (same arch, same data, same compute)
| Seed | Final loss | Smooth final | Wall time | Throughput |
|---|---|---|---|---|
| 0 (this repo) | 3.0745 | 3.1106 | 2.06 h | 703 K tok/s |
| 1 | 3.1355 | 3.1097 | 2.02 h | 717 K tok/s |
| 2 | 3.0765 | 3.1261 | 2.02 h | 717 K tok/s |
| mean Β± std | 3.0955 Β± 0.028 | 3.1155 Β± 0.008 | 2.03 Β± 0.02 h | 712 Β± 7 K tok/s |
Architecture
Full nanochat stack with GELU MLP β everything else identical to the default nanochat GPT:
- RoPE (base 100K), no learned positional embeddings
- MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
- QK norm with 1.2Γ scaling
- Sliding window attention with
"SSSL"pattern - Tied embeddings (wte β lm_head)
- Parameterless RMSNorm post-embedding and per block
- Value embeddings (ResFormer-style, alternating layers)
- Per-layer learnable residual scalars
- Smear (learnable bigram gate on first 24 dims of token embedding)
- Backout (subtract mid-layer residual)
- Logit soft-capping via
tanh(x/15)Β·15 - No biases in any Linear (attention Q/K/V/proj, MLP fc/proj)
- MLP FFN:
Linear(n β 4n) β jax.nn.gelu β Linear(4n β n)- Applied via the
--mlp geluflag inscripts/train_d12_chinchilla.pywhich monkey-patchesMLP.__call__inflaxchat.gpt
- Applied via the
Training
| Architecture | Nanochat-style GPT with GELU MLP |
| Parameters | 261,096,338 |
| Config | d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL window |
| Data | allenai/c4 (English split, streamed) |
| Tokenizer | mistralai/Mistral-7B-v0.1 (vocab 32,768) |
| Optimizer | plain AdamW, Ξ²=(0.9, 0.999), wd=0.01, global-norm grad clip 1.0 |
| LR schedule | warmup-cosine-decay, peak 0.01, warmup 500, end 5% of peak |
| Batch | 32/device Γ 8 devices = 256 global (262 K tokens/step) |
| Seq length | 1024 |
| Tokens | Chinchilla 20Γ params β 5.22 B tokens (19,920 steps) |
| Hardware | TPU v6e-8 (TRC), europe-west4-a |
| Seed | 0 (seeds 1, 2 trained under identical config for variance table above) |
| Wandb | irf-sic/flaxchat β gelu-d12-chinchilla-lr0.01-seed0 |
The optimizer state is preserved β checkpoint is resumable for further training.
Contents
This repository contains the full Orbax checkpoint (model weights + AdamW optimizer state), the frozen config, and code snapshots sufficient to load the model:
.
βββ 19920/ # final Orbax checkpoint (~2.86 GB total)
β βββ _CHECKPOINT_METADATA
β βββ metadata/ # JSON: {"final_loss": 3.07, "smooth": 3.11}
β βββ model/ # nnx.Param state β architecture weights
β βββ optimizer/ # optax AdamW state (Adam m/v moments + step + LR state)
βββ config.json
βββ README.md
βββ code/ # snapshots of flaxchat code (source of truth: github)
βββ gpt.py
βββ checkpoint.py
βββ config.py
βββ train_d12_chinchilla.py # the exact training script (use --mlp gelu)
βββ load_model.py
Loading
git clone https://github.com/mlnomadpy/flaxchat && cd flaxchat
pixi install # or pip install with the deps in code/load_model.py
huggingface-cli download mlnomad/gelu-d12-chinchilla-261M --local-dir ./hf_model
# Inside your flaxchat clone β adapt code/load_model.py for GELU:
from flax import nnx
import jax, jax.numpy as jnp
from flaxchat.gpt import GPT, GPTConfig, MLP
from flaxchat.checkpoint import restore_model_from_checkpoint
# Patch MLP to use GELU instead of ReLUΒ² (matches training)
_orig = MLP.__call__
def _gelu_call(self, x):
x = self.c_fc(x)
x = jax.nn.gelu(x)
x = self.c_proj(x)
return x
MLP.__call__ = _gelu_call
config = GPTConfig(
sequence_len=1024, vocab_size=32768,
n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
meta = restore_model_from_checkpoint(model, "./hf_model/19920")
print(f"Loaded β final loss: {meta['final_loss']:.4f}") # expect 3.0745
Weights and optimizer state
Both the model parameters and the full AdamW optimizer state (Adam m/v moments, step counter, LR schedule state) are stored at:
- HuggingFace Hub:
mlnomad/gelu-d12-chinchilla-261M19920/model/βnnx.Paramstate19920/optimizer/βoptaxAdamW state (resumable)
The checkpoint format is Orbax (OCDBT PyTree). flaxchat.checkpoint.restore_model_from_checkpoint(model, ckpt_path, optimizer=optimizer) restores both in one call.
License
Apache 2.0.
- Downloads last month
- 48