NeoLLM / modeling_neollm.py
KitsuVp's picture
Update modeling_neollm.py
94c09c1 verified
#!/usr/bin/env python3
"""
NeoLLM model with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
full attention augmented with optional Momentum, MEA, and LUCID operators,
Gated Attention (Qiu et al., 2025) combined with Affine-Scaled Attention
(Bae et al., 2026), an optional Leviathan continuous token embedding
generator, an optional Leviathan-JTok-M token-indexed modulation module,
optional Spelling Bee Embeddings (Rabe et al., 2026), and optional Context
Re-Positioning (Li et al., 2026).
Attention stack (orthogonal, all active simultaneously when enabled):
1. Gated Attention (use_gated_attention implicit via q_proj gate chunk):
applies a head-specific elementwise sigmoid gate to the concatenated
SDPA output before o_proj (G1 position, Qiu et al. 2025 Β§2.2).
Introduces non-linearity between W_V and W_O, sparse input-dependent
gating, and eliminates attention sink.
2. Affine-Scaled Attention (use_affine_scaled_attention):
modulates softmax attention weights directly as
[α(X)·softmax(QK^T/√dk) + β(X)] V
relaxing the unit-sum constraint of softmax. Ξ± is per-head, per-query,
input-dependent and bounded in [0,1] via linear_clipping. Ξ² is a
moving-average bias that prevents collapse. Reduces first-token bias,
increases attention entropy, and is complementary to Gated Attention
(Bae et al. 2026, Table 2: Affine-Scaled + Gated > either alone).
Flash/SDPA path (exact, no approximation):
Expanding [Ξ±Β·softmax(QKα΅€)+Ξ²]V distributively yields two terms:
Ξ± Β· flash_attn(Q,K,V) β€” Flash computes this directly
+ Ξ² Β· Ξ£_{j≀i} V_j β€” causal prefix-sum of V (V.cumsum)
The V_cumsum tensor [B,H,S,d_head] is the only extra memory vs a
standard flash call. Semantically identical to the eager path;
per-weight softmax tensors (attn_weights_pre/post_affine) are
unavailable in flash mode and remain None in AnalysisState.
Eager path: exact with full weight access, used for interpretability.
References:
FANformer: "FANformer: Improving Large Language Models Through Effective
Periodicity Modeling"
SeeDNorm: "Self-Rescaled Dynamic Normalization"
Learnable Multipliers: "Learnable Multipliers: Freeing the Scale of Language
Model Matrix Layers"
Gated Attention: Qiu et al. (2025). "Gated Attention for Large Language
Models: Non-linearity, Sparsity, and Attention-Sink-Free."
arXiv:2505.06708.
Affine-Scaled Attention: Bae et al. (2026). "Affine-Scaled Attention:
Towards Flexible and Stable Transformer Attention." arXiv:2602.23057.
Leviathan Generator: Batley & Saha (2026). "A Separable Architecture for
Continuous Token Representation in Language Models." arXiv:2601.22040.
KHRONOS: Batley & Saha (2025). "KHRONOS: a Kernel-Based Neural Architecture
for Rapid, Resource-Efficient Scientific Computation." arXiv:2505.13315.
JTok / JTok-M: Yang et al. (2026). "JTok: On Token Embedding as Another
Axis of Scaling Law via Joint Token Self-Modulation." arXiv:2602.00800.
Spelling Bee Embeddings: Rabe, Clymo & Dong (2026). "Spelling Bee
Embeddings for Language Modeling." arXiv:2601.18030.
Context Re-Positioning: Li, Zhao, Cai & Sproat (2026). "REPO: Language
Models with Context Re-Positioning." arXiv:2512.14391.
"""
import math
from dataclasses import dataclass
from typing import Callable, List, Optional, Union, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from cut_cross_entropy import linear_cross_entropy
from torch.utils.checkpoint import checkpoint
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, logging
from .configuration_neollm import NeoLLMConfig
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
torch._dynamo.config.capture_scalar_outputs = True
logger = logging.get_logger(__name__)
# ==================== NATIVE ANALYSIS STATE ====================
# All dataclasses below define the analysis container hierarchy.
# This infrastructure is ONLY active when:
# 1. model.enable_analysis() has been called, AND
# 2. the model is in eval mode (model.eval() / not model.training)
#
# During training, analysis_state is always None β€” zero overhead,
# zero interference with gradients, zero change to the training flow.
#
# Access after any inference call:
# state = model.last_analysis # AnalysisState | None
#
# Every tensor stored here is detached from the computation graph.
# Fields are None when the corresponding config flag is inactive,
# so adding or removing a component only requires updating its own
# dataclass β€” all other analysis paths remain untouched.
@dataclass
class FANAnalysis:
"""
Decomposed output of a FANLayer call.
FANLayer'(X) = [cos(WpΒ·X) β€– sin(WpΒ·X) β€– (WpΜ„Β·X + BpΜ„)]
"""
cosine_component: Optional[torch.Tensor] = None # [*, p_output_dim]
sine_component: Optional[torch.Tensor] = None # [*, p_output_dim]
linear_component: Optional[torch.Tensor] = None # [*, g_output_dim]
@dataclass
class SeeDNormAnalysis:
"""
Internals of a SeeDNorm forward pass.
SeeDNorm(x) = [Οƒ(xΒ·Ξ²^T)Β·Ξ± + Ξ³] βŠ™ x/RMS(x)
"""
rescale_factor: Optional[torch.Tensor] = None # tanh(Ξ£ xΒ·Ξ²) β€” the dynamic gate
dynamic_scale: Optional[torch.Tensor] = None # rescale_factorΒ·Ξ± + Ξ³
x_normalized: Optional[torch.Tensor] = None # x / RMS(x) before dynamic_scale
output: Optional[torch.Tensor] = None # final output
@dataclass
class GPASAnalysis:
"""
Internals of a GPAS forward pass.
GPAS(x) = x - silu(Ξ±)Β·x_detached
"""
silu_alpha: Optional[torch.Tensor] = None # F.silu(self.alpha) β€” scalar
subtracted_component: Optional[torch.Tensor] = None # silu(Ξ±) Β· x.detach()
@dataclass
class PolyNormAnalysis:
"""
Internals of a PolyNorm forward pass (used as MLP act_fn).
Branches: x1 (linear), x2 (quadratic), x3 (cubic), each normalized.
x2 and x3 are partially orthogonalized against x1 via learned Ξ±.
"""
x1: Optional[torch.Tensor] = None # RMS-normalized linear branch
x2_pre_exclusive: Optional[torch.Tensor] = None # quadratic branch before exclusivity
x3_pre_exclusive: Optional[torch.Tensor] = None # cubic branch before exclusivity
x2_post_exclusive: Optional[torch.Tensor] = None # quadratic after exclusivity + renorm
x3_post_exclusive: Optional[torch.Tensor] = None # cubic after exclusivity + renorm
alpha2: Optional[torch.Tensor] = None # sigmoid(exclusive_logits[0]) for x2
alpha3: Optional[torch.Tensor] = None # sigmoid(exclusive_logits[1]) for x3
weights: Optional[torch.Tensor] = None # self.weight [3] β€” branch mixing
bias: Optional[torch.Tensor] = None # self.bias scalar
output: Optional[torch.Tensor] = None # final PolyNorm output
@dataclass
class HadamardAnalysis:
"""
Internals of a HadamardOProj forward pass.
Only populated when use_hadamard_o_proj=True.
Reference: Aggarwal & Kumar (2026). arXiv:2603.08343.
post_fwht: WHT output before Ξ± scaling [..., D] β€” useful to verify
that the transform is truly norm-preserving (ΞΊ=1 sanity check).
alpha_snapshot: detached copy of the learnable Ξ± vector [D] β€” tracks how
per-channel scaling evolves during training analysis.
"""
post_fwht: Optional[torch.Tensor] = None # FWHT(x)/√d before α [B,S,D]
alpha_snapshot: Optional[torch.Tensor] = None # self.alpha [D] β€” learned scale
@dataclass
class REPOAnalysis:
"""
Internals of a REPOModule forward pass.
Only populated when use_repo=True and layer_idx >= repo_start_layer.
Reference: Li, Zhao, Cai & Sproat (2026). arXiv:2512.14391.
positions: continuous per-head positions z [B, H, S] produced by f_Ο•.
Captures what position pattern the model learned for each head:
constant (NoPE-like), monotonic (RoPE-like), or hybrid.
Use this field with the attention interpretability toolkit to
reproduce the position-pattern analysis of Li et al. (2026) Β§5.2.
r_repr: intermediate position representation r [B, S, d_p] β€” output of
the SwiGLU sub-layer before the per-head linear W_z.
Shared across heads within the layer.
"""
positions: Optional[torch.Tensor] = None # z [B, H, S] β€” predicted positions
r_repr: Optional[torch.Tensor] = None # r [B, S, d_p] β€” shared repr
@dataclass
class AttentionAnalysis:
"""
Full attention internals for one NeoLLMAttention forward pass.
Fields present unconditionally capture the always-active path.
Fields guarded by config flags are None when that flag is off.
"""
# ── FANLayer (always active) ──────────────────────────────────────
fan: Optional[FANAnalysis] = None # FAN components for attention
# ── Q/K/V projection (always active) ─────────────────────────────
q_raw: Optional[torch.Tensor] = None # Q before q_norm [B,S,H,d]
gate_raw: Optional[torch.Tensor] = None # gate chunk before sigmoid [B,S,H*d]
gate_sigmoid: Optional[torch.Tensor] = None # sigmoid(gate) β€” Gated Attention weight
q_post_norm: Optional[torch.Tensor] = None # Q after SeeDNorm q_norm [B,H,S,d]
k_post_norm: Optional[torch.Tensor] = None # K after SeeDNorm k_norm [B,H,S,d]
v_raw: Optional[torch.Tensor] = None # V raw (pre MEA/LUCID) [B,H,S,d]
# ── RoPE / REPO (always active) ──────────────────────────────────
# When use_repo=False or layer_idx < repo_start_layer: standard integer
# RoPE β€” q_post_rope/k_post_rope are Q/K after apply_rotary_pos_emb.
# When use_repo=True and layer_idx >= repo_start_layer: REPO path β€”
# q_post_rope/k_post_rope are Q/K after _apply_repo_rope with
# continuous per-head positions z from REPOModule. The positions
# themselves and the intermediate r_repr are in the .repo sub-object.
q_post_rope: Optional[torch.Tensor] = None # Q after RoPE/REPO [B,H,S,d]
k_post_rope: Optional[torch.Tensor] = None # K after RoPE/REPO [B,H,S,d]
# ── Momentum (conditional on use_momentum_attention) ──────────────
q_momentum_delta: Optional[torch.Tensor] = None # causal_first_difference(Q)
k_momentum_delta: Optional[torch.Tensor] = None # causal_first_difference(K)
q_post_momentum: Optional[torch.Tensor] = None # Q + Ξ³Β·Ξ”
k_post_momentum: Optional[torch.Tensor] = None # K + Ξ³Β·Ξ”
# ── MEA head mixing (conditional on use_mea_attention) ────────────
mea_key_mix_matrix: Optional[torch.Tensor] = None # mea_key_mix [H_comp,H_kv]
mea_value_mix_matrix: Optional[torch.Tensor] = None # mea_value_mix [H_comp,H_kv]
k_post_mea: Optional[torch.Tensor] = None # K after head mixing
v_post_mea: Optional[torch.Tensor] = None # V after head mixing
# ── LUCID preconditioner (conditional on use_lucid_attention) ─────
lucid_preconditioner: Optional[torch.Tensor] = None # lower-triangular prec matrix
v_post_lucid: Optional[torch.Tensor] = None # V after triangular solve
# ── Affine-Scaled Attention (conditional on use_affine_scaled_attention) ──
alpha_per_head: Optional[torch.Tensor] = None # Ξ± [B,H,S,1] in [0,1]
beta_per_head: Optional[torch.Tensor] = None # Ξ² [B,H,S,1] moving-avg bias
alpha_moving_avg: Optional[torch.Tensor] = None # alpha_ma EMA snapshot
attn_weights_pre_affine: Optional[torch.Tensor] = None # softmax weights before Ξ±,Ξ²
attn_weights_post_affine: Optional[torch.Tensor] = None # Ξ±Β·softmax + Ξ²
# ── Standard attention weights (eager non-affine path) ────────────
attn_weights: Optional[torch.Tensor] = None # softmax weights (None for flash/sdpa)
# ── Post-SDPA (always active) ─────────────────────────────────────
attn_output_raw: Optional[torch.Tensor] = None # SDPA output [B,S,H,d]
# ── MEA output norm (conditional on use_mea_attention) ────────────
attn_output_post_mea_norm: Optional[torch.Tensor] = None # after MEAHeadSeeDNorm
# ── XSA (conditional on use_xsa) ─────────────────────────────────
xsa_self_position_component: Optional[torch.Tensor] = None # proj subtracted
attn_output_post_xsa: Optional[torch.Tensor] = None # after XSA removal
# ── Directional Routing (conditional on use_directional_routing) ──
direction_vecs_normalized: Optional[torch.Tensor] = None # unit-norm d [H,K,d]
dr_router_logits: Optional[torch.Tensor] = None # MLP router output [B,H*K]
dr_routing_weights: Optional[torch.Tensor] = None # sigmoid(TΒ·logits) [B,H,K]
dr_projection: Optional[torch.Tensor] = None # (oΒ·d) scalars [B,S,H,K]
dr_suppression: Optional[torch.Tensor] = None # Ξ£ rΒ·projΒ·d [B,S,H,d]
attn_output_post_routing: Optional[torch.Tensor] = None # after DR removal
# ── Gate and o_proj (always active) ───────────────────────────────
attn_output_pre_gate: Optional[torch.Tensor] = None # pre gate multiply [B,S,H,d]
attn_output_final: Optional[torch.Tensor] = None # after o_proj [B,S,D]
# ── HadamardOProj internals (conditional on use_hadamard_o_proj) ──
hadamard: Optional["HadamardAnalysis"] = None # None when dense o_proj active
# ── REPO position prediction (conditional on use_repo) ────────────
repo: Optional["REPOAnalysis"] = None # None when layer_idx < repo_start_layer
@dataclass
class MLPAnalysis:
"""
Internals of a NeoLLMMLP forward pass.
SwiGLU-like: down_proj(dropout(PolyNorm(gate_proj(fan)) Β· up_proj(fan)))
When use_versatile_ffn=True, the versatile sub-object carries all
VersatileFFN internals and fan/gate/up/polynorm fields remain None.
"""
fan: Optional[FANAnalysis] = None # FAN components for MLP
gate_proj_output: Optional[torch.Tensor] = None # gate_proj(x_fan) [B,S,I]
up_proj_output: Optional[torch.Tensor] = None # up_proj(x_fan) [B,S,I]
polynorm: Optional[PolyNormAnalysis] = None # PolyNorm of gate branch
act_times_up: Optional[torch.Tensor] = None # PolyNorm(gate)Β·up [B,S,I]
output: Optional[torch.Tensor] = None # after down_proj [B,S,D]
# ── VersatileFFN (conditional on use_versatile_ffn) ──────────────────
versatile: Optional["VersatileFFNAnalysis"] = None # None when standard MLP
@dataclass
class VersatileFFNAnalysis:
"""
Internals of a VersatileFFN forward pass.
Only populated when use_versatile_ffn=True.
Reference: Nie et al. (2026). arXiv:2512.14531.
depth_probs: Gumbel-Softmax distribution p [B, S, max_depth] β€” only in
training (hard=True STE). None during inference.
expected_loops: E[L] per token [B, S] β€” differentiable proxy for difficulty.
During inference computed from discrete argmax+1.
moe_weight: Ξ» = (L_max - E[L]) / L_max [B, S] β€” fusion gate scalar.
Near 1 β†’ width dominates (easy token).
Near 0 β†’ depth dominates (hard token).
loop_choice: argmax(depth_logits) [B, S] β€” discrete depth selected at
inference. None during training.
x_depth: Output of the depth-versatile path [B, S, D].
x_width: Output of the width-versatile MoE path [B, S, D].
"""
depth_probs: Optional[torch.Tensor] = None # [B, S, max_depth] training only
expected_loops: Optional[torch.Tensor] = None # [B, S]
moe_weight: Optional[torch.Tensor] = None # [B, S]
loop_choice: Optional[torch.Tensor] = None # [B, S] inference only
x_depth: Optional[torch.Tensor] = None # [B, S, D]
x_width: Optional[torch.Tensor] = None # [B, S, D]
@dataclass
class JTokMAnalysis:
"""
Internals of a LeviathanJTokM forward pass for one decoder layer.
"""
surfaces: Optional[torch.Tensor] = None # all n_e surface outputs [N,n_e,D]
router_logits: Optional[torch.Tensor] = None # pre-TopK logits [N,n_e]
topk_indices: Optional[torch.Tensor] = None # selected surface indices [N,K]
routing_weights: Optional[torch.Tensor] = None # normalized sigmoid weights [N,K]
mixed_pre_norm: Optional[torch.Tensor] = None # weighted sum before norm [N,D]
mixed_normalized: Optional[torch.Tensor] = None # direction-normalized mixed [N,D]
delta_r: Optional[torch.Tensor] = None # final injection (scaled) [N,D]
p_sum: Optional[torch.Tensor] = None # routing probability sum [n_e]
f_sum: Optional[torch.Tensor] = None # load fraction sum [n_e]
lns_scale: Optional[float] = None # 1/√(2β„“) scaling factor
@dataclass
class AttnResAnalysis:
"""
Depth-wise softmax attention weights from AttnRes sublayer calls.
Only populated when use_attn_res=True.
"""
weights_pre_attn: Optional[torch.Tensor] = None # softmax over sources [N+1,B,S]
weights_pre_mlp: Optional[torch.Tensor] = None # softmax over sources [N+1,B,S]
sources_count: Optional[int] = None # number of sources including partial
@dataclass
class LAuReLAnalysis:
"""
Internals of the LAuReL residual augmentation for one decoder layer.
Only populated when use_laurel=True.
Two slots are provided β€” one per sublayer (attention, MLP) β€” reflecting
that LAuReL is applied independently at each residual junction.
LAuReL-RW (use_laurel_rw=True):
alpha_attn / alpha_mlp β€” effective Ξ± after softmax [scalar]
beta_attn / beta_mlp β€” effective Ξ² after softmax [scalar]
The applied residual is Ξ±Β·f(x) + Ξ²Β·x (or Ξ²Β·(BAx+x) with LR).
LAuReL-LR (use_laurel_lr=True):
lr_input_norm_attn / lr_input_norm_mlp β€” β€–xβ€–β‚‚ entering the low-rank
path, useful for monitoring magnitude growth or collapse [scalar].
lr_delta_norm_attn / lr_delta_norm_mlp β€” β€–BAxβ€–β‚‚ after the low-rank
correction; relative to lr_input_norm reveals how much the LR
branch contributes [scalar].
When both variants are active (LAUREL-RW+LR, paper eq. 5):
x_{i+1} = Ξ±Β·f(x) + Ξ²Β·(BAx + x)
All six fields are populated simultaneously.
"""
# LAUREL-RW scalars (after softmax normalisation)
alpha_attn: Optional[float] = None # Ξ± for attention sublayer
beta_attn: Optional[float] = None # Ξ² for attention sublayer
alpha_mlp: Optional[float] = None # Ξ± for MLP sublayer
beta_mlp: Optional[float] = None # Ξ² for MLP sublayer
# LAUREL-LR norms
lr_input_norm_attn: Optional[float] = None # β€–xβ€–β‚‚ entering LR path, attn sublayer
lr_delta_norm_attn: Optional[float] = None # β€–BAxβ€–β‚‚ from LR path, attn sublayer
lr_input_norm_mlp: Optional[float] = None # β€–xβ€–β‚‚ entering LR path, MLP sublayer
lr_delta_norm_mlp: Optional[float] = None # β€–BAxβ€–β‚‚ from LR path, MLP sublayer
@dataclass
class LayerAnalysis:
"""
Complete analysis snapshot for one NeoLLMDecoderLayer forward pass.
Sub-objects are None when the corresponding config flag is inactive.
"""
layer_idx: int = 0
# Hidden state boundaries
hidden_states_input: Optional[torch.Tensor] = None # entering this layer
hidden_states_output: Optional[torch.Tensor] = None # leaving this layer
h_tilde: Optional[torch.Tensor] = None # post-attn residual, pre-MLP
# Attention sublayer
seednorm_pre_attn: Optional[SeeDNormAnalysis] = None # input_layernorm
lns_attn_output: Optional[torch.Tensor] = None # after LNS(1/βˆšβ„“)
attention: Optional[AttentionAnalysis] = None # full attention analysis
attn_contribution: Optional[torch.Tensor] = None # attn output before residual
gpas_attn: Optional[GPASAnalysis] = None # GPAS after attn residual
# MLP sublayer
seednorm_post_attn: Optional[SeeDNormAnalysis] = None # post_attention_layernorm
lns_mlp_output: Optional[torch.Tensor] = None # after LNS(1/βˆšβ„“)
mlp: Optional[MLPAnalysis] = None # full MLP analysis
mlp_contribution: Optional[torch.Tensor] = None # MLP output before residual
gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
# Optional components (None when inactive)
jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
laurel: Optional[LAuReLAnalysis] = None # if use_laurel
@dataclass
class GeneratorAnalysis:
"""
Internals of a LeviathanGenerator forward pass.
Only populated when use_token_generator=True.
"""
z_raw: Optional[torch.Tensor] = None # [N,d_seed] codebook sum
z_tilde: Optional[torch.Tensor] = None # [N,d_seed] JTok-M path output
B_vals: Optional[torch.Tensor] = None # [N,d_seed,n_knots] B-spline basis
z_all_pre_norm: Optional[torch.Tensor] = None # [N,M,d_seed] per-head pre-sigmoid
z_all_post_sigmoid: Optional[torch.Tensor] = None # [N,M,d_seed] per-head post-sigmoid
modes_all: Optional[torch.Tensor] = None # [N,M,krank] KHRONOS tensor product
embeddings: Optional[torch.Tensor] = None # [N,hidden_size] final output
@dataclass
class AnalysisState:
"""
Root analysis container for one NeoLLMForCausalLM forward pass.
Populated when model.enable_analysis() is active AND model.training=False.
All tensors are detached from the computation graph.
Structure:
.input_ids β€” original input token ids
.embeddings β€” token embeddings entering the decoder stack
.final_hidden_states β€” post-norm output entering lm_head
.generator β€” LeviathanGenerator internals (use_token_generator only)
.layers[i] β€” per-layer LayerAnalysis, indexed by layer_idx
.jtokm_aux_stats β€” raw per-layer load-balancing tuples (use_jtokm only)
.attn_res_sources_final β€” AttnRes source list at end of forward (use_attn_res only)
.logits β€” lm_head output (only when labels is None)
Access pattern:
model.eval()
model.enable_analysis()
_ = model(input_ids)
state = model.last_analysis
alpha = state.layers[3].attention.alpha_per_head
REPO access (use_repo=True, layers >= repo_start_layer):
# Per-head predicted positions [B, H, S] for layer i:
z = state.layers[i].attention.repo.positions
# Shared position representation r [B, S, d_p] for layer i:
r = state.layers[i].attention.repo.r_repr
# Q/K after REPO rotation (identical field name as standard RoPE path):
q = state.layers[i].attention.q_post_rope # [B, H, S, head_dim]
k = state.layers[i].attention.k_post_rope # [B, H, S, head_dim]
# For layers below repo_start_layer: .repo is None, q/k_post_rope
# contain standard integer-RoPE rotated Q/K β€” same field, same shape.
"""
input_ids: Optional[torch.Tensor] = None
embeddings: Optional[torch.Tensor] = None
final_hidden_states: Optional[torch.Tensor] = None
generator: Optional[GeneratorAnalysis] = None
layers: Optional[List[LayerAnalysis]] = None
jtokm_aux_stats: Optional[list] = None
attn_res_sources_final: Optional[list] = None
logits: Optional[torch.Tensor] = None
class ScalarMultiplier(nn.Module):
"""
Scalar Learnable Multiplier: W̃ = s·W
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the
WD-noise equilibrium that constrains ||W|| ∝ √(η/λ).
"""
def __init__(self, initial_value: float = 1.0):
super().__init__()
self.multiplier = nn.Parameter(torch.tensor(initial_value))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.multiplier * x
class VectorMultiplier(nn.Module):
"""
Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c)
From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
Frees not only the overall matrix norm but also individual row/column norms from
the WD-noise equilibrium, enabling richer feature scale diversity.
"""
def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0):
super().__init__()
self.multiplier_type = multiplier_type
self.multiplier = nn.Parameter(torch.ones(dim) * initial_value)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.multiplier
class LinearWithMultipliers(nn.Module):
"""
Linear layer with optional row and/or column learnable multipliers.
Implements: y = (r βŠ™ (W @ (c βŠ™ x))) + b
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
use_row_multiplier: bool = False,
use_column_multiplier: bool = False
):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.use_row_multiplier = use_row_multiplier
self.use_column_multiplier = use_column_multiplier
if use_row_multiplier:
self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row")
if use_column_multiplier:
self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column")
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_column_multiplier:
x = self.column_multiplier(x)
x = self.linear(x)
if self.use_row_multiplier:
x = self.row_multiplier(x)
return x
# ==================== LEVIATHAN CONTINUOUS TOKEN GENERATOR ====================
class LeviathanGenerator(nn.Module):
"""
Continuous token embedding generator for the Leviathan architecture.
Replaces E ∈ R^{VΓ—D} with a separable generator G : {0,...,V-1} β†’ R^D.
Three-stage pipeline: latent compositional indexing β†’ B-spline basis
expansion β†’ tensor-product aggregation (Batley & Saha, 2026, Β§3.1).
When ``return_internals=True`` the forward returns
``(embeddings, z_tilde, B_vals)`` for reuse by JTok-M surfaces in every
decoder layer, avoiding redundant B-spline evaluation.
B-spline basis uses the KHRONOS closed-form quadratic kernel (ckhronos.py)
which is fully compatible with torch.compile (no .item(), no Python loops
inside the kernel, all tensor shapes static).
Sign parity tracking matches KHRONOS KHRONOSLayer exactly.
Initialization: spline_coeff ~ normal(mean=1.0, std=0.1), matching
KHRONOS init_weights_prod so that phi β‰ˆ 1.0 at init and the product of
d_seed factors starts near 1.0 instead of ~10^{-21}.
**Frequency-based codebook ordering (optional)**
By default, the base-k decomposition maps token indices directly to
codebook coordinates via arithmetic: token x β†’ (x // bΒ², x // b % b, x % b).
This assigns coordinates based on index position, which is arbitrary with
respect to linguistic meaning under BPE tokenisation.
When ``set_freq_order`` is called with a frequency-rank tensor, the
decomposition maps tokens through their frequency rank first:
token x β†’ rank_freq[x] β†’ (rank // bΒ², rank // b % b, rank % b).
This makes tokens with similar corpus frequency share codebook entries,
introducing pre-existing statistical structure into the gradient of W_res
from step 0. Since token frequency correlates with distributional behaviour
(Zipfian distribution, syntactic category, semantic class), the gradient
βˆ‚L/βˆ‚W_res = Ξ£_x Ξ΄_x Β· zΜƒ_x^T
has low-rank structure immediately exploitable by Conda's SVD projection,
analogous to how the dense embedding table E gradient has low-rank structure
from the language distribution. Without this ordering, the SVD finds only
noise until codebooks organise through training, delaying Conda's advantage.
If ``set_freq_order`` is never called, ``freq_order`` remains None and the
module behaves identically to the original implementation β€” the feature is
fully opt-in and backward compatible.
"""
def __init__(self, config: NeoLLMConfig):
super().__init__()
vocab_size = config.vocab_size
hidden_size = config.hidden_size
d_seed = config.generator_d_seed
num_modes = config.generator_num_modes
num_knots = config.generator_num_knots
spline_degree = config.generator_spline_degree
k = config.generator_k
krank = getattr(config, "generator_krank", 64)
b = math.ceil(vocab_size ** (1.0 / k))
self.b = b
self.k = k
self.d_seed = d_seed
self.num_modes = num_modes
self.num_knots = num_knots
self.spline_degree = spline_degree
self.krank = krank
self.hidden_size = hidden_size
# ── Stage 1: shared codebook lookup ──────────────────────────────
# Produces z [N, d_seed] β€” the raw seed before any per-head
# preprocessing. This is the only shared computation across heads.
self.codebooks = nn.Parameter(torch.empty(k, b, d_seed))
# Frequency-based codebook ordering (opt-in via set_freq_order).
# Non-persistent: not saved to checkpoints, must be set at load time.
self.register_buffer("freq_order", None, persistent=False)
# Shared knot grid β€” fixed, not learned.
# Used by both the generator heads and the JTok-M shared path.
self.register_buffer(
"knot_grid",
torch.linspace(0.0, 1.0, num_knots),
persistent=False,
)
# ── JTok-M shared path ────────────────────────────────────────────
# seed_proj + seed_norm produce z_tilde [N, d_seed] used by JTok-M
# surfaces in every decoder layer. This path is kept separate from
# the per-head generator path so JTok-M is completely unaffected
# by the generator architecture change.
self.seed_proj = nn.Linear(d_seed, d_seed, bias=True)
self.seed_norm = nn.LayerNorm(d_seed)
# ── Per-head generator (fused, vectorized) ───────────────────────
# Mathematically identical to 8 independent heads but fused into
# single tensors so the entire per-head path executes in 6 kernels
# instead of 8Γ—5=40, and the maximum intermediate tensor appears
# once instead of 8 times β€” eliminating the fragmentation that caused
# OOM during backward.
#
# head_proj_weight [num_modes*d_seed, d_seed]:
# Replaces ModuleList of 8 Linear(d_seed, d_seed, bias=False).
# Forward: z @ W^T β†’ [N, M*d_seed] β†’ reshape [N, M, d_seed].
# Gradient is identical: each [d_seed, d_seed] block receives
# gradient only from its own head output.
#
# head_norm_weight [num_modes, d_seed], head_norm_bias [num_modes, d_seed]:
# Replace ModuleList of 8 LayerNorm(d_seed) with independent
# weight/bias per head. Manual LN formula over last dim preserves
# exact per-head normalization semantics.
#
# head_scale [num_modes, d_seed]: unchanged, already fused.
#
# head_spline [num_modes, d_seed, num_knots, krank]: unchanged, already fused.
#
# head_out_weight [num_modes, krank, hidden_size]:
# Replaces ModuleList of 8 Linear(krank, hidden_size, bias=False).
# Forward: einsum("nmk,mkd->nd", modes, W) β€” all heads projected
# and summed in a single kernel.
self.head_proj_weight = nn.Parameter(
torch.empty(num_modes * d_seed, d_seed)
)
self.head_norm_weight = nn.Parameter(torch.ones(num_modes, d_seed))
self.head_norm_bias = nn.Parameter(torch.zeros(num_modes, d_seed))
self.head_norm_eps = 1e-5
# head_scale: [num_modes, d_seed], initialized to (num_knots - 1)
self.head_scale = nn.Parameter(
torch.full((num_modes, d_seed), float(num_knots - 1))
)
# head_spline: [num_modes, d_seed, num_knots, krank]
self.head_spline = nn.Parameter(
torch.empty(num_modes, d_seed, num_knots, krank)
)
self.head_out_weight = nn.Parameter(
torch.empty(num_modes, krank, hidden_size)
)
def set_freq_order(self, freq_order: torch.Tensor) -> None:
"""
Register a frequency-rank mapping to structure codebook coordinates.
Must be called after model instantiation and after any device transfer
(.to(device), .cuda(), etc.) since the buffer is non-persistent and
is not saved to checkpoints.
Args:
freq_order: Long tensor of shape ``(vocab_size,)`` where
``freq_order[x]`` is the frequency rank of token x in the
training corpus (rank 0 = most frequent token). Typically
computed as ``torch.argsort(token_counts, descending=True)``.
Example::
counts = compute_token_frequencies(tokenizer, dataset) # [V]
ranks = torch.argsort(counts, descending=True) # [V]
model.model.token_generator.set_freq_order(ranks)
"""
if freq_order.shape[0] != self.codebooks.shape[1] ** self.k:
# Soft warning: shape mismatch may indicate wrong vocab size.
# Not a hard error since vocab_size in config may be padded.
pass
self.freq_order = freq_order.long().to(self.codebooks.device)
def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Deterministic base-b decomposition: i β†’ (i_0, ..., i_{k-1}).
When ``freq_order`` is set, token indices are remapped through their
frequency rank before decomposition. This ensures that tokens sharing
codebook entries are similar in corpus frequency rather than arbitrary
in BPE index space, providing pre-existing low-rank gradient structure
for Conda's SVD projection from step 0.
Without ``freq_order``: x β†’ (x // b^{k-1}, ..., x % b)
With ``freq_order``: x β†’ freq_order[x] β†’ (rank // b^{k-1}, ..., rank % b)
"""
ids = token_ids.long().clone()
if self.freq_order is not None:
ids = self.freq_order[ids]
coords = torch.empty(
*token_ids.shape, self.k,
dtype=torch.long, device=token_ids.device,
)
for r in range(self.k - 1, -1, -1):
coords[..., r] = ids % self.b
ids = ids // self.b
return coords
def _bspline_basis(self, x_flat: torch.Tensor) -> torch.Tensor:
"""
KHRONOS quadratic B-spline basis with fixed scalar scale.
Used exclusively by the JTok-M shared path (z_tilde β†’ B_vals).
JTok-M surfaces have their own spline_coeff and call _modes_from_basis
with these B_vals. This path is unchanged from the original design.
Args:
x_flat: [N, d_seed], values in [0, 1].
Returns:
[N, d_seed, num_knots] float32.
"""
scale = float(self.num_knots - 1)
x32 = x_flat.float()
x_e = x32.unsqueeze(-1)
grid = self.knot_grid.float().view(1, 1, -1)
d = (x_e - grid).abs() * scale
return torch.where(
d < 0.5,
0.75 - d ** 2,
torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
) # [N, d_seed, num_knots] float32
def _bspline_basis_all_heads(
self,
x_all: torch.Tensor,
) -> torch.Tensor:
"""
Vectorized KHRONOS quadratic B-spline basis for all heads at once.
Mathematically identical to calling _bspline_basis_head 8 times in a
loop, but materializes the full [N, M, d_seed, n_knots] tensor in a
single kernel instead of 8 sequential [N, d_seed, n_knots] tensors.
Args:
x_all: [N, M, d_seed], values in [0, 1], all heads stacked.
Returns:
[N, M, d_seed, n_knots] float32.
NOTE: Este mΓ©todo se mantiene para compatibilidad con JTok-M y anΓ‘lisis.
El forward del generator ya NO lo usa β€” usa _compute_head en su lugar.
"""
x32 = x_all.float()
x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1]
grid = self.knot_grid.float().view(1, 1, 1, -1) # [1, 1, 1, n_knots]
# head_scale [M, d_seed] β†’ [1, M, d_seed, 1]
sc = self.head_scale.float().unsqueeze(0).unsqueeze(-1)
d = (x_e - grid).abs() * sc # [N, M, d_seed, n_knots]
return torch.where(
d < 0.5,
0.75 - d ** 2,
torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
) # [N, M, d_seed, n_knots] float32
def _compute_head(
self,
z: torch.Tensor,
m: int,
) -> torch.Tensor:
"""
Forward completo para el cabezal m del generator.
Reemplaza la materializaciΓ³n conjunta [N, M, d_seed, n_knots] del path
vectorizado. Cada llamada materializa solo [N, d_seed, n_knots] (1 cabezal),
reduciendo el pico de memoria de O(MΒ·d_seedΒ·n_knots) a O(d_seedΒ·n_knots)
por cabezal.
Pipeline:
z [N, d_seed]
β†’ Linear(head_proj_weight[m*d_seed:(m+1)*d_seed]) β†’ [N, d_seed]
β†’ ManualLayerNorm(weight[m], bias[m]) β†’ [N, d_seed]
β†’ sigmoid(x/2) β†’ [N, d_seed] (coordenada en [0,1]^d_seed)
β†’ B-spline KHRONOS con scale=head_scale[m] β†’ [N, d_seed, n_knots]
β†’ einsum con head_spline[m] β†’ per_dim [N, d_seed, krank]
β†’ sign-parity product (log-sum-exp) β†’ modes [N, krank]
β†’ Linear(head_out_weight[m]) β†’ [N, hidden_size]
Por quΓ© loop Python sobre M cabezales en lugar de vmap:
torch.vmap sobre cabezales con parΓ‘metros distintos requiere
functional_call y stack_module_state, lo que complica el acceso
a buffers (knot_grid, head_norm_eps) desde dentro del transform.
Un loop Python con M=8 fijo es unrolleado por TorchDynamo en una
secuencia estΓ‘tica de ops β€” exactamente como lo hace XLA/Flax en
la implementaciΓ³n original de Reza. El compilador ve 8 grafos
idΓ©nticos en estructura pero con parΓ‘metros distintos, y puede
fusionarlos u optimizarlos de forma independiente. Con chunk_size=1
en vmap el comportamiento serΓ­a anΓ‘logo pero con mayor overhead de
instrumentaciΓ³n.
Args:
z: [N, d_seed] β€” codebook seed compartido (float del dtype del modelo).
m: Γ­ndice del cabezal (0 ≀ m < num_modes), Python int estΓ‘tico.
Returns:
[N, hidden_size] β€” contribuciΓ³n de este cabezal al embedding final.
"""
d = self.d_seed
nk = self.num_knots
kr = self.krank
# ── ProyecciΓ³n lineal para el cabezal m ──────────────────────────
# head_proj_weight [M*d_seed, d_seed] β€” los pesos del cabezal m
# son las filas [m*d_seed : (m+1)*d_seed].
proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
# Keep the matmul in the parameter dtype so eager inference matches
# mixed-precision training, then promote to float32 for the reduction-
# heavy normalization and KHRONOS path below.
zh = F.linear(
z.to(dtype=proj_w.dtype, device=proj_w.device),
proj_w,
) # [N, d_seed]
zh = zh.float()
# ── LayerNorm manual por cabezal ──────────────────────────────────
# Equivalente a nn.LayerNorm(d_seed) con parΓ‘metros independientes
# head_norm_weight[m] y head_norm_bias[m].
norm_w = self.head_norm_weight[m].float()
norm_b = self.head_norm_bias[m].float()
mean = zh.mean(dim=-1, keepdim=True)
var = zh.var(dim=-1, keepdim=True, unbiased=False)
zh = (zh - mean) / (var + self.head_norm_eps).sqrt()
zh = zh * norm_w + norm_b
# ── Sigmoid(x/2) β†’ coordenada latente en [0,1]^d_seed ────────────
zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
# ── B-spline KHRONOS para este cabezal ────────────────────────────
# head_scale[m]: [d_seed] β€” escala por dimensiΓ³n para este cabezal.
# Materializa [N, d_seed, n_knots] en lugar de [N, M, d_seed, n_knots].
sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1]
x_e = zh.unsqueeze(-1) # [N, d_seed, 1]
grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
dist = (x_e - grid).abs() * sc # [N, d_seed, n_knots]
B_m = torch.where(
dist < 0.5,
0.75 - dist ** 2,
torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
) # [N, d_seed, n_knots]
# ── KHRONOS tensor product para este cabezal ──────────────────────
# head_spline[m]: [d_seed, n_knots, krank]
# per_dim[n, d, k] = Ξ£_g B_m[n, d, g] * head_spline[m, d, g, k]
# Shape: [N, d_seed, krank] β€” pico mΓ‘ximo en este cabezal.
per_dim = torch.einsum(
"ndg,dgk->ndk",
B_m,
self.head_spline[m].float(),
) # [N, d_seed, krank]
# Sign-parity log-product (KHRONOS): evita underflow multiplicando
# en log-space y recuperando el signo por paridad de negativos.
per_dim_abs = per_dim.abs() + 1e-9
log_mag = torch.log(per_dim_abs).sum(dim=1) # [N, krank]
num_neg = (per_dim < 0).long().sum(dim=1) # [N, krank]
prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank]
modes_m = prod_sign * torch.exp(log_mag) # [N, krank]
# ── ProyecciΓ³n de salida del cabezal ──────────────────────────────
# head_out_weight[m]: [krank, hidden_size]
# NOTA: NO usar F.linear aquΓ­. F.linear(A, W) computa A @ W.T,
# esperando W con shape [out, in] = [hidden, krank]. Pero
# head_out_weight estΓ‘ almacenado como [krank, hidden] (igual que
# el einsum original "nmk,mkd->nd" que contrae sobre k sin transponer).
# La multiplicaciΓ³n correcta es modes_m @ W directamente:
# [N, krank] @ [krank, hidden] β†’ [N, hidden]
out_m = (
modes_m.to(self.head_out_weight.dtype)
@ self.head_out_weight[m]
) # [N, hidden_size]
return out_m
def _khronos_all_heads(
self,
B_all: torch.Tensor,
) -> torch.Tensor:
"""
Vectorized KHRONOS tensor-product for all heads at once.
Mathematically identical to calling _khronos_head_product 8 times,
but uses a single einsum over the head dimension. The sign-parity
aggregation is performed independently per head via the M dimension.
Args:
B_all: [N, M, d_seed, n_knots] float32
Returns:
[N, M, krank] in float32.
"""
# per_dim: [N, M, d_seed, krank]
# einsum: token n, head m, seed-dim d, knot g β†’ krank k
per_dim = torch.einsum(
"nmdg,mdgk->nmdk",
B_all,
self.head_spline.float(),
)
per_dim_abs = per_dim.abs() + 1e-9
# Sum log-magnitudes over d_seed dimension β†’ [N, M, krank]
log_mag = torch.log(per_dim_abs).sum(dim=2)
num_neg = (per_dim < 0).long().sum(dim=2)
prod_sign = 1.0 - 2.0 * (num_neg % 2).float()
return prod_sign * torch.exp(log_mag) # [N, M, krank]
def _modes_from_basis(
self,
B_vals: torch.Tensor,
spline_coeff: torch.Tensor,
target_dtype: torch.dtype,
) -> torch.Tensor:
"""
Shared tensor-product aggregation used by both the input generator
and JTok-M surfaces.
phi_{r,j} = spline_coeff[j,r,:] Β· B_vals[n,r,:]
M_j = sign_j * exp(Ξ£_r log|phi_{r,j}|) (KHRONOS sign-parity)
Args:
B_vals: [N, d_seed, n_knots] float32
spline_coeff: [num_modes, d_seed, n_knots]
target_dtype: output dtype
Returns:
modes: [N, num_modes] in target_dtype
"""
phi = torch.einsum(
"jrk,nrk->njr",
spline_coeff.float(),
B_vals,
) # [N, num_modes, d_seed]
phi_abs = phi.abs() + 1e-9
log_mag = torch.log(phi_abs).sum(dim=-1) # [N, M]
num_neg = (phi < 0).long().sum(dim=-1) # [N, M]
prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, M]
return (prod_sign * torch.exp(log_mag)).to(target_dtype) # [N, M]
def forward(
self,
token_ids: torch.Tensor,
return_internals: bool = False,
analysis: Optional[GeneratorAnalysis] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Generate embeddings from discrete token indices.
Two parallel paths run from the shared codebook output z:
JTok-M path (return_internals):
z β†’ seed_proj β†’ seed_norm β†’ sigmoid β†’ z_tilde [N, d_seed]
z_tilde β†’ _bspline_basis (fixed scale) β†’ B_vals [N, d_seed, n_knots]
These are returned for reuse by every decoder layer's JTok-M
surfaces without redundant B-spline evaluation.
Generator path (per-head, real Leviathan architecture):
For each head i:
z β†’ head_proj_i β†’ head_norm_i β†’ sigmoid(x/2) β†’ z_tilde_i
z_tilde_i β†’ _bspline_basis_head(scale_i) β†’ B_vals_i
B_vals_i Γ— head_spline_i β†’ _khronos_head_product β†’ [N, krank]
head_out_i([N, krank]) β†’ [N, hidden_size]
e = sum of all head_out_i outputs (no W_res, per author's code)
Args:
token_ids: (batch, seq_len) or (seq_len,)
return_internals: if True, also return z_tilde and B_vals for
reuse by JTok-M surfaces in every decoder layer.
analysis: Optional[GeneratorAnalysis] β€” when not None and model
is in eval mode with analysis armed, deposits all
internal tensors (detached). No-op during training.
Returns:
embeddings [*token_ids.shape, hidden_size],
or (embeddings, z_tilde [N, d_seed], B_vals [N, d_seed, n_knots])
when return_internals=True.
"""
target_dtype = self.codebooks.dtype
orig_shape = token_ids.shape
N = token_ids.numel()
# ── Shared Stage 1: compositional codebook indexing ───────────────
coords = self._base_k_decompose(token_ids)
coords_flat = coords.reshape(N, self.k)
z = torch.zeros(N, self.d_seed, device=token_ids.device, dtype=target_dtype)
for r in range(self.k):
z = z + self.codebooks[r][coords_flat[:, r]]
if analysis is not None:
analysis.z_raw = z.detach()
# ── JTok-M shared path ────────────────────────────────────────────
# Produces z_tilde and B_vals consumed by every decoder layer's
# JTok-M module. This path is unchanged and uses fixed scalar scale.
z_tilde = torch.sigmoid(self.seed_norm(self.seed_proj(z))) # [N, d_seed]
B_vals = self._bspline_basis(z_tilde.clamp(0.0, 1.0)) # [N, d_seed, n_knots]
if analysis is not None:
analysis.z_tilde = z_tilde.detach()
analysis.B_vals = B_vals.detach()
# ── Per-head generator path (secuencial, un cabezal a la vez) ──────
# ORIGINAL PROBLEM: el path vectorizado anterior procesaba los M
# cabezales en paralelo con kernels fusionados:
#
# _bspline_basis_all_heads β†’ [N, M, d_seed, n_knots] ← TENSOR GIGANTE
# _khronos_all_heads β†’ per_dim [N, M, d_seed, krank] ← AÚN MAYOR
#
# Con N=B*S=32768, M=8, d_seed=128, n_knots=32, krank=16:
# [N,M,d_seed,n_knots] = 32768 Γ— 8 Γ— 128 Γ— 32 Γ— 4 bytes β‰ˆ 512 MB
# [N,M,d_seed,krank] = 32768 Γ— 8 Γ— 128 Γ— 16 Γ— 4 bytes β‰ˆ 256 MB
# Estos tensores viven simultΓ‘neamente en el pool de CUDAGraphs,
# causando OOM en el backward cuando se suman las activaciones guardadas
# de las 12 capas del decoder.
#
# SOLUCIΓ“N (equivalente a la impl. JAX de Reza):
# Loop Python sobre M=8 cabezales (count fijo β†’ TorchDynamo unrollea
# en 8 secuencias de ops estΓ‘ticas sin graph breaks).
# Cada cabezal materializa como mΓ‘ximo [N, d_seed, krank] β‰ˆ 32 MB.
# La suma se acumula in-place β†’ el tensor del cabezal anterior puede
# ser liberado por el allocator antes de procesar el siguiente.
#
# Por quΓ© NO vmap(chunk_size=1):
# vmap requiere que la funciΓ³n sea "pura" (sin acceso a self.*).
# head_norm_eps, knot_grid y los parΓ‘metros indexados [m] se pasan
# implΓ­citamente a travΓ©s del closure. Con vmap habrΓ­a que
# stack_module_state + functional_call, lo que aΓ±ade overhead de
# instrumentaciΓ³n sin beneficio real ya que el loop estΓ‘tico es
# igualmente trazable por el compilador y produce el mismo grafo.
target_dtype = self.codebooks.dtype
e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
for m in range(self.num_modes):
e = e + self._compute_head(z, m)
# No W_res β€” confirmed absent in the authors' implementation
e = e.reshape(*orig_shape, self.hidden_size)
if analysis is not None:
analysis.embeddings = e.detach()
if return_internals:
return e, z_tilde, B_vals
return e
# ==================== LEVIATHAN-JTOK-M MODULATION MODULE ====================
class LeviathanJTokM(nn.Module):
"""
Leviathan-JTok-M token-indexed modulation module for one decoder layer.
Fuses the Leviathan continuous geometry with JTok-M (Yang et al., 2026):
- Instead of per-token lookup tables (O(VΒ·D) per layer), uses n_e independent
CP-separable surfaces over the shared z̃_x from the Leviathan generator,
reusing B_vals already computed in the embedding stage.
- Context-dependent router: gates over h̃^ℓ_x (hidden state after attention)
using Sigmoid+TopK β€” not Softmax β€” to avoid inter-surface competition.
- Additive injection with 1/√(2β„“) scaling coordinated with the existing LNS
factor 1/βˆšβ„“, maintaining a constant JTok-M / backbone ratio of 1/√2 β‰ˆ 0.707
across all depths (instead of 1/√(2N_l) which would grow JTok-M dominance
in deep layers as LNS suppresses backbone activations).
- Fully vectorized: all surfaces evaluated in one einsum, TopK with fixed K
produces static shapes β€” compatible with torch.compile max-autotune.
Scaling note (LNS coordination):
LNS applies 1/βˆšβ„“ to backbone sublayer inputs (β„“ = 1-indexed layer).
JTok-M applies 1/√(2β„“) to its injection residual.
Ratio: [1/√(2β„“)] / [1/βˆšβ„“] = 1/√2 β€” constant at every depth.
Parameter cost per layer (defaults n_e=5, M_mod=4, d_seed=128, D=512):
spline_coeff: n_e Γ— M_mod Γ— d_seed Γ— n_knots = 5Γ—4Γ—128Γ—32 = 81,920
W_out: n_e Γ— M_mod Γ— D = 5Γ—4Γ—512 = 10,240
W_res: n_e Γ— d_seed Γ— D = 5Γ—128Γ—512 = 327,680
router R: D Γ— n_e = 512Γ—5 = 2,560
scaler s: D = 512
Total per layer: ~422,912 β†’ ~5.07M for 12 layers.
References:
Yang, Y. et al. (2026). JTok. arXiv:2602.00800.
Batley & Saha (2026). Leviathan. arXiv:2601.22040.
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.n_e = config.jtokm_num_experts
self.top_k = config.jtokm_top_k
self.M_mod = config.jtokm_num_modes
self.d_seed = config.generator_d_seed
self.n_knots = config.generator_num_knots
self.hidden_size = config.hidden_size
self.norm_eps = config.jtokm_norm_eps
# LNS-coordinated scale: 1/√(2β„“), β„“ = layer_idx + 1 (1-indexed)
ell = max(layer_idx + 1, 1)
self.lns_scale = 1.0 / math.sqrt(2.0 * ell)
# n_e CP-separable surfaces β€” each with its own spline_coeff, W_out, W_res
# Stored as fused tensors for a single vectorized einsum:
# spline_coeff: [n_e, M_mod, d_seed, n_knots]
# W_out: [n_e, M_mod, D]
# W_res: [n_e, d_seed, D]
self.spline_coeff = nn.Parameter(
torch.empty(self.n_e, self.M_mod, self.d_seed, self.n_knots)
)
self.W_out = nn.Parameter(
torch.empty(self.n_e, self.M_mod, self.hidden_size)
)
self.W_res = nn.Parameter(
torch.empty(self.n_e, self.d_seed, self.hidden_size)
)
# Context-dependent router: RMSNorm(h̃) @ R → [N, n_e]
self.router = nn.Linear(config.hidden_size, self.n_e, bias=False)
# Learnable per-dimension scaler (JTok eq. 7 / JTok-M eq. 12)
self.scaler = nn.Parameter(torch.ones(config.hidden_size))
# ── Surface evaluation ────────────────────────────────────────────────
def _eval_surfaces(
self,
B_vals: torch.Tensor,
z_tilde: torch.Tensor,
target_dtype: torch.dtype,
analysis: Optional[JTokMAnalysis] = None,
) -> torch.Tensor:
"""
Evaluate all n_e surfaces vectorized over the full token batch.
phi[n, i, j, r] = spline_coeff[i, j, r, :] Β· B_vals[n, r, :]
M[n, i, j] = sign * exp(Ξ£_r log|phi[n,i,j,r]|)
m[n, i] = W_out[i] @ M[n,i] + W_res[i] @ z̃[n]
All shapes are static β†’ torch.compile compatible.
Args:
B_vals: [N, d_seed, n_knots] float32
z_tilde: [N, d_seed]
target_dtype: model dtype
analysis: Optional[JTokMAnalysis] β€” deposits surfaces when not None.
Returns:
surfaces: [N, n_e, D]
"""
N = B_vals.shape[0]
# phi: [N, n_e, M_mod, d_seed]
# einsum: "ijrk, nrk -> nijr" where i=n_e, j=M_mod, r=d_seed, k=n_knots
phi = torch.einsum(
"ijrk,nrk->nijr",
self.spline_coeff.float(), # [n_e, M_mod, d_seed, n_knots]
B_vals, # [N, d_seed, n_knots]
) # [N, n_e, M_mod, d_seed]
# KHRONOS sign-parity product aggregation over d_seed
phi_abs = phi.abs() + 1e-9
log_mag = torch.log(phi_abs).sum(dim=-1) # [N, n_e, M_mod]
num_neg = (phi < 0).long().sum(dim=-1) # [N, n_e, M_mod]
prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, n_e, M_mod]
modes = (prod_sign * torch.exp(log_mag)).to(target_dtype)
# modes: [N, n_e, M_mod]
# W_out projection: [N, n_e, M_mod] Γ— [n_e, M_mod, D] β†’ [N, n_e, D]
out_modes = torch.einsum("nim,imd->nid", modes, self.W_out.to(target_dtype))
# W_res residual: z_tilde [N, d_seed] Γ— [n_e, d_seed, D] β†’ [N, n_e, D]
z = z_tilde.to(target_dtype)
out_res = torch.einsum("nd,idc->nic", z, self.W_res.to(target_dtype))
surfaces = out_modes + out_res # [N, n_e, D]
if analysis is not None:
analysis.surfaces = surfaces.detach()
return surfaces
# ── Router ────────────────────────────────────────────────────────────
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
def _route_and_mix(
self,
h_tilde: torch.Tensor,
surfaces: torch.Tensor,
analysis: Optional[JTokMAnalysis] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Context-dependent routing over h_tilde (hidden state after attention).
Sigmoid + TopK (not Softmax) avoids inter-surface competition:
g = RMSNorm(h̃) @ R [N, n_e]
K winners selected per token
w_i = Οƒ(g_i) / Ξ£_{j∈K} Οƒ(g_j) (normalized over selected only)
e = Σ_{i∈K} w_i * surfaces[n, i]
All tensor shapes are static. TopK with fixed K returns [N, K] indices.
torch.compile compatible.
Also computes load-balancing statistics p_i and f_i for aux loss.
Args:
h_tilde: [N, D] β€” hidden state after attention (before MLP)
surfaces: [N, n_e, D]
analysis: Optional[JTokMAnalysis] β€” deposits routing data when not None.
Returns:
mixed: [N, D]
aux_stats: (p_sum [n_e], f_sum [n_e], N) for loss accumulation
"""
N = h_tilde.shape[0]
# Router logits [N, n_e]
g = self.router(self._rms_norm(h_tilde))
# TopK selection β€” static shape [N, top_k]
topk_vals, topk_idx = torch.topk(g, self.top_k, dim=-1)
# Sigmoid weights over selected K surfaces (JTok-M eq. 10-11)
sig_vals = torch.sigmoid(topk_vals) # [N, K]
w = sig_vals / sig_vals.sum(dim=-1, keepdim=True) # [N, K]
# Gather selected surfaces [N, K, D] and weight-sum
# topk_idx: [N, K] β†’ expand to [N, K, D]
idx_exp = topk_idx.unsqueeze(-1).expand(N, self.top_k, self.hidden_size)
selected = surfaces.gather(dim=1, index=idx_exp) # [N, K, D]
mixed = (w.unsqueeze(-1) * selected).sum(dim=1) # [N, D]
if analysis is not None:
analysis.router_logits = g.detach()
analysis.topk_indices = topk_idx.detach()
analysis.routing_weights = w.detach()
analysis.mixed_pre_norm = mixed.detach()
# Load-balancing statistics for aux loss (Appendix B, Yang et al. 2026)
# p_i = mean routing probability over batch
# f_i = fraction of tokens actually routed to i
with torch.no_grad():
sig_all = torch.sigmoid(g) # [N, n_e]
p_sum = sig_all.sum(dim=0) # [n_e]
# one-hot mask of selections: [N, n_e]
onehot = torch.zeros_like(g).scatter_(
1, topk_idx, 1.0
)
f_sum = onehot.sum(dim=0) # [n_e]
return mixed, (p_sum, f_sum, N)
# ── Forward ───────────────────────────────────────────────────────────
def forward(
self,
h_tilde: torch.Tensor,
z_tilde: torch.Tensor,
B_vals: torch.Tensor,
analysis: Optional[JTokMAnalysis] = None,
) -> Tuple[torch.Tensor, Tuple]:
"""
Compute additive JTok-M residual for one decoder layer.
Args:
h_tilde: [N, D] hidden state after attention (before MLP)
z_tilde: [N, d_seed] latent coordinate from generator
B_vals: [N, d_seed, n_k] B-spline basis (computed once, reused)
analysis: Optional[JTokMAnalysis] β€” deposits all JTok-M internals when
not None. No-op during training (analysis is always None then).
Returns:
delta_r: [N, D] additive residual (already scaled)
aux_stats: tuple for accumulating load-balance loss
"""
target_dtype = h_tilde.dtype
# All n_e surfaces in one vectorized pass
surfaces = self._eval_surfaces(B_vals, z_tilde, target_dtype, analysis=analysis)
# Context-dependent routing
mixed, aux_stats = self._route_and_mix(h_tilde, surfaces, analysis=analysis)
# Normalise direction, apply scaler, scale with 1/√(2β„“)
# Norm_Ξ΅ decouples direction from magnitude (JTok Appendix D.2)
mixed_norm = mixed / (mixed.norm(dim=-1, keepdim=True) + self.norm_eps)
delta_r = self.lns_scale * self.scaler * mixed_norm # [N, D]
if analysis is not None:
analysis.mixed_normalized = mixed_norm.detach()
analysis.delta_r = delta_r.detach()
analysis.p_sum = aux_stats[0].detach()
analysis.f_sum = aux_stats[1].detach()
analysis.lns_scale = self.lns_scale
return delta_r, aux_stats
def compute_jtokm_aux_loss(
aux_stats_list: list,
n_e: int,
weight: float,
) -> torch.Tensor:
"""
Aggregate load-balancing aux loss over all active layers.
L_aux = Ξ» Β· n_e Β· Ξ£_i p_i Β· f_i
averaged over all layers with active JTok-M.
Args:
aux_stats_list: list of (p_sum [n_e], f_sum [n_e], N) per layer
n_e: number of experts
weight: Ξ» coefficient
Returns:
scalar loss tensor
"""
total_loss = None
for p_sum, f_sum, N in aux_stats_list:
p_i = p_sum / N # average routing probability [n_e]
f_i = f_sum / (N * 1.0) # load fraction [n_e]
layer_loss = weight * n_e * (p_i * f_i).sum()
total_loss = layer_loss if total_loss is None else total_loss + layer_loss
if total_loss is None:
return torch.tensor(0.0)
return total_loss / len(aux_stats_list)
# ==================== ORIGINAL COMPONENTS ====================
class FANLayer(nn.Module):
"""
Fourier Analysis Network (FAN) layer.
FANLayer'(X) = [cos(WpX) || sin(WpX) || (WpΒ―X + BpΒ―)]
"""
def __init__(self, hidden_size: int, fan_ratio: float = 0.25):
super().__init__()
self.hidden_size = hidden_size
self.fan_ratio = fan_ratio
output_dim = hidden_size + int(hidden_size * fan_ratio)
self.p_output_dim = int(output_dim * fan_ratio)
self.g_output_dim = output_dim - self.p_output_dim * 2
self.input_linear = nn.Linear(
hidden_size, self.p_output_dim + self.g_output_dim, bias=True
)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02)
if self.input_linear.bias is not None:
nn.init.zeros_(self.input_linear.bias)
def forward(
self,
x: torch.Tensor,
analysis: Optional[FANAnalysis] = None,
) -> torch.Tensor:
pg = self.input_linear(x)
p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1)
cos_p = torch.cos(p)
sin_p = torch.sin(p)
if analysis is not None:
analysis.cosine_component = cos_p.detach()
analysis.sine_component = sin_p.detach()
analysis.linear_component = g.detach()
return torch.cat([cos_p, sin_p, g], dim=-1)
class LNS(nn.Module):
"""
LayerNorm Scaling: applies 1/βˆšβ„“ to suppress variance growth with depth.
From "The Curse of Depth in Large Language Models".
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = max(layer_idx + 1, 1)
self.scale = 1.0 / math.sqrt(self.layer_idx)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
class GPAS(nn.Module):
"""Gradient-Preserving Activation Scaling."""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.alpha = nn.Parameter(torch.zeros(1))
def forward(
self,
x: torch.Tensor,
analysis: Optional[GPASAnalysis] = None,
) -> torch.Tensor:
silu_alpha = F.silu(self.alpha)
subtracted = silu_alpha * x.detach()
if analysis is not None:
analysis.silu_alpha = silu_alpha.detach()
analysis.subtracted_component = subtracted.detach()
return x - subtracted
class SeeDNorm(nn.Module):
"""
Self-Rescaled Dynamic Normalization.
SeeDNorm(x) = [Οƒ(xΒ·Ξ²^T)Β·Ξ± + Ξ³] βŠ™ x/RMS(x)
rescale_factor = tanh(x · β) ∈ (-1, 1) escalar por token
dynamic_scale = rescale_factor Β· Ξ± + Ξ³ ∈ ℝ^dim
output = dynamic_scale βŠ™ RMSNorm(x)
"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.dim = dim
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.alpha = nn.Parameter(torch.ones(dim))
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(
self,
x: torch.Tensor,
analysis: Optional[SeeDNormAnalysis] = None,
) -> torch.Tensor:
rescale_factor = torch.tanh(
torch.sum(x * self.beta, dim=-1, keepdim=True)
)
dynamic_scale = rescale_factor * self.alpha + self.gamma
x_normalized = self._rms_norm(x.float())
output = (x_normalized * dynamic_scale.float()).type_as(x)
if analysis is not None:
analysis.rescale_factor = rescale_factor.detach()
analysis.dynamic_scale = dynamic_scale.detach()
analysis.x_normalized = x_normalized.detach()
analysis.output = output.detach()
return output
def extra_repr(self) -> str:
return f"dim={self.dim}, eps={self.eps}"
# ==================== ROTARY EMBEDDING ====================
class NeoLLMRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: NeoLLMConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = "default"
if (hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and isinstance(config.rope_scaling, dict)):
rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
if rope_type and rope_type in ROPE_INIT_FUNCTIONS:
self.rope_type = rope_type
rope_init_fn = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: NeoLLMConfig = None,
device: Optional["torch.device"] = None,
seq_len: int = None,
) -> tuple["torch.Tensor", float]:
base = config.rope_theta
dim = getattr(config, "head_dim", None) or \
config.hidden_size // config.num_attention_heads
dim = int(dim * getattr(config, "partial_rotary_factor", 1.0))
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64)
.to(device=device, dtype=torch.float) / dim)
)
return inv_freq, 1.0
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
B = x.shape[0]
if position_ids.shape[0] != B:
position_ids = position_ids.expand(B, -1)
device_type = (x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu")
if self.inv_freq.device.type == "meta":
inv_freq_data, _ = self.compute_default_rope_parameters(
self.config, device=x.device
)
self.register_buffer("inv_freq", inv_freq_data, persistent=False)
self.register_buffer("original_inv_freq", inv_freq_data.clone(), persistent=False)
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1)
* inv_freq.unsqueeze(0).unsqueeze(0))
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def linear_clipping(x: torch.Tensor) -> torch.Tensor:
"""
Piecewise-linear activation for Affine-Scaled Attention scaling factors.
Avoids the saturation problem of sigmoid, which collapses most outputs
toward 0 or 1 and loses the intermediate scaling range the model needs
for fine-grained per-query attention modulation.
f(x) = 0 if x ≀ -5
= 0.1Β·x + 0.5 if -5 < x < 5
= 1 if x β‰₯ 5
Equivalent to: clamp(0.1Β·x + 0.5, 0, 1).
Output range: [0, 1]. Gradient: 0.1 across the entire non-saturated region.
Reference: Bae et al. (2026), Affine-Scaled Attention Β§6.
"""
return torch.clamp(0.1 * x + 0.5, 0.0, 1.0)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
return torch.cat([q_embed, q_pass], dim=-1), torch.cat([k_embed, k_pass], dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def causal_first_difference(x: torch.Tensor) -> torch.Tensor:
return x - F.pad(x[..., :-1, :], (0, 0, 1, 0))
def rms_key_unit_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
return F.normalize(x.float(), p=2, dim=-1, eps=eps) * math.sqrt(x.shape[-1])
def infer_key_validity(
attention_mask: Optional[torch.Tensor], seq_len: int, num_heads: int
) -> Optional[torch.Tensor]:
if attention_mask is None or attention_mask.ndim != 4:
return None
if attention_mask.shape[-2] != seq_len or attention_mask.shape[-1] != seq_len:
return None
diag = attention_mask.diagonal(dim1=-2, dim2=-1)
valid = torch.isfinite(diag) & (diag == 0)
if valid.shape[1] == 1 and num_heads != 1:
valid = valid.expand(-1, num_heads, -1)
elif valid.shape[1] != num_heads:
valid = valid[:, :1, :].expand(-1, num_heads, -1)
return valid
def head_linear_compose(
hidden_states: torch.Tensor, mixing_matrix: torch.Tensor
) -> torch.Tensor:
return torch.einsum(
"bhtd,hk->bktd",
hidden_states,
mixing_matrix.to(device=hidden_states.device, dtype=hidden_states.dtype),
)
class MEAHeadSeeDNorm(nn.Module):
"""MEA head-level normalization using SeeDNorm grouped by KV structure (GQA-aware)."""
def __init__(self, num_heads: int, head_dim: int, num_kv_groups: int, eps: float = 1e-6):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_groups = num_kv_groups
self.num_kv_heads = num_heads // num_kv_groups
self.group_dim = num_kv_groups * head_dim
self.norm = SeeDNorm(self.group_dim, eps=eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch, seq_len, num_heads, head_dim = hidden_states.shape
if num_heads != self.num_heads or head_dim != self.head_dim:
raise ValueError(
f"MEAHeadSeeDNorm expected ({self.num_heads}, {self.head_dim}), "
f"received ({num_heads}, {head_dim})"
)
grouped = hidden_states.reshape(batch, seq_len, self.num_kv_heads, self.group_dim)
return self.norm(grouped).reshape(batch, seq_len, num_heads, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
return attn_output, attn_weights
def affine_scaled_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
alpha: torch.Tensor,
beta: torch.Tensor,
dropout: float = 0.0,
attn_analysis: Optional[AttentionAnalysis] = None,
**kwargs: Unpack[TransformersKwargs],
):
"""
Affine-Scaled Attention (eager path).
Replaces the standard weighted sum softmax(QK^T/√dk) V with:
[α(X) · softmax(QK^T/√dk) + β(X)] V
Ξ± is a per-head, per-query input-dependent scale in [0, 1].
Ξ² is an input-dependent bias that compensates for deviations of Ξ± from its
running average, preventing the effective attention mass from collapsing.
Both Ξ± and Ξ² are computed in NeoLLMAttention.forward and passed in;
this function only performs the affine reweighting and the value aggregation.
The existing Gated Attention gate (applied post-SDPA to the concatenated
output before o_proj) is orthogonal to this and is not modified here.
Reference: Bae et al. (2026), Affine-Scaled Attention, Eq. 6–8.
Args:
alpha: [batch, num_heads, seq_q, 1] β€” input-dependent scale per query
beta: [batch, num_heads, seq_q, 1] β€” input-dependent bias per query
attn_analysis: Optional[AttentionAnalysis] β€” deposits pre/post-affine weights
when not None. No-op during training (always None then).
"""
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights_softmax = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query.dtype)
if attn_analysis is not None:
attn_analysis.attn_weights_pre_affine = attn_weights_softmax.detach()
# Affine reweighting: relaxes the unit-sum constraint of softmax.
# α ∈ [0, 1] scales down the softmax distribution (input-adaptively per head).
# Ξ² offsets to prevent collapse when Ξ± deviates from its running mean.
# Shapes: Ξ±, Ξ² are [B, H, S_q, 1], attn_weights is [B, H, S_q, S_k] β†’ broadcast.
attn_weights_affine = alpha * attn_weights_softmax + beta
if attn_analysis is not None:
attn_analysis.attn_weights_post_affine = attn_weights_affine.detach()
attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous()
attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
return attn_output, attn_weights_affine
def affine_scaled_flash_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
alpha: torch.Tensor,
beta: torch.Tensor,
dropout: float = 0.0,
attn_analysis: Optional[AttentionAnalysis] = None,
**kwargs: Unpack[TransformersKwargs],
):
"""
Affine-Scaled Attention β€” flash/sdpa path.
Exact mathematical decomposition of [Ξ±Β·softmax(QKα΅€)+Ξ²]V using only the
public flash/sdpa interface β€” no kernel modification required.
Derivation
----------
The paper formula expands distributively:
[Ξ± Β· softmax(QKα΅€/√dk) + Ξ²] Β· V
= Ξ± Β· [softmax(QKα΅€/√dk) Β· V] ← term 1: standard flash output
+ Ξ² Β· [Ξ£_{j≀i} V_j] ← term 2: causal prefix-sum of V
Term 2 follows because Ξ² is a scalar per query (broadcast over all S_k),
and softmax weights sum to 1 over the causal window:
Ξ²_i Β· Ξ£_j w_{i,j} Β· V_j = Ξ²_i Β· Ξ£_{j≀i} V_j (since Ξ£ w_{i,j} = 1)
Dropout
-------
The eager path drops entries of the combined weight matrix (Ξ±Β·softmax + Ξ²)
before multiplying by V. With the flash interface we cannot access that
combined matrix, so we apply dropout=0 to the flash kernel and instead
apply nn.functional.dropout to the final combined output tensor. This is
output dropout rather than weight dropout β€” a different (but standard)
regularisation that achieves the same intent without the intermediate
weight matrix. During inference dropout=0 so the paths are identical.
Padding mask
------------
The V_cumsum must not accumulate values from padding positions, since the
flash kernel zeros those out internally but a plain cumsum does not.
We zero V at padding positions before the cumsum by reading the diagonal
of the attention_mask (valid positions have mask value 0, padding has -inf).
Memory overhead vs standard flash call
---------------------------------------
One extra tensor of shape [B, H_q, S, d_head] for V_cumsum.
At (B=1, H=8, S=2048, d_head=64, bf16): β‰ˆ 4 MB per call, β‰ˆ 48 MB total
across 12 layers. Allocated and freed within each forward call.
Attention-weight analysis fields (attn_weights_pre_affine,
attn_weights_post_affine) remain None β€” those tensors are never
materialised by the flash kernel and cannot be recovered.
All other AnalysisState fields (alpha, beta, alpha_ma) are deposited
by the caller before this function is invoked.
Args:
alpha: [B, H_q, S_q, 1] β€” input-dependent scale per query, in [0, 1]
beta: [B, H_q, S_q, 1] β€” moving-average bias per query
attn_analysis: deposited by caller; this function does not write to it.
"""
# ── Term 1: standard flash / sdpa output ─────────────────────────────
# dropout=0.0: we apply dropout to the combined output below instead.
attn_fn = ALL_ATTENTION_FUNCTIONS[module.config._attn_implementation]
flash_out, _ = attn_fn(
module, query, key, value, attention_mask,
dropout=0.0, scaling=scaling, **kwargs,
)
# flash_out: [B, S, H_q, d_head] β€” HF wrappers all return this layout
# ── Term 2: Ξ² Β· causal prefix-sum of V ───────────────────────────────
# Expand V from KV heads to query heads for GQA.
value_expanded = repeat_kv(value, module.num_key_value_groups) # [B, H_q, S, d_head]
# Zero out padded positions so they don't accumulate into the prefix sum.
# attention_mask is [B, 1, S, S] with 0 at valid positions, -inf at padding.
# The diagonal gives the per-position validity: valid β†’ 0, padding β†’ -inf.
if attention_mask is not None and attention_mask.ndim == 4:
diag = attention_mask.diagonal(dim1=-2, dim2=-1) # [B, 1, S]
# valid=True where diag==0 (not -inf), padding=False
valid = (diag == 0).to(value_expanded.dtype) # [B, 1, S]
valid = valid.unsqueeze(-1) # [B, 1, S, 1]
# broadcast over H_q and d_head; zero out V at padding positions
value_expanded = value_expanded * valid # [B, H_q, S, d_head]
# Causal prefix-sum: position i accumulates all valid j ≀ i.
v_cumsum = value_expanded.cumsum(dim=2) # [B, H_q, S, d_head]
# Transpose to match flash_out layout.
v_cumsum_t = v_cumsum.transpose(1, 2).contiguous() # [B, S, H_q, d_head]
# Ξ±, Ξ²: [B, H_q, S, 1] β†’ [B, S, H_q, 1] to broadcast over d_head.
alpha_t = alpha.permute(0, 2, 1, 3) # [B, S, H_q, 1]
beta_t = beta.permute(0, 2, 1, 3) # [B, S, H_q, 1]
# ── Combine and apply dropout to the full affine output ───────────────
output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
output = nn.functional.dropout(output, p=dropout, training=module.training)
# attn_weights is None β€” flash never exposes the softmax weight matrix.
return output, None
class HadamardOProj(nn.Module):
"""
Parameter-free Walsh–Hadamard output projection with learnable affine rescaling.
Replaces the dense W_O ∈ R^{dΓ—d} in multi-head attention with a fixed
orthogonal Walsh–Hadamard Transform followed by a per-channel learnable
affine: output = Ξ± βŠ™ FWHT(x) + Ξ²
Motivation (Aggarwal & Kumar, 2026, arXiv:2603.08343):
The standard dense o_proj develops extreme condition numbers during
training (ΞΊ up to 10^5 observed in practice) because the optimiser has
no incentive to keep singular values balanced β€” some directions are
amplified while others collapse toward zero. This makes the layer
hostile to FP8 quantisation, which uses a single per-tensor scale and
therefore loses the low-magnitude directions entirely.
The Walsh–Hadamard Transform is a fixed orthogonal matrix whose
singular values are all identically 1, making ΞΊ = 1 by construction.
It cannot develop condition-number pathology because it has no
parameters. The learnable Ξ±/Ξ² restore per-channel expressivity at
a cost of 2Β·d parameters instead of dΒ².
Properties:
- Condition number: ΞΊ = 1 (exact, permanent, by construction)
- Parameters: 2Β·d vs dΒ² for dense (~25% attention params saved)
- Forward FLOPs: O(d log d) vs O(dΒ²) for dense
- Norm preservation: FWHT is isometric β€” β€–FWHT(x)β€–β‚‚ = β€–xβ€–β‚‚
- FP8 friendliness: single per-tensor scale covers all directions equally
- Requires: d must be a power of 2
The FWHT is implemented as an in-place iterative butterfly (Cooley-Tukey
pattern over additions/subtractions) followed by 1/√d normalisation to
produce an orthonormal transform (H^T H = I). No external dependency.
Reference:
Aggarwal, S. & Kumar, L. (2026). "Rethinking Attention Output
Projection: Structured Hadamard Transforms for Efficient Transformers."
arXiv:2603.08343.
"""
def __init__(self, dim: int, bias: bool = True):
super().__init__()
assert dim > 0 and (dim & (dim - 1)) == 0, (
f"HadamardOProj requires dim to be a power of 2, got {dim}"
)
self.dim = dim
self.norm = dim ** -0.5 # 1/√d β€” makes H^T H = I
# Learnable affine rescaling: Ξ± βŠ™ FWHT(x) + Ξ²
# Initialised to Ξ±=1, Ξ²=0 so the layer starts as a pure WHT,
# identical to an orthonormal projection with unit gain.
self.alpha = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim)) if bias else None
def _fwht(self, x: torch.Tensor) -> torch.Tensor:
"""
Iterative in-place Fast Walsh–Hadamard Transform over the last dim.
Butterfly pattern: logβ‚‚(d) stages, each pairing elements at stride h.
Cost: dΒ·logβ‚‚(d) additions/subtractions, zero multiplications.
Compatible with torch.compile β€” all shapes are static, no Python loops
visible to the tracer once d is fixed.
"""
h = 1
while h < self.dim:
# Reshape to expose pairs at current stride
x = x.reshape(*x.shape[:-1], -1, 2 * h)
a, b = x[..., :h], x[..., h:]
# Butterfly: (a+b, a-b) β€” only additions and subtractions
x = torch.cat([a + b, a - b], dim=-1)
x = x.reshape(*x.shape[:-2], self.dim)
h *= 2
return x
def forward(
self,
x: torch.Tensor,
analysis: Optional["HadamardAnalysis"] = None,
) -> torch.Tensor:
"""
Args:
x: [..., dim] β€” concatenated multi-head attention outputs
analysis: HadamardAnalysis container populated when analysis mode
is active (eval + model.enable_analysis()). None otherwise.
Returns:
Ξ± βŠ™ (FWHT(x) / √dim) + Ξ² of shape [..., dim]
"""
out = self._fwht(x) * self.norm # normalise: H^T H = I
if analysis is not None:
analysis.post_fwht = out.detach()
analysis.alpha_snapshot = self.alpha.detach()
out = out * self.alpha # per-channel learnable scale
if self.beta is not None:
out = out + self.beta # per-channel learnable bias
return out
class REPOModule(nn.Module):
"""
Context Re-Positioning module f_Ο• (Li et al., 2026, arXiv:2512.14391).
Replaces the fixed linear integer indices ``0…L-1`` fed to RoPE with
continuous, data-dependent positions ``z_i`` learned end-to-end.
Architecture (Eq. 4–6 of the paper):
# Position representation β€” shared across all heads in this layer
r_i = Swish(h_i W_g) βŠ™ (h_i W_c) r_i ∈ R^{d_p}
# Position assignment β€” independent per head
z_i^(h) = r_i w_z^(h) z_i^(h) ∈ R (scalar)
where ``h_i ∈ R^d`` is the hidden state of token ``i`` entering the
decoder layer (pre-FANLayer), and ``d_p = hidden_size // 8`` by default.
The resulting positions ``z [B, H, S]`` are real-valued and
unconstrained. They are used to compute per-head ``cos/sin`` embeddings
inline from ``inv_freq``, replacing the standard integer-based
``position_embeddings`` for this layer.
Design notes:
- ``W_g`` and ``W_c`` are shared across heads (parameter efficiency).
- ``W_z`` is a single ``[d_p, num_heads]`` matrix; each column is the
per-head assignment vector ``w_z^(h)``. Vectorized as one matmul.
- The raw hidden state ``h_i`` (not the FAN-augmented or normed variant)
is used as input, matching the paper's formulation and avoiding
circular dependency with q/k norm.
- No bias on any projection β€” consistent with the paper's Eq. 4–5.
Reference:
Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). "REPO: Language
Models with Context Re-Positioning." arXiv:2512.14391.
"""
def __init__(self, hidden_size: int, d_p: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.d_p = d_p
self.num_heads = num_heads
# SwiGLU position representation (shared across heads, Eq. 4)
self.W_g = nn.Linear(hidden_size, d_p, bias=False)
self.W_c = nn.Linear(hidden_size, d_p, bias=False)
# Per-head position assignment (vectorized, Eq. 5)
# W_z[:, h] is w_z^(h) for head h
self.W_z = nn.Linear(d_p, num_heads, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
repo_analysis: Optional[REPOAnalysis] = None,
) -> torch.Tensor:
"""
Args:
hidden_states: [B, S, hidden_size] β€” residual stream entering the
decoder layer, before FANLayer augmentation.
repo_analysis: REPOAnalysis container populated when analysis mode
is active. None during training (zero overhead).
Returns:
z: [B, H, S] β€” continuous per-head position scalars.
z[:, h, i] is the position assigned to token i by head h.
"""
# Position representation (Eq. 4): Swish(h W_g) βŠ™ (h W_c)
r = F.silu(self.W_g(hidden_states)) * self.W_c(hidden_states) # [B, S, d_p]
# Per-head assignment (Eq. 5): z^(h) = r W_z[:, h]
# W_z output: [B, S, H] β†’ transpose to [B, H, S]
z = self.W_z(r).transpose(1, 2).contiguous() # [B, H, S]
if repo_analysis is not None:
repo_analysis.r_repr = r.detach()
repo_analysis.positions = z.detach()
return z
def _apply_repo_rope(
q: torch.Tensor,
k: torch.Tensor,
z: torch.Tensor,
inv_freq: torch.Tensor,
attention_scaling: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE to Q and K using continuous per-head positions from REPO.
Replaces the standard ``apply_rotary_pos_emb(q, k, cos, sin)`` call for
layers where REPO is active. Builds ``cos/sin`` inline from ``z`` and
``inv_freq`` so that the rotation is differentiable w.r.t. ``z`` and
therefore w.r.t. the parameters of REPOModule.
Args:
q: [B, H, S, head_dim]
k: [B, H_kv, S, head_dim] (GQA: H_kv ≀ H)
z: [B, H, S] β€” per-head positions from REPOModule
inv_freq: [rotary_dim/2] β€” frozen RoPE frequency vector
attention_scaling: float β€” scaling factor from NeoLLMRotaryEmbedding
Returns:
(q_embed, k_embed) with the same shapes as (q, k).
Implementation note on GQA:
Q has ``num_attention_heads`` heads; K/V have ``num_key_value_heads``
heads (fewer under GQA). REPO produces one position per Q head.
For K we average the positions of the Q heads that map to each KV
head (groups of size ``num_key_value_groups``). This is the minimal
approach consistent with the paper's per-head independence claim:
each KV head receives a position that is representative of the Q
heads it serves.
"""
B, H, S = z.shape
H_kv = k.shape[1]
n_groups = H // H_kv
rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim
# inv_freq arrives from rotary_emb at forward time via repo_rope_args β€”
# already float32 on the correct device, no .to() needed, no DeviceCopy op.
# No autocast barrier: explicit .float() casts on z_q/z_k are sufficient
# to maintain float32 precision for the trig ops. Removing the context
# manager lets Inductor plan all intermediate tensors as part of a single
# static memory graph, eliminating mid-forward allocations that cause
# VRAM variance under max-autotune.
inv_freq_f = inv_freq
# z_q: [B, H, S, 1] Γ— inv_freq: [rotary_dim/2] β†’ [B, H, S, rotary_dim/2]
z_q = z.float().unsqueeze(-1) # [B, H, S, 1]
freqs_q = z_q * inv_freq_f # [B, H, S, r/2]
emb_q = torch.cat([freqs_q, freqs_q], dim=-1) # [B, H, S, r]
cos_q = (emb_q.cos() * attention_scaling).to(q.dtype)
sin_q = (emb_q.sin() * attention_scaling).to(q.dtype)
# KV positions: mean over the Q heads in each GQA group β†’ [B, H_kv, S]
z_k = z.view(B, H_kv, n_groups, S).mean(dim=2) # [B, H_kv, S]
z_k = z_k.float().unsqueeze(-1) # [B, H_kv, S, 1]
freqs_k = z_k * inv_freq_f # [B, H_kv, S, r/2]
emb_k = torch.cat([freqs_k, freqs_k], dim=-1) # [B, H_kv, S, r]
cos_k = (emb_k.cos() * attention_scaling).to(k.dtype)
sin_k = (emb_k.sin() * attention_scaling).to(k.dtype)
# Rotate only the first rotary_dim channels; pass the rest through unchanged.
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = torch.cat(
[(q_rot * cos_q) + (rotate_half(q_rot) * sin_q), q_pass], dim=-1
)
k_embed = torch.cat(
[(k_rot * cos_k) + (rotate_half(k_rot) * sin_k), k_pass], dim=-1
)
return q_embed, k_embed
class NeoLLMAttention(nn.Module):
"""
Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
optional Momentum, MEA head-level composition, optional LUCID preconditioning,
optional Affine-Scaled Attention, optional Exclusive Self Attention,
optional Directional Routing (Taylor, 2026), and optional Context
Re-Positioning (Li et al., 2026).
Directional Routing inserts at position C β€” post-XSA, pre-reshape β€” where
the output is already normalized (MEAHeadSeeDNorm) and has auto-position
removed (XSA). The router suppresses directions of cross-domain interference
orthogonal to the self-position already cleaned by XSA.
Pipeline (all active simultaneously when enabled):
FANLayer β†’ q_proj(gate) β†’ q_norm/k_norm β†’ REPO/RoPE β†’ Momentum
β†’ MEA(K,V) β†’ LUCID(V) β†’ v_ref β†’ Affine-Scaled SDPA
β†’ MEAHeadSeeDNorm β†’ XSA β†’ Directional Routing β†’ reshape
β†’ o_proj Β· sigmoid(gate) β†’ dropout
RoPE variants (controlled by config.use_repo and layer_idx):
use_repo=False (default): standard integer RoPE via pre-computed
position_embeddings β€” identical to prior behaviour.
use_repo=True, layer_idx >= repo_start_layer:
REPOModule f_Ο• predicts continuous per-head positions
z [B, H, S] from hidden_states. cos/sin are built
inline from z and inv_freq so the rotation is
differentiable w.r.t. f_Ο• parameters.
use_repo=True, layer_idx < repo_start_layer:
standard integer RoPE (lower layers capture surface
features that benefit less from re-positioning).
o_proj variants (controlled by config.use_hadamard_o_proj):
False (default): dense LinearWithMultipliers β€” full expressivity,
develops high ΞΊ during training (FP8 risk).
True: HadamardOProj β€” fixed WHT + learnable Ξ±/Ξ²,
ΞΊ = 1 by construction, 25% fewer attention params,
FP8-friendly (Aggarwal & Kumar, 2026, arXiv:2603.08343).
References:
Directional Routing: Taylor (2026). arXiv:2603.14923.
Hadamard o_proj: Aggarwal & Kumar (2026). arXiv:2603.08343.
Context Re-Positioning: Li et al. (2026). arXiv:2512.14391.
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.sqrt_head_dim = math.sqrt(self.head_dim)
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.use_momentum_attention = getattr(config, "use_momentum_attention", False)
self.momentum_gamma = float(getattr(config, "momentum_gamma", 0.0))
self.use_mea_attention = getattr(config, "use_mea_attention", False)
self.mea_component_key_value_heads = int(
getattr(config, "mea_component_key_value_heads", config.num_key_value_heads)
)
self.mea_groupnorm_eps = float(
getattr(config, "mea_groupnorm_eps", config.rms_norm_eps)
)
self.use_lucid_attention = getattr(config, "use_lucid_attention", False)
self.lucid_attention_eps = float(
getattr(config, "lucid_attention_eps", config.rms_norm_eps)
)
self.use_hadamard_o_proj = getattr(config, "use_hadamard_o_proj", False)
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, "fan_ratio", 0.125),
)
fan_output_dim = config.hidden_size + int(
config.hidden_size * getattr(config, "fan_ratio", 0.125)
)
self.q_proj = LinearWithMultipliers(
fan_output_dim, config.num_attention_heads * self.head_dim * 2,
bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=False,
)
self.num_mea_component_heads = (
self.mea_component_key_value_heads
if self.use_mea_attention else config.num_key_value_heads
)
self.k_proj = nn.Linear(
fan_output_dim, self.num_mea_component_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
fan_output_dim, self.num_mea_component_heads * self.head_dim,
bias=config.attention_bias,
)
# ── Output projection (Aggarwal & Kumar, 2026, arXiv:2603.08343) ────
# use_hadamard_o_proj=False (default): dense LinearWithMultipliers.
# use_hadamard_o_proj=True: HadamardOProj β€” fixed WHT + learnable Ξ±/Ξ².
# ΞΊ = 1 by construction, 25% fewer attention params, FP8-friendly.
# Requires hidden_size to be a power of 2 (512 βœ“, 1024 βœ“, 768 βœ—).
_o_in = config.num_attention_heads * self.head_dim
if self.use_hadamard_o_proj:
assert _o_in == config.hidden_size, (
f"HadamardOProj requires in_dim == out_dim, "
f"got {_o_in} vs {config.hidden_size}"
)
self.o_proj = HadamardOProj(config.hidden_size, bias=config.attention_bias)
else:
self.o_proj = LinearWithMultipliers(
_o_in, config.hidden_size,
bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True,
)
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
if self.use_mea_attention:
self.mea_key_mix = nn.Parameter(
torch.eye(self.num_mea_component_heads, config.num_key_value_heads)
)
self.mea_value_mix = nn.Parameter(
torch.eye(self.num_mea_component_heads, config.num_key_value_heads)
)
self.mea_output_norm = MEAHeadSeeDNorm(
num_heads=config.num_attention_heads, head_dim=self.head_dim,
num_kv_groups=self.num_key_value_groups, eps=self.mea_groupnorm_eps,
)
else:
self.mea_key_mix = None
self.mea_value_mix = None
self.mea_output_norm = None
self.dropout = nn.Dropout(config.dropout_rate)
self.lambda_1 = nn.Parameter(torch.tensor(0.5))
self.lambda_2 = nn.Parameter(torch.tensor(0.5))
# ── Affine-Scaled Attention (Bae et al., 2026) ───────────────────────
self.use_affine_scaled_attention = getattr(config, "use_affine_scaled_attention", False)
self.affine_momentum = float(getattr(config, "affine_momentum", 0.9))
# ── Exclusive Self Attention (Zhai, 2026) ────────────────────────────
self.use_xsa = getattr(config, "use_xsa", False)
self.xsa_eps = float(getattr(config, "xsa_eps", 1e-6))
if self.use_affine_scaled_attention:
self.alpha_proj = nn.Linear(
config.hidden_size, config.num_attention_heads, bias=False
)
self.register_buffer(
"alpha_ma",
torch.zeros(1, config.num_attention_heads, 1, 1),
persistent=True,
)
# ── Directional Routing (Taylor, 2026) ───────────────────────────────
# Each attention head learns K unit-norm direction vectors in head-space.
# A shared 4-layer MLP router β€” conditioned on the mean-pooled sequence
# representation β€” produces per-input sigmoid weights r_{h,k} ∈ [0,1]
# that control how much of each directional component is suppressed from
# the head's output after XSA (position C):
#
# o'_h = o_h - Ξ£_k r_{h,k} Β· (o_h Β· d_{h,k}) Β· d_{h,k}
#
# Position C (post-XSA, pre-reshape) is chosen because:
# - XSA already removed auto-position (self-position noise).
# - MEAHeadSeeDNorm already normalized the output β€” suppression has
# predictable magnitude since d_{h,k} is unit-norm.
# - Directions live in head-space d_head before o_proj, preserving
# the vocabulary projection interpretability of the paper.
# - No interaction with the Gated Attention gate (applied post o_proj).
#
# When use_xsa=False, position C reduces to post-SeeDNorm, pre-reshape β€”
# routing still applies correctly, directions just also span the
# self-position subspace (no XSA cleaned it first).
#
# Router: mean-pools hidden_states (pre-FAN residual stream) over S,
# passes through 4-layer MLP, outputs HΓ—K logits, temperature-scaled
# sigmoid β†’ r_{h,k}. Temperature T=5.0 pushes weights toward {0,1}.
# The router is shared across all heads within this layer, exactly as
# in the paper. No auxiliary loss β€” learns from LM objective only.
#
# direction_vecs: [H, K, d_head] β€” unit-normalized in forward, not init.
# Initialized from normal(0, 1) and normalized at first forward pass.
self.use_directional_routing = getattr(config, "use_directional_routing", False)
self.directional_routing_k = int(getattr(config, "directional_routing_k", 4))
self.directional_routing_temp = float(getattr(config, "directional_routing_temp", 5.0))
if self.use_directional_routing:
H = config.num_attention_heads
K = self.directional_routing_k
D = self.head_dim
R = config.hidden_size # router hidden dim
# Direction vectors: [H, K, d_head]
# Stored unnormalized β€” unit-normalized during forward.
self.direction_vecs = nn.Parameter(
torch.randn(H, K, D)
)
# 4-layer MLP router shared across heads within this layer.
# Input: mean-pooled hidden_states [B, hidden_size]
# Output: [B, H*K] β†’ reshape [B, H, K] β†’ sigmoid(TΒ·x) β†’ r_{h,k}
# Intermediate dim = hidden_size throughout, matching the paper.
self.direction_router = nn.Sequential(
SeeDNorm(R, eps=config.rms_norm_eps),
nn.Linear(R, R, bias=True),
nn.GELU(),
nn.Linear(R, R, bias=True),
nn.GELU(),
nn.Linear(R, R, bias=True),
nn.GELU(),
nn.Linear(R, H * K, bias=True),
)
else:
self.direction_vecs = None
self.direction_router = None
# ── Context Re-Positioning (Li et al., 2026) ─────────────────────
# Active for layers at or above repo_start_layer only.
# Layers below repo_start_layer use standard integer RoPE positions.
# inv_freq is accessed from the model's rotary_emb at forward time;
# stored here as a non-persistent buffer reference set by NeoLLMModel.
self.use_repo = (
getattr(config, "use_repo", False)
and layer_idx >= getattr(config, "repo_start_layer", config.num_hidden_layers // 3)
)
if self.use_repo:
_d_p = getattr(config, "repo_d_p", config.hidden_size // 8)
self.repo_module = REPOModule(
hidden_size=config.hidden_size,
d_p=_d_p,
num_heads=config.num_attention_heads,
)
else:
self.repo_module = None
def _apply_momentum_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
attn_analysis: Optional[AttentionAnalysis] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.use_momentum_attention or self.momentum_gamma == 0.0:
return q, k
dq = causal_first_difference(q)
dk = causal_first_difference(k)
q_new = q + self.momentum_gamma * dq
k_new = k + self.momentum_gamma * dk
if attn_analysis is not None:
attn_analysis.q_momentum_delta = dq.detach()
attn_analysis.k_momentum_delta = dk.detach()
attn_analysis.q_post_momentum = q_new.detach()
attn_analysis.k_post_momentum = k_new.detach()
return q_new, k_new
def _apply_mea_head_mixing(
self,
k: torch.Tensor,
v: torch.Tensor,
attn_analysis: Optional[AttentionAnalysis] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not self.use_mea_attention:
return k, v
k_mixed = head_linear_compose(k, self.mea_key_mix).contiguous()
v_mixed = head_linear_compose(v, self.mea_value_mix).contiguous()
if attn_analysis is not None:
attn_analysis.mea_key_mix_matrix = self.mea_key_mix.detach()
attn_analysis.mea_value_mix_matrix = self.mea_value_mix.detach()
attn_analysis.k_post_mea = k_mixed.detach()
attn_analysis.v_post_mea = v_mixed.detach()
return k_mixed, v_mixed
def _apply_lucid_preconditioner(
self,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor],
attn_analysis: Optional[AttentionAnalysis] = None,
) -> torch.Tensor:
if not self.use_lucid_attention:
return v.contiguous()
key_rn = rms_key_unit_norm(k, eps=self.lucid_attention_eps)
logits = torch.matmul(key_rn, key_rn.transpose(-1, -2)) * self.scaling - self.sqrt_head_dim
prec = torch.tril(torch.exp(logits))
kv = infer_key_validity(attention_mask, k.shape[-2], k.shape[1])
if kv is not None:
prec = prec * (kv.unsqueeze(-1) & kv.unsqueeze(-2)).to(prec.dtype)
eye = torch.eye(prec.shape[-1], device=prec.device, dtype=prec.dtype).view(
1, 1, prec.shape[-1], prec.shape[-1]
)
prec = prec + eye * (1.0 - prec.diagonal(dim1=-2, dim2=-1).unsqueeze(-1))
result = torch.linalg.solve_triangular(
prec, v.float(), upper=False, unitriangular=True
).to(v.dtype).contiguous()
if attn_analysis is not None:
attn_analysis.lucid_preconditioner = prec.detach()
attn_analysis.v_post_lucid = result.detach()
return result
def _apply_directional_routing(
self,
attn_out: torch.Tensor,
hidden_states: torch.Tensor,
attn_analysis: Optional[AttentionAnalysis] = None,
) -> torch.Tensor:
"""
Directional suppression at position C (post-XSA, pre-reshape).
Args:
attn_out: [B, S, H, d_head] β€” output after XSA and SeeDNorm.
hidden_states: [B, S, hidden_size] β€” pre-FAN residual stream,
used as router input (same as paper's x_i).
attn_analysis: Optional[AttentionAnalysis] β€” deposits DR internals.
Returns:
[B, S, H, d_head] with selected directional components suppressed.
"""
# ── Router: one routing decision per sequence ─────────────────────
# Mean-pool over sequence dimension S β†’ [B, hidden_size].
# The router is sequence-level (not token-level), matching the paper.
# This produces one suppression pattern per input sequence,
# which the paper shows learns domain-adaptive behavior in early
# layers and fixed syntactic pruning in late layers.
pooled = hidden_states.mean(dim=1) # [B, hidden_size]
logits = self.direction_router(pooled) # [B, H*K]
r = torch.sigmoid(self.directional_routing_temp * logits)
r = r.view(hidden_states.shape[0], self.config.num_attention_heads,
self.directional_routing_k) # [B, H, K]
# Expand over sequence for broadcasting: [B, 1, H, K]
r = r.unsqueeze(1)
# ── Unit-normalize direction vectors ──────────────────────────────
# Normalize at forward time, not at init, following the paper.
# d: [H, K, d_head] β†’ unit norm along d_head dimension.
d = F.normalize(self.direction_vecs, dim=-1) # [H, K, d_head]
# ── Directional suppression ───────────────────────────────────────
# attn_out: [B, S, H, d_head]
# For each head h and direction k:
# proj_{h,k} = (o_h Β· d_{h,k}) scalar per (B, S, H, K)
# suppress = r_{h,k} Β· proj_{h,k} Β· d_{h,k}
# o'_h = o_h - Ξ£_k suppress_{h,k}
#
# proj: einsum over d_head dimension
# attn_out [B, S, H, D] Γ— d [H, K, D] β†’ [B, S, H, K]
proj = torch.einsum("bshd,hkd->bshk", attn_out, d) # [B, S, H, K]
# r [B, 1, H, K] Γ— proj [B, S, H, K] β†’ [B, S, H, K]
weighted = r * proj # [B, S, H, K]
# Ξ£_k weighted_{h,k} Β· d_{h,k}:
# weighted [B, S, H, K] Γ— d [H, K, D] β†’ [B, S, H, D]
suppression = torch.einsum("bshk,hkd->bshd", weighted, d)
result = attn_out - suppression
if attn_analysis is not None:
attn_analysis.direction_vecs_normalized = d.detach()
attn_analysis.dr_router_logits = logits.detach()
attn_analysis.dr_routing_weights = r.squeeze(1).detach()
attn_analysis.dr_projection = proj.detach()
attn_analysis.dr_suppression = suppression.detach()
attn_analysis.attn_output_post_routing = result.detach()
return result
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
attn_analysis: Optional[AttentionAnalysis] = None,
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
fan_a = attn_analysis.fan if attn_analysis is not None else None
h_fan = self.fan_layer(hidden_states, analysis=fan_a)
if first_layer_fan is not None:
h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
current_layer_fan = h_fan.clone()
query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
q_raw, gate = torch.chunk(
self.q_proj(h_fan).view(*input_shape, self.config.num_attention_heads, self.head_dim * 2),
2, dim=-1,
)
gate = gate.reshape(*input_shape, -1)
if attn_analysis is not None:
attn_analysis.q_raw = q_raw.detach()
attn_analysis.gate_raw = gate.detach()
q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
k = self.k_norm(self.k_proj(h_fan).view(kv_shape)).transpose(1, 2)
v = self.v_proj(h_fan).view(kv_shape).transpose(1, 2)
if attn_analysis is not None:
attn_analysis.q_post_norm = q.detach()
attn_analysis.k_post_norm = k.detach()
attn_analysis.v_raw = v.detach()
cos, sin = position_embeddings
if self.use_repo:
# REPO path: f_Ο• predicts continuous per-head positions from the
# residual stream, then cos/sin are built inline from those positions
# so the rotation is differentiable w.r.t. REPOModule parameters.
# inv_freq and attention_scaling arrive via repo_rope_args, sourced
# directly from rotary_emb at forward time β€” no buffer on this module,
# no meta-tensor issue on lm_eval / to(device) paths.
# (Li et al., 2026, Β§3.2 β€” Eq. 6–7)
repo_a = attn_analysis.repo if attn_analysis is not None else None
z = self.repo_module(hidden_states, repo_analysis=repo_a) # [B, H, S]
inv_freq, attn_scaling = repo_rope_args
q, k = _apply_repo_rope(q, k, z, inv_freq, attn_scaling)
else:
# Standard path: integer positions pre-computed by NeoLLMModel.
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if attn_analysis is not None:
attn_analysis.q_post_rope = q.detach()
attn_analysis.k_post_rope = k.detach()
q, k = self._apply_momentum_attention(q, k, attn_analysis=attn_analysis)
k, v = self._apply_mea_head_mixing(k, v, attn_analysis=attn_analysis)
v = self._apply_lucid_preconditioner(k, v, attention_mask, attn_analysis=attn_analysis)
# Capture v_ref for XSA after MEA mixing and LUCID preconditioning.
# This is the vector that actually participated in SDPA aggregation.
v_ref = v if self.use_xsa else None
# ── Affine-Scaled Attention ───────────────────────────────────────
# Active whenever use_affine_scaled_attention=True, regardless of
# attention backend. Two code paths β€” same math, different execution:
# eager : full weight access, attn_weights_pre/post_affine captured.
# flash/sdpa: Ξ±Β·flash_out + Ξ²Β·V_cumsum, no weight tensors materialised.
alpha = None
beta = None
use_affine = self.use_affine_scaled_attention
if use_affine:
alpha = linear_clipping(self.alpha_proj(hidden_states)) # [B, S, H]
alpha = alpha.permute(0, 2, 1).unsqueeze(-1) # [B, H, S, 1]
N = k.shape[-2]
beta = (self.alpha_ma.to(alpha.dtype) - alpha) / max(N, 1)
if self.training:
with torch.no_grad():
batch_mean = alpha.mean(dim=(0, 2), keepdim=True)
self.alpha_ma.copy_(
self.affine_momentum * self.alpha_ma
+ (1.0 - self.affine_momentum) * batch_mean
)
if attn_analysis is not None:
attn_analysis.alpha_per_head = alpha.detach()
attn_analysis.beta_per_head = beta.detach()
attn_analysis.alpha_moving_avg = self.alpha_ma.detach()
if use_affine:
if self.config._attn_implementation == "eager":
# Eager: materialises softmax weights, full analysis available.
attn_out, attn_weights = affine_scaled_eager_attention_forward(
self, q, k, v, attention_mask,
scaling=self.scaling, alpha=alpha, beta=beta,
dropout=0.0 if not self.training else self.attention_dropout,
attn_analysis=attn_analysis,
**kwargs,
)
else:
# Flash / SDPA: exact formula via V.cumsum, no weight tensors.
# attn_weights_pre/post_affine remain None in AnalysisState.
attn_out, attn_weights = affine_scaled_flash_attention_forward(
self, q, k, v, attention_mask,
scaling=self.scaling, alpha=alpha, beta=beta,
dropout=0.0 if not self.training else self.attention_dropout,
attn_analysis=attn_analysis,
**kwargs,
)
else:
if self.config._attn_implementation == "eager":
attn_fn = eager_attention_forward
else:
attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_out, attn_weights = attn_fn(
self, q, k, v, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, **kwargs,
)
if attn_analysis is not None:
attn_analysis.attn_weights = (
attn_weights.detach() if attn_weights is not None else None
)
if attn_analysis is not None:
attn_analysis.attn_output_raw = attn_out.detach()
attn_out = attn_out.reshape(*input_shape, -1, self.head_dim)
if self.use_mea_attention:
attn_out = self.mea_output_norm(attn_out)
if attn_analysis is not None:
attn_analysis.attn_output_post_mea_norm = attn_out.detach()
# ── Exclusive Self Attention (position B, pre-routing) ────────────
# Removes auto-position component before directional routing so that
# direction_vecs specialize exclusively in cross-domain interference.
if self.use_xsa and v_ref is not None:
v_ref_expanded = repeat_kv(v_ref, self.num_key_value_groups) # [B, H, S, D]
v_ref_t = v_ref_expanded.transpose(1, 2) # [B, S, H, D]
v_ref_t = v_ref_t.to(attn_out.dtype)
proj = (attn_out * v_ref_t).sum(dim=-1, keepdim=True)
norm_sq = (v_ref_t * v_ref_t).sum(dim=-1, keepdim=True).clamp(min=self.xsa_eps)
xsa_comp = (proj / norm_sq) * v_ref_t
attn_out = attn_out - xsa_comp
if attn_analysis is not None:
attn_analysis.xsa_self_position_component = xsa_comp.detach()
attn_analysis.attn_output_post_xsa = attn_out.detach()
# ── Directional Routing (position C, post-XSA, pre-reshape) ──────
# Suppresses cross-domain interference directions from the head output.
# Operates on [B, S, H, d_head] before reshape and o_proj.
# When use_xsa=False: directions span full head-space (no XSA pre-clean).
# When use_directional_routing=False: this block is skipped entirely.
if self.use_directional_routing:
attn_out = self._apply_directional_routing(
attn_out, hidden_states, attn_analysis=attn_analysis
)
# ── Reshape β†’ o_proj β†’ Gated Attention gate β†’ dropout ────────────
attn_out_flat = attn_out.reshape(*input_shape, -1).contiguous()
if attn_analysis is not None:
attn_analysis.attn_output_pre_gate = attn_out_flat.detach()
gate_sig = torch.sigmoid(gate)
attn_analysis.gate_sigmoid = gate_sig.detach()
gated = attn_out_flat * gate_sig
if self.use_hadamard_o_proj:
# Pass HadamardAnalysis sub-object so post_fwht and alpha are captured
attn_out_gated = self.o_proj(gated, analysis=attn_analysis.hadamard)
else:
attn_out_gated = self.o_proj(gated)
else:
attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate))
attn_out_gated = self.dropout(attn_out_gated)
if attn_analysis is not None:
attn_analysis.attn_output_final = attn_out_gated.detach()
return attn_out_gated, attn_weights, current_layer_fan
class PolyNorm(nn.Module):
def __init__(
self,
eps: float = 1e-6,
proj_eps: float = 1e-6,
exclusive_init: float = 0.5,
exclusive: bool = True,
):
super().__init__()
self.weight = nn.Parameter(torch.ones(3) / 3)
self.bias = nn.Parameter(torch.zeros(1))
self.eps = eps
self.exclusive = exclusive
if exclusive:
self.proj_eps = proj_eps
# Dos fuerzas exclusivas aprendibles en (0, 1), una por rama de orden alto.
# Se parametrizan con logits para que sigmoid mantenga alpha ∈ (0, 1).
exclusive_init = float(min(max(exclusive_init, 1e-4), 1.0 - 1e-4))
init = torch.full((2,), exclusive_init, dtype=torch.float32)
self.exclusive_logits = nn.Parameter(torch.logit(init))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def _exclusive(self, branch, ref, alpha, x1_f, ref_norm_sq):
"""
Elimina de `branch` la componente alineada con `ref` (la rama lineal x1),
ponderada por alpha ∈ (0, 1), y renormaliza el resultado.
branch := _norm(branch - alpha Β· proj_{ref}(branch))
El denominador ref_norm_sq se pasa precalculado para evitar duplicarlo
cuando se llama dos veces por forward (una vez para x2, otra para x3).
"""
branch_f = branch.float()
dot = (branch_f * x1_f).sum(dim=-1, keepdim=True)
proj_coeff = (dot / ref_norm_sq).to(branch.dtype)
out = branch - alpha.to(branch.dtype) * proj_coeff * ref
return self._norm(out)
def forward(
self,
x: torch.Tensor,
analysis: Optional[PolyNormAnalysis] = None,
) -> torch.Tensor:
# CachΓ© de potencias: x_sq reutilizado en x1 y x2; x_cu = xΒ·x_sq evita pow(3)
x_sq = x.pow(2)
x_cu = x * x_sq
# Tres ramas normalizadas
x1 = x * x_sq.mean(-1, keepdim=True).add(self.eps).rsqrt()
x2 = x_sq * (x_sq * x_sq).mean(-1, keepdim=True).add(self.eps).rsqrt()
x3 = x_cu * (x_cu * x_cu).mean(-1, keepdim=True).add(self.eps).rsqrt()
if analysis is not None:
analysis.x1 = x1.detach()
analysis.x2_pre_exclusive = x2.detach()
analysis.x3_pre_exclusive = x3.detach()
if self.exclusive:
# Fuerzas exclusivas aprendibles
alpha2, alpha3 = torch.sigmoid(self.exclusive_logits).unbind()
if analysis is not None:
analysis.alpha2 = alpha2.detach()
analysis.alpha3 = alpha3.detach()
analysis.weights = self.weight.detach()
analysis.bias = self.bias.detach()
# Precalcular ref (x1) en fp32 y su norma al cuadrado β€” compartido por x2 y x3
x1_f = x1.float()
ref_norm_sq = x1_f.pow(2).sum(-1, keepdim=True).clamp_min(self.proj_eps)
# OrtogonalizaciΓ³n parcial de las ramas de orden alto respecto a la lineal
x2 = self._exclusive(x2, x1, alpha2, x1_f, ref_norm_sq)
x3 = self._exclusive(x3, x1, alpha3, x1_f, ref_norm_sq)
if analysis is not None:
analysis.x2_post_exclusive = x2.detach()
analysis.x3_post_exclusive = x3.detach()
else:
if analysis is not None:
analysis.weights = self.weight.detach()
analysis.bias = self.bias.detach()
output = (
self.weight[0] * x3
+ self.weight[1] * x2
+ self.weight[2] * x1
+ self.bias
)
if analysis is not None:
analysis.output = output.detach()
return output
def compute_versatile_aux_loss(
aux_stats_list: list,
n_experts: int,
weight: float,
) -> torch.Tensor:
"""
Load-balancing auxiliary loss for VersatileFFN width-path expert routing.
Identical formula to JTok-M: L_aux = Ξ» Β· n_e Β· Ξ£_i p_i Β· f_i
averaged over all decoder layers with active VersatileFFN.
Args:
aux_stats_list: list of (p_sum [n_e], f_sum [n_e], N_tokens) per layer.
n_experts: total number of virtual experts.
weight: Ξ» coefficient.
Returns:
Scalar loss tensor.
"""
total_loss = None
for p_sum, f_sum, N in aux_stats_list:
p_i = p_sum / N
layer_loss = weight * n_experts * (p_i * f_sum).sum()
total_loss = layer_loss if total_loss is None else total_loss + layer_loss
if total_loss is None:
return torch.tensor(0.0)
return total_loss / len(aux_stats_list)
class VersatileFFN(nn.Module):
"""
VersatileFFN: dual-process feed-forward network with parameter reuse.
Drop-in replacement for NeoLLMMLP. Shares the same weight matrices but
reuses them across two complementary paths (Nie et al., 2026):
Width-Versatile path (virtual MoE, Eq. 7–8):
Derives N virtual sub-experts by slicing non-overlapping contiguous
segments of the intermediate dimension. A Top-K router selects
``active_experts`` experts per token. No additional parameters beyond
the router ``expert_gate [hidden, N_experts]``.
Depth-Versatile path (recursive, Eq. 9–11):
Applies the full shared MLP (inner SeeDNorm + FANLayer + gate/up/down)
recursively up to ``max_depth`` times. A Gumbel-Softmax loop predictor
(``depth_predictor [hidden, max_depth]``) decides per-token depth.
During training: always L_max iterations with soft STE weighting.
During inference: early exit at argmax(p) to save FLOPs.
Difficulty-aware fusion (Eq. 12–13):
Ξ» = (L_max βˆ’ E[L]) / L_max ∈ [0, 1)
output = Ξ» Β· Y_width + (1 βˆ’ Ξ») Β· Y_depth
Easy tokens (low E[L] β†’ Ξ» β†’ 1) favour the fast width path.
Hard tokens (high E[L] β†’ Ξ» β†’ 0) favour the deep recursive path.
NeoLLM-specific adaptations vs. the OLMo reference implementation:
- FANLayer runs once per forward (shared between both paths).
- Expert slicing uses contiguous segments (no SwiGLU half-shift)
because gate and up are separate projections here.
- PolyNorm (shared parameters) is used as the activation in both paths.
- Multipliers from LinearWithMultipliers are applied per-slice in the
width path to correctly replicate the full MLP forward.
- Gumbel temperature stored as a persistent float32 buffer so it
survives checkpoint save/load without external tracking.
- Width path load-balancing returns (output, aux_stats) for integration
with the existing NeoLLMForCausalLM aux-loss accumulation pattern.
Width dispatch (CUDAGraph-compatible sparse routing):
El dispatch original del paper (torch.where + index_add_) es sparse y
fiel al paper pero produce shapes dependientes de datos β†’ incompatible
con CUDAGraphs. La implementaciΓ³n usa argsort como dispatcher estΓ‘tico:
flat_expert [NΒ·K] β†’ argsort β†’ perm [NΒ·K] (shape siempre igual)
sorted_tok [NΒ·K] = flat_tok[perm] (Γ­ndices de token originales)
grouped_tok [E, C] = sorted_tok.view(E, C) (C = NΒ·K // E, constante)
Propiedades clave:
Β· argsort: output shape = input shape, siempre [NΒ·K]. CUDAGraph βœ“
Β· C = N_tokΒ·K // E es un entero Python conocido en compile-time.
Con el aux loss manteniendo balance, cada experto recibe β‰ˆ C slots.
Β· scatter_add_ con index [C, D] de shape estΓ‘tico: CUDAGraph βœ“
(los VALORES del index cambian por batch, no el SHAPE).
Β· FLOPs idΓ©nticos al original: cada experto procesa [C, D] = [NΒ·K/E, D]
tokens, no todos los N tokens. Con K=2, E=4: C = N/2 por experto.
Conditional Parallelism (inferencia, Algorithm 2):
Β· Los tokens con Ξ»=0 (argmax β†’ max_depth) igualmente participan en el
grouped buffer y su expert forward se computa (shapes estΓ‘ticos).
Β· Su contribuciΓ³n es cancelada por Ξ»=0 en la fusiΓ³n:
output = x_depthΒ·(1βˆ’Ξ») + x_moeΒ·Ξ» β†’ x_depth si Ξ»=0
Β· Esto pierde el saving de FLOPs de los Ξ»=0 tokens, pero la correctitud
matemΓ‘tica es exacta. Tradeoff aceptable vs CUDAGraph-incompatibilidad.
Discrete Early-Exit (inferencia, Algorithm 2):
Β· Sustituido por always-max_depth + torch.gather con loop_choice.
Para max_depth=2 el overhead es ≀ 1 iteraciΓ³n extra por token.
Reference:
Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in
LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531.
"""
def __init__(self, config: NeoLLMConfig):
super().__init__()
self.total_experts = getattr(config, "versatile_total_experts", 8)
self.active_experts = getattr(config, "versatile_active_experts", 2)
self.max_depth = getattr(config, "versatile_max_depth", 4)
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# ── Shared MLP weights (identical layout to NeoLLMMLP) ──────────────
fan_ratio = getattr(config, "fan_ratio_ffn", 0.0625)
fan_dim = config.hidden_size + int(config.hidden_size * fan_ratio)
self.fan_layer = FANLayer(hidden_size=config.hidden_size, fan_ratio=fan_ratio)
self.gate_proj = LinearWithMultipliers(
fan_dim, config.intermediate_size,
bias=False, use_row_multiplier=True, use_column_multiplier=False,
)
self.up_proj = nn.Linear(fan_dim, config.intermediate_size, bias=False)
self.down_proj = LinearWithMultipliers(
config.intermediate_size, config.hidden_size,
bias=False, use_row_multiplier=True, use_column_multiplier=True,
)
self.act_fn = PolyNorm(exclusive=False)
self.dropout = nn.Dropout(config.dropout_rate)
# ── Inner normalization for depth-recursive steps ────────────────────
# Applied before each recursive MLP application (mirrors the paper's
# ff_norm inside each depth loop).
self.ff_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
# ── Width path: expert router + contiguous segment indices ──────────
seg = config.intermediate_size // self.total_experts
self.expert_segment = seg
self.expert_gate = nn.Linear(config.hidden_size, self.total_experts, bias=False)
# Static expert slice indices [total_experts, seg] β€” non-persistent
# buffer so .to(device) moves them automatically.
idx_list = [
torch.arange(i * seg, (i + 1) * seg)
for i in range(self.total_experts)
]
self.register_buffer("expert_idx", torch.stack(idx_list), persistent=False)
# ── Depth path: loop count predictor ────────────────────────────────
self.depth_predictor = nn.Linear(config.hidden_size, self.max_depth, bias=False)
# Gumbel temperature as a persistent scalar buffer.
# Decays externally via update_gumbel_temperature() each training step.
temp_start = getattr(config, "versatile_gumbel_temp_start", 5.0)
self.register_buffer(
"gumbel_temp",
torch.tensor(temp_start, dtype=torch.float32),
persistent=True,
)
self._gumbel_temp_end = getattr(config, "versatile_gumbel_temp_end", 0.1)
self._gumbel_temp_decay = getattr(config, "versatile_gumbel_temp_decay", 0.99984)
# ── Public API ────────────────────────────────────────────────────────────
def update_gumbel_temperature(self) -> None:
"""Decay Gumbel temperature by one step. Call once per training step."""
new_t = max(
self.gumbel_temp.item() * self._gumbel_temp_decay,
self._gumbel_temp_end,
)
self.gumbel_temp.fill_(new_t)
# ── Private helpers ───────────────────────────────────────────────────────
def _expert_forward(
self,
x_fan: torch.Tensor, # [N, fan_dim]
x_in: torch.Tensor, # [N, hidden_size] β€” residual base
idx: torch.Tensor, # [seg] β€” which intermediate neurons to use
) -> torch.Tensor:
"""
Virtual expert forward for a single expert (Eq. 6–8 of paper).
Slices gate_proj, up_proj, down_proj weights to the expert segment.
Multipliers from LinearWithMultipliers are applied per-slice to keep
the computation exactly equivalent to the full MLP forward.
Returns x_in + expert_out (residual included, matching Eq. 8).
"""
# Gate projection: sliced rows + row multiplier
g_w = self.gate_proj.linear.weight[idx] # [seg, fan_dim]
g_mul = self.gate_proj.row_multiplier.multiplier[idx] # [seg]
gate = F.linear(x_fan, g_w) * g_mul # [N, seg]
# Up projection: sliced rows, no multiplier on up_proj
u_w = self.up_proj.weight[idx] # [seg, fan_dim]
up = F.linear(x_fan, u_w) # [N, seg]
# Activation β€” PolyNorm is shape-agnostic (weight [3], bias [1])
act = self.act_fn(gate) * up # [N, seg]
act = self.dropout(act)
# Down projection: column multiplier on input, then row on output
col_mul = self.down_proj.column_multiplier.multiplier[idx] # [seg]
d_w = self.down_proj.linear.weight[:, idx] # [hidden, seg]
row_mul = self.down_proj.row_multiplier.multiplier # [hidden]
out = F.linear(act * col_mul, d_w) * row_mul # [N, hidden]
return x_in + out
def _full_forward_step(self, x: torch.Tensor) -> torch.Tensor:
"""
One recursive MLP application for the depth path (Eq. 9).
Includes inner SeeDNorm + FANLayer so each iteration starts from a
normalized, periodicity-augmented view of the current hidden state.
Residual connection applied inside to match the paper's formulation.
"""
og = x
x = self.ff_norm(x)
x_f = self.fan_layer(x)
gate = self.gate_proj(x_f)
up = self.up_proj(x_f)
act = self.act_fn(gate) * up
act = self.dropout(act)
return og + self.down_proj(act)
# ── Forward ───────────────────────────────────────────────────────────────
def forward(
self,
x: torch.Tensor,
analysis: Optional[MLPAnalysis] = None,
) -> Tuple[torch.Tensor, Optional[tuple]]:
"""
Args:
x: [B, S, hidden_size] β€” pre-normalized input
(SeeDNorm + LNS already applied by NeoLLMDecoderLayer).
analysis: MLPAnalysis with .versatile pre-allocated when analysis
mode is active. None during training (zero overhead).
Returns:
(output [B, S, hidden_size], aux_stats)
aux_stats = (p_sum [N_experts], f_sum [N_experts], N_tokens)
for load-balancing loss, or None in inference if width path skipped.
"""
B, S, D = x.shape
# Depth predictor reads difficulty from the pre-normalized hidden state
depth_logits = self.depth_predictor(x) # [B, S, max_depth]
# FANLayer runs once β€” output shared between width and depth paths
x_fan = self.fan_layer(x) # [B, S, fan_dim]
# ═════════════════════ TRAINING ══════════════════════════════════════
if self.training:
# Gumbel-Softmax with hard=True (STE): discrete forward, continuous
# backward. Annealed temperature controls sharpness of the selection.
depth_probs = F.gumbel_softmax(
depth_logits,
tau=float(self.gumbel_temp),
hard=True,
dim=-1,
) # [B, S, max_depth]
# ── Depth path: always L_max iterations (static graph for compile) ─
depth_outputs = []
current_x = x
for _ in range(self.max_depth):
current_x = self._full_forward_step(current_x)
depth_outputs.append(current_x)
# Soft weighted combination β€” gradient flows through depth_probs
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,L]
x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) # [B,S,D]
# ── Width path: argsort-based sparse dispatch (Eq. 7–8) ──────────
# MatemΓ‘tica (paper Β§3.2):
# Y_width = Σ_{k∈TopK} g_k · Y_k,
# Y_k = H + W_out^(k) Ο†(W_proj^(k) LayerNorm(H)) (Eq. 8)
# Como Σ_{k∈TopK} g_k = 1 (softmax normalizado sobre TopK):
# Y_width = H + Σ_{k∈TopK} g_k · delta_k
#
# ImplementaciΓ³n sparse con shapes estΓ‘ticos:
#
# 1. flat_expert [N_tokΒ·K]: Γ­ndices de experto por token-slot.
# argsort β†’ perm [N_tokΒ·K] con shape siempre igual. CUDAGraph βœ“
#
# 2. sorted_tok [N_tokΒ·K] = flat_tok[perm]: tokens ordenados por
# experto. Todos los tokens del experto e quedan contiguos.
#
# 3. view(E, C) con C = N_tokΒ·K // E constante Python β†’ shape
# estΓ‘tico [E, C, D] para gather y forward.
#
# 4. _expert_forward sobre [C, D] por experto β€” mismos FLOPs que
# el original con torch.where: solo C tokens por experto,
# no los N_tok completos. Con K=2, E=4: C = N_tok/2.
#
# 5. scatter_add_: index de shape [C, D] siempre estΓ‘tico.
# Los VALORES varΓ­an por batch, el SHAPE no. CUDAGraph βœ“
# Acumula Ξ£_{k} g_k Β· Y_k para cada token n mediante
# sum sobre los K slots que apuntan a n.
K = self.active_experts
E = self.total_experts
N_tok = B * S
C = (N_tok * K) // E # tokens por experto β€” constante compile-time
routing_logits = self.expert_gate(x) # [B, S, E]
topk_w, topk_i = torch.topk(routing_logits, k=K, dim=-1)
topk_w = torch.softmax(topk_w, dim=-1) # [B, S, K]
x_flat = x.reshape(-1, D) # [N_tok, D]
x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) # [N_tok, fan_dim]
# Aplanar: cada token aparece K veces, una por experto seleccionado
flat_expert = topk_i.reshape(-1) # [N_tokΒ·K] long
flat_tok = (
torch.arange(N_tok, device=x.device, dtype=torch.long)
.unsqueeze(1).expand(N_tok, K).reshape(-1)
) # [N_tokΒ·K] long
flat_w = topk_w.reshape(-1) # [N_tokΒ·K]
# Ordenar por expert ID: todos los tokens del mismo experto juntos
perm = torch.argsort(flat_expert, stable=True) # [N_tokΒ·K] long
sorted_tok = flat_tok[perm] # [N_tokΒ·K] long
sorted_w = flat_w[perm] # [N_tokΒ·K]
# Agrupar por experto [E, C] β€” C conocido en compile-time
grouped_tok = sorted_tok.view(E, C) # [E, C] long
grouped_w = sorted_w.view(E, C) # [E, C]
# Gather features del token original para cada slot de experto
flat_idx = grouped_tok.reshape(-1) # [EΒ·C] long
fan_dim = x_fan_flat.shape[-1]
x_grouped = x_flat[flat_idx].view(E, C, D) # [E, C, D]
xf_grouped = x_fan_flat[flat_idx].view(E, C, fan_dim) # [E, C, fan_dim]
# Expert forward + scatter_add_ de vuelta a [N_tok, D]
# Loop desenrollado por dynamo (E constante Python) β€” sin graph breaks
x_moe_flat = torch.zeros(N_tok, D, device=x.device, dtype=x.dtype)
for eid in range(E):
# out_e [C, D] = x_grouped[eid] + delta_e (residual incluido, Eq. 8)
out_e = self._expert_forward(
xf_grouped[eid], x_grouped[eid], self.expert_idx[eid]
)
w_e = grouped_w[eid].unsqueeze(-1) # [C, 1]
tok_idx_e = grouped_tok[eid].unsqueeze(1).expand(C, D) # [C, D] long
# Acumula g_k Β· Y_k en la posiciΓ³n original del token
# Cuando eid recorre los K experts de un token n:
# x_moe_flat[n] = Ξ£_k g_k Β· Y_k = H_n + Ξ£_k g_k Β· delta_k
x_moe_flat.scatter_add_(
0, tok_idx_e, (out_e * w_e).to(x_moe_flat.dtype)
)
x_moe = x_moe_flat.reshape(B, S, D)
# Load-balancing aux stats (Eq. load-balancing loss)
# p_sum: probabilidad media por experto (sobre routing_logits completo)
# f_sum: fracciΓ³n real de tokens asignados a cada experto
r_probs_flat = torch.softmax(
routing_logits.reshape(-1, E), dim=-1
) # [N_tok, E]
p_sum = r_probs_flat.sum(dim=0) # [E]
f_sum = (
F.one_hot(flat_expert.long(), E).float().sum(dim=0)
/ float(N_tok * K)
) # [E]
aux_stats = (p_sum, f_sum, N_tok)
# ── Difficulty-aware fusion (Eq. 12–13) ──────────────────────────
loop_idx = torch.arange(
1, self.max_depth + 1, device=x.device, dtype=depth_probs.dtype
)
expected_L = (depth_probs * loop_idx).sum(dim=-1) # [B, S]
moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
output = (
x_depth * (1.0 - moe_weight.unsqueeze(-1))
+ x_moe * moe_weight.unsqueeze(-1)
)
loop_choice = None # not used during training
# ═════════════════════ INFERENCE ══════════════════════════════════════
else:
loop_choice = depth_logits.argmax(dim=-1) # [B, S]
# ── Depth path: siempre max_depth iteraciones (shape estΓ‘tico) ─
# ORIGINAL PROBLEM: el early-exit original usaba
# max_loop = int(loop_choice.max().item())
# que produce una sincronizaciΓ³n CPU-GPU (equivalente a .item())
# y hace que el nΓΊmero de iteraciones del loop dependa de datos β€”
# ambas condiciones prohΓ­ben la captura de CUDAGraphs.
#
# SOLUCIΓ“N: siempre se ejecutan exactamente self.max_depth
# iteraciones. depth_stack [B,S,D,max_depth] tiene shape estΓ‘tico.
# El gather sobre loop_choice selecciona la salida correcta por
# token sin necesidad de conocer cuΓ‘ntas iteraciones se ejecutaron.
# La pΓ©rdida de FLOPs por iteraciones "extra" es mΓ­nima porque
# max_depth es pequeΓ±o (default 2) y _full_forward_step es ligero.
depth_outputs = []
current_x = x
for _ in range(self.max_depth):
current_x = self._full_forward_step(current_x)
depth_outputs.append(current_x)
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,max_depth]
gather_idx = (
loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1)
)
x_depth = depth_stack.gather(dim=-1, index=gather_idx).squeeze(-1)
# Fusion weight from discrete choice
expected_L = (loop_choice + 1).float() # [B, S]
moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
aux_stats = None
depth_probs = None
# ── Width path: argsort-based sparse dispatch (mismo mecanismo
# que entrenamiento, Eq. 7–8 + Conditional Parallelism Β§A) ───
#
# Conditional Parallelism (Algorithm 2 del paper):
# Si Ξ»=0 para un token β†’ Y = Y_depth, el width path se omite.
# Con shapes estΓ‘ticos no podemos excluir dinΓ‘micamente esos tokens
# del buffer. En su lugar, los Ξ»=0 tokens participan en el grouped
# buffer y su expert forward corre, pero la fusiΓ³n
# output = x_depthΒ·(1βˆ’Ξ») + x_moeΒ·Ξ»
# garantiza output = x_depth cuando Ξ»=0, sin ninguna rama condicional.
# Los FLOPs del width path para esos tokens son el ΓΊnico overhead.
N_tok_inf = B * S
K_inf = self.active_experts
E_inf = self.total_experts
C_inf = (N_tok_inf * K_inf) // E_inf
x_flat = x.reshape(-1, D)
x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1])
routing_logits = self.expert_gate(x_flat) # [N_tok, E]
tw, ti = torch.topk(routing_logits, k=K_inf, dim=-1)
tw = torch.softmax(tw, dim=-1) # [N_tok, K]
flat_expert_i = ti.reshape(-1) # [N_tokΒ·K] long
flat_tok_i = (
torch.arange(N_tok_inf, device=x.device, dtype=torch.long)
.unsqueeze(1).expand(N_tok_inf, K_inf).reshape(-1)
) # [N_tokΒ·K] long
flat_w_i = tw.reshape(-1) # [N_tokΒ·K]
perm_i = torch.argsort(flat_expert_i, stable=True) # [N_tokΒ·K]
sorted_tok_i = flat_tok_i[perm_i] # [N_tokΒ·K]
sorted_w_i = flat_w_i[perm_i] # [N_tokΒ·K]
grouped_tok_i = sorted_tok_i.view(E_inf, C_inf) # [E, C]
grouped_w_i = sorted_w_i.view(E_inf, C_inf) # [E, C]
flat_idx_i = grouped_tok_i.reshape(-1) # [EΒ·C]
fan_dim_i = x_fan_flat.shape[-1]
x_grouped_i = x_flat[flat_idx_i].view(E_inf, C_inf, D) # [E, C, D]
xf_grouped_i = x_fan_flat[flat_idx_i].view(E_inf, C_inf, fan_dim_i) # [E, C, fan_dim]
x_moe_flat_i = torch.zeros(N_tok_inf, D, device=x.device, dtype=x.dtype)
for eid in range(E_inf):
out_e_i = self._expert_forward(
xf_grouped_i[eid], x_grouped_i[eid], self.expert_idx[eid]
)
w_e_i = grouped_w_i[eid].unsqueeze(-1) # [C, 1]
tok_idx_e_i = grouped_tok_i[eid].unsqueeze(1).expand(C_inf, D)
x_moe_flat_i.scatter_add_(
0, tok_idx_e_i, (out_e_i * w_e_i).to(x_moe_flat_i.dtype)
)
x_moe = x_moe_flat_i.reshape(B, S, D)
output = (
x_depth * (1.0 - moe_weight.unsqueeze(-1))
+ x_moe * moe_weight.unsqueeze(-1)
)
# ── Analysis deposits ─────────────────────────────────────────────────
if analysis is not None:
if analysis.versatile is not None:
va = analysis.versatile
va.depth_probs = depth_probs.detach() if depth_probs is not None else None
va.expected_loops = expected_L.detach()
va.moe_weight = moe_weight.detach()
va.loop_choice = loop_choice.detach() if loop_choice is not None else None
va.x_depth = x_depth.detach()
va.x_width = x_moe.detach()
analysis.output = output.detach()
return output, aux_stats
class NeoLLMMLP(nn.Module):
"""MLP with FANformer integration and Learnable Multipliers."""
def __init__(self, config):
super().__init__()
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, "fan_ratio_ffn", 0.0625),
)
fan_dim = config.hidden_size + int(
config.hidden_size * getattr(config, "fan_ratio_ffn", 0.0625)
)
self.gate_proj = LinearWithMultipliers(
fan_dim, config.intermediate_size,
bias=False, use_row_multiplier=True, use_column_multiplier=False,
)
self.up_proj = nn.Linear(fan_dim, config.intermediate_size, bias=False)
self.down_proj = LinearWithMultipliers(
config.intermediate_size, config.hidden_size,
bias=False, use_row_multiplier=True, use_column_multiplier=True,
)
self.act_fn = PolyNorm(exclusive_init=0.00, exclusive=getattr(config, "polynorm_exclusive", True))
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
x: torch.Tensor,
analysis: Optional[MLPAnalysis] = None,
) -> torch.Tensor:
fan_a = analysis.fan if analysis is not None else None
x_fan = self.fan_layer(x, analysis=fan_a)
gate_out = self.gate_proj(x_fan)
up_out = self.up_proj(x_fan)
if analysis is not None:
analysis.gate_proj_output = gate_out.detach()
analysis.up_proj_output = up_out.detach()
poly_a = analysis.polynorm if analysis is not None else None
act_out = self.act_fn(gate_out, analysis=poly_a)
act_x_up = act_out * up_out
if analysis is not None:
analysis.act_times_up = act_x_up.detach()
result = self.down_proj(self.dropout(act_x_up))
if analysis is not None:
analysis.output = result.detach()
return result
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
"""
Decoder layer with standard residual connections, optional JTok-M injection.
Flow (JTok-M active):
1. SeeDNorm β†’ LNS(1/βˆšβ„“) β†’ Attention β†’ residual + GPAS
2. [capture h̃ = hidden after attention for JTok-M router]
3. SeeDNorm β†’ LNS(1/βˆšβ„“) β†’ MLP β†’ Ξ”m
4. h^{β„“+1} = hΜƒ + Ξ”m + Ξ”r (Ξ”r from JTok-M, scaled 1/√(2β„“))
5. GPAS
LNS coordination:
LNS factor: 1/βˆšβ„“
JTok-M factor: 1/√(2β„“) β†’ ratio = 1/√2 constant at all depths.
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.use_jtokm = config.use_jtokm
self.self_attn = NeoLLMAttention(config, layer_idx)
self.mlp = (
VersatileFFN(config)
if getattr(config, "use_versatile_ffn", False)
else NeoLLMMLP(config)
)
self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False)
self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lns_attn = LNS(layer_idx)
self.lns_mlp = LNS(layer_idx)
self.gpas_attn = GPAS(config.hidden_size)
self.gpas_mlp = GPAS(config.hidden_size)
self.current_layer_fan = None
if self.use_jtokm:
self.jtokm = LeviathanJTokM(config, layer_idx)
else:
self.jtokm = None
# ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
# Replaces fixed residual accumulation with learned softmax attention
# over preceding layer outputs. Each decoder layer has two learnable
# pseudo-queries β€” one for pre-attention and one for pre-MLP β€” plus a
# shared RMSNorm applied to keys to prevent magnitude-dominated softmax.
# Pseudo-queries are initialized to ZERO so AttnRes starts as uniform
# average (equivalent to standard residual mean) and training volatility
# is avoided. This is critical per the paper's ablation.
self.use_attn_res = getattr(config, 'use_attn_res', False)
if self.use_attn_res:
self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.attn_res_query_attn = None
self.attn_res_query_mlp = None
self.attn_res_norm = None
# ── LAuReL: Learned Augmented Residual Layer (Menghani et al., ICML 2025) ─
# Generalises the canonical residual connection with learned scalar
# weights (RW) and/or a low-rank linear correction (LR). Applied
# independently to the attention and MLP sublayers (two residual
# junctions per decoder layer).
#
# LAUREL-RW (use_laurel_rw):
# raw scalars Ξ±Μƒ, Ξ²Μƒ β†’ softmax([Ξ±Μƒ, Ξ²Μƒ]) = (Ξ±, Ξ²) bounded in (0,1)
# Residual becomes: Ξ±Β·f(x) + Ξ²Β·x
#
# LAUREL-LR (use_laurel_lr):
# A: nn.Linear(D→r, bias=False) initialised column-orthogonal
# B: nn.Linear(r→D, bias=False) initialised to zero
# Residual becomes: f(x) + B(A(x)) + x
#
# LAUREL-RW+LR (both active, paper eq. 5):
# Residual becomes: Ξ±Β·f(x) + Ξ²Β·(B(A(x)) + x)
#
# Mutex with use_attn_res is enforced at config validation time.
self.use_laurel = getattr(config, 'use_laurel', False)
self.use_laurel_rw = getattr(config, 'use_laurel_rw', True)
self.use_laurel_lr = getattr(config, 'use_laurel_lr', False)
D = config.hidden_size
r = getattr(config, 'laurel_lr_rank', 32)
if self.use_laurel and self.use_laurel_rw:
# Two raw scalars per sublayer; softmax-normalised in forward.
# Initialised to [0, 0] β†’ softmax β†’ (0.5, 0.5) at step 0,
# matching a standard equal-weight residual as the starting point.
self.laurel_rw_attn = nn.Parameter(torch.zeros(2)) # [Ξ±Μƒ_attn, Ξ²Μƒ_attn]
self.laurel_rw_mlp = nn.Parameter(torch.zeros(2)) # [Ξ±Μƒ_mlp, Ξ²Μƒ_mlp]
else:
self.laurel_rw_attn = None
self.laurel_rw_mlp = None
if self.use_laurel and self.use_laurel_lr:
# Attention sublayer low-rank matrices.
# A: D×r (projects D→r), initialised column-orthogonal per §3.3.
# B: r×D (projects r→D), initialised to zero → identity start.
self.laurel_lr_A_attn = nn.Linear(D, r, bias=False)
self.laurel_lr_B_attn = nn.Linear(r, D, bias=False)
# MLP sublayer low-rank matrices (independent capacity).
self.laurel_lr_A_mlp = nn.Linear(D, r, bias=False)
self.laurel_lr_B_mlp = nn.Linear(r, D, bias=False)
# Initialise: B→zero, A→column-orthogonal (paper footnote 2):
# A_{i,j} = 1/√(rD) if i mod r == j else 0
for A_mat in (self.laurel_lr_A_attn, self.laurel_lr_A_mlp):
nn.init.zeros_(A_mat.weight)
for j in range(r):
for i in range(D):
if i % r == j:
A_mat.weight.data[j, i] = 1.0 / (r * D) ** 0.5
for B_mat in (self.laurel_lr_B_attn, self.laurel_lr_B_mlp):
nn.init.zeros_(B_mat.weight)
else:
self.laurel_lr_A_attn = None
self.laurel_lr_B_attn = None
self.laurel_lr_A_mlp = None
self.laurel_lr_B_mlp = None
def _attn_res(
self,
sources: list,
partial: torch.Tensor,
query: torch.Tensor,
analysis_slot: Optional[AttnResAnalysis] = None,
analysis_key: Optional[str] = None,
) -> torch.Tensor:
"""
Depth-wise softmax attention over preceding layer outputs.
Computes:
V = stack(sources + [partial]) [N+1, B, S, D]
K = RMSNorm(V) [N+1, B, S, D]
logits = query Β· K [N+1, B, S]
weights = softmax(logits, dim=0) [N+1, B, S]
h = Ξ£_n weights_n Β· V_n [B, S, D]
The pseudo-query is shared across positions (per the paper design).
RMSNorm on keys prevents layers with large-magnitude outputs from
dominating the softmax. Initialized to zero β†’ uniform weights at
step 0, reducing to standard residual mean.
Args:
sources: list of [B, S, D] tensors β€” completed block summaries
or all previous layer outputs (Full AttnRes).
partial: [B, S, D] β€” current intra-block partial sum.
query: [D] β€” learnable pseudo-query for this sublayer.
Returns:
[B, S, D] β€” weighted combination of sources + partial.
"""
all_v = sources + [partial] # list of [B, S, D]
V = torch.stack(all_v, dim=0) # [N+1, B, S, D]
K = self.attn_res_norm(V) # [N+1, B, S, D]
logits = torch.einsum('d,nbsd->nbs', query, K) # [N+1, B, S]
weights = torch.softmax(logits, dim=0) # [N+1, B, S]
return torch.einsum('nbs,nbsd->bsd', weights, V) # [B, S, D]
def _laurel_residual(
self,
residual: torch.Tensor,
delta: torch.Tensor,
rw_param: Optional[torch.Tensor],
A_mat,
B_mat,
analysis: Optional["LAuReLAnalysis"] = None,
slot: str = "attn",
) -> torch.Tensor:
"""
Computes the LAuReL-augmented residual junction (Menghani et al., ICML 2025).
Dispatches among three regimes depending on which sub-variants are active:
LAUREL-RW only (use_laurel_rw=True, use_laurel_lr=False):
Ξ±, Ξ² = softmax([Ξ±Μƒ, Ξ²Μƒ])
out = Ξ± Β· delta + Ξ² Β· residual (paper Β§2.1)
LAUREL-LR only (use_laurel_rw=False, use_laurel_lr=True):
lr_delta = B(A(residual))
out = delta + lr_delta + residual (paper eq. 3)
LAUREL-RW+LR (both active, paper eq. 5):
Ξ±, Ξ² = softmax([Ξ±Μƒ, Ξ²Μƒ])
lr_delta = B(A(residual))
out = Ξ± Β· delta + Ξ² Β· (lr_delta + residual)
In all cases the standard residual identity (out = delta + residual) is
recovered at initialisation: RW starts at (Ξ±=0.5, Ξ²=0.5); LR starts
with B=0 so lr_delta=0.
Args:
residual: [B, S, D] β€” accumulated residual stream (x_i in the paper).
delta: [B, S, D] β€” sublayer output f(x_i) (attention or MLP).
rw_param: Parameter[2] β€” raw [Ξ±Μƒ, Ξ²Μƒ] before softmax; None if RW off.
A_mat: nn.Linear(D→r, bias=False) — LR down-proj; None if LR off.
B_mat: nn.Linear(r→D, bias=False) — LR up-proj; None if LR off.
analysis: LAuReLAnalysis slot to deposit scalar diagnostics into.
slot: "attn" or "mlp" β€” selects which analysis fields to write.
Returns:
[B, S, D] β€” augmented residual output for the current sublayer.
"""
has_rw = rw_param is not None
has_lr = A_mat is not None
# ── LAUREL-LR: low-rank residual correction ────────────────────────
if has_lr:
lr_delta = B_mat(A_mat(residual)) # [B, S, D]
if analysis is not None and not self.training:
in_norm = residual.detach().norm(dim=-1).mean().item()
lr_norm = lr_delta.detach().norm(dim=-1).mean().item()
if slot == "attn":
analysis.lr_input_norm_attn = in_norm
analysis.lr_delta_norm_attn = lr_norm
else:
analysis.lr_input_norm_mlp = in_norm
analysis.lr_delta_norm_mlp = lr_norm
else:
lr_delta = None
# ── LAUREL-RW: learned scalar gate ────────────────────────────────
if has_rw:
ab = torch.softmax(rw_param.float(), dim=0).to(residual.dtype)
alpha = ab[0]
beta = ab[1]
if analysis is not None and not self.training:
if slot == "attn":
analysis.alpha_attn = alpha.item()
analysis.beta_attn = beta.item()
else:
analysis.alpha_mlp = alpha.item()
analysis.beta_mlp = beta.item()
else:
alpha = beta = None
# ── Compose output ─────────────────────────────────────────────────
if has_rw and has_lr:
# LAUREL-RW+LR (paper eq. 5): Ξ±Β·f(x) + Ξ²Β·(BAx + x)
return alpha * delta + beta * (lr_delta + residual)
elif has_rw:
# LAUREL-RW (paper Β§2.1): Ξ±Β·f(x) + Ξ²Β·x
return alpha * delta + beta * residual
else:
# LAUREL-LR (paper eq. 3): f(x) + BAx + x
return delta + lr_delta + residual
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
z_tilde: Optional[torch.Tensor] = None,
B_vals: Optional[torch.Tensor] = None,
attn_res_sources: Optional[list] = None,
attn_res_partial: Optional[torch.Tensor] = None,
layer_analysis: Optional[LayerAnalysis] = None,
output_attentions: Optional[bool] = False,
repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple:
# ── Snapshot input ────────────────────────────────────────────────
if layer_analysis is not None:
layer_analysis.hidden_states_input = hidden_states.detach()
# ── Attention Residuals: compute pre-attention input ──────────────
# When active, the input to the attention sublayer is no longer the
# raw hidden_states (accumulated residual) but a softmax-weighted
# combination of all previous layer outputs (or block summaries).
# attn_res_partial carries the intra-block standard residual that
# connects the attention and MLP sublayers within this layer.
# When inactive, flow is identical to the original.
ar_analysis = layer_analysis.attn_res if layer_analysis is not None else None
if self.use_attn_res and attn_res_sources is not None and attn_res_partial is not None:
h_attn = self._attn_res(
attn_res_sources, attn_res_partial, self.attn_res_query_attn,
ar_analysis, "attn",
)
residual_attn = attn_res_partial
else:
h_attn = hidden_states
residual_attn = hidden_states
# ── Attention block ───────────────────────────────────────────────
sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
h_normed = self.input_layernorm(h_attn, analysis=sn_pre)
h_lns = self.lns_attn(h_normed)
if layer_analysis is not None:
layer_analysis.lns_attn_output = h_lns.detach()
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
hidden_states=h_lns,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
first_layer_fan=first_layer_fan,
attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
repo_rope_args=repo_rope_args,
**kwargs,
)
if layer_analysis is not None:
layer_analysis.attn_contribution = hidden_states.detach()
gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
# ── Residual junction (attention sublayer) ────────────────────────
# LAuReL active: replace fixed accumulation with learned scalar/LR
# augmentation (Menghani et al., ICML 2025). Standard path is
# preserved when use_laurel=False; both branches enter GPAS.
if self.use_laurel:
attn_aug = self._laurel_residual(
residual_attn, hidden_states,
self.laurel_rw_attn, self.laurel_lr_A_attn, self.laurel_lr_B_attn,
analysis=layer_analysis.laurel if layer_analysis is not None else None,
slot="attn",
)
else:
attn_aug = residual_attn + hidden_states
h_tilde = self.gpas_attn(attn_aug, analysis=gpas_attn_a)
if layer_analysis is not None:
layer_analysis.h_tilde = h_tilde.detach()
# ── Attention Residuals: compute pre-MLP input ────────────────────
# After attention, the partial sum is updated with h_tilde.
# The pre-MLP AttnRes attends over the same sources but with h_tilde
# as the current partial β€” capturing the within-layer attention output.
if self.use_attn_res and attn_res_sources is not None:
h_mlp = self._attn_res(
attn_res_sources, h_tilde, self.attn_res_query_mlp,
ar_analysis, "mlp",
)
residual_mlp = h_tilde
else:
h_mlp = h_tilde
residual_mlp = h_tilde
# ── MLP block ─────────────────────────────────────────────────────
sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None
h_normed2 = self.post_attention_layernorm(h_mlp, analysis=sn_post)
h_lns2 = self.lns_mlp(h_normed2)
if layer_analysis is not None:
layer_analysis.lns_mlp_output = h_lns2.detach()
mlp_a = layer_analysis.mlp if layer_analysis is not None else None
if self.use_versatile_ffn:
delta_m, versatile_aux = self.mlp(h_lns2, analysis=mlp_a)
else:
delta_m = self.mlp(h_lns2, analysis=mlp_a)
versatile_aux = None
if layer_analysis is not None:
layer_analysis.mlp_contribution = delta_m.detach()
# ── JTok-M injection (additive alongside MLP residual) ────────────
aux_stats = None
# ── Residual junction (MLP sublayer) ─────────────────────────────
# LAuReL augments the base MLP residual (residual_mlp + delta_m).
# When JTok-M is active, its additive delta_r is summed on top of
# the LAuReL output β€” JTok-M is orthogonal to the residual gate and
# always contributes as a plain additive correction.
laurel_la = layer_analysis.laurel if layer_analysis is not None else None
if self.use_laurel:
mlp_aug = self._laurel_residual(
residual_mlp, delta_m,
self.laurel_rw_mlp, self.laurel_lr_A_mlp, self.laurel_lr_B_mlp,
analysis=laurel_la,
slot="mlp",
)
else:
mlp_aug = residual_mlp + delta_m
if self.use_jtokm and z_tilde is not None and B_vals is not None:
orig_shape = h_tilde.shape
h_flat = h_tilde.reshape(-1, self.hidden_size)
z_flat = z_tilde.reshape(-1, z_tilde.shape[-1])
B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1])
jtokm_a = layer_analysis.jtokm if layer_analysis is not None else None
delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
delta_r = delta_r.reshape(orig_shape)
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
hidden_states = self.gpas_mlp(mlp_aug + delta_r, analysis=gpas_mlp_a)
else:
gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
hidden_states = self.gpas_mlp(mlp_aug, analysis=gpas_mlp_a)
if layer_analysis is not None:
layer_analysis.hidden_states_output = hidden_states.detach()
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if aux_stats is not None:
outputs += (aux_stats,)
if versatile_aux is not None:
outputs += (versatile_aux,)
return outputs
class SpellingBeeEmbedding(nn.Module):
"""
Spelling Bee Embeddings (Rabe et al., 2026, arXiv:2601.18030).
Augments token embeddings with character-level information derived from
the UTF-8 byte sequence of each token. The spelling bee embedding is the
mean of the standard token embedding and a character-level summary:
e_bee(t) = 0.5 * (e_tok(t) + e_chars(t))
e_chars(t) = inv_sqrt_len(t) * Ξ£_{i=0}^{15} RoPE(e_byte[b_i], i)
where inv_sqrt_len = 1/√|t| is precomputed per token type at setup time.
Key design decisions vs. a naΓ―ve per-occurrence implementation:
1. **Vocab-level computation** β€” e_chars is built over the full vocabulary
once per forward (shape [V, d]), then gathered by token_ids. A naΓ―ve
implementation would compute [B*S, 16, d] per step, repeating identical
work for every occurrence of a frequent token. This approach reduces
the dominant intermediate from O(BΒ·SΒ·16Β·d) to O(VΒ·16Β·d), where V β‰ͺ BΒ·S
in practice for most batches.
2. **Static [256, 16, d] rope_bytes table** β€” RoPE is applied once over
all 256 possible byte values at all 16 positions, producing a table
with fully static shapes. torch.compile / max_autotune can fuse the
construction of this table (two elementwise ops + concat over fixed
dims) into a single kernel. Token-level e_chars is then a gather +
sum over this table, also fully static.
3. **Precomputed inv_sqrt_lens** β€” 1/√byte_len is computed once in
set_byte_table and stored as a persistent buffer. The per-forward
normalisation becomes a single elementwise multiply, with no sqrt or
division in the hot path.
Compatible with both the standard embed_tokens path and the
LeviathanGenerator path.
**Inference cost: zero overhead after baking.**
Call ``bake_inference_table(token_embeds_weight)`` once after training to
collapse the SBE into a single embedding table indistinguishable from a
standard nn.Embedding lookup.
**Setup: call ``set_byte_table(tokenizer)`` once after model init** (and
before any .to(device) / FP8 conversion) before training. The byte table
and inv_sqrt_lens are persistent buffers saved in checkpoints.
References:
Rabe, Clymo & Dong (2026). "Spelling Bee Embeddings for Language
Modeling." arXiv:2601.18030.
"""
MAX_BYTES: int = 16
def __init__(self, config: "NeoLLMConfig"):
super().__init__()
d = config.hidden_size
base = getattr(config, "rope_theta", 10000.0)
# 256 Γ— d byte embedding lookup (one per UTF-8 byte value 0..255).
self.byte_emb = nn.Embedding(256, d)
# ── Persistent buffers (saved in checkpoints) ─────────────────────
# token_bytes [vocab_size, MAX_BYTES]: UTF-8 byte values per token,
# padded with 0x00 up to MAX_BYTES positions.
self.register_buffer(
"token_bytes",
torch.zeros(config.vocab_size, self.MAX_BYTES, dtype=torch.long),
persistent=True,
)
# inv_sqrt_lens [vocab_size]: precomputed 1/sqrt(byte_len) per token.
# Replaces the runtime sqrt+division of the naΓ―ve implementation.
self.register_buffer(
"inv_sqrt_lens",
torch.ones(config.vocab_size, dtype=torch.float),
persistent=True,
)
# ── Non-persistent buffers (recomputed from fixed formula on load) ─
# RoPE cos/sin for intra-token positions 0..MAX_BYTES-1.
# Shape [MAX_BYTES, d//2] β€” applied over the 256-type axis in
# _build_rope_bytes, not over the batch/sequence axis.
half = d // 2
theta = 1.0 / (base ** (torch.arange(0, half, dtype=torch.float) * 2.0 / d))
pos = torch.arange(self.MAX_BYTES, dtype=torch.float)
freqs = torch.outer(pos, theta) # [MAX_BYTES, half]
self.register_buffer("intra_cos", freqs.cos(), persistent=False)
self.register_buffer("intra_sin", freqs.sin(), persistent=False)
# Static position index [MAX_BYTES] used as the column index in the
# vocab-level gather. Registered as buffer to avoid dynamic tensor
# creation inside forward (which would trigger torch.compile retracing).
self.register_buffer(
"pos_idx",
torch.arange(self.MAX_BYTES, dtype=torch.long),
persistent=False,
)
# ── Setup ─────────────────────────────────────────────────────────────────
def set_byte_table(self, tokenizer) -> None:
"""
Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer.
Must be called **once** after model instantiation and **before**
``.to(device)`` / FP8 conversion so the buffers land on the correct
device after those transforms. Both buffers are persistent and will
be saved/restored from checkpoints automatically.
Args:
tokenizer: Any HuggingFace tokenizer with
``convert_ids_to_tokens(int) -> str | None``.
"""
vocab_size = self.token_bytes.shape[0]
byte_ids = torch.zeros(vocab_size, self.MAX_BYTES, dtype=torch.long)
inv_sqrt = torch.ones(vocab_size, dtype=torch.float) # default 1/√1
for token_id in range(vocab_size):
token_str = tokenizer.convert_ids_to_tokens(token_id)
if token_str is None:
continue
# Some tokenizers use a special space character (Δ  / ▁); encode
# directly to UTF-8 so byte values match raw text bytes.
try:
raw = token_str.encode("utf-8")
except Exception:
raw = b"\x00"
n = min(len(raw), self.MAX_BYTES)
for i in range(n):
byte_ids[token_id, i] = raw[i]
inv_sqrt[token_id] = 1.0 / math.sqrt(max(n, 1))
self.token_bytes.copy_(byte_ids.to(self.token_bytes.device))
self.inv_sqrt_lens.copy_(inv_sqrt.to(self.inv_sqrt_lens.device))
# ── Core helpers ──────────────────────────────────────────────────────────
def _build_rope_bytes(self) -> torch.Tensor:
"""
Build the static [256, MAX_BYTES, d] RoPE-encoded byte table.
For each of the 256 possible byte values and each of the MAX_BYTES
intra-token positions, applies RoPE rotation using the current
byte_emb.weight. All shapes are fully static, so torch.compile can
fuse this into a single kernel.
Called once per forward pass; the result is discarded afterward.
The cost is two broadcast elementwise ops + one cat over fixed dims.
Returns:
rope_bytes [256, MAX_BYTES, d]
"""
w = self.byte_emb.weight # [256, d]
half = w.shape[-1] // 2
w1 = w[:, :half].unsqueeze(1) # [256, 1, half]
w2 = w[:, half:].unsqueeze(1) # [256, 1, half]
cos = self.intra_cos.unsqueeze(0) # [1, MAX_BYTES, half]
sin = self.intra_sin.unsqueeze(0) # [1, MAX_BYTES, half]
return torch.cat(
[w1 * cos - w2 * sin,
w1 * sin + w2 * cos],
dim=-1,
) # [256, MAX_BYTES, d]
# ── Forward ───────────────────────────────────────────────────────────────
def forward(
self,
token_ids: torch.Tensor, # [B, S] or [N]
token_embeds: torch.Tensor, # [B, S, d] or [N, d]
) -> torch.Tensor:
"""
Args:
token_ids: integer token indices to look up byte sequences.
token_embeds: embeddings from embed_tokens or LeviathanGenerator.
Returns:
Spelling bee embeddings β€” same shape as token_embeds.
"""
# ── Step 1: build rope_bytes over 256 byte types Γ— 16 positions ───
# Shape [256, MAX_BYTES, d] β€” fully static, one kernel via compile.
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
# ── Step 2: build e_chars over vocab types, not occurrences ────────
# token_bytes [V, MAX_BYTES]: byte value at each position per token.
# pos_idx [MAX_BYTES]: column selector 0..MAX_BYTES-1.
# rope_bytes[token_bytes, pos_idx[None, :], :] selects, for each
# vocab token and each position, the RoPE-rotated embedding of that
# byte at that position. Result [V, MAX_BYTES, d], then sum β†’ [V, d].
e_chars_vocab = rope_bytes[
self.token_bytes, # [V, MAX_BYTES] β€” row index
self.pos_idx.unsqueeze(0), # [1, MAX_BYTES] β†’ broadcast [V, MAX_BYTES]
].sum(1) # [V, d]
# ── Step 3: apply precomputed 1/√byte_len per vocab type ────────────
# No sqrt or division in the hot path β€” pure multiply.
e_chars_vocab = e_chars_vocab * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
# ── Step 4: gather only the tokens present in this batch ────────────
# This is the only BΓ—S operation β€” a single embedding lookup.
e_chars = e_chars_vocab[token_ids] # [B, S, d] or [N, d]
# ── Step 5: mean with token embeddings ──────────────────────────────
return (token_embeds + e_chars) * 0.5
# ── Inference utility ─────────────────────────────────────────────────────
@torch.no_grad()
def bake_inference_table(
self,
token_emb_weight: torch.Tensor,
) -> torch.Tensor:
"""
Collapse SBE into a single [vocab_size, d] embedding table.
After baking, the SBE computation is indistinguishable from a standard
nn.Embedding lookup β€” zero additional overhead at inference time.
Args:
token_emb_weight: [vocab_size, d] β€” weight matrix of embed_tokens
or the equivalent table (e.g. after Leviathan).
Returns:
[vocab_size, d] β€” baked spelling bee embedding table.
Usage::
baked = model.model.spelling_bee.bake_inference_table(
model.model.embed_tokens.weight
)
model.model.embed_tokens.weight.copy_(baked)
# Optionally free byte_emb parameters:
# del model.model.spelling_bee
"""
rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d]
e_chars_vocab = rope_bytes[
self.token_bytes,
self.pos_idx.unsqueeze(0),
].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d]
return (token_emb_weight + e_chars_vocab) * 0.5
class NeoLLMPreTrainedModel(PreTrainedModel):
"""
Base class with custom weight initialization for all NeoLLM components.
LeviathanGenerator (real per-head architecture):
- codebooks: normal(0, initializer_range)
- head_proj[i]: normal(0, initializer_range) β€” standard for linear
- head_norm[i]: weight=1, bias=0 β€” default LayerNorm init
- head_scale: filled with (num_knots - 1) β€” matches ckhronos.py
exactly: scale initialized to stretch the knot grid
uniformly so d = |x - grid| * (num_knots-1) maps the
[0,1] input to [0, num_knots-1] at init.
- head_spline: normal(mean=1.0, std=0.1) β€” KHRONOS init_weights_prod.
The effective coefficient is (1 + delta) so the
product across d_seed dimensions starts near 1 rather
than near 0, providing stable gradients from step 0.
- head_out[i]: normal(0, initializer_range / sqrt(num_modes)) β€”
scaled by 1/√M so the sum of M head outputs starts
with the same variance as a single head projection.
- seed_proj: normal(0, initializer_range) β€” JTok-M shared path
No W_res β€” confirmed absent in the authors' implementation.
LeviathanJTokM:
- spline_coeff: normal(mean=1.0, std=0.1) β€” same as generator
- W_out: normal(0, initializer_range)
- W_res: normal(0, initializer_range)
- router: parent handles (normal init)
- scaler: ones (identity at init)
NeoLLMAttention (Affine-Scaled Attention):
- alpha_proj: normal(0, 0.02) β€” near-zero so linear_clipping(β‰ˆ0) β‰ˆ 0.5
at init, giving a mild ~0.5Γ— scaling of softmax weights
per head rather than collapsing to 0 or 1.
- alpha_ma: zeros β€” running EMA starts at 0, Ξ² starts as βˆ’Ξ±/N β‰ˆ small
negative offset; model quickly learns to adjust both.
REPOModule (Context Re-Positioning):
- W_g, W_c, W_z: default normal init from parent _init_weights.
No special initialization required β€” the SwiGLU
sub-layer starts near-zero, so z_i β‰ˆ 0 for all tokens
at step 0, which is equivalent to constant position
assignment (NoPE-like). The model quickly learns to
differentiate positions as needed.
"""
config: NeoLLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NeoLLMDecoderLayer"]
_supports_attention_backend = True
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_is_stateful = True
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, NeoLLMAttention):
if hasattr(module, "lambda_1"):
module.lambda_1.data.fill_(0.5)
if hasattr(module, "lambda_2"):
module.lambda_2.data.fill_(0.5)
if hasattr(module, "mea_key_mix") and module.mea_key_mix is not None:
# Identity initialization: at step 0 MEA behaves as standard attention
# and all matrix entries receive gradient immediately from the first step.
# For square matrices (normal training case) this is exact identity.
# For rectangular matrices (KV compression, h'<h) this is truncated identity.
nn.init.eye_(module.mea_key_mix.data) if (
module.mea_key_mix.shape[0] == module.mea_key_mix.shape[1]
) else module.mea_key_mix.data.copy_(
torch.eye(
module.mea_key_mix.shape[0], module.mea_key_mix.shape[1]
).to(device=module.mea_key_mix.device, dtype=module.mea_key_mix.dtype)
)
if hasattr(module, "mea_value_mix") and module.mea_value_mix is not None:
nn.init.eye_(module.mea_value_mix.data) if (
module.mea_value_mix.shape[0] == module.mea_value_mix.shape[1]
) else module.mea_value_mix.data.copy_(
torch.eye(
module.mea_value_mix.shape[0], module.mea_value_mix.shape[1]
).to(device=module.mea_value_mix.device, dtype=module.mea_value_mix.dtype)
)
if hasattr(module, "alpha_proj"):
nn.init.normal_(module.alpha_proj.weight, mean=0.0, std=0.02)
if hasattr(module, "alpha_ma"):
module.alpha_ma.zero_()
elif isinstance(module, GPAS):
module.alpha.data.fill_(0.0)
elif isinstance(module, FANLayer):
pass # handled in FANLayer.__init__
elif isinstance(module, SeeDNorm):
pass # gamma=1, beta=0, alpha=1 set in __init__
elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
if hasattr(module, "multiplier"):
module.multiplier.data.fill_(1.0)
elif isinstance(module, LeviathanGenerator):
nn.init.normal_(module.codebooks,
mean=0.0, std=self.config.initializer_range)
# head_scale: (num_knots - 1) exactly, matching ckhronos.py
module.head_scale.data.fill_(float(module.num_knots - 1))
# head_spline: KHRONOS init_weights_prod β€” near 1.0 for stability
nn.init.normal_(module.head_spline, mean=1.0, std=0.1)
# head_proj_weight: rows are [M * d_seed], cols are [d_seed].
# Each block of d_seed rows corresponds to one head β€” initialize
# with standard normal so each head's projection starts at the
# same scale as a regular linear layer.
nn.init.normal_(module.head_proj_weight,
mean=0.0, std=self.config.initializer_range)
# head_norm_weight: ones, head_norm_bias: zeros β€” standard LN init
module.head_norm_weight.data.fill_(1.0)
module.head_norm_bias.data.zero_()
# head_out_weight: scaled by 1/√num_modes so the sum of M head
# outputs starts with the same variance as a single projection.
out_std = self.config.initializer_range / math.sqrt(module.num_modes)
nn.init.normal_(module.head_out_weight, mean=0.0, std=out_std)
elif isinstance(module, LeviathanJTokM):
# Same KHRONOS init for surface spline coefficients
nn.init.normal_(module.spline_coeff, mean=1.0, std=0.1)
# W_out and W_res: standard normal
nn.init.normal_(module.W_out,
mean=0.0, std=self.config.initializer_range)
nn.init.normal_(module.W_res,
mean=0.0, std=self.config.initializer_range)
# scaler: ones (identity, JTok-M starts neutral)
module.scaler.data.fill_(1.0)
# router: parent handles via normal init
elif isinstance(module, NeoLLMDecoderLayer):
# AttnRes pseudo-queries: MUST be initialized to zero.
# Zero initialization ensures uniform attention weights at step 0
# (softmax of zeros is uniform), making AttnRes equivalent to a
# standard residual mean at the start of training. Non-zero init
# causes training instability per the paper's ablation (Β§5).
if hasattr(module, 'attn_res_query_attn') and module.attn_res_query_attn is not None:
module.attn_res_query_attn.data.zero_()
module.attn_res_query_mlp.data.zero_()
elif isinstance(module, SpellingBeeEmbedding):
# byte_emb initialised identically to token embeddings: std=1/√d.
# Ensures E[β€–e_byteβ€–Β²] β‰ˆ 1 at init, matching etok, so the
# normalisation factor Ξ± = sqrt(byte_len) is calibrated from step 0.
d = module.byte_emb.embedding_dim
nn.init.normal_(module.byte_emb.weight, mean=0.0, std=1.0 / math.sqrt(d))
class NeoLLMModel(NeoLLMPreTrainedModel):
"""
NeoLLM base decoder-only Transformer.
When use_jtokm=True, the generator returns (embeddings, z_tilde, B_vals).
z_tilde and B_vals are passed to every decoder layer so JTok-M surfaces
can reuse the B-spline basis without recomputation.
When use_attn_res=True, the model maintains a list of previous layer
outputs (or block summaries for Block AttnRes) and passes them to each
decoder layer, replacing fixed residual accumulation with learned
depth-wise softmax attention (Kimi Team, 2026, arXiv:2603.15031).
Spelling Bee Embeddings (Rabe et al., 2026, arXiv:2601.18030):
``use_spelling_bee_embeddings`` is independent of the Leviathan flag:
SBE is applied post-embedding regardless of which path produced the
token embeddings (embed_tokens or LeviathanGenerator).
Flag coupling:
- use_token_generator=False, use_spelling_bee_embeddings=False
β†’ standard embed_tokens, no SBE [default]
- use_token_generator=False, use_spelling_bee_embeddings=True
β†’ standard embed_tokens + SBE
- use_token_generator=True, use_spelling_bee_embeddings=False
β†’ LeviathanGenerator only, no SBE
- use_token_generator=True, use_spelling_bee_embeddings=True
β†’ LeviathanGenerator + SBE
Setup: call ``model.model.spelling_bee.set_byte_table(tokenizer)``
once after model init and before training (and after any .to(device)).
"""
def __init__(self, config: NeoLLMConfig):
super().__init__(config)
# ── Embedding path ────────────────────────────────────────────────────
if config.use_token_generator:
self.token_generator = LeviathanGenerator(config)
else:
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
# ── Spelling Bee Embeddings (Rabe et al., 2026) ───────────────────────
# Active when use_spelling_bee_embeddings=True, compatible with both
# the embed_tokens and LeviathanGenerator paths.
use_sbe = getattr(config, "use_spelling_bee_embeddings", False)
if use_sbe:
self.spelling_bee = SpellingBeeEmbedding(config)
else:
self.spelling_bee = None
self.layers = nn.ModuleList(
[NeoLLMDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)]
)
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.first_layer_fan = None
self.post_init()
def get_input_embeddings(self):
if self.config.use_token_generator:
return self.token_generator
return self.embed_tokens
def set_input_embeddings(self, value):
if self.config.use_token_generator:
self.token_generator = value
else:
self.embed_tokens = value
def _build_layer_analysis(self, layer_idx: int = 0) -> LayerAnalysis:
"""
Construct a LayerAnalysis with sub-objects pre-allocated for every
component that is active in the current config.
Fields that correspond to disabled flags remain None β€” the analysis
consumer can check for None without needing to inspect config flags.
Called once per layer per forward when analysis is active.
"""
cfg = self.config
_repo_active = (
getattr(cfg, "use_repo", False)
and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
)
_versatile = getattr(cfg, "use_versatile_ffn", False)
return LayerAnalysis(
seednorm_pre_attn = SeeDNormAnalysis(),
seednorm_post_attn = SeeDNormAnalysis(),
attention = AttentionAnalysis(
fan = FANAnalysis(),
hadamard = HadamardAnalysis() if getattr(cfg, "use_hadamard_o_proj", False) else None,
repo = REPOAnalysis() if _repo_active else None,
),
mlp = MLPAnalysis(
fan = FANAnalysis() if not _versatile else None,
polynorm = PolyNormAnalysis() if not _versatile else None,
versatile = VersatileFFNAnalysis() if _versatile else None,
),
gpas_attn = GPASAnalysis(),
gpas_mlp = GPASAnalysis(),
jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
laurel = LAuReLAnalysis() if getattr(cfg, "use_laurel", False) else None,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
analysis_state: Optional[AnalysisState] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tuple:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
if return_dict is None:
cfg_dict = vars(self.config)
return_dict = cfg_dict.get(
"return_dict",
cfg_dict.get("use_return_dict", True),
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("Specify exactly one of input_ids or inputs_embeds")
# ── Embedding stage ────────────────────────────────────────────────
z_tilde = None
B_vals = None
gen_a = (
analysis_state.generator
if analysis_state is not None and self.config.use_token_generator
else None
)
if inputs_embeds is None:
if self.config.use_token_generator:
if self.config.use_jtokm:
# Return internals for reuse by JTok-M surfaces
inputs_embeds, z_tilde, B_vals = self.token_generator(
input_ids, return_internals=True, analysis=gen_a
)
# Reshape to [batch, seq, d_seed] and [batch, seq, d_seed, n_knots]
z_tilde = z_tilde.reshape(*input_ids.shape, self.config.generator_d_seed)
B_vals = B_vals.reshape(
*input_ids.shape,
self.config.generator_d_seed,
self.config.generator_num_knots,
)
else:
inputs_embeds = self.token_generator(input_ids, analysis=gen_a)
else:
inputs_embeds = self.embed_tokens(input_ids)
# ── Spelling Bee Embeddings (applied post-embedding, pre-decoder) ──────
# input_ids may be None when inputs_embeds was passed directly by the
# caller; in that case SBE cannot run (no token_ids available) and is
# silently skipped β€” consistent with the standard embedding bypass path.
if self.spelling_bee is not None and input_ids is not None:
inputs_embeds = self.spelling_bee(input_ids, inputs_embeds)
if analysis_state is not None:
analysis_state.embeddings = inputs_embeds.detach()
if position_ids is None:
position_ids = torch.arange(
0, inputs_embeds.shape[1], device=inputs_embeds.device
).unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device
),
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_aux_stats = []
position_embeddings = self.rotary_emb(hidden_states, position_ids)
self.first_layer_fan = None
# ── REPO: pass inv_freq by reference at forward time ──────────────────
# rotary_emb.inv_freq is already on the correct device (managed by
# NeoLLMRotaryEmbedding as a buffer) β€” no .to(), no DeviceCopy op.
# Computed once here and passed through the decoder layer chain so
# NeoLLMAttention never needs to store it as a buffer itself, avoiding
# the meta-tensor issue that occurs when lm_eval calls .to(device).
repo_rope_args = (
(self.rotary_emb.inv_freq, self.rotary_emb.attention_scaling)
if getattr(self.config, "use_repo", False) else None
)
# ── Attention Residuals state ──────────────────────────────────────
# Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
# decoder layer β€” all previous outputs are kept, max N=num_layers+1.
# Block AttnRes (attn_res_num_blocks>0): sources grows by one entry per
# block boundary β€” at most num_blocks+1 entries, far less memory.
# In both modes, attn_res_partial is the current intra-block accumulated
# hidden state that connects the attn and MLP sublayers and flows between
# decoder layers within a block.
use_attn_res = getattr(self.config, 'use_attn_res', False)
attn_res_sources = None
attn_res_partial = None
if use_attn_res:
attn_res_sources = [hidden_states] # b_0 = token embedding
attn_res_partial = hidden_states # initial partial sum
num_blocks = getattr(self.config, 'attn_res_num_blocks', 0)
block_size = (
max(self.config.num_hidden_layers // num_blocks, 1)
if num_blocks > 0
else 1 # Full AttnRes: every layer is its own "block"
)
# Pre-allocate per-layer analysis list when analysis is active
if analysis_state is not None:
analysis_state.layers = []
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# ── Block AttnRes: boundary handling ──────────────────────────
# At each block boundary (excluding layer 0): append the current
# partial sum to sources as a completed block summary, then reset
# partial to None so the new block builds from scratch β€” matching
# the paper's pseudocode exactly.
# For Full AttnRes (block_size=1): every layer is a boundary, so
# partial is appended and reset after every layer. The partial is
# re-seeded from the previous hidden_states below.
if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0:
attn_res_sources = attn_res_sources + [attn_res_partial]
attn_res_partial = hidden_states # start new block from current output
# Build per-layer analysis container (only in eval + analysis mode)
layer_analysis = None
if analysis_state is not None:
layer_analysis = self._build_layer_analysis(layer_idx)
layer_analysis.layer_idx = layer_idx
analysis_state.layers.append(layer_analysis)
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
first_layer_fan=self.first_layer_fan,
z_tilde=z_tilde,
B_vals=B_vals,
attn_res_sources=attn_res_sources,
attn_res_partial=attn_res_partial if use_attn_res else None,
layer_analysis=layer_analysis,
output_attentions=output_attentions,
repo_rope_args=repo_rope_args,
**kwargs,
)
hidden_states = layer_outputs[0]
# Update AttnRes partial sum β€” the new partial is the layer output
if use_attn_res:
attn_res_partial = hidden_states
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Collect JTok-M aux stats (last element if present)
if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1):
all_aux_stats.append(layer_outputs[-1])
# Collect VersatileFFN aux stats (second-to-last if jtokm also present,
# or last if jtokm is absent). Only non-None during training.
if getattr(self.config, "use_versatile_ffn", False):
for item in layer_outputs[1:]:
if isinstance(item, tuple) and len(item) == 3:
# (p_sum, f_sum, N_tokens) signature
all_aux_stats.append(("versatile", item))
break
if (self.first_layer_fan is None
and hasattr(decoder_layer, "current_layer_fan")):
self.first_layer_fan = decoder_layer.current_layer_fan
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# ── Finalise analysis snapshot ─────────────────────────────────────
if analysis_state is not None:
analysis_state.final_hidden_states = hidden_states.detach()
analysis_state.jtokm_aux_stats = all_aux_stats if self.config.use_jtokm else None
analysis_state.attn_res_sources_final = (
attn_res_sources if use_attn_res else None
)
if not return_dict:
return tuple(
v for v in [hidden_states, None, all_hidden_states, all_attentions]
if v is not None
) + (all_aux_stats,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_attentions,
), all_aux_stats
@torch.compiler.disable
def compute_cce_loss(
hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None
):
"""CCE loss excluded from torch.compile."""
processed_labels = labels.to(hidden_states.device)
if pad_token_id is not None:
processed_labels = torch.where(
processed_labels == pad_token_id,
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
processed_labels,
)
return linear_cross_entropy(
hidden_states, lm_head_weight, processed_labels,
bias=lm_head_bias, shift=1, impl="cce_kahan_full_c", reduction="mean",
)
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
"""
Causal LM with NeoLLM backbone.
When use_jtokm=True, the load-balancing auxiliary loss is computed from
the per-layer JTok-M routing statistics and added to the cross-entropy loss:
total_loss = CE_loss + L_aux
where L_aux = Ξ» Β· n_e Β· (1/L) Β· Ξ£_β„“ Ξ£_i p_i^β„“ Β· f_i^β„“
── Analysis mode ──────────────────────────────────────────────────────────
Analysis is NEVER active during training (model.training=True).
It is opt-in at inference time:
model.eval()
model.enable_analysis() # arm the system
with torch.no_grad():
_ = model(input_ids)
state = model.last_analysis # AnalysisState with all internals
model.disable_analysis() # disarm β€” zero overhead again
# or: model.train() # training always disarms automatically
last_analysis is replaced at each forward call when analysis is armed.
last_analysis is None between enable_analysis() and the first forward,
and cleared by disable_analysis().
Example field access:
state.layers[3].attention.alpha_per_head # Affine-Scaled Ξ±
state.layers[0].mlp.polynorm.x2_post_exclusive
state.generator.z_tilde # Leviathan latent
state.layers[2].jtokm.routing_weights # JTok-M router
state.layers[5].attn_res.weights_pre_attn # AttnRes softmax
"""
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: NeoLLMConfig):
super().__init__(config)
self.model = NeoLLMModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.use_token_generator:
self._tied_weights_keys = {}
# ── Analysis infrastructure ───────────────────────────────────────
# _analysis_armed: set by enable_analysis() / disable_analysis().
# last_analysis: populated after each forward when armed + eval.
# Neither is an nn.Parameter or buffer β€” zero effect on training.
self._analysis_armed: bool = False
self.last_analysis: Optional[AnalysisState] = None
self.post_init()
# ── Public analysis API ───────────────────────────────────────────────
def enable_analysis(self) -> None:
"""
Arm the analysis system.
Has no effect during training β€” analysis is always off when
model.training=True, regardless of this flag. Safe to call at any
point without affecting the training loop.
Usage::
model.eval()
model.enable_analysis()
with torch.no_grad():
_ = model(input_ids)
state = model.last_analysis
"""
self._analysis_armed = True
def disable_analysis(self) -> None:
"""
Disarm the analysis system and clear last_analysis.
After this call, forward passes produce zero analysis overhead.
"""
self._analysis_armed = False
self.last_analysis = None
def _make_analysis_state(
self,
input_ids: Optional[torch.Tensor],
) -> Optional[AnalysisState]:
"""
Decide whether to build an AnalysisState for this forward pass.
Returns None (zero cost) when:
- analysis is not armed, OR
- the model is in training mode (self.training=True).
Otherwise returns a fresh AnalysisState with top-level optional
fields pre-allocated according to the active config flags.
"""
if not self._analysis_armed or self.training:
return None
cfg = self.config
return AnalysisState(
input_ids = input_ids.detach() if input_ids is not None else None,
generator = GeneratorAnalysis() if cfg.use_token_generator else None,
layers = None, # filled by NeoLLMModel.forward
jtokm_aux_stats = [] if cfg.use_jtokm else None,
attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
)
# ── Standard model API ────────────────────────────────────────────────
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values=None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> dict:
"""
NeoLLM does not implement KV caching (always returns past_key_values=None).
Transformers' default GenerationMixin.prepare_inputs_for_generation assumes
KV cache is active and slices input_ids to only the newest token on every
step past the prefill. Without a real cache that retains previous K/V
states, the attention module only sees 1 key/value pair while the causal
mask still spans the full context length β€” causing a shape mismatch in SDPA.
This override always forwards the COMPLETE input_ids sequence so the model
can recompute attention over the full context from scratch at every step.
Generation is therefore slower (no caching benefit) but numerically correct.
"""
model_inputs: dict = {"input_ids": input_ids, "attention_mask": attention_mask}
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
# ── Build analysis container (None during training or when disarmed) ──
analysis_state = self._make_analysis_state(input_ids)
model_out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
analysis_state=analysis_state,
**kwargs,
)
# Unpack: model returns (BaseModelOutputWithPast, aux_stats_list)
# or a tuple when return_dict=False
if isinstance(model_out, tuple):
outputs, all_aux_stats = model_out[0], model_out[-1]
if isinstance(outputs, tuple):
hidden_states = outputs[0]
else:
hidden_states = outputs.last_hidden_state
else:
outputs = model_out
all_aux_stats = []
hidden_states = outputs.last_hidden_state
loss = None
if labels is not None:
loss = compute_cce_loss(
hidden_states, labels, self.lm_head.weight,
getattr(self.lm_head, "bias", None), self.config.pad_token_id,
)
# Add JTok-M load-balancing auxiliary loss
if self.config.use_jtokm and all_aux_stats:
jtokm_stats = [
s for s in all_aux_stats
if not (isinstance(s, tuple) and len(s) == 2 and s[0] == "versatile")
]
if jtokm_stats:
aux_loss = compute_jtokm_aux_loss(
jtokm_stats,
n_e=self.config.jtokm_num_experts,
weight=self.config.jtokm_aux_loss_weight,
)
loss = loss + aux_loss
# Add VersatileFFN load-balancing auxiliary loss
if getattr(self.config, "use_versatile_ffn", False) and all_aux_stats:
versatile_stats = [
s[1] for s in all_aux_stats
if isinstance(s, tuple) and len(s) == 2 and s[0] == "versatile"
]
if versatile_stats:
v_loss = compute_versatile_aux_loss(
versatile_stats,
n_experts=self.config.versatile_total_experts,
weight=self.config.versatile_aux_loss_weight,
)
loss = loss + v_loss
logits = None
else:
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int) else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
# ── Finalise and store analysis state ─────────────────────────────
if analysis_state is not None:
analysis_state.logits = logits.detach() if logits is not None else None
self.last_analysis = analysis_state
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
)
# ==================== AUTOMODEL REGISTRATION ====================
__all__ = [
"NeoLLMForCausalLM",
"NeoLLMModel",
"NeoLLMPreTrainedModel",
"NeoLLMConfig",
"LeviathanGenerator",
"LeviathanJTokM",
"SpellingBeeEmbedding",
"FANLayer",
"SeeDNorm",
"ScalarMultiplier",
"VectorMultiplier",
"LinearWithMultipliers",
"MEAHeadSeeDNorm",
"HadamardOProj",
"REPOModule",
"VersatileFFN",
"compute_versatile_aux_loss",
# Analysis dataclasses β€” exported so external tools can type-hint against them
"AnalysisState",
"LayerAnalysis",
"AttentionAnalysis",
"MLPAnalysis",
"FANAnalysis",
"SeeDNormAnalysis",
"GPASAnalysis",
"PolyNormAnalysis",
"HadamardAnalysis",
"REPOAnalysis",
"VersatileFFNAnalysis",
"JTokMAnalysis",
"AttnResAnalysis",
"GeneratorAnalysis",
]
AutoConfig.register("neollm", NeoLLMConfig)
AutoModel.register(NeoLLMConfig, NeoLLMModel)
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)