omar-ah commited on
Commit
61d4766
·
verified ·
1 Parent(s): e7b234a

Upload train_production.py

Browse files
Files changed (1) hide show
  1. code/train_production.py +655 -0
code/train_production.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViL-DLM Production Training Script
3
+ Runs on HF Jobs with GPU
4
+
5
+ Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain
6
+ Stage 2: Full finetune on multimodal instruction data
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import math
12
+ import json
13
+ import time
14
+ import argparse
15
+ from pathlib import Path
16
+ from typing import Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import Dataset, DataLoader
22
+ from torch.optim import AdamW
23
+ from torch.optim.lr_scheduler import CosineAnnealingLR
24
+
25
+ import numpy as np
26
+ from PIL import Image
27
+ from io import BytesIO
28
+ from datasets import load_dataset
29
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
30
+ from huggingface_hub import HfApi, snapshot_download
31
+
32
+ import trackio
33
+
34
+ # ============================================================
35
+ # 1. Model Config
36
+ # ============================================================
37
+
38
+ from dataclasses import dataclass, field
39
+
40
+ @dataclass
41
+ class ViLConfig:
42
+ img_size: int = 224
43
+ patch_size: int = 16
44
+ in_channels: int = 3
45
+ dim: int = 384
46
+ depth: int = 24
47
+ conv_kernel_size: int = 3
48
+ bidirectional: bool = True
49
+ dropout: float = 0.0
50
+
51
+ @property
52
+ def num_patches(self):
53
+ return (self.img_size // self.patch_size) ** 2
54
+
55
+
56
+ @dataclass
57
+ class ProjConfig:
58
+ vil_dim: int = 384
59
+ lm_dim: int = 1024
60
+ hidden_mult: int = 2
61
+ num_layers: int = 2
62
+
63
+ # ============================================================
64
+ # 2. Vision xLSTM Implementation
65
+ # ============================================================
66
+
67
+ class PatchEmbedding(nn.Module):
68
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=384):
69
+ super().__init__()
70
+ self.num_patches = (img_size // patch_size) ** 2
71
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
72
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
73
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
74
+
75
+ def forward(self, x):
76
+ x = self.proj(x).flatten(2).transpose(1, 2)
77
+ return x + self.pos_embed
78
+
79
+
80
+ class MLSTMCell(nn.Module):
81
+ """mLSTM with matrix memory and exponential gating"""
82
+ def __init__(self, input_dim, head_dim, num_heads=4):
83
+ super().__init__()
84
+ self.head_dim = head_dim
85
+ self.num_heads = num_heads
86
+ self.total_dim = head_dim * num_heads
87
+ self.scale = 1.0 / math.sqrt(head_dim)
88
+
89
+ self.W_q = nn.Linear(input_dim, self.total_dim, bias=True)
90
+ self.W_k = nn.Linear(input_dim, self.total_dim, bias=True)
91
+ self.W_v = nn.Linear(input_dim, self.total_dim, bias=True)
92
+ self.w_f = nn.Linear(input_dim, num_heads, bias=True)
93
+ self.w_i = nn.Linear(input_dim, num_heads, bias=True)
94
+ self.w_o = nn.Linear(input_dim, self.total_dim, bias=True)
95
+
96
+ def forward(self, x):
97
+ B, T, D = x.shape
98
+
99
+ q = self.W_q(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
100
+ k = (self.W_k(x) * self.scale).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
101
+ v = self.W_v(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
102
+ o = torch.sigmoid(self.w_o(x))
103
+
104
+ log_f = F.logsigmoid(self.w_f(x)).permute(0, 2, 1) # [B, H, T]
105
+ log_i = self.w_i(x).permute(0, 2, 1) # [B, H, T]
106
+
107
+ decay = torch.exp(log_f) # [B, H, T]
108
+ gate = torch.exp(log_i) # [B, H, T]
109
+
110
+ h_state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
111
+ device=x.device, dtype=x.dtype)
112
+ n_state = torch.zeros(B, self.num_heads, self.head_dim,
113
+ device=x.device, dtype=x.dtype)
114
+
115
+ outputs = []
116
+ for t in range(T):
117
+ f_t = decay[:, :, t].unsqueeze(-1)
118
+ i_t = gate[:, :, t].unsqueeze(-1)
119
+ k_t = k[:, :, t, :]
120
+ v_t = v[:, :, t, :]
121
+ q_t = q[:, :, t, :]
122
+
123
+ h_state = f_t.unsqueeze(-1) * h_state + i_t.unsqueeze(-1) * torch.einsum('bhd,bhe->bhde', v_t, k_t)
124
+ n_state = f_t * n_state + i_t * k_t
125
+
126
+ Cq = torch.einsum('bhde,bhe->bhd', h_state, q_t)
127
+ nq = torch.einsum('bhd,bhd->bh', n_state, q_t).unsqueeze(-1).abs().clamp(min=1.0)
128
+ outputs.append(Cq / nq)
129
+
130
+ out = torch.stack(outputs, dim=2) # [B, H, T, D]
131
+ out = out.permute(0, 2, 1, 3).reshape(B, T, self.total_dim)
132
+ return out * o
133
+
134
+
135
+ class MLSTMBlock(nn.Module):
136
+ def __init__(self, dim, conv_kernel=3, dropout=0.0):
137
+ super().__init__()
138
+ self.norm = nn.LayerNorm(dim)
139
+ self.pre_proj = nn.Linear(dim, dim * 3)
140
+ self.conv = nn.Conv2d(dim, dim, kernel_size=conv_kernel, padding=conv_kernel // 2, groups=dim)
141
+ self.mlstm = MLSTMCell(dim, dim // 4, num_heads=4)
142
+ self.out_proj = nn.Linear(dim, dim)
143
+ self.dropout = nn.Dropout(dropout)
144
+
145
+ def forward(self, x, h=None, w=None):
146
+ B, T, D = x.shape
147
+ residual = x
148
+ x = self.norm(x)
149
+ gate_b, gate_c, h_tilde = self.pre_proj(x).chunk(3, dim=-1)
150
+
151
+ if h is not None and w is not None:
152
+ h_2d = h_tilde.transpose(1, 2).view(B, D, h, w)
153
+ h_2d = self.conv(h_2d)
154
+ h_tilde = h_2d.view(B, D, T).transpose(1, 2)
155
+
156
+ y = torch.sigmoid(gate_b) * h_tilde
157
+ y = self.mlstm(y)
158
+ y = torch.sigmoid(gate_c) * y
159
+ return residual + self.dropout(self.out_proj(y))
160
+
161
+
162
+ class FFNBlock(nn.Module):
163
+ def __init__(self, dim, mult=4, dropout=0.0):
164
+ super().__init__()
165
+ hidden = int(dim * mult * 2 / 3)
166
+ self.norm = nn.LayerNorm(dim)
167
+ self.w1 = nn.Linear(dim, hidden)
168
+ self.w2 = nn.Linear(dim, hidden)
169
+ self.w3 = nn.Linear(hidden, dim)
170
+ self.dropout = nn.Dropout(dropout)
171
+
172
+ def forward(self, x):
173
+ r = x
174
+ x = self.norm(x)
175
+ return r + self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
176
+
177
+
178
+ class VisionXLSTM(nn.Module):
179
+ def __init__(self, config):
180
+ super().__init__()
181
+ self.config = config
182
+ self.patch_embed = PatchEmbedding(config.img_size, config.patch_size, config.in_channels, config.dim)
183
+ self.h = config.img_size // config.patch_size
184
+ self.w = config.img_size // config.patch_size
185
+
186
+ self.blocks = nn.ModuleList()
187
+ self.ffns = nn.ModuleList()
188
+ for _ in range(config.depth):
189
+ self.blocks.append(MLSTMBlock(config.dim, config.conv_kernel_size, config.dropout))
190
+ self.ffns.append(FFNBlock(config.dim, dropout=config.dropout))
191
+ self.final_norm = nn.LayerNorm(config.dim)
192
+
193
+ def forward_features(self, pixel_values):
194
+ x = self.patch_embed(pixel_values)
195
+ for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
196
+ if self.config.bidirectional and i % 2 == 1:
197
+ x = x.flip(1)
198
+ x = block(x, h=self.h, w=self.w)
199
+ x = ffn(x)
200
+ x = x.flip(1)
201
+ else:
202
+ x = block(x, h=self.h, w=self.w)
203
+ x = ffn(x)
204
+ return self.final_norm(x)
205
+
206
+
207
+ class VisionProjector(nn.Module):
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ hidden_dim = config.lm_dim * config.hidden_mult
211
+ layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
212
+ for _ in range(config.num_layers - 1):
213
+ layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
214
+ layers.append(nn.Linear(hidden_dim, config.lm_dim))
215
+ self.mlp = nn.Sequential(*layers)
216
+
217
+ def forward(self, x):
218
+ return self.mlp(x)
219
+
220
+
221
+ # ============================================================
222
+ # 3. MDLM Scheduler & ViL-DLM Model
223
+ # ============================================================
224
+
225
+ class MDLMScheduler:
226
+ def __init__(self, mask_token_id=151643):
227
+ self.mask_token_id = mask_token_id
228
+
229
+ def add_noise(self, input_ids, t):
230
+ B, T = input_ids.shape
231
+ mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
232
+ mask_ratio = mask_ratio.unsqueeze(1).expand(B, T)
233
+ mask = torch.rand(B, T, device=input_ids.device) < mask_ratio
234
+ noisy_ids = input_ids.clone()
235
+ noisy_ids[mask] = self.mask_token_id
236
+ return noisy_ids, mask
237
+
238
+ def sample_timesteps(self, batch_size, device):
239
+ return torch.rand(batch_size, device=device)
240
+
241
+
242
+ class ViLDLM(nn.Module):
243
+ def __init__(self, vil_config, proj_config, lm_path):
244
+ super().__init__()
245
+ self.vil_config = vil_config
246
+ self.vision_encoder = VisionXLSTM(vil_config)
247
+ self.projector = VisionProjector(proj_config)
248
+ self.scheduler = MDLMScheduler()
249
+ self.num_patches = vil_config.num_patches
250
+
251
+ # Load diffusion LM
252
+ print(f"Loading diffusion LM from {lm_path}...")
253
+ self.lm = AutoModelForMaskedLM.from_pretrained(
254
+ lm_path, trust_remote_code=True, dtype=torch.bfloat16
255
+ )
256
+ self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
257
+ lm_params = sum(p.numel() for p in self.lm.parameters())
258
+ print(f"Loaded LM: {lm_params/1e6:.1f}M params")
259
+
260
+ def forward(self, pixel_values, input_ids, attention_mask, labels=None):
261
+ B, T = input_ids.shape
262
+ device = input_ids.device
263
+ if labels is None:
264
+ labels = input_ids.clone()
265
+
266
+ # Diffusion: mask tokens
267
+ t = self.scheduler.sample_timesteps(B, device)
268
+ noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
269
+
270
+ # Encode image
271
+ vision_features = self.vision_encoder.forward_features(pixel_values)
272
+ visual_tokens = self.projector(vision_features)
273
+
274
+ # Get text embeddings
275
+ text_embeds = self.lm.model.embed_tokens(noisy_ids)
276
+ visual_tokens = visual_tokens.to(dtype=text_embeds.dtype)
277
+
278
+ # Concat [vision | text]
279
+ inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
280
+ vis_mask = torch.ones(B, self.num_patches, device=device, dtype=attention_mask.dtype)
281
+ full_mask = torch.cat([vis_mask, attention_mask], dim=1)
282
+
283
+ # Forward through LM
284
+ outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_mask)
285
+ text_logits = outputs.logits[:, self.num_patches:, :]
286
+
287
+ # MDLM loss on masked positions only
288
+ loss_mask = noise_mask.float()
289
+ if loss_mask.sum() == 0:
290
+ loss = torch.tensor(0.0, device=device, requires_grad=True)
291
+ else:
292
+ logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
293
+ labels_flat = labels.reshape(-1)
294
+ loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none').reshape(B, T)
295
+ loss = (loss_flat * loss_mask).sum() / loss_mask.sum()
296
+
297
+ return {'loss': loss, 'logits': text_logits, 'noise_mask': noise_mask, 't': t}
298
+
299
+ def freeze_vision(self):
300
+ for p in self.vision_encoder.parameters():
301
+ p.requires_grad = False
302
+
303
+ def freeze_lm(self):
304
+ for p in self.lm.parameters():
305
+ p.requires_grad = False
306
+
307
+ def unfreeze_all(self):
308
+ for p in self.parameters():
309
+ p.requires_grad = True
310
+
311
+ def count_params(self):
312
+ vil = sum(p.numel() for p in self.vision_encoder.parameters())
313
+ proj = sum(p.numel() for p in self.projector.parameters())
314
+ lm = sum(p.numel() for p in self.lm.parameters())
315
+ train = sum(p.numel() for p in self.parameters() if p.requires_grad)
316
+ return {'vil': vil, 'proj': proj, 'lm': lm, 'total': vil+proj+lm, 'trainable': train}
317
+
318
+
319
+ # ============================================================
320
+ # 4. Dataset
321
+ # ============================================================
322
+
323
+ class LLaVAPretrainDataset(Dataset):
324
+ def __init__(self, tokenizer, max_length=512, img_size=224, max_samples=None):
325
+ print("Loading LLaVA-Pretrain dataset...")
326
+ self.data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
327
+ if max_samples:
328
+ self.data = self.data.select(range(min(max_samples, len(self.data))))
329
+ print(f"Loaded {len(self.data)} samples")
330
+ self.tokenizer = tokenizer
331
+ self.max_length = max_length
332
+ self.img_size = img_size
333
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
334
+ self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
335
+
336
+ def __len__(self):
337
+ return len(self.data)
338
+
339
+ def __getitem__(self, idx):
340
+ sample = self.data[idx]
341
+
342
+ # Image
343
+ try:
344
+ img = sample['image']
345
+ if isinstance(img, str):
346
+ img = Image.open(img).convert('RGB')
347
+ elif isinstance(img, dict) and 'bytes' in img:
348
+ img = Image.open(BytesIO(img['bytes'])).convert('RGB')
349
+ elif not isinstance(img, Image.Image):
350
+ img = Image.new('RGB', (self.img_size, self.img_size), (128, 128, 128))
351
+ else:
352
+ img = img.convert('RGB')
353
+ img = img.resize((self.img_size, self.img_size), Image.BICUBIC)
354
+ arr = np.array(img).astype(np.float32) / 255.0
355
+ pv = torch.from_numpy(arr).permute(2, 0, 1)
356
+ pv = (pv - self.mean) / self.std
357
+ except Exception:
358
+ pv = torch.zeros(3, self.img_size, self.img_size)
359
+
360
+ # Text from conversations
361
+ text = ""
362
+ if 'conversations' in sample:
363
+ parts = []
364
+ for turn in sample['conversations']:
365
+ val = turn.get('value', '').replace('<image>\n', '').replace('<image>', '').strip()
366
+ if val:
367
+ parts.append(val)
368
+ text = ' '.join(parts)
369
+ if not text:
370
+ text = "Describe this image."
371
+
372
+ tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length',
373
+ truncation=True, return_tensors='pt')
374
+
375
+ return {
376
+ 'pixel_values': pv,
377
+ 'input_ids': tokens['input_ids'].squeeze(0),
378
+ 'attention_mask': tokens['attention_mask'].squeeze(0),
379
+ 'labels': tokens['input_ids'].squeeze(0).clone(),
380
+ }
381
+
382
+
383
+ # ============================================================
384
+ # 5. Training Loop
385
+ # ============================================================
386
+
387
+ def train(args):
388
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
389
+ print(f"Device: {device}")
390
+ if torch.cuda.is_available():
391
+ print(f"GPU: {torch.cuda.get_device_name()}")
392
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
393
+
394
+ # Download dLLM model
395
+ print("Downloading dLLM Qwen3-0.6B diffusion model...")
396
+ lm_path = snapshot_download('dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1')
397
+
398
+ # Fix the modeling file (remove dllm import in __main__)
399
+ modeling_file = os.path.join(lm_path, "modeling_qwen3.py")
400
+ with open(modeling_file, 'r') as f:
401
+ content = f.read()
402
+ # Replace the __main__ block that imports dllm
403
+ content = content.replace(
404
+ 'if __name__ == "__main__":\n import dllm',
405
+ 'if __name__ == "__main__":\n pass\n # import dllm'
406
+ )
407
+ # Fix attention_type compatibility
408
+ content = content.replace(
409
+ 'attention_mask=causal_mask_mapping[decoder_layer.attention_type]',
410
+ 'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))'
411
+ )
412
+ with open(modeling_file, 'w') as f:
413
+ f.write(content)
414
+ print(f"Model downloaded to {lm_path}")
415
+
416
+ # Build model
417
+ vil_config = ViLConfig()
418
+ proj_config = ProjConfig()
419
+ model = ViLDLM(vil_config, proj_config, lm_path)
420
+
421
+ # Stage setup
422
+ if args.stage == 1:
423
+ print("\n=== STAGE 1: Projector-only training ===")
424
+ model.freeze_vision()
425
+ model.freeze_lm()
426
+ elif args.stage == 2:
427
+ print("\n=== STAGE 2: Full finetune ===")
428
+ model.unfreeze_all()
429
+
430
+ params = model.count_params()
431
+ print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M")
432
+ print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M")
433
+
434
+ model = model.to(device)
435
+
436
+ # Enable gradient checkpointing for LM
437
+ if hasattr(model.lm, 'gradient_checkpointing_enable'):
438
+ model.lm.gradient_checkpointing_enable()
439
+
440
+ # Dataset
441
+ dataset = LLaVAPretrainDataset(
442
+ tokenizer=model.tokenizer,
443
+ max_length=args.max_length,
444
+ img_size=224,
445
+ max_samples=args.max_samples,
446
+ )
447
+
448
+ dataloader = DataLoader(
449
+ dataset, batch_size=args.batch_size, shuffle=True,
450
+ num_workers=4, pin_memory=True, drop_last=True,
451
+ )
452
+
453
+ # Optimizer with per-component LR
454
+ param_groups = []
455
+ if args.stage == 1:
456
+ param_groups = [{'params': [p for p in model.projector.parameters() if p.requires_grad],
457
+ 'lr': 1e-3}]
458
+ else:
459
+ param_groups = [
460
+ {'params': [p for p in model.vision_encoder.parameters() if p.requires_grad], 'lr': 2e-6},
461
+ {'params': [p for p in model.projector.parameters() if p.requires_grad], 'lr': 1e-5},
462
+ {'params': [p for p in model.lm.parameters() if p.requires_grad], 'lr': 1e-5},
463
+ ]
464
+ param_groups = [g for g in param_groups if len(g['params']) > 0]
465
+
466
+ optimizer = AdamW(param_groups, weight_decay=0.05, betas=(0.9, 0.999))
467
+ total_steps = len(dataloader) * args.epochs // args.grad_accum
468
+ scheduler = CosineAnnealingLR(optimizer, T_max=max(total_steps, 1), eta_min=1e-6)
469
+
470
+ # Trackio
471
+ trackio.init(name=f"vil-dlm-stage{args.stage}")
472
+
473
+ # Training loop
474
+ global_step = 0
475
+ best_loss = float('inf')
476
+
477
+ for epoch in range(args.epochs):
478
+ model.train()
479
+ epoch_loss = 0
480
+ num_batches = 0
481
+
482
+ for batch_idx, batch in enumerate(dataloader):
483
+ pv = batch['pixel_values'].to(device)
484
+ ids = batch['input_ids'].to(device)
485
+ mask = batch['attention_mask'].to(device)
486
+ labels = batch['labels'].to(device)
487
+
488
+ outputs = model(pixel_values=pv, input_ids=ids, attention_mask=mask, labels=labels)
489
+ loss = outputs['loss'] / args.grad_accum
490
+ loss.backward()
491
+
492
+ if (batch_idx + 1) % args.grad_accum == 0:
493
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
494
+ optimizer.step()
495
+ scheduler.step()
496
+ optimizer.zero_grad()
497
+ global_step += 1
498
+
499
+ actual_loss = loss.item() * args.grad_accum
500
+ mask_ratio = outputs['noise_mask'].float().mean().item()
501
+ lr = optimizer.param_groups[0]['lr']
502
+
503
+ if global_step % 5 == 0:
504
+ print(f"[E{epoch}] Step {global_step}/{total_steps} | "
505
+ f"Loss: {actual_loss:.4f} | LR: {lr:.2e} | Mask: {mask_ratio:.1%}")
506
+
507
+ trackio.log({
508
+ 'train/loss': actual_loss,
509
+ 'train/lr': lr,
510
+ 'train/mask_ratio': mask_ratio,
511
+ 'train/epoch': epoch,
512
+ 'train/step': global_step,
513
+ })
514
+
515
+ epoch_loss += loss.item() * args.grad_accum
516
+ num_batches += 1
517
+
518
+ avg_loss = epoch_loss / max(num_batches, 1)
519
+ print(f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f}\n")
520
+ trackio.log({'train/epoch_loss': avg_loss, 'train/epoch': epoch})
521
+
522
+ # Save checkpoint
523
+ if avg_loss < best_loss:
524
+ best_loss = avg_loss
525
+ save_dir = os.path.join(args.output_dir, f"stage{args.stage}_best")
526
+ os.makedirs(save_dir, exist_ok=True)
527
+ torch.save(model.vision_encoder.state_dict(), os.path.join(save_dir, "vision_encoder.pt"))
528
+ torch.save(model.projector.state_dict(), os.path.join(save_dir, "projector.pt"))
529
+ if args.stage >= 2:
530
+ model.lm.save_pretrained(os.path.join(save_dir, "diffusion_lm"))
531
+ print(f"Saved best checkpoint (loss={best_loss:.4f})")
532
+
533
+ # Push to Hub
534
+ print("\nPushing to Hub...")
535
+ api = HfApi()
536
+ repo_id = args.hub_model_id
537
+
538
+ try:
539
+ api.create_repo(repo_id, exist_ok=True, private=False)
540
+ except Exception as e:
541
+ print(f"Repo note: {e}")
542
+
543
+ save_dir = os.path.join(args.output_dir, f"stage{args.stage}_best")
544
+
545
+ # Save config + README
546
+ config_dict = {
547
+ 'architecture': 'ViL-DLM',
548
+ 'components': {
549
+ 'vision_encoder': 'Vision-xLSTM-S (ViL-S)',
550
+ 'projector': '2-layer MLP',
551
+ 'diffusion_lm': 'dLLM Qwen3-0.6B MDLM',
552
+ },
553
+ 'vil_dim': 384,
554
+ 'lm_dim': 1024,
555
+ 'num_patches': 196,
556
+ 'training_stage': args.stage,
557
+ 'best_loss': best_loss,
558
+ 'total_params_M': params['total'] / 1e6,
559
+ 'trainable_params_M': params['trainable'] / 1e6,
560
+ 'based_on': [
561
+ 'Vision-LSTM (arxiv:2406.04303)',
562
+ 'dLLM (arxiv:2602.22661)',
563
+ 'LLaDA-V (arxiv:2505.16933)',
564
+ 'LFM2 (arxiv:2511.23404)',
565
+ ],
566
+ 'teacher': 'google/gemma-4-E2B-it (planned for stage 3)',
567
+ }
568
+ with open(os.path.join(save_dir, "model_config.json"), 'w') as f:
569
+ json.dump(config_dict, f, indent=2)
570
+
571
+ readme = f"""---
572
+ license: apache-2.0
573
+ tags:
574
+ - vision-language
575
+ - diffusion
576
+ - xlstm
577
+ - vision-lstm
578
+ - masked-diffusion
579
+ - mdlm
580
+ language: en
581
+ pipeline_tag: image-text-to-text
582
+ ---
583
+
584
+ # ViL-DLM: Vision xLSTM Diffusion Language Model
585
+
586
+ **The first vision-language model combining Vision xLSTM with a diffusion language backbone.**
587
+
588
+ ## Architecture
589
+
590
+ | Component | Model | Params |
591
+ |-----------|-------|--------|
592
+ | Vision Encoder | **Vision-xLSTM-S (ViL-S)** | ~57M |
593
+ | Projector | 2-layer MLP (GELU) | ~7M |
594
+ | Language Backbone | **dLLM Qwen3-0.6B (MDLM)** | ~596M |
595
+ | **Total** | | **~660M** |
596
+
597
+ ### Why This Combination?
598
+
599
+ 1. **ViL (Vision xLSTM)** — O(N) linear complexity vision encoder vs ViT's O(N²). Uses alternating bidirectional mLSTM blocks with exponential gating and Conv2D for spatial context. Based on [arxiv:2406.04303](https://arxiv.org/abs/2406.04303).
600
+
601
+ 2. **Diffusion Language Model** — Non-autoregressive text generation via masked denoising. Bidirectional attention enables richer contextual understanding. Based on [dLLM/MDLM](https://arxiv.org/abs/2602.22661).
602
+
603
+ 3. **Knowledge Distillation** (Stage 3) — Planned distillation from [Gemma 4 E2B](https://huggingface.co/google/gemma-4-E2B-it) using LFM2-style Decoupled Top-K distillation.
604
+
605
+ ## Training Recipe
606
+
607
+ Inspired by LLaDA-V, LaViDa, LFM2, and Mistral/Pixtral:
608
+
609
+ | Stage | What's Trained | Dataset | LR |
610
+ |-------|---------------|---------|-----|
611
+ | 1 | Projector only | LLaVA-Pretrain (558K) | 1e-3 |
612
+ | 2 | Full model | The Cauldron (multimodal) | ViL:2e-6, Proj:1e-5, LM:1e-5 |
613
+ | 3 | + KD from Gemma 4 E2B | Mixed | + Top-K KD (α=0.5, T=2, K=32) |
614
+
615
+ **Current stage: {args.stage} | Best loss: {best_loss:.4f}**
616
+
617
+ ## Novelty
618
+
619
+ This is (to our knowledge) the **first published model** combining:
620
+ - Vision xLSTM as a vision encoder in a VLM
621
+ - A discrete masked diffusion language model backbone
622
+ - Multi-stage training with knowledge distillation from an AR multimodal teacher
623
+
624
+ ## References
625
+
626
+ - [Vision-LSTM](https://arxiv.org/abs/2406.04303) — Alkin et al., 2024
627
+ - [dLLM](https://arxiv.org/abs/2602.22661) — Berkeley, 2025
628
+ - [MDLM](https://arxiv.org/abs/2406.07524) — Kuleshov group, NeurIPS 2024
629
+ - [LLaDA-V](https://arxiv.org/abs/2505.16933) — GSAI-ML, 2025
630
+ - [LFM2](https://arxiv.org/abs/2511.23404) — Liquid AI, 2025
631
+ - [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) — Google, 2026
632
+ """
633
+
634
+ with open(os.path.join(save_dir, "README.md"), 'w') as f:
635
+ f.write(readme)
636
+
637
+ api.upload_folder(folder_path=save_dir, repo_id=repo_id,
638
+ commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})")
639
+ print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}")
640
+ print("Training complete!")
641
+
642
+
643
+ if __name__ == "__main__":
644
+ parser = argparse.ArgumentParser()
645
+ parser.add_argument("--stage", type=int, default=1)
646
+ parser.add_argument("--epochs", type=int, default=2)
647
+ parser.add_argument("--batch_size", type=int, default=4)
648
+ parser.add_argument("--grad_accum", type=int, default=8)
649
+ parser.add_argument("--max_length", type=int, default=512)
650
+ parser.add_argument("--max_samples", type=int, default=None)
651
+ parser.add_argument("--output_dir", type=str, default="./vil-dlm-output")
652
+ parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B")
653
+ args = parser.parse_args()
654
+
655
+ train(args)