YatNMN-Softplus + scalar_bias d=12 Chinchilla (261M)
A 261M-parameter nanochat-architecture GPT whose MLP uses the YatNMN layer from nmn>=0.2.29 with softplus_bias=True, scalar_bias=True, and learnable_epsilon=True. Trained on English C4 to Chinchilla-optimal token budget (20Γ params β 5.22 B tokens).
The scalar_bias=True variant uses a single shared bias across all neurons (shape (1,)) β this matches the theoretical formulation of YatNMN more faithfully than the original per-neuron (ff,) bias. This is the ablation testing whether the per-neuron bias in the original YatNMN-Softplus run was load-bearing.
Final result
| Metric | Value |
|---|---|
| Final loss | 3.0244 |
| Smooth final loss | 3.0641 |
| Tokens | 5.22 B |
| Steps | 19,920 |
| Wall time | 2.2 h |
| Throughput | ~653 K tok/s |
Comparison vs other MLPs (same arch, same data, same compute)
| MLP variant | Final loss | Ξ vs GELU |
|---|---|---|
| YatNMN-Softplus (per-neuron bias) | 2.98 | β0.13 |
| YatNMN-Softplus + scalar_bias (this model) | 3.06 | β0.05 |
| GELU | 3.11 | baseline |
| YatNMN (plain, no softplus) | 3.13 | +0.02 |
| YatNMN-Softplus + fineweb-edu continuation | 2.84 | β0.27 |
Takeaways:
scalar_biasstill beats GELU by 0.05 nats at identical compute, so the YatNMN family remains a win over stock MLPs.- The per-neuron bias variant (2.98) beats
scalar_bias(3.06) by 0.08 nats β so the extra(ff,)parameters are doing useful work in this regime, even though they deviate from the clean theoretical formulation. - The gap may narrow with more compute / seeds β this is seed=0 only. Seeds 1 and 2 of the scalar_bias 3-seed sweep are still running.
Architecture
Full nanochat stack with YatNMN-Softplus + scalar_bias swapping in for ReLUΒ² in the MLP β everything else identical:
- RoPE (base 100K), no learned positional embeddings
- MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
- QK norm with 1.2Γ scaling
- Sliding window attention with
"SSSL"pattern - Tied embeddings (wte β lm_head)
- Parameterless RMSNorm post-embedding and per block
- Value embeddings (ResFormer-style, alternating layers)
- Per-layer learnable residual scalars
- Smear (learnable bigram gate on first 24 dims of token embedding)
- Backout (subtract mid-layer residual)
- Logit soft-capping via
tanh(x/15)Β·15 - No biases in attention (Q/K/V/proj) or MLP output projection
- MLP FFN:
YatNMN(n β 4n) β Linear(4n β n)- Replaces the nanochat default
relu(x)**2with the YatNMN formulay = Ξ± Β· (xΒ·W + b)Β² / (||x β W||Β² + Ξ΅) scalar_bias=Trueβbis a shared(1,)parameter across all4nneuronssoftplus_bias=Trueβb = softplus(raw_b)to keepb > 0learnable_epsilon=TrueβΞ΅ = softplus(raw_eps)(init target 1e-3)alphais learnable per neuron (standard YatNMN)- Applied via the
--mlp yatnmn-softplusflag inscripts/train_d12_chinchilla.py(the scalar_bias branch is always enabled foryatnmn-softplus)
- Replaces the nanochat default
Training
| Architecture | Nanochat-style GPT with YatNMN-Softplus+scalar_bias MLP |
| Parameters | 261,133,214 |
| Config | d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL window |
| Data | allenai/c4 (English split, streamed) |
| Tokenizer | mistralai/Mistral-7B-v0.1 (vocab 32,768) |
| Optimizer | plain AdamW, Ξ²=(0.9, 0.999), wd=0.01, global-norm grad clip 1.0 |
| LR schedule | warmup-cosine-decay, peak 0.03, warmup 500, end 5% of peak |
| Batch | 32/device Γ 8 devices = 256 global (262 K tokens/step) |
| Seq length | 1024 |
| Tokens | Chinchilla 20Γ params β 5.22 B tokens (19,920 steps) |
| Hardware | TPU v6e-8 (TRC), europe-west4-a |
| Seed | 0 |
| nmn | nmn>=0.2.29 (includes scalar_bias flag) |
| Wandb | irf-sic/flaxchat β yatnmn-softplus-d12-chinchilla-lr0.03-seed0 |
The optimizer state is preserved β checkpoint is resumable for further training (see mlnomad/yatnmn-softplus-d12-fineweb-edu-261M for an example of continued pretraining on fineweb-edu).
Contents
.
βββ 19920/ # final Orbax checkpoint
β βββ _CHECKPOINT_METADATA
β βββ metadata/ # JSON: {"final_loss": 3.02, "smooth": 3.06}
β βββ model/ # nnx.Param state β architecture weights
β βββ optimizer/ # optax AdamW state (m/v moments + step count + LR state)
βββ config.json
βββ README.md
βββ code/ # snapshots of flaxchat code (source of truth: github)
βββ gpt.py
βββ checkpoint.py
βββ config.py
βββ train_d12_chinchilla.py # the exact training script (use --mlp yatnmn-softplus)
βββ load_model.py
Loading
git clone https://github.com/mlnomadpy/flaxchat && cd flaxchat
pixi install # or pip install with the deps in code/load_model.py
pip install 'nmn>=0.2.29'
huggingface-cli download mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M --local-dir ./hf_model
# Inside your flaxchat clone β adapt code/load_model.py:
from flax import nnx
import jax.numpy as jnp
from nmn.nnx.layers import YatNMN
from flaxchat.gpt import GPT, GPTConfig, Block
from flaxchat.checkpoint import restore_model_from_checkpoint
# Patch MLP to YatNMN-Softplus with scalar_bias (matches training)
_orig_block = Block.__init__
def _patched(self, config, layer_idx, *, rngs, use_remat=False):
_orig_block(self, config, layer_idx, rngs=rngs, use_remat=use_remat)
class YatFFN(nnx.Module):
def __init__(self, n, ff, *, rngs):
self.c_fc = YatNMN(
n, ff,
use_bias=True,
scalar_bias=True,
softplus_bias=True,
learnable_epsilon=True,
epsilon=1e-3,
rngs=rngs,
)
self.c_proj = nnx.Linear(ff, n, use_bias=False, rngs=rngs)
def __call__(self, x): return self.c_proj(self.c_fc(x))
self.mlp = YatFFN(config.n_embd, 4 * config.n_embd, rngs=rngs)
Block.__init__ = _patched
# Zero-init c_proj (matches training init)
_orig_init = GPT._init_weights
def _zero_proj(self):
_orig_init(self)
for b in self.blocks:
b.mlp.c_proj.kernel[...] = jnp.zeros_like(b.mlp.c_proj.kernel[...])
GPT._init_weights = _zero_proj
config = GPTConfig(
sequence_len=1024, vocab_size=32768,
n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
meta = restore_model_from_checkpoint(model, "./hf_model/19920")
print(f"Loaded β final loss: {meta['final_loss']:.4f}") # expect 3.0244
Related models
mlnomad/yatnmn-softplus-d12-chinchilla-261Mβ same arch with per-neuron(ff,)bias, loss 2.98 (beats this by 0.08)mlnomad/yatnmn-softplus-d12-fineweb-edu-261Mβ YatNMN-Softplus continued on fineweb-edu, loss 2.84mlnomad/gelu-d12-chinchilla-261Mβ GELU baseline at identical compute, loss 3.11
License
Apache 2.0.
- Downloads last month
- 26