omar-ah commited on
Commit
2b05eb6
·
1 Parent(s): 1e0c38c

Update model configuration and training scripts with new vision backbone support and dependencies

Browse files
code/model_config.py CHANGED
@@ -10,11 +10,13 @@ from typing import Optional, List
10
  @dataclass
11
  class ViLEncoderConfig:
12
  """Vision xLSTM (ViL) encoder configuration"""
 
 
13
  img_size: int = 224
14
  patch_size: int = 16
15
  in_channels: int = 3
16
- dim: int = 384 # ViL-S default (23M params)
17
- depth: int = 24 # Standard ViL depth
18
  mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim
19
  conv_kernel_size: int = 3 # QK Conv2D kernel
20
  bidirectional: bool = True # alternating scan directions
 
10
  @dataclass
11
  class ViLEncoderConfig:
12
  """Vision xLSTM (ViL) encoder configuration"""
13
+ vision_backbone: str = "vil2-small"
14
+ pretrained: bool = True
15
  img_size: int = 224
16
  patch_size: int = 16
17
  in_channels: int = 3
18
+ dim: int = 384 # patch feature dim for vil-small / vil2-small
19
+ depth: int = 12 # VisionLSTM2 block-pairs; v1 vil-small internally uses 24
20
  mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim
21
  conv_kernel_size: int = 3 # QK Conv2D kernel
22
  bidirectional: bool = True # alternating scan directions
code/train_production.py CHANGED
@@ -28,6 +28,7 @@ from io import BytesIO
28
  from datasets import load_dataset
29
  from transformers import AutoTokenizer, AutoModelForMaskedLM
30
  from huggingface_hub import HfApi, snapshot_download
 
31
 
32
  import trackio
33
 
@@ -39,6 +40,8 @@ from dataclasses import dataclass, field
39
 
40
  @dataclass
41
  class ViLConfig:
 
 
42
  img_size: int = 224
43
  patch_size: int = 16
44
  in_channels: int = 3
@@ -243,15 +246,15 @@ class ViLDLM(nn.Module):
243
  def __init__(self, vil_config, proj_config, lm_path):
244
  super().__init__()
245
  self.vil_config = vil_config
246
- self.vision_encoder = VisionXLSTM(vil_config)
247
- self.projector = VisionProjector(proj_config)
248
  self.scheduler = MDLMScheduler()
249
  self.num_patches = vil_config.num_patches
250
 
251
  # Load diffusion LM
252
  print(f"Loading diffusion LM from {lm_path}...")
253
  self.lm = AutoModelForMaskedLM.from_pretrained(
254
- lm_path, trust_remote_code=True, dtype=torch.bfloat16
255
  )
256
  self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
257
  lm_params = sum(p.numel() for p in self.lm.parameters())
 
28
  from datasets import load_dataset
29
  from transformers import AutoTokenizer, AutoModelForMaskedLM
30
  from huggingface_hub import HfApi, snapshot_download
31
+ from vision_xlstm import VisionProjector as UpstreamVisionProjector, VisionXLSTM as UpstreamVisionXLSTM
32
 
33
  import trackio
34
 
 
40
 
41
  @dataclass
42
  class ViLConfig:
43
+ vision_backbone: str = "vil2-small"
44
+ pretrained: bool = True
45
  img_size: int = 224
46
  patch_size: int = 16
47
  in_channels: int = 3
 
246
  def __init__(self, vil_config, proj_config, lm_path):
247
  super().__init__()
248
  self.vil_config = vil_config
249
+ self.vision_encoder = UpstreamVisionXLSTM(vil_config)
250
+ self.projector = UpstreamVisionProjector(proj_config)
251
  self.scheduler = MDLMScheduler()
252
  self.num_patches = vil_config.num_patches
253
 
254
  # Load diffusion LM
255
  print(f"Loading diffusion LM from {lm_path}...")
256
  self.lm = AutoModelForMaskedLM.from_pretrained(
257
+ lm_path, trust_remote_code=True, torch_dtype=torch.bfloat16
258
  )
259
  self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
260
  lm_params = sum(p.numel() for p in self.lm.parameters())
code/vil_dlm_model.py CHANGED
@@ -26,7 +26,7 @@ import torch
26
  import torch.nn as nn
27
  import torch.nn.functional as F
28
  from typing import Optional, Dict, Any, Tuple
29
- from transformers import AutoModelForMaskedLM, AutoTokenizer
30
 
31
  from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig
32
  from vision_xlstm import VisionXLSTM, VisionProjector
@@ -119,7 +119,7 @@ class ViLDLM(nn.Module):
119
  self.lm = AutoModelForMaskedLM.from_pretrained(
120
  model_path,
121
  trust_remote_code=True,
122
- dtype=torch.bfloat16 if self.config.bf16 else torch.float32,
123
  )
124
  self.tokenizer = AutoTokenizer.from_pretrained(
125
  model_path,
@@ -419,13 +419,12 @@ class ViLDLMWithDistillation(ViLDLM):
419
  bnb_4bit_compute_dtype=torch.bfloat16,
420
  bnb_4bit_quant_type="nf4",
421
  )
422
- self.teacher = AutoModelForMaskedLM.from_pretrained(
423
  self.kd_config.teacher_model_id,
424
  quantization_config=bnb_config,
425
  device_map="auto",
426
  )
427
  else:
428
- from transformers import AutoModelForImageTextToText
429
  self.teacher = AutoModelForImageTextToText.from_pretrained(
430
  self.kd_config.teacher_model_id,
431
  torch_dtype=torch.bfloat16,
 
26
  import torch.nn as nn
27
  import torch.nn.functional as F
28
  from typing import Optional, Dict, Any, Tuple
29
+ from transformers import AutoModelForImageTextToText, AutoModelForMaskedLM, AutoTokenizer
30
 
31
  from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig
32
  from vision_xlstm import VisionXLSTM, VisionProjector
 
119
  self.lm = AutoModelForMaskedLM.from_pretrained(
120
  model_path,
121
  trust_remote_code=True,
122
+ torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float32,
123
  )
124
  self.tokenizer = AutoTokenizer.from_pretrained(
125
  model_path,
 
419
  bnb_4bit_compute_dtype=torch.bfloat16,
420
  bnb_4bit_quant_type="nf4",
421
  )
422
+ self.teacher = AutoModelForImageTextToText.from_pretrained(
423
  self.kd_config.teacher_model_id,
424
  quantization_config=bnb_config,
425
  device_map="auto",
426
  )
427
  else:
 
428
  self.teacher = AutoModelForImageTextToText.from_pretrained(
429
  self.kd_config.teacher_model_id,
430
  torch_dtype=torch.bfloat16,
code/vision_xlstm.py CHANGED
@@ -1,348 +1,191 @@
1
  """
2
- Vision xLSTM (ViL) encoder implementation.
3
- Based on: "Vision-LSTM: xLSTM as Generic Vision Backbone" (arxiv:2406.04303)
4
-
5
- Key design:
6
- - Patch embedding (ViT-style, 16x16 patches)
7
- - Alternating bidirectional mLSTM blocks (top-left→bottom-right, bottom-right→top-left)
8
- - Conv2D for QK local context
9
- - Linear complexity O(N) vs ViT's O(N²)
10
  """
11
 
12
- import math
 
 
 
 
 
 
 
13
  import torch
14
  import torch.nn as nn
15
- import torch.nn.functional as F
16
- from einops import rearrange
17
 
18
 
19
- class PatchEmbedding(nn.Module):
20
- """Convert image to patch tokens (identical to ViT)"""
21
- def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=384):
22
- super().__init__()
23
- self.img_size = img_size
24
- self.patch_size = patch_size
25
- self.num_patches = (img_size // patch_size) ** 2
26
- self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
27
- self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
28
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
29
-
30
- def forward(self, x):
31
- # x: [B, C, H, W]
32
- B = x.shape[0]
33
- x = self.proj(x) # [B, D, H/P, W/P]
34
- x = x.flatten(2).transpose(1, 2) # [B, N, D]
35
- x = x + self.pos_embed
36
- return x
37
-
38
-
39
- class MLSTMCell(nn.Module):
40
- """
41
- Matrix-LSTM (mLSTM) cell with exponential gating.
42
-
43
- Core equations:
44
- q = W_q @ x, k = (1/√d) * W_k @ x, v = W_v @ x
45
- f = exp(w_f @ x), i = exp(w_i @ x), o = sigmoid(w_o @ x)
46
- C_t = f * C_{t-1} + i * (v ⊗ k) [outer product memory update]
47
- n_t = f * n_{t-1} + i * k [normalizer]
48
- h_t = o ⊙ (C_t @ q / max(|n_t^T @ q|, 1))
49
- """
50
- def __init__(self, input_dim, head_dim, num_heads=1):
51
- super().__init__()
52
- self.head_dim = head_dim
53
- self.num_heads = num_heads
54
- self.total_dim = head_dim * num_heads
55
-
56
- # QKV projections
57
- self.W_q = nn.Linear(input_dim, self.total_dim, bias=True)
58
- self.W_k = nn.Linear(input_dim, self.total_dim, bias=True)
59
- self.W_v = nn.Linear(input_dim, self.total_dim, bias=True)
60
-
61
- # Gates (scalar per head)
62
- self.w_f = nn.Linear(input_dim, num_heads, bias=True) # forget gate
63
- self.w_i = nn.Linear(input_dim, num_heads, bias=True) # input gate
64
- self.w_o = nn.Linear(input_dim, self.total_dim, bias=True) # output gate
65
-
66
- # Scaling
67
- self.scale = 1.0 / math.sqrt(head_dim)
68
-
69
- def forward(self, x):
70
- """
71
- x: [B, T, D]
72
- Returns: [B, T, total_dim]
73
-
74
- For efficiency, we compute the parallel form via cumulative sums.
75
- """
76
- B, T, D = x.shape
77
-
78
- q = self.W_q(x) # [B, T, total_dim]
79
- k = self.W_k(x) * self.scale # [B, T, total_dim]
80
- v = self.W_v(x) # [B, T, total_dim]
81
-
82
- # Gates
83
- log_f = self.w_f(x) # [B, T, num_heads] - log forget gate
84
- log_i = self.w_i(x) # [B, T, num_heads] - log input gate
85
- o = torch.sigmoid(self.w_o(x)) # [B, T, total_dim]
86
-
87
- # Stabilize with log-space computation
88
- # Cumulative log forget gates for parallel scan
89
- log_f = F.logsigmoid(log_f) # bound to (-inf, 0)
90
-
91
- # Reshape for multi-head
92
- q = rearrange(q, 'b t (h d) -> b h t d', h=self.num_heads)
93
- k = rearrange(k, 'b t (h d) -> b h t d', h=self.num_heads)
94
- v = rearrange(v, 'b t (h d) -> b h t d', h=self.num_heads)
95
- log_f = rearrange(log_f, 'b t h -> b h t')
96
- log_i = rearrange(log_i, 'b t h -> b h t')
97
-
98
- # Parallel computation via chunked linear attention approximation
99
- # For efficiency, use the "linear attention" form:
100
- # h_t = Σ_{s≤t} (Π_{j=s+1}^{t} f_j) * i_s * v_s * k_s^T * q_t
101
- # This is equivalent to softmax-free linear attention with decay
102
-
103
- # Compute cumulative forget gate products in log space
104
- cum_log_f = torch.cumsum(log_f, dim=-1) # [B, H, T]
105
-
106
- # Log weights: log(f^cum * i) for each position
107
- # w_{t,s} = cum_log_f[t] - cum_log_f[s] + log_i[s]
108
- # For parallel form, compute weighted KV accumulation
109
-
110
- # Simplified parallel form using exponential weights
111
- weights = torch.exp(cum_log_f) # [B, H, T] - cumulative decay
112
- i_weights = torch.exp(log_i) # [B, H, T] - input gates
113
-
114
- # Weighted keys and values
115
- w = (i_weights / (weights + 1e-6)).unsqueeze(-1) # [B, H, T, 1]
116
-
117
- kv = torch.einsum('bhtd,bhte->bhde', k * w, v * w) # [B, H, D, D] approx
118
-
119
- # Actually, let's use the simpler chunkwise form for correctness:
120
- # Direct sequential would be too slow, so use causal linear attention
121
- # qk = q @ k^T with causal mask approximated by decay
122
-
123
- # Efficient approximation: use causal dot product with decay
124
- # Gates are per-head scalars: [B, H, T]
125
- decay = torch.exp(log_f) # [B, H, T]
126
- gate = torch.exp(log_i) # [B, H, T]
127
-
128
- # Sequential scan (will be replaced by parallel scan in production)
129
- h_state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
130
- device=x.device, dtype=x.dtype)
131
- n_state = torch.zeros(B, self.num_heads, self.head_dim,
132
- device=x.device, dtype=x.dtype)
133
-
134
- outputs = []
135
- for t in range(T):
136
- f_t = decay[:, :, t] # [B, H] - per-head scalar
137
- i_t = gate[:, :, t] # [B, H] - per-head scalar
138
- k_t = k[:, :, t, :] # [B, H, D]
139
- v_t = v[:, :, t, :] # [B, H, D]
140
- q_t = q[:, :, t, :] # [B, H, D]
141
-
142
- # Expand gates for broadcasting: [B, H] -> [B, H, 1] and [B, H, 1, 1]
143
- f_t_d = f_t.unsqueeze(-1) # [B, H, 1] for D dim
144
- i_t_d = i_t.unsqueeze(-1) # [B, H, 1] for D dim
145
- f_t_dd = f_t.unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] for DxD
146
- i_t_dd = i_t.unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] for DxD
147
-
148
- # Update cell state: C = f*C + i*(v outer k)
149
- h_state = f_t_dd * h_state + i_t_dd * torch.einsum('bhd,bhe->bhde', v_t, k_t)
150
- # Update normalizer: n = f*n + i*k
151
- n_state = f_t_d * n_state + i_t_d * k_t
152
-
153
- # Output: o * (C @ q / max(|n^T @ q|, 1))
154
- Cq = torch.einsum('bhde,bhe->bhd', h_state, q_t)
155
- nq = torch.einsum('bhd,bhd->bh', n_state, q_t).unsqueeze(-1)
156
- nq = torch.clamp(nq.abs(), min=1.0)
157
- h_t = Cq / nq
158
- outputs.append(h_t)
159
-
160
- out = torch.stack(outputs, dim=2) # [B, H, T, D]
161
- out = rearrange(out, 'b h t d -> b t (h d)')
162
- out = out * o
163
-
164
- return out
165
-
166
-
167
- class MLSTMBlock(nn.Module):
168
- """
169
- ViL mLSTM block with Conv2D for QK spatial context.
170
- Wraps mLSTM in a gated MLP structure.
171
- """
172
- def __init__(self, dim, conv_kernel=3, dropout=0.0):
173
- super().__init__()
174
- self.norm = nn.LayerNorm(dim)
175
-
176
- # Pre-projection: expand to 3x for gate structure
177
- self.pre_proj = nn.Linear(dim, dim * 3)
178
-
179
- # Conv2D for spatial QK context (key ViL innovation)
180
- self.conv = nn.Conv2d(dim, dim, kernel_size=conv_kernel,
181
- padding=conv_kernel // 2, groups=dim) # depthwise
182
-
183
- # mLSTM cell
184
- self.mlstm = MLSTMCell(
185
- input_dim=dim,
186
- head_dim=dim // 4, # 4 heads
187
- num_heads=4
188
  )
189
-
190
- # Output projection
191
- self.out_proj = nn.Linear(dim, dim)
192
- self.dropout = nn.Dropout(dropout)
193
-
194
- def forward(self, x, h=None, w=None):
195
- """
196
- x: [B, T, D] patch tokens
197
- h, w: spatial dimensions for conv (sqrt(T) each for square images)
198
- """
199
- B, T, D = x.shape
200
- residual = x
201
- x = self.norm(x)
202
-
203
- # Gate structure: split into B (gate), C (gate), h_tilde (input)
204
- projected = self.pre_proj(x) # [B, T, 3D]
205
- gate_b, gate_c, h_tilde = projected.chunk(3, dim=-1)
206
-
207
- # Apply spatial conv to h_tilde for local context
208
- if h is not None and w is not None:
209
- h_2d = rearrange(h_tilde, 'b (h w) d -> b d h w', h=h, w=w)
210
- h_2d = self.conv(h_2d)
211
- h_tilde = rearrange(h_2d, 'b d h w -> b (h w) d')
212
-
213
- # Input gating
214
- y = torch.sigmoid(gate_b) * h_tilde
215
-
216
- # mLSTM
217
- y = self.mlstm(y)
218
-
219
- # Output gating
220
- y = torch.sigmoid(gate_c) * y
221
- y = self.out_proj(y)
222
- y = self.dropout(y)
223
-
224
- return residual + y
225
-
226
-
227
- class FFNBlock(nn.Module):
228
- """SwiGLU feed-forward block"""
229
- def __init__(self, dim, mult=4, dropout=0.0):
230
- super().__init__()
231
- hidden = int(dim * mult * 2 / 3) # SwiGLU uses 2/3 factor
232
- self.norm = nn.LayerNorm(dim)
233
- self.w1 = nn.Linear(dim, hidden)
234
- self.w2 = nn.Linear(dim, hidden)
235
- self.w3 = nn.Linear(hidden, dim)
236
- self.dropout = nn.Dropout(dropout)
237
-
238
- def forward(self, x):
239
- residual = x
240
- x = self.norm(x)
241
- return residual + self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
242
 
243
 
244
  class VisionXLSTM(nn.Module):
245
  """
246
- Vision xLSTM (ViL) encoder.
247
-
248
- Architecture:
249
- 1. Patch embedding (Conv2D, 16x16)
250
- 2. Alternating bidirectional mLSTM blocks
251
- 3. SwiGLU FFN after each mLSTM
252
-
253
- Output: all patch tokens [B, num_patches, dim] for VLM projection
254
  """
 
255
  def __init__(self, config):
256
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  self.config = config
258
-
259
- # Patch embedding
260
- self.patch_embed = PatchEmbedding(
261
- img_size=config.img_size,
262
- patch_size=config.patch_size,
263
- in_channels=config.in_channels,
264
- embed_dim=config.dim
265
- )
266
-
267
- self.h = config.img_size // config.patch_size
268
- self.w = config.img_size // config.patch_size
269
-
270
- # Alternating mLSTM blocks + FFN
271
- self.blocks = nn.ModuleList()
272
- self.ffns = nn.ModuleList()
273
- for i in range(config.depth):
274
- self.blocks.append(MLSTMBlock(
275
- dim=config.dim,
276
- conv_kernel=config.conv_kernel_size,
277
- dropout=config.dropout
278
- ))
279
- self.ffns.append(FFNBlock(dim=config.dim, dropout=config.dropout))
280
-
281
- self.final_norm = nn.LayerNorm(config.dim)
282
-
283
- def forward_features(self, pixel_values):
284
- """
285
- Extract patch features for VLM projection.
286
-
287
- Args:
288
- pixel_values: [B, C, H, W] images
289
- Returns:
290
- [B, num_patches, dim] patch token features
291
- """
292
- x = self.patch_embed(pixel_values) # [B, N, D]
293
-
294
- for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
295
- if self.config.bidirectional and i % 2 == 1:
296
- # Even blocks (0-indexed odd): reverse scan direction
297
- x = x.flip(1)
298
- x = block(x, h=self.h, w=self.w)
299
- x = ffn(x)
300
- x = x.flip(1)
301
- else:
302
- # Odd blocks: forward scan
303
- x = block(x, h=self.h, w=self.w)
304
- x = ffn(x)
305
-
306
- x = self.final_norm(x)
307
- return x
308
-
309
- def forward(self, pixel_values):
310
- """Classification forward (bilateral concat pooling)"""
311
- features = self.forward_features(pixel_values)
312
- # Bilateral concat: first + last patch
313
- pooled = torch.cat([features[:, 0], features[:, -1]], dim=-1)
314
- return pooled
315
 
316
 
317
  class VisionProjector(nn.Module):
318
  """
319
- MLP projector: maps ViL features LM embedding space.
320
- Following LLaDA-V / LaViDa: 2-layer MLP with GELU.
321
  """
 
322
  def __init__(self, config):
323
  super().__init__()
324
  hidden_dim = config.lm_dim * config.hidden_mult
325
-
326
- layers = []
327
- layers.append(nn.Linear(config.vil_dim, hidden_dim))
328
- layers.append(nn.GELU())
329
  if config.dropout > 0:
330
  layers.append(nn.Dropout(config.dropout))
331
-
332
  for _ in range(config.num_layers - 1):
333
- layers.append(nn.Linear(hidden_dim, hidden_dim))
334
- layers.append(nn.GELU())
335
  if config.dropout > 0:
336
  layers.append(nn.Dropout(config.dropout))
337
-
338
  layers.append(nn.Linear(hidden_dim, config.lm_dim))
339
  self.mlp = nn.Sequential(*layers)
340
-
341
- def forward(self, vision_features):
342
- """
343
- Args:
344
- vision_features: [B, num_patches, vil_dim]
345
- Returns:
346
- [B, num_patches, lm_dim]
347
- """
348
  return self.mlp(vision_features)
 
1
  """
2
+ Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository.
3
+
4
+ This module keeps the existing ViL-DLM contract:
5
+ - `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]`
6
+ - `VisionProjector` maps those visual tokens into the LM embedding space
 
 
 
7
  """
8
 
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from pathlib import Path
13
+ import os
14
+ import ssl
15
+
16
+ import certifi
17
  import torch
18
  import torch.nn as nn
 
 
19
 
20
 
21
+ REPO_ROOT = Path(__file__).resolve().parents[1]
22
+ VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm"
23
+
24
+ if str(VISION_LSTM_ROOT) not in sys.path:
25
+ sys.path.insert(0, str(VISION_LSTM_ROOT))
26
+
27
+ from vision_lstm import VisionLSTM, VisionLSTM2 # noqa: E402
28
+
29
+
30
+ VISION_BACKBONES = {
31
+ "vil-small": {
32
+ "ctor": VisionLSTM,
33
+ "preprocess": "v1",
34
+ "url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th",
35
+ "kwargs": {
36
+ "dim": 384,
37
+ "depth": 24,
38
+ "legacy_norm": True,
39
+ "mode": None,
40
+ "pooling": None,
41
+ "output_shape": None,
42
+ },
43
+ },
44
+ "vil2-small": {
45
+ "ctor": VisionLSTM2,
46
+ "preprocess": "v2",
47
+ "url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th",
48
+ "kwargs": {
49
+ "dim": 384,
50
+ "depth": 12,
51
+ "legacy_norm": True,
52
+ "mode": "features",
53
+ "pooling": None,
54
+ "output_shape": None,
55
+ "conv_kind": "2d",
56
+ "conv_kernel_size": 3,
57
+ "norm_bias": True,
58
+ "proj_bias": True,
59
+ },
60
+ },
61
+ }
62
+
63
+
64
+ def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
65
+ state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
66
+ state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
67
+ state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
68
+ state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight")
69
+ state_dict["norm.weight"] = state_dict.pop("head.0.weight")
70
+ state_dict["norm.bias"] = state_dict.pop("head.0.bias")
71
+ state_dict["head.weight"] = state_dict.pop("head.1.weight")
72
+ state_dict["head.bias"] = state_dict.pop("head.1.bias")
73
+ return state_dict
74
+
75
+
76
+ def _preprocess_v2_state_dict(
77
+ state_dict: dict[str, torch.Tensor],
78
+ *,
79
+ depth: int,
80
+ legacy_norm: bool,
81
+ ) -> dict[str, torch.Tensor]:
82
+ state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
83
+ state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
84
+ state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
85
+ state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()}
86
+ for index in range(depth * 2):
87
+ if index % 2 == 0:
88
+ state_dict = {
89
+ key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value
90
+ for key, value in state_dict.items()
91
+ }
92
+ else:
93
+ state_dict = {
94
+ key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value
95
+ for key, value in state_dict.items()
96
+ }
97
+ state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight")
98
+ state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias")
99
+ if legacy_norm:
100
+ state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight")
101
+ state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias")
102
+ state_dict["head.weight"] = state_dict.pop("head.1.weight")
103
+ state_dict["head.bias"] = state_dict.pop("head.1.bias")
104
+ return state_dict
105
+
106
+
107
+ def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None:
108
+ os.environ.setdefault("SSL_CERT_FILE", certifi.where())
109
+ ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
110
+ payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu")
111
+ state_dict = payload["state_dict"]
112
+ if spec["preprocess"] == "v1":
113
+ state_dict = _preprocess_v1_state_dict(state_dict)
114
+ elif spec["preprocess"] == "v2":
115
+ state_dict = _preprocess_v2_state_dict(
116
+ state_dict,
117
+ depth=spec["kwargs"]["depth"],
118
+ legacy_norm=spec["kwargs"]["legacy_norm"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
+ else:
121
+ raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}")
122
+ if getattr(model, "head", None) is None:
123
+ state_dict.pop("head.weight", None)
124
+ state_dict.pop("head.bias", None)
125
+ model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  class VisionXLSTM(nn.Module):
129
  """
130
+ Thin adapter over upstream VisionLSTM / VisionLSTM2 models.
131
+
132
+ The default backbone is `vil2-small`, which matches the requested 384-dim
133
+ patch features while using the newer ViL v2 implementation.
 
 
 
 
134
  """
135
+
136
  def __init__(self, config):
137
  super().__init__()
138
+ backbone_name = getattr(config, "vision_backbone", "vil2-small")
139
+ pretrained = getattr(config, "pretrained", True)
140
+ img_size = getattr(config, "img_size", 224)
141
+ patch_size = getattr(config, "patch_size", 16)
142
+ in_channels = getattr(config, "in_channels", 3)
143
+
144
+ if backbone_name not in VISION_BACKBONES:
145
+ supported = ", ".join(sorted(VISION_BACKBONES))
146
+ raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}")
147
+
148
+ spec = VISION_BACKBONES[backbone_name]
149
+ ctor_kwargs = dict(spec["kwargs"])
150
+ ctor_kwargs["input_shape"] = (in_channels, img_size, img_size)
151
+ ctor_kwargs["patch_size"] = patch_size
152
+
153
  self.config = config
154
+ self.backbone_name = backbone_name
155
+ self.model = spec["ctor"](**ctor_kwargs)
156
+ self.dim = ctor_kwargs["dim"]
157
+ self.num_patches = self.model.patch_embed.num_patches
158
+
159
+ if pretrained:
160
+ _load_pretrained_backbone(self.model, backbone_name, spec)
161
+
162
+ def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
163
+ return self.model(pixel_values)
164
+
165
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
166
+ return self.forward_features(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  class VisionProjector(nn.Module):
170
  """
171
+ MLP projector: maps ViL features -> LM embedding space.
 
172
  """
173
+
174
  def __init__(self, config):
175
  super().__init__()
176
  hidden_dim = config.lm_dim * config.hidden_mult
177
+
178
+ layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
 
 
179
  if config.dropout > 0:
180
  layers.append(nn.Dropout(config.dropout))
181
+
182
  for _ in range(config.num_layers - 1):
183
+ layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
 
184
  if config.dropout > 0:
185
  layers.append(nn.Dropout(config.dropout))
186
+
187
  layers.append(nn.Linear(hidden_dim, config.lm_dim))
188
  self.mlp = nn.Sequential(*layers)
189
+
190
+ def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
191
  return self.mlp(vision_features)
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "vil-dlm"
3
+ version = "0.1.0"
4
+ description = "Local smoke-test environment for ViL-DLM"
5
+ requires-python = ">=3.11,<3.12"
6
+ dependencies = [
7
+ "torch",
8
+ "torchvision",
9
+ "transformers",
10
+ "huggingface_hub",
11
+ "einops",
12
+ "numpy",
13
+ "pillow",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ "datasets",
19
+ "accelerate",
20
+ "trackio",
21
+ ]
22
+
23
+ [tool.uv]
24
+ package = false
train_production.py CHANGED
@@ -1 +1,3 @@
1
- # Content will be read from sandbox
 
 
 
1
+ """Compatibility stub for the real training entrypoint in `code/train_production.py`."""
2
+
3
+ raise SystemExit("Use `python code/train_production.py ...` from the repo root.")