YatNMN-Softplus + scalar_bias d=12 Chinchilla (261M)

A 261M-parameter nanochat-architecture GPT whose MLP uses the YatNMN layer from nmn>=0.2.29 with softplus_bias=True, scalar_bias=True, and learnable_epsilon=True. Trained on English C4 to Chinchilla-optimal token budget (20Γ— params β‰ˆ 5.22 B tokens).

The scalar_bias=True variant uses a single shared bias across all neurons (shape (1,)) β€” this matches the theoretical formulation of YatNMN more faithfully than the original per-neuron (ff,) bias. This is the ablation testing whether the per-neuron bias in the original YatNMN-Softplus run was load-bearing.

Final result

Metric Value
Final loss 3.0244
Smooth final loss 3.0641
Tokens 5.22 B
Steps 19,920
Wall time 2.2 h
Throughput ~653 K tok/s

Comparison vs other MLPs (same arch, same data, same compute)

MLP variant Final loss Ξ” vs GELU
YatNMN-Softplus (per-neuron bias) 2.98 βˆ’0.13
YatNMN-Softplus + scalar_bias (this model) 3.06 βˆ’0.05
GELU 3.11 baseline
YatNMN (plain, no softplus) 3.13 +0.02
YatNMN-Softplus + fineweb-edu continuation 2.84 βˆ’0.27

Takeaways:

  • scalar_bias still beats GELU by 0.05 nats at identical compute, so the YatNMN family remains a win over stock MLPs.
  • The per-neuron bias variant (2.98) beats scalar_bias (3.06) by 0.08 nats β€” so the extra (ff,) parameters are doing useful work in this regime, even though they deviate from the clean theoretical formulation.
  • The gap may narrow with more compute / seeds β€” this is seed=0 only. Seeds 1 and 2 of the scalar_bias 3-seed sweep are still running.

Architecture

Full nanochat stack with YatNMN-Softplus + scalar_bias swapping in for ReLUΒ² in the MLP β€” everything else identical:

  • RoPE (base 100K), no learned positional embeddings
  • MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
  • QK norm with 1.2Γ— scaling
  • Sliding window attention with "SSSL" pattern
  • Tied embeddings (wte ↔ lm_head)
  • Parameterless RMSNorm post-embedding and per block
  • Value embeddings (ResFormer-style, alternating layers)
  • Per-layer learnable residual scalars
  • Smear (learnable bigram gate on first 24 dims of token embedding)
  • Backout (subtract mid-layer residual)
  • Logit soft-capping via tanh(x/15)Β·15
  • No biases in attention (Q/K/V/proj) or MLP output projection
  • MLP FFN: YatNMN(n β†’ 4n) β†’ Linear(4n β†’ n)
    • Replaces the nanochat default relu(x)**2 with the YatNMN formula y = Ξ± Β· (xΒ·W + b)Β² / (||x βˆ’ W||Β² + Ξ΅)
    • scalar_bias=True β†’ b is a shared (1,) parameter across all 4n neurons
    • softplus_bias=True β†’ b = softplus(raw_b) to keep b > 0
    • learnable_epsilon=True β†’ Ξ΅ = softplus(raw_eps) (init target 1e-3)
    • alpha is learnable per neuron (standard YatNMN)
    • Applied via the --mlp yatnmn-softplus flag in scripts/train_d12_chinchilla.py (the scalar_bias branch is always enabled for yatnmn-softplus)

Training

Architecture Nanochat-style GPT with YatNMN-Softplus+scalar_bias MLP
Parameters 261,133,214
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL window
Data allenai/c4 (English split, streamed)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, Ξ²=(0.9, 0.999), wd=0.01, global-norm grad clip 1.0
LR schedule warmup-cosine-decay, peak 0.03, warmup 500, end 5% of peak
Batch 32/device Γ— 8 devices = 256 global (262 K tokens/step)
Seq length 1024
Tokens Chinchilla 20Γ— params β‰ˆ 5.22 B tokens (19,920 steps)
Hardware TPU v6e-8 (TRC), europe-west4-a
Seed 0
nmn nmn>=0.2.29 (includes scalar_bias flag)
Wandb irf-sic/flaxchat β€” yatnmn-softplus-d12-chinchilla-lr0.03-seed0

The optimizer state is preserved β†’ checkpoint is resumable for further training (see mlnomad/yatnmn-softplus-d12-fineweb-edu-261M for an example of continued pretraining on fineweb-edu).

Contents

.
β”œβ”€β”€ 19920/                          # final Orbax checkpoint
β”‚   β”œβ”€β”€ _CHECKPOINT_METADATA
β”‚   β”œβ”€β”€ metadata/                   # JSON: {"final_loss": 3.02, "smooth": 3.06}
β”‚   β”œβ”€β”€ model/                      # nnx.Param state β€” architecture weights
β”‚   └── optimizer/                  # optax AdamW state (m/v moments + step count + LR state)
β”œβ”€β”€ config.json
β”œβ”€β”€ README.md
└── code/                           # snapshots of flaxchat code (source of truth: github)
    β”œβ”€β”€ gpt.py
    β”œβ”€β”€ checkpoint.py
    β”œβ”€β”€ config.py
    β”œβ”€β”€ train_d12_chinchilla.py     # the exact training script (use --mlp yatnmn-softplus)
    └── load_model.py

Loading

git clone https://github.com/mlnomadpy/flaxchat && cd flaxchat
pixi install   # or pip install with the deps in code/load_model.py
pip install 'nmn>=0.2.29'
huggingface-cli download mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M --local-dir ./hf_model
# Inside your flaxchat clone β€” adapt code/load_model.py:
from flax import nnx
import jax.numpy as jnp
from nmn.nnx.layers import YatNMN
from flaxchat.gpt import GPT, GPTConfig, Block
from flaxchat.checkpoint import restore_model_from_checkpoint

# Patch MLP to YatNMN-Softplus with scalar_bias (matches training)
_orig_block = Block.__init__
def _patched(self, config, layer_idx, *, rngs, use_remat=False):
    _orig_block(self, config, layer_idx, rngs=rngs, use_remat=use_remat)
    class YatFFN(nnx.Module):
        def __init__(self, n, ff, *, rngs):
            self.c_fc = YatNMN(
                n, ff,
                use_bias=True,
                scalar_bias=True,
                softplus_bias=True,
                learnable_epsilon=True,
                epsilon=1e-3,
                rngs=rngs,
            )
            self.c_proj = nnx.Linear(ff, n, use_bias=False, rngs=rngs)
        def __call__(self, x): return self.c_proj(self.c_fc(x))
    self.mlp = YatFFN(config.n_embd, 4 * config.n_embd, rngs=rngs)
Block.__init__ = _patched

# Zero-init c_proj (matches training init)
_orig_init = GPT._init_weights
def _zero_proj(self):
    _orig_init(self)
    for b in self.blocks:
        b.mlp.c_proj.kernel[...] = jnp.zeros_like(b.mlp.c_proj.kernel[...])
GPT._init_weights = _zero_proj

config = GPTConfig(
    sequence_len=1024, vocab_size=32768,
    n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
    window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
meta = restore_model_from_checkpoint(model, "./hf_model/19920")
print(f"Loaded β€” final loss: {meta['final_loss']:.4f}")  # expect 3.0244

Related models

License

Apache 2.0.

Downloads last month
26
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M

Finetunes
1 model

Dataset used to train mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M