asdf98 commited on
Commit
192c527
·
verified ·
1 Parent(s): b43a040

Upload liqmamba/mamba2_ssd.py

Browse files
Files changed (1) hide show
  1. liqmamba/mamba2_ssd.py +235 -200
liqmamba/mamba2_ssd.py CHANGED
@@ -1,270 +1,208 @@
1
  """
2
  Mamba-2 SSD Block with Liquid (CfC) Gating
3
 
4
- Implements the Structured State Space Duality (SSD) from Mamba-2 paper
5
- (Dao & Gu, 2024) with CfC-based gating instead of standard SiLU.
6
-
7
- The SSD framework unifies SSMs and attention through structured
8
- semiseparable matrices, giving us:
9
- - Linear-time training (no quadratic attention)
10
- - Parallelizable scans (no sequential loops at train time)
11
- - 2-8x faster than original Mamba
12
-
13
- For 2D images: we use multi-directional scanning patterns
14
- with learnable end-of-line tokens (from DiM paper).
15
  """
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
- from typing import Optional
21
 
22
  from .cfc import CfCGate
23
 
24
 
25
  class Mamba2SSDBlock(nn.Module):
26
  """
27
- Single Mamba-2 SSD block with CfC liquid gating.
28
-
29
- Architecture (per Mamba-2):
30
- x -> Norm -> [in_proj -> conv1d -> SiLU] -> SSD scan -> CfC-gate -> out_proj
31
-
32
- Key changes from standard Mamba-2:
33
- - SiLU activation replaced with CfC-gate for adaptive per-token computation
34
- - Optional CfC state modulation in the SSD path
35
 
36
  Args:
37
  dim: Hidden dimension
38
- d_state: SSM state dimension (default 16 for lightweight)
39
  d_conv: Convolution kernel size
40
- expand: Expansion factor for inner dimension
41
- n_groups: Number of groups for head structure (like GQA)
42
- chunk_size: Scan chunk size for efficient computation
43
- use_cfc_modulation: Whether to apply CfC to SSM state transitions
44
  """
45
-
46
- def __init__(
47
- self,
48
- dim: int,
49
- d_state: int = 16,
50
- d_conv: int = 4,
51
- expand: int = 2,
52
- n_groups: int = 1,
53
- chunk_size: int = 256,
54
- use_cfc_modulation: bool = True,
55
- dropout: float = 0.0,
56
- ):
57
  super().__init__()
58
-
59
  self.dim = dim
60
  self.d_state = d_state
61
- self.d_conv = d_conv
62
- self.expand = expand
63
  self.inner_dim = dim * expand
64
  self.n_groups = n_groups
65
- self.chunk_size = chunk_size
66
  self.use_cfc_modulation = use_cfc_modulation
67
 
68
- # Layer normalization (RMSNorm style for efficiency)
69
  self.norm = nn.RMSNorm(dim)
70
 
71
- # Input projection: x -> (x, z) where z is the gate branch
72
  self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
73
 
74
- # 1D convolution for local feature mixing
75
- self.conv1d = nn.Conv1d(
76
- in_channels=self.inner_dim,
77
- out_channels=self.inner_dim,
78
- kernel_size=d_conv,
79
- groups=self.inner_dim,
80
- padding=d_conv - 1,
81
- )
82
 
83
- # SSD parameters
84
- # A: diagonal state transition (learned per-head)
85
- # Using log-space for stability
86
- self.A_log = nn.Parameter(
87
- torch.randn(self.inner_dim // n_groups, d_state) * 0.01
88
- )
89
- # D: skip connection parameter
90
- self.D = nn.Parameter(torch.ones(self.inner_dim // n_groups))
91
 
92
- # Projections for B and C (input-dependent)
 
 
 
93
  dt_rank = max(1, dim // 16)
94
- self.dt_proj = nn.Linear(dim, self.inner_dim)
95
- self.B_proj = nn.Linear(dim, d_state * n_groups, bias=False)
96
- self.C_proj = nn.Linear(dim, d_state * n_groups, bias=False)
 
 
 
97
 
98
- # CfC gate (replaces SiLU)
99
- if use_cfc_modulation:
100
- self.cfc_gate = CfCGate(self.inner_dim)
101
- else:
102
- self.gate_act = nn.SiLU()
103
 
104
- # Output projection
105
  self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
106
-
107
- # Dropout
108
  self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
109
 
110
- # For 2D scan patterns: learnable EOL (end-of-line) tokens
111
- self.register_buffer("init_conv_state", None)
112
-
113
- def _apply_conv1d(self, x: torch.Tensor) -> torch.Tensor:
114
- """Apply 1D causal convolution with proper padding handling."""
115
- # x: (B, L, inner_dim) -> (B, inner_dim, L)
116
- x = x.transpose(1, 2)
117
  x = self.conv1d(x)
118
- # Remove extra padding (causal: keep only first L outputs)
119
- x = x[..., :x.shape[-1] - (self.d_conv - 1)]
120
- # Apply SiLU activation (kept here as per Mamba-2 spec)
121
  x = F.silu(x)
122
- # (B, inner_dim, L) -> (B, L, inner_dim)
123
- x = x.transpose(1, 2)
124
- return x
125
-
126
- def _selective_scan(self, u: torch.Tensor, delta: torch.Tensor,
127
- A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
128
- D: torch.Tensor) -> torch.Tensor:
129
  """
130
- Selective scan operation (SSD core).
131
 
132
- This is the key computation that replaces attention.
133
- Uses chunked parallel scan for O(N) complexity.
134
 
135
- Args:
136
- u: Input (B, L, inner_dim)
137
- delta: Time step (B, L, inner_dim)
138
- A: State matrix (inner_dim//n_groups, d_state)
139
- B: Input projection (B, L, n_groups*d_state)
140
- C: Output projection (B, L, n_groups*d_state)
141
- D: Skip connection (inner_dim//n_groups,)
142
- Returns:
143
- y: Output (B, L, inner_dim)
144
  """
145
- B, L, D = u.shape
146
  G = self.n_groups
147
  N = self.d_state
148
- DG = D // G # dim per group
149
 
150
  # Reshape for grouped processing
151
- u = u.view(B, L, G, DG)
152
- delta = delta.view(B, L, G, DG)
153
- B = B.view(B, L, G, N)
154
- C = C.view(B, L, G, N)
155
-
156
- # Discretize A: A_bar = exp(delta * A)
157
- # A_log shape: (DG, N), delta shape: (B, L, G, DG)
158
- A = -torch.exp(A_log) # (DG, N) - keep A negative for stability
159
- deltaA = torch.exp(delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, G, DG, N)
160
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(-2) # (B, L, G, DG, N)
161
-
162
- # Chunked parallel scan (associative scan)
163
- # For simplicity and Colab compatibility, we use a
164
- # memory-efficient sequential scan rather than requiring
165
- # custom CUDA kernels
166
- y = self._parallel_scan(u, deltaA, deltaB, C)
 
167
 
168
  # Add skip connection
169
  y = y + u * D.view(1, 1, G, DG)
170
-
171
- return y.view(B, L, D)
172
 
173
- def _parallel_scan(self, u: torch.Tensor, deltaA: torch.Tensor,
174
- deltaB: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
175
  """
176
- Efficient parallel scan implementation using PyTorch operations.
177
 
178
- This implements the SSD (State Space Duality) scan that is:
179
- - Parallelizable (no sequential dependency at train time)
180
- - Linear in sequence length (O(N) instead of O(N²))
181
 
182
- Uses the matrix form: y_i = sum_{j<=i} C_i * A_{i:j} * B_j * u_j
183
  """
184
- B, L, G, DG, N = deltaB.shape
185
-
186
- # Compute the SSD kernel using cumulative products
187
- # This leverages PyTorch's efficient cumprod/cumsum operations
188
- x = deltaB * u.unsqueeze(-1) # (B, L, G, DG, N)
189
-
190
- # Parallel prefix scan using associative property
191
- # We use a work-efficient scan algorithm
192
- y = torch.zeros_like(x)
193
- y[:, 0] = x[:, 0]
194
-
195
- # Segment the scan into chunks for memory efficiency
196
- chunk_size = min(self.chunk_size, L)
197
-
198
- for start in range(0, L, chunk_size):
199
- end = min(start + chunk_size, L)
200
- chunk_len = end - start
201
-
202
- # Within chunk: compute using matrix operations
203
- if start == 0:
204
- # First chunk: standard scan
205
- for i in range(1, chunk_len):
206
- idx = start + i
207
- y[:, idx] = deltaA[:, idx] * y[:, idx-1] + x[:, idx]
208
- else:
209
- # Subsequent chunks: carry over final state from previous chunk
210
- carry = y[:, start-1:start] # (B, 1, G, DG, N)
211
- for i in range(chunk_len):
212
- idx = start + i
213
- if i == 0:
214
- y[:, idx] = deltaA[:, idx] * carry.squeeze(1) + x[:, idx]
215
- else:
216
- y[:, idx] = deltaA[:, idx] * y[:, idx-1] + x[:, idx]
 
 
 
 
 
 
 
 
217
 
218
  # Project through C
219
- y = (y * C.unsqueeze(-2)).sum(dim=-1) # (B, L, G, DG)
220
-
221
  return y
222
 
223
- def forward(self, x: torch.Tensor,
224
- scan_direction: str = "forward") -> torch.Tensor:
225
- """
226
- Args:
227
- x: (B, L, D) input sequence
228
- scan_direction: 'forward' or 'reverse' (for bidirectional scanning)
229
- Returns:
230
- (B, L, D) output
231
- """
232
  residual = x
 
233
 
234
- # Optional: reverse sequence for bidirectional scanning
235
  if scan_direction == "reverse":
236
  x = torch.flip(x, dims=[1])
237
 
238
  # Pre-norm
239
- x = self.norm(x)
240
 
241
- # Input projection -> split into u and z branches
242
- proj = self.in_proj(x) # (B, L, 2*inner_dim)
243
  u, z = proj.chunk(2, dim=-1)
244
 
245
- # 1D convolution on u branch
246
  u = self._apply_conv1d(u)
247
 
248
- # Compute SSD parameters
249
- delta = F.softplus(self.dt_proj(x)) # (B, L, inner_dim)
250
- B = self.B_proj(x) # (B, L, n_groups*d_state)
251
- C = self.C_proj(x) # (B, L, n_groups*d_state)
252
 
253
  # Selective scan
254
- u = self._selective_scan(u, delta, -torch.exp(self.A_log), B, C, self.D)
255
 
256
- # CfC gating (replaces SiLU in standard Mamba-2)
257
- if self.use_cfc_modulation:
258
- gate = self.cfc_gate(x)
259
  u = u * gate
260
  else:
261
  u = u * F.silu(z)
262
 
263
- # Output projection
264
  out = self.out_proj(u)
265
  out = self.dropout(out)
266
 
267
- # Restore direction
268
  if scan_direction == "reverse":
269
  out = torch.flip(out, dims=[1])
270
 
@@ -272,22 +210,119 @@ class Mamba2SSDBlock(nn.Module):
272
 
273
 
274
  class BidirectionalMambaBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  """
276
- Bidirectional Mamba block for 2D image processing.
277
 
278
- Combines forward and reverse scans to give each token
279
- access to both left and right context.
280
- """
 
 
281
 
282
- def __init__(self, dim: int, **kwargs):
 
 
283
  super().__init__()
284
- self.forward_ssd = Mamba2SSDBlock(dim, **kwargs)
285
- self.reverse_ssd = Mamba2SSDBlock(dim, **kwargs)
286
- self.merge = nn.Linear(dim * 2, dim, bias=False)
287
- self.norm = nn.LayerNorm(dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- def forward(self, x: torch.Tensor) -> torch.Tensor:
290
- fwd = self.forward_ssd(x, "forward")
291
- rev = self.reverse_ssd(x, "reverse")
292
- out = self.merge(torch.cat([fwd, rev], dim=-1))
293
- return self.norm(out)
 
1
  """
2
  Mamba-2 SSD Block with Liquid (CfC) Gating
3
 
4
+ Implements Structured State Space Duality (SSD) from Mamba-2 with CfC gating.
5
+ O(N) complexity, fully parallelizable scan (no sequential loops at train time).
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
11
 
12
  from .cfc import CfCGate
13
 
14
 
15
  class Mamba2SSDBlock(nn.Module):
16
  """
17
+ Mamba-2 SSD block with CfC liquid gating.
 
 
 
 
 
 
 
18
 
19
  Args:
20
  dim: Hidden dimension
21
+ d_state: SSM state dimension (default 16)
22
  d_conv: Convolution kernel size
23
+ expand: Expansion factor
24
+ n_groups: Number of groups (like GQA heads)
25
+ use_cfc_modulation: Use CfC gate instead of SiLU
 
26
  """
27
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2, n_groups=1,
28
+ use_cfc_modulation=True, dropout=0.0):
 
 
 
 
 
 
 
 
 
 
29
  super().__init__()
 
30
  self.dim = dim
31
  self.d_state = d_state
 
 
32
  self.inner_dim = dim * expand
33
  self.n_groups = n_groups
34
+ self.d_inner_group = self.inner_dim // n_groups
35
  self.use_cfc_modulation = use_cfc_modulation
36
 
37
+ # Pre-norm (RMSNorm)
38
  self.norm = nn.RMSNorm(dim)
39
 
40
+ # Input projection (x -> x, z gate branch)
41
  self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
42
 
43
+ # 1D conv for local mixing
44
+ self.conv1d = nn.Conv1d(self.inner_dim, self.inner_dim, d_conv,
45
+ groups=self.inner_dim, padding=d_conv-1)
 
 
 
 
 
46
 
47
+ # A parameter (log-space for stability)
48
+ self.A_log = nn.Parameter(torch.randn(n_groups, d_state) * 0.01)
 
 
 
 
 
 
49
 
50
+ # D skip parameter
51
+ self.D = nn.Parameter(torch.ones(n_groups))
52
+
53
+ # dt / B / C projections (input-dependent)
54
  dt_rank = max(1, dim // 16)
55
+ self.dt_proj = nn.Sequential(
56
+ nn.Linear(dim, dt_rank, bias=False),
57
+ nn.Linear(dt_rank, self.inner_dim, bias=True)
58
+ )
59
+ self.B_proj = nn.Linear(dim, n_groups * d_state, bias=False)
60
+ self.C_proj = nn.Linear(dim, n_groups * d_state, bias=False)
61
 
62
+ # CfC gate or SiLU
63
+ self.cfc_gate = CfCGate(self.inner_dim) if use_cfc_modulation else None
 
 
 
64
 
65
+ # Output
66
  self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
 
 
67
  self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
68
 
69
+ def _apply_conv1d(self, x):
70
+ """Causal 1D conv."""
71
+ B, L, D = x.shape
72
+ x = x.transpose(1, 2) # (B, D, L)
 
 
 
73
  x = self.conv1d(x)
74
+ x = x[..., :L] # causal: trim padding
 
 
75
  x = F.silu(x)
76
+ return x.transpose(1, 2) # (B, L, D)
77
+
78
+ def _selective_scan(self, u, delta, A, B, C, D):
 
 
 
 
79
  """
80
+ Selective scan using SSD (structured state space duality).
81
 
82
+ The key insight from Mamba-2: SSD = matrix form of SSM,
83
+ computable via associative scan in O(N) parallel time.
84
 
85
+ Uses PyTorch native ops (no custom CUDA needed).
 
 
 
 
 
 
 
 
86
  """
87
+ B_sz, L, ID = u.shape
88
  G = self.n_groups
89
  N = self.d_state
90
+ DG = ID // G
91
 
92
  # Reshape for grouped processing
93
+ u = u.view(B_sz, L, G, DG) # (B, L, G, DG)
94
+ delta = delta.view(B_sz, L, G, DG) # (B, L, G, DG)
95
+ B = B.view(B_sz, L, G, N) # (B, L, G, N)
96
+ C = C.view(B_sz, L, G, N) # (B, L, G, N)
97
+
98
+ # Discretization
99
+ # A_bar = exp(delta * A) element-wise
100
+ A_neg = -torch.exp(self.A_log) # (G, N), negative for stability
101
+ deltaA = torch.exp(delta.unsqueeze(-1) * A_neg) # (B, L, G, DG, N)
102
+ deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(-2) * u.unsqueeze(-1)
103
+ # deltaB_u: (B, L, G, DG, N)
104
+
105
+ # Parallel associative scan
106
+ # We compute h_i = deltaA_i * h_{i-1} + deltaB_u_i
107
+ # Then y_i = C_i * h_i + D * u_i
108
+ y = self._associative_scan(deltaA, deltaB_u, C)
109
+ # y: (B, L, G, DG)
110
 
111
  # Add skip connection
112
  y = y + u * D.view(1, 1, G, DG)
113
+ return y.view(B_sz, L, ID)
 
114
 
115
+ def _associative_scan(self, A, X, C):
 
116
  """
117
+ Parallel prefix scan using binary tree reduction.
118
 
119
+ Implements: h_i = A_i * h_{i-1} + X_i
120
+ where multiplication is element-wise (diagonal A).
 
121
 
122
+ Uses a work-efficient parallel scan algorithm.
123
  """
124
+ B, L, G, DG, N = A.shape
125
+ device = A.device
126
+
127
+ # Pad to power of 2 for simpler binary tree
128
+ orig_L = L
129
+ L_pow2 = 1
130
+ while L_pow2 < L:
131
+ L_pow2 *= 2
132
+
133
+ if L_pow2 > L:
134
+ pad_len = L_pow2 - L
135
+ A_pad = torch.cat([A, torch.ones(B, pad_len, G, DG, N, device=device)], dim=1)
136
+ X_pad = torch.cat([X, torch.zeros(B, pad_len, G, DG, N, device=device)], dim=1)
137
+ else:
138
+ A_pad, X_pad = A, X
139
+ L_pow2 = L
140
+
141
+ # Blelloch scan (work-efficient): up-sweep + down-sweep
142
+ h = X_pad.clone()
143
+
144
+ # Up-sweep (reduce)
145
+ stride = 1
146
+ while stride < L_pow2:
147
+ for i in range(stride - 1, L_pow2 - stride, stride * 2):
148
+ # h[i+stride] = A[i+stride] * h[i] + h[i+stride]
149
+ Ai = A_pad[:, i+stride:i+stride+1]
150
+ h[:, i+stride:i+stride+1] = Ai * h[:, i:i+1] + h[:, i+stride:i+stride+1]
151
+ stride *= 2
152
+
153
+ # Down-sweep
154
+ stride = L_pow2 // 2
155
+ while stride > 0:
156
+ for i in range(stride - 1, L_pow2 - stride, stride * 2):
157
+ tmp = h[:, i:i+1].clone()
158
+ Ai = A_pad[:, i+stride:i+stride+1]
159
+ h[:, i:i+1] = h[:, i+stride:i+stride+1]
160
+ h[:, i+stride:i+stride+1] = Ai * tmp + h[:, i+stride:i+stride+1]
161
+ stride //= 2
162
+
163
+ # Trim padding
164
+ h = h[:, :orig_L]
165
 
166
  # Project through C
167
+ y = (h * C.unsqueeze(-2)).sum(dim=-1) # (B, L, G, DG)
 
168
  return y
169
 
170
+ def forward(self, x, scan_direction="forward"):
 
 
 
 
 
 
 
 
171
  residual = x
172
+ B, L, D = x.shape
173
 
 
174
  if scan_direction == "reverse":
175
  x = torch.flip(x, dims=[1])
176
 
177
  # Pre-norm
178
+ x_norm = self.norm(x)
179
 
180
+ # Input projection
181
+ proj = self.in_proj(x_norm)
182
  u, z = proj.chunk(2, dim=-1)
183
 
184
+ # 1D conv
185
  u = self._apply_conv1d(u)
186
 
187
+ # SSD parameters
188
+ delta = F.softplus(self.dt_proj(x_norm))
189
+ B_proj = self.B_proj(x_norm)
190
+ C_proj = self.C_proj(x_norm)
191
 
192
  # Selective scan
193
+ u = self._selective_scan(u, delta, self.A_log, B_proj, C_proj, self.D)
194
 
195
+ # Gate
196
+ if self.use_cfc_modulation and self.cfc_gate is not None:
197
+ gate = self.cfc_gate(x_norm)
198
  u = u * gate
199
  else:
200
  u = u * F.silu(z)
201
 
202
+ # Output
203
  out = self.out_proj(u)
204
  out = self.dropout(out)
205
 
 
206
  if scan_direction == "reverse":
207
  out = torch.flip(out, dims=[1])
208
 
 
210
 
211
 
212
  class BidirectionalMambaBlock(nn.Module):
213
+ """Bidirectional Mamba: forward + reverse scans merged."""
214
+ def __init__(self, dim, **kwargs):
215
+ super().__init__()
216
+ self.fwd = Mamba2SSDBlock(dim, **kwargs)
217
+ self.rev = Mamba2SSDBlock(dim, **kwargs)
218
+ self.merge = nn.Linear(dim * 2, dim, bias=False)
219
+ self.norm = nn.LayerNorm(dim)
220
+
221
+ def forward(self, x):
222
+ f = self.fwd(x, "forward")
223
+ r = self.rev(x, "reverse")
224
+ out = self.merge(torch.cat([f, r], dim=-1))
225
+ return self.norm(out)
226
+
227
+
228
+ class MultiDirectionalScan(nn.Module):
229
  """
230
+ 2D-adapted Mamba layer with multi-directional scanning.
231
 
232
+ From DiM paper: cycles through 4 scan patterns across layers:
233
+ - Row-major forward
234
+ - Row-major reverse
235
+ - Column-major forward
236
+ - Column-major reverse
237
 
238
+ Includes learnable EOL (end-of-line) tokens.
239
+ """
240
+ def __init__(self, dim, pattern="row_fwd", eol_tokens=1, **kwargs):
241
  super().__init__()
242
+ self.dim = dim
243
+ self.pattern = pattern
244
+ self.eol_tokens = eol_tokens
245
+ self.ssd = BidirectionalMambaBlock(dim, **kwargs) if pattern == "bidir" \
246
+ else Mamba2SSDBlock(dim, **kwargs)
247
+
248
+ # Learnable end-of-line tokens
249
+ if eol_tokens > 0:
250
+ self.eol_token = nn.Parameter(torch.randn(1, eol_tokens, dim) * 0.02)
251
+
252
+ def _unflatten_2d(self, x, H, W, pattern):
253
+ """Convert 1D sequence to 2D spatial layout."""
254
+ B, L, D = x.shape
255
+
256
+ if pattern == "row_fwd":
257
+ return x.view(B, H, W, D)
258
+ elif pattern == "row_rev":
259
+ return torch.flip(x.view(B, H, W, D), dims=[2])
260
+ elif pattern == "col_fwd":
261
+ return x.view(B, W, H, D).transpose(1, 2)
262
+ elif pattern == "col_rev":
263
+ return torch.flip(x.view(B, W, H, D).transpose(1, 2), dims=[2])
264
+ return x.view(B, H, W, D)
265
+
266
+ def _flatten_2d(self, x, pattern):
267
+ """Convert 2D spatial layout back to 1D sequence."""
268
+ B, H, W, D = x.shape
269
+
270
+ if pattern == "row_fwd":
271
+ return x.reshape(B, H * W, D)
272
+ elif pattern == "row_rev":
273
+ return torch.flip(x, dims=[2]).reshape(B, H * W, D)
274
+ elif pattern == "col_fwd":
275
+ return x.transpose(1, 2).reshape(B, H * W, D)
276
+ elif pattern == "col_rev":
277
+ return torch.flip(x.transpose(1, 2), dims=[2]).reshape(B, H * W, D)
278
+ return x.reshape(B, H * W, D)
279
+
280
+ def _add_eol_tokens_row(self, x, H, W):
281
+ """Add EOL tokens at end of each row."""
282
+ B, L, D = x.shape
283
+ x = x.view(B, H, W, D)
284
+ eol = self.eol_token.expand(B, H, -1, -1)
285
+ x_with_eol = torch.cat([x, eol], dim=2) # (B, H, W+eol, D)
286
+ return x_with_eol.reshape(B, H * (W + self.eol_tokens), D)
287
+
288
+ def forward(self, x, H, W):
289
+ """
290
+ Args:
291
+ x: (B, H*W, D) flattened image tokens
292
+ H, W: spatial dimensions
293
+ """
294
+ B, L, D = x.shape
295
+
296
+ # Add EOL tokens
297
+ if self.eol_tokens > 0:
298
+ if "row" in self.pattern:
299
+ x = self._add_eol_tokens_row(x, H, W)
300
+ scan_W = W + self.eol_tokens
301
+ H_2d, W_2d = H, scan_W
302
+ else: # col scan
303
+ x_t = x.view(B, W, H, D).transpose(1, 2) # (B, H, W, D)
304
+ x_t = x_t.reshape(B, H * W, D)
305
+ # Add EOL for columns
306
+ x_t = x_t.view(B, H, W, D)
307
+ eol = self.eol_token.expand(B, H, -1, -1)
308
+ x_t = torch.cat([x_t, eol], dim=2)
309
+ H_2d, W_2d = H, W + self.eol_tokens
310
+ x = x_t.reshape(B, H_2d * W_2d, D)
311
+ else:
312
+ H_2d, W_2d = H, W
313
+
314
+ # Apply SSD
315
+ if "rev" in self.pattern:
316
+ x = self.ssd(x, "reverse")
317
+ else:
318
+ x = self.ssd(x, "forward")
319
+
320
+ # Remove EOL tokens
321
+ if self.eol_tokens > 0:
322
+ x = x.view(B, H_2d, W_2d, D)
323
+ x = x[:, :, :W_2d - self.eol_tokens, :] # Remove EOL
324
+ if "col" in self.pattern:
325
+ x = x.transpose(1, 2) # (B, W, H, D)
326
+ x = x.reshape(B, H * W, D)
327
 
328
+ return x