GELU d=12 Chinchilla (261M) β€” PyTorch / HuggingFace Transformers

A 261M-parameter nanochat-architecture GPT with GELU MLP, originally trained in JAX/Flax on a TPU v6e-8 and ported to PyTorch for easy inference via the HuggingFace transformers API.

Weights are bit-exact with the Flax checkpoint (mlnomad/gelu-d12-chinchilla-261M) β€” parity validated at max |Ξ” logits| = 1.3e-5 on CPU/fp32.

Quick start

pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mlnomad/gelu-d12-chinchilla-261M-pytorch",
    trust_remote_code=True,     # the model class ships in the repo
    dtype=torch.float32,
).eval()

# The model was trained with the Mistral-7B-v0.1 tokenizer (vocab 32,768)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids

with torch.no_grad():
    out = model.generate(
        ids,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        use_cache=True,
        pad_token_id=tokenizer.eos_token_id or 0,
    )
print(tokenizer.decode(out[0], skip_special_tokens=True))

Greedy completion examples from this checkpoint:

"The meaning of life is" β†’ to live in a place where you can live. The meaning of life is to live in a

"Once upon a time," β†’ the first thing I would do was to make a list of all the things I would like to do

Model details

Parameters 261,096,338
Architecture Nanochat-style GPT with GELU MLP (ported from JAX/Flax NNX)
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window
Training data allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ—)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, peak LR 0.01, warmup-cosine
Hardware TPU v6e-8 (TRC), europe-west4-a
Final loss 3.0745 (smooth 3.1106)
3-seed variance mean 3.0955 Β± 0.028 (smoothed 3.1155 Β± 0.008)

Architecture features

Full nanochat stack, faithfully ported to PyTorch:

  • RoPE (base 100,000), split-half layout
  • MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
  • QK-norm with 1.2Γ— scaling (after RoPE)
  • Parameterless RMSNorm (no learnable gain) post-embedding and per block
  • Sliding-window attention with "SSSL" pattern
  • Tied embeddings (lm_head = wte.T)
  • Value embeddings on alternating layers (ResFormer-style)
  • Per-layer learnable residual scalars (resid_lambdas, x0_lambdas)
  • Smear β€” learnable gate on first 24 dims of token embedding mixes in prev token
  • Backout β€” subtract mid-layer residual from late layers
  • Logit soft-cap: 15 Β· tanh(logits / 15)
  • No biases in any Linear
  • MLP: Linear(n β†’ 4n) β†’ F.gelu(approximate="tanh") β†’ Linear(4n β†’ n)

KV cache

The GeluGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. KV-cache parity vs full forward is validated at max |Ξ”| < 3e-5. Pass use_cache=True (the default for .generate()).

Files in this repo

.
β”œβ”€β”€ config.json                     # HF config with auto_map pointing to the classes below
β”œβ”€β”€ generation_config.json
β”œβ”€β”€ model.safetensors               # 1.04 GB, fp32 weights + persistent RoPE buffers
β”œβ”€β”€ torch_gpt.py                    # pure PyTorch GPT module (GELU_GPT)
β”œβ”€β”€ configuration_gelu_gpt.py       # PretrainedConfig subclass
β”œβ”€β”€ modeling_gelu_gpt.py            # PreTrainedModel + GenerationMixin wrapper with KV cache
└── README.md

Related

Wikitext-103 evaluation

Metric Value
Wikitext-103 test loss 3.840
Wikitext-103 test PPL 46.52

Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β€” this is a zero-shot transfer metric).

License

Apache 2.0.

Downloads last month
1,421
Safetensors
Model size
0.3B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/gelu-d12-chinchilla-261M-pytorch

Finetuned
(1)
this model

Dataset used to train mlnomad/gelu-d12-chinchilla-261M-pytorch

Space using mlnomad/gelu-d12-chinchilla-261M-pytorch 1