GELU d=12 Chinchilla (261M) β PyTorch / HuggingFace Transformers
A 261M-parameter nanochat-architecture GPT with GELU MLP, originally trained in JAX/Flax on a TPU v6e-8 and ported to PyTorch for easy inference via the HuggingFace transformers API.
Weights are bit-exact with the Flax checkpoint (mlnomad/gelu-d12-chinchilla-261M) β parity validated at max |Ξ logits| = 1.3e-5 on CPU/fp32.
Quick start
pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"mlnomad/gelu-d12-chinchilla-261M-pytorch",
trust_remote_code=True, # the model class ships in the repo
dtype=torch.float32,
).eval()
# The model was trained with the Mistral-7B-v0.1 tokenizer (vocab 32,768)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
out = model.generate(
ids,
max_new_tokens=50,
do_sample=True,
temperature=0.8,
top_p=0.9,
use_cache=True,
pad_token_id=tokenizer.eos_token_id or 0,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
Greedy completion examples from this checkpoint:
"The meaning of life is" β to live in a place where you can live. The meaning of life is to live in a
"Once upon a time," β the first thing I would do was to make a list of all the things I would like to do
Model details
| Parameters | 261,096,338 |
| Architecture | Nanochat-style GPT with GELU MLP (ported from JAX/Flax NNX) |
| Config | d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window |
| Training data | allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ) |
| Tokenizer | mistralai/Mistral-7B-v0.1 (vocab 32,768) |
| Optimizer | plain AdamW, peak LR 0.01, warmup-cosine |
| Hardware | TPU v6e-8 (TRC), europe-west4-a |
| Final loss | 3.0745 (smooth 3.1106) |
| 3-seed variance | mean 3.0955 Β± 0.028 (smoothed 3.1155 Β± 0.008) |
Architecture features
Full nanochat stack, faithfully ported to PyTorch:
- RoPE (base 100,000), split-half layout
- MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
- QK-norm with 1.2Γ scaling (after RoPE)
- Parameterless RMSNorm (no learnable gain) post-embedding and per block
- Sliding-window attention with
"SSSL"pattern - Tied embeddings (lm_head = wte.T)
- Value embeddings on alternating layers (ResFormer-style)
- Per-layer learnable residual scalars (
resid_lambdas,x0_lambdas) - Smear β learnable gate on first 24 dims of token embedding mixes in prev token
- Backout β subtract mid-layer residual from late layers
- Logit soft-cap:
15 Β· tanh(logits / 15) - No biases in any Linear
- MLP:
Linear(n β 4n) β F.gelu(approximate="tanh") β Linear(4n β n)
KV cache
The GeluGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. KV-cache parity vs full forward is validated at max |Ξ| < 3e-5. Pass use_cache=True (the default for .generate()).
Files in this repo
.
βββ config.json # HF config with auto_map pointing to the classes below
βββ generation_config.json
βββ model.safetensors # 1.04 GB, fp32 weights + persistent RoPE buffers
βββ torch_gpt.py # pure PyTorch GPT module (GELU_GPT)
βββ configuration_gelu_gpt.py # PretrainedConfig subclass
βββ modeling_gelu_gpt.py # PreTrainedModel + GenerationMixin wrapper with KV cache
βββ README.md
Related
mlnomad/gelu-d12-chinchilla-261Mβ original JAX/Flax Orbax checkpoint (model + AdamW optimizer state, resumable)- flaxchat β the JAX/Flax training harness that produced the weights
Wikitext-103 evaluation
| Metric | Value |
|---|---|
| Wikitext-103 test loss | 3.840 |
| Wikitext-103 test PPL | 46.52 |
Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β this is a zero-shot transfer metric).
License
Apache 2.0.
- Downloads last month
- 1,421
Model tree for mlnomad/gelu-d12-chinchilla-261M-pytorch
Base model
mlnomad/gelu-d12-chinchilla-261M