asdf98 commited on
Commit
cbb87e4
·
verified ·
1 Parent(s): 2fd257b

Upload luminars/ssm.py

Browse files
Files changed (1) hide show
  1. luminars/ssm.py +42 -0
luminars/ssm.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Selective State Space (Mamba2) cell + SelectiveScanKernel.
3
+ No dependencies on mamba_ssm -- pure PyTorch.
4
+ """
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, einsum
10
+
11
+ def selective_scan_oneshot(x, delta, A, B, C, D):
12
+ """
13
+ x: (B, L, N) -- input tokens
14
+ delta: (B, L, N) -- time-step, elementwise
15
+ A: (N,) -- diagonal S4D real part
16
+ B, C: (B, L, N) -- input-dependent
17
+ D: (N,) -- skip connection
18
+ Returns y: (B, L, N)
19
+ """
20
+ B_, L, N = x.shape
21
+ # discretize: A_bar = exp(delta * A), B_bar = delta * B
22
+ # A is negative (stable), delta > 0
23
+ A = -torch.abs(A) # force stability
24
+ delta = F.softplus(delta) # >0
25
+ A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, N, N)?? No, A is (N,)
26
+ A_bar = torch.exp(delta * A) # (B, L, N)
27
+ B_x = delta * B * x # (B, L, N)
28
+
29
+ # recurrent scan
30
+ h = torch.zeros(B_, N, device=x.device, dtype=x.dtype)
31
+ ys = []
32
+ for t in range(L):
33
+ h = A_bar[:, t] * h + B_x[:, t]
34
+ y = einsum(h, C[:, t], 'b n, b n -> b')
35
+ ys.append(y)
36
+ # Actually y = (C_t * h).sum(-1) gives scalar per token... reshape needed.
37
+ # Let's do it vectorised:
38
+ # We actually need y_t = sum_n C_{b,t,n} * h_{b,n} = inner product in N dim
39
+ y = torch.stack(ys, dim=1).unsqueeze(-1) * C # no, this is wrong dimension
40
+ # FIX: h is (B,N), output is (B,N) from h*C where C is (B,L,N)
41
+ # Let me rewrite properly
42
+ pass