omar-ah commited on
Commit
519f856
·
verified ·
1 Parent(s): 61d4766

Upload vision_xlstm.py

Browse files
Files changed (1) hide show
  1. code/vision_xlstm.py +348 -0
code/vision_xlstm.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision xLSTM (ViL) encoder implementation.
3
+ Based on: "Vision-LSTM: xLSTM as Generic Vision Backbone" (arxiv:2406.04303)
4
+
5
+ Key design:
6
+ - Patch embedding (ViT-style, 16x16 patches)
7
+ - Alternating bidirectional mLSTM blocks (top-left→bottom-right, bottom-right→top-left)
8
+ - Conv2D for QK local context
9
+ - Linear complexity O(N) vs ViT's O(N²)
10
+ """
11
+
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange
17
+
18
+
19
+ class PatchEmbedding(nn.Module):
20
+ """Convert image to patch tokens (identical to ViT)"""
21
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=384):
22
+ super().__init__()
23
+ self.img_size = img_size
24
+ self.patch_size = patch_size
25
+ self.num_patches = (img_size // patch_size) ** 2
26
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
27
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
28
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
29
+
30
+ def forward(self, x):
31
+ # x: [B, C, H, W]
32
+ B = x.shape[0]
33
+ x = self.proj(x) # [B, D, H/P, W/P]
34
+ x = x.flatten(2).transpose(1, 2) # [B, N, D]
35
+ x = x + self.pos_embed
36
+ return x
37
+
38
+
39
+ class MLSTMCell(nn.Module):
40
+ """
41
+ Matrix-LSTM (mLSTM) cell with exponential gating.
42
+
43
+ Core equations:
44
+ q = W_q @ x, k = (1/√d) * W_k @ x, v = W_v @ x
45
+ f = exp(w_f @ x), i = exp(w_i @ x), o = sigmoid(w_o @ x)
46
+ C_t = f * C_{t-1} + i * (v ⊗ k) [outer product memory update]
47
+ n_t = f * n_{t-1} + i * k [normalizer]
48
+ h_t = o ⊙ (C_t @ q / max(|n_t^T @ q|, 1))
49
+ """
50
+ def __init__(self, input_dim, head_dim, num_heads=1):
51
+ super().__init__()
52
+ self.head_dim = head_dim
53
+ self.num_heads = num_heads
54
+ self.total_dim = head_dim * num_heads
55
+
56
+ # QKV projections
57
+ self.W_q = nn.Linear(input_dim, self.total_dim, bias=True)
58
+ self.W_k = nn.Linear(input_dim, self.total_dim, bias=True)
59
+ self.W_v = nn.Linear(input_dim, self.total_dim, bias=True)
60
+
61
+ # Gates (scalar per head)
62
+ self.w_f = nn.Linear(input_dim, num_heads, bias=True) # forget gate
63
+ self.w_i = nn.Linear(input_dim, num_heads, bias=True) # input gate
64
+ self.w_o = nn.Linear(input_dim, self.total_dim, bias=True) # output gate
65
+
66
+ # Scaling
67
+ self.scale = 1.0 / math.sqrt(head_dim)
68
+
69
+ def forward(self, x):
70
+ """
71
+ x: [B, T, D]
72
+ Returns: [B, T, total_dim]
73
+
74
+ For efficiency, we compute the parallel form via cumulative sums.
75
+ """
76
+ B, T, D = x.shape
77
+
78
+ q = self.W_q(x) # [B, T, total_dim]
79
+ k = self.W_k(x) * self.scale # [B, T, total_dim]
80
+ v = self.W_v(x) # [B, T, total_dim]
81
+
82
+ # Gates
83
+ log_f = self.w_f(x) # [B, T, num_heads] - log forget gate
84
+ log_i = self.w_i(x) # [B, T, num_heads] - log input gate
85
+ o = torch.sigmoid(self.w_o(x)) # [B, T, total_dim]
86
+
87
+ # Stabilize with log-space computation
88
+ # Cumulative log forget gates for parallel scan
89
+ log_f = F.logsigmoid(log_f) # bound to (-inf, 0)
90
+
91
+ # Reshape for multi-head
92
+ q = rearrange(q, 'b t (h d) -> b h t d', h=self.num_heads)
93
+ k = rearrange(k, 'b t (h d) -> b h t d', h=self.num_heads)
94
+ v = rearrange(v, 'b t (h d) -> b h t d', h=self.num_heads)
95
+ log_f = rearrange(log_f, 'b t h -> b h t')
96
+ log_i = rearrange(log_i, 'b t h -> b h t')
97
+
98
+ # Parallel computation via chunked linear attention approximation
99
+ # For efficiency, use the "linear attention" form:
100
+ # h_t = Σ_{s≤t} (Π_{j=s+1}^{t} f_j) * i_s * v_s * k_s^T * q_t
101
+ # This is equivalent to softmax-free linear attention with decay
102
+
103
+ # Compute cumulative forget gate products in log space
104
+ cum_log_f = torch.cumsum(log_f, dim=-1) # [B, H, T]
105
+
106
+ # Log weights: log(f^cum * i) for each position
107
+ # w_{t,s} = cum_log_f[t] - cum_log_f[s] + log_i[s]
108
+ # For parallel form, compute weighted KV accumulation
109
+
110
+ # Simplified parallel form using exponential weights
111
+ weights = torch.exp(cum_log_f) # [B, H, T] - cumulative decay
112
+ i_weights = torch.exp(log_i) # [B, H, T] - input gates
113
+
114
+ # Weighted keys and values
115
+ w = (i_weights / (weights + 1e-6)).unsqueeze(-1) # [B, H, T, 1]
116
+
117
+ kv = torch.einsum('bhtd,bhte->bhde', k * w, v * w) # [B, H, D, D] approx
118
+
119
+ # Actually, let's use the simpler chunkwise form for correctness:
120
+ # Direct sequential would be too slow, so use causal linear attention
121
+ # qk = q @ k^T with causal mask approximated by decay
122
+
123
+ # Efficient approximation: use causal dot product with decay
124
+ # Gates are per-head scalars: [B, H, T]
125
+ decay = torch.exp(log_f) # [B, H, T]
126
+ gate = torch.exp(log_i) # [B, H, T]
127
+
128
+ # Sequential scan (will be replaced by parallel scan in production)
129
+ h_state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
130
+ device=x.device, dtype=x.dtype)
131
+ n_state = torch.zeros(B, self.num_heads, self.head_dim,
132
+ device=x.device, dtype=x.dtype)
133
+
134
+ outputs = []
135
+ for t in range(T):
136
+ f_t = decay[:, :, t] # [B, H] - per-head scalar
137
+ i_t = gate[:, :, t] # [B, H] - per-head scalar
138
+ k_t = k[:, :, t, :] # [B, H, D]
139
+ v_t = v[:, :, t, :] # [B, H, D]
140
+ q_t = q[:, :, t, :] # [B, H, D]
141
+
142
+ # Expand gates for broadcasting: [B, H] -> [B, H, 1] and [B, H, 1, 1]
143
+ f_t_d = f_t.unsqueeze(-1) # [B, H, 1] for D dim
144
+ i_t_d = i_t.unsqueeze(-1) # [B, H, 1] for D dim
145
+ f_t_dd = f_t.unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] for DxD
146
+ i_t_dd = i_t.unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] for DxD
147
+
148
+ # Update cell state: C = f*C + i*(v outer k)
149
+ h_state = f_t_dd * h_state + i_t_dd * torch.einsum('bhd,bhe->bhde', v_t, k_t)
150
+ # Update normalizer: n = f*n + i*k
151
+ n_state = f_t_d * n_state + i_t_d * k_t
152
+
153
+ # Output: o * (C @ q / max(|n^T @ q|, 1))
154
+ Cq = torch.einsum('bhde,bhe->bhd', h_state, q_t)
155
+ nq = torch.einsum('bhd,bhd->bh', n_state, q_t).unsqueeze(-1)
156
+ nq = torch.clamp(nq.abs(), min=1.0)
157
+ h_t = Cq / nq
158
+ outputs.append(h_t)
159
+
160
+ out = torch.stack(outputs, dim=2) # [B, H, T, D]
161
+ out = rearrange(out, 'b h t d -> b t (h d)')
162
+ out = out * o
163
+
164
+ return out
165
+
166
+
167
+ class MLSTMBlock(nn.Module):
168
+ """
169
+ ViL mLSTM block with Conv2D for QK spatial context.
170
+ Wraps mLSTM in a gated MLP structure.
171
+ """
172
+ def __init__(self, dim, conv_kernel=3, dropout=0.0):
173
+ super().__init__()
174
+ self.norm = nn.LayerNorm(dim)
175
+
176
+ # Pre-projection: expand to 3x for gate structure
177
+ self.pre_proj = nn.Linear(dim, dim * 3)
178
+
179
+ # Conv2D for spatial QK context (key ViL innovation)
180
+ self.conv = nn.Conv2d(dim, dim, kernel_size=conv_kernel,
181
+ padding=conv_kernel // 2, groups=dim) # depthwise
182
+
183
+ # mLSTM cell
184
+ self.mlstm = MLSTMCell(
185
+ input_dim=dim,
186
+ head_dim=dim // 4, # 4 heads
187
+ num_heads=4
188
+ )
189
+
190
+ # Output projection
191
+ self.out_proj = nn.Linear(dim, dim)
192
+ self.dropout = nn.Dropout(dropout)
193
+
194
+ def forward(self, x, h=None, w=None):
195
+ """
196
+ x: [B, T, D] patch tokens
197
+ h, w: spatial dimensions for conv (sqrt(T) each for square images)
198
+ """
199
+ B, T, D = x.shape
200
+ residual = x
201
+ x = self.norm(x)
202
+
203
+ # Gate structure: split into B (gate), C (gate), h_tilde (input)
204
+ projected = self.pre_proj(x) # [B, T, 3D]
205
+ gate_b, gate_c, h_tilde = projected.chunk(3, dim=-1)
206
+
207
+ # Apply spatial conv to h_tilde for local context
208
+ if h is not None and w is not None:
209
+ h_2d = rearrange(h_tilde, 'b (h w) d -> b d h w', h=h, w=w)
210
+ h_2d = self.conv(h_2d)
211
+ h_tilde = rearrange(h_2d, 'b d h w -> b (h w) d')
212
+
213
+ # Input gating
214
+ y = torch.sigmoid(gate_b) * h_tilde
215
+
216
+ # mLSTM
217
+ y = self.mlstm(y)
218
+
219
+ # Output gating
220
+ y = torch.sigmoid(gate_c) * y
221
+ y = self.out_proj(y)
222
+ y = self.dropout(y)
223
+
224
+ return residual + y
225
+
226
+
227
+ class FFNBlock(nn.Module):
228
+ """SwiGLU feed-forward block"""
229
+ def __init__(self, dim, mult=4, dropout=0.0):
230
+ super().__init__()
231
+ hidden = int(dim * mult * 2 / 3) # SwiGLU uses 2/3 factor
232
+ self.norm = nn.LayerNorm(dim)
233
+ self.w1 = nn.Linear(dim, hidden)
234
+ self.w2 = nn.Linear(dim, hidden)
235
+ self.w3 = nn.Linear(hidden, dim)
236
+ self.dropout = nn.Dropout(dropout)
237
+
238
+ def forward(self, x):
239
+ residual = x
240
+ x = self.norm(x)
241
+ return residual + self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
242
+
243
+
244
+ class VisionXLSTM(nn.Module):
245
+ """
246
+ Vision xLSTM (ViL) encoder.
247
+
248
+ Architecture:
249
+ 1. Patch embedding (Conv2D, 16x16)
250
+ 2. Alternating bidirectional mLSTM blocks
251
+ 3. SwiGLU FFN after each mLSTM
252
+
253
+ Output: all patch tokens [B, num_patches, dim] for VLM projection
254
+ """
255
+ def __init__(self, config):
256
+ super().__init__()
257
+ self.config = config
258
+
259
+ # Patch embedding
260
+ self.patch_embed = PatchEmbedding(
261
+ img_size=config.img_size,
262
+ patch_size=config.patch_size,
263
+ in_channels=config.in_channels,
264
+ embed_dim=config.dim
265
+ )
266
+
267
+ self.h = config.img_size // config.patch_size
268
+ self.w = config.img_size // config.patch_size
269
+
270
+ # Alternating mLSTM blocks + FFN
271
+ self.blocks = nn.ModuleList()
272
+ self.ffns = nn.ModuleList()
273
+ for i in range(config.depth):
274
+ self.blocks.append(MLSTMBlock(
275
+ dim=config.dim,
276
+ conv_kernel=config.conv_kernel_size,
277
+ dropout=config.dropout
278
+ ))
279
+ self.ffns.append(FFNBlock(dim=config.dim, dropout=config.dropout))
280
+
281
+ self.final_norm = nn.LayerNorm(config.dim)
282
+
283
+ def forward_features(self, pixel_values):
284
+ """
285
+ Extract patch features for VLM projection.
286
+
287
+ Args:
288
+ pixel_values: [B, C, H, W] images
289
+ Returns:
290
+ [B, num_patches, dim] patch token features
291
+ """
292
+ x = self.patch_embed(pixel_values) # [B, N, D]
293
+
294
+ for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
295
+ if self.config.bidirectional and i % 2 == 1:
296
+ # Even blocks (0-indexed odd): reverse scan direction
297
+ x = x.flip(1)
298
+ x = block(x, h=self.h, w=self.w)
299
+ x = ffn(x)
300
+ x = x.flip(1)
301
+ else:
302
+ # Odd blocks: forward scan
303
+ x = block(x, h=self.h, w=self.w)
304
+ x = ffn(x)
305
+
306
+ x = self.final_norm(x)
307
+ return x
308
+
309
+ def forward(self, pixel_values):
310
+ """Classification forward (bilateral concat pooling)"""
311
+ features = self.forward_features(pixel_values)
312
+ # Bilateral concat: first + last patch
313
+ pooled = torch.cat([features[:, 0], features[:, -1]], dim=-1)
314
+ return pooled
315
+
316
+
317
+ class VisionProjector(nn.Module):
318
+ """
319
+ MLP projector: maps ViL features → LM embedding space.
320
+ Following LLaDA-V / LaViDa: 2-layer MLP with GELU.
321
+ """
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ hidden_dim = config.lm_dim * config.hidden_mult
325
+
326
+ layers = []
327
+ layers.append(nn.Linear(config.vil_dim, hidden_dim))
328
+ layers.append(nn.GELU())
329
+ if config.dropout > 0:
330
+ layers.append(nn.Dropout(config.dropout))
331
+
332
+ for _ in range(config.num_layers - 1):
333
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
334
+ layers.append(nn.GELU())
335
+ if config.dropout > 0:
336
+ layers.append(nn.Dropout(config.dropout))
337
+
338
+ layers.append(nn.Linear(hidden_dim, config.lm_dim))
339
+ self.mlp = nn.Sequential(*layers)
340
+
341
+ def forward(self, vision_features):
342
+ """
343
+ Args:
344
+ vision_features: [B, num_patches, vil_dim]
345
+ Returns:
346
+ [B, num_patches, lm_dim]
347
+ """
348
+ return self.mlp(vision_features)