| """ |
| Chimera 5.2 — recurrent / attention layers (CPU-first). |
| |
| Every layer in this module exposes a ``forward(x, cache=None)`` signature and |
| returns ``(out, new_cache)``. ``cache`` is an arbitrary tensor / dict that the |
| layer reads on the previous timestep and returns updated for the next call. |
| This makes O(T) decoding possible instead of the O(T²) recompute used by |
| the original implementation. |
| |
| Optimisations vs. the previous draft: |
| * No ``einops`` dependency — every reshape is a plain :func:`Tensor.view`. |
| * Mask cache keyed by (T, dtype, device) — no per-token allocation churn. |
| * Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones |
| during training (the inter-chunk recurrence runs at fp32 with detached |
| state on CPU, gradient flow is preserved through the per-chunk QKV path). |
| * mLSTM forgets are accumulated in log-space with a single ``cumsum``; the |
| causal mask is added once instead of per-row. |
| * TitansMAC only computes the values it actually uses (the original draft |
| built ``kv`` and threw it away – removed). |
| * TSPSpanKnotLayer's energy is a single fused linear projection; the per-step |
| Hamming/coherence loops are replaced by vectorised cosine similarity. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .quantization import BitLinear, RMSNorm |
|
|
|
|
| |
| |
| |
|
|
| _MASK_CACHE: dict = {} |
|
|
|
|
| def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| """Cached additive causal mask: 0 on/below diag, ``-inf`` above.""" |
| key = ("neg_inf", T, str(device), dtype) |
| cached = _MASK_CACHE.get(key) |
| if cached is not None: |
| return cached |
| |
| |
| with torch.inference_mode(False), torch.no_grad(): |
| mask = torch.zeros(T, T, dtype=dtype, device=device) |
| mask.masked_fill_( |
| torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1), |
| float("-inf"), |
| ) |
| _MASK_CACHE[key] = mask |
| return mask |
|
|
|
|
| def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor: |
| """Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating.""" |
| key = ("tril_bool", T, str(device)) |
| cached = _MASK_CACHE.get(key) |
| if cached is not None: |
| return cached |
| with torch.inference_mode(False), torch.no_grad(): |
| mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device)) |
| _MASK_CACHE[key] = mask |
| return mask |
|
|
|
|
| def _make_linear(use_ternary: bool): |
| if use_ternary: |
| return BitLinear |
| return lambda i, o, **kw: nn.Linear(i, o, bias=False) |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLUMLP(nn.Module): |
| """SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``.""" |
|
|
| __constants__ = ["hidden_size", "intermediate_size"] |
|
|
| def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True): |
| super().__init__() |
| L = _make_linear(use_ternary) |
| self.hidden_size = int(hidden_size) |
| self.intermediate_size = int(intermediate_size) |
| self.gate_proj = L(self.hidden_size, self.intermediate_size) |
| self.up_proj = L(self.hidden_size, self.intermediate_size) |
| self.down_proj = L(self.intermediate_size, self.hidden_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
| |
| |
|
|
| class ShortConv1d(nn.Module): |
| """Causal depthwise 1-D convolution + SiLU. |
| |
| Supports streaming via a small (kernel_size-1) tail cache so generation |
| runs at O(1) per token even though the conv has a kernel > 1. |
| """ |
|
|
| __constants__ = ["kernel_size", "dim"] |
|
|
| def __init__(self, dim: int, kernel_size: int = 4): |
| super().__init__() |
| self.dim = int(dim) |
| self.kernel_size = int(kernel_size) |
| self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size, |
| padding=self.kernel_size - 1, groups=self.dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| B, T, D = x.shape |
| xt = x.transpose(1, 2) |
| if tail is not None and tail.numel() > 0: |
| xt = torch.cat([tail, xt], dim=-1) |
| T_full = xt.shape[-1] |
| else: |
| T_full = T |
| y = self.conv(xt)[..., :T_full] |
| y = y[..., -T:] |
| new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0] |
| return F.silu(y).transpose(1, 2), new_tail |
|
|
|
|
| |
| |
| |
|
|
| def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| g: torch.Tensor, beta: torch.Tensor, |
| state: Optional[torch.Tensor], chunk_size: int |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Chunkwise gated delta-rule scan. |
| |
| Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``. |
| ``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``. |
| Returns (output [B, T, H, V], new_state). |
| """ |
| B, T, H, K = q.shape |
| V = v.shape[-1] |
| device = q.device |
|
|
| |
| q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32) |
| k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32) |
| v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32) |
| g = g.permute(0, 2, 1).contiguous().to(torch.float32) |
| beta = beta.permute(0, 2, 1).contiguous().to(torch.float32) |
|
|
| scale = K ** -0.5 |
| q = q * scale |
| v = v * beta.unsqueeze(-1) |
|
|
| chunk = min(chunk_size, T) |
| if state is None: |
| S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) |
| else: |
| S = state.to(torch.float32) |
|
|
| out_chunks = [] |
| for start in range(0, T, chunk): |
| end = min(start + chunk, T) |
| c = end - start |
| qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end] |
|
|
| |
| log_decay = gc.cumsum(dim=-1) |
| |
| |
| diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2) |
| causal = _causal_tril_bool(c, device) |
| intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff)) |
|
|
| |
| attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w |
| o_intra = torch.matmul(attn, vc) |
| o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S) |
| out_chunks.append(o_intra + o_inter) |
|
|
| |
| decay_total = log_decay[:, :, -1:] |
| S = S * decay_total.unsqueeze(-1).exp() |
| per_step = (decay_total - log_decay).unsqueeze(-1).exp() |
| S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc) |
|
|
| out = torch.cat(out_chunks, dim=2) |
| return out.permute(0, 2, 1, 3).contiguous(), S |
|
|
|
|
| class GatedDeltaNetLayer(nn.Module): |
| """Gated DeltaNet — chunkwise parallel during training, O(1) per token at inference.""" |
|
|
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, |
| expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6, |
| chunk_size: int = 64, use_ternary: bool = True): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.num_heads = int(num_heads) |
| self.head_dim = int(head_dim) |
| self.head_v_dim = int(head_dim * expand_v) |
| self.key_dim = self.num_heads * self.head_dim |
| self.value_dim = self.num_heads * self.head_v_dim |
| self.chunk_size = int(chunk_size) |
|
|
| L = _make_linear(use_ternary) |
| self.q_proj = L(self.hidden_size, self.key_dim) |
| self.k_proj = L(self.hidden_size, self.key_dim) |
| self.v_proj = L(self.hidden_size, self.value_dim) |
| self.g_proj = L(self.hidden_size, self.value_dim) |
| self.o_proj = L(self.value_dim, self.hidden_size) |
|
|
| self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) |
| self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) |
|
|
| A = torch.empty(self.num_heads).uniform_(0.0, 16.0) |
| self.A_log = nn.Parameter(torch.log(A)) |
| self.A_log._no_weight_decay = True |
| dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4) |
| self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt))) |
| self.dt_bias._no_weight_decay = True |
|
|
| self.q_conv = ShortConv1d(self.key_dim, conv_size) |
| self.k_conv = ShortConv1d(self.key_dim, conv_size) |
| self.v_conv = ShortConv1d(self.value_dim, conv_size) |
| self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) |
|
|
| def forward(self, x: torch.Tensor, cache: Optional[dict] = None |
| ) -> Tuple[torch.Tensor, dict]: |
| B, T, _ = x.shape |
| prev_state = cache.get("state") if cache else None |
| prev_q_tail = cache.get("q_tail") if cache else None |
| prev_k_tail = cache.get("k_tail") if cache else None |
| prev_v_tail = cache.get("v_tail") if cache else None |
|
|
| q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail) |
| k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail) |
| v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail) |
|
|
| q = q_full.view(B, T, self.num_heads, self.head_dim) |
| k = k_full.view(B, T, self.num_heads, self.head_dim) |
| v = v_full.view(B, T, self.num_heads, self.head_v_dim) |
| q = F.normalize(q, p=2.0, dim=-1) |
| k = F.normalize(k, p=2.0, dim=-1) |
|
|
| beta = torch.sigmoid(self.b_proj(x)) |
| A = -self.A_log.exp() |
| dt = F.softplus(self.a_proj(x) + self.dt_bias) |
| g = dt * A.view(1, 1, -1) |
|
|
| out, new_state = _gated_delta_chunkwise(q, k, v, g, beta, |
| state=prev_state, |
| chunk_size=self.chunk_size) |
|
|
| gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim) |
| out = self.o_norm(out) * F.silu(gate) |
| out = out.reshape(B, T, self.value_dim) |
| out = self.o_proj(out) |
|
|
| new_cache = { |
| "state": new_state.detach(), |
| "q_tail": q_tail.detach(), |
| "k_tail": k_tail.detach(), |
| "v_tail": v_tail.detach(), |
| } |
| return out, new_cache |
|
|
|
|
| |
| |
| |
|
|
| class MLSTMLayer(nn.Module): |
| """Parallelised mLSTM with log-space cumulative gates.""" |
|
|
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, |
| norm_eps: float = 1e-6, gate_soft_cap: float = 15.0, |
| use_ternary: bool = True): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.num_heads = int(num_heads) |
| self.head_dim = int(head_dim) |
| self.qk_dim = self.num_heads * self.head_dim |
| self.v_dim = self.num_heads * self.head_dim |
|
|
| L = _make_linear(use_ternary) |
| self.q_proj = L(self.hidden_size, self.qk_dim) |
| self.k_proj = L(self.hidden_size, self.qk_dim) |
| self.v_proj = L(self.hidden_size, self.v_dim) |
| self.o_proj = L(self.v_dim, self.hidden_size) |
|
|
| self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True) |
| self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True) |
| self.ogate = L(self.hidden_size, self.v_dim) |
|
|
| nn.init.constant_(self.igate.bias, -10.0) |
| with torch.no_grad(): |
| self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads)) |
|
|
| self.gate_soft_cap = float(gate_soft_cap) |
| self.o_norm = nn.LayerNorm(self.head_dim) |
| self.eps = 1e-6 |
|
|
| @staticmethod |
| def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor: |
| return cap * torch.tanh(x / cap) |
|
|
| def forward(self, x: torch.Tensor, cache: Optional[dict] = None |
| ) -> Tuple[torch.Tensor, dict]: |
| B, T, _ = x.shape |
| H = self.num_heads |
| D = self.head_dim |
| scale = D ** -0.5 |
|
|
| q = self.q_proj(x).view(B, T, H, D) * scale |
| k = self.k_proj(x).view(B, T, H, D) |
| v = self.v_proj(x).view(B, T, H, D) |
|
|
| i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap) |
| f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap) |
| f_log = F.logsigmoid(f_raw) |
|
|
| |
| prev_logf = cache.get("log_f_cum") if cache else None |
| log_f_cum = f_log.cumsum(dim=1) |
| if prev_logf is not None: |
| log_f_cum = log_f_cum + prev_logf.unsqueeze(1) |
|
|
| |
| q_h = q.permute(0, 2, 1, 3) |
| k_h = k.permute(0, 2, 1, 3) |
| v_h = v.permute(0, 2, 1, 3) |
| log_f_cum_h = log_f_cum.permute(0, 2, 1) |
| i_raw_h = i_raw.permute(0, 2, 1) |
|
|
| |
| log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2) |
| + i_raw_h.unsqueeze(-2)) |
| log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype) |
| m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0) |
| gate_w = (log_gate - m).exp() |
|
|
| attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w |
| n = torch.matmul(gate_w, k_h) |
| denom = (q_h * n).sum(-1, keepdim=True).abs() |
| denom = torch.maximum(denom, torch.exp(-m)) + self.eps |
|
|
| out = torch.matmul(attn, v_h) / denom |
| out = self.o_norm(out.float()).to(x.dtype) |
| out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim) |
|
|
| out_gate = torch.sigmoid(self.ogate(x)) |
| out = self.o_proj(out_gate * out) |
|
|
| new_cache = {"log_f_cum": log_f_cum[:, -1].detach()} |
| return out, new_cache |
|
|
|
|
| |
| |
| |
|
|
| class TitansMACLayer(nn.Module): |
| """Memory-as-Context linear attention with persistent memory slots.""" |
|
|
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, |
| memory_depth: int = 2, persistent_slots: int = 64, |
| local_window: int = 1024, norm_eps: float = 1e-6, |
| use_ternary: bool = True): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.num_heads = int(num_heads) |
| self.head_dim = int(head_dim) |
| self.memory_depth = int(memory_depth) |
| self.local_window = int(local_window) |
| self.persistent_slots = int(persistent_slots) |
| self.qk_dim = self.num_heads * self.head_dim |
| self.v_dim = self.num_heads * self.head_dim |
|
|
| L = _make_linear(use_ternary) |
| self.q_proj = L(self.hidden_size, self.qk_dim) |
| self.k_proj = L(self.hidden_size, self.qk_dim) |
| self.v_proj = L(self.hidden_size, self.v_dim) |
| self.o_proj = L(self.v_dim, self.hidden_size) |
|
|
| self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) |
| self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) |
| self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) |
|
|
| if self.persistent_slots > 0: |
| self.persistent_memory = nn.Parameter( |
| torch.randn(self.persistent_slots, self.hidden_size) * 0.02) |
| else: |
| self.register_parameter("persistent_memory", None) |
|
|
| self.o_norm = RMSNorm(self.v_dim, eps=norm_eps) |
|
|
| def forward(self, x: torch.Tensor, cache: Optional[dict] = None |
| ) -> Tuple[torch.Tensor, dict]: |
| B, T, _ = x.shape |
| H = self.num_heads |
| D = self.head_dim |
| |
| q = self.q_proj(x).view(B, T, H, D) |
| k = self.k_proj(x).view(B, T, H, D) |
| v = self.v_proj(x).view(B, T, H, D) |
|
|
| alpha = torch.sigmoid(self.alpha_proj(x)) |
| eta = torch.sigmoid(self.eta_proj(x)) |
| theta = torch.sigmoid(self.theta_proj(x)) * 0.1 |
|
|
| q_h = q.permute(0, 2, 1, 3).to(torch.float32) |
| k_h = k.permute(0, 2, 1, 3).to(torch.float32) |
| v_h = v.permute(0, 2, 1, 3).to(torch.float32) |
| alpha_h = alpha.permute(0, 2, 1).to(torch.float32) |
| eta_h = eta.permute(0, 2, 1).to(torch.float32) |
| theta_h = theta.permute(0, 2, 1).to(torch.float32) |
|
|
| |
| log_retain = torch.log1p(-alpha_h.clamp(max=0.999)) |
| log_retain_cum = log_retain.cumsum(dim=-1) |
| decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2) |
| decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype) |
| decay = decay.exp() |
|
|
| contrib = (eta_h * theta_h).unsqueeze(-1) * v_h |
| attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay |
| out = torch.matmul(attn, contrib) |
|
|
| out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim) |
| out = self.o_norm(out.to(x.dtype)) |
| return self.o_proj(out), cache or {} |
|
|
|
|
| |
| |
| |
|
|
| class TSPSpanKnotLayer(nn.Module): |
| """TSP Span Knot: GatedDeltaNet body with a small additive energy term.""" |
|
|
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, |
| norm_eps: float = 1e-6, chunk_size: int = 64, |
| use_ternary: bool = True): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim, |
| norm_eps=norm_eps, chunk_size=chunk_size, |
| use_ternary=use_ternary) |
| |
| self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False) |
| self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3])) |
| self._semantic_memory = None |
|
|
| def set_semantic_memory(self, mem) -> None: |
| self._semantic_memory = mem |
|
|
| def forward(self, x: torch.Tensor, cache: Optional[dict] = None |
| ) -> Tuple[torch.Tensor, dict]: |
| out, new_cache = self.gdn(x, cache=cache) |
| energies = self.energy_proj(out) |
| weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True) |
| |
| return out + weighted * 0.01, new_cache |
|
|
|
|
| __all__ = [ |
| "SwiGLUMLP", |
| "ShortConv1d", |
| "GatedDeltaNetLayer", |
| "MLSTMLayer", |
| "TitansMACLayer", |
| "TSPSpanKnotLayer", |
| ] |
|
|