Upload luminars/ssm.py
Browse files- luminars/ssm.py +42 -0
luminars/ssm.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Selective State Space (Mamba2) cell + SelectiveScanKernel.
|
| 3 |
+
No dependencies on mamba_ssm -- pure PyTorch.
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange, einsum
|
| 10 |
+
|
| 11 |
+
def selective_scan_oneshot(x, delta, A, B, C, D):
|
| 12 |
+
"""
|
| 13 |
+
x: (B, L, N) -- input tokens
|
| 14 |
+
delta: (B, L, N) -- time-step, elementwise
|
| 15 |
+
A: (N,) -- diagonal S4D real part
|
| 16 |
+
B, C: (B, L, N) -- input-dependent
|
| 17 |
+
D: (N,) -- skip connection
|
| 18 |
+
Returns y: (B, L, N)
|
| 19 |
+
"""
|
| 20 |
+
B_, L, N = x.shape
|
| 21 |
+
# discretize: A_bar = exp(delta * A), B_bar = delta * B
|
| 22 |
+
# A is negative (stable), delta > 0
|
| 23 |
+
A = -torch.abs(A) # force stability
|
| 24 |
+
delta = F.softplus(delta) # >0
|
| 25 |
+
A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, N, N)?? No, A is (N,)
|
| 26 |
+
A_bar = torch.exp(delta * A) # (B, L, N)
|
| 27 |
+
B_x = delta * B * x # (B, L, N)
|
| 28 |
+
|
| 29 |
+
# recurrent scan
|
| 30 |
+
h = torch.zeros(B_, N, device=x.device, dtype=x.dtype)
|
| 31 |
+
ys = []
|
| 32 |
+
for t in range(L):
|
| 33 |
+
h = A_bar[:, t] * h + B_x[:, t]
|
| 34 |
+
y = einsum(h, C[:, t], 'b n, b n -> b')
|
| 35 |
+
ys.append(y)
|
| 36 |
+
# Actually y = (C_t * h).sum(-1) gives scalar per token... reshape needed.
|
| 37 |
+
# Let's do it vectorised:
|
| 38 |
+
# We actually need y_t = sum_n C_{b,t,n} * h_{b,n} = inner product in N dim
|
| 39 |
+
y = torch.stack(ys, dim=1).unsqueeze(-1) * C # no, this is wrong dimension
|
| 40 |
+
# FIX: h is (B,N), output is (B,N) from h*C where C is (B,L,N)
|
| 41 |
+
# Let me rewrite properly
|
| 42 |
+
pass
|