asdf98 commited on
Commit
5e89621
·
verified ·
1 Parent(s): cbb87e4

Upload luminars/ssm.py

Browse files
Files changed (1) hide show
  1. luminars/ssm.py +71 -33
luminars/ssm.py CHANGED
@@ -1,42 +1,80 @@
1
  """
2
- Selective State Space (Mamba2) cell + SelectiveScanKernel.
3
- No dependencies on mamba_ssm -- pure PyTorch.
4
  """
5
  import math
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from einops import rearrange, einsum
10
 
11
- def selective_scan_oneshot(x, delta, A, B, C, D):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
- x: (B, L, N) -- input tokens
14
- delta: (B, L, N) -- time-step, elementwise
15
- A: (N,) -- diagonal S4D real part
16
- B, C: (B, L, N) -- input-dependent
17
- D: (N,) -- skip connection
18
- Returns y: (B, L, N)
19
  """
20
- B_, L, N = x.shape
21
- # discretize: A_bar = exp(delta * A), B_bar = delta * B
22
- # A is negative (stable), delta > 0
23
- A = -torch.abs(A) # force stability
24
- delta = F.softplus(delta) # >0
25
- A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, N, N)?? No, A is (N,)
26
- A_bar = torch.exp(delta * A) # (B, L, N)
27
- B_x = delta * B * x # (B, L, N)
28
-
29
- # recurrent scan
30
- h = torch.zeros(B_, N, device=x.device, dtype=x.dtype)
31
- ys = []
32
- for t in range(L):
33
- h = A_bar[:, t] * h + B_x[:, t]
34
- y = einsum(h, C[:, t], 'b n, b n -> b')
35
- ys.append(y)
36
- # Actually y = (C_t * h).sum(-1) gives scalar per token... reshape needed.
37
- # Let's do it vectorised:
38
- # We actually need y_t = sum_n C_{b,t,n} * h_{b,n} = inner product in N dim
39
- y = torch.stack(ys, dim=1).unsqueeze(-1) * C # no, this is wrong dimension
40
- # FIX: h is (B,N), output is (B,N) from h*C where C is (B,L,N)
41
- # Let me rewrite properly
42
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Lightweight Selective Linear Recurrent Unit (SLRU) -- RWKV/Mamba hybrid.
3
+ No heavy deps. Pure PyTorch. Linear O(n) in seq len.
4
  """
5
  import math
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
 
10
+
11
+ def rmsnorm(x):
12
+ return x * torch.rsqrt(x.var(dim=-1, keepdim=True) + 1e-6) * math.sqrt(x.size(-1))
13
+
14
+
15
+ class SiluGLU(nn.Module):
16
+ """Gated MLP: (x·W_gate) ⊙ SiLU(x·W_up) · W_down"""
17
+ def __init__(self, dim_in, dim_out=None, expand=2):
18
+ super().__init__()
19
+ dim_out = dim_out or dim_in
20
+ hidden = int(dim_in * expand)
21
+ self.W_gate = nn.Linear(dim_in, hidden, bias=False)
22
+ self.W_up = nn.Linear(dim_in, hidden, bias=False)
23
+ self.W_down = nn.Linear(hidden, dim_out, bias=False)
24
+ def forward(self, x):
25
+ return self.W_down(F.silu(self.W_gate(x)) * self.W_up(x))
26
+
27
+
28
+ class SelectiveLRU(nn.Module):
29
  """
30
+ Simplified selective linear recurrent cell.
31
+ h_t = decay_t * h_{t-1} + (1 - decay_t) * (x_t · B_proj)
32
+ y_t = C_proj(h_t) + D_skip * x_t
33
+
34
+ Key: B_t, C_t, decay_t are ALL input-dependent (selective).
35
+ Merges RWKV's time-mixing with Mamba's selective SSM in a tiny form.
36
  """
37
+ def __init__(self, dim, d_state=64, expand=2):
38
+ super().__init__()
39
+ self.dim = dim
40
+ self.d_state = d_state
41
+ self.expand = expand
42
+ hidden = dim * expand
43
+
44
+ # Linear projections (all fused: input -> [B, C, delta, skip, gate])
45
+ self.in_proj = nn.Linear(dim, hidden * 4, bias=False)
46
+
47
+ # State transition
48
+ self.W_B = nn.Linear(hidden, d_state) # input -> state
49
+ self.W_C = nn.Linear(d_state, hidden) # state -> output
50
+ self.log_A = nn.Parameter(torch.randn(d_state)) # stable: -exp(log_A)
51
+ self.D = nn.Parameter(torch.randn(hidden)) # skip connection
52
+
53
+ # Output gate
54
+ self.out_gate = nn.Linear(dim, hidden, bias=False)
55
+ self.out_proj = nn.Linear(hidden, dim, bias=False)
56
+
57
+ def forward(self, x):
58
+ """x: (B, L, dim) -> y: (B, L, dim)"""
59
+ B, L, dim = x.shape
60
+
61
+ # Input-dependent gates
62
+ gates = self.in_proj(x) # (B, L, hidden*4)
63
+ B_gate, C_gate, delta, skip = gates.chunk(4, dim=-1) # each (B, L, hidden)
64
+
65
+ # Selective parameters (per-token, per-channel)
66
+ B_t = torch.tanh(B_gate) # bound selective B
67
+ C_t = torch.tanh(C_gate) # bound selective C
68
+ delta_t = F.softplus(delta) # positive time-step
69
+ decay = torch.exp(-delta_t * torch.exp(self.log_A).view(1, 1, -1)) # (B, L, hidden)
70
+
71
+ # Recurrent scan in hidden dimension (vectorized over batch)
72
+ # h_t: (B, d_state)
73
+ # We process per-token
74
+ state = torch.zeros(B, self.d_state, device=x.device, dtype=x.dtype)
75
+ outputs = []
76
+ for t in range(L):
77
+ # state update
78
+ Bx = einsum(B_t[:, t], x[:, t], 'b h, b d -> b h')
79
+ hmm this isn't right...
80
+ pass