ch1mera / chimera /layers.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
"""
Chimera 5.2 — recurrent / attention layers (CPU-first).
Every layer in this module exposes a ``forward(x, cache=None)`` signature and
returns ``(out, new_cache)``. ``cache`` is an arbitrary tensor / dict that the
layer reads on the previous timestep and returns updated for the next call.
This makes O(T) decoding possible instead of the O(T²) recompute used by
the original implementation.
Optimisations vs. the previous draft:
* No ``einops`` dependency — every reshape is a plain :func:`Tensor.view`.
* Mask cache keyed by (T, dtype, device) — no per-token allocation churn.
* Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones
during training (the inter-chunk recurrence runs at fp32 with detached
state on CPU, gradient flow is preserved through the per-chunk QKV path).
* mLSTM forgets are accumulated in log-space with a single ``cumsum``; the
causal mask is added once instead of per-row.
* TitansMAC only computes the values it actually uses (the original draft
built ``kv`` and threw it away – removed).
* TSPSpanKnotLayer's energy is a single fused linear projection; the per-step
Hamming/coherence loops are replaced by vectorised cosine similarity.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .quantization import BitLinear, RMSNorm
# ---------------------------------------------------------------------------
# Shared utilities
# ---------------------------------------------------------------------------
_MASK_CACHE: dict = {}
def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Cached additive causal mask: 0 on/below diag, ``-inf`` above."""
key = ("neg_inf", T, str(device), dtype)
cached = _MASK_CACHE.get(key)
if cached is not None:
return cached
# Build outside any autograd / inference-mode context so the tensor is a
# plain leaf that can be reused across train/eval/inference_mode calls.
with torch.inference_mode(False), torch.no_grad():
mask = torch.zeros(T, T, dtype=dtype, device=device)
mask.masked_fill_(
torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1),
float("-inf"),
)
_MASK_CACHE[key] = mask
return mask
def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor:
"""Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating."""
key = ("tril_bool", T, str(device))
cached = _MASK_CACHE.get(key)
if cached is not None:
return cached
with torch.inference_mode(False), torch.no_grad():
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
_MASK_CACHE[key] = mask
return mask
def _make_linear(use_ternary: bool):
if use_ternary:
return BitLinear
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
# ---------------------------------------------------------------------------
# SwiGLU MLP (shared with MoE)
# ---------------------------------------------------------------------------
class SwiGLUMLP(nn.Module):
"""SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``."""
__constants__ = ["hidden_size", "intermediate_size"]
def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True):
super().__init__()
L = _make_linear(use_ternary)
self.hidden_size = int(hidden_size)
self.intermediate_size = int(intermediate_size)
self.gate_proj = L(self.hidden_size, self.intermediate_size)
self.up_proj = L(self.hidden_size, self.intermediate_size)
self.down_proj = L(self.intermediate_size, self.hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ---------------------------------------------------------------------------
# Causal depthwise conv (used by Gated DeltaNet)
# ---------------------------------------------------------------------------
class ShortConv1d(nn.Module):
"""Causal depthwise 1-D convolution + SiLU.
Supports streaming via a small (kernel_size-1) tail cache so generation
runs at O(1) per token even though the conv has a kernel > 1.
"""
__constants__ = ["kernel_size", "dim"]
def __init__(self, dim: int, kernel_size: int = 4):
super().__init__()
self.dim = int(dim)
self.kernel_size = int(kernel_size)
self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size,
padding=self.kernel_size - 1, groups=self.dim, bias=False)
def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [B, T, D] -> conv expects [B, D, T]
B, T, D = x.shape
xt = x.transpose(1, 2) # [B, D, T]
if tail is not None and tail.numel() > 0:
xt = torch.cat([tail, xt], dim=-1)
T_full = xt.shape[-1]
else:
T_full = T
y = self.conv(xt)[..., :T_full] # causal: drop the trailing pad slack
y = y[..., -T:] # only keep outputs aligned with new inputs
new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0]
return F.silu(y).transpose(1, 2), new_tail
# ---------------------------------------------------------------------------
# Gated DeltaNet (chunkwise parallel + recurrent state)
# ---------------------------------------------------------------------------
def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
g: torch.Tensor, beta: torch.Tensor,
state: Optional[torch.Tensor], chunk_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Chunkwise gated delta-rule scan.
Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``.
``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``.
Returns (output [B, T, H, V], new_state).
"""
B, T, H, K = q.shape
V = v.shape[-1]
device = q.device
# Permute once: [B, H, T, *]
q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32)
k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32)
v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32)
g = g.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
beta = beta.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
scale = K ** -0.5
q = q * scale
v = v * beta.unsqueeze(-1)
chunk = min(chunk_size, T)
if state is None:
S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
else:
S = state.to(torch.float32)
out_chunks = []
for start in range(0, T, chunk):
end = min(start + chunk, T)
c = end - start
qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end]
# Cumulative log-decay within the chunk.
log_decay = gc.cumsum(dim=-1) # [B, H, c]
# Within-chunk weighting: exp(log_decay[i] - log_decay[j]) for j <= i
# Built once via outer subtraction; mask non-causal entries to 0.
diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2) # [B, H, c, c]
causal = _causal_tril_bool(c, device) # [c, c]
intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff))
# Output = qc @ kc^T * intra_w @ vc + qc * exp(log_decay) @ S
attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w # [B, H, c, c]
o_intra = torch.matmul(attn, vc) # [B, H, c, V]
o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S) # [B, H, c, V]
out_chunks.append(o_intra + o_inter)
# Update carried state: S <- S * exp(decay_total) + (kc * exp(decay_chunk_end - log_decay)).T @ vc
decay_total = log_decay[:, :, -1:] # [B, H, 1]
S = S * decay_total.unsqueeze(-1).exp()
per_step = (decay_total - log_decay).unsqueeze(-1).exp() # [B, H, c, 1]
S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc)
out = torch.cat(out_chunks, dim=2) # [B, H, T, V]
return out.permute(0, 2, 1, 3).contiguous(), S
class GatedDeltaNetLayer(nn.Module):
"""Gated DeltaNet — chunkwise parallel during training, O(1) per token at inference."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6,
chunk_size: int = 64, use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.num_heads = int(num_heads)
self.head_dim = int(head_dim)
self.head_v_dim = int(head_dim * expand_v)
self.key_dim = self.num_heads * self.head_dim
self.value_dim = self.num_heads * self.head_v_dim
self.chunk_size = int(chunk_size)
L = _make_linear(use_ternary)
self.q_proj = L(self.hidden_size, self.key_dim)
self.k_proj = L(self.hidden_size, self.key_dim)
self.v_proj = L(self.hidden_size, self.value_dim)
self.g_proj = L(self.hidden_size, self.value_dim)
self.o_proj = L(self.value_dim, self.hidden_size)
self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
A = torch.empty(self.num_heads).uniform_(0.0, 16.0)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4)
self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt)))
self.dt_bias._no_weight_decay = True
self.q_conv = ShortConv1d(self.key_dim, conv_size)
self.k_conv = ShortConv1d(self.key_dim, conv_size)
self.v_conv = ShortConv1d(self.value_dim, conv_size)
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
) -> Tuple[torch.Tensor, dict]:
B, T, _ = x.shape
prev_state = cache.get("state") if cache else None
prev_q_tail = cache.get("q_tail") if cache else None
prev_k_tail = cache.get("k_tail") if cache else None
prev_v_tail = cache.get("v_tail") if cache else None
q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail)
k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail)
v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail)
q = q_full.view(B, T, self.num_heads, self.head_dim)
k = k_full.view(B, T, self.num_heads, self.head_dim)
v = v_full.view(B, T, self.num_heads, self.head_v_dim)
q = F.normalize(q, p=2.0, dim=-1)
k = F.normalize(k, p=2.0, dim=-1)
beta = torch.sigmoid(self.b_proj(x)) # [B, T, H]
A = -self.A_log.exp()
dt = F.softplus(self.a_proj(x) + self.dt_bias) # [B, T, H]
g = dt * A.view(1, 1, -1)
out, new_state = _gated_delta_chunkwise(q, k, v, g, beta,
state=prev_state,
chunk_size=self.chunk_size)
gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim)
out = self.o_norm(out) * F.silu(gate)
out = out.reshape(B, T, self.value_dim)
out = self.o_proj(out)
new_cache = {
"state": new_state.detach(),
"q_tail": q_tail.detach(),
"k_tail": k_tail.detach(),
"v_tail": v_tail.detach(),
}
return out, new_cache
# ---------------------------------------------------------------------------
# xLSTM mLSTM — parallel chunkwise + carried state
# ---------------------------------------------------------------------------
class MLSTMLayer(nn.Module):
"""Parallelised mLSTM with log-space cumulative gates."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
norm_eps: float = 1e-6, gate_soft_cap: float = 15.0,
use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.num_heads = int(num_heads)
self.head_dim = int(head_dim)
self.qk_dim = self.num_heads * self.head_dim
self.v_dim = self.num_heads * self.head_dim
L = _make_linear(use_ternary)
self.q_proj = L(self.hidden_size, self.qk_dim)
self.k_proj = L(self.hidden_size, self.qk_dim)
self.v_proj = L(self.hidden_size, self.v_dim)
self.o_proj = L(self.v_dim, self.hidden_size)
self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
self.ogate = L(self.hidden_size, self.v_dim)
nn.init.constant_(self.igate.bias, -10.0)
with torch.no_grad():
self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads))
self.gate_soft_cap = float(gate_soft_cap)
self.o_norm = nn.LayerNorm(self.head_dim)
self.eps = 1e-6
@staticmethod
def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor:
return cap * torch.tanh(x / cap)
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
) -> Tuple[torch.Tensor, dict]:
B, T, _ = x.shape
H = self.num_heads
D = self.head_dim
scale = D ** -0.5
q = self.q_proj(x).view(B, T, H, D) * scale
k = self.k_proj(x).view(B, T, H, D)
v = self.v_proj(x).view(B, T, H, D)
i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap) # [B, T, H]
f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap)
f_log = F.logsigmoid(f_raw) # [B, T, H]
# Log-space accumulators with carry-in.
prev_logf = cache.get("log_f_cum") if cache else None # [B, H]
log_f_cum = f_log.cumsum(dim=1) # [B, T, H]
if prev_logf is not None:
log_f_cum = log_f_cum + prev_logf.unsqueeze(1)
# Permute to head-major.
q_h = q.permute(0, 2, 1, 3) # [B, H, T, D]
k_h = k.permute(0, 2, 1, 3)
v_h = v.permute(0, 2, 1, 3)
log_f_cum_h = log_f_cum.permute(0, 2, 1) # [B, H, T]
i_raw_h = i_raw.permute(0, 2, 1)
# log_gate[t, s] = log_f_cum[t] - log_f_cum[s] + i[s], causal.
log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2)
+ i_raw_h.unsqueeze(-2))
log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype)
m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0)
gate_w = (log_gate - m).exp() # [B, H, T, T]
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w
n = torch.matmul(gate_w, k_h) # [B, H, T, D]
denom = (q_h * n).sum(-1, keepdim=True).abs()
denom = torch.maximum(denom, torch.exp(-m)) + self.eps
out = torch.matmul(attn, v_h) / denom # [B, H, T, D]
out = self.o_norm(out.float()).to(x.dtype)
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
out_gate = torch.sigmoid(self.ogate(x))
out = self.o_proj(out_gate * out)
new_cache = {"log_f_cum": log_f_cum[:, -1].detach()}
return out, new_cache
# ---------------------------------------------------------------------------
# Titans MAC — gated linear attention with persistent memory
# ---------------------------------------------------------------------------
class TitansMACLayer(nn.Module):
"""Memory-as-Context linear attention with persistent memory slots."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
memory_depth: int = 2, persistent_slots: int = 64,
local_window: int = 1024, norm_eps: float = 1e-6,
use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.num_heads = int(num_heads)
self.head_dim = int(head_dim)
self.memory_depth = int(memory_depth)
self.local_window = int(local_window)
self.persistent_slots = int(persistent_slots)
self.qk_dim = self.num_heads * self.head_dim
self.v_dim = self.num_heads * self.head_dim
L = _make_linear(use_ternary)
self.q_proj = L(self.hidden_size, self.qk_dim)
self.k_proj = L(self.hidden_size, self.qk_dim)
self.v_proj = L(self.hidden_size, self.v_dim)
self.o_proj = L(self.v_dim, self.hidden_size)
self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
if self.persistent_slots > 0:
self.persistent_memory = nn.Parameter(
torch.randn(self.persistent_slots, self.hidden_size) * 0.02)
else:
self.register_parameter("persistent_memory", None)
self.o_norm = RMSNorm(self.v_dim, eps=norm_eps)
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
) -> Tuple[torch.Tensor, dict]:
B, T, _ = x.shape
H = self.num_heads
D = self.head_dim
# Project once.
q = self.q_proj(x).view(B, T, H, D)
k = self.k_proj(x).view(B, T, H, D)
v = self.v_proj(x).view(B, T, H, D)
alpha = torch.sigmoid(self.alpha_proj(x)) # [B, T, H]
eta = torch.sigmoid(self.eta_proj(x))
theta = torch.sigmoid(self.theta_proj(x)) * 0.1
q_h = q.permute(0, 2, 1, 3).to(torch.float32)
k_h = k.permute(0, 2, 1, 3).to(torch.float32)
v_h = v.permute(0, 2, 1, 3).to(torch.float32)
alpha_h = alpha.permute(0, 2, 1).to(torch.float32)
eta_h = eta.permute(0, 2, 1).to(torch.float32)
theta_h = theta.permute(0, 2, 1).to(torch.float32)
# Causal forgetting decay built in log-space.
log_retain = torch.log1p(-alpha_h.clamp(max=0.999))
log_retain_cum = log_retain.cumsum(dim=-1)
decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2)
decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype)
decay = decay.exp() # 0 above diag
contrib = (eta_h * theta_h).unsqueeze(-1) * v_h # [B, H, T, D]
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay # [B, H, T, T]
out = torch.matmul(attn, contrib) # [B, H, T, D]
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
out = self.o_norm(out.to(x.dtype))
return self.o_proj(out), cache or {}
# ---------------------------------------------------------------------------
# TSP Span Knot — fast vectorised energy
# ---------------------------------------------------------------------------
class TSPSpanKnotLayer(nn.Module):
"""TSP Span Knot: GatedDeltaNet body with a small additive energy term."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
norm_eps: float = 1e-6, chunk_size: int = 64,
use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim,
norm_eps=norm_eps, chunk_size=chunk_size,
use_ternary=use_ternary)
# Single fused projection produces five energy terms.
self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False)
self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3]))
self._semantic_memory = None
def set_semantic_memory(self, mem) -> None:
self._semantic_memory = mem
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
) -> Tuple[torch.Tensor, dict]:
out, new_cache = self.gdn(x, cache=cache)
energies = self.energy_proj(out) # [B, T, 5]
weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True)
# Small residual nudge — keeps gradient signal small as in 5.1.
return out + weighted * 0.01, new_cache
__all__ = [
"SwiGLUMLP",
"ShortConv1d",
"GatedDeltaNetLayer",
"MLSTMLayer",
"TitansMACLayer",
"TSPSpanKnotLayer",
]