""" Chimera 5.2 — recurrent / attention layers (CPU-first). Every layer in this module exposes a ``forward(x, cache=None)`` signature and returns ``(out, new_cache)``. ``cache`` is an arbitrary tensor / dict that the layer reads on the previous timestep and returns updated for the next call. This makes O(T) decoding possible instead of the O(T²) recompute used by the original implementation. Optimisations vs. the previous draft: * No ``einops`` dependency — every reshape is a plain :func:`Tensor.view`. * Mask cache keyed by (T, dtype, device) — no per-token allocation churn. * Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones during training (the inter-chunk recurrence runs at fp32 with detached state on CPU, gradient flow is preserved through the per-chunk QKV path). * mLSTM forgets are accumulated in log-space with a single ``cumsum``; the causal mask is added once instead of per-row. * TitansMAC only computes the values it actually uses (the original draft built ``kv`` and threw it away – removed). * TSPSpanKnotLayer's energy is a single fused linear projection; the per-step Hamming/coherence loops are replaced by vectorised cosine similarity. """ from __future__ import annotations import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from .quantization import BitLinear, RMSNorm # --------------------------------------------------------------------------- # Shared utilities # --------------------------------------------------------------------------- _MASK_CACHE: dict = {} def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """Cached additive causal mask: 0 on/below diag, ``-inf`` above.""" key = ("neg_inf", T, str(device), dtype) cached = _MASK_CACHE.get(key) if cached is not None: return cached # Build outside any autograd / inference-mode context so the tensor is a # plain leaf that can be reused across train/eval/inference_mode calls. with torch.inference_mode(False), torch.no_grad(): mask = torch.zeros(T, T, dtype=dtype, device=device) mask.masked_fill_( torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1), float("-inf"), ) _MASK_CACHE[key] = mask return mask def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor: """Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating.""" key = ("tril_bool", T, str(device)) cached = _MASK_CACHE.get(key) if cached is not None: return cached with torch.inference_mode(False), torch.no_grad(): mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device)) _MASK_CACHE[key] = mask return mask def _make_linear(use_ternary: bool): if use_ternary: return BitLinear return lambda i, o, **kw: nn.Linear(i, o, bias=False) # --------------------------------------------------------------------------- # SwiGLU MLP (shared with MoE) # --------------------------------------------------------------------------- class SwiGLUMLP(nn.Module): """SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``.""" __constants__ = ["hidden_size", "intermediate_size"] def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True): super().__init__() L = _make_linear(use_ternary) self.hidden_size = int(hidden_size) self.intermediate_size = int(intermediate_size) self.gate_proj = L(self.hidden_size, self.intermediate_size) self.up_proj = L(self.hidden_size, self.intermediate_size) self.down_proj = L(self.intermediate_size, self.hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # --------------------------------------------------------------------------- # Causal depthwise conv (used by Gated DeltaNet) # --------------------------------------------------------------------------- class ShortConv1d(nn.Module): """Causal depthwise 1-D convolution + SiLU. Supports streaming via a small (kernel_size-1) tail cache so generation runs at O(1) per token even though the conv has a kernel > 1. """ __constants__ = ["kernel_size", "dim"] def __init__(self, dim: int, kernel_size: int = 4): super().__init__() self.dim = int(dim) self.kernel_size = int(kernel_size) self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size, padding=self.kernel_size - 1, groups=self.dim, bias=False) def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: # x: [B, T, D] -> conv expects [B, D, T] B, T, D = x.shape xt = x.transpose(1, 2) # [B, D, T] if tail is not None and tail.numel() > 0: xt = torch.cat([tail, xt], dim=-1) T_full = xt.shape[-1] else: T_full = T y = self.conv(xt)[..., :T_full] # causal: drop the trailing pad slack y = y[..., -T:] # only keep outputs aligned with new inputs new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0] return F.silu(y).transpose(1, 2), new_tail # --------------------------------------------------------------------------- # Gated DeltaNet (chunkwise parallel + recurrent state) # --------------------------------------------------------------------------- def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, state: Optional[torch.Tensor], chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Chunkwise gated delta-rule scan. Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``. ``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``. Returns (output [B, T, H, V], new_state). """ B, T, H, K = q.shape V = v.shape[-1] device = q.device # Permute once: [B, H, T, *] q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32) k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32) v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32) g = g.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T] beta = beta.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T] scale = K ** -0.5 q = q * scale v = v * beta.unsqueeze(-1) chunk = min(chunk_size, T) if state is None: S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) else: S = state.to(torch.float32) out_chunks = [] for start in range(0, T, chunk): end = min(start + chunk, T) c = end - start qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end] # Cumulative log-decay within the chunk. log_decay = gc.cumsum(dim=-1) # [B, H, c] # Within-chunk weighting: exp(log_decay[i] - log_decay[j]) for j <= i # Built once via outer subtraction; mask non-causal entries to 0. diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2) # [B, H, c, c] causal = _causal_tril_bool(c, device) # [c, c] intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff)) # Output = qc @ kc^T * intra_w @ vc + qc * exp(log_decay) @ S attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w # [B, H, c, c] o_intra = torch.matmul(attn, vc) # [B, H, c, V] o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S) # [B, H, c, V] out_chunks.append(o_intra + o_inter) # Update carried state: S <- S * exp(decay_total) + (kc * exp(decay_chunk_end - log_decay)).T @ vc decay_total = log_decay[:, :, -1:] # [B, H, 1] S = S * decay_total.unsqueeze(-1).exp() per_step = (decay_total - log_decay).unsqueeze(-1).exp() # [B, H, c, 1] S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc) out = torch.cat(out_chunks, dim=2) # [B, H, T, V] return out.permute(0, 2, 1, 3).contiguous(), S class GatedDeltaNetLayer(nn.Module): """Gated DeltaNet — chunkwise parallel during training, O(1) per token at inference.""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6, chunk_size: int = 64, use_ternary: bool = True): super().__init__() self.hidden_size = int(hidden_size) self.num_heads = int(num_heads) self.head_dim = int(head_dim) self.head_v_dim = int(head_dim * expand_v) self.key_dim = self.num_heads * self.head_dim self.value_dim = self.num_heads * self.head_v_dim self.chunk_size = int(chunk_size) L = _make_linear(use_ternary) self.q_proj = L(self.hidden_size, self.key_dim) self.k_proj = L(self.hidden_size, self.key_dim) self.v_proj = L(self.hidden_size, self.value_dim) self.g_proj = L(self.hidden_size, self.value_dim) self.o_proj = L(self.value_dim, self.hidden_size) self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) A = torch.empty(self.num_heads).uniform_(0.0, 16.0) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4) self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt))) self.dt_bias._no_weight_decay = True self.q_conv = ShortConv1d(self.key_dim, conv_size) self.k_conv = ShortConv1d(self.key_dim, conv_size) self.v_conv = ShortConv1d(self.value_dim, conv_size) self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) def forward(self, x: torch.Tensor, cache: Optional[dict] = None ) -> Tuple[torch.Tensor, dict]: B, T, _ = x.shape prev_state = cache.get("state") if cache else None prev_q_tail = cache.get("q_tail") if cache else None prev_k_tail = cache.get("k_tail") if cache else None prev_v_tail = cache.get("v_tail") if cache else None q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail) k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail) v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail) q = q_full.view(B, T, self.num_heads, self.head_dim) k = k_full.view(B, T, self.num_heads, self.head_dim) v = v_full.view(B, T, self.num_heads, self.head_v_dim) q = F.normalize(q, p=2.0, dim=-1) k = F.normalize(k, p=2.0, dim=-1) beta = torch.sigmoid(self.b_proj(x)) # [B, T, H] A = -self.A_log.exp() dt = F.softplus(self.a_proj(x) + self.dt_bias) # [B, T, H] g = dt * A.view(1, 1, -1) out, new_state = _gated_delta_chunkwise(q, k, v, g, beta, state=prev_state, chunk_size=self.chunk_size) gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim) out = self.o_norm(out) * F.silu(gate) out = out.reshape(B, T, self.value_dim) out = self.o_proj(out) new_cache = { "state": new_state.detach(), "q_tail": q_tail.detach(), "k_tail": k_tail.detach(), "v_tail": v_tail.detach(), } return out, new_cache # --------------------------------------------------------------------------- # xLSTM mLSTM — parallel chunkwise + carried state # --------------------------------------------------------------------------- class MLSTMLayer(nn.Module): """Parallelised mLSTM with log-space cumulative gates.""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, norm_eps: float = 1e-6, gate_soft_cap: float = 15.0, use_ternary: bool = True): super().__init__() self.hidden_size = int(hidden_size) self.num_heads = int(num_heads) self.head_dim = int(head_dim) self.qk_dim = self.num_heads * self.head_dim self.v_dim = self.num_heads * self.head_dim L = _make_linear(use_ternary) self.q_proj = L(self.hidden_size, self.qk_dim) self.k_proj = L(self.hidden_size, self.qk_dim) self.v_proj = L(self.hidden_size, self.v_dim) self.o_proj = L(self.v_dim, self.hidden_size) self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True) self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True) self.ogate = L(self.hidden_size, self.v_dim) nn.init.constant_(self.igate.bias, -10.0) with torch.no_grad(): self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads)) self.gate_soft_cap = float(gate_soft_cap) self.o_norm = nn.LayerNorm(self.head_dim) self.eps = 1e-6 @staticmethod def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor: return cap * torch.tanh(x / cap) def forward(self, x: torch.Tensor, cache: Optional[dict] = None ) -> Tuple[torch.Tensor, dict]: B, T, _ = x.shape H = self.num_heads D = self.head_dim scale = D ** -0.5 q = self.q_proj(x).view(B, T, H, D) * scale k = self.k_proj(x).view(B, T, H, D) v = self.v_proj(x).view(B, T, H, D) i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap) # [B, T, H] f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap) f_log = F.logsigmoid(f_raw) # [B, T, H] # Log-space accumulators with carry-in. prev_logf = cache.get("log_f_cum") if cache else None # [B, H] log_f_cum = f_log.cumsum(dim=1) # [B, T, H] if prev_logf is not None: log_f_cum = log_f_cum + prev_logf.unsqueeze(1) # Permute to head-major. q_h = q.permute(0, 2, 1, 3) # [B, H, T, D] k_h = k.permute(0, 2, 1, 3) v_h = v.permute(0, 2, 1, 3) log_f_cum_h = log_f_cum.permute(0, 2, 1) # [B, H, T] i_raw_h = i_raw.permute(0, 2, 1) # log_gate[t, s] = log_f_cum[t] - log_f_cum[s] + i[s], causal. log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2) + i_raw_h.unsqueeze(-2)) log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype) m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0) gate_w = (log_gate - m).exp() # [B, H, T, T] attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w n = torch.matmul(gate_w, k_h) # [B, H, T, D] denom = (q_h * n).sum(-1, keepdim=True).abs() denom = torch.maximum(denom, torch.exp(-m)) + self.eps out = torch.matmul(attn, v_h) / denom # [B, H, T, D] out = self.o_norm(out.float()).to(x.dtype) out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim) out_gate = torch.sigmoid(self.ogate(x)) out = self.o_proj(out_gate * out) new_cache = {"log_f_cum": log_f_cum[:, -1].detach()} return out, new_cache # --------------------------------------------------------------------------- # Titans MAC — gated linear attention with persistent memory # --------------------------------------------------------------------------- class TitansMACLayer(nn.Module): """Memory-as-Context linear attention with persistent memory slots.""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, memory_depth: int = 2, persistent_slots: int = 64, local_window: int = 1024, norm_eps: float = 1e-6, use_ternary: bool = True): super().__init__() self.hidden_size = int(hidden_size) self.num_heads = int(num_heads) self.head_dim = int(head_dim) self.memory_depth = int(memory_depth) self.local_window = int(local_window) self.persistent_slots = int(persistent_slots) self.qk_dim = self.num_heads * self.head_dim self.v_dim = self.num_heads * self.head_dim L = _make_linear(use_ternary) self.q_proj = L(self.hidden_size, self.qk_dim) self.k_proj = L(self.hidden_size, self.qk_dim) self.v_proj = L(self.hidden_size, self.v_dim) self.o_proj = L(self.v_dim, self.hidden_size) self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) if self.persistent_slots > 0: self.persistent_memory = nn.Parameter( torch.randn(self.persistent_slots, self.hidden_size) * 0.02) else: self.register_parameter("persistent_memory", None) self.o_norm = RMSNorm(self.v_dim, eps=norm_eps) def forward(self, x: torch.Tensor, cache: Optional[dict] = None ) -> Tuple[torch.Tensor, dict]: B, T, _ = x.shape H = self.num_heads D = self.head_dim # Project once. q = self.q_proj(x).view(B, T, H, D) k = self.k_proj(x).view(B, T, H, D) v = self.v_proj(x).view(B, T, H, D) alpha = torch.sigmoid(self.alpha_proj(x)) # [B, T, H] eta = torch.sigmoid(self.eta_proj(x)) theta = torch.sigmoid(self.theta_proj(x)) * 0.1 q_h = q.permute(0, 2, 1, 3).to(torch.float32) k_h = k.permute(0, 2, 1, 3).to(torch.float32) v_h = v.permute(0, 2, 1, 3).to(torch.float32) alpha_h = alpha.permute(0, 2, 1).to(torch.float32) eta_h = eta.permute(0, 2, 1).to(torch.float32) theta_h = theta.permute(0, 2, 1).to(torch.float32) # Causal forgetting decay built in log-space. log_retain = torch.log1p(-alpha_h.clamp(max=0.999)) log_retain_cum = log_retain.cumsum(dim=-1) decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2) decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype) decay = decay.exp() # 0 above diag contrib = (eta_h * theta_h).unsqueeze(-1) * v_h # [B, H, T, D] attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay # [B, H, T, T] out = torch.matmul(attn, contrib) # [B, H, T, D] out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim) out = self.o_norm(out.to(x.dtype)) return self.o_proj(out), cache or {} # --------------------------------------------------------------------------- # TSP Span Knot — fast vectorised energy # --------------------------------------------------------------------------- class TSPSpanKnotLayer(nn.Module): """TSP Span Knot: GatedDeltaNet body with a small additive energy term.""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, norm_eps: float = 1e-6, chunk_size: int = 64, use_ternary: bool = True): super().__init__() self.hidden_size = int(hidden_size) self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim, norm_eps=norm_eps, chunk_size=chunk_size, use_ternary=use_ternary) # Single fused projection produces five energy terms. self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False) self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3])) self._semantic_memory = None def set_semantic_memory(self, mem) -> None: self._semantic_memory = mem def forward(self, x: torch.Tensor, cache: Optional[dict] = None ) -> Tuple[torch.Tensor, dict]: out, new_cache = self.gdn(x, cache=cache) energies = self.energy_proj(out) # [B, T, 5] weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True) # Small residual nudge — keeps gradient signal small as in 5.1. return out + weighted * 0.01, new_cache __all__ = [ "SwiGLUMLP", "ShortConv1d", "GatedDeltaNetLayer", "MLSTMLayer", "TitansMACLayer", "TSPSpanKnotLayer", ]