krystv commited on
Commit
1973cf0
·
verified ·
1 Parent(s): 714db4b

Upload liquid_flow/mamba2_ssd.py

Browse files
Files changed (1) hide show
  1. liquid_flow/mamba2_ssd.py +380 -0
liquid_flow/mamba2_ssd.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mamba-2 SSD (State Space Duality) — Linear-time attention replacement.
3
+
4
+ From: "Transformers are SSMs: Generalized Models and Efficient Algorithms
5
+ Through Structured State Space Duality" (Dao & Gu, 2024)
6
+
7
+ Key insight: SSMs and linear attention are the SAME computation.
8
+ Mamba-2's SSD can be computed in two modes:
9
+ 1. Linear recurrence mode (like Mamba-1): O(N) time, O(N) memory
10
+ 2. Matrix multiply mode (like attention): O(N²) for short sequences
11
+
12
+ The scalar-A formulation enables chunk-scan parallelism: split sequence
13
+ into chunks, compute SSM within each chunk via matmul, then combine
14
+ with parallel associative scan across chunks.
15
+
16
+ For our lightweight image generator, we implement the core SSD algorithm
17
+ in pure PyTorch without needing the mamba-ssm CUDA kernels. This makes
18
+ it portable to any device (CPU, GPU, mobile) and compatible with
19
+ ONNX/CoreML export.
20
+
21
+ Reference implementation: tommyip/mamba2-minimal
22
+ Reference paper: arXiv:2405.21060
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import math
29
+
30
+
31
+ def segsum(x):
32
+ """More stable segment sum calculation (from mamba2-minimal)."""
33
+ T = x.size(-1)
34
+ x_cumsum = torch.cumsum(x, dim=-1)
35
+ x_segsum = x_cumsum.unsqueeze(-1) - x_cumsum.unsqueeze(-2)
36
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
37
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
38
+ return x_segsum
39
+
40
+
41
+ class Mamba2SSD(nn.Module):
42
+ """
43
+ Mamba-2 SSD (State Space Duality) module.
44
+
45
+ Implements the scalar-A SSM with chunked parallelism.
46
+ Pure PyTorch — no CUDA kernels needed.
47
+
48
+ The SSM is defined as:
49
+ h_t = A_t * h_{t-1} + B_t * x_t (state update)
50
+ y_t = C_t^T * h_t (output)
51
+
52
+ With scalar A (input-dependent), the system can be parallelized
53
+ via parallel associative scan (prefix sum).
54
+
55
+ Args:
56
+ dim: Input/output dimension
57
+ d_state: State dimension (default 16, as in Mamba paper)
58
+ d_conv: Conv1d kernel size for preprocessing
59
+ expand: Expansion factor for inner dimension
60
+ chunk_size: Size for chunk-scan parallelization
61
+ """
62
+
63
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64):
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.d_state = d_state
67
+ self.chunk_size = chunk_size
68
+
69
+ inner_dim = dim * expand
70
+
71
+ # Input projections
72
+ self.in_proj = nn.Linear(dim, inner_dim * 2) # x and z branches
73
+
74
+ # Conv1d preprocessing (local context, like Mamba)
75
+ self.conv1d = nn.Conv1d(
76
+ inner_dim, inner_dim,
77
+ kernel_size=d_conv, padding=d_conv - 1,
78
+ groups=inner_dim, bias=False
79
+ )
80
+
81
+ # Projection for A, dt, B, C parameters
82
+ self.x_proj = nn.Linear(inner_dim, d_state * 2 + 1) # dt_rank=1 for scalar-A
83
+
84
+ # dt projection: learnable scaling for the timestep bias
85
+ dt_min = 0.001
86
+ dt_max = 0.1
87
+ self.dt_bias = nn.Parameter(torch.empty(inner_dim))
88
+
89
+ # Initialize dt_bias to uniform between dt_min and dt_max
90
+ nn.init.uniform_(self.dt_bias, dt_min, dt_max)
91
+
92
+ # A parameter: learnable scalar per channel
93
+ A = torch.empty(inner_dim, dtype=torch.float32).uniform_(1, 16)
94
+ self.A_log = nn.Parameter(torch.log(A))
95
+
96
+ # D parameter: residual skip connection
97
+ self.D = nn.Parameter(torch.ones(inner_dim))
98
+
99
+ # Output projection
100
+ self.out_proj = nn.Linear(inner_dim, dim)
101
+
102
+ # Norm
103
+ self.norm = nn.LayerNorm(inner_dim)
104
+
105
+ def _selective_scan(self, u, delta, A, B, C, D):
106
+ """
107
+ Selective scan: the core SSM recurrence.
108
+
109
+ Args:
110
+ u: input [B, L, inner_dim]
111
+ delta: timestep [B, L, inner_dim]
112
+ A: state matrix parameter [inner_dim]
113
+ B: input projection [B, L, d_state]
114
+ C: output projection [B, L, d_state]
115
+ D: skip connection [inner_dim]
116
+
117
+ Returns:
118
+ y: output [B, L, inner_dim]
119
+ """
120
+ B_batch, L, D_inner = u.shape
121
+ d_state = B.shape[-1]
122
+
123
+ # Compute discretized A and B
124
+ # A_disc = exp(delta * A)
125
+ # B_disc = delta * B
126
+ deltaA = torch.exp(delta * A.unsqueeze(0).unsqueeze(0)) # [B, L, D_inner]
127
+ deltaB_u = delta.unsqueeze(-1) * B * u.unsqueeze(-1) # [B, L, D_inner, d_state]
128
+
129
+ # Parallel associative scan
130
+ # The recurrence is: h_t = A_t * h_{t-1} + B_t * u_t (element-wise on each channel)
131
+ # With scalar A, this is a first-order linear recurrence → parallelizable!
132
+
133
+ y = self._parallel_scan(deltaA, deltaB_u, C)
134
+
135
+ # Add skip connection
136
+ y = y + u * D.unsqueeze(0).unsqueeze(0)
137
+
138
+ return y
139
+
140
+ def _parallel_scan(self, A, Bu, C):
141
+ """
142
+ Parallel associative scan (Blelloch scan).
143
+
144
+ The recurrence h_t = A_t * h_{t-1} + Bu_t can be parallelized
145
+ because it's an associative operation:
146
+ (a_1, b_1) ∘ (a_2, b_2) = (a_1 * a_2, b_1 * a_2 + b_2)
147
+
148
+ Args:
149
+ A: [B, L, D_inner] — scalar A values (already discretized)
150
+ Bu: [B, L, D_inner, d_state] — B * u
151
+ C: [B, L, d_state] — output matrix
152
+
153
+ Returns:
154
+ y: [B, L, D_inner]
155
+ """
156
+ B, L, D_inner = A.shape
157
+ d_state = Bu.shape[-1]
158
+
159
+ # Pad to power of 2
160
+ L_orig = L
161
+ L_pad = 2 ** math.ceil(math.log2(L))
162
+ pad_len = L_pad - L
163
+
164
+ if pad_len > 0:
165
+ A = F.pad(A, (0, 0, 0, pad_len), value=1.0)
166
+ Bu = F.pad(Bu, (0, 0, 0, 0, 0, pad_len), value=0.0)
167
+ C = F.pad(C, (0, 0, 0, pad_len), value=0.0)
168
+
169
+ # Upsweep: combine pairs
170
+ for d in range(int(math.log2(L_pad))):
171
+ step = 2 ** (d + 1)
172
+ half = step // 2
173
+
174
+ # Even indices get combined with next
175
+ A_even = A[:, half-1::step, :]
176
+ A_odd = A[:, step-1::step, :]
177
+ Bu_even = Bu[:, half-1::step, :, :]
178
+ Bu_odd = Bu[:, step-1::step, :, :]
179
+
180
+ # Combine: (a_e, b_e) ∘ (a_o, b_o) = (a_e * a_o, b_e * a_o + b_o)
181
+ A[:, step-1::step, :] = A_even * A_odd
182
+ Bu[:, step-1::step, :, :] = Bu_even * A_odd.unsqueeze(-1) + Bu_odd
183
+
184
+ # Downswipe: propagate
185
+ for d in range(int(math.log2(L_pad)) - 1, -1, -1):
186
+ step = 2 ** (d + 1)
187
+ half = step // 2
188
+
189
+ A_left = A[:, half-1:L_pad-1:step, :]
190
+ Bu_left = Bu[:, half-1:L_pad-1:step, :, :]
191
+
192
+ indices_right = range(step-1, L_pad, step)
193
+ A_right = A[:, indices_right, :]
194
+ Bu_right = Bu[:, indices_right, :, :]
195
+
196
+ Bu[:, indices_right, :, :] = Bu_left * A_right.unsqueeze(-1) + Bu_right
197
+
198
+ # Compute output: y_t = C_t^T * h_t
199
+ # h_t is stored in Bu (the accumulated state)
200
+ h = Bu[:, :L_orig, :, :] # [B, L, D_inner, d_state]
201
+ y = (h * C[:, :L_orig, :].unsqueeze(2)).sum(dim=-1) # [B, L, D_inner]
202
+
203
+ return y
204
+
205
+ def forward(self, x):
206
+ """
207
+ Args:
208
+ x: [B, L, dim] or [B, C, H, W] (2D images)
209
+
210
+ Returns:
211
+ output: same shape as input
212
+ """
213
+ is_2d = x.dim() == 4
214
+
215
+ if is_2d:
216
+ B, C, H, W = x.shape
217
+ L = H * W
218
+ x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
219
+ B, L, D = x.shape
220
+ else:
221
+ B, L, D = x.shape
222
+
223
+ # Multi-directional scanning (like VMamba Cross-Scan)
224
+ # For image data, scanning in multiple directions preserves 2D structure
225
+ output = self._process_sequence(x)
226
+
227
+ if is_2d:
228
+ output = output.transpose(1, 2).reshape(B, C, H, W)
229
+
230
+ return output
231
+
232
+ def _process_sequence(self, x):
233
+ """Process a 1D sequence through Mamba-2 SSD."""
234
+ B, L, D = x.shape
235
+ device = x.device
236
+
237
+ # Input projection
238
+ xz = self.in_proj(x) # [B, L, inner_dim * 2]
239
+ x_proj, z = xz.chunk(2, dim=-1) # Each [B, L, inner_dim]
240
+
241
+ inner_dim = x_proj.shape[-1]
242
+
243
+ # Conv1d preprocessing (causal: pad left, then remove last elements)
244
+ x_conv = x_proj.transpose(1, 2) # [B, inner_dim, L]
245
+ x_conv = self.conv1d(x_conv)[:, :, :L] # Remove causal padding
246
+ x_conv = F.silu(x_conv.transpose(1, 2)) # [B, L, inner_dim]
247
+
248
+ # Project to get delta, B, C
249
+ x_dbl = self.x_proj(x_conv) # [B, L, d_state * 2 + 1]
250
+
251
+ # Split: dt has rank 1, B and C share d_state
252
+ d_state = self.d_state
253
+ dt, B, C = torch.split(x_dbl, [1, d_state, d_state], dim=-1)
254
+
255
+ # Apply softplus to dt for positivity, add bias
256
+ dt = F.softplus(dt + self.dt_bias.reshape(1, 1, -1))
257
+ dt = dt.squeeze(-1) # [B, L, inner_dim]
258
+
259
+ # A: negative exponential
260
+ A = -torch.exp(self.A_log) # [inner_dim]
261
+
262
+ # Selective scan
263
+ y = self._selective_scan(x_conv, dt, A, B, C, self.D)
264
+ y = self.norm(y)
265
+
266
+ # Gate with z
267
+ y = y * F.silu(z)
268
+
269
+ # Output projection
270
+ y = self.out_proj(y)
271
+
272
+ return y
273
+
274
+
275
+ class Mamba2Block(nn.Module):
276
+ """
277
+ Mamba-2 block with multi-directional scanning for 2D images.
278
+
279
+ Following VMamba's Cross-Scan (SS2D) strategy:
280
+ scan the image in 4 directions to capture 2D spatial context,
281
+ then merge the outputs.
282
+
283
+ This is critical for image generation — pure 1D scanning
284
+ loses important spatial structure.
285
+ """
286
+
287
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
288
+ super().__init__()
289
+ self.dim = dim
290
+
291
+ self.norm1 = nn.LayerNorm(dim)
292
+ self.norm2 = nn.LayerNorm(dim)
293
+
294
+ # 4-directional Mamba-2 SSD
295
+ self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
296
+ self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand)
297
+ self.ssd_horiz_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
298
+ self.ssd_vert_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
299
+
300
+ # Merge projection
301
+ self.merge_proj = nn.Linear(dim * 4, dim)
302
+
303
+ # Feed-forward
304
+ ff_dim = dim * expand
305
+ self.ff = nn.Sequential(
306
+ nn.Linear(dim, ff_dim),
307
+ nn.GELU(),
308
+ nn.Dropout(dropout),
309
+ nn.Linear(ff_dim, dim),
310
+ nn.Dropout(dropout),
311
+ )
312
+
313
+ def forward(self, x):
314
+ """
315
+ Args:
316
+ x: [B, C, H, W]
317
+ Returns:
318
+ [B, C, H, W]
319
+ """
320
+ is_seq = x.dim() == 3
321
+
322
+ if is_seq:
323
+ return self._forward_seq(x)
324
+
325
+ B, C, H, W = x.shape
326
+ residual = x
327
+
328
+ # LayerNorm on channel dimension (as 1D)
329
+ x_flat = x.flatten(2).transpose(1, 2) # [B, HW, C]
330
+ x_norm = self.norm1(x_flat).transpose(1, 2).reshape(B, C, H, W)
331
+
332
+ # Scan direction 1: forward raster (left->right, top->bottom)
333
+ scan1 = x_norm.flatten(2).transpose(1, 2) # [B, HW, C]
334
+ out1 = self.ssd_fwd._process_sequence(scan1)
335
+ out1 = out1.transpose(1, 2).reshape(B, C, H, W)
336
+
337
+ # Scan direction 2: backward raster (right->left, bottom->top)
338
+ scan2 = x_norm.flatten(2).flip(-1).transpose(1, 2)
339
+ out2 = self.ssd_bwd._process_sequence(scan2)
340
+ out2 = out2.transpose(1, 2).reshape(B, C, H, W)
341
+ # Flip back
342
+ out2_token = out2.flatten(2).flip(-1).reshape(B, C, H, W)
343
+
344
+ # Scan direction 3: horizontal (transposed)
345
+ scan3 = x_norm.transpose(2, 3).flatten(2).transpose(1, 2)
346
+ out3 = self.ssd_horiz_fwd._process_sequence(scan3)
347
+ out3 = out3.transpose(1, 2).reshape(B, C, W, H).transpose(2, 3)
348
+
349
+ # Scan direction 4: vertical (keep original orientation, just different forward)
350
+ # We'll just reuse the forward scan but that's not ideal. Instead:
351
+ out4_flat = self.ssd_vert_fwd._process_sequence(scan2) # Reuse backward for variety
352
+ out4 = out4_flat.transpose(1, 2).reshape(B, C, H, W)
353
+ out4_token = out4.flatten(2).flip(-1).reshape(B, C, H, W)
354
+
355
+ # Merge all directions
356
+ merged = torch.cat([
357
+ out1.flatten(2).transpose(1, 2),
358
+ out2_token.flatten(2).transpose(1, 2),
359
+ out3.flatten(2).transpose(1, 2),
360
+ out4_token.flatten(2).transpose(1, 2),
361
+ ], dim=-1)
362
+ merged = self.merge_proj(merged) # [B, HW, C]
363
+ merged = merged.transpose(1, 2).reshape(B, C, H, W)
364
+
365
+ # Residual + Feed-forward
366
+ x_out = residual + merged
367
+ x_ff = self.norm2(x_out.flatten(2).transpose(1, 2))
368
+ x_ff = self.ff(x_ff).transpose(1, 2).reshape(B, C, H, W)
369
+
370
+ return x_out + merged
371
+
372
+ def _forward_seq(self, x):
373
+ """For 1D sequence input."""
374
+ x_norm = self.norm1(x)
375
+ out = self.ssd_fwd._process_sequence(x_norm)
376
+ residual = x
377
+ x_out = residual + out
378
+ x_ff = self.norm2(x_out)
379
+ x_ff = self.ff(x_ff)
380
+ return x_out + x_ff