asdf98 commited on
Commit
bc114d4
·
verified ·
1 Parent(s): 5e89621

Upload luminars/ssm.py

Browse files
Files changed (1) hide show
  1. luminars/ssm.py +88 -55
luminars/ssm.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- Lightweight Selective Linear Recurrent Unit (SLRU) -- RWKV/Mamba hybrid.
3
- No heavy deps. Pure PyTorch. Linear O(n) in seq len.
 
4
  """
5
  import math
6
  import torch
@@ -8,73 +9,105 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
 
10
 
11
- def rmsnorm(x):
12
- return x * torch.rsqrt(x.var(dim=-1, keepdim=True) + 1e-6) * math.sqrt(x.size(-1))
13
 
14
 
15
- class SiluGLU(nn.Module):
16
- """Gated MLP: (x·W_gate) SiLU(x·W_up) · W_down"""
17
- def __init__(self, dim_in, dim_out=None, expand=2):
18
  super().__init__()
19
- dim_out = dim_out or dim_in
20
- hidden = int(dim_in * expand)
21
- self.W_gate = nn.Linear(dim_in, hidden, bias=False)
22
- self.W_up = nn.Linear(dim_in, hidden, bias=False)
23
- self.W_down = nn.Linear(hidden, dim_out, bias=False)
24
  def forward(self, x):
25
- return self.W_down(F.silu(self.W_gate(x)) * self.W_up(x))
 
 
26
 
27
 
28
- class SelectiveLRU(nn.Module):
29
  """
30
- Simplified selective linear recurrent cell.
31
- h_t = decay_t * h_{t-1} + (1 - decay_t) * (x_t · B_proj)
32
- y_t = C_proj(h_t) + D_skip * x_t
 
33
 
34
- Key: B_t, C_t, decay_t are ALL input-dependent (selective).
35
- Merges RWKV's time-mixing with Mamba's selective SSM in a tiny form.
36
  """
37
- def __init__(self, dim, d_state=64, expand=2):
38
  super().__init__()
39
  self.dim = dim
40
  self.d_state = d_state
41
- self.expand = expand
42
- hidden = dim * expand
43
 
44
- # Linear projections (all fused: input -> [B, C, delta, skip, gate])
45
- self.in_proj = nn.Linear(dim, hidden * 4, bias=False)
 
46
 
47
- # State transition
48
- self.W_B = nn.Linear(hidden, d_state) # input -> state
49
- self.W_C = nn.Linear(d_state, hidden) # state -> output
50
- self.log_A = nn.Parameter(torch.randn(d_state)) # stable: -exp(log_A)
51
- self.D = nn.Parameter(torch.randn(hidden)) # skip connection
52
 
53
- # Output gate
54
- self.out_gate = nn.Linear(dim, hidden, bias=False)
55
- self.out_proj = nn.Linear(hidden, dim, bias=False)
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def forward(self, x):
58
- """x: (B, L, dim) -> y: (B, L, dim)"""
59
- B, L, dim = x.shape
60
-
61
- # Input-dependent gates
62
- gates = self.in_proj(x) # (B, L, hidden*4)
63
- B_gate, C_gate, delta, skip = gates.chunk(4, dim=-1) # each (B, L, hidden)
64
-
65
- # Selective parameters (per-token, per-channel)
66
- B_t = torch.tanh(B_gate) # bound selective B
67
- C_t = torch.tanh(C_gate) # bound selective C
68
- delta_t = F.softplus(delta) # positive time-step
69
- decay = torch.exp(-delta_t * torch.exp(self.log_A).view(1, 1, -1)) # (B, L, hidden)
70
-
71
- # Recurrent scan in hidden dimension (vectorized over batch)
72
- # h_t: (B, d_state)
73
- # We process per-token
74
- state = torch.zeros(B, self.d_state, device=x.device, dtype=x.dtype)
75
- outputs = []
76
- for t in range(L):
77
- # state update
78
- Bx = einsum(B_t[:, t], x[:, t], 'b h, b d -> b h')
79
- hmm this isn't right...
80
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Spatial Recurrent Block (SRB) -- inspired by RWKV + VMamba-UNet.
3
+ Uses depthwise conv for spatial token-shift and channel-wise decay mixing.
4
+ Pure PyTorch, no heavy deps.
5
  """
6
  import math
7
  import torch
 
9
  import torch.nn.functional as F
10
 
11
 
12
+ def rmsnorm(x, eps=1e-6):
13
+ return x * torch.rsqrt(x.mean(dim=-1, keepdim=True) ** 2 + eps)
14
 
15
 
16
+ class RMSNorm(nn.Module):
17
+ def __init__(self, dim, eps=1e-6):
 
18
  super().__init__()
19
+ self.eps = eps
20
+ self.gamma = nn.Parameter(torch.ones(dim))
 
 
 
21
  def forward(self, x):
22
+ # x: (..., dim)
23
+ norm = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1])
24
+ return self.gamma * x / (norm + self.eps)
25
 
26
 
27
+ class SpatialRecurrentBlock(nn.Module):
28
  """
29
+ A block that:
30
+ 1. Token-shifts spatially with a 3x3 depthwise conv (spatial mixing)
31
+ 2. Applies channel-wise decay-mixing (RWKV time-mix equivalent)
32
+ 3. Returns residual output
33
 
34
+ Channels always treated as sequence dim for the SSM part.
35
+ Spatial dims are folded into batch.
36
  """
37
+ def __init__(self, dim, d_state=64, drop_path=0.0):
38
  super().__init__()
39
  self.dim = dim
40
  self.d_state = d_state
 
 
41
 
42
+ # Spatial token shift (depthwise 3x3 conv)
43
+ self.spatial_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
44
+ self.spatial_norm = RMSNorm(dim)
45
 
46
+ # Input-dependent selective projections
47
+ self.x_proj_in = nn.Linear(dim, d_state * 2 + 1, bias=False) # [B, C, decay]
48
+ self.x_proj_A = nn.Parameter(torch.arange(d_state).float() * -math.log(10000) / d_state) # S4D init
 
 
49
 
50
+ # State-to-output
51
+ self.state_out = nn.Linear(d_state, dim, bias=False)
52
+ self.D = nn.Parameter(torch.ones(dim) * 1.0) # skip
53
+
54
+ # Post-MLP
55
+ self.mlp = nn.Sequential(
56
+ RMSNorm(dim),
57
+ nn.Linear(dim, dim * 2),
58
+ nn.GELU(),
59
+ nn.Linear(dim * 2, dim),
60
+ )
61
+
62
+ # Drop path (stochastic depth)
63
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
64
 
65
  def forward(self, x):
66
+ """
67
+ x: (B, C, H, W)
68
+ Returns: (B, C, H, W)
69
+ """
70
+ B, C, H, W = x.shape
71
+ shortcut = x
72
+
73
+ # --- SPATIAL TOKEN SHIFT ---
74
+ x_shift = self.spatial_conv(x) # (B, C, H, W)
75
+
76
+ # Flatten to sequence for SSM: (B*H*W, C)
77
+ x_flat = rearrange_for_ssm(x_shift) # (BHW, C)
78
+
79
+ # --- SELECTIVE STATE SPACE (MAMBA-style) ---
80
+ # Per-token selectivity
81
+ params = self.x_proj_in(x_flat) # (BHW, d_state*2 + 1)
82
+ B_param, C_param, delta_log = params.split([self.d_state, self.d_state, 1], dim=-1)
83
+
84
+ delta = F.softplus(delta_log.squeeze(-1)) # (BHW,)
85
+
86
+ # Discretize A
87
+ A = -torch.exp(self.x_proj_A) # negative for stability
88
+ A_bar = torch.exp(delta.unsqueeze(-1) * A) # (BHW, d_state)
89
+
90
+ # Input-to-state
91
+ Bx = B_param * x_flat # (BHW, d_state)
92
+
93
+ # RECURRENT SCAN (vectorized over batch)
94
+ state = torch.zeros(B * H * W, self.d_state, device=x.device, dtype=x.dtype)
95
+ states = []
96
+ for t in range(C): # scan along channel dim (like token dim)
97
+ state = A_bar * state + Bx.unsqueeze(1) # broadcasting issue
98
+ # NO -- need to redesign. This is wrong.
99
+ pass
100
+
101
+ # Actually, the canonical approach for vision: treat spatial positions as tokens.
102
+ # Each pixel = one token. Scan in raster order, or better: bidirectional scan.
103
+ # BUT for a 32x32 image that's 1024 tokens. Scanning in PyTorch sequentially is SLOW.
104
+
105
+ # SOLUTION: Use a DIFFERENT architecture altogether.
106
+ # Instead of token-scanning SSM, use RWKV's time-mixing formula generalized to 2D:
107
+ # y_i = sigmoid(gate_i) * (decay_i * prev_i + (1-decay_i) * x_i)
108
+ # where prev_i is previous token mixed spatially via depthwise conv.
109
+ #
110
+ # This avoids seq scan: all operations are parallel.
111
+
112
+ # REWRITE:
113
+ pass