| |
| """ |
| NeoLLM model with FANformer, SeeDNorm, ResFormer, Learnable Multipliers, |
| full attention augmented with optional Momentum, MEA, and LUCID operators, |
| Gated Attention (Qiu et al., 2025) combined with Affine-Scaled Attention |
| (Bae et al., 2026), an optional Leviathan continuous token embedding |
| generator, an optional Leviathan-JTok-M token-indexed modulation module, |
| optional Spelling Bee Embeddings (Rabe et al., 2026), and optional Context |
| Re-Positioning (Li et al., 2026). |
| |
| Attention stack (orthogonal, all active simultaneously when enabled): |
| 1. Gated Attention (use_gated_attention implicit via q_proj gate chunk): |
| applies a head-specific elementwise sigmoid gate to the concatenated |
| SDPA output before o_proj (G1 position, Qiu et al. 2025 Β§2.2). |
| Introduces non-linearity between W_V and W_O, sparse input-dependent |
| gating, and eliminates attention sink. |
| 2. Affine-Scaled Attention (use_affine_scaled_attention): |
| modulates softmax attention weights directly as |
| [Ξ±(X)Β·softmax(QK^T/βdk) + Ξ²(X)] V |
| relaxing the unit-sum constraint of softmax. Ξ± is per-head, per-query, |
| input-dependent and bounded in [0,1] via linear_clipping. Ξ² is a |
| moving-average bias that prevents collapse. Reduces first-token bias, |
| increases attention entropy, and is complementary to Gated Attention |
| (Bae et al. 2026, Table 2: Affine-Scaled + Gated > either alone). |
| |
| Flash/SDPA path (exact, no approximation): |
| Expanding [Ξ±Β·softmax(QKα΅)+Ξ²]V distributively yields two terms: |
| Ξ± Β· flash_attn(Q,K,V) β Flash computes this directly |
| + Ξ² Β· Ξ£_{jβ€i} V_j β causal prefix-sum of V (V.cumsum) |
| The V_cumsum tensor [B,H,S,d_head] is the only extra memory vs a |
| standard flash call. Semantically identical to the eager path; |
| per-weight softmax tensors (attn_weights_pre/post_affine) are |
| unavailable in flash mode and remain None in AnalysisState. |
| |
| Eager path: exact with full weight access, used for interpretability. |
| |
| References: |
| FANformer: "FANformer: Improving Large Language Models Through Effective |
| Periodicity Modeling" |
| SeeDNorm: "Self-Rescaled Dynamic Normalization" |
| Learnable Multipliers: "Learnable Multipliers: Freeing the Scale of Language |
| Model Matrix Layers" |
| Gated Attention: Qiu et al. (2025). "Gated Attention for Large Language |
| Models: Non-linearity, Sparsity, and Attention-Sink-Free." |
| arXiv:2505.06708. |
| Affine-Scaled Attention: Bae et al. (2026). "Affine-Scaled Attention: |
| Towards Flexible and Stable Transformer Attention." arXiv:2602.23057. |
| Leviathan Generator: Batley & Saha (2026). "A Separable Architecture for |
| Continuous Token Representation in Language Models." arXiv:2601.22040. |
| KHRONOS: Batley & Saha (2025). "KHRONOS: a Kernel-Based Neural Architecture |
| for Rapid, Resource-Efficient Scientific Computation." arXiv:2505.13315. |
| JTok / JTok-M: Yang et al. (2026). "JTok: On Token Embedding as Another |
| Axis of Scaling Law via Joint Token Self-Modulation." arXiv:2602.00800. |
| Spelling Bee Embeddings: Rabe, Clymo & Dong (2026). "Spelling Bee |
| Embeddings for Language Modeling." arXiv:2601.18030. |
| Context Re-Positioning: Li, Zhao, Cai & Sproat (2026). "REPO: Language |
| Models with Context Re-Positioning." arXiv:2512.14391. |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Callable, List, Optional, Union, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from cut_cross_entropy import linear_cross_entropy |
| from torch.utils.checkpoint import checkpoint |
|
|
| from transformers.activations import ACT2FN |
| from transformers.generation import GenerationMixin |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, logging |
| from .configuration_neollm import NeoLLMConfig |
|
|
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
| torch._dynamo.config.capture_scalar_outputs = True |
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @dataclass |
| class FANAnalysis: |
| """ |
| Decomposed output of a FANLayer call. |
| FANLayer'(X) = [cos(WpΒ·X) β sin(WpΒ·X) β (WpΜΒ·X + BpΜ)] |
| """ |
| cosine_component: Optional[torch.Tensor] = None |
| sine_component: Optional[torch.Tensor] = None |
| linear_component: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class SeeDNormAnalysis: |
| """ |
| Internals of a SeeDNorm forward pass. |
| SeeDNorm(x) = [Ο(xΒ·Ξ²^T)Β·Ξ± + Ξ³] β x/RMS(x) |
| """ |
| rescale_factor: Optional[torch.Tensor] = None |
| dynamic_scale: Optional[torch.Tensor] = None |
| x_normalized: Optional[torch.Tensor] = None |
| output: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class GPASAnalysis: |
| """ |
| Internals of a GPAS forward pass. |
| GPAS(x) = x - silu(Ξ±)Β·x_detached |
| """ |
| silu_alpha: Optional[torch.Tensor] = None |
| subtracted_component: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class PolyNormAnalysis: |
| """ |
| Internals of a PolyNorm forward pass (used as MLP act_fn). |
| Branches: x1 (linear), x2 (quadratic), x3 (cubic), each normalized. |
| x2 and x3 are partially orthogonalized against x1 via learned Ξ±. |
| """ |
| x1: Optional[torch.Tensor] = None |
| x2_pre_exclusive: Optional[torch.Tensor] = None |
| x3_pre_exclusive: Optional[torch.Tensor] = None |
| x2_post_exclusive: Optional[torch.Tensor] = None |
| x3_post_exclusive: Optional[torch.Tensor] = None |
| alpha2: Optional[torch.Tensor] = None |
| alpha3: Optional[torch.Tensor] = None |
| weights: Optional[torch.Tensor] = None |
| bias: Optional[torch.Tensor] = None |
| output: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class HadamardAnalysis: |
| """ |
| Internals of a HadamardOProj forward pass. |
| Only populated when use_hadamard_o_proj=True. |
| |
| Reference: Aggarwal & Kumar (2026). arXiv:2603.08343. |
| |
| post_fwht: WHT output before Ξ± scaling [..., D] β useful to verify |
| that the transform is truly norm-preserving (ΞΊ=1 sanity check). |
| alpha_snapshot: detached copy of the learnable Ξ± vector [D] β tracks how |
| per-channel scaling evolves during training analysis. |
| """ |
| post_fwht: Optional[torch.Tensor] = None |
| alpha_snapshot: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class REPOAnalysis: |
| """ |
| Internals of a REPOModule forward pass. |
| Only populated when use_repo=True and layer_idx >= repo_start_layer. |
| |
| Reference: Li, Zhao, Cai & Sproat (2026). arXiv:2512.14391. |
| |
| positions: continuous per-head positions z [B, H, S] produced by f_Ο. |
| Captures what position pattern the model learned for each head: |
| constant (NoPE-like), monotonic (RoPE-like), or hybrid. |
| Use this field with the attention interpretability toolkit to |
| reproduce the position-pattern analysis of Li et al. (2026) Β§5.2. |
| r_repr: intermediate position representation r [B, S, d_p] β output of |
| the SwiGLU sub-layer before the per-head linear W_z. |
| Shared across heads within the layer. |
| """ |
| positions: Optional[torch.Tensor] = None |
| r_repr: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class AttentionAnalysis: |
| """ |
| Full attention internals for one NeoLLMAttention forward pass. |
| Fields present unconditionally capture the always-active path. |
| Fields guarded by config flags are None when that flag is off. |
| """ |
| |
| fan: Optional[FANAnalysis] = None |
|
|
| |
| q_raw: Optional[torch.Tensor] = None |
| gate_raw: Optional[torch.Tensor] = None |
| gate_sigmoid: Optional[torch.Tensor] = None |
| q_post_norm: Optional[torch.Tensor] = None |
| k_post_norm: Optional[torch.Tensor] = None |
| v_raw: Optional[torch.Tensor] = None |
|
|
| |
| |
| |
| |
| |
| |
| |
| q_post_rope: Optional[torch.Tensor] = None |
| k_post_rope: Optional[torch.Tensor] = None |
|
|
| |
| q_momentum_delta: Optional[torch.Tensor] = None |
| k_momentum_delta: Optional[torch.Tensor] = None |
| q_post_momentum: Optional[torch.Tensor] = None |
| k_post_momentum: Optional[torch.Tensor] = None |
|
|
| |
| mea_key_mix_matrix: Optional[torch.Tensor] = None |
| mea_value_mix_matrix: Optional[torch.Tensor] = None |
| k_post_mea: Optional[torch.Tensor] = None |
| v_post_mea: Optional[torch.Tensor] = None |
|
|
| |
| lucid_preconditioner: Optional[torch.Tensor] = None |
| v_post_lucid: Optional[torch.Tensor] = None |
|
|
| |
| alpha_per_head: Optional[torch.Tensor] = None |
| beta_per_head: Optional[torch.Tensor] = None |
| alpha_moving_avg: Optional[torch.Tensor] = None |
| attn_weights_pre_affine: Optional[torch.Tensor] = None |
| attn_weights_post_affine: Optional[torch.Tensor] = None |
|
|
| |
| attn_weights: Optional[torch.Tensor] = None |
|
|
| |
| attn_output_raw: Optional[torch.Tensor] = None |
|
|
| |
| attn_output_post_mea_norm: Optional[torch.Tensor] = None |
|
|
| |
| xsa_self_position_component: Optional[torch.Tensor] = None |
| attn_output_post_xsa: Optional[torch.Tensor] = None |
|
|
| |
| direction_vecs_normalized: Optional[torch.Tensor] = None |
| dr_router_logits: Optional[torch.Tensor] = None |
| dr_routing_weights: Optional[torch.Tensor] = None |
| dr_projection: Optional[torch.Tensor] = None |
| dr_suppression: Optional[torch.Tensor] = None |
| attn_output_post_routing: Optional[torch.Tensor] = None |
|
|
| |
| attn_output_pre_gate: Optional[torch.Tensor] = None |
| attn_output_final: Optional[torch.Tensor] = None |
|
|
| |
| hadamard: Optional["HadamardAnalysis"] = None |
|
|
| |
| repo: Optional["REPOAnalysis"] = None |
|
|
|
|
| @dataclass |
| class MLPAnalysis: |
| """ |
| Internals of a NeoLLMMLP forward pass. |
| SwiGLU-like: down_proj(dropout(PolyNorm(gate_proj(fan)) Β· up_proj(fan))) |
| When use_versatile_ffn=True, the versatile sub-object carries all |
| VersatileFFN internals and fan/gate/up/polynorm fields remain None. |
| """ |
| fan: Optional[FANAnalysis] = None |
| gate_proj_output: Optional[torch.Tensor] = None |
| up_proj_output: Optional[torch.Tensor] = None |
| polynorm: Optional[PolyNormAnalysis] = None |
| act_times_up: Optional[torch.Tensor] = None |
| output: Optional[torch.Tensor] = None |
|
|
| |
| versatile: Optional["VersatileFFNAnalysis"] = None |
|
|
|
|
| @dataclass |
| class VersatileFFNAnalysis: |
| """ |
| Internals of a VersatileFFN forward pass. |
| Only populated when use_versatile_ffn=True. |
| |
| Reference: Nie et al. (2026). arXiv:2512.14531. |
| |
| depth_probs: Gumbel-Softmax distribution p [B, S, max_depth] β only in |
| training (hard=True STE). None during inference. |
| expected_loops: E[L] per token [B, S] β differentiable proxy for difficulty. |
| During inference computed from discrete argmax+1. |
| moe_weight: Ξ» = (L_max - E[L]) / L_max [B, S] β fusion gate scalar. |
| Near 1 β width dominates (easy token). |
| Near 0 β depth dominates (hard token). |
| loop_choice: argmax(depth_logits) [B, S] β discrete depth selected at |
| inference. None during training. |
| x_depth: Output of the depth-versatile path [B, S, D]. |
| x_width: Output of the width-versatile MoE path [B, S, D]. |
| """ |
| depth_probs: Optional[torch.Tensor] = None |
| expected_loops: Optional[torch.Tensor] = None |
| moe_weight: Optional[torch.Tensor] = None |
| loop_choice: Optional[torch.Tensor] = None |
| x_depth: Optional[torch.Tensor] = None |
| x_width: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class JTokMAnalysis: |
| """ |
| Internals of a LeviathanJTokM forward pass for one decoder layer. |
| """ |
| surfaces: Optional[torch.Tensor] = None |
| router_logits: Optional[torch.Tensor] = None |
| topk_indices: Optional[torch.Tensor] = None |
| routing_weights: Optional[torch.Tensor] = None |
| mixed_pre_norm: Optional[torch.Tensor] = None |
| mixed_normalized: Optional[torch.Tensor] = None |
| delta_r: Optional[torch.Tensor] = None |
| p_sum: Optional[torch.Tensor] = None |
| f_sum: Optional[torch.Tensor] = None |
| lns_scale: Optional[float] = None |
|
|
|
|
| @dataclass |
| class AttnResAnalysis: |
| """ |
| Depth-wise softmax attention weights from AttnRes sublayer calls. |
| Only populated when use_attn_res=True. |
| """ |
| weights_pre_attn: Optional[torch.Tensor] = None |
| weights_pre_mlp: Optional[torch.Tensor] = None |
| sources_count: Optional[int] = None |
|
|
|
|
| @dataclass |
| class LAuReLAnalysis: |
| """ |
| Internals of the LAuReL residual augmentation for one decoder layer. |
| Only populated when use_laurel=True. |
| |
| Two slots are provided β one per sublayer (attention, MLP) β reflecting |
| that LAuReL is applied independently at each residual junction. |
| |
| LAuReL-RW (use_laurel_rw=True): |
| alpha_attn / alpha_mlp β effective Ξ± after softmax [scalar] |
| beta_attn / beta_mlp β effective Ξ² after softmax [scalar] |
| The applied residual is Ξ±Β·f(x) + Ξ²Β·x (or Ξ²Β·(BAx+x) with LR). |
| |
| LAuReL-LR (use_laurel_lr=True): |
| lr_input_norm_attn / lr_input_norm_mlp β βxββ entering the low-rank |
| path, useful for monitoring magnitude growth or collapse [scalar]. |
| lr_delta_norm_attn / lr_delta_norm_mlp β βBAxββ after the low-rank |
| correction; relative to lr_input_norm reveals how much the LR |
| branch contributes [scalar]. |
| |
| When both variants are active (LAUREL-RW+LR, paper eq. 5): |
| x_{i+1} = Ξ±Β·f(x) + Ξ²Β·(BAx + x) |
| All six fields are populated simultaneously. |
| """ |
| |
| alpha_attn: Optional[float] = None |
| beta_attn: Optional[float] = None |
| alpha_mlp: Optional[float] = None |
| beta_mlp: Optional[float] = None |
|
|
| |
| lr_input_norm_attn: Optional[float] = None |
| lr_delta_norm_attn: Optional[float] = None |
| lr_input_norm_mlp: Optional[float] = None |
| lr_delta_norm_mlp: Optional[float] = None |
|
|
|
|
| @dataclass |
| class LayerAnalysis: |
| """ |
| Complete analysis snapshot for one NeoLLMDecoderLayer forward pass. |
| Sub-objects are None when the corresponding config flag is inactive. |
| """ |
| layer_idx: int = 0 |
|
|
| |
| hidden_states_input: Optional[torch.Tensor] = None |
| hidden_states_output: Optional[torch.Tensor] = None |
| h_tilde: Optional[torch.Tensor] = None |
|
|
| |
| seednorm_pre_attn: Optional[SeeDNormAnalysis] = None |
| lns_attn_output: Optional[torch.Tensor] = None |
| attention: Optional[AttentionAnalysis] = None |
| attn_contribution: Optional[torch.Tensor] = None |
| gpas_attn: Optional[GPASAnalysis] = None |
|
|
| |
| seednorm_post_attn: Optional[SeeDNormAnalysis] = None |
| lns_mlp_output: Optional[torch.Tensor] = None |
| mlp: Optional[MLPAnalysis] = None |
| mlp_contribution: Optional[torch.Tensor] = None |
| gpas_mlp: Optional[GPASAnalysis] = None |
|
|
| |
| jtokm: Optional[JTokMAnalysis] = None |
| attn_res: Optional[AttnResAnalysis] = None |
| laurel: Optional[LAuReLAnalysis] = None |
|
|
|
|
| @dataclass |
| class GeneratorAnalysis: |
| """ |
| Internals of a LeviathanGenerator forward pass. |
| Only populated when use_token_generator=True. |
| """ |
| z_raw: Optional[torch.Tensor] = None |
| z_tilde: Optional[torch.Tensor] = None |
| B_vals: Optional[torch.Tensor] = None |
| z_all_pre_norm: Optional[torch.Tensor] = None |
| z_all_post_sigmoid: Optional[torch.Tensor] = None |
| modes_all: Optional[torch.Tensor] = None |
| embeddings: Optional[torch.Tensor] = None |
|
|
|
|
| @dataclass |
| class AnalysisState: |
| """ |
| Root analysis container for one NeoLLMForCausalLM forward pass. |
| |
| Populated when model.enable_analysis() is active AND model.training=False. |
| All tensors are detached from the computation graph. |
| |
| Structure: |
| .input_ids β original input token ids |
| .embeddings β token embeddings entering the decoder stack |
| .final_hidden_states β post-norm output entering lm_head |
| .generator β LeviathanGenerator internals (use_token_generator only) |
| .layers[i] β per-layer LayerAnalysis, indexed by layer_idx |
| .jtokm_aux_stats β raw per-layer load-balancing tuples (use_jtokm only) |
| .attn_res_sources_final β AttnRes source list at end of forward (use_attn_res only) |
| .logits β lm_head output (only when labels is None) |
| |
| Access pattern: |
| model.eval() |
| model.enable_analysis() |
| _ = model(input_ids) |
| state = model.last_analysis |
| alpha = state.layers[3].attention.alpha_per_head |
| |
| REPO access (use_repo=True, layers >= repo_start_layer): |
| # Per-head predicted positions [B, H, S] for layer i: |
| z = state.layers[i].attention.repo.positions |
| |
| # Shared position representation r [B, S, d_p] for layer i: |
| r = state.layers[i].attention.repo.r_repr |
| |
| # Q/K after REPO rotation (identical field name as standard RoPE path): |
| q = state.layers[i].attention.q_post_rope # [B, H, S, head_dim] |
| k = state.layers[i].attention.k_post_rope # [B, H, S, head_dim] |
| |
| # For layers below repo_start_layer: .repo is None, q/k_post_rope |
| # contain standard integer-RoPE rotated Q/K β same field, same shape. |
| """ |
| input_ids: Optional[torch.Tensor] = None |
| embeddings: Optional[torch.Tensor] = None |
| final_hidden_states: Optional[torch.Tensor] = None |
|
|
| generator: Optional[GeneratorAnalysis] = None |
| layers: Optional[List[LayerAnalysis]] = None |
| jtokm_aux_stats: Optional[list] = None |
| attn_res_sources_final: Optional[list] = None |
| logits: Optional[torch.Tensor] = None |
|
|
| class ScalarMultiplier(nn.Module): |
| """ |
| Scalar Learnable Multiplier: WΜ = sΒ·W |
| |
| From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": |
| Allows the effective matrix norm ||WΜ|| = sΒ·||W|| to adapt to data, escaping the |
| WD-noise equilibrium that constrains ||W|| β β(Ξ·/Ξ»). |
| """ |
| def __init__(self, initial_value: float = 1.0): |
| super().__init__() |
| self.multiplier = nn.Parameter(torch.tensor(initial_value)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.multiplier * x |
|
|
|
|
| class VectorMultiplier(nn.Module): |
| """ |
| Vector Learnable Multipliers: WΜ = diag(r)Β·WΒ·diag(c) |
| |
| From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": |
| Frees not only the overall matrix norm but also individual row/column norms from |
| the WD-noise equilibrium, enabling richer feature scale diversity. |
| """ |
| def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0): |
| super().__init__() |
| self.multiplier_type = multiplier_type |
| self.multiplier = nn.Parameter(torch.ones(dim) * initial_value) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.multiplier |
|
|
|
|
| class LinearWithMultipliers(nn.Module): |
| """ |
| Linear layer with optional row and/or column learnable multipliers. |
| Implements: y = (r β (W @ (c β x))) + b |
| """ |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| bias: bool = True, |
| use_row_multiplier: bool = False, |
| use_column_multiplier: bool = False |
| ): |
| super().__init__() |
| self.linear = nn.Linear(in_features, out_features, bias=bias) |
| self.use_row_multiplier = use_row_multiplier |
| self.use_column_multiplier = use_column_multiplier |
|
|
| if use_row_multiplier: |
| self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row") |
| if use_column_multiplier: |
| self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.use_column_multiplier: |
| x = self.column_multiplier(x) |
| x = self.linear(x) |
| if self.use_row_multiplier: |
| x = self.row_multiplier(x) |
| return x |
|
|
|
|
| |
|
|
| class LeviathanGenerator(nn.Module): |
| """ |
| Continuous token embedding generator for the Leviathan architecture. |
| |
| Replaces E β R^{VΓD} with a separable generator G : {0,...,V-1} β R^D. |
| Three-stage pipeline: latent compositional indexing β B-spline basis |
| expansion β tensor-product aggregation (Batley & Saha, 2026, Β§3.1). |
| |
| When ``return_internals=True`` the forward returns |
| ``(embeddings, z_tilde, B_vals)`` for reuse by JTok-M surfaces in every |
| decoder layer, avoiding redundant B-spline evaluation. |
| |
| B-spline basis uses the KHRONOS closed-form quadratic kernel (ckhronos.py) |
| which is fully compatible with torch.compile (no .item(), no Python loops |
| inside the kernel, all tensor shapes static). |
| |
| Sign parity tracking matches KHRONOS KHRONOSLayer exactly. |
| |
| Initialization: spline_coeff ~ normal(mean=1.0, std=0.1), matching |
| KHRONOS init_weights_prod so that phi β 1.0 at init and the product of |
| d_seed factors starts near 1.0 instead of ~10^{-21}. |
| |
| **Frequency-based codebook ordering (optional)** |
| |
| By default, the base-k decomposition maps token indices directly to |
| codebook coordinates via arithmetic: token x β (x // bΒ², x // b % b, x % b). |
| This assigns coordinates based on index position, which is arbitrary with |
| respect to linguistic meaning under BPE tokenisation. |
| |
| When ``set_freq_order`` is called with a frequency-rank tensor, the |
| decomposition maps tokens through their frequency rank first: |
| token x β rank_freq[x] β (rank // bΒ², rank // b % b, rank % b). |
| |
| This makes tokens with similar corpus frequency share codebook entries, |
| introducing pre-existing statistical structure into the gradient of W_res |
| from step 0. Since token frequency correlates with distributional behaviour |
| (Zipfian distribution, syntactic category, semantic class), the gradient |
| |
| βL/βW_res = Ξ£_x Ξ΄_x Β· zΜ_x^T |
| |
| has low-rank structure immediately exploitable by Conda's SVD projection, |
| analogous to how the dense embedding table E gradient has low-rank structure |
| from the language distribution. Without this ordering, the SVD finds only |
| noise until codebooks organise through training, delaying Conda's advantage. |
| |
| If ``set_freq_order`` is never called, ``freq_order`` remains None and the |
| module behaves identically to the original implementation β the feature is |
| fully opt-in and backward compatible. |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig): |
| super().__init__() |
|
|
| vocab_size = config.vocab_size |
| hidden_size = config.hidden_size |
| d_seed = config.generator_d_seed |
| num_modes = config.generator_num_modes |
| num_knots = config.generator_num_knots |
| spline_degree = config.generator_spline_degree |
| k = config.generator_k |
| krank = getattr(config, "generator_krank", 64) |
|
|
| b = math.ceil(vocab_size ** (1.0 / k)) |
|
|
| self.b = b |
| self.k = k |
| self.d_seed = d_seed |
| self.num_modes = num_modes |
| self.num_knots = num_knots |
| self.spline_degree = spline_degree |
| self.krank = krank |
| self.hidden_size = hidden_size |
|
|
| |
| |
| |
| self.codebooks = nn.Parameter(torch.empty(k, b, d_seed)) |
|
|
| |
| |
| self.register_buffer("freq_order", None, persistent=False) |
|
|
| |
| |
| self.register_buffer( |
| "knot_grid", |
| torch.linspace(0.0, 1.0, num_knots), |
| persistent=False, |
| ) |
|
|
| |
| |
| |
| |
| |
| self.seed_proj = nn.Linear(d_seed, d_seed, bias=True) |
| self.seed_norm = nn.LayerNorm(d_seed) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.head_proj_weight = nn.Parameter( |
| torch.empty(num_modes * d_seed, d_seed) |
| ) |
| self.head_norm_weight = nn.Parameter(torch.ones(num_modes, d_seed)) |
| self.head_norm_bias = nn.Parameter(torch.zeros(num_modes, d_seed)) |
| self.head_norm_eps = 1e-5 |
|
|
| |
| self.head_scale = nn.Parameter( |
| torch.full((num_modes, d_seed), float(num_knots - 1)) |
| ) |
| |
| self.head_spline = nn.Parameter( |
| torch.empty(num_modes, d_seed, num_knots, krank) |
| ) |
| self.head_out_weight = nn.Parameter( |
| torch.empty(num_modes, krank, hidden_size) |
| ) |
|
|
| def set_freq_order(self, freq_order: torch.Tensor) -> None: |
| """ |
| Register a frequency-rank mapping to structure codebook coordinates. |
| |
| Must be called after model instantiation and after any device transfer |
| (.to(device), .cuda(), etc.) since the buffer is non-persistent and |
| is not saved to checkpoints. |
| |
| Args: |
| freq_order: Long tensor of shape ``(vocab_size,)`` where |
| ``freq_order[x]`` is the frequency rank of token x in the |
| training corpus (rank 0 = most frequent token). Typically |
| computed as ``torch.argsort(token_counts, descending=True)``. |
| |
| Example:: |
| |
| counts = compute_token_frequencies(tokenizer, dataset) # [V] |
| ranks = torch.argsort(counts, descending=True) # [V] |
| model.model.token_generator.set_freq_order(ranks) |
| """ |
| if freq_order.shape[0] != self.codebooks.shape[1] ** self.k: |
| |
| |
| pass |
| self.freq_order = freq_order.long().to(self.codebooks.device) |
|
|
| def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Deterministic base-b decomposition: i β (i_0, ..., i_{k-1}). |
| |
| When ``freq_order`` is set, token indices are remapped through their |
| frequency rank before decomposition. This ensures that tokens sharing |
| codebook entries are similar in corpus frequency rather than arbitrary |
| in BPE index space, providing pre-existing low-rank gradient structure |
| for Conda's SVD projection from step 0. |
| |
| Without ``freq_order``: x β (x // b^{k-1}, ..., x % b) |
| With ``freq_order``: x β freq_order[x] β (rank // b^{k-1}, ..., rank % b) |
| """ |
| ids = token_ids.long().clone() |
| if self.freq_order is not None: |
| ids = self.freq_order[ids] |
|
|
| coords = torch.empty( |
| *token_ids.shape, self.k, |
| dtype=torch.long, device=token_ids.device, |
| ) |
| for r in range(self.k - 1, -1, -1): |
| coords[..., r] = ids % self.b |
| ids = ids // self.b |
| return coords |
|
|
| def _bspline_basis(self, x_flat: torch.Tensor) -> torch.Tensor: |
| """ |
| KHRONOS quadratic B-spline basis with fixed scalar scale. |
| |
| Used exclusively by the JTok-M shared path (z_tilde β B_vals). |
| JTok-M surfaces have their own spline_coeff and call _modes_from_basis |
| with these B_vals. This path is unchanged from the original design. |
| |
| Args: |
| x_flat: [N, d_seed], values in [0, 1]. |
| Returns: |
| [N, d_seed, num_knots] float32. |
| """ |
| scale = float(self.num_knots - 1) |
| x32 = x_flat.float() |
| x_e = x32.unsqueeze(-1) |
| grid = self.knot_grid.float().view(1, 1, -1) |
| d = (x_e - grid).abs() * scale |
|
|
| return torch.where( |
| d < 0.5, |
| 0.75 - d ** 2, |
| torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)), |
| ) |
|
|
| def _bspline_basis_all_heads( |
| self, |
| x_all: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Vectorized KHRONOS quadratic B-spline basis for all heads at once. |
| |
| Mathematically identical to calling _bspline_basis_head 8 times in a |
| loop, but materializes the full [N, M, d_seed, n_knots] tensor in a |
| single kernel instead of 8 sequential [N, d_seed, n_knots] tensors. |
| |
| Args: |
| x_all: [N, M, d_seed], values in [0, 1], all heads stacked. |
| Returns: |
| [N, M, d_seed, n_knots] float32. |
| |
| NOTE: Este mΓ©todo se mantiene para compatibilidad con JTok-M y anΓ‘lisis. |
| El forward del generator ya NO lo usa β usa _compute_head en su lugar. |
| """ |
| x32 = x_all.float() |
| x_e = x32.unsqueeze(-1) |
| grid = self.knot_grid.float().view(1, 1, 1, -1) |
| |
| sc = self.head_scale.float().unsqueeze(0).unsqueeze(-1) |
| d = (x_e - grid).abs() * sc |
|
|
| return torch.where( |
| d < 0.5, |
| 0.75 - d ** 2, |
| torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)), |
| ) |
|
|
| def _compute_head( |
| self, |
| z: torch.Tensor, |
| m: int, |
| ) -> torch.Tensor: |
| """ |
| Forward completo para el cabezal m del generator. |
| |
| Reemplaza la materializaciΓ³n conjunta [N, M, d_seed, n_knots] del path |
| vectorizado. Cada llamada materializa solo [N, d_seed, n_knots] (1 cabezal), |
| reduciendo el pico de memoria de O(MΒ·d_seedΒ·n_knots) a O(d_seedΒ·n_knots) |
| por cabezal. |
| |
| Pipeline: |
| z [N, d_seed] |
| β Linear(head_proj_weight[m*d_seed:(m+1)*d_seed]) β [N, d_seed] |
| β ManualLayerNorm(weight[m], bias[m]) β [N, d_seed] |
| β sigmoid(x/2) β [N, d_seed] (coordenada en [0,1]^d_seed) |
| β B-spline KHRONOS con scale=head_scale[m] β [N, d_seed, n_knots] |
| β einsum con head_spline[m] β per_dim [N, d_seed, krank] |
| β sign-parity product (log-sum-exp) β modes [N, krank] |
| β Linear(head_out_weight[m]) β [N, hidden_size] |
| |
| Por quΓ© loop Python sobre M cabezales en lugar de vmap: |
| torch.vmap sobre cabezales con parΓ‘metros distintos requiere |
| functional_call y stack_module_state, lo que complica el acceso |
| a buffers (knot_grid, head_norm_eps) desde dentro del transform. |
| Un loop Python con M=8 fijo es unrolleado por TorchDynamo en una |
| secuencia estΓ‘tica de ops β exactamente como lo hace XLA/Flax en |
| la implementaciΓ³n original de Reza. El compilador ve 8 grafos |
| idΓ©nticos en estructura pero con parΓ‘metros distintos, y puede |
| fusionarlos u optimizarlos de forma independiente. Con chunk_size=1 |
| en vmap el comportamiento serΓa anΓ‘logo pero con mayor overhead de |
| instrumentaciΓ³n. |
| |
| Args: |
| z: [N, d_seed] β codebook seed compartido (float del dtype del modelo). |
| m: Γndice del cabezal (0 β€ m < num_modes), Python int estΓ‘tico. |
| Returns: |
| [N, hidden_size] β contribuciΓ³n de este cabezal al embedding final. |
| """ |
| d = self.d_seed |
| nk = self.num_knots |
| kr = self.krank |
|
|
| |
| |
| |
| proj_w = self.head_proj_weight[m * d : (m + 1) * d] |
| |
| |
| |
| zh = F.linear( |
| z.to(dtype=proj_w.dtype, device=proj_w.device), |
| proj_w, |
| ) |
| zh = zh.float() |
|
|
| |
| |
| |
| norm_w = self.head_norm_weight[m].float() |
| norm_b = self.head_norm_bias[m].float() |
| mean = zh.mean(dim=-1, keepdim=True) |
| var = zh.var(dim=-1, keepdim=True, unbiased=False) |
| zh = (zh - mean) / (var + self.head_norm_eps).sqrt() |
| zh = zh * norm_w + norm_b |
|
|
| |
| zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) |
|
|
| |
| |
| |
| sc = self.head_scale[m].float().view(1, -1, 1) |
| x_e = zh.unsqueeze(-1) |
| grid = self.knot_grid.float().view(1, 1, -1) |
| dist = (x_e - grid).abs() * sc |
| B_m = torch.where( |
| dist < 0.5, |
| 0.75 - dist ** 2, |
| torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)), |
| ) |
|
|
| |
| |
| |
| |
| per_dim = torch.einsum( |
| "ndg,dgk->ndk", |
| B_m, |
| self.head_spline[m].float(), |
| ) |
|
|
| |
| |
| per_dim_abs = per_dim.abs() + 1e-9 |
| log_mag = torch.log(per_dim_abs).sum(dim=1) |
| num_neg = (per_dim < 0).long().sum(dim=1) |
| prod_sign = 1.0 - 2.0 * (num_neg % 2).float() |
| modes_m = prod_sign * torch.exp(log_mag) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| out_m = ( |
| modes_m.to(self.head_out_weight.dtype) |
| @ self.head_out_weight[m] |
| ) |
| return out_m |
|
|
| def _khronos_all_heads( |
| self, |
| B_all: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Vectorized KHRONOS tensor-product for all heads at once. |
| |
| Mathematically identical to calling _khronos_head_product 8 times, |
| but uses a single einsum over the head dimension. The sign-parity |
| aggregation is performed independently per head via the M dimension. |
| |
| Args: |
| B_all: [N, M, d_seed, n_knots] float32 |
| Returns: |
| [N, M, krank] in float32. |
| """ |
| |
| |
| per_dim = torch.einsum( |
| "nmdg,mdgk->nmdk", |
| B_all, |
| self.head_spline.float(), |
| ) |
|
|
| per_dim_abs = per_dim.abs() + 1e-9 |
| |
| log_mag = torch.log(per_dim_abs).sum(dim=2) |
| num_neg = (per_dim < 0).long().sum(dim=2) |
| prod_sign = 1.0 - 2.0 * (num_neg % 2).float() |
|
|
| return prod_sign * torch.exp(log_mag) |
|
|
| def _modes_from_basis( |
| self, |
| B_vals: torch.Tensor, |
| spline_coeff: torch.Tensor, |
| target_dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """ |
| Shared tensor-product aggregation used by both the input generator |
| and JTok-M surfaces. |
| |
| phi_{r,j} = spline_coeff[j,r,:] Β· B_vals[n,r,:] |
| M_j = sign_j * exp(Ξ£_r log|phi_{r,j}|) (KHRONOS sign-parity) |
| |
| Args: |
| B_vals: [N, d_seed, n_knots] float32 |
| spline_coeff: [num_modes, d_seed, n_knots] |
| target_dtype: output dtype |
| Returns: |
| modes: [N, num_modes] in target_dtype |
| """ |
| phi = torch.einsum( |
| "jrk,nrk->njr", |
| spline_coeff.float(), |
| B_vals, |
| ) |
|
|
| phi_abs = phi.abs() + 1e-9 |
| log_mag = torch.log(phi_abs).sum(dim=-1) |
| num_neg = (phi < 0).long().sum(dim=-1) |
| prod_sign = 1.0 - 2.0 * (num_neg % 2).float() |
|
|
| return (prod_sign * torch.exp(log_mag)).to(target_dtype) |
|
|
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| return_internals: bool = False, |
| analysis: Optional[GeneratorAnalysis] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| """ |
| Generate embeddings from discrete token indices. |
| |
| Two parallel paths run from the shared codebook output z: |
| |
| JTok-M path (return_internals): |
| z β seed_proj β seed_norm β sigmoid β z_tilde [N, d_seed] |
| z_tilde β _bspline_basis (fixed scale) β B_vals [N, d_seed, n_knots] |
| These are returned for reuse by every decoder layer's JTok-M |
| surfaces without redundant B-spline evaluation. |
| |
| Generator path (per-head, real Leviathan architecture): |
| For each head i: |
| z β head_proj_i β head_norm_i β sigmoid(x/2) β z_tilde_i |
| z_tilde_i β _bspline_basis_head(scale_i) β B_vals_i |
| B_vals_i Γ head_spline_i β _khronos_head_product β [N, krank] |
| head_out_i([N, krank]) β [N, hidden_size] |
| e = sum of all head_out_i outputs (no W_res, per author's code) |
| |
| Args: |
| token_ids: (batch, seq_len) or (seq_len,) |
| return_internals: if True, also return z_tilde and B_vals for |
| reuse by JTok-M surfaces in every decoder layer. |
| analysis: Optional[GeneratorAnalysis] β when not None and model |
| is in eval mode with analysis armed, deposits all |
| internal tensors (detached). No-op during training. |
| Returns: |
| embeddings [*token_ids.shape, hidden_size], |
| or (embeddings, z_tilde [N, d_seed], B_vals [N, d_seed, n_knots]) |
| when return_internals=True. |
| """ |
| target_dtype = self.codebooks.dtype |
| orig_shape = token_ids.shape |
| N = token_ids.numel() |
|
|
| |
| coords = self._base_k_decompose(token_ids) |
| coords_flat = coords.reshape(N, self.k) |
| z = torch.zeros(N, self.d_seed, device=token_ids.device, dtype=target_dtype) |
| for r in range(self.k): |
| z = z + self.codebooks[r][coords_flat[:, r]] |
|
|
| if analysis is not None: |
| analysis.z_raw = z.detach() |
|
|
| |
| |
| |
| z_tilde = torch.sigmoid(self.seed_norm(self.seed_proj(z))) |
| B_vals = self._bspline_basis(z_tilde.clamp(0.0, 1.0)) |
|
|
| if analysis is not None: |
| analysis.z_tilde = z_tilde.detach() |
| analysis.B_vals = B_vals.detach() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| target_dtype = self.codebooks.dtype |
| e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype) |
| for m in range(self.num_modes): |
| e = e + self._compute_head(z, m) |
|
|
| |
| e = e.reshape(*orig_shape, self.hidden_size) |
|
|
| if analysis is not None: |
| analysis.embeddings = e.detach() |
|
|
| if return_internals: |
| return e, z_tilde, B_vals |
| return e |
|
|
|
|
| |
|
|
| class LeviathanJTokM(nn.Module): |
| """ |
| Leviathan-JTok-M token-indexed modulation module for one decoder layer. |
| |
| Fuses the Leviathan continuous geometry with JTok-M (Yang et al., 2026): |
| |
| - Instead of per-token lookup tables (O(VΒ·D) per layer), uses n_e independent |
| CP-separable surfaces over the shared zΜ_x from the Leviathan generator, |
| reusing B_vals already computed in the embedding stage. |
| - Context-dependent router: gates over hΜ^β_x (hidden state after attention) |
| using Sigmoid+TopK β not Softmax β to avoid inter-surface competition. |
| - Additive injection with 1/β(2β) scaling coordinated with the existing LNS |
| factor 1/ββ, maintaining a constant JTok-M / backbone ratio of 1/β2 β 0.707 |
| across all depths (instead of 1/β(2N_l) which would grow JTok-M dominance |
| in deep layers as LNS suppresses backbone activations). |
| - Fully vectorized: all surfaces evaluated in one einsum, TopK with fixed K |
| produces static shapes β compatible with torch.compile max-autotune. |
| |
| Scaling note (LNS coordination): |
| LNS applies 1/ββ to backbone sublayer inputs (β = 1-indexed layer). |
| JTok-M applies 1/β(2β) to its injection residual. |
| Ratio: [1/β(2β)] / [1/ββ] = 1/β2 β constant at every depth. |
| |
| Parameter cost per layer (defaults n_e=5, M_mod=4, d_seed=128, D=512): |
| spline_coeff: n_e Γ M_mod Γ d_seed Γ n_knots = 5Γ4Γ128Γ32 = 81,920 |
| W_out: n_e Γ M_mod Γ D = 5Γ4Γ512 = 10,240 |
| W_res: n_e Γ d_seed Γ D = 5Γ128Γ512 = 327,680 |
| router R: D Γ n_e = 512Γ5 = 2,560 |
| scaler s: D = 512 |
| Total per layer: ~422,912 β ~5.07M for 12 layers. |
| |
| References: |
| Yang, Y. et al. (2026). JTok. arXiv:2602.00800. |
| Batley & Saha (2026). Leviathan. arXiv:2601.22040. |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.n_e = config.jtokm_num_experts |
| self.top_k = config.jtokm_top_k |
| self.M_mod = config.jtokm_num_modes |
| self.d_seed = config.generator_d_seed |
| self.n_knots = config.generator_num_knots |
| self.hidden_size = config.hidden_size |
| self.norm_eps = config.jtokm_norm_eps |
|
|
| |
| ell = max(layer_idx + 1, 1) |
| self.lns_scale = 1.0 / math.sqrt(2.0 * ell) |
|
|
| |
| |
| |
| |
| |
| self.spline_coeff = nn.Parameter( |
| torch.empty(self.n_e, self.M_mod, self.d_seed, self.n_knots) |
| ) |
| self.W_out = nn.Parameter( |
| torch.empty(self.n_e, self.M_mod, self.hidden_size) |
| ) |
| self.W_res = nn.Parameter( |
| torch.empty(self.n_e, self.d_seed, self.hidden_size) |
| ) |
|
|
| |
| self.router = nn.Linear(config.hidden_size, self.n_e, bias=False) |
|
|
| |
| self.scaler = nn.Parameter(torch.ones(config.hidden_size)) |
|
|
| |
|
|
| def _eval_surfaces( |
| self, |
| B_vals: torch.Tensor, |
| z_tilde: torch.Tensor, |
| target_dtype: torch.dtype, |
| analysis: Optional[JTokMAnalysis] = None, |
| ) -> torch.Tensor: |
| """ |
| Evaluate all n_e surfaces vectorized over the full token batch. |
| |
| phi[n, i, j, r] = spline_coeff[i, j, r, :] Β· B_vals[n, r, :] |
| M[n, i, j] = sign * exp(Ξ£_r log|phi[n,i,j,r]|) |
| m[n, i] = W_out[i] @ M[n,i] + W_res[i] @ zΜ[n] |
| |
| All shapes are static β torch.compile compatible. |
| |
| Args: |
| B_vals: [N, d_seed, n_knots] float32 |
| z_tilde: [N, d_seed] |
| target_dtype: model dtype |
| analysis: Optional[JTokMAnalysis] β deposits surfaces when not None. |
| Returns: |
| surfaces: [N, n_e, D] |
| """ |
| N = B_vals.shape[0] |
|
|
| |
| |
| phi = torch.einsum( |
| "ijrk,nrk->nijr", |
| self.spline_coeff.float(), |
| B_vals, |
| ) |
|
|
| |
| phi_abs = phi.abs() + 1e-9 |
| log_mag = torch.log(phi_abs).sum(dim=-1) |
| num_neg = (phi < 0).long().sum(dim=-1) |
| prod_sign = 1.0 - 2.0 * (num_neg % 2).float() |
| modes = (prod_sign * torch.exp(log_mag)).to(target_dtype) |
| |
|
|
| |
| out_modes = torch.einsum("nim,imd->nid", modes, self.W_out.to(target_dtype)) |
|
|
| |
| z = z_tilde.to(target_dtype) |
| out_res = torch.einsum("nd,idc->nic", z, self.W_res.to(target_dtype)) |
|
|
| surfaces = out_modes + out_res |
|
|
| if analysis is not None: |
| analysis.surfaces = surfaces.detach() |
|
|
| return surfaces |
|
|
| |
|
|
| def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps) |
|
|
| def _route_and_mix( |
| self, |
| h_tilde: torch.Tensor, |
| surfaces: torch.Tensor, |
| analysis: Optional[JTokMAnalysis] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Context-dependent routing over h_tilde (hidden state after attention). |
| |
| Sigmoid + TopK (not Softmax) avoids inter-surface competition: |
| g = RMSNorm(hΜ) @ R [N, n_e] |
| K winners selected per token |
| w_i = Ο(g_i) / Ξ£_{jβK} Ο(g_j) (normalized over selected only) |
| e = Ξ£_{iβK} w_i * surfaces[n, i] |
| |
| All tensor shapes are static. TopK with fixed K returns [N, K] indices. |
| torch.compile compatible. |
| |
| Also computes load-balancing statistics p_i and f_i for aux loss. |
| |
| Args: |
| h_tilde: [N, D] β hidden state after attention (before MLP) |
| surfaces: [N, n_e, D] |
| analysis: Optional[JTokMAnalysis] β deposits routing data when not None. |
| Returns: |
| mixed: [N, D] |
| aux_stats: (p_sum [n_e], f_sum [n_e], N) for loss accumulation |
| """ |
| N = h_tilde.shape[0] |
|
|
| |
| g = self.router(self._rms_norm(h_tilde)) |
|
|
| |
| topk_vals, topk_idx = torch.topk(g, self.top_k, dim=-1) |
|
|
| |
| sig_vals = torch.sigmoid(topk_vals) |
| w = sig_vals / sig_vals.sum(dim=-1, keepdim=True) |
|
|
| |
| |
| idx_exp = topk_idx.unsqueeze(-1).expand(N, self.top_k, self.hidden_size) |
| selected = surfaces.gather(dim=1, index=idx_exp) |
| mixed = (w.unsqueeze(-1) * selected).sum(dim=1) |
|
|
| if analysis is not None: |
| analysis.router_logits = g.detach() |
| analysis.topk_indices = topk_idx.detach() |
| analysis.routing_weights = w.detach() |
| analysis.mixed_pre_norm = mixed.detach() |
|
|
| |
| |
| |
| with torch.no_grad(): |
| sig_all = torch.sigmoid(g) |
| p_sum = sig_all.sum(dim=0) |
| |
| onehot = torch.zeros_like(g).scatter_( |
| 1, topk_idx, 1.0 |
| ) |
| f_sum = onehot.sum(dim=0) |
|
|
| return mixed, (p_sum, f_sum, N) |
|
|
| |
|
|
| def forward( |
| self, |
| h_tilde: torch.Tensor, |
| z_tilde: torch.Tensor, |
| B_vals: torch.Tensor, |
| analysis: Optional[JTokMAnalysis] = None, |
| ) -> Tuple[torch.Tensor, Tuple]: |
| """ |
| Compute additive JTok-M residual for one decoder layer. |
| |
| Args: |
| h_tilde: [N, D] hidden state after attention (before MLP) |
| z_tilde: [N, d_seed] latent coordinate from generator |
| B_vals: [N, d_seed, n_k] B-spline basis (computed once, reused) |
| analysis: Optional[JTokMAnalysis] β deposits all JTok-M internals when |
| not None. No-op during training (analysis is always None then). |
| Returns: |
| delta_r: [N, D] additive residual (already scaled) |
| aux_stats: tuple for accumulating load-balance loss |
| """ |
| target_dtype = h_tilde.dtype |
|
|
| |
| surfaces = self._eval_surfaces(B_vals, z_tilde, target_dtype, analysis=analysis) |
|
|
| |
| mixed, aux_stats = self._route_and_mix(h_tilde, surfaces, analysis=analysis) |
|
|
| |
| |
| mixed_norm = mixed / (mixed.norm(dim=-1, keepdim=True) + self.norm_eps) |
| delta_r = self.lns_scale * self.scaler * mixed_norm |
|
|
| if analysis is not None: |
| analysis.mixed_normalized = mixed_norm.detach() |
| analysis.delta_r = delta_r.detach() |
| analysis.p_sum = aux_stats[0].detach() |
| analysis.f_sum = aux_stats[1].detach() |
| analysis.lns_scale = self.lns_scale |
|
|
| return delta_r, aux_stats |
|
|
|
|
| def compute_jtokm_aux_loss( |
| aux_stats_list: list, |
| n_e: int, |
| weight: float, |
| ) -> torch.Tensor: |
| """ |
| Aggregate load-balancing aux loss over all active layers. |
| |
| L_aux = Ξ» Β· n_e Β· Ξ£_i p_i Β· f_i |
| averaged over all layers with active JTok-M. |
| |
| Args: |
| aux_stats_list: list of (p_sum [n_e], f_sum [n_e], N) per layer |
| n_e: number of experts |
| weight: Ξ» coefficient |
| Returns: |
| scalar loss tensor |
| """ |
| total_loss = None |
| for p_sum, f_sum, N in aux_stats_list: |
| p_i = p_sum / N |
| f_i = f_sum / (N * 1.0) |
| layer_loss = weight * n_e * (p_i * f_i).sum() |
| total_loss = layer_loss if total_loss is None else total_loss + layer_loss |
| if total_loss is None: |
| return torch.tensor(0.0) |
| return total_loss / len(aux_stats_list) |
|
|
|
|
| |
|
|
| class FANLayer(nn.Module): |
| """ |
| Fourier Analysis Network (FAN) layer. |
| FANLayer'(X) = [cos(WpX) || sin(WpX) || (WpΒ―X + BpΒ―)] |
| """ |
|
|
| def __init__(self, hidden_size: int, fan_ratio: float = 0.25): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.fan_ratio = fan_ratio |
|
|
| output_dim = hidden_size + int(hidden_size * fan_ratio) |
| self.p_output_dim = int(output_dim * fan_ratio) |
| self.g_output_dim = output_dim - self.p_output_dim * 2 |
|
|
| self.input_linear = nn.Linear( |
| hidden_size, self.p_output_dim + self.g_output_dim, bias=True |
| ) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02) |
| if self.input_linear.bias is not None: |
| nn.init.zeros_(self.input_linear.bias) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[FANAnalysis] = None, |
| ) -> torch.Tensor: |
| pg = self.input_linear(x) |
| p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1) |
| cos_p = torch.cos(p) |
| sin_p = torch.sin(p) |
| if analysis is not None: |
| analysis.cosine_component = cos_p.detach() |
| analysis.sine_component = sin_p.detach() |
| analysis.linear_component = g.detach() |
| return torch.cat([cos_p, sin_p, g], dim=-1) |
|
|
|
|
| class LNS(nn.Module): |
| """ |
| LayerNorm Scaling: applies 1/ββ to suppress variance growth with depth. |
| From "The Curse of Depth in Large Language Models". |
| """ |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.layer_idx = max(layer_idx + 1, 1) |
| self.scale = 1.0 / math.sqrt(self.layer_idx) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.scale |
|
|
|
|
| class GPAS(nn.Module): |
| """Gradient-Preserving Activation Scaling.""" |
| def __init__(self, d_model: int): |
| super().__init__() |
| self.d_model = d_model |
| self.alpha = nn.Parameter(torch.zeros(1)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[GPASAnalysis] = None, |
| ) -> torch.Tensor: |
| silu_alpha = F.silu(self.alpha) |
| subtracted = silu_alpha * x.detach() |
| if analysis is not None: |
| analysis.silu_alpha = silu_alpha.detach() |
| analysis.subtracted_component = subtracted.detach() |
| return x - subtracted |
|
|
|
|
| class SeeDNorm(nn.Module): |
| """ |
| Self-Rescaled Dynamic Normalization. |
| SeeDNorm(x) = [Ο(xΒ·Ξ²^T)Β·Ξ± + Ξ³] β x/RMS(x) |
| |
| rescale_factor = tanh(x Β· Ξ²) β (-1, 1) escalar por token |
| dynamic_scale = rescale_factor Β· Ξ± + Ξ³ β β^dim |
| output = dynamic_scale β RMSNorm(x) |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| eps: float = 1e-6, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
|
|
| self.gamma = nn.Parameter(torch.ones(dim)) |
| self.beta = nn.Parameter(torch.zeros(dim)) |
| self.alpha = nn.Parameter(torch.ones(dim)) |
|
|
| def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[SeeDNormAnalysis] = None, |
| ) -> torch.Tensor: |
| rescale_factor = torch.tanh( |
| torch.sum(x * self.beta, dim=-1, keepdim=True) |
| ) |
| dynamic_scale = rescale_factor * self.alpha + self.gamma |
| x_normalized = self._rms_norm(x.float()) |
| output = (x_normalized * dynamic_scale.float()).type_as(x) |
| if analysis is not None: |
| analysis.rescale_factor = rescale_factor.detach() |
| analysis.dynamic_scale = dynamic_scale.detach() |
| analysis.x_normalized = x_normalized.detach() |
| analysis.output = output.detach() |
| return output |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, eps={self.eps}" |
|
|
|
|
| |
|
|
| class NeoLLMRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: NeoLLMConfig, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
| self.config = config |
|
|
| self.rope_type = "default" |
| if (hasattr(config, "rope_scaling") |
| and config.rope_scaling is not None |
| and isinstance(config.rope_scaling, dict)): |
| rope_type = config.rope_scaling.get( |
| "rope_type", config.rope_scaling.get("type") |
| ) |
| if rope_type and rope_type in ROPE_INIT_FUNCTIONS: |
| self.rope_type = rope_type |
|
|
| rope_init_fn = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: NeoLLMConfig = None, |
| device: Optional["torch.device"] = None, |
| seq_len: int = None, |
| ) -> tuple["torch.Tensor", float]: |
| base = config.rope_theta |
| dim = getattr(config, "head_dim", None) or \ |
| config.hidden_size // config.num_attention_heads |
| dim = int(dim * getattr(config, "partial_rotary_factor", 1.0)) |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64) |
| .to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, 1.0 |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| if position_ids.dim() == 1: |
| position_ids = position_ids.unsqueeze(0) |
| B = x.shape[0] |
| if position_ids.shape[0] != B: |
| position_ids = position_ids.expand(B, -1) |
| |
| device_type = (x.device.type |
| if isinstance(x.device.type, str) and x.device.type != "mps" |
| else "cpu") |
| |
| if self.inv_freq.device.type == "meta": |
| inv_freq_data, _ = self.compute_default_rope_parameters( |
| self.config, device=x.device |
| ) |
| self.register_buffer("inv_freq", inv_freq_data, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq_data.clone(), persistent=False) |
| |
| inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) |
| |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1) |
| * inv_freq.unsqueeze(0).unsqueeze(0)) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
| |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def linear_clipping(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Piecewise-linear activation for Affine-Scaled Attention scaling factors. |
| |
| Avoids the saturation problem of sigmoid, which collapses most outputs |
| toward 0 or 1 and loses the intermediate scaling range the model needs |
| for fine-grained per-query attention modulation. |
| |
| f(x) = 0 if x β€ -5 |
| = 0.1Β·x + 0.5 if -5 < x < 5 |
| = 1 if x β₯ 5 |
| |
| Equivalent to: clamp(0.1Β·x + 0.5, 0, 1). |
| Output range: [0, 1]. Gradient: 0.1 across the entire non-saturated region. |
| |
| Reference: Bae et al. (2026), Affine-Scaled Attention Β§6. |
| """ |
| return torch.clamp(0.1 * x + 0.5, 0.0, 1.0) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| rotary_dim = cos.shape[-1] |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
| q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
| k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
| return torch.cat([q_embed, q_pass], dim=-1), torch.cat([k_embed, k_pass], dim=-1) |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def causal_first_difference(x: torch.Tensor) -> torch.Tensor: |
| return x - F.pad(x[..., :-1, :], (0, 0, 1, 0)) |
|
|
|
|
| def rms_key_unit_norm(x: torch.Tensor, eps: float) -> torch.Tensor: |
| return F.normalize(x.float(), p=2, dim=-1, eps=eps) * math.sqrt(x.shape[-1]) |
|
|
|
|
| def infer_key_validity( |
| attention_mask: Optional[torch.Tensor], seq_len: int, num_heads: int |
| ) -> Optional[torch.Tensor]: |
| if attention_mask is None or attention_mask.ndim != 4: |
| return None |
| if attention_mask.shape[-2] != seq_len or attention_mask.shape[-1] != seq_len: |
| return None |
| diag = attention_mask.diagonal(dim1=-2, dim2=-1) |
| valid = torch.isfinite(diag) & (diag == 0) |
| if valid.shape[1] == 1 and num_heads != 1: |
| valid = valid.expand(-1, num_heads, -1) |
| elif valid.shape[1] != num_heads: |
| valid = valid[:, :1, :].expand(-1, num_heads, -1) |
| return valid |
|
|
|
|
| def head_linear_compose( |
| hidden_states: torch.Tensor, mixing_matrix: torch.Tensor |
| ) -> torch.Tensor: |
| return torch.einsum( |
| "bhtd,hk->bktd", |
| hidden_states, |
| mixing_matrix.to(device=hidden_states.device, dtype=hidden_states.dtype), |
| ) |
|
|
|
|
| class MEAHeadSeeDNorm(nn.Module): |
| """MEA head-level normalization using SeeDNorm grouped by KV structure (GQA-aware).""" |
|
|
| def __init__(self, num_heads: int, head_dim: int, num_kv_groups: int, eps: float = 1e-6): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.num_kv_groups = num_kv_groups |
| self.num_kv_heads = num_heads // num_kv_groups |
| self.group_dim = num_kv_groups * head_dim |
| self.norm = SeeDNorm(self.group_dim, eps=eps) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| batch, seq_len, num_heads, head_dim = hidden_states.shape |
| if num_heads != self.num_heads or head_dim != self.head_dim: |
| raise ValueError( |
| f"MEAHeadSeeDNorm expected ({self.num_heads}, {self.head_dim}), " |
| f"received ({num_heads}, {head_dim})" |
| ) |
| grouped = hidden_states.reshape(batch, seq_len, self.num_kv_heads, self.group_dim) |
| return self.norm(grouped).reshape(batch, seq_len, num_heads, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
| attn_weights = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(query.dtype) |
| attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous() |
| attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training) |
| return attn_output, attn_weights |
|
|
|
|
| def affine_scaled_eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| alpha: torch.Tensor, |
| beta: torch.Tensor, |
| dropout: float = 0.0, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| """ |
| Affine-Scaled Attention (eager path). |
| |
| Replaces the standard weighted sum softmax(QK^T/βdk) V with: |
| |
| [Ξ±(X) Β· softmax(QK^T/βdk) + Ξ²(X)] V |
| |
| Ξ± is a per-head, per-query input-dependent scale in [0, 1]. |
| Ξ² is an input-dependent bias that compensates for deviations of Ξ± from its |
| running average, preventing the effective attention mass from collapsing. |
| |
| Both Ξ± and Ξ² are computed in NeoLLMAttention.forward and passed in; |
| this function only performs the affine reweighting and the value aggregation. |
| |
| The existing Gated Attention gate (applied post-SDPA to the concatenated |
| output before o_proj) is orthogonal to this and is not modified here. |
| |
| Reference: Bae et al. (2026), Affine-Scaled Attention, Eq. 6β8. |
| |
| Args: |
| alpha: [batch, num_heads, seq_q, 1] β input-dependent scale per query |
| beta: [batch, num_heads, seq_q, 1] β input-dependent bias per query |
| attn_analysis: Optional[AttentionAnalysis] β deposits pre/post-affine weights |
| when not None. No-op during training (always None then). |
| """ |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
| attn_weights_softmax = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(query.dtype) |
|
|
| if attn_analysis is not None: |
| attn_analysis.attn_weights_pre_affine = attn_weights_softmax.detach() |
|
|
| |
| |
| |
| |
| attn_weights_affine = alpha * attn_weights_softmax + beta |
|
|
| if attn_analysis is not None: |
| attn_analysis.attn_weights_post_affine = attn_weights_affine.detach() |
|
|
| attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous() |
| attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training) |
| return attn_output, attn_weights_affine |
|
|
|
|
| def affine_scaled_flash_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| alpha: torch.Tensor, |
| beta: torch.Tensor, |
| dropout: float = 0.0, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| """ |
| Affine-Scaled Attention β flash/sdpa path. |
| |
| Exact mathematical decomposition of [Ξ±Β·softmax(QKα΅)+Ξ²]V using only the |
| public flash/sdpa interface β no kernel modification required. |
| |
| Derivation |
| ---------- |
| The paper formula expands distributively: |
| |
| [Ξ± Β· softmax(QKα΅/βdk) + Ξ²] Β· V |
| = Ξ± Β· [softmax(QKα΅/βdk) Β· V] β term 1: standard flash output |
| + Ξ² Β· [Ξ£_{jβ€i} V_j] β term 2: causal prefix-sum of V |
| |
| Term 2 follows because Ξ² is a scalar per query (broadcast over all S_k), |
| and softmax weights sum to 1 over the causal window: |
| Ξ²_i Β· Ξ£_j w_{i,j} Β· V_j = Ξ²_i Β· Ξ£_{jβ€i} V_j (since Ξ£ w_{i,j} = 1) |
| |
| Dropout |
| ------- |
| The eager path drops entries of the combined weight matrix (Ξ±Β·softmax + Ξ²) |
| before multiplying by V. With the flash interface we cannot access that |
| combined matrix, so we apply dropout=0 to the flash kernel and instead |
| apply nn.functional.dropout to the final combined output tensor. This is |
| output dropout rather than weight dropout β a different (but standard) |
| regularisation that achieves the same intent without the intermediate |
| weight matrix. During inference dropout=0 so the paths are identical. |
| |
| Padding mask |
| ------------ |
| The V_cumsum must not accumulate values from padding positions, since the |
| flash kernel zeros those out internally but a plain cumsum does not. |
| We zero V at padding positions before the cumsum by reading the diagonal |
| of the attention_mask (valid positions have mask value 0, padding has -inf). |
| |
| Memory overhead vs standard flash call |
| --------------------------------------- |
| One extra tensor of shape [B, H_q, S, d_head] for V_cumsum. |
| At (B=1, H=8, S=2048, d_head=64, bf16): β 4 MB per call, β 48 MB total |
| across 12 layers. Allocated and freed within each forward call. |
| |
| Attention-weight analysis fields (attn_weights_pre_affine, |
| attn_weights_post_affine) remain None β those tensors are never |
| materialised by the flash kernel and cannot be recovered. |
| All other AnalysisState fields (alpha, beta, alpha_ma) are deposited |
| by the caller before this function is invoked. |
| |
| Args: |
| alpha: [B, H_q, S_q, 1] β input-dependent scale per query, in [0, 1] |
| beta: [B, H_q, S_q, 1] β moving-average bias per query |
| attn_analysis: deposited by caller; this function does not write to it. |
| """ |
| |
| |
| attn_fn = ALL_ATTENTION_FUNCTIONS[module.config._attn_implementation] |
| flash_out, _ = attn_fn( |
| module, query, key, value, attention_mask, |
| dropout=0.0, scaling=scaling, **kwargs, |
| ) |
| |
|
|
| |
| |
| value_expanded = repeat_kv(value, module.num_key_value_groups) |
|
|
| |
| |
| |
| if attention_mask is not None and attention_mask.ndim == 4: |
| diag = attention_mask.diagonal(dim1=-2, dim2=-1) |
| |
| valid = (diag == 0).to(value_expanded.dtype) |
| valid = valid.unsqueeze(-1) |
| |
| value_expanded = value_expanded * valid |
|
|
| |
| v_cumsum = value_expanded.cumsum(dim=2) |
| |
| v_cumsum_t = v_cumsum.transpose(1, 2).contiguous() |
|
|
| |
| alpha_t = alpha.permute(0, 2, 1, 3) |
| beta_t = beta.permute(0, 2, 1, 3) |
|
|
| |
| output = alpha_t * flash_out + beta_t * v_cumsum_t |
| output = nn.functional.dropout(output, p=dropout, training=module.training) |
|
|
| |
| return output, None |
|
|
|
|
| class HadamardOProj(nn.Module): |
| """ |
| Parameter-free WalshβHadamard output projection with learnable affine rescaling. |
| |
| Replaces the dense W_O β R^{dΓd} in multi-head attention with a fixed |
| orthogonal WalshβHadamard Transform followed by a per-channel learnable |
| affine: output = Ξ± β FWHT(x) + Ξ² |
| |
| Motivation (Aggarwal & Kumar, 2026, arXiv:2603.08343): |
| The standard dense o_proj develops extreme condition numbers during |
| training (ΞΊ up to 10^5 observed in practice) because the optimiser has |
| no incentive to keep singular values balanced β some directions are |
| amplified while others collapse toward zero. This makes the layer |
| hostile to FP8 quantisation, which uses a single per-tensor scale and |
| therefore loses the low-magnitude directions entirely. |
| |
| The WalshβHadamard Transform is a fixed orthogonal matrix whose |
| singular values are all identically 1, making ΞΊ = 1 by construction. |
| It cannot develop condition-number pathology because it has no |
| parameters. The learnable Ξ±/Ξ² restore per-channel expressivity at |
| a cost of 2Β·d parameters instead of dΒ². |
| |
| Properties: |
| - Condition number: ΞΊ = 1 (exact, permanent, by construction) |
| - Parameters: 2Β·d vs dΒ² for dense (~25% attention params saved) |
| - Forward FLOPs: O(d log d) vs O(dΒ²) for dense |
| - Norm preservation: FWHT is isometric β βFWHT(x)ββ = βxββ |
| - FP8 friendliness: single per-tensor scale covers all directions equally |
| - Requires: d must be a power of 2 |
| |
| The FWHT is implemented as an in-place iterative butterfly (Cooley-Tukey |
| pattern over additions/subtractions) followed by 1/βd normalisation to |
| produce an orthonormal transform (H^T H = I). No external dependency. |
| |
| Reference: |
| Aggarwal, S. & Kumar, L. (2026). "Rethinking Attention Output |
| Projection: Structured Hadamard Transforms for Efficient Transformers." |
| arXiv:2603.08343. |
| """ |
|
|
| def __init__(self, dim: int, bias: bool = True): |
| super().__init__() |
| assert dim > 0 and (dim & (dim - 1)) == 0, ( |
| f"HadamardOProj requires dim to be a power of 2, got {dim}" |
| ) |
| self.dim = dim |
| self.norm = dim ** -0.5 |
|
|
| |
| |
| |
| self.alpha = nn.Parameter(torch.ones(dim)) |
| self.beta = nn.Parameter(torch.zeros(dim)) if bias else None |
|
|
| def _fwht(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Iterative in-place Fast WalshβHadamard Transform over the last dim. |
| |
| Butterfly pattern: logβ(d) stages, each pairing elements at stride h. |
| Cost: dΒ·logβ(d) additions/subtractions, zero multiplications. |
| Compatible with torch.compile β all shapes are static, no Python loops |
| visible to the tracer once d is fixed. |
| """ |
| h = 1 |
| while h < self.dim: |
| |
| x = x.reshape(*x.shape[:-1], -1, 2 * h) |
| a, b = x[..., :h], x[..., h:] |
| |
| x = torch.cat([a + b, a - b], dim=-1) |
| x = x.reshape(*x.shape[:-2], self.dim) |
| h *= 2 |
| return x |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional["HadamardAnalysis"] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: [..., dim] β concatenated multi-head attention outputs |
| analysis: HadamardAnalysis container populated when analysis mode |
| is active (eval + model.enable_analysis()). None otherwise. |
| |
| Returns: |
| Ξ± β (FWHT(x) / βdim) + Ξ² of shape [..., dim] |
| """ |
| out = self._fwht(x) * self.norm |
|
|
| if analysis is not None: |
| analysis.post_fwht = out.detach() |
| analysis.alpha_snapshot = self.alpha.detach() |
|
|
| out = out * self.alpha |
| if self.beta is not None: |
| out = out + self.beta |
| return out |
|
|
|
|
| class REPOModule(nn.Module): |
| """ |
| Context Re-Positioning module f_Ο (Li et al., 2026, arXiv:2512.14391). |
| |
| Replaces the fixed linear integer indices ``0β¦L-1`` fed to RoPE with |
| continuous, data-dependent positions ``z_i`` learned end-to-end. |
| |
| Architecture (Eq. 4β6 of the paper): |
| |
| # Position representation β shared across all heads in this layer |
| r_i = Swish(h_i W_g) β (h_i W_c) r_i β R^{d_p} |
| |
| # Position assignment β independent per head |
| z_i^(h) = r_i w_z^(h) z_i^(h) β R (scalar) |
| |
| where ``h_i β R^d`` is the hidden state of token ``i`` entering the |
| decoder layer (pre-FANLayer), and ``d_p = hidden_size // 8`` by default. |
| |
| The resulting positions ``z [B, H, S]`` are real-valued and |
| unconstrained. They are used to compute per-head ``cos/sin`` embeddings |
| inline from ``inv_freq``, replacing the standard integer-based |
| ``position_embeddings`` for this layer. |
| |
| Design notes: |
| - ``W_g`` and ``W_c`` are shared across heads (parameter efficiency). |
| - ``W_z`` is a single ``[d_p, num_heads]`` matrix; each column is the |
| per-head assignment vector ``w_z^(h)``. Vectorized as one matmul. |
| - The raw hidden state ``h_i`` (not the FAN-augmented or normed variant) |
| is used as input, matching the paper's formulation and avoiding |
| circular dependency with q/k norm. |
| - No bias on any projection β consistent with the paper's Eq. 4β5. |
| |
| Reference: |
| Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). "REPO: Language |
| Models with Context Re-Positioning." arXiv:2512.14391. |
| """ |
|
|
| def __init__(self, hidden_size: int, d_p: int, num_heads: int): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.d_p = d_p |
| self.num_heads = num_heads |
|
|
| |
| self.W_g = nn.Linear(hidden_size, d_p, bias=False) |
| self.W_c = nn.Linear(hidden_size, d_p, bias=False) |
|
|
| |
| |
| self.W_z = nn.Linear(d_p, num_heads, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| repo_analysis: Optional[REPOAnalysis] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states: [B, S, hidden_size] β residual stream entering the |
| decoder layer, before FANLayer augmentation. |
| repo_analysis: REPOAnalysis container populated when analysis mode |
| is active. None during training (zero overhead). |
| |
| Returns: |
| z: [B, H, S] β continuous per-head position scalars. |
| z[:, h, i] is the position assigned to token i by head h. |
| """ |
| |
| r = F.silu(self.W_g(hidden_states)) * self.W_c(hidden_states) |
|
|
| |
| |
| z = self.W_z(r).transpose(1, 2).contiguous() |
|
|
| if repo_analysis is not None: |
| repo_analysis.r_repr = r.detach() |
| repo_analysis.positions = z.detach() |
|
|
| return z |
|
|
|
|
| def _apply_repo_rope( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| z: torch.Tensor, |
| inv_freq: torch.Tensor, |
| attention_scaling: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply RoPE to Q and K using continuous per-head positions from REPO. |
| |
| Replaces the standard ``apply_rotary_pos_emb(q, k, cos, sin)`` call for |
| layers where REPO is active. Builds ``cos/sin`` inline from ``z`` and |
| ``inv_freq`` so that the rotation is differentiable w.r.t. ``z`` and |
| therefore w.r.t. the parameters of REPOModule. |
| |
| Args: |
| q: [B, H, S, head_dim] |
| k: [B, H_kv, S, head_dim] (GQA: H_kv β€ H) |
| z: [B, H, S] β per-head positions from REPOModule |
| inv_freq: [rotary_dim/2] β frozen RoPE frequency vector |
| attention_scaling: float β scaling factor from NeoLLMRotaryEmbedding |
| |
| Returns: |
| (q_embed, k_embed) with the same shapes as (q, k). |
| |
| Implementation note on GQA: |
| Q has ``num_attention_heads`` heads; K/V have ``num_key_value_heads`` |
| heads (fewer under GQA). REPO produces one position per Q head. |
| For K we average the positions of the Q heads that map to each KV |
| head (groups of size ``num_key_value_groups``). This is the minimal |
| approach consistent with the paper's per-head independence claim: |
| each KV head receives a position that is representative of the Q |
| heads it serves. |
| """ |
| B, H, S = z.shape |
| H_kv = k.shape[1] |
| n_groups = H // H_kv |
| rotary_dim = inv_freq.shape[0] * 2 |
|
|
| |
| |
| |
| |
| |
| |
| |
| inv_freq_f = inv_freq |
|
|
| |
| z_q = z.float().unsqueeze(-1) |
| freqs_q = z_q * inv_freq_f |
| emb_q = torch.cat([freqs_q, freqs_q], dim=-1) |
| cos_q = (emb_q.cos() * attention_scaling).to(q.dtype) |
| sin_q = (emb_q.sin() * attention_scaling).to(q.dtype) |
|
|
| |
| z_k = z.view(B, H_kv, n_groups, S).mean(dim=2) |
| z_k = z_k.float().unsqueeze(-1) |
| freqs_k = z_k * inv_freq_f |
| emb_k = torch.cat([freqs_k, freqs_k], dim=-1) |
| cos_k = (emb_k.cos() * attention_scaling).to(k.dtype) |
| sin_k = (emb_k.sin() * attention_scaling).to(k.dtype) |
|
|
| |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
|
|
| q_embed = torch.cat( |
| [(q_rot * cos_q) + (rotate_half(q_rot) * sin_q), q_pass], dim=-1 |
| ) |
| k_embed = torch.cat( |
| [(k_rot * cos_k) + (rotate_half(k_rot) * sin_k), k_pass], dim=-1 |
| ) |
| return q_embed, k_embed |
|
|
|
|
| class NeoLLMAttention(nn.Module): |
| """ |
| Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers, |
| optional Momentum, MEA head-level composition, optional LUCID preconditioning, |
| optional Affine-Scaled Attention, optional Exclusive Self Attention, |
| optional Directional Routing (Taylor, 2026), and optional Context |
| Re-Positioning (Li et al., 2026). |
| |
| Directional Routing inserts at position C β post-XSA, pre-reshape β where |
| the output is already normalized (MEAHeadSeeDNorm) and has auto-position |
| removed (XSA). The router suppresses directions of cross-domain interference |
| orthogonal to the self-position already cleaned by XSA. |
| |
| Pipeline (all active simultaneously when enabled): |
| FANLayer β q_proj(gate) β q_norm/k_norm β REPO/RoPE β Momentum |
| β MEA(K,V) β LUCID(V) β v_ref β Affine-Scaled SDPA |
| β MEAHeadSeeDNorm β XSA β Directional Routing β reshape |
| β o_proj Β· sigmoid(gate) β dropout |
| |
| RoPE variants (controlled by config.use_repo and layer_idx): |
| use_repo=False (default): standard integer RoPE via pre-computed |
| position_embeddings β identical to prior behaviour. |
| use_repo=True, layer_idx >= repo_start_layer: |
| REPOModule f_Ο predicts continuous per-head positions |
| z [B, H, S] from hidden_states. cos/sin are built |
| inline from z and inv_freq so the rotation is |
| differentiable w.r.t. f_Ο parameters. |
| use_repo=True, layer_idx < repo_start_layer: |
| standard integer RoPE (lower layers capture surface |
| features that benefit less from re-positioning). |
| |
| o_proj variants (controlled by config.use_hadamard_o_proj): |
| False (default): dense LinearWithMultipliers β full expressivity, |
| develops high ΞΊ during training (FP8 risk). |
| True: HadamardOProj β fixed WHT + learnable Ξ±/Ξ², |
| ΞΊ = 1 by construction, 25% fewer attention params, |
| FP8-friendly (Aggarwal & Kumar, 2026, arXiv:2603.08343). |
| |
| References: |
| Directional Routing: Taylor (2026). arXiv:2603.14923. |
| Hadamard o_proj: Aggarwal & Kumar (2026). arXiv:2603.08343. |
| Context Re-Positioning: Li et al. (2026). arXiv:2512.14391. |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr( |
| config, "head_dim", config.hidden_size // config.num_attention_heads |
| ) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim ** -0.5 |
| self.sqrt_head_dim = math.sqrt(self.head_dim) |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.use_momentum_attention = getattr(config, "use_momentum_attention", False) |
| self.momentum_gamma = float(getattr(config, "momentum_gamma", 0.0)) |
| self.use_mea_attention = getattr(config, "use_mea_attention", False) |
| self.mea_component_key_value_heads = int( |
| getattr(config, "mea_component_key_value_heads", config.num_key_value_heads) |
| ) |
| self.mea_groupnorm_eps = float( |
| getattr(config, "mea_groupnorm_eps", config.rms_norm_eps) |
| ) |
| self.use_lucid_attention = getattr(config, "use_lucid_attention", False) |
| self.lucid_attention_eps = float( |
| getattr(config, "lucid_attention_eps", config.rms_norm_eps) |
| ) |
| self.use_hadamard_o_proj = getattr(config, "use_hadamard_o_proj", False) |
|
|
| self.fan_layer = FANLayer( |
| hidden_size=config.hidden_size, |
| fan_ratio=getattr(config, "fan_ratio", 0.125), |
| ) |
| fan_output_dim = config.hidden_size + int( |
| config.hidden_size * getattr(config, "fan_ratio", 0.125) |
| ) |
|
|
| self.q_proj = LinearWithMultipliers( |
| fan_output_dim, config.num_attention_heads * self.head_dim * 2, |
| bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=False, |
| ) |
| self.num_mea_component_heads = ( |
| self.mea_component_key_value_heads |
| if self.use_mea_attention else config.num_key_value_heads |
| ) |
| self.k_proj = nn.Linear( |
| fan_output_dim, self.num_mea_component_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.v_proj = nn.Linear( |
| fan_output_dim, self.num_mea_component_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| |
| |
| |
| |
| |
| _o_in = config.num_attention_heads * self.head_dim |
| if self.use_hadamard_o_proj: |
| assert _o_in == config.hidden_size, ( |
| f"HadamardOProj requires in_dim == out_dim, " |
| f"got {_o_in} vs {config.hidden_size}" |
| ) |
| self.o_proj = HadamardOProj(config.hidden_size, bias=config.attention_bias) |
| else: |
| self.o_proj = LinearWithMultipliers( |
| _o_in, config.hidden_size, |
| bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True, |
| ) |
|
|
| self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
| if self.use_mea_attention: |
| self.mea_key_mix = nn.Parameter( |
| torch.eye(self.num_mea_component_heads, config.num_key_value_heads) |
| ) |
| self.mea_value_mix = nn.Parameter( |
| torch.eye(self.num_mea_component_heads, config.num_key_value_heads) |
| ) |
| self.mea_output_norm = MEAHeadSeeDNorm( |
| num_heads=config.num_attention_heads, head_dim=self.head_dim, |
| num_kv_groups=self.num_key_value_groups, eps=self.mea_groupnorm_eps, |
| ) |
| else: |
| self.mea_key_mix = None |
| self.mea_value_mix = None |
| self.mea_output_norm = None |
|
|
| self.dropout = nn.Dropout(config.dropout_rate) |
| self.lambda_1 = nn.Parameter(torch.tensor(0.5)) |
| self.lambda_2 = nn.Parameter(torch.tensor(0.5)) |
|
|
| |
| self.use_affine_scaled_attention = getattr(config, "use_affine_scaled_attention", False) |
| self.affine_momentum = float(getattr(config, "affine_momentum", 0.9)) |
|
|
| |
| self.use_xsa = getattr(config, "use_xsa", False) |
| self.xsa_eps = float(getattr(config, "xsa_eps", 1e-6)) |
|
|
| if self.use_affine_scaled_attention: |
| self.alpha_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads, bias=False |
| ) |
| self.register_buffer( |
| "alpha_ma", |
| torch.zeros(1, config.num_attention_heads, 1, 1), |
| persistent=True, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.use_directional_routing = getattr(config, "use_directional_routing", False) |
| self.directional_routing_k = int(getattr(config, "directional_routing_k", 4)) |
| self.directional_routing_temp = float(getattr(config, "directional_routing_temp", 5.0)) |
|
|
| if self.use_directional_routing: |
| H = config.num_attention_heads |
| K = self.directional_routing_k |
| D = self.head_dim |
| R = config.hidden_size |
|
|
| |
| |
| self.direction_vecs = nn.Parameter( |
| torch.randn(H, K, D) |
| ) |
|
|
| |
| |
| |
| |
| self.direction_router = nn.Sequential( |
| SeeDNorm(R, eps=config.rms_norm_eps), |
| nn.Linear(R, R, bias=True), |
| nn.GELU(), |
| nn.Linear(R, R, bias=True), |
| nn.GELU(), |
| nn.Linear(R, R, bias=True), |
| nn.GELU(), |
| nn.Linear(R, H * K, bias=True), |
| ) |
| else: |
| self.direction_vecs = None |
| self.direction_router = None |
|
|
| |
| |
| |
| |
| |
| self.use_repo = ( |
| getattr(config, "use_repo", False) |
| and layer_idx >= getattr(config, "repo_start_layer", config.num_hidden_layers // 3) |
| ) |
| if self.use_repo: |
| _d_p = getattr(config, "repo_d_p", config.hidden_size // 8) |
| self.repo_module = REPOModule( |
| hidden_size=config.hidden_size, |
| d_p=_d_p, |
| num_heads=config.num_attention_heads, |
| ) |
| else: |
| self.repo_module = None |
|
|
| def _apply_momentum_attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if not self.use_momentum_attention or self.momentum_gamma == 0.0: |
| return q, k |
| dq = causal_first_difference(q) |
| dk = causal_first_difference(k) |
| q_new = q + self.momentum_gamma * dq |
| k_new = k + self.momentum_gamma * dk |
| if attn_analysis is not None: |
| attn_analysis.q_momentum_delta = dq.detach() |
| attn_analysis.k_momentum_delta = dk.detach() |
| attn_analysis.q_post_momentum = q_new.detach() |
| attn_analysis.k_post_momentum = k_new.detach() |
| return q_new, k_new |
|
|
| def _apply_mea_head_mixing( |
| self, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if not self.use_mea_attention: |
| return k, v |
| k_mixed = head_linear_compose(k, self.mea_key_mix).contiguous() |
| v_mixed = head_linear_compose(v, self.mea_value_mix).contiguous() |
| if attn_analysis is not None: |
| attn_analysis.mea_key_mix_matrix = self.mea_key_mix.detach() |
| attn_analysis.mea_value_mix_matrix = self.mea_value_mix.detach() |
| attn_analysis.k_post_mea = k_mixed.detach() |
| attn_analysis.v_post_mea = v_mixed.detach() |
| return k_mixed, v_mixed |
|
|
| def _apply_lucid_preconditioner( |
| self, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| ) -> torch.Tensor: |
| if not self.use_lucid_attention: |
| return v.contiguous() |
| key_rn = rms_key_unit_norm(k, eps=self.lucid_attention_eps) |
| logits = torch.matmul(key_rn, key_rn.transpose(-1, -2)) * self.scaling - self.sqrt_head_dim |
| prec = torch.tril(torch.exp(logits)) |
| kv = infer_key_validity(attention_mask, k.shape[-2], k.shape[1]) |
| if kv is not None: |
| prec = prec * (kv.unsqueeze(-1) & kv.unsqueeze(-2)).to(prec.dtype) |
| eye = torch.eye(prec.shape[-1], device=prec.device, dtype=prec.dtype).view( |
| 1, 1, prec.shape[-1], prec.shape[-1] |
| ) |
| prec = prec + eye * (1.0 - prec.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)) |
| result = torch.linalg.solve_triangular( |
| prec, v.float(), upper=False, unitriangular=True |
| ).to(v.dtype).contiguous() |
| if attn_analysis is not None: |
| attn_analysis.lucid_preconditioner = prec.detach() |
| attn_analysis.v_post_lucid = result.detach() |
| return result |
|
|
| def _apply_directional_routing( |
| self, |
| attn_out: torch.Tensor, |
| hidden_states: torch.Tensor, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| ) -> torch.Tensor: |
| """ |
| Directional suppression at position C (post-XSA, pre-reshape). |
| |
| Args: |
| attn_out: [B, S, H, d_head] β output after XSA and SeeDNorm. |
| hidden_states: [B, S, hidden_size] β pre-FAN residual stream, |
| used as router input (same as paper's x_i). |
| attn_analysis: Optional[AttentionAnalysis] β deposits DR internals. |
| Returns: |
| [B, S, H, d_head] with selected directional components suppressed. |
| """ |
| |
| |
| |
| |
| |
| |
| pooled = hidden_states.mean(dim=1) |
| logits = self.direction_router(pooled) |
| r = torch.sigmoid(self.directional_routing_temp * logits) |
| r = r.view(hidden_states.shape[0], self.config.num_attention_heads, |
| self.directional_routing_k) |
| |
| r = r.unsqueeze(1) |
|
|
| |
| |
| |
| d = F.normalize(self.direction_vecs, dim=-1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| proj = torch.einsum("bshd,hkd->bshk", attn_out, d) |
|
|
| |
| weighted = r * proj |
|
|
| |
| |
| suppression = torch.einsum("bshk,hkd->bshd", weighted, d) |
| result = attn_out - suppression |
|
|
| if attn_analysis is not None: |
| attn_analysis.direction_vecs_normalized = d.detach() |
| attn_analysis.dr_router_logits = logits.detach() |
| attn_analysis.dr_routing_weights = r.squeeze(1).detach() |
| attn_analysis.dr_projection = proj.detach() |
| attn_analysis.dr_suppression = suppression.detach() |
| attn_analysis.attn_output_post_routing = result.detach() |
|
|
| return result |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| first_layer_fan: Optional[torch.Tensor] = None, |
| attn_analysis: Optional[AttentionAnalysis] = None, |
| repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| input_shape = hidden_states.shape[:-1] |
|
|
| fan_a = attn_analysis.fan if attn_analysis is not None else None |
| h_fan = self.fan_layer(hidden_states, analysis=fan_a) |
| if first_layer_fan is not None: |
| h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan |
| current_layer_fan = h_fan.clone() |
|
|
| query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim) |
| kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim) |
|
|
| q_raw, gate = torch.chunk( |
| self.q_proj(h_fan).view(*input_shape, self.config.num_attention_heads, self.head_dim * 2), |
| 2, dim=-1, |
| ) |
| gate = gate.reshape(*input_shape, -1) |
|
|
| if attn_analysis is not None: |
| attn_analysis.q_raw = q_raw.detach() |
| attn_analysis.gate_raw = gate.detach() |
|
|
| q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2) |
| k = self.k_norm(self.k_proj(h_fan).view(kv_shape)).transpose(1, 2) |
| v = self.v_proj(h_fan).view(kv_shape).transpose(1, 2) |
|
|
| if attn_analysis is not None: |
| attn_analysis.q_post_norm = q.detach() |
| attn_analysis.k_post_norm = k.detach() |
| attn_analysis.v_raw = v.detach() |
|
|
| cos, sin = position_embeddings |
| if self.use_repo: |
| |
| |
| |
| |
| |
| |
| |
| repo_a = attn_analysis.repo if attn_analysis is not None else None |
| z = self.repo_module(hidden_states, repo_analysis=repo_a) |
| inv_freq, attn_scaling = repo_rope_args |
| q, k = _apply_repo_rope(q, k, z, inv_freq, attn_scaling) |
| else: |
| |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| if attn_analysis is not None: |
| attn_analysis.q_post_rope = q.detach() |
| attn_analysis.k_post_rope = k.detach() |
|
|
| q, k = self._apply_momentum_attention(q, k, attn_analysis=attn_analysis) |
| k, v = self._apply_mea_head_mixing(k, v, attn_analysis=attn_analysis) |
| v = self._apply_lucid_preconditioner(k, v, attention_mask, attn_analysis=attn_analysis) |
|
|
| |
| |
| v_ref = v if self.use_xsa else None |
|
|
| |
| |
| |
| |
| |
| alpha = None |
| beta = None |
| use_affine = self.use_affine_scaled_attention |
| if use_affine: |
| alpha = linear_clipping(self.alpha_proj(hidden_states)) |
| alpha = alpha.permute(0, 2, 1).unsqueeze(-1) |
| N = k.shape[-2] |
| beta = (self.alpha_ma.to(alpha.dtype) - alpha) / max(N, 1) |
| if self.training: |
| with torch.no_grad(): |
| batch_mean = alpha.mean(dim=(0, 2), keepdim=True) |
| self.alpha_ma.copy_( |
| self.affine_momentum * self.alpha_ma |
| + (1.0 - self.affine_momentum) * batch_mean |
| ) |
| if attn_analysis is not None: |
| attn_analysis.alpha_per_head = alpha.detach() |
| attn_analysis.beta_per_head = beta.detach() |
| attn_analysis.alpha_moving_avg = self.alpha_ma.detach() |
|
|
| if use_affine: |
| if self.config._attn_implementation == "eager": |
| |
| attn_out, attn_weights = affine_scaled_eager_attention_forward( |
| self, q, k, v, attention_mask, |
| scaling=self.scaling, alpha=alpha, beta=beta, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| attn_analysis=attn_analysis, |
| **kwargs, |
| ) |
| else: |
| |
| |
| attn_out, attn_weights = affine_scaled_flash_attention_forward( |
| self, q, k, v, attention_mask, |
| scaling=self.scaling, alpha=alpha, beta=beta, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| attn_analysis=attn_analysis, |
| **kwargs, |
| ) |
| else: |
| if self.config._attn_implementation == "eager": |
| attn_fn = eager_attention_forward |
| else: |
| attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| attn_out, attn_weights = attn_fn( |
| self, q, k, v, attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, **kwargs, |
| ) |
| if attn_analysis is not None: |
| attn_analysis.attn_weights = ( |
| attn_weights.detach() if attn_weights is not None else None |
| ) |
|
|
| if attn_analysis is not None: |
| attn_analysis.attn_output_raw = attn_out.detach() |
|
|
| attn_out = attn_out.reshape(*input_shape, -1, self.head_dim) |
| if self.use_mea_attention: |
| attn_out = self.mea_output_norm(attn_out) |
| if attn_analysis is not None: |
| attn_analysis.attn_output_post_mea_norm = attn_out.detach() |
|
|
| |
| |
| |
| if self.use_xsa and v_ref is not None: |
| v_ref_expanded = repeat_kv(v_ref, self.num_key_value_groups) |
| v_ref_t = v_ref_expanded.transpose(1, 2) |
| v_ref_t = v_ref_t.to(attn_out.dtype) |
| proj = (attn_out * v_ref_t).sum(dim=-1, keepdim=True) |
| norm_sq = (v_ref_t * v_ref_t).sum(dim=-1, keepdim=True).clamp(min=self.xsa_eps) |
| xsa_comp = (proj / norm_sq) * v_ref_t |
| attn_out = attn_out - xsa_comp |
| if attn_analysis is not None: |
| attn_analysis.xsa_self_position_component = xsa_comp.detach() |
| attn_analysis.attn_output_post_xsa = attn_out.detach() |
|
|
| |
| |
| |
| |
| |
| if self.use_directional_routing: |
| attn_out = self._apply_directional_routing( |
| attn_out, hidden_states, attn_analysis=attn_analysis |
| ) |
|
|
| |
| attn_out_flat = attn_out.reshape(*input_shape, -1).contiguous() |
| if attn_analysis is not None: |
| attn_analysis.attn_output_pre_gate = attn_out_flat.detach() |
| gate_sig = torch.sigmoid(gate) |
| attn_analysis.gate_sigmoid = gate_sig.detach() |
| gated = attn_out_flat * gate_sig |
| if self.use_hadamard_o_proj: |
| |
| attn_out_gated = self.o_proj(gated, analysis=attn_analysis.hadamard) |
| else: |
| attn_out_gated = self.o_proj(gated) |
| else: |
| attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate)) |
|
|
| attn_out_gated = self.dropout(attn_out_gated) |
|
|
| if attn_analysis is not None: |
| attn_analysis.attn_output_final = attn_out_gated.detach() |
|
|
| return attn_out_gated, attn_weights, current_layer_fan |
| class PolyNorm(nn.Module): |
| def __init__( |
| self, |
| eps: float = 1e-6, |
| proj_eps: float = 1e-6, |
| exclusive_init: float = 0.5, |
| exclusive: bool = True, |
| ): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(3) / 3) |
| self.bias = nn.Parameter(torch.zeros(1)) |
| self.eps = eps |
| self.exclusive = exclusive |
|
|
| if exclusive: |
| self.proj_eps = proj_eps |
| |
| |
| exclusive_init = float(min(max(exclusive_init, 1e-4), 1.0 - 1e-4)) |
| init = torch.full((2,), exclusive_init, dtype=torch.float32) |
| self.exclusive_logits = nn.Parameter(torch.logit(init)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def _exclusive(self, branch, ref, alpha, x1_f, ref_norm_sq): |
| """ |
| Elimina de `branch` la componente alineada con `ref` (la rama lineal x1), |
| ponderada por alpha β (0, 1), y renormaliza el resultado. |
| |
| branch := _norm(branch - alpha Β· proj_{ref}(branch)) |
| |
| El denominador ref_norm_sq se pasa precalculado para evitar duplicarlo |
| cuando se llama dos veces por forward (una vez para x2, otra para x3). |
| """ |
| branch_f = branch.float() |
| dot = (branch_f * x1_f).sum(dim=-1, keepdim=True) |
| proj_coeff = (dot / ref_norm_sq).to(branch.dtype) |
| out = branch - alpha.to(branch.dtype) * proj_coeff * ref |
| return self._norm(out) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[PolyNormAnalysis] = None, |
| ) -> torch.Tensor: |
| |
| x_sq = x.pow(2) |
| x_cu = x * x_sq |
|
|
| |
| x1 = x * x_sq.mean(-1, keepdim=True).add(self.eps).rsqrt() |
| x2 = x_sq * (x_sq * x_sq).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| x3 = x_cu * (x_cu * x_cu).mean(-1, keepdim=True).add(self.eps).rsqrt() |
|
|
| if analysis is not None: |
| analysis.x1 = x1.detach() |
| analysis.x2_pre_exclusive = x2.detach() |
| analysis.x3_pre_exclusive = x3.detach() |
|
|
| if self.exclusive: |
| |
| alpha2, alpha3 = torch.sigmoid(self.exclusive_logits).unbind() |
|
|
| if analysis is not None: |
| analysis.alpha2 = alpha2.detach() |
| analysis.alpha3 = alpha3.detach() |
| analysis.weights = self.weight.detach() |
| analysis.bias = self.bias.detach() |
|
|
| |
| x1_f = x1.float() |
| ref_norm_sq = x1_f.pow(2).sum(-1, keepdim=True).clamp_min(self.proj_eps) |
|
|
| |
| x2 = self._exclusive(x2, x1, alpha2, x1_f, ref_norm_sq) |
| x3 = self._exclusive(x3, x1, alpha3, x1_f, ref_norm_sq) |
|
|
| if analysis is not None: |
| analysis.x2_post_exclusive = x2.detach() |
| analysis.x3_post_exclusive = x3.detach() |
| else: |
| if analysis is not None: |
| analysis.weights = self.weight.detach() |
| analysis.bias = self.bias.detach() |
|
|
| output = ( |
| self.weight[0] * x3 |
| + self.weight[1] * x2 |
| + self.weight[2] * x1 |
| + self.bias |
| ) |
|
|
| if analysis is not None: |
| analysis.output = output.detach() |
|
|
| return output |
|
|
|
|
| def compute_versatile_aux_loss( |
| aux_stats_list: list, |
| n_experts: int, |
| weight: float, |
| ) -> torch.Tensor: |
| """ |
| Load-balancing auxiliary loss for VersatileFFN width-path expert routing. |
| |
| Identical formula to JTok-M: L_aux = Ξ» Β· n_e Β· Ξ£_i p_i Β· f_i |
| averaged over all decoder layers with active VersatileFFN. |
| |
| Args: |
| aux_stats_list: list of (p_sum [n_e], f_sum [n_e], N_tokens) per layer. |
| n_experts: total number of virtual experts. |
| weight: Ξ» coefficient. |
| Returns: |
| Scalar loss tensor. |
| """ |
| total_loss = None |
| for p_sum, f_sum, N in aux_stats_list: |
| p_i = p_sum / N |
| layer_loss = weight * n_experts * (p_i * f_sum).sum() |
| total_loss = layer_loss if total_loss is None else total_loss + layer_loss |
| if total_loss is None: |
| return torch.tensor(0.0) |
| return total_loss / len(aux_stats_list) |
|
|
|
|
| class VersatileFFN(nn.Module): |
| """ |
| VersatileFFN: dual-process feed-forward network with parameter reuse. |
| |
| Drop-in replacement for NeoLLMMLP. Shares the same weight matrices but |
| reuses them across two complementary paths (Nie et al., 2026): |
| |
| Width-Versatile path (virtual MoE, Eq. 7β8): |
| Derives N virtual sub-experts by slicing non-overlapping contiguous |
| segments of the intermediate dimension. A Top-K router selects |
| ``active_experts`` experts per token. No additional parameters beyond |
| the router ``expert_gate [hidden, N_experts]``. |
| |
| Depth-Versatile path (recursive, Eq. 9β11): |
| Applies the full shared MLP (inner SeeDNorm + FANLayer + gate/up/down) |
| recursively up to ``max_depth`` times. A Gumbel-Softmax loop predictor |
| (``depth_predictor [hidden, max_depth]``) decides per-token depth. |
| During training: always L_max iterations with soft STE weighting. |
| During inference: early exit at argmax(p) to save FLOPs. |
| |
| Difficulty-aware fusion (Eq. 12β13): |
| Ξ» = (L_max β E[L]) / L_max β [0, 1) |
| output = Ξ» Β· Y_width + (1 β Ξ») Β· Y_depth |
| Easy tokens (low E[L] β Ξ» β 1) favour the fast width path. |
| Hard tokens (high E[L] β Ξ» β 0) favour the deep recursive path. |
| |
| NeoLLM-specific adaptations vs. the OLMo reference implementation: |
| - FANLayer runs once per forward (shared between both paths). |
| - Expert slicing uses contiguous segments (no SwiGLU half-shift) |
| because gate and up are separate projections here. |
| - PolyNorm (shared parameters) is used as the activation in both paths. |
| - Multipliers from LinearWithMultipliers are applied per-slice in the |
| width path to correctly replicate the full MLP forward. |
| - Gumbel temperature stored as a persistent float32 buffer so it |
| survives checkpoint save/load without external tracking. |
| - Width path load-balancing returns (output, aux_stats) for integration |
| with the existing NeoLLMForCausalLM aux-loss accumulation pattern. |
| |
| Width dispatch (CUDAGraph-compatible sparse routing): |
| El dispatch original del paper (torch.where + index_add_) es sparse y |
| fiel al paper pero produce shapes dependientes de datos β incompatible |
| con CUDAGraphs. La implementaciΓ³n usa argsort como dispatcher estΓ‘tico: |
| |
| flat_expert [NΒ·K] β argsort β perm [NΒ·K] (shape siempre igual) |
| sorted_tok [NΒ·K] = flat_tok[perm] (Γndices de token originales) |
| grouped_tok [E, C] = sorted_tok.view(E, C) (C = NΒ·K // E, constante) |
| |
| Propiedades clave: |
| Β· argsort: output shape = input shape, siempre [NΒ·K]. CUDAGraph β |
| Β· C = N_tokΒ·K // E es un entero Python conocido en compile-time. |
| Con el aux loss manteniendo balance, cada experto recibe β C slots. |
| Β· scatter_add_ con index [C, D] de shape estΓ‘tico: CUDAGraph β |
| (los VALORES del index cambian por batch, no el SHAPE). |
| Β· FLOPs idΓ©nticos al original: cada experto procesa [C, D] = [NΒ·K/E, D] |
| tokens, no todos los N tokens. Con K=2, E=4: C = N/2 por experto. |
| |
| Conditional Parallelism (inferencia, Algorithm 2): |
| Β· Los tokens con Ξ»=0 (argmax β max_depth) igualmente participan en el |
| grouped buffer y su expert forward se computa (shapes estΓ‘ticos). |
| Β· Su contribuciΓ³n es cancelada por Ξ»=0 en la fusiΓ³n: |
| output = x_depthΒ·(1βΞ») + x_moeΒ·Ξ» β x_depth si Ξ»=0 |
| Β· Esto pierde el saving de FLOPs de los Ξ»=0 tokens, pero la correctitud |
| matemΓ‘tica es exacta. Tradeoff aceptable vs CUDAGraph-incompatibilidad. |
| |
| Discrete Early-Exit (inferencia, Algorithm 2): |
| Β· Sustituido por always-max_depth + torch.gather con loop_choice. |
| Para max_depth=2 el overhead es β€ 1 iteraciΓ³n extra por token. |
| |
| Reference: |
| Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in |
| LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531. |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig): |
| super().__init__() |
|
|
| self.total_experts = getattr(config, "versatile_total_experts", 8) |
| self.active_experts = getattr(config, "versatile_active_experts", 2) |
| self.max_depth = getattr(config, "versatile_max_depth", 4) |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
|
|
| |
| fan_ratio = getattr(config, "fan_ratio_ffn", 0.0625) |
| fan_dim = config.hidden_size + int(config.hidden_size * fan_ratio) |
|
|
| self.fan_layer = FANLayer(hidden_size=config.hidden_size, fan_ratio=fan_ratio) |
| self.gate_proj = LinearWithMultipliers( |
| fan_dim, config.intermediate_size, |
| bias=False, use_row_multiplier=True, use_column_multiplier=False, |
| ) |
| self.up_proj = nn.Linear(fan_dim, config.intermediate_size, bias=False) |
| self.down_proj = LinearWithMultipliers( |
| config.intermediate_size, config.hidden_size, |
| bias=False, use_row_multiplier=True, use_column_multiplier=True, |
| ) |
| self.act_fn = PolyNorm(exclusive=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| |
| |
| |
| self.ff_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| |
| seg = config.intermediate_size // self.total_experts |
| self.expert_segment = seg |
| self.expert_gate = nn.Linear(config.hidden_size, self.total_experts, bias=False) |
|
|
| |
| |
| idx_list = [ |
| torch.arange(i * seg, (i + 1) * seg) |
| for i in range(self.total_experts) |
| ] |
| self.register_buffer("expert_idx", torch.stack(idx_list), persistent=False) |
|
|
| |
| self.depth_predictor = nn.Linear(config.hidden_size, self.max_depth, bias=False) |
|
|
| |
| |
| temp_start = getattr(config, "versatile_gumbel_temp_start", 5.0) |
| self.register_buffer( |
| "gumbel_temp", |
| torch.tensor(temp_start, dtype=torch.float32), |
| persistent=True, |
| ) |
| self._gumbel_temp_end = getattr(config, "versatile_gumbel_temp_end", 0.1) |
| self._gumbel_temp_decay = getattr(config, "versatile_gumbel_temp_decay", 0.99984) |
|
|
| |
|
|
| def update_gumbel_temperature(self) -> None: |
| """Decay Gumbel temperature by one step. Call once per training step.""" |
| new_t = max( |
| self.gumbel_temp.item() * self._gumbel_temp_decay, |
| self._gumbel_temp_end, |
| ) |
| self.gumbel_temp.fill_(new_t) |
|
|
| |
|
|
| def _expert_forward( |
| self, |
| x_fan: torch.Tensor, |
| x_in: torch.Tensor, |
| idx: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Virtual expert forward for a single expert (Eq. 6β8 of paper). |
| |
| Slices gate_proj, up_proj, down_proj weights to the expert segment. |
| Multipliers from LinearWithMultipliers are applied per-slice to keep |
| the computation exactly equivalent to the full MLP forward. |
| |
| Returns x_in + expert_out (residual included, matching Eq. 8). |
| """ |
| |
| g_w = self.gate_proj.linear.weight[idx] |
| g_mul = self.gate_proj.row_multiplier.multiplier[idx] |
| gate = F.linear(x_fan, g_w) * g_mul |
|
|
| |
| u_w = self.up_proj.weight[idx] |
| up = F.linear(x_fan, u_w) |
|
|
| |
| act = self.act_fn(gate) * up |
| act = self.dropout(act) |
|
|
| |
| col_mul = self.down_proj.column_multiplier.multiplier[idx] |
| d_w = self.down_proj.linear.weight[:, idx] |
| row_mul = self.down_proj.row_multiplier.multiplier |
| out = F.linear(act * col_mul, d_w) * row_mul |
|
|
| return x_in + out |
|
|
| def _full_forward_step(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| One recursive MLP application for the depth path (Eq. 9). |
| |
| Includes inner SeeDNorm + FANLayer so each iteration starts from a |
| normalized, periodicity-augmented view of the current hidden state. |
| Residual connection applied inside to match the paper's formulation. |
| """ |
| og = x |
| x = self.ff_norm(x) |
| x_f = self.fan_layer(x) |
| gate = self.gate_proj(x_f) |
| up = self.up_proj(x_f) |
| act = self.act_fn(gate) * up |
| act = self.dropout(act) |
| return og + self.down_proj(act) |
|
|
| |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[MLPAnalysis] = None, |
| ) -> Tuple[torch.Tensor, Optional[tuple]]: |
| """ |
| Args: |
| x: [B, S, hidden_size] β pre-normalized input |
| (SeeDNorm + LNS already applied by NeoLLMDecoderLayer). |
| analysis: MLPAnalysis with .versatile pre-allocated when analysis |
| mode is active. None during training (zero overhead). |
| |
| Returns: |
| (output [B, S, hidden_size], aux_stats) |
| aux_stats = (p_sum [N_experts], f_sum [N_experts], N_tokens) |
| for load-balancing loss, or None in inference if width path skipped. |
| """ |
| B, S, D = x.shape |
|
|
| |
| depth_logits = self.depth_predictor(x) |
|
|
| |
| x_fan = self.fan_layer(x) |
|
|
| |
| if self.training: |
| |
| |
| depth_probs = F.gumbel_softmax( |
| depth_logits, |
| tau=float(self.gumbel_temp), |
| hard=True, |
| dim=-1, |
| ) |
|
|
| |
| depth_outputs = [] |
| current_x = x |
| for _ in range(self.max_depth): |
| current_x = self._full_forward_step(current_x) |
| depth_outputs.append(current_x) |
|
|
| |
| depth_stack = torch.stack(depth_outputs, dim=-1) |
| x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| K = self.active_experts |
| E = self.total_experts |
| N_tok = B * S |
| C = (N_tok * K) // E |
|
|
| routing_logits = self.expert_gate(x) |
| topk_w, topk_i = torch.topk(routing_logits, k=K, dim=-1) |
| topk_w = torch.softmax(topk_w, dim=-1) |
|
|
| x_flat = x.reshape(-1, D) |
| x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) |
|
|
| |
| flat_expert = topk_i.reshape(-1) |
| flat_tok = ( |
| torch.arange(N_tok, device=x.device, dtype=torch.long) |
| .unsqueeze(1).expand(N_tok, K).reshape(-1) |
| ) |
| flat_w = topk_w.reshape(-1) |
|
|
| |
| perm = torch.argsort(flat_expert, stable=True) |
| sorted_tok = flat_tok[perm] |
| sorted_w = flat_w[perm] |
|
|
| |
| grouped_tok = sorted_tok.view(E, C) |
| grouped_w = sorted_w.view(E, C) |
|
|
| |
| flat_idx = grouped_tok.reshape(-1) |
| fan_dim = x_fan_flat.shape[-1] |
| x_grouped = x_flat[flat_idx].view(E, C, D) |
| xf_grouped = x_fan_flat[flat_idx].view(E, C, fan_dim) |
|
|
| |
| |
| x_moe_flat = torch.zeros(N_tok, D, device=x.device, dtype=x.dtype) |
| for eid in range(E): |
| |
| out_e = self._expert_forward( |
| xf_grouped[eid], x_grouped[eid], self.expert_idx[eid] |
| ) |
| w_e = grouped_w[eid].unsqueeze(-1) |
| tok_idx_e = grouped_tok[eid].unsqueeze(1).expand(C, D) |
| |
| |
| |
| x_moe_flat.scatter_add_( |
| 0, tok_idx_e, (out_e * w_e).to(x_moe_flat.dtype) |
| ) |
|
|
| x_moe = x_moe_flat.reshape(B, S, D) |
|
|
| |
| |
| |
| r_probs_flat = torch.softmax( |
| routing_logits.reshape(-1, E), dim=-1 |
| ) |
| p_sum = r_probs_flat.sum(dim=0) |
| f_sum = ( |
| F.one_hot(flat_expert.long(), E).float().sum(dim=0) |
| / float(N_tok * K) |
| ) |
| aux_stats = (p_sum, f_sum, N_tok) |
|
|
| |
| loop_idx = torch.arange( |
| 1, self.max_depth + 1, device=x.device, dtype=depth_probs.dtype |
| ) |
| expected_L = (depth_probs * loop_idx).sum(dim=-1) |
| moe_weight = (self.max_depth - expected_L) / self.max_depth |
| output = ( |
| x_depth * (1.0 - moe_weight.unsqueeze(-1)) |
| + x_moe * moe_weight.unsqueeze(-1) |
| ) |
| loop_choice = None |
|
|
| |
| else: |
| loop_choice = depth_logits.argmax(dim=-1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| depth_outputs = [] |
| current_x = x |
| for _ in range(self.max_depth): |
| current_x = self._full_forward_step(current_x) |
| depth_outputs.append(current_x) |
|
|
| depth_stack = torch.stack(depth_outputs, dim=-1) |
| gather_idx = ( |
| loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1) |
| ) |
| x_depth = depth_stack.gather(dim=-1, index=gather_idx).squeeze(-1) |
|
|
| |
| expected_L = (loop_choice + 1).float() |
| moe_weight = (self.max_depth - expected_L) / self.max_depth |
|
|
| aux_stats = None |
| depth_probs = None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| N_tok_inf = B * S |
| K_inf = self.active_experts |
| E_inf = self.total_experts |
| C_inf = (N_tok_inf * K_inf) // E_inf |
|
|
| x_flat = x.reshape(-1, D) |
| x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) |
|
|
| routing_logits = self.expert_gate(x_flat) |
| tw, ti = torch.topk(routing_logits, k=K_inf, dim=-1) |
| tw = torch.softmax(tw, dim=-1) |
|
|
| flat_expert_i = ti.reshape(-1) |
| flat_tok_i = ( |
| torch.arange(N_tok_inf, device=x.device, dtype=torch.long) |
| .unsqueeze(1).expand(N_tok_inf, K_inf).reshape(-1) |
| ) |
| flat_w_i = tw.reshape(-1) |
|
|
| perm_i = torch.argsort(flat_expert_i, stable=True) |
| sorted_tok_i = flat_tok_i[perm_i] |
| sorted_w_i = flat_w_i[perm_i] |
|
|
| grouped_tok_i = sorted_tok_i.view(E_inf, C_inf) |
| grouped_w_i = sorted_w_i.view(E_inf, C_inf) |
|
|
| flat_idx_i = grouped_tok_i.reshape(-1) |
| fan_dim_i = x_fan_flat.shape[-1] |
| x_grouped_i = x_flat[flat_idx_i].view(E_inf, C_inf, D) |
| xf_grouped_i = x_fan_flat[flat_idx_i].view(E_inf, C_inf, fan_dim_i) |
|
|
| x_moe_flat_i = torch.zeros(N_tok_inf, D, device=x.device, dtype=x.dtype) |
| for eid in range(E_inf): |
| out_e_i = self._expert_forward( |
| xf_grouped_i[eid], x_grouped_i[eid], self.expert_idx[eid] |
| ) |
| w_e_i = grouped_w_i[eid].unsqueeze(-1) |
| tok_idx_e_i = grouped_tok_i[eid].unsqueeze(1).expand(C_inf, D) |
| x_moe_flat_i.scatter_add_( |
| 0, tok_idx_e_i, (out_e_i * w_e_i).to(x_moe_flat_i.dtype) |
| ) |
|
|
| x_moe = x_moe_flat_i.reshape(B, S, D) |
|
|
| output = ( |
| x_depth * (1.0 - moe_weight.unsqueeze(-1)) |
| + x_moe * moe_weight.unsqueeze(-1) |
| ) |
|
|
| |
| if analysis is not None: |
| if analysis.versatile is not None: |
| va = analysis.versatile |
| va.depth_probs = depth_probs.detach() if depth_probs is not None else None |
| va.expected_loops = expected_L.detach() |
| va.moe_weight = moe_weight.detach() |
| va.loop_choice = loop_choice.detach() if loop_choice is not None else None |
| va.x_depth = x_depth.detach() |
| va.x_width = x_moe.detach() |
| analysis.output = output.detach() |
|
|
| return output, aux_stats |
|
|
|
|
| class NeoLLMMLP(nn.Module): |
| """MLP with FANformer integration and Learnable Multipliers.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.fan_layer = FANLayer( |
| hidden_size=config.hidden_size, |
| fan_ratio=getattr(config, "fan_ratio_ffn", 0.0625), |
| ) |
| fan_dim = config.hidden_size + int( |
| config.hidden_size * getattr(config, "fan_ratio_ffn", 0.0625) |
| ) |
| self.gate_proj = LinearWithMultipliers( |
| fan_dim, config.intermediate_size, |
| bias=False, use_row_multiplier=True, use_column_multiplier=False, |
| ) |
| self.up_proj = nn.Linear(fan_dim, config.intermediate_size, bias=False) |
| self.down_proj = LinearWithMultipliers( |
| config.intermediate_size, config.hidden_size, |
| bias=False, use_row_multiplier=True, use_column_multiplier=True, |
| ) |
| self.act_fn = PolyNorm(exclusive_init=0.00, exclusive=getattr(config, "polynorm_exclusive", True)) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| analysis: Optional[MLPAnalysis] = None, |
| ) -> torch.Tensor: |
| fan_a = analysis.fan if analysis is not None else None |
| x_fan = self.fan_layer(x, analysis=fan_a) |
|
|
| gate_out = self.gate_proj(x_fan) |
| up_out = self.up_proj(x_fan) |
|
|
| if analysis is not None: |
| analysis.gate_proj_output = gate_out.detach() |
| analysis.up_proj_output = up_out.detach() |
|
|
| poly_a = analysis.polynorm if analysis is not None else None |
| act_out = self.act_fn(gate_out, analysis=poly_a) |
| act_x_up = act_out * up_out |
|
|
| if analysis is not None: |
| analysis.act_times_up = act_x_up.detach() |
|
|
| result = self.down_proj(self.dropout(act_x_up)) |
|
|
| if analysis is not None: |
| analysis.output = result.detach() |
|
|
| return result |
|
|
|
|
| class NeoLLMDecoderLayer(GradientCheckpointingLayer): |
| """ |
| Decoder layer with standard residual connections, optional JTok-M injection. |
| |
| Flow (JTok-M active): |
| 1. SeeDNorm β LNS(1/ββ) β Attention β residual + GPAS |
| 2. [capture hΜ = hidden after attention for JTok-M router] |
| 3. SeeDNorm β LNS(1/ββ) β MLP β Ξm |
| 4. h^{β+1} = hΜ + Ξm + Ξr (Ξr from JTok-M, scaled 1/β(2β)) |
| 5. GPAS |
| |
| LNS coordination: |
| LNS factor: 1/ββ |
| JTok-M factor: 1/β(2β) β ratio = 1/β2 constant at all depths. |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.layer_idx = layer_idx |
| self.use_jtokm = config.use_jtokm |
|
|
| self.self_attn = NeoLLMAttention(config, layer_idx) |
| self.mlp = ( |
| VersatileFFN(config) |
| if getattr(config, "use_versatile_ffn", False) |
| else NeoLLMMLP(config) |
| ) |
| self.use_versatile_ffn = getattr(config, "use_versatile_ffn", False) |
| self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.lns_attn = LNS(layer_idx) |
| self.lns_mlp = LNS(layer_idx) |
| self.gpas_attn = GPAS(config.hidden_size) |
| self.gpas_mlp = GPAS(config.hidden_size) |
| self.current_layer_fan = None |
|
|
| if self.use_jtokm: |
| self.jtokm = LeviathanJTokM(config, layer_idx) |
| else: |
| self.jtokm = None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| self.use_attn_res = getattr(config, 'use_attn_res', False) |
| if self.use_attn_res: |
| self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size)) |
| self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size)) |
| self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| else: |
| self.attn_res_query_attn = None |
| self.attn_res_query_mlp = None |
| self.attn_res_norm = None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.use_laurel = getattr(config, 'use_laurel', False) |
| self.use_laurel_rw = getattr(config, 'use_laurel_rw', True) |
| self.use_laurel_lr = getattr(config, 'use_laurel_lr', False) |
| D = config.hidden_size |
| r = getattr(config, 'laurel_lr_rank', 32) |
|
|
| if self.use_laurel and self.use_laurel_rw: |
| |
| |
| |
| self.laurel_rw_attn = nn.Parameter(torch.zeros(2)) |
| self.laurel_rw_mlp = nn.Parameter(torch.zeros(2)) |
| else: |
| self.laurel_rw_attn = None |
| self.laurel_rw_mlp = None |
|
|
| if self.use_laurel and self.use_laurel_lr: |
| |
| |
| |
| self.laurel_lr_A_attn = nn.Linear(D, r, bias=False) |
| self.laurel_lr_B_attn = nn.Linear(r, D, bias=False) |
| |
| self.laurel_lr_A_mlp = nn.Linear(D, r, bias=False) |
| self.laurel_lr_B_mlp = nn.Linear(r, D, bias=False) |
| |
| |
| for A_mat in (self.laurel_lr_A_attn, self.laurel_lr_A_mlp): |
| nn.init.zeros_(A_mat.weight) |
| for j in range(r): |
| for i in range(D): |
| if i % r == j: |
| A_mat.weight.data[j, i] = 1.0 / (r * D) ** 0.5 |
| for B_mat in (self.laurel_lr_B_attn, self.laurel_lr_B_mlp): |
| nn.init.zeros_(B_mat.weight) |
| else: |
| self.laurel_lr_A_attn = None |
| self.laurel_lr_B_attn = None |
| self.laurel_lr_A_mlp = None |
| self.laurel_lr_B_mlp = None |
|
|
| def _attn_res( |
| self, |
| sources: list, |
| partial: torch.Tensor, |
| query: torch.Tensor, |
| analysis_slot: Optional[AttnResAnalysis] = None, |
| analysis_key: Optional[str] = None, |
| ) -> torch.Tensor: |
| """ |
| Depth-wise softmax attention over preceding layer outputs. |
| |
| Computes: |
| V = stack(sources + [partial]) [N+1, B, S, D] |
| K = RMSNorm(V) [N+1, B, S, D] |
| logits = query Β· K [N+1, B, S] |
| weights = softmax(logits, dim=0) [N+1, B, S] |
| h = Ξ£_n weights_n Β· V_n [B, S, D] |
| |
| The pseudo-query is shared across positions (per the paper design). |
| RMSNorm on keys prevents layers with large-magnitude outputs from |
| dominating the softmax. Initialized to zero β uniform weights at |
| step 0, reducing to standard residual mean. |
| |
| Args: |
| sources: list of [B, S, D] tensors β completed block summaries |
| or all previous layer outputs (Full AttnRes). |
| partial: [B, S, D] β current intra-block partial sum. |
| query: [D] β learnable pseudo-query for this sublayer. |
| Returns: |
| [B, S, D] β weighted combination of sources + partial. |
| """ |
| all_v = sources + [partial] |
| V = torch.stack(all_v, dim=0) |
| K = self.attn_res_norm(V) |
| logits = torch.einsum('d,nbsd->nbs', query, K) |
| weights = torch.softmax(logits, dim=0) |
| return torch.einsum('nbs,nbsd->bsd', weights, V) |
|
|
| def _laurel_residual( |
| self, |
| residual: torch.Tensor, |
| delta: torch.Tensor, |
| rw_param: Optional[torch.Tensor], |
| A_mat, |
| B_mat, |
| analysis: Optional["LAuReLAnalysis"] = None, |
| slot: str = "attn", |
| ) -> torch.Tensor: |
| """ |
| Computes the LAuReL-augmented residual junction (Menghani et al., ICML 2025). |
| |
| Dispatches among three regimes depending on which sub-variants are active: |
| |
| LAUREL-RW only (use_laurel_rw=True, use_laurel_lr=False): |
| Ξ±, Ξ² = softmax([Ξ±Μ, Ξ²Μ]) |
| out = Ξ± Β· delta + Ξ² Β· residual (paper Β§2.1) |
| |
| LAUREL-LR only (use_laurel_rw=False, use_laurel_lr=True): |
| lr_delta = B(A(residual)) |
| out = delta + lr_delta + residual (paper eq. 3) |
| |
| LAUREL-RW+LR (both active, paper eq. 5): |
| Ξ±, Ξ² = softmax([Ξ±Μ, Ξ²Μ]) |
| lr_delta = B(A(residual)) |
| out = Ξ± Β· delta + Ξ² Β· (lr_delta + residual) |
| |
| In all cases the standard residual identity (out = delta + residual) is |
| recovered at initialisation: RW starts at (Ξ±=0.5, Ξ²=0.5); LR starts |
| with B=0 so lr_delta=0. |
| |
| Args: |
| residual: [B, S, D] β accumulated residual stream (x_i in the paper). |
| delta: [B, S, D] β sublayer output f(x_i) (attention or MLP). |
| rw_param: Parameter[2] β raw [Ξ±Μ, Ξ²Μ] before softmax; None if RW off. |
| A_mat: nn.Linear(Dβr, bias=False) β LR down-proj; None if LR off. |
| B_mat: nn.Linear(rβD, bias=False) β LR up-proj; None if LR off. |
| analysis: LAuReLAnalysis slot to deposit scalar diagnostics into. |
| slot: "attn" or "mlp" β selects which analysis fields to write. |
| |
| Returns: |
| [B, S, D] β augmented residual output for the current sublayer. |
| """ |
| has_rw = rw_param is not None |
| has_lr = A_mat is not None |
|
|
| |
| if has_lr: |
| lr_delta = B_mat(A_mat(residual)) |
| if analysis is not None and not self.training: |
| in_norm = residual.detach().norm(dim=-1).mean().item() |
| lr_norm = lr_delta.detach().norm(dim=-1).mean().item() |
| if slot == "attn": |
| analysis.lr_input_norm_attn = in_norm |
| analysis.lr_delta_norm_attn = lr_norm |
| else: |
| analysis.lr_input_norm_mlp = in_norm |
| analysis.lr_delta_norm_mlp = lr_norm |
| else: |
| lr_delta = None |
|
|
| |
| if has_rw: |
| ab = torch.softmax(rw_param.float(), dim=0).to(residual.dtype) |
| alpha = ab[0] |
| beta = ab[1] |
| if analysis is not None and not self.training: |
| if slot == "attn": |
| analysis.alpha_attn = alpha.item() |
| analysis.beta_attn = beta.item() |
| else: |
| analysis.alpha_mlp = alpha.item() |
| analysis.beta_mlp = beta.item() |
| else: |
| alpha = beta = None |
|
|
| |
| if has_rw and has_lr: |
| |
| return alpha * delta + beta * (lr_delta + residual) |
| elif has_rw: |
| |
| return alpha * delta + beta * residual |
| else: |
| |
| return delta + lr_delta + residual |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| first_layer_fan: Optional[torch.Tensor] = None, |
| z_tilde: Optional[torch.Tensor] = None, |
| B_vals: Optional[torch.Tensor] = None, |
| attn_res_sources: Optional[list] = None, |
| attn_res_partial: Optional[torch.Tensor] = None, |
| layer_analysis: Optional[LayerAnalysis] = None, |
| output_attentions: Optional[bool] = False, |
| repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple: |
| |
| if layer_analysis is not None: |
| layer_analysis.hidden_states_input = hidden_states.detach() |
|
|
| |
| |
| |
| |
| |
| |
| |
| ar_analysis = layer_analysis.attn_res if layer_analysis is not None else None |
| if self.use_attn_res and attn_res_sources is not None and attn_res_partial is not None: |
| h_attn = self._attn_res( |
| attn_res_sources, attn_res_partial, self.attn_res_query_attn, |
| ar_analysis, "attn", |
| ) |
| residual_attn = attn_res_partial |
| else: |
| h_attn = hidden_states |
| residual_attn = hidden_states |
|
|
| |
| sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None |
| h_normed = self.input_layernorm(h_attn, analysis=sn_pre) |
| h_lns = self.lns_attn(h_normed) |
| if layer_analysis is not None: |
| layer_analysis.lns_attn_output = h_lns.detach() |
|
|
| hidden_states, attn_weights, self.current_layer_fan = self.self_attn( |
| hidden_states=h_lns, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| first_layer_fan=first_layer_fan, |
| attn_analysis=layer_analysis.attention if layer_analysis is not None else None, |
| repo_rope_args=repo_rope_args, |
| **kwargs, |
| ) |
|
|
| if layer_analysis is not None: |
| layer_analysis.attn_contribution = hidden_states.detach() |
|
|
| gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None |
| |
| |
| |
| |
| if self.use_laurel: |
| attn_aug = self._laurel_residual( |
| residual_attn, hidden_states, |
| self.laurel_rw_attn, self.laurel_lr_A_attn, self.laurel_lr_B_attn, |
| analysis=layer_analysis.laurel if layer_analysis is not None else None, |
| slot="attn", |
| ) |
| else: |
| attn_aug = residual_attn + hidden_states |
| h_tilde = self.gpas_attn(attn_aug, analysis=gpas_attn_a) |
|
|
| if layer_analysis is not None: |
| layer_analysis.h_tilde = h_tilde.detach() |
|
|
| |
| |
| |
| |
| if self.use_attn_res and attn_res_sources is not None: |
| h_mlp = self._attn_res( |
| attn_res_sources, h_tilde, self.attn_res_query_mlp, |
| ar_analysis, "mlp", |
| ) |
| residual_mlp = h_tilde |
| else: |
| h_mlp = h_tilde |
| residual_mlp = h_tilde |
|
|
| |
| sn_post = layer_analysis.seednorm_post_attn if layer_analysis is not None else None |
| h_normed2 = self.post_attention_layernorm(h_mlp, analysis=sn_post) |
| h_lns2 = self.lns_mlp(h_normed2) |
| if layer_analysis is not None: |
| layer_analysis.lns_mlp_output = h_lns2.detach() |
|
|
| mlp_a = layer_analysis.mlp if layer_analysis is not None else None |
| if self.use_versatile_ffn: |
| delta_m, versatile_aux = self.mlp(h_lns2, analysis=mlp_a) |
| else: |
| delta_m = self.mlp(h_lns2, analysis=mlp_a) |
| versatile_aux = None |
|
|
| if layer_analysis is not None: |
| layer_analysis.mlp_contribution = delta_m.detach() |
|
|
| |
| aux_stats = None |
| |
| |
| |
| |
| |
| laurel_la = layer_analysis.laurel if layer_analysis is not None else None |
| if self.use_laurel: |
| mlp_aug = self._laurel_residual( |
| residual_mlp, delta_m, |
| self.laurel_rw_mlp, self.laurel_lr_A_mlp, self.laurel_lr_B_mlp, |
| analysis=laurel_la, |
| slot="mlp", |
| ) |
| else: |
| mlp_aug = residual_mlp + delta_m |
|
|
| if self.use_jtokm and z_tilde is not None and B_vals is not None: |
| orig_shape = h_tilde.shape |
| h_flat = h_tilde.reshape(-1, self.hidden_size) |
| z_flat = z_tilde.reshape(-1, z_tilde.shape[-1]) |
| B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1]) |
|
|
| jtokm_a = layer_analysis.jtokm if layer_analysis is not None else None |
| delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a) |
| delta_r = delta_r.reshape(orig_shape) |
|
|
| gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None |
| hidden_states = self.gpas_mlp(mlp_aug + delta_r, analysis=gpas_mlp_a) |
| else: |
| gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None |
| hidden_states = self.gpas_mlp(mlp_aug, analysis=gpas_mlp_a) |
|
|
| if layer_analysis is not None: |
| layer_analysis.hidden_states_output = hidden_states.detach() |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (attn_weights,) |
| if aux_stats is not None: |
| outputs += (aux_stats,) |
| if versatile_aux is not None: |
| outputs += (versatile_aux,) |
| return outputs |
|
|
|
|
| class SpellingBeeEmbedding(nn.Module): |
| """ |
| Spelling Bee Embeddings (Rabe et al., 2026, arXiv:2601.18030). |
| |
| Augments token embeddings with character-level information derived from |
| the UTF-8 byte sequence of each token. The spelling bee embedding is the |
| mean of the standard token embedding and a character-level summary: |
| |
| e_bee(t) = 0.5 * (e_tok(t) + e_chars(t)) |
| |
| e_chars(t) = inv_sqrt_len(t) * Ξ£_{i=0}^{15} RoPE(e_byte[b_i], i) |
| |
| where inv_sqrt_len = 1/β|t| is precomputed per token type at setup time. |
| |
| Key design decisions vs. a naΓ―ve per-occurrence implementation: |
| |
| 1. **Vocab-level computation** β e_chars is built over the full vocabulary |
| once per forward (shape [V, d]), then gathered by token_ids. A naΓ―ve |
| implementation would compute [B*S, 16, d] per step, repeating identical |
| work for every occurrence of a frequent token. This approach reduces |
| the dominant intermediate from O(BΒ·SΒ·16Β·d) to O(VΒ·16Β·d), where V βͺ BΒ·S |
| in practice for most batches. |
| |
| 2. **Static [256, 16, d] rope_bytes table** β RoPE is applied once over |
| all 256 possible byte values at all 16 positions, producing a table |
| with fully static shapes. torch.compile / max_autotune can fuse the |
| construction of this table (two elementwise ops + concat over fixed |
| dims) into a single kernel. Token-level e_chars is then a gather + |
| sum over this table, also fully static. |
| |
| 3. **Precomputed inv_sqrt_lens** β 1/βbyte_len is computed once in |
| set_byte_table and stored as a persistent buffer. The per-forward |
| normalisation becomes a single elementwise multiply, with no sqrt or |
| division in the hot path. |
| |
| Compatible with both the standard embed_tokens path and the |
| LeviathanGenerator path. |
| |
| **Inference cost: zero overhead after baking.** |
| Call ``bake_inference_table(token_embeds_weight)`` once after training to |
| collapse the SBE into a single embedding table indistinguishable from a |
| standard nn.Embedding lookup. |
| |
| **Setup: call ``set_byte_table(tokenizer)`` once after model init** (and |
| before any .to(device) / FP8 conversion) before training. The byte table |
| and inv_sqrt_lens are persistent buffers saved in checkpoints. |
| |
| References: |
| Rabe, Clymo & Dong (2026). "Spelling Bee Embeddings for Language |
| Modeling." arXiv:2601.18030. |
| """ |
|
|
| MAX_BYTES: int = 16 |
|
|
| def __init__(self, config: "NeoLLMConfig"): |
| super().__init__() |
| d = config.hidden_size |
| base = getattr(config, "rope_theta", 10000.0) |
|
|
| |
| self.byte_emb = nn.Embedding(256, d) |
|
|
| |
| |
| |
| self.register_buffer( |
| "token_bytes", |
| torch.zeros(config.vocab_size, self.MAX_BYTES, dtype=torch.long), |
| persistent=True, |
| ) |
| |
| |
| self.register_buffer( |
| "inv_sqrt_lens", |
| torch.ones(config.vocab_size, dtype=torch.float), |
| persistent=True, |
| ) |
|
|
| |
| |
| |
| |
| half = d // 2 |
| theta = 1.0 / (base ** (torch.arange(0, half, dtype=torch.float) * 2.0 / d)) |
| pos = torch.arange(self.MAX_BYTES, dtype=torch.float) |
| freqs = torch.outer(pos, theta) |
| self.register_buffer("intra_cos", freqs.cos(), persistent=False) |
| self.register_buffer("intra_sin", freqs.sin(), persistent=False) |
|
|
| |
| |
| |
| self.register_buffer( |
| "pos_idx", |
| torch.arange(self.MAX_BYTES, dtype=torch.long), |
| persistent=False, |
| ) |
|
|
| |
|
|
| def set_byte_table(self, tokenizer) -> None: |
| """ |
| Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer. |
| |
| Must be called **once** after model instantiation and **before** |
| ``.to(device)`` / FP8 conversion so the buffers land on the correct |
| device after those transforms. Both buffers are persistent and will |
| be saved/restored from checkpoints automatically. |
| |
| Args: |
| tokenizer: Any HuggingFace tokenizer with |
| ``convert_ids_to_tokens(int) -> str | None``. |
| """ |
| vocab_size = self.token_bytes.shape[0] |
| byte_ids = torch.zeros(vocab_size, self.MAX_BYTES, dtype=torch.long) |
| inv_sqrt = torch.ones(vocab_size, dtype=torch.float) |
|
|
| for token_id in range(vocab_size): |
| token_str = tokenizer.convert_ids_to_tokens(token_id) |
| if token_str is None: |
| continue |
| |
| |
| try: |
| raw = token_str.encode("utf-8") |
| except Exception: |
| raw = b"\x00" |
| n = min(len(raw), self.MAX_BYTES) |
| for i in range(n): |
| byte_ids[token_id, i] = raw[i] |
| inv_sqrt[token_id] = 1.0 / math.sqrt(max(n, 1)) |
|
|
| self.token_bytes.copy_(byte_ids.to(self.token_bytes.device)) |
| self.inv_sqrt_lens.copy_(inv_sqrt.to(self.inv_sqrt_lens.device)) |
|
|
| |
|
|
| def _build_rope_bytes(self) -> torch.Tensor: |
| """ |
| Build the static [256, MAX_BYTES, d] RoPE-encoded byte table. |
| |
| For each of the 256 possible byte values and each of the MAX_BYTES |
| intra-token positions, applies RoPE rotation using the current |
| byte_emb.weight. All shapes are fully static, so torch.compile can |
| fuse this into a single kernel. |
| |
| Called once per forward pass; the result is discarded afterward. |
| The cost is two broadcast elementwise ops + one cat over fixed dims. |
| |
| Returns: |
| rope_bytes [256, MAX_BYTES, d] |
| """ |
| w = self.byte_emb.weight |
| half = w.shape[-1] // 2 |
| w1 = w[:, :half].unsqueeze(1) |
| w2 = w[:, half:].unsqueeze(1) |
| cos = self.intra_cos.unsqueeze(0) |
| sin = self.intra_sin.unsqueeze(0) |
| return torch.cat( |
| [w1 * cos - w2 * sin, |
| w1 * sin + w2 * cos], |
| dim=-1, |
| ) |
|
|
| |
|
|
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| token_embeds: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| token_ids: integer token indices to look up byte sequences. |
| token_embeds: embeddings from embed_tokens or LeviathanGenerator. |
| Returns: |
| Spelling bee embeddings β same shape as token_embeds. |
| """ |
| |
| |
| rope_bytes = self._build_rope_bytes() |
|
|
| |
| |
| |
| |
| |
| |
| e_chars_vocab = rope_bytes[ |
| self.token_bytes, |
| self.pos_idx.unsqueeze(0), |
| ].sum(1) |
|
|
| |
| |
| e_chars_vocab = e_chars_vocab * self.inv_sqrt_lens.unsqueeze(-1) |
|
|
| |
| |
| e_chars = e_chars_vocab[token_ids] |
|
|
| |
| return (token_embeds + e_chars) * 0.5 |
|
|
| |
|
|
| @torch.no_grad() |
| def bake_inference_table( |
| self, |
| token_emb_weight: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Collapse SBE into a single [vocab_size, d] embedding table. |
| |
| After baking, the SBE computation is indistinguishable from a standard |
| nn.Embedding lookup β zero additional overhead at inference time. |
| |
| Args: |
| token_emb_weight: [vocab_size, d] β weight matrix of embed_tokens |
| or the equivalent table (e.g. after Leviathan). |
| Returns: |
| [vocab_size, d] β baked spelling bee embedding table. |
| |
| Usage:: |
| |
| baked = model.model.spelling_bee.bake_inference_table( |
| model.model.embed_tokens.weight |
| ) |
| model.model.embed_tokens.weight.copy_(baked) |
| # Optionally free byte_emb parameters: |
| # del model.model.spelling_bee |
| """ |
| rope_bytes = self._build_rope_bytes() |
| e_chars_vocab = rope_bytes[ |
| self.token_bytes, |
| self.pos_idx.unsqueeze(0), |
| ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) |
| return (token_emb_weight + e_chars_vocab) * 0.5 |
|
|
|
|
| class NeoLLMPreTrainedModel(PreTrainedModel): |
| """ |
| Base class with custom weight initialization for all NeoLLM components. |
| |
| LeviathanGenerator (real per-head architecture): |
| - codebooks: normal(0, initializer_range) |
| - head_proj[i]: normal(0, initializer_range) β standard for linear |
| - head_norm[i]: weight=1, bias=0 β default LayerNorm init |
| - head_scale: filled with (num_knots - 1) β matches ckhronos.py |
| exactly: scale initialized to stretch the knot grid |
| uniformly so d = |x - grid| * (num_knots-1) maps the |
| [0,1] input to [0, num_knots-1] at init. |
| - head_spline: normal(mean=1.0, std=0.1) β KHRONOS init_weights_prod. |
| The effective coefficient is (1 + delta) so the |
| product across d_seed dimensions starts near 1 rather |
| than near 0, providing stable gradients from step 0. |
| - head_out[i]: normal(0, initializer_range / sqrt(num_modes)) β |
| scaled by 1/βM so the sum of M head outputs starts |
| with the same variance as a single head projection. |
| - seed_proj: normal(0, initializer_range) β JTok-M shared path |
| No W_res β confirmed absent in the authors' implementation. |
| LeviathanJTokM: |
| - spline_coeff: normal(mean=1.0, std=0.1) β same as generator |
| - W_out: normal(0, initializer_range) |
| - W_res: normal(0, initializer_range) |
| - router: parent handles (normal init) |
| - scaler: ones (identity at init) |
| NeoLLMAttention (Affine-Scaled Attention): |
| - alpha_proj: normal(0, 0.02) β near-zero so linear_clipping(β0) β 0.5 |
| at init, giving a mild ~0.5Γ scaling of softmax weights |
| per head rather than collapsing to 0 or 1. |
| - alpha_ma: zeros β running EMA starts at 0, Ξ² starts as βΞ±/N β small |
| negative offset; model quickly learns to adjust both. |
| REPOModule (Context Re-Positioning): |
| - W_g, W_c, W_z: default normal init from parent _init_weights. |
| No special initialization required β the SwiGLU |
| sub-layer starts near-zero, so z_i β 0 for all tokens |
| at step 0, which is equivalent to constant position |
| assignment (NoPE-like). The model quickly learns to |
| differentiate positions as needed. |
| """ |
| config: NeoLLMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["NeoLLMDecoderLayer"] |
| _supports_attention_backend = True |
| _supports_flash_attn = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _is_stateful = True |
|
|
| def _init_weights(self, module): |
| super()._init_weights(module) |
|
|
| if isinstance(module, NeoLLMAttention): |
| if hasattr(module, "lambda_1"): |
| module.lambda_1.data.fill_(0.5) |
| if hasattr(module, "lambda_2"): |
| module.lambda_2.data.fill_(0.5) |
| if hasattr(module, "mea_key_mix") and module.mea_key_mix is not None: |
| |
| |
| |
| |
| nn.init.eye_(module.mea_key_mix.data) if ( |
| module.mea_key_mix.shape[0] == module.mea_key_mix.shape[1] |
| ) else module.mea_key_mix.data.copy_( |
| torch.eye( |
| module.mea_key_mix.shape[0], module.mea_key_mix.shape[1] |
| ).to(device=module.mea_key_mix.device, dtype=module.mea_key_mix.dtype) |
| ) |
| if hasattr(module, "mea_value_mix") and module.mea_value_mix is not None: |
| nn.init.eye_(module.mea_value_mix.data) if ( |
| module.mea_value_mix.shape[0] == module.mea_value_mix.shape[1] |
| ) else module.mea_value_mix.data.copy_( |
| torch.eye( |
| module.mea_value_mix.shape[0], module.mea_value_mix.shape[1] |
| ).to(device=module.mea_value_mix.device, dtype=module.mea_value_mix.dtype) |
| ) |
| if hasattr(module, "alpha_proj"): |
| nn.init.normal_(module.alpha_proj.weight, mean=0.0, std=0.02) |
| if hasattr(module, "alpha_ma"): |
| module.alpha_ma.zero_() |
|
|
| elif isinstance(module, GPAS): |
| module.alpha.data.fill_(0.0) |
|
|
| elif isinstance(module, FANLayer): |
| pass |
|
|
| elif isinstance(module, SeeDNorm): |
| pass |
|
|
| elif isinstance(module, (ScalarMultiplier, VectorMultiplier)): |
| if hasattr(module, "multiplier"): |
| module.multiplier.data.fill_(1.0) |
|
|
| elif isinstance(module, LeviathanGenerator): |
| nn.init.normal_(module.codebooks, |
| mean=0.0, std=self.config.initializer_range) |
| |
| module.head_scale.data.fill_(float(module.num_knots - 1)) |
| |
| nn.init.normal_(module.head_spline, mean=1.0, std=0.1) |
| |
| |
| |
| |
| nn.init.normal_(module.head_proj_weight, |
| mean=0.0, std=self.config.initializer_range) |
| |
| module.head_norm_weight.data.fill_(1.0) |
| module.head_norm_bias.data.zero_() |
| |
| |
| out_std = self.config.initializer_range / math.sqrt(module.num_modes) |
| nn.init.normal_(module.head_out_weight, mean=0.0, std=out_std) |
|
|
| elif isinstance(module, LeviathanJTokM): |
| |
| nn.init.normal_(module.spline_coeff, mean=1.0, std=0.1) |
| |
| nn.init.normal_(module.W_out, |
| mean=0.0, std=self.config.initializer_range) |
| nn.init.normal_(module.W_res, |
| mean=0.0, std=self.config.initializer_range) |
| |
| module.scaler.data.fill_(1.0) |
| |
|
|
| elif isinstance(module, NeoLLMDecoderLayer): |
| |
| |
| |
| |
| |
| if hasattr(module, 'attn_res_query_attn') and module.attn_res_query_attn is not None: |
| module.attn_res_query_attn.data.zero_() |
| module.attn_res_query_mlp.data.zero_() |
|
|
| elif isinstance(module, SpellingBeeEmbedding): |
| |
| |
| |
| d = module.byte_emb.embedding_dim |
| nn.init.normal_(module.byte_emb.weight, mean=0.0, std=1.0 / math.sqrt(d)) |
|
|
|
|
| class NeoLLMModel(NeoLLMPreTrainedModel): |
| """ |
| NeoLLM base decoder-only Transformer. |
| |
| When use_jtokm=True, the generator returns (embeddings, z_tilde, B_vals). |
| z_tilde and B_vals are passed to every decoder layer so JTok-M surfaces |
| can reuse the B-spline basis without recomputation. |
| |
| When use_attn_res=True, the model maintains a list of previous layer |
| outputs (or block summaries for Block AttnRes) and passes them to each |
| decoder layer, replacing fixed residual accumulation with learned |
| depth-wise softmax attention (Kimi Team, 2026, arXiv:2603.15031). |
| |
| Spelling Bee Embeddings (Rabe et al., 2026, arXiv:2601.18030): |
| ``use_spelling_bee_embeddings`` is independent of the Leviathan flag: |
| SBE is applied post-embedding regardless of which path produced the |
| token embeddings (embed_tokens or LeviathanGenerator). |
| |
| Flag coupling: |
| - use_token_generator=False, use_spelling_bee_embeddings=False |
| β standard embed_tokens, no SBE [default] |
| - use_token_generator=False, use_spelling_bee_embeddings=True |
| β standard embed_tokens + SBE |
| - use_token_generator=True, use_spelling_bee_embeddings=False |
| β LeviathanGenerator only, no SBE |
| - use_token_generator=True, use_spelling_bee_embeddings=True |
| β LeviathanGenerator + SBE |
| |
| Setup: call ``model.model.spelling_bee.set_byte_table(tokenizer)`` |
| once after model init and before training (and after any .to(device)). |
| """ |
|
|
| def __init__(self, config: NeoLLMConfig): |
| super().__init__(config) |
|
|
| |
| if config.use_token_generator: |
| self.token_generator = LeviathanGenerator(config) |
| else: |
| self.embed_tokens = nn.Embedding( |
| config.vocab_size, config.hidden_size, config.pad_token_id |
| ) |
|
|
| |
| |
| |
| use_sbe = getattr(config, "use_spelling_bee_embeddings", False) |
| if use_sbe: |
| self.spelling_bee = SpellingBeeEmbedding(config) |
| else: |
| self.spelling_bee = None |
|
|
| self.layers = nn.ModuleList( |
| [NeoLLMDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = NeoLLMRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.first_layer_fan = None |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| if self.config.use_token_generator: |
| return self.token_generator |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| if self.config.use_token_generator: |
| self.token_generator = value |
| else: |
| self.embed_tokens = value |
|
|
| def _build_layer_analysis(self, layer_idx: int = 0) -> LayerAnalysis: |
| """ |
| Construct a LayerAnalysis with sub-objects pre-allocated for every |
| component that is active in the current config. |
| |
| Fields that correspond to disabled flags remain None β the analysis |
| consumer can check for None without needing to inspect config flags. |
| Called once per layer per forward when analysis is active. |
| """ |
| cfg = self.config |
| _repo_active = ( |
| getattr(cfg, "use_repo", False) |
| and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3) |
| ) |
| _versatile = getattr(cfg, "use_versatile_ffn", False) |
| return LayerAnalysis( |
| seednorm_pre_attn = SeeDNormAnalysis(), |
| seednorm_post_attn = SeeDNormAnalysis(), |
| attention = AttentionAnalysis( |
| fan = FANAnalysis(), |
| hadamard = HadamardAnalysis() if getattr(cfg, "use_hadamard_o_proj", False) else None, |
| repo = REPOAnalysis() if _repo_active else None, |
| ), |
| mlp = MLPAnalysis( |
| fan = FANAnalysis() if not _versatile else None, |
| polynorm = PolyNormAnalysis() if not _versatile else None, |
| versatile = VersatileFFNAnalysis() if _versatile else None, |
| ), |
| gpas_attn = GPASAnalysis(), |
| gpas_mlp = GPASAnalysis(), |
| jtokm = JTokMAnalysis() if cfg.use_jtokm else None, |
| attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None, |
| laurel = LAuReLAnalysis() if getattr(cfg, "use_laurel", False) else None, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| analysis_state: Optional[AnalysisState] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Tuple: |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| output_attentions = ( |
| output_attentions if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| if return_dict is None: |
| cfg_dict = vars(self.config) |
| return_dict = cfg_dict.get( |
| "return_dict", |
| cfg_dict.get("use_return_dict", True), |
| ) |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds") |
|
|
| |
| z_tilde = None |
| B_vals = None |
|
|
| gen_a = ( |
| analysis_state.generator |
| if analysis_state is not None and self.config.use_token_generator |
| else None |
| ) |
|
|
| if inputs_embeds is None: |
| if self.config.use_token_generator: |
| if self.config.use_jtokm: |
| |
| inputs_embeds, z_tilde, B_vals = self.token_generator( |
| input_ids, return_internals=True, analysis=gen_a |
| ) |
| |
| z_tilde = z_tilde.reshape(*input_ids.shape, self.config.generator_d_seed) |
| B_vals = B_vals.reshape( |
| *input_ids.shape, |
| self.config.generator_d_seed, |
| self.config.generator_num_knots, |
| ) |
| else: |
| inputs_embeds = self.token_generator(input_ids, analysis=gen_a) |
| else: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| |
| |
| |
| if self.spelling_bee is not None and input_ids is not None: |
| inputs_embeds = self.spelling_bee(input_ids, inputs_embeds) |
|
|
| if analysis_state is not None: |
| analysis_state.embeddings = inputs_embeds.detach() |
|
|
| if position_ids is None: |
| position_ids = torch.arange( |
| 0, inputs_embeds.shape[1], device=inputs_embeds.device |
| ).unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=torch.arange( |
| inputs_embeds.shape[1], device=inputs_embeds.device |
| ), |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_aux_stats = [] |
|
|
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| self.first_layer_fan = None |
|
|
| |
| |
| |
| |
| |
| |
| repo_rope_args = ( |
| (self.rotary_emb.inv_freq, self.rotary_emb.attention_scaling) |
| if getattr(self.config, "use_repo", False) else None |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| use_attn_res = getattr(self.config, 'use_attn_res', False) |
| attn_res_sources = None |
| attn_res_partial = None |
| if use_attn_res: |
| attn_res_sources = [hidden_states] |
| attn_res_partial = hidden_states |
|
|
| num_blocks = getattr(self.config, 'attn_res_num_blocks', 0) |
| block_size = ( |
| max(self.config.num_hidden_layers // num_blocks, 1) |
| if num_blocks > 0 |
| else 1 |
| ) |
|
|
| |
| if analysis_state is not None: |
| analysis_state.layers = [] |
|
|
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0: |
| attn_res_sources = attn_res_sources + [attn_res_partial] |
| attn_res_partial = hidden_states |
|
|
| |
| layer_analysis = None |
| if analysis_state is not None: |
| layer_analysis = self._build_layer_analysis(layer_idx) |
| layer_analysis.layer_idx = layer_idx |
| analysis_state.layers.append(layer_analysis) |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=causal_mask, |
| first_layer_fan=self.first_layer_fan, |
| z_tilde=z_tilde, |
| B_vals=B_vals, |
| attn_res_sources=attn_res_sources, |
| attn_res_partial=attn_res_partial if use_attn_res else None, |
| layer_analysis=layer_analysis, |
| output_attentions=output_attentions, |
| repo_rope_args=repo_rope_args, |
| **kwargs, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| |
| if use_attn_res: |
| attn_res_partial = hidden_states |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| |
| if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1): |
| all_aux_stats.append(layer_outputs[-1]) |
|
|
| |
| |
| if getattr(self.config, "use_versatile_ffn", False): |
| for item in layer_outputs[1:]: |
| if isinstance(item, tuple) and len(item) == 3: |
| |
| all_aux_stats.append(("versatile", item)) |
| break |
|
|
| if (self.first_layer_fan is None |
| and hasattr(decoder_layer, "current_layer_fan")): |
| self.first_layer_fan = decoder_layer.current_layer_fan |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| |
| if analysis_state is not None: |
| analysis_state.final_hidden_states = hidden_states.detach() |
| analysis_state.jtokm_aux_stats = all_aux_stats if self.config.use_jtokm else None |
| analysis_state.attn_res_sources_final = ( |
| attn_res_sources if use_attn_res else None |
| ) |
|
|
| if not return_dict: |
| return tuple( |
| v for v in [hidden_states, None, all_hidden_states, all_attentions] |
| if v is not None |
| ) + (all_aux_stats,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=None, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ), all_aux_stats |
|
|
|
|
| @torch.compiler.disable |
| def compute_cce_loss( |
| hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None |
| ): |
| """CCE loss excluded from torch.compile.""" |
| processed_labels = labels.to(hidden_states.device) |
| if pad_token_id is not None: |
| processed_labels = torch.where( |
| processed_labels == pad_token_id, |
| torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device), |
| processed_labels, |
| ) |
| return linear_cross_entropy( |
| hidden_states, lm_head_weight, processed_labels, |
| bias=lm_head_bias, shift=1, impl="cce_kahan_full_c", reduction="mean", |
| ) |
|
|
|
|
| class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin): |
| """ |
| Causal LM with NeoLLM backbone. |
| |
| When use_jtokm=True, the load-balancing auxiliary loss is computed from |
| the per-layer JTok-M routing statistics and added to the cross-entropy loss: |
| |
| total_loss = CE_loss + L_aux |
| |
| where L_aux = Ξ» Β· n_e Β· (1/L) Β· Ξ£_β Ξ£_i p_i^β Β· f_i^β |
| |
| ββ Analysis mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Analysis is NEVER active during training (model.training=True). |
| It is opt-in at inference time: |
| |
| model.eval() |
| model.enable_analysis() # arm the system |
| with torch.no_grad(): |
| _ = model(input_ids) |
| state = model.last_analysis # AnalysisState with all internals |
| |
| model.disable_analysis() # disarm β zero overhead again |
| # or: model.train() # training always disarms automatically |
| |
| last_analysis is replaced at each forward call when analysis is armed. |
| last_analysis is None between enable_analysis() and the first forward, |
| and cleared by disable_analysis(). |
| |
| Example field access: |
| state.layers[3].attention.alpha_per_head # Affine-Scaled Ξ± |
| state.layers[0].mlp.polynorm.x2_post_exclusive |
| state.generator.z_tilde # Leviathan latent |
| state.layers[2].jtokm.routing_weights # JTok-M router |
| state.layers[5].attn_res.weights_pre_attn # AttnRes softmax |
| """ |
|
|
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
|
|
| def __init__(self, config: NeoLLMConfig): |
| super().__init__(config) |
| self.model = NeoLLMModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| if config.use_token_generator: |
| self._tied_weights_keys = {} |
|
|
| |
| |
| |
| |
| self._analysis_armed: bool = False |
| self.last_analysis: Optional[AnalysisState] = None |
|
|
| self.post_init() |
|
|
| |
|
|
| def enable_analysis(self) -> None: |
| """ |
| Arm the analysis system. |
| |
| Has no effect during training β analysis is always off when |
| model.training=True, regardless of this flag. Safe to call at any |
| point without affecting the training loop. |
| |
| Usage:: |
| |
| model.eval() |
| model.enable_analysis() |
| with torch.no_grad(): |
| _ = model(input_ids) |
| state = model.last_analysis |
| """ |
| self._analysis_armed = True |
|
|
| def disable_analysis(self) -> None: |
| """ |
| Disarm the analysis system and clear last_analysis. |
| |
| After this call, forward passes produce zero analysis overhead. |
| """ |
| self._analysis_armed = False |
| self.last_analysis = None |
|
|
| def _make_analysis_state( |
| self, |
| input_ids: Optional[torch.Tensor], |
| ) -> Optional[AnalysisState]: |
| """ |
| Decide whether to build an AnalysisState for this forward pass. |
| |
| Returns None (zero cost) when: |
| - analysis is not armed, OR |
| - the model is in training mode (self.training=True). |
| Otherwise returns a fresh AnalysisState with top-level optional |
| fields pre-allocated according to the active config flags. |
| """ |
| if not self._analysis_armed or self.training: |
| return None |
| cfg = self.config |
| return AnalysisState( |
| input_ids = input_ids.detach() if input_ids is not None else None, |
| generator = GeneratorAnalysis() if cfg.use_token_generator else None, |
| layers = None, |
| jtokm_aux_stats = [] if cfg.use_jtokm else None, |
| attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None, |
| ) |
|
|
| |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values=None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| NeoLLM does not implement KV caching (always returns past_key_values=None). |
| Transformers' default GenerationMixin.prepare_inputs_for_generation assumes |
| KV cache is active and slices input_ids to only the newest token on every |
| step past the prefill. Without a real cache that retains previous K/V |
| states, the attention module only sees 1 key/value pair while the causal |
| mask still spans the full context length β causing a shape mismatch in SDPA. |
| |
| This override always forwards the COMPLETE input_ids sequence so the model |
| can recompute attention over the full context from scratch at every step. |
| Generation is therefore slower (no caching benefit) but numerically correct. |
| """ |
| model_inputs: dict = {"input_ids": input_ids, "attention_mask": attention_mask} |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs["inputs_embeds"] = inputs_embeds |
| return model_inputs |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| |
| analysis_state = self._make_analysis_state(input_ids) |
|
|
| model_out = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| analysis_state=analysis_state, |
| **kwargs, |
| ) |
|
|
| |
| |
| if isinstance(model_out, tuple): |
| outputs, all_aux_stats = model_out[0], model_out[-1] |
| if isinstance(outputs, tuple): |
| hidden_states = outputs[0] |
| else: |
| hidden_states = outputs.last_hidden_state |
| else: |
| outputs = model_out |
| all_aux_stats = [] |
| hidden_states = outputs.last_hidden_state |
|
|
| loss = None |
| if labels is not None: |
| loss = compute_cce_loss( |
| hidden_states, labels, self.lm_head.weight, |
| getattr(self.lm_head, "bias", None), self.config.pad_token_id, |
| ) |
| |
| if self.config.use_jtokm and all_aux_stats: |
| jtokm_stats = [ |
| s for s in all_aux_stats |
| if not (isinstance(s, tuple) and len(s) == 2 and s[0] == "versatile") |
| ] |
| if jtokm_stats: |
| aux_loss = compute_jtokm_aux_loss( |
| jtokm_stats, |
| n_e=self.config.jtokm_num_experts, |
| weight=self.config.jtokm_aux_loss_weight, |
| ) |
| loss = loss + aux_loss |
| |
| if getattr(self.config, "use_versatile_ffn", False) and all_aux_stats: |
| versatile_stats = [ |
| s[1] for s in all_aux_stats |
| if isinstance(s, tuple) and len(s) == 2 and s[0] == "versatile" |
| ] |
| if versatile_stats: |
| v_loss = compute_versatile_aux_loss( |
| versatile_stats, |
| n_experts=self.config.versatile_total_experts, |
| weight=self.config.versatile_aux_loss_weight, |
| ) |
| loss = loss + v_loss |
| logits = None |
| else: |
| slice_indices = ( |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| |
| if analysis_state is not None: |
| analysis_state.logits = logits.detach() if logits is not None else None |
| self.last_analysis = analysis_state |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, |
| attentions=outputs.attentions if hasattr(outputs, "attentions") else None, |
| ) |
|
|
|
|
| |
|
|
| __all__ = [ |
| "NeoLLMForCausalLM", |
| "NeoLLMModel", |
| "NeoLLMPreTrainedModel", |
| "NeoLLMConfig", |
| "LeviathanGenerator", |
| "LeviathanJTokM", |
| "SpellingBeeEmbedding", |
| "FANLayer", |
| "SeeDNorm", |
| "ScalarMultiplier", |
| "VectorMultiplier", |
| "LinearWithMultipliers", |
| "MEAHeadSeeDNorm", |
| "HadamardOProj", |
| "REPOModule", |
| "VersatileFFN", |
| "compute_versatile_aux_loss", |
| |
| "AnalysisState", |
| "LayerAnalysis", |
| "AttentionAnalysis", |
| "MLPAnalysis", |
| "FANAnalysis", |
| "SeeDNormAnalysis", |
| "GPASAnalysis", |
| "PolyNormAnalysis", |
| "HadamardAnalysis", |
| "REPOAnalysis", |
| "VersatileFFNAnalysis", |
| "JTokMAnalysis", |
| "AttnResAnalysis", |
| "GeneratorAnalysis", |
| ] |
|
|
| AutoConfig.register("neollm", NeoLLMConfig) |
| AutoModel.register(NeoLLMConfig, NeoLLMModel) |
| AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM) |
|
|