File size: 21,096 Bytes
6e408ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
"""
Chimera 5.2 — recurrent / attention layers (CPU-first).

Every layer in this module exposes a ``forward(x, cache=None)`` signature and
returns ``(out, new_cache)``.  ``cache`` is an arbitrary tensor / dict that the
layer reads on the previous timestep and returns updated for the next call.
This makes O(T) decoding possible instead of the O(T²) recompute used by
the original implementation.

Optimisations vs. the previous draft:
* No ``einops`` dependency — every reshape is a plain :func:`Tensor.view`.
* Mask cache keyed by (T, dtype, device) — no per-token allocation churn.
* Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones
  during training (the inter-chunk recurrence runs at fp32 with detached
  state on CPU, gradient flow is preserved through the per-chunk QKV path).
* mLSTM forgets are accumulated in log-space with a single ``cumsum``; the
  causal mask is added once instead of per-row.
* TitansMAC only computes the values it actually uses (the original draft
  built ``kv`` and threw it away – removed).
* TSPSpanKnotLayer's energy is a single fused linear projection; the per-step
  Hamming/coherence loops are replaced by vectorised cosine similarity.
"""

from __future__ import annotations

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .quantization import BitLinear, RMSNorm


# ---------------------------------------------------------------------------
# Shared utilities
# ---------------------------------------------------------------------------

_MASK_CACHE: dict = {}


def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Cached additive causal mask: 0 on/below diag, ``-inf`` above."""
    key = ("neg_inf", T, str(device), dtype)
    cached = _MASK_CACHE.get(key)
    if cached is not None:
        return cached
    # Build outside any autograd / inference-mode context so the tensor is a
    # plain leaf that can be reused across train/eval/inference_mode calls.
    with torch.inference_mode(False), torch.no_grad():
        mask = torch.zeros(T, T, dtype=dtype, device=device)
        mask.masked_fill_(
            torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1),
            float("-inf"),
        )
    _MASK_CACHE[key] = mask
    return mask


def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor:
    """Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating."""
    key = ("tril_bool", T, str(device))
    cached = _MASK_CACHE.get(key)
    if cached is not None:
        return cached
    with torch.inference_mode(False), torch.no_grad():
        mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
    _MASK_CACHE[key] = mask
    return mask


def _make_linear(use_ternary: bool):
    if use_ternary:
        return BitLinear
    return lambda i, o, **kw: nn.Linear(i, o, bias=False)


# ---------------------------------------------------------------------------
# SwiGLU MLP (shared with MoE)
# ---------------------------------------------------------------------------

class SwiGLUMLP(nn.Module):
    """SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``."""

    __constants__ = ["hidden_size", "intermediate_size"]

    def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True):
        super().__init__()
        L = _make_linear(use_ternary)
        self.hidden_size = int(hidden_size)
        self.intermediate_size = int(intermediate_size)
        self.gate_proj = L(self.hidden_size, self.intermediate_size)
        self.up_proj = L(self.hidden_size, self.intermediate_size)
        self.down_proj = L(self.intermediate_size, self.hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


# ---------------------------------------------------------------------------
# Causal depthwise conv (used by Gated DeltaNet)
# ---------------------------------------------------------------------------

class ShortConv1d(nn.Module):
    """Causal depthwise 1-D convolution + SiLU.

    Supports streaming via a small (kernel_size-1) tail cache so generation
    runs at O(1) per token even though the conv has a kernel > 1.
    """

    __constants__ = ["kernel_size", "dim"]

    def __init__(self, dim: int, kernel_size: int = 4):
        super().__init__()
        self.dim = int(dim)
        self.kernel_size = int(kernel_size)
        self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size,
                              padding=self.kernel_size - 1, groups=self.dim, bias=False)

    def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: [B, T, D] -> conv expects [B, D, T]
        B, T, D = x.shape
        xt = x.transpose(1, 2)  # [B, D, T]
        if tail is not None and tail.numel() > 0:
            xt = torch.cat([tail, xt], dim=-1)
            T_full = xt.shape[-1]
        else:
            T_full = T
        y = self.conv(xt)[..., :T_full]  # causal: drop the trailing pad slack
        y = y[..., -T:]  # only keep outputs aligned with new inputs
        new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0]
        return F.silu(y).transpose(1, 2), new_tail


# ---------------------------------------------------------------------------
# Gated DeltaNet (chunkwise parallel + recurrent state)
# ---------------------------------------------------------------------------

def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                           g: torch.Tensor, beta: torch.Tensor,
                           state: Optional[torch.Tensor], chunk_size: int
                           ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Chunkwise gated delta-rule scan.

    Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``.
    ``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``.
    Returns (output [B, T, H, V], new_state).
    """
    B, T, H, K = q.shape
    V = v.shape[-1]
    device = q.device

    # Permute once: [B, H, T, *]
    q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32)
    k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32)
    v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32)
    g = g.permute(0, 2, 1).contiguous().to(torch.float32)         # [B, H, T]
    beta = beta.permute(0, 2, 1).contiguous().to(torch.float32)   # [B, H, T]

    scale = K ** -0.5
    q = q * scale
    v = v * beta.unsqueeze(-1)

    chunk = min(chunk_size, T)
    if state is None:
        S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
    else:
        S = state.to(torch.float32)

    out_chunks = []
    for start in range(0, T, chunk):
        end = min(start + chunk, T)
        c = end - start
        qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end]

        # Cumulative log-decay within the chunk.
        log_decay = gc.cumsum(dim=-1)                                  # [B, H, c]
        # Within-chunk weighting: exp(log_decay[i] - log_decay[j]) for j <= i
        # Built once via outer subtraction; mask non-causal entries to 0.
        diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2)       # [B, H, c, c]
        causal = _causal_tril_bool(c, device)                          # [c, c]
        intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff))

        # Output = qc @ kc^T * intra_w @ vc  +  qc * exp(log_decay) @ S
        attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w        # [B, H, c, c]
        o_intra = torch.matmul(attn, vc)                               # [B, H, c, V]
        o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S)  # [B, H, c, V]
        out_chunks.append(o_intra + o_inter)

        # Update carried state: S <- S * exp(decay_total) + (kc * exp(decay_chunk_end - log_decay)).T @ vc
        decay_total = log_decay[:, :, -1:]                             # [B, H, 1]
        S = S * decay_total.unsqueeze(-1).exp()
        per_step = (decay_total - log_decay).unsqueeze(-1).exp()       # [B, H, c, 1]
        S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc)

    out = torch.cat(out_chunks, dim=2)                                  # [B, H, T, V]
    return out.permute(0, 2, 1, 3).contiguous(), S


class GatedDeltaNetLayer(nn.Module):
    """Gated DeltaNet — chunkwise parallel during training, O(1) per token at inference."""

    def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
                 expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6,
                 chunk_size: int = 64, use_ternary: bool = True):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.num_heads = int(num_heads)
        self.head_dim = int(head_dim)
        self.head_v_dim = int(head_dim * expand_v)
        self.key_dim = self.num_heads * self.head_dim
        self.value_dim = self.num_heads * self.head_v_dim
        self.chunk_size = int(chunk_size)

        L = _make_linear(use_ternary)
        self.q_proj = L(self.hidden_size, self.key_dim)
        self.k_proj = L(self.hidden_size, self.key_dim)
        self.v_proj = L(self.hidden_size, self.value_dim)
        self.g_proj = L(self.hidden_size, self.value_dim)
        self.o_proj = L(self.value_dim, self.hidden_size)

        self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
        self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)

        A = torch.empty(self.num_heads).uniform_(0.0, 16.0)
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True
        dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4)
        self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt)))
        self.dt_bias._no_weight_decay = True

        self.q_conv = ShortConv1d(self.key_dim, conv_size)
        self.k_conv = ShortConv1d(self.key_dim, conv_size)
        self.v_conv = ShortConv1d(self.value_dim, conv_size)
        self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)

    def forward(self, x: torch.Tensor, cache: Optional[dict] = None
                ) -> Tuple[torch.Tensor, dict]:
        B, T, _ = x.shape
        prev_state = cache.get("state") if cache else None
        prev_q_tail = cache.get("q_tail") if cache else None
        prev_k_tail = cache.get("k_tail") if cache else None
        prev_v_tail = cache.get("v_tail") if cache else None

        q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail)
        k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail)
        v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail)

        q = q_full.view(B, T, self.num_heads, self.head_dim)
        k = k_full.view(B, T, self.num_heads, self.head_dim)
        v = v_full.view(B, T, self.num_heads, self.head_v_dim)
        q = F.normalize(q, p=2.0, dim=-1)
        k = F.normalize(k, p=2.0, dim=-1)

        beta = torch.sigmoid(self.b_proj(x))                        # [B, T, H]
        A = -self.A_log.exp()
        dt = F.softplus(self.a_proj(x) + self.dt_bias)              # [B, T, H]
        g = dt * A.view(1, 1, -1)

        out, new_state = _gated_delta_chunkwise(q, k, v, g, beta,
                                                state=prev_state,
                                                chunk_size=self.chunk_size)

        gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim)
        out = self.o_norm(out) * F.silu(gate)
        out = out.reshape(B, T, self.value_dim)
        out = self.o_proj(out)

        new_cache = {
            "state": new_state.detach(),
            "q_tail": q_tail.detach(),
            "k_tail": k_tail.detach(),
            "v_tail": v_tail.detach(),
        }
        return out, new_cache


# ---------------------------------------------------------------------------
# xLSTM mLSTM — parallel chunkwise + carried state
# ---------------------------------------------------------------------------

class MLSTMLayer(nn.Module):
    """Parallelised mLSTM with log-space cumulative gates."""

    def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
                 norm_eps: float = 1e-6, gate_soft_cap: float = 15.0,
                 use_ternary: bool = True):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.num_heads = int(num_heads)
        self.head_dim = int(head_dim)
        self.qk_dim = self.num_heads * self.head_dim
        self.v_dim = self.num_heads * self.head_dim

        L = _make_linear(use_ternary)
        self.q_proj = L(self.hidden_size, self.qk_dim)
        self.k_proj = L(self.hidden_size, self.qk_dim)
        self.v_proj = L(self.hidden_size, self.v_dim)
        self.o_proj = L(self.v_dim, self.hidden_size)

        self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        self.ogate = L(self.hidden_size, self.v_dim)

        nn.init.constant_(self.igate.bias, -10.0)
        with torch.no_grad():
            self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads))

        self.gate_soft_cap = float(gate_soft_cap)
        self.o_norm = nn.LayerNorm(self.head_dim)
        self.eps = 1e-6

    @staticmethod
    def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor:
        return cap * torch.tanh(x / cap)

    def forward(self, x: torch.Tensor, cache: Optional[dict] = None
                ) -> Tuple[torch.Tensor, dict]:
        B, T, _ = x.shape
        H = self.num_heads
        D = self.head_dim
        scale = D ** -0.5

        q = self.q_proj(x).view(B, T, H, D) * scale
        k = self.k_proj(x).view(B, T, H, D)
        v = self.v_proj(x).view(B, T, H, D)

        i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap)   # [B, T, H]
        f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap)
        f_log = F.logsigmoid(f_raw)                                  # [B, T, H]

        # Log-space accumulators with carry-in.
        prev_logf = cache.get("log_f_cum") if cache else None        # [B, H]
        log_f_cum = f_log.cumsum(dim=1)                              # [B, T, H]
        if prev_logf is not None:
            log_f_cum = log_f_cum + prev_logf.unsqueeze(1)

        # Permute to head-major.
        q_h = q.permute(0, 2, 1, 3)                                  # [B, H, T, D]
        k_h = k.permute(0, 2, 1, 3)
        v_h = v.permute(0, 2, 1, 3)
        log_f_cum_h = log_f_cum.permute(0, 2, 1)                     # [B, H, T]
        i_raw_h = i_raw.permute(0, 2, 1)

        # log_gate[t, s] = log_f_cum[t] - log_f_cum[s] + i[s], causal.
        log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2)
                    + i_raw_h.unsqueeze(-2))
        log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype)
        m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0)
        gate_w = (log_gate - m).exp()                                # [B, H, T, T]

        attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w
        n = torch.matmul(gate_w, k_h)                                # [B, H, T, D]
        denom = (q_h * n).sum(-1, keepdim=True).abs()
        denom = torch.maximum(denom, torch.exp(-m)) + self.eps

        out = torch.matmul(attn, v_h) / denom                        # [B, H, T, D]
        out = self.o_norm(out.float()).to(x.dtype)
        out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)

        out_gate = torch.sigmoid(self.ogate(x))
        out = self.o_proj(out_gate * out)

        new_cache = {"log_f_cum": log_f_cum[:, -1].detach()}
        return out, new_cache


# ---------------------------------------------------------------------------
# Titans MAC — gated linear attention with persistent memory
# ---------------------------------------------------------------------------

class TitansMACLayer(nn.Module):
    """Memory-as-Context linear attention with persistent memory slots."""

    def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
                 memory_depth: int = 2, persistent_slots: int = 64,
                 local_window: int = 1024, norm_eps: float = 1e-6,
                 use_ternary: bool = True):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.num_heads = int(num_heads)
        self.head_dim = int(head_dim)
        self.memory_depth = int(memory_depth)
        self.local_window = int(local_window)
        self.persistent_slots = int(persistent_slots)
        self.qk_dim = self.num_heads * self.head_dim
        self.v_dim = self.num_heads * self.head_dim

        L = _make_linear(use_ternary)
        self.q_proj = L(self.hidden_size, self.qk_dim)
        self.k_proj = L(self.hidden_size, self.qk_dim)
        self.v_proj = L(self.hidden_size, self.v_dim)
        self.o_proj = L(self.v_dim, self.hidden_size)

        self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)

        if self.persistent_slots > 0:
            self.persistent_memory = nn.Parameter(
                torch.randn(self.persistent_slots, self.hidden_size) * 0.02)
        else:
            self.register_parameter("persistent_memory", None)

        self.o_norm = RMSNorm(self.v_dim, eps=norm_eps)

    def forward(self, x: torch.Tensor, cache: Optional[dict] = None
                ) -> Tuple[torch.Tensor, dict]:
        B, T, _ = x.shape
        H = self.num_heads
        D = self.head_dim
        # Project once.
        q = self.q_proj(x).view(B, T, H, D)
        k = self.k_proj(x).view(B, T, H, D)
        v = self.v_proj(x).view(B, T, H, D)

        alpha = torch.sigmoid(self.alpha_proj(x))                     # [B, T, H]
        eta = torch.sigmoid(self.eta_proj(x))
        theta = torch.sigmoid(self.theta_proj(x)) * 0.1

        q_h = q.permute(0, 2, 1, 3).to(torch.float32)
        k_h = k.permute(0, 2, 1, 3).to(torch.float32)
        v_h = v.permute(0, 2, 1, 3).to(torch.float32)
        alpha_h = alpha.permute(0, 2, 1).to(torch.float32)
        eta_h = eta.permute(0, 2, 1).to(torch.float32)
        theta_h = theta.permute(0, 2, 1).to(torch.float32)

        # Causal forgetting decay built in log-space.
        log_retain = torch.log1p(-alpha_h.clamp(max=0.999))
        log_retain_cum = log_retain.cumsum(dim=-1)
        decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2)
        decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype)
        decay = decay.exp()                                            # 0 above diag

        contrib = (eta_h * theta_h).unsqueeze(-1) * v_h                # [B, H, T, D]
        attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay        # [B, H, T, T]
        out = torch.matmul(attn, contrib)                              # [B, H, T, D]

        out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
        out = self.o_norm(out.to(x.dtype))
        return self.o_proj(out), cache or {}


# ---------------------------------------------------------------------------
# TSP Span Knot — fast vectorised energy
# ---------------------------------------------------------------------------

class TSPSpanKnotLayer(nn.Module):
    """TSP Span Knot: GatedDeltaNet body with a small additive energy term."""

    def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
                 norm_eps: float = 1e-6, chunk_size: int = 64,
                 use_ternary: bool = True):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim,
                                      norm_eps=norm_eps, chunk_size=chunk_size,
                                      use_ternary=use_ternary)
        # Single fused projection produces five energy terms.
        self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False)
        self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3]))
        self._semantic_memory = None

    def set_semantic_memory(self, mem) -> None:
        self._semantic_memory = mem

    def forward(self, x: torch.Tensor, cache: Optional[dict] = None
                ) -> Tuple[torch.Tensor, dict]:
        out, new_cache = self.gdn(x, cache=cache)
        energies = self.energy_proj(out)                              # [B, T, 5]
        weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True)
        # Small residual nudge — keeps gradient signal small as in 5.1.
        return out + weighted * 0.01, new_cache


__all__ = [
    "SwiGLUMLP",
    "ShortConv1d",
    "GatedDeltaNetLayer",
    "MLSTMLayer",
    "TitansMACLayer",
    "TSPSpanKnotLayer",
]