omar-ah commited on
Commit
0d77b0a
·
1 Parent(s): 60684e7

Implement stage-aware real-run training pipeline

Browse files
Files changed (5) hide show
  1. README.md +18 -10
  2. code/model_config.py +21 -9
  3. code/train_production.py +1075 -549
  4. code/vil_dlm_model.py +46 -72
  5. pyproject.toml +1 -0
README.md CHANGED
@@ -48,9 +48,9 @@ pipeline_tag: image-text-to-text
48
  - **Key change from AR**: replaces causal attention mask with bidirectional padding-only mask
49
  - Weighted cross-entropy loss on masked positions only (MDLM objective)
50
 
51
- ### Knowledge Distillation (Planned Stage 3)
52
  - Teacher: [Gemma 4 E2B](https://huggingface.co/google/gemma-4-E2B-it) (5.1B params, ~2B effective)
53
- - **Decoupled Top-K Distillation** (from [LFM2](https://arxiv.org/abs/2511.23404)): only align top-32 teacher logits
54
  - Temperature τ=2.0, α_KD=0.5 (50% diffusion loss + 50% KD loss)
55
 
56
  ## Training Recipe
@@ -61,7 +61,7 @@ Multi-stage training inspired by LLaDA-V, LaViDa, LFM2, and Mistral/Pixtral:
61
  |-------|-------------------|---------|---------------|--------|
62
  | 1 | Projector only (ViL & LM frozen) | [LLaVA-Pretrain](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain) (558K) | 1e-3 | 1-2 |
63
  | 2 | Full model (all components) | [The Cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) | ViL:2e-6, Proj:1e-5, LM:1e-5 | 3 |
64
- | 3 | + KD from Gemma 4 E2B | Mixed instruction data | + Top-K KD (α=0.5) | 2 |
65
 
66
  ### Efficiency Tricks Applied
67
  - **Per-component learning rates** (LLaDA-V recipe): vision encoder gets 5× lower LR
@@ -82,20 +82,28 @@ This is a genuinely **unexplored frontier** in the literature:
82
  ## Running Training
83
 
84
  ```bash
85
- # Stage 1: Projector alignment (2-4 hours on A10G)
86
- python train_production.py --stage 1 --epochs 2 --batch_size 4 --grad_accum 8
87
 
88
- # Stage 2: Full finetune (8-12 hours on A10G)
89
- python train_production.py --stage 2 --epochs 3 --batch_size 2 --grad_accum 16
90
 
91
- # Quick test (10 min, small subset)
92
- python train_production.py --stage 1 --epochs 1 --batch_size 2 --grad_accum 1 --max_samples 100
 
 
 
 
 
 
93
  ```
94
 
 
 
95
  ### Hardware Requirements
96
  - **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
97
  - **Stage 2**: A10G (24GB) recommended — full model gradients (~660M params)
98
- - **Stage 3**: A100 (80GB) recommended — teacher model (Gemma 4 E2B) + student
99
 
100
  ### Dependencies
101
  ```
 
48
  - **Key change from AR**: replaces causal attention mask with bidirectional padding-only mask
49
  - Weighted cross-entropy loss on masked positions only (MDLM objective)
50
 
51
+ ### Knowledge Distillation (Stage 3)
52
  - Teacher: [Gemma 4 E2B](https://huggingface.co/google/gemma-4-E2B-it) (5.1B params, ~2B effective)
53
+ - **Sparse cross-tokenizer distillation**: prepare a teacher-scored candidate bank in the student token space, then blend sparse KL with diffusion loss
54
  - Temperature τ=2.0, α_KD=0.5 (50% diffusion loss + 50% KD loss)
55
 
56
  ## Training Recipe
 
61
  |-------|-------------------|---------|---------------|--------|
62
  | 1 | Projector only (ViL & LM frozen) | [LLaVA-Pretrain](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain) (558K) | 1e-3 | 1-2 |
63
  | 2 | Full model (all components) | [The Cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) | ViL:2e-6, Proj:1e-5, LM:1e-5 | 3 |
64
+ | 3 | + KD from Gemma 4 E2B | Stage 2 data mix + cached teacher bank | Sparse cross-tokenizer KD (α=0.5) | 2 |
65
 
66
  ### Efficiency Tricks Applied
67
  - **Per-component learning rates** (LLaDA-V recipe): vision encoder gets 5× lower LR
 
82
  ## Running Training
83
 
84
  ```bash
85
+ # Stage 1: projector-only alignment
86
+ python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
87
 
88
+ # Stage 2: full-model finetune on the balanced Cauldron mix
89
+ python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
90
 
91
+ # Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint
92
+ python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_batch_size 2
93
+
94
+ # Stage 3b: sparse KD training from the cached teacher bank
95
+ python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16
96
+
97
+ # Cheap validation gate for any stage
98
+ python code/train_production.py --stage 1 --require_cuda --dry_run_batches 1 --max_samples 8
99
  ```
100
 
101
+ Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
102
+
103
  ### Hardware Requirements
104
  - **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
105
  - **Stage 2**: A10G (24GB) recommended — full model gradients (~660M params)
106
+ - **Stage 3**: H100 / A100 (80GB) recommended — Gemma 4 teacher bank prep + student distillation
107
 
108
  ### Dependencies
109
  ```
code/model_config.py CHANGED
@@ -62,7 +62,9 @@ class DistillationConfig:
62
  temperature: float = 2.0 # KD temperature
63
  alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss
64
  alpha_vision_kd: float = 0.3 # weight for vision feature distillation
65
- top_k_logits: int = 32 # LFM2-style top-K distillation
 
 
66
 
67
 
68
  @dataclass
@@ -96,34 +98,44 @@ class TrainingConfig:
96
  # Data
97
  pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K
98
  finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal
99
-
 
 
 
 
 
 
 
 
 
 
100
  # Output
101
  output_dir: str = "./vil-dlm-output"
102
  hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
103
- push_to_hub: bool = True
104
 
105
  # Stages
106
- stage: int = 1 # 1 = projector only, 2 = full finetune, 3 = + distillation
107
 
108
 
109
- def get_config(stage: int = 1) -> TrainingConfig:
110
  config = TrainingConfig()
111
  config.stage = stage
112
 
113
- if stage == 1:
114
  # Stage 1: Train projector only (ViL frozen, LM frozen)
115
  config.learning_rate = 1e-3
116
  config.num_epochs = 1
117
  config.per_device_train_batch_size = 8
118
  config.gradient_accumulation_steps = 4
119
- elif stage == 2:
120
  # Stage 2: Full model finetune (ViL + projector + LM)
121
  config.learning_rate = 1e-5
122
  config.vil_learning_rate = 2e-6
123
  config.projector_learning_rate = 1e-5
124
  config.num_epochs = 3
125
- elif stage == 3:
126
- # Stage 3: + Distillation from Gemma 4
127
  config.learning_rate = 1e-5
128
  config.num_epochs = 2
129
  config.distillation.alpha_kd = 0.5
 
62
  temperature: float = 2.0 # KD temperature
63
  alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss
64
  alpha_vision_kd: float = 0.3 # weight for vision feature distillation
65
+ kd_top_k: int = 8 # sparse cross-tokenizer candidate set size
66
+ kd_positions_per_sample: int = 16
67
+ teacher_cache_dir: str = "./vil-dlm-output/teacher-cache"
68
 
69
 
70
  @dataclass
 
98
  # Data
99
  pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K
100
  finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal
101
+ finetune_dataset_configs: List[str] = field(default_factory=lambda: [
102
+ "ai2d",
103
+ "vqav2",
104
+ "a_okvqa",
105
+ "textvqa",
106
+ "docvqa",
107
+ "chartqa",
108
+ "textcaps",
109
+ "screen2words",
110
+ ])
111
+
112
  # Output
113
  output_dir: str = "./vil-dlm-output"
114
  hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
115
+ push_to_hub: bool = False
116
 
117
  # Stages
118
+ stage: str = "1" # 1, 2, 3a, 3b
119
 
120
 
121
+ def get_config(stage: str = "1") -> TrainingConfig:
122
  config = TrainingConfig()
123
  config.stage = stage
124
 
125
+ if stage == "1":
126
  # Stage 1: Train projector only (ViL frozen, LM frozen)
127
  config.learning_rate = 1e-3
128
  config.num_epochs = 1
129
  config.per_device_train_batch_size = 8
130
  config.gradient_accumulation_steps = 4
131
+ elif stage == "2":
132
  # Stage 2: Full model finetune (ViL + projector + LM)
133
  config.learning_rate = 1e-5
134
  config.vil_learning_rate = 2e-6
135
  config.projector_learning_rate = 1e-5
136
  config.num_epochs = 3
137
+ elif stage in {"3a", "3b"}:
138
+ # Stage 3: sparse cross-tokenizer distillation with Gemma 4
139
  config.learning_rate = 1e-5
140
  config.num_epochs = 2
141
  config.distillation.alpha_kd = 0.5
code/train_production.py CHANGED
@@ -1,42 +1,61 @@
1
  """
2
- ViL-DLM Production Training Script
3
- Runs on HF Jobs with GPU
4
 
5
- Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain
6
- Stage 2: Full finetune on multimodal instruction data
 
 
 
7
  """
8
 
9
- import os
10
- import sys
11
- import math
12
  import json
 
 
13
  import time
14
- import argparse
 
 
15
  from pathlib import Path
16
- from typing import Dict, Optional
17
 
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
- from torch.utils.data import Dataset, DataLoader
 
 
 
 
22
  from torch.optim import AdamW
23
  from torch.optim.lr_scheduler import CosineAnnealingLR
 
 
 
 
 
 
 
24
 
25
- import numpy as np
26
- from PIL import Image
27
- from io import BytesIO
28
- from datasets import load_dataset
29
- from transformers import AutoTokenizer, AutoModelForMaskedLM
30
- from huggingface_hub import HfApi, snapshot_download
31
- from vision_xlstm import VisionProjector as UpstreamVisionProjector, VisionXLSTM as UpstreamVisionXLSTM
32
 
33
- import trackio
34
 
35
- # ============================================================
36
- # 1. Model Config
37
- # ============================================================
 
 
 
 
 
 
 
38
 
39
- from dataclasses import dataclass, field
40
 
41
  @dataclass
42
  class ViLConfig:
@@ -50,9 +69,9 @@ class ViLConfig:
50
  conv_kernel_size: int = 3
51
  bidirectional: bool = True
52
  dropout: float = 0.0
53
-
54
  @property
55
- def num_patches(self):
56
  return (self.img_size // self.patch_size) ** 2
57
 
58
 
@@ -66,10 +85,10 @@ class ProjConfig:
66
 
67
 
68
  class _TrackioShim:
69
- def __init__(self):
70
  self.enabled = False
71
 
72
- def init(self, name: str, project: str = "vil-dlm"):
73
  try:
74
  trackio.init(name=name, project=project)
75
  self.enabled = True
@@ -77,7 +96,7 @@ class _TrackioShim:
77
  self.enabled = False
78
  print(f"Trackio disabled: {exc}")
79
 
80
- def log(self, payload: dict):
81
  if not self.enabled:
82
  return
83
  try:
@@ -86,610 +105,1097 @@ class _TrackioShim:
86
  self.enabled = False
87
  print(f"Trackio logging disabled after error: {exc}")
88
 
89
- # ============================================================
90
- # 2. Vision xLSTM Implementation
91
- # ============================================================
92
-
93
- class PatchEmbedding(nn.Module):
94
- def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=384):
95
- super().__init__()
96
- self.num_patches = (img_size // patch_size) ** 2
97
- self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
98
- self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
99
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
100
-
101
- def forward(self, x):
102
- x = self.proj(x).flatten(2).transpose(1, 2)
103
- return x + self.pos_embed
104
-
105
-
106
- class MLSTMCell(nn.Module):
107
- """mLSTM with matrix memory and exponential gating"""
108
- def __init__(self, input_dim, head_dim, num_heads=4):
109
- super().__init__()
110
- self.head_dim = head_dim
111
- self.num_heads = num_heads
112
- self.total_dim = head_dim * num_heads
113
- self.scale = 1.0 / math.sqrt(head_dim)
114
-
115
- self.W_q = nn.Linear(input_dim, self.total_dim, bias=True)
116
- self.W_k = nn.Linear(input_dim, self.total_dim, bias=True)
117
- self.W_v = nn.Linear(input_dim, self.total_dim, bias=True)
118
- self.w_f = nn.Linear(input_dim, num_heads, bias=True)
119
- self.w_i = nn.Linear(input_dim, num_heads, bias=True)
120
- self.w_o = nn.Linear(input_dim, self.total_dim, bias=True)
121
-
122
- def forward(self, x):
123
- B, T, D = x.shape
124
-
125
- q = self.W_q(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
126
- k = (self.W_k(x) * self.scale).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
127
- v = self.W_v(x).view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
128
- o = torch.sigmoid(self.w_o(x))
129
-
130
- log_f = F.logsigmoid(self.w_f(x)).permute(0, 2, 1) # [B, H, T]
131
- log_i = self.w_i(x).permute(0, 2, 1) # [B, H, T]
132
-
133
- decay = torch.exp(log_f) # [B, H, T]
134
- gate = torch.exp(log_i) # [B, H, T]
135
-
136
- h_state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
137
- device=x.device, dtype=x.dtype)
138
- n_state = torch.zeros(B, self.num_heads, self.head_dim,
139
- device=x.device, dtype=x.dtype)
140
-
141
- outputs = []
142
- for t in range(T):
143
- f_t = decay[:, :, t].unsqueeze(-1)
144
- i_t = gate[:, :, t].unsqueeze(-1)
145
- k_t = k[:, :, t, :]
146
- v_t = v[:, :, t, :]
147
- q_t = q[:, :, t, :]
148
-
149
- h_state = f_t.unsqueeze(-1) * h_state + i_t.unsqueeze(-1) * torch.einsum('bhd,bhe->bhde', v_t, k_t)
150
- n_state = f_t * n_state + i_t * k_t
151
-
152
- Cq = torch.einsum('bhde,bhe->bhd', h_state, q_t)
153
- nq = torch.einsum('bhd,bhd->bh', n_state, q_t).unsqueeze(-1).abs().clamp(min=1.0)
154
- outputs.append(Cq / nq)
155
-
156
- out = torch.stack(outputs, dim=2) # [B, H, T, D]
157
- out = out.permute(0, 2, 1, 3).reshape(B, T, self.total_dim)
158
- return out * o
159
-
160
-
161
- class MLSTMBlock(nn.Module):
162
- def __init__(self, dim, conv_kernel=3, dropout=0.0):
163
- super().__init__()
164
- self.norm = nn.LayerNorm(dim)
165
- self.pre_proj = nn.Linear(dim, dim * 3)
166
- self.conv = nn.Conv2d(dim, dim, kernel_size=conv_kernel, padding=conv_kernel // 2, groups=dim)
167
- self.mlstm = MLSTMCell(dim, dim // 4, num_heads=4)
168
- self.out_proj = nn.Linear(dim, dim)
169
- self.dropout = nn.Dropout(dropout)
170
-
171
- def forward(self, x, h=None, w=None):
172
- B, T, D = x.shape
173
- residual = x
174
- x = self.norm(x)
175
- gate_b, gate_c, h_tilde = self.pre_proj(x).chunk(3, dim=-1)
176
-
177
- if h is not None and w is not None:
178
- h_2d = h_tilde.transpose(1, 2).view(B, D, h, w)
179
- h_2d = self.conv(h_2d)
180
- h_tilde = h_2d.view(B, D, T).transpose(1, 2)
181
-
182
- y = torch.sigmoid(gate_b) * h_tilde
183
- y = self.mlstm(y)
184
- y = torch.sigmoid(gate_c) * y
185
- return residual + self.dropout(self.out_proj(y))
186
-
187
-
188
- class FFNBlock(nn.Module):
189
- def __init__(self, dim, mult=4, dropout=0.0):
190
- super().__init__()
191
- hidden = int(dim * mult * 2 / 3)
192
- self.norm = nn.LayerNorm(dim)
193
- self.w1 = nn.Linear(dim, hidden)
194
- self.w2 = nn.Linear(dim, hidden)
195
- self.w3 = nn.Linear(hidden, dim)
196
- self.dropout = nn.Dropout(dropout)
197
-
198
- def forward(self, x):
199
- r = x
200
- x = self.norm(x)
201
- return r + self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
202
-
203
-
204
- class VisionXLSTM(nn.Module):
205
- def __init__(self, config):
206
- super().__init__()
207
- self.config = config
208
- self.patch_embed = PatchEmbedding(config.img_size, config.patch_size, config.in_channels, config.dim)
209
- self.h = config.img_size // config.patch_size
210
- self.w = config.img_size // config.patch_size
211
-
212
- self.blocks = nn.ModuleList()
213
- self.ffns = nn.ModuleList()
214
- for _ in range(config.depth):
215
- self.blocks.append(MLSTMBlock(config.dim, config.conv_kernel_size, config.dropout))
216
- self.ffns.append(FFNBlock(config.dim, dropout=config.dropout))
217
- self.final_norm = nn.LayerNorm(config.dim)
218
-
219
- def forward_features(self, pixel_values):
220
- x = self.patch_embed(pixel_values)
221
- for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
222
- if self.config.bidirectional and i % 2 == 1:
223
- x = x.flip(1)
224
- x = block(x, h=self.h, w=self.w)
225
- x = ffn(x)
226
- x = x.flip(1)
227
- else:
228
- x = block(x, h=self.h, w=self.w)
229
- x = ffn(x)
230
- return self.final_norm(x)
231
-
232
-
233
- class VisionProjector(nn.Module):
234
- def __init__(self, config):
235
- super().__init__()
236
- hidden_dim = config.lm_dim * config.hidden_mult
237
- layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
238
- for _ in range(config.num_layers - 1):
239
- layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
240
- layers.append(nn.Linear(hidden_dim, config.lm_dim))
241
- self.mlp = nn.Sequential(*layers)
242
-
243
- def forward(self, x):
244
- return self.mlp(x)
245
-
246
-
247
- # ============================================================
248
- # 3. MDLM Scheduler & ViL-DLM Model
249
- # ============================================================
250
 
251
  class MDLMScheduler:
252
- def __init__(self, mask_token_id=151643):
253
  self.mask_token_id = mask_token_id
254
-
255
- def add_noise(self, input_ids, t):
256
- B, T = input_ids.shape
257
  mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
258
- mask_ratio = mask_ratio.unsqueeze(1).expand(B, T)
259
- mask = torch.rand(B, T, device=input_ids.device) < mask_ratio
260
  noisy_ids = input_ids.clone()
261
  noisy_ids[mask] = self.mask_token_id
262
  return noisy_ids, mask
263
-
264
- def sample_timesteps(self, batch_size, device):
265
  return torch.rand(batch_size, device=device)
266
 
267
 
268
  class ViLDLM(nn.Module):
269
- def __init__(self, vil_config, proj_config, lm_path):
270
  super().__init__()
271
  self.vil_config = vil_config
272
  self.vision_encoder = UpstreamVisionXLSTM(vil_config)
273
  self.projector = UpstreamVisionProjector(proj_config)
274
- self.scheduler = MDLMScheduler()
275
- self.num_patches = vil_config.num_patches
276
-
277
- # Load diffusion LM
278
- print(f"Loading diffusion LM from {lm_path}...")
279
  self.lm = AutoModelForMaskedLM.from_pretrained(
280
- lm_path, trust_remote_code=True, torch_dtype=torch.bfloat16
 
 
281
  )
282
  self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
283
- lm_params = sum(p.numel() for p in self.lm.parameters())
284
- print(f"Loaded LM: {lm_params/1e6:.1f}M params")
285
-
286
- def forward(self, pixel_values, input_ids, attention_mask, labels=None):
287
- B, T = input_ids.shape
288
- device = input_ids.device
289
- if labels is None:
290
- labels = input_ids.clone()
291
-
292
- # Diffusion: mask tokens
293
- t = self.scheduler.sample_timesteps(B, device)
294
- noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
295
-
296
- # Encode image
297
  vision_features = self.vision_encoder.forward_features(pixel_values)
298
  visual_tokens = self.projector(vision_features)
299
-
300
- # Get text embeddings
301
- text_embeds = self.lm.model.embed_tokens(noisy_ids)
302
  visual_tokens = visual_tokens.to(dtype=text_embeds.dtype)
303
-
304
- # Concat [vision | text]
305
  inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
306
- vis_mask = torch.ones(B, self.num_patches, device=device, dtype=attention_mask.dtype)
307
- full_mask = torch.cat([vis_mask, attention_mask], dim=1)
308
-
309
- # Forward through LM
310
- outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_mask)
311
- text_logits = outputs.logits[:, self.num_patches:, :]
312
-
313
- # MDLM loss on masked positions only
314
- loss_mask = noise_mask.float()
315
- if loss_mask.sum() == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  loss = torch.tensor(0.0, device=device, requires_grad=True)
317
  else:
318
  logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
319
  labels_flat = labels.reshape(-1)
320
- loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none').reshape(B, T)
321
- loss = (loss_flat * loss_mask).sum() / loss_mask.sum()
322
-
323
- return {'loss': loss, 'logits': text_logits, 'noise_mask': noise_mask, 't': t}
324
-
325
- def freeze_vision(self):
326
- for p in self.vision_encoder.parameters():
327
- p.requires_grad = False
328
-
329
- def freeze_lm(self):
330
- for p in self.lm.parameters():
331
- p.requires_grad = False
332
-
333
- def unfreeze_all(self):
334
- for p in self.parameters():
335
- p.requires_grad = True
336
-
337
- def count_params(self):
 
 
 
 
 
338
  vil = sum(p.numel() for p in self.vision_encoder.parameters())
339
  proj = sum(p.numel() for p in self.projector.parameters())
340
  lm = sum(p.numel() for p in self.lm.parameters())
341
- train = sum(p.numel() for p in self.parameters() if p.requires_grad)
342
- return {'vil': vil, 'proj': proj, 'lm': lm, 'total': vil+proj+lm, 'trainable': train}
343
 
 
 
 
 
 
 
 
344
 
345
- # ============================================================
346
- # 4. Dataset
347
- # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
- class LLaVAPretrainDataset(Dataset):
350
- def __init__(self, tokenizer, max_length=512, img_size=224, max_samples=None):
351
- print("Loading LLaVA-Pretrain dataset...")
352
- self.dataset_root = None
353
- try:
354
- self.data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
355
- except Exception as exc:
356
- print(f"Primary dataset loader failed ({exc}). Falling back to direct JSON loading...")
357
- self.dataset_root = snapshot_download(
358
- "liuhaotian/LLaVA-Pretrain",
359
- repo_type="dataset",
360
- allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
361
- )
362
- json_path = os.path.join(self.dataset_root, "blip_laion_cc_sbu_558k.json")
363
- self.data = load_dataset("json", data_files={"train": json_path}, split="train")
364
- if max_samples:
365
- self.data = self.data.select(range(min(max_samples, len(self.data))))
366
- print(f"Loaded {len(self.data)} samples")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  self.tokenizer = tokenizer
368
  self.max_length = max_length
369
  self.img_size = img_size
370
- self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
371
- self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
372
-
373
- def __len__(self):
374
- return len(self.data)
375
-
376
- def __getitem__(self, idx):
377
- sample = self.data[idx]
378
-
379
- # Image
380
- try:
381
- img = sample['image']
382
- if isinstance(img, str):
383
- candidate_paths = [img]
384
- if self.dataset_root and not os.path.isabs(img):
385
- candidate_paths.extend([
386
- os.path.join(self.dataset_root, img),
387
- os.path.join(self.dataset_root, "images", img),
388
- ])
389
- image_path = next((path for path in candidate_paths if os.path.exists(path)), img)
390
- img = Image.open(image_path).convert('RGB')
391
- elif isinstance(img, dict) and 'bytes' in img:
392
- img = Image.open(BytesIO(img['bytes'])).convert('RGB')
393
- elif not isinstance(img, Image.Image):
394
- img = Image.new('RGB', (self.img_size, self.img_size), (128, 128, 128))
395
- else:
396
- img = img.convert('RGB')
397
- img = img.resize((self.img_size, self.img_size), Image.BICUBIC)
398
- arr = np.array(img).astype(np.float32) / 255.0
399
- pv = torch.from_numpy(arr).permute(2, 0, 1)
400
- pv = (pv - self.mean) / self.std
401
- except Exception:
402
- pv = torch.zeros(3, self.img_size, self.img_size)
403
-
404
- # Text from conversations
 
 
 
 
 
 
 
 
 
 
405
  text = ""
406
- if 'conversations' in sample:
407
  parts = []
408
- for turn in sample['conversations']:
409
- val = turn.get('value', '').replace('<image>\n', '').replace('<image>', '').strip()
410
  if val:
411
  parts.append(val)
412
- text = ' '.join(parts)
413
- elif sample.get('blip_caption'):
414
- text = sample['blip_caption'].strip()
415
  if not text:
416
  text = "Describe this image."
417
-
418
- tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length',
419
- truncation=True, return_tensors='pt')
420
-
 
 
 
 
 
 
421
  return {
422
- 'pixel_values': pv,
423
- 'input_ids': tokens['input_ids'].squeeze(0),
424
- 'attention_mask': tokens['attention_mask'].squeeze(0),
425
- 'labels': tokens['input_ids'].squeeze(0).clone(),
 
426
  }
427
 
 
 
 
 
428
 
429
- # ============================================================
430
- # 5. Training Loop
431
- # ============================================================
432
 
433
- def train(args):
434
- tracker = _TrackioShim()
435
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  print(f"Device: {device}")
437
  if torch.cuda.is_available():
438
- print(f"GPU: {torch.cuda.get_device_name()}")
439
  print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
440
-
441
- # Download dLLM model
442
- print("Downloading dLLM Qwen3-0.6B diffusion model...")
443
- lm_path = snapshot_download('dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1')
444
-
445
- # Fix the modeling file (remove dllm import in __main__)
446
- modeling_file = os.path.join(lm_path, "modeling_qwen3.py")
447
- with open(modeling_file, 'r') as f:
448
- content = f.read()
449
- # Replace the __main__ block that imports dllm
450
- content = content.replace(
451
- 'if __name__ == "__main__":\n import dllm',
452
- 'if __name__ == "__main__":\n pass\n # import dllm'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  )
454
- # Fix attention_type compatibility
455
- content = content.replace(
456
- 'attention_mask=causal_mask_mapping[decoder_layer.attention_type]',
457
- 'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  )
459
- with open(modeling_file, 'w') as f:
460
- f.write(content)
461
- print(f"Model downloaded to {lm_path}")
462
-
463
- # Build model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  vil_config = ViLConfig()
465
  proj_config = ProjConfig()
466
  model = ViLDLM(vil_config, proj_config, lm_path)
467
-
468
- # Stage setup
469
- if args.stage == 1:
470
- print("\n=== STAGE 1: Projector-only training ===")
471
- model.freeze_vision()
472
- model.freeze_lm()
473
- elif args.stage == 2:
474
- print("\n=== STAGE 2: Full finetune ===")
475
- model.unfreeze_all()
476
-
477
  params = model.count_params()
478
  print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M")
479
  print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M")
480
-
481
  model = model.to(device)
482
-
483
- # Enable gradient checkpointing for LM
484
- if hasattr(model.lm, 'gradient_checkpointing_enable'):
485
  model.lm.gradient_checkpointing_enable()
486
-
487
- # Dataset
488
- dataset = LLaVAPretrainDataset(
489
- tokenizer=model.tokenizer,
490
- max_length=args.max_length,
491
- img_size=224,
492
- max_samples=args.max_samples,
493
- )
494
-
495
- dataloader = DataLoader(
496
- dataset, batch_size=args.batch_size, shuffle=True,
497
- num_workers=4, pin_memory=True, drop_last=True,
 
 
 
498
  )
499
-
500
- # Optimizer with per-component LR
501
- param_groups = []
502
- if args.stage == 1:
503
- param_groups = [{'params': [p for p in model.projector.parameters() if p.requires_grad],
504
- 'lr': 1e-3}]
505
- else:
506
- param_groups = [
507
- {'params': [p for p in model.vision_encoder.parameters() if p.requires_grad], 'lr': 2e-6},
508
- {'params': [p for p in model.projector.parameters() if p.requires_grad], 'lr': 1e-5},
509
- {'params': [p for p in model.lm.parameters() if p.requires_grad], 'lr': 1e-5},
510
- ]
511
- param_groups = [g for g in param_groups if len(g['params']) > 0]
512
-
513
- optimizer = AdamW(param_groups, weight_decay=0.05, betas=(0.9, 0.999))
514
- total_steps = len(dataloader) * args.epochs // args.grad_accum
515
- scheduler = CosineAnnealingLR(optimizer, T_max=max(total_steps, 1), eta_min=1e-6)
516
-
517
- # Trackio
518
  tracker.init(name=f"vil-dlm-stage{args.stage}")
519
-
520
- # Training loop
 
521
  global_step = 0
522
- best_loss = float('inf')
523
-
524
  for epoch in range(args.epochs):
525
  model.train()
526
- epoch_loss = 0
 
 
527
  num_batches = 0
528
-
 
529
  for batch_idx, batch in enumerate(dataloader):
530
- pv = batch['pixel_values'].to(device)
531
- ids = batch['input_ids'].to(device)
532
- mask = batch['attention_mask'].to(device)
533
- labels = batch['labels'].to(device)
534
-
535
- outputs = model(pixel_values=pv, input_ids=ids, attention_mask=mask, labels=labels)
536
- loss = outputs['loss'] / args.grad_accum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  loss.backward()
538
-
539
  if (batch_idx + 1) % args.grad_accum == 0:
540
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
541
  optimizer.step()
542
  scheduler.step()
543
- optimizer.zero_grad()
544
  global_step += 1
545
-
546
- actual_loss = loss.item() * args.grad_accum
547
- mask_ratio = outputs['noise_mask'].float().mean().item()
548
- lr = optimizer.param_groups[0]['lr']
549
-
550
- if global_step % 5 == 0:
551
- print(f"[E{epoch}] Step {global_step}/{total_steps} | "
552
- f"Loss: {actual_loss:.4f} | LR: {lr:.2e} | Mask: {mask_ratio:.1%}")
553
-
554
- tracker.log({
555
- 'train/loss': actual_loss,
556
- 'train/lr': lr,
557
- 'train/mask_ratio': mask_ratio,
558
- 'train/epoch': epoch,
559
- 'train/step': global_step,
560
- })
561
-
562
- epoch_loss += loss.item() * args.grad_accum
563
- num_batches += 1
564
-
565
- avg_loss = epoch_loss / max(num_batches, 1)
566
- print(f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f}\n")
567
- tracker.log({'train/epoch_loss': avg_loss, 'train/epoch': epoch})
568
-
569
- # Save checkpoint
570
- if avg_loss < best_loss:
571
- best_loss = avg_loss
572
- save_dir = os.path.join(args.output_dir, f"stage{args.stage}_best")
573
- os.makedirs(save_dir, exist_ok=True)
574
- torch.save(model.vision_encoder.state_dict(), os.path.join(save_dir, "vision_encoder.pt"))
575
- torch.save(model.projector.state_dict(), os.path.join(save_dir, "projector.pt"))
576
- if args.stage >= 2:
577
- model.lm.save_pretrained(os.path.join(save_dir, "diffusion_lm"))
578
- print(f"Saved best checkpoint (loss={best_loss:.4f})")
579
-
580
- # Push to Hub
581
- print("\nPushing to Hub...")
582
- api = HfApi()
583
- repo_id = args.hub_model_id
584
-
585
- try:
586
- api.create_repo(repo_id, exist_ok=True, private=False)
587
- except Exception as e:
588
- print(f"Repo note: {e}")
589
-
590
- save_dir = os.path.join(args.output_dir, f"stage{args.stage}_best")
591
-
592
- # Save config + README
593
- config_dict = {
594
- 'architecture': 'ViL-DLM',
595
- 'components': {
596
- 'vision_encoder': 'Vision-xLSTM-S (ViL-S)',
597
- 'projector': '2-layer MLP',
598
- 'diffusion_lm': 'dLLM Qwen3-0.6B MDLM',
599
- },
600
- 'vil_dim': 384,
601
- 'lm_dim': 1024,
602
- 'num_patches': 196,
603
- 'training_stage': args.stage,
604
- 'best_loss': best_loss,
605
- 'total_params_M': params['total'] / 1e6,
606
- 'trainable_params_M': params['trainable'] / 1e6,
607
- 'based_on': [
608
- 'Vision-LSTM (arxiv:2406.04303)',
609
- 'dLLM (arxiv:2602.22661)',
610
- 'LLaDA-V (arxiv:2505.16933)',
611
- 'LFM2 (arxiv:2511.23404)',
612
- ],
613
- 'teacher': 'google/gemma-4-E2B-it (planned for stage 3)',
614
- }
615
- with open(os.path.join(save_dir, "model_config.json"), 'w') as f:
616
- json.dump(config_dict, f, indent=2)
617
-
618
- readme = f"""---
619
- license: apache-2.0
620
- tags:
621
- - vision-language
622
- - diffusion
623
- - xlstm
624
- - vision-lstm
625
- - masked-diffusion
626
- - mdlm
627
- language: en
628
- pipeline_tag: image-text-to-text
629
- ---
630
-
631
- # ViL-DLM: Vision xLSTM Diffusion Language Model
632
-
633
- **The first vision-language model combining Vision xLSTM with a diffusion language backbone.**
634
-
635
- ## Architecture
636
-
637
- | Component | Model | Params |
638
- |-----------|-------|--------|
639
- | Vision Encoder | **Vision-xLSTM-S (ViL-S)** | ~57M |
640
- | Projector | 2-layer MLP (GELU) | ~7M |
641
- | Language Backbone | **dLLM Qwen3-0.6B (MDLM)** | ~596M |
642
- | **Total** | | **~660M** |
643
-
644
- ### Why This Combination?
645
 
646
- 1. **ViL (Vision xLSTM)** — O(N) linear complexity vision encoder vs ViT's O(N²). Uses alternating bidirectional mLSTM blocks with exponential gating and Conv2D for spatial context. Based on [arxiv:2406.04303](https://arxiv.org/abs/2406.04303).
 
 
 
 
 
 
 
 
647
 
648
- 2. **Diffusion Language Model** — Non-autoregressive text generation via masked denoising. Bidirectional attention enables richer contextual understanding. Based on [dLLM/MDLM](https://arxiv.org/abs/2602.22661).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
- 3. **Knowledge Distillation** (Stage 3) — Planned distillation from [Gemma 4 E2B](https://huggingface.co/google/gemma-4-E2B-it) using LFM2-style Decoupled Top-K distillation.
651
-
652
- ## Training Recipe
653
-
654
- Inspired by LLaDA-V, LaViDa, LFM2, and Mistral/Pixtral:
655
-
656
- | Stage | What's Trained | Dataset | LR |
657
- |-------|---------------|---------|-----|
658
- | 1 | Projector only | LLaVA-Pretrain (558K) | 1e-3 |
659
- | 2 | Full model | The Cauldron (multimodal) | ViL:2e-6, Proj:1e-5, LM:1e-5 |
660
- | 3 | + KD from Gemma 4 E2B | Mixed | + Top-K KD (α=0.5, T=2, K=32) |
661
 
662
- **Current stage: {args.stage} | Best loss: {best_loss:.4f}**
 
663
 
664
- ## Novelty
 
 
 
 
 
 
665
 
666
- This is (to our knowledge) the **first published model** combining:
667
- - Vision xLSTM as a vision encoder in a VLM
668
- - A discrete masked diffusion language model backbone
669
- - Multi-stage training with knowledge distillation from an AR multimodal teacher
 
 
 
 
 
 
 
670
 
671
- ## References
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
- - [Vision-LSTM](https://arxiv.org/abs/2406.04303) — Alkin et al., 2024
674
- - [dLLM](https://arxiv.org/abs/2602.22661) — Berkeley, 2025
675
- - [MDLM](https://arxiv.org/abs/2406.07524) Kuleshov group, NeurIPS 2024
676
- - [LLaDA-V](https://arxiv.org/abs/2505.16933) — GSAI-ML, 2025
677
- - [LFM2](https://arxiv.org/abs/2511.23404) — Liquid AI, 2025
678
- - [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) — Google, 2026
679
- """
680
-
681
- with open(os.path.join(save_dir, "README.md"), 'w') as f:
682
- f.write(readme)
683
-
684
- api.upload_folder(folder_path=save_dir, repo_id=repo_id,
685
- commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})")
686
- print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}")
687
  print("Training complete!")
688
 
689
 
690
- if __name__ == "__main__":
691
  parser = argparse.ArgumentParser()
692
- parser.add_argument("--stage", type=int, default=1)
693
  parser.add_argument("--epochs", type=int, default=2)
694
  parser.add_argument("--batch_size", type=int, default=4)
695
  parser.add_argument("--grad_accum", type=int, default=8)
@@ -697,6 +1203,26 @@ if __name__ == "__main__":
697
  parser.add_argument("--max_samples", type=int, default=None)
698
  parser.add_argument("--output_dir", type=str, default="./vil-dlm-output")
699
  parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  args = parser.parse_args()
701
-
702
- train(args)
 
 
1
  """
2
+ ViL-DLM production training script.
 
3
 
4
+ Stages:
5
+ 1 - projector-only alignment on LLaVA-Pretrain
6
+ 2 - full-model finetune on The Cauldron
7
+ 3a - offline teacher candidate-bank preparation with Gemma 4 E2B-it
8
+ 3b - sparse cross-tokenizer distillation training using cached teacher targets
9
  """
10
 
11
+ import argparse
12
+ import hashlib
 
13
  import json
14
+ import math
15
+ import os
16
  import time
17
+ from collections import defaultdict
18
+ from dataclasses import dataclass
19
+ from io import BytesIO
20
  from pathlib import Path
21
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
22
 
23
+ import numpy as np
24
  import torch
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
+ import trackio
28
+ from datasets import Dataset as HFDataset
29
+ from datasets import concatenate_datasets, load_dataset
30
+ from huggingface_hub import HfApi, snapshot_download
31
+ from PIL import Image
32
  from torch.optim import AdamW
33
  from torch.optim.lr_scheduler import CosineAnnealingLR
34
+ from torch.utils.data import DataLoader, Dataset
35
+ from transformers import (
36
+ AutoModelForImageTextToText,
37
+ AutoModelForMaskedLM,
38
+ AutoProcessor,
39
+ AutoTokenizer,
40
+ )
41
 
42
+ from vision_xlstm import (
43
+ VisionProjector as UpstreamVisionProjector,
44
+ VisionXLSTM as UpstreamVisionXLSTM,
45
+ )
 
 
 
46
 
 
47
 
48
+ DEFAULT_CAULDRON_CONFIGS = [
49
+ "ai2d",
50
+ "vqav2",
51
+ "a_okvqa",
52
+ "textvqa",
53
+ "docvqa",
54
+ "chartqa",
55
+ "textcaps",
56
+ "screen2words",
57
+ ]
58
 
 
59
 
60
  @dataclass
61
  class ViLConfig:
 
69
  conv_kernel_size: int = 3
70
  bidirectional: bool = True
71
  dropout: float = 0.0
72
+
73
  @property
74
+ def num_patches(self) -> int:
75
  return (self.img_size // self.patch_size) ** 2
76
 
77
 
 
85
 
86
 
87
  class _TrackioShim:
88
+ def __init__(self) -> None:
89
  self.enabled = False
90
 
91
+ def init(self, name: str, project: str = "vil-dlm") -> None:
92
  try:
93
  trackio.init(name=name, project=project)
94
  self.enabled = True
 
96
  self.enabled = False
97
  print(f"Trackio disabled: {exc}")
98
 
99
+ def log(self, payload: dict) -> None:
100
  if not self.enabled:
101
  return
102
  try:
 
105
  self.enabled = False
106
  print(f"Trackio logging disabled after error: {exc}")
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  class MDLMScheduler:
110
+ def __init__(self, mask_token_id: int) -> None:
111
  self.mask_token_id = mask_token_id
112
+
113
+ def add_noise(self, input_ids: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ batch, length = input_ids.shape
115
  mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
116
+ mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
117
+ mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
118
  noisy_ids = input_ids.clone()
119
  noisy_ids[mask] = self.mask_token_id
120
  return noisy_ids, mask
121
+
122
+ def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
123
  return torch.rand(batch_size, device=device)
124
 
125
 
126
  class ViLDLM(nn.Module):
127
+ def __init__(self, vil_config: ViLConfig, proj_config: ProjConfig, lm_path: str) -> None:
128
  super().__init__()
129
  self.vil_config = vil_config
130
  self.vision_encoder = UpstreamVisionXLSTM(vil_config)
131
  self.projector = UpstreamVisionProjector(proj_config)
 
 
 
 
 
132
  self.lm = AutoModelForMaskedLM.from_pretrained(
133
+ lm_path,
134
+ trust_remote_code=True,
135
+ torch_dtype=torch.bfloat16,
136
  )
137
  self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
138
+ if self.tokenizer.pad_token_id is None:
139
+ self.tokenizer.pad_token = self.tokenizer.eos_token
140
+ self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id)
141
+
142
+ @property
143
+ def num_patches(self) -> int:
144
+ return self.vil_config.num_patches
145
+
146
+ def prepare_multimodal_inputs(
147
+ self,
148
+ pixel_values: torch.Tensor,
149
+ input_ids: torch.Tensor,
150
+ attention_mask: torch.Tensor,
151
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
152
  vision_features = self.vision_encoder.forward_features(pixel_values)
153
  visual_tokens = self.projector(vision_features)
154
+ text_embeds = self.lm.model.embed_tokens(input_ids)
 
 
155
  visual_tokens = visual_tokens.to(dtype=text_embeds.dtype)
 
 
156
  inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
157
+ vis_mask = torch.ones(
158
+ pixel_values.shape[0],
159
+ self.num_patches,
160
+ device=attention_mask.device,
161
+ dtype=attention_mask.dtype,
162
+ )
163
+ full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
164
+ return inputs_embeds, full_attention_mask
165
+
166
+ def predict_clean_logits(
167
+ self,
168
+ pixel_values: torch.Tensor,
169
+ input_ids: torch.Tensor,
170
+ attention_mask: torch.Tensor,
171
+ ) -> torch.Tensor:
172
+ inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
173
+ pixel_values=pixel_values,
174
+ input_ids=input_ids,
175
+ attention_mask=attention_mask,
176
+ )
177
+ outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
178
+ return outputs.logits[:, self.num_patches :, :]
179
+
180
+ def forward(
181
+ self,
182
+ pixel_values: torch.Tensor,
183
+ input_ids: torch.Tensor,
184
+ attention_mask: torch.Tensor,
185
+ labels: Optional[torch.Tensor] = None,
186
+ loss_mask: Optional[torch.Tensor] = None,
187
+ ) -> Dict[str, torch.Tensor]:
188
+ batch_size, seq_len = input_ids.shape
189
+ device = input_ids.device
190
+ if labels is None:
191
+ labels = input_ids.clone()
192
+ if loss_mask is None:
193
+ loss_mask = attention_mask
194
+
195
+ t = self.scheduler.sample_timesteps(batch_size, device)
196
+ noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
197
+ inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
198
+ pixel_values=pixel_values,
199
+ input_ids=noisy_ids,
200
+ attention_mask=attention_mask,
201
+ )
202
+ outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
203
+ text_logits = outputs.logits[:, self.num_patches :, :]
204
+
205
+ active_mask = noise_mask.float() * loss_mask.float()
206
+ if active_mask.sum() == 0:
207
  loss = torch.tensor(0.0, device=device, requires_grad=True)
208
  else:
209
  logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
210
  labels_flat = labels.reshape(-1)
211
+ per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none").reshape(batch_size, seq_len)
212
+ loss = (per_token * active_mask).sum() / active_mask.sum()
213
+
214
+ return {
215
+ "loss": loss,
216
+ "logits": text_logits,
217
+ "noise_mask": noise_mask,
218
+ "t": t,
219
+ }
220
+
221
+ def freeze_vision(self) -> None:
222
+ for param in self.vision_encoder.parameters():
223
+ param.requires_grad = False
224
+
225
+ def freeze_lm(self) -> None:
226
+ for param in self.lm.parameters():
227
+ param.requires_grad = False
228
+
229
+ def unfreeze_all(self) -> None:
230
+ for param in self.parameters():
231
+ param.requires_grad = True
232
+
233
+ def count_params(self) -> Dict[str, int]:
234
  vil = sum(p.numel() for p in self.vision_encoder.parameters())
235
  proj = sum(p.numel() for p in self.projector.parameters())
236
  lm = sum(p.numel() for p in self.lm.parameters())
237
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
238
+ return {"vil": vil, "proj": proj, "lm": lm, "total": vil + proj + lm, "trainable": trainable}
239
 
240
+ def save_checkpoint(self, save_dir: Path, include_lm: bool) -> None:
241
+ save_dir.mkdir(parents=True, exist_ok=True)
242
+ torch.save(self.vision_encoder.state_dict(), save_dir / "vision_encoder.pt")
243
+ torch.save(self.projector.state_dict(), save_dir / "projector.pt")
244
+ if include_lm:
245
+ self.lm.save_pretrained(save_dir / "diffusion_lm")
246
+ self.tokenizer.save_pretrained(save_dir / "diffusion_lm")
247
 
248
+ def load_checkpoint(self, checkpoint_dir: Path, include_lm: bool) -> None:
249
+ vision_path = checkpoint_dir / "vision_encoder.pt"
250
+ projector_path = checkpoint_dir / "projector.pt"
251
+ if vision_path.exists():
252
+ self.vision_encoder.load_state_dict(torch.load(vision_path, map_location="cpu"))
253
+ if projector_path.exists():
254
+ self.projector.load_state_dict(torch.load(projector_path, map_location="cpu"))
255
+ if include_lm:
256
+ diffusion_dir = checkpoint_dir / "diffusion_lm"
257
+ if diffusion_dir.exists():
258
+ self.lm = AutoModelForMaskedLM.from_pretrained(
259
+ diffusion_dir,
260
+ trust_remote_code=True,
261
+ torch_dtype=torch.bfloat16,
262
+ )
263
+ self.tokenizer = AutoTokenizer.from_pretrained(diffusion_dir, trust_remote_code=True)
264
+ if self.tokenizer.pad_token_id is None:
265
+ self.tokenizer.pad_token = self.tokenizer.eos_token
266
+ self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id)
267
 
268
+
269
+ def ensure_hf_cache_root() -> None:
270
+ os.environ.setdefault("HF_HOME", "/teamspace/studios/this_studio/.cache/huggingface")
271
+
272
+
273
+ def patch_diffusion_modeling_file(lm_path: str) -> None:
274
+ modeling_file = os.path.join(lm_path, "modeling_qwen3.py")
275
+ with open(modeling_file, "r", encoding="utf-8") as handle:
276
+ content = handle.read()
277
+ content = content.replace(
278
+ 'if __name__ == "__main__":\n import dllm',
279
+ 'if __name__ == "__main__":\n pass\n # import dllm',
280
+ )
281
+ content = content.replace(
282
+ "attention_mask=causal_mask_mapping[decoder_layer.attention_type]",
283
+ 'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))',
284
+ )
285
+ with open(modeling_file, "w", encoding="utf-8") as handle:
286
+ handle.write(content)
287
+
288
+
289
+ def download_student_backbone() -> str:
290
+ print("Downloading dLLM Qwen3-0.6B diffusion model...")
291
+ lm_path = snapshot_download("dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1")
292
+ patch_diffusion_modeling_file(lm_path)
293
+ print(f"Model downloaded to {lm_path}")
294
+ return lm_path
295
+
296
+
297
+ def parse_dataset_configs(dataset_configs: Optional[str]) -> List[str]:
298
+ if dataset_configs:
299
+ return [item.strip() for item in dataset_configs.split(",") if item.strip()]
300
+ return list(DEFAULT_CAULDRON_CONFIGS)
301
+
302
+
303
+ def stable_text_hash(*parts: str) -> str:
304
+ joined = "\n".join(parts)
305
+ return hashlib.sha1(joined.encode("utf-8")).hexdigest()
306
+
307
+
308
+ def build_prompt_prefix(prompt_text: str) -> str:
309
+ return f"User: {prompt_text.strip()}\nAssistant:"
310
+
311
+
312
+ def tokenize_prompt_and_target(
313
+ tokenizer: AutoTokenizer,
314
+ prompt_text: str,
315
+ target_text: str,
316
+ max_length: int,
317
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
318
+ prefix_text = build_prompt_prefix(prompt_text)
319
+ prefix_ids = tokenizer(prefix_text, add_special_tokens=True)["input_ids"]
320
+ target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"]
321
+ if not target_ids:
322
+ target_ids = tokenizer(" " + "N/A", add_special_tokens=False)["input_ids"][:1]
323
+
324
+ max_prefix_len = max_length - 1
325
+ if len(prefix_ids) > max_prefix_len:
326
+ prefix_ids = prefix_ids[:max_prefix_len]
327
+
328
+ remaining = max_length - len(prefix_ids)
329
+ if remaining <= 0:
330
+ prefix_ids = prefix_ids[: max_length - 1]
331
+ remaining = 1
332
+ target_ids = target_ids[:remaining]
333
+ if not target_ids:
334
+ prefix_ids = prefix_ids[: max_length - 1]
335
+ target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"][:1]
336
+
337
+ input_ids = prefix_ids + target_ids
338
+ loss_mask = [0] * len(prefix_ids) + [1] * len(target_ids)
339
+ attention_mask = [1] * len(input_ids)
340
+ labels = list(input_ids)
341
+
342
+ pad_token_id = tokenizer.pad_token_id
343
+ if pad_token_id is None:
344
+ pad_token_id = tokenizer.eos_token_id
345
+
346
+ pad_len = max_length - len(input_ids)
347
+ if pad_len > 0:
348
+ input_ids = input_ids + [pad_token_id] * pad_len
349
+ attention_mask = attention_mask + [0] * pad_len
350
+ labels = labels + [pad_token_id] * pad_len
351
+ loss_mask = loss_mask + [0] * pad_len
352
+
353
+ return (
354
+ torch.tensor(input_ids, dtype=torch.long),
355
+ torch.tensor(attention_mask, dtype=torch.long),
356
+ torch.tensor(labels, dtype=torch.long),
357
+ torch.tensor(loss_mask, dtype=torch.float32),
358
+ )
359
+
360
+
361
+ def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
362
+ if isinstance(img, str):
363
+ img = Image.open(img).convert("RGB")
364
+ elif isinstance(img, dict) and "bytes" in img:
365
+ img = Image.open(BytesIO(img["bytes"])).convert("RGB")
366
+ elif isinstance(img, Image.Image):
367
+ img = img.convert("RGB")
368
+ else:
369
+ img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
370
+
371
+ pil_image = img
372
+ resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
373
+ arr = np.array(resized).astype(np.float32) / 255.0
374
+ tensor = torch.from_numpy(arr).permute(2, 0, 1)
375
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
376
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
377
+ tensor = (tensor - mean) / std
378
+ return tensor, pil_image
379
+
380
+
381
+ class NormalizedVisionLanguageDataset(Dataset):
382
+ def __init__(
383
+ self,
384
+ records: HFDataset,
385
+ tokenizer: AutoTokenizer,
386
+ max_length: int,
387
+ img_size: int,
388
+ ) -> None:
389
+ self.records = records
390
  self.tokenizer = tokenizer
391
  self.max_length = max_length
392
  self.img_size = img_size
393
+
394
+ def __len__(self) -> int:
395
+ return len(self.records)
396
+
397
+ def __getitem__(self, idx: int) -> Dict[str, object]:
398
+ sample = self.records[int(idx)]
399
+ pixel_values, pil_image = preprocess_image_for_student(sample["image"], self.img_size)
400
+ input_ids, attention_mask, labels, loss_mask = tokenize_prompt_and_target(
401
+ tokenizer=self.tokenizer,
402
+ prompt_text=sample["prompt_text"],
403
+ target_text=sample["target_text"],
404
+ max_length=self.max_length,
405
+ )
406
+ return {
407
+ "pixel_values": pixel_values,
408
+ "input_ids": input_ids,
409
+ "attention_mask": attention_mask,
410
+ "labels": labels,
411
+ "loss_mask": loss_mask,
412
+ "sample_id": sample["sample_id"],
413
+ "prompt_text": sample["prompt_text"],
414
+ "target_text": sample["target_text"],
415
+ "source_config": sample.get("source_config", "unknown"),
416
+ "pil_image": pil_image,
417
+ }
418
+
419
+
420
+ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
421
+ print("Loading LLaVA-Pretrain dataset...")
422
+ dataset_root = None
423
+ try:
424
+ data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
425
+ except Exception as exc:
426
+ print(f"Primary dataset loader failed ({exc}). Falling back to direct JSON loading...")
427
+ dataset_root = snapshot_download(
428
+ "liuhaotian/LLaVA-Pretrain",
429
+ repo_type="dataset",
430
+ allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
431
+ )
432
+ json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
433
+ data = load_dataset("json", data_files={"train": json_path}, split="train")
434
+ if max_samples:
435
+ data = data.select(range(min(max_samples, len(data))))
436
+
437
+ def normalize(sample: Dict[str, object], idx: int) -> Dict[str, object]:
438
  text = ""
439
+ if "conversations" in sample:
440
  parts = []
441
+ for turn in sample["conversations"]:
442
+ val = turn.get("value", "").replace("<image>\n", "").replace("<image>", "").strip()
443
  if val:
444
  parts.append(val)
445
+ text = " ".join(parts)
446
+ elif sample.get("blip_caption"):
447
+ text = sample["blip_caption"].strip()
448
  if not text:
449
  text = "Describe this image."
450
+
451
+ image_obj = sample.get("image")
452
+ if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
453
+ candidate_paths = [
454
+ image_obj,
455
+ os.path.join(dataset_root, image_obj),
456
+ os.path.join(dataset_root, "images", image_obj),
457
+ ]
458
+ image_obj = next((path for path in candidate_paths if os.path.exists(path)), image_obj)
459
+
460
  return {
461
+ "image": image_obj,
462
+ "prompt_text": "Describe this image.",
463
+ "target_text": text,
464
+ "sample_id": f"llava-pretrain:{sample.get('id', idx)}",
465
+ "source_config": "llava_pretrain",
466
  }
467
 
468
+ records = [normalize(data[i], i) for i in range(len(data))]
469
+ normalized = HFDataset.from_list(records)
470
+ print(f"Loaded {len(normalized)} LLaVA samples")
471
+ return normalized
472
 
 
 
 
473
 
474
+ def build_cauldron_records(configs: Sequence[str], max_samples: Optional[int]) -> Tuple[HFDataset, Dict[str, Dict[str, int]]]:
475
+ normalized_configs: List[HFDataset] = []
476
+ skip_stats: Dict[str, Dict[str, int]] = {}
477
+ per_config_limit = None
478
+ if max_samples:
479
+ per_config_limit = max(1, max_samples // max(len(configs), 1))
480
+
481
+ for config_name in configs:
482
+ print(f"Loading The Cauldron config: {config_name}")
483
+ ds = load_dataset("HuggingFaceM4/the_cauldron", config_name, split="train")
484
+ stats = defaultdict(int)
485
+
486
+ def explode(batch: Dict[str, List[object]], indices: List[int]) -> Dict[str, List[object]]:
487
+ output = {
488
+ "image": [],
489
+ "prompt_text": [],
490
+ "target_text": [],
491
+ "sample_id": [],
492
+ "source_config": [],
493
+ }
494
+ for local_idx, row_idx in enumerate(indices):
495
+ images = batch["images"][local_idx]
496
+ texts = batch["texts"][local_idx]
497
+ if not images or len(images) != 1:
498
+ stats["multi_or_missing_image"] += 1
499
+ continue
500
+ if not texts:
501
+ stats["missing_turns"] += 1
502
+ continue
503
+ for turn_idx, turn in enumerate(texts):
504
+ user_text = (turn.get("user") or "").strip()
505
+ assistant_text = (turn.get("assistant") or "").strip()
506
+ if not user_text or not assistant_text:
507
+ stats["missing_user_or_assistant"] += 1
508
+ continue
509
+ output["image"].append(images[0])
510
+ output["prompt_text"].append(user_text)
511
+ output["target_text"].append(assistant_text)
512
+ output["sample_id"].append(f"{config_name}:{row_idx}:{turn_idx}")
513
+ output["source_config"].append(config_name)
514
+ stats["kept"] += 1
515
+ return output
516
+
517
+ exploded = ds.map(
518
+ explode,
519
+ batched=True,
520
+ with_indices=True,
521
+ remove_columns=ds.column_names,
522
+ desc=f"Normalizing {config_name}",
523
+ )
524
+ if per_config_limit is not None:
525
+ exploded = exploded.select(range(min(per_config_limit, len(exploded))))
526
+ normalized_configs.append(exploded)
527
+ skip_stats[config_name] = dict(stats)
528
+ print(f"{config_name}: kept={stats['kept']} skipped={sum(v for k, v in stats.items() if k != 'kept')}")
529
+
530
+ if not normalized_configs:
531
+ raise RuntimeError("No valid The Cauldron configs were loaded.")
532
+
533
+ combined = concatenate_datasets(normalized_configs)
534
+ if max_samples:
535
+ combined = combined.select(range(min(max_samples, len(combined))))
536
+ print(f"Loaded {len(combined)} normalized The Cauldron samples")
537
+ return combined, skip_stats
538
+
539
+
540
+ def collate_vision_language(batch: List[Dict[str, object]]) -> Dict[str, object]:
541
+ return {
542
+ "pixel_values": torch.stack([sample["pixel_values"] for sample in batch]),
543
+ "input_ids": torch.stack([sample["input_ids"] for sample in batch]),
544
+ "attention_mask": torch.stack([sample["attention_mask"] for sample in batch]),
545
+ "labels": torch.stack([sample["labels"] for sample in batch]),
546
+ "loss_mask": torch.stack([sample["loss_mask"] for sample in batch]),
547
+ "sample_id": [sample["sample_id"] for sample in batch],
548
+ "prompt_text": [sample["prompt_text"] for sample in batch],
549
+ "target_text": [sample["target_text"] for sample in batch],
550
+ "source_config": [sample["source_config"] for sample in batch],
551
+ "pil_image": [sample["pil_image"] for sample in batch],
552
+ }
553
+
554
+
555
+ def create_stage_dataset(stage: str, tokenizer: AutoTokenizer, args: argparse.Namespace) -> Tuple[NormalizedVisionLanguageDataset, Dict[str, Dict[str, int]]]:
556
+ if stage == "1":
557
+ return NormalizedVisionLanguageDataset(
558
+ records=build_llava_records(args.max_samples),
559
+ tokenizer=tokenizer,
560
+ max_length=args.max_length,
561
+ img_size=224,
562
+ ), {}
563
+
564
+ configs = parse_dataset_configs(args.dataset_configs)
565
+ records, skip_stats = build_cauldron_records(configs, args.max_samples)
566
+ return NormalizedVisionLanguageDataset(
567
+ records=records,
568
+ tokenizer=tokenizer,
569
+ max_length=args.max_length,
570
+ img_size=224,
571
+ ), skip_stats
572
+
573
+
574
+ def build_dataloader(
575
+ dataset: Dataset,
576
+ batch_size: int,
577
+ shuffle: bool,
578
+ num_workers: int,
579
+ persistent_workers: bool,
580
+ ) -> DataLoader:
581
+ return DataLoader(
582
+ dataset,
583
+ batch_size=batch_size,
584
+ shuffle=shuffle,
585
+ num_workers=num_workers,
586
+ pin_memory=torch.cuda.is_available(),
587
+ persistent_workers=persistent_workers and num_workers > 0,
588
+ drop_last=False,
589
+ collate_fn=collate_vision_language,
590
+ )
591
+
592
+
593
+ def print_device_info(device: torch.device) -> None:
594
  print(f"Device: {device}")
595
  if torch.cuda.is_available():
596
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
597
  print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
598
+ print(f"torch.version.cuda: {torch.version.cuda}")
599
+
600
+
601
+ def ensure_runtime_requirements(args: argparse.Namespace) -> None:
602
+ if args.require_cuda and not torch.cuda.is_available():
603
+ raise RuntimeError("CUDA is required for this run but torch.cuda.is_available() is False.")
604
+ if args.stage in {"2", "3a", "3b"} and not parse_dataset_configs(args.dataset_configs):
605
+ raise RuntimeError("Stage 2/3 requires at least one The Cauldron config.")
606
+ if args.stage in {"3a", "3b"} and not args.teacher_cache_dir:
607
+ raise RuntimeError("Stage 3 requires --teacher_cache_dir.")
608
+ if args.stage in {"3a", "3b"} and not args.resume_from:
609
+ raise RuntimeError("Stage 3 requires --resume_from pointing to a Stage 2 checkpoint.")
610
+ if args.stage == "3a":
611
+ try:
612
+ import bitsandbytes # noqa: F401
613
+ except ImportError as exc:
614
+ raise RuntimeError("Stage 3a requires bitsandbytes in the active environment.") from exc
615
+
616
+
617
+ def maybe_resume_model(model: ViLDLM, args: argparse.Namespace) -> None:
618
+ if not args.resume_from:
619
+ return
620
+ checkpoint_dir = Path(args.resume_from)
621
+ if not checkpoint_dir.exists():
622
+ raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
623
+ include_lm = args.stage in {"2", "3a", "3b"}
624
+ print(f"Resuming from checkpoint: {checkpoint_dir}")
625
+ model.load_checkpoint(checkpoint_dir, include_lm=include_lm)
626
+
627
+
628
+ def get_optimizer(model: ViLDLM, stage: str) -> AdamW:
629
+ if stage == "1":
630
+ groups = [
631
+ {
632
+ "params": [p for p in model.projector.parameters() if p.requires_grad],
633
+ "lr": 1e-3,
634
+ }
635
+ ]
636
+ else:
637
+ groups = [
638
+ {
639
+ "params": [p for p in model.vision_encoder.parameters() if p.requires_grad],
640
+ "lr": 2e-6,
641
+ },
642
+ {
643
+ "params": [p for p in model.projector.parameters() if p.requires_grad],
644
+ "lr": 1e-5,
645
+ },
646
+ {
647
+ "params": [p for p in model.lm.parameters() if p.requires_grad],
648
+ "lr": 1e-5,
649
+ },
650
+ ]
651
+ groups = [group for group in groups if group["params"]]
652
+ return AdamW(groups, weight_decay=0.05, betas=(0.9, 0.999))
653
+
654
+
655
+ def setup_model_for_stage(model: ViLDLM, stage: str) -> None:
656
+ if stage == "1":
657
+ print("\n=== STAGE 1: Projector-only alignment ===")
658
+ model.freeze_vision()
659
+ model.freeze_lm()
660
+ elif stage in {"2", "3b"}:
661
+ label = "Full finetune" if stage == "2" else "Sparse KD finetune"
662
+ print(f"\n=== STAGE {stage.upper()}: {label} ===")
663
+ model.unfreeze_all()
664
+ elif stage == "3a":
665
+ print("\n=== STAGE 3A: Teacher candidate-bank preparation ===")
666
+ model.unfreeze_all()
667
+ for param in model.parameters():
668
+ param.requires_grad = False
669
+ else:
670
+ raise ValueError(f"Unsupported stage: {stage}")
671
+
672
+
673
+ def compute_sparse_kd_loss(
674
+ student_logits: torch.Tensor,
675
+ noise_mask: torch.Tensor,
676
+ sample_ids: Sequence[str],
677
+ bank_map: Dict[str, List[Dict[str, object]]],
678
+ temperature: float,
679
+ ) -> Tuple[torch.Tensor, int]:
680
+ entries_used = 0
681
+ losses: List[torch.Tensor] = []
682
+ for batch_idx, sample_id in enumerate(sample_ids):
683
+ sample_entries = bank_map.get(sample_id, [])
684
+ for entry in sample_entries:
685
+ position = int(entry["position"])
686
+ if position >= student_logits.shape[1]:
687
+ continue
688
+ if not bool(noise_mask[batch_idx, position].item()):
689
+ continue
690
+ candidate_ids = torch.tensor(
691
+ entry["candidate_token_ids"],
692
+ device=student_logits.device,
693
+ dtype=torch.long,
694
+ )
695
+ teacher_probs = torch.tensor(
696
+ entry["teacher_probs"],
697
+ device=student_logits.device,
698
+ dtype=student_logits.dtype,
699
+ )
700
+ gathered = student_logits[batch_idx, position, candidate_ids]
701
+ student_log_probs = F.log_softmax(gathered / temperature, dim=-1)
702
+ loss = F.kl_div(
703
+ student_log_probs.unsqueeze(0),
704
+ teacher_probs.unsqueeze(0),
705
+ reduction="batchmean",
706
+ ) * (temperature ** 2)
707
+ losses.append(loss)
708
+ entries_used += 1
709
+
710
+ if not losses:
711
+ return torch.tensor(0.0, device=student_logits.device), 0
712
+ return torch.stack(losses).mean(), entries_used
713
+
714
+
715
+ def compute_teacher_logprobs(
716
+ teacher: AutoModelForImageTextToText,
717
+ processor: AutoProcessor,
718
+ pil_image: Image.Image,
719
+ prompt_text: str,
720
+ candidate_texts: Sequence[str],
721
+ teacher_batch_size: int,
722
+ ) -> torch.Tensor:
723
+ prompt_messages = [
724
+ {
725
+ "role": "user",
726
+ "content": [
727
+ {"type": "image", "image": pil_image},
728
+ {"type": "text", "text": prompt_text},
729
+ ],
730
+ }
731
+ ]
732
+ prompt_inputs = processor.apply_chat_template(
733
+ prompt_messages,
734
+ tokenize=True,
735
+ return_dict=True,
736
+ return_tensors="pt",
737
+ add_generation_prompt=True,
738
  )
739
+ prompt_len = prompt_inputs["input_ids"].shape[1]
740
+
741
+ teacher_device = next(teacher.parameters()).device
742
+ all_logprobs = []
743
+ for start in range(0, len(candidate_texts), max(teacher_batch_size, 1)):
744
+ batch_candidates = candidate_texts[start : start + max(teacher_batch_size, 1)]
745
+ conversations = []
746
+ for candidate_text in batch_candidates:
747
+ conversations.append(
748
+ [
749
+ {
750
+ "role": "user",
751
+ "content": [
752
+ {"type": "image", "image": pil_image},
753
+ {"type": "text", "text": prompt_text},
754
+ ],
755
+ },
756
+ {
757
+ "role": "assistant",
758
+ "content": [{"type": "text", "text": candidate_text}],
759
+ },
760
+ ]
761
+ )
762
+
763
+ batch_inputs = processor.apply_chat_template(
764
+ conversations,
765
+ tokenize=True,
766
+ return_dict=True,
767
+ return_tensors="pt",
768
+ padding=True,
769
+ add_generation_prompt=False,
770
+ )
771
+ batch_inputs = {key: value.to(teacher_device) for key, value in batch_inputs.items()}
772
+ outputs = teacher(**batch_inputs)
773
+ logits = outputs.logits[:, :-1, :]
774
+ labels = batch_inputs["input_ids"][:, 1:].clone()
775
+ attention_mask = batch_inputs["attention_mask"]
776
+
777
+ seq_len = batch_inputs["input_ids"].shape[1]
778
+ for batch_idx in range(labels.shape[0]):
779
+ valid_len = int(attention_mask[batch_idx].sum().item())
780
+ left_pad = seq_len - valid_len
781
+ prefix_cut = left_pad + prompt_len - 1
782
+ if prefix_cut > 0:
783
+ labels[batch_idx, :prefix_cut] = -100
784
+ labels[batch_idx, attention_mask[batch_idx, 1:] == 0] = -100
785
+
786
+ per_token = F.cross_entropy(
787
+ logits.reshape(-1, logits.shape[-1]),
788
+ labels.reshape(-1),
789
+ ignore_index=-100,
790
+ reduction="none",
791
+ ).reshape(labels.shape)
792
+ token_mask = (labels != -100).float()
793
+ all_logprobs.append(-(per_token * token_mask).sum(dim=-1).cpu())
794
+
795
+ return torch.cat(all_logprobs, dim=0)
796
+
797
+
798
+ def choose_distillation_positions(
799
+ clean_logits: torch.Tensor,
800
+ labels: torch.Tensor,
801
+ loss_mask: torch.Tensor,
802
+ max_positions: int,
803
+ ) -> List[int]:
804
+ valid_positions = torch.nonzero(loss_mask > 0, as_tuple=False).flatten()
805
+ if valid_positions.numel() == 0:
806
+ return []
807
+ probs = F.softmax(clean_logits[valid_positions], dim=-1)
808
+ gold = labels[valid_positions].unsqueeze(-1)
809
+ gold_probs = probs.gather(-1, gold).squeeze(-1)
810
+ _, ranked = torch.sort(gold_probs, descending=False)
811
+ selected = valid_positions[ranked][:max_positions]
812
+ return [int(pos.item()) for pos in selected]
813
+
814
+
815
+ def build_candidate_ids(
816
+ logits_at_position: torch.Tensor,
817
+ gold_token_id: int,
818
+ top_k: int,
819
+ ) -> List[int]:
820
+ candidate_ids = logits_at_position.topk(max(top_k - 1, 1)).indices.tolist()
821
+ if gold_token_id not in candidate_ids:
822
+ candidate_ids.append(gold_token_id)
823
+ deduped = []
824
+ seen = set()
825
+ for token_id in candidate_ids:
826
+ if token_id in seen:
827
+ continue
828
+ deduped.append(token_id)
829
+ seen.add(token_id)
830
+ return deduped[:top_k]
831
+
832
+
833
+ def decode_assistant_text(
834
+ tokenizer: AutoTokenizer,
835
+ full_ids: torch.Tensor,
836
+ attention_mask: torch.Tensor,
837
+ loss_mask: torch.Tensor,
838
+ ) -> str:
839
+ active = (attention_mask > 0) & (loss_mask > 0)
840
+ assistant_ids = full_ids[active].tolist()
841
+ return tokenizer.decode(assistant_ids, skip_special_tokens=True).strip()
842
+
843
+
844
+ def prepare_teacher_bank(
845
+ args: argparse.Namespace,
846
+ model: ViLDLM,
847
+ dataset: NormalizedVisionLanguageDataset,
848
+ ) -> None:
849
+ if args.dry_run_batches:
850
+ max_items = min(args.teacher_batch_size * args.dry_run_batches, len(dataset))
851
+ elif args.max_samples:
852
+ max_items = min(args.max_samples, len(dataset))
853
+ else:
854
+ max_items = len(dataset)
855
+
856
+ try:
857
+ from transformers import BitsAndBytesConfig
858
+ except ImportError as exc:
859
+ raise RuntimeError("bitsandbytes/transformers quantization support is required for Stage 3a.") from exc
860
+
861
+ print(f"Loading teacher: {args.teacher_model_id}")
862
+ quantization_config = BitsAndBytesConfig(
863
+ load_in_4bit=True,
864
+ bnb_4bit_compute_dtype=torch.bfloat16,
865
+ bnb_4bit_quant_type="nf4",
866
  )
867
+ teacher = AutoModelForImageTextToText.from_pretrained(
868
+ args.teacher_model_id,
869
+ quantization_config=quantization_config,
870
+ device_map="auto",
871
+ attn_implementation="sdpa",
872
+ )
873
+ teacher.eval()
874
+ processor = AutoProcessor.from_pretrained(args.teacher_model_id, padding_side="left")
875
+
876
+ cache_dir = Path(args.teacher_cache_dir)
877
+ cache_dir.mkdir(parents=True, exist_ok=True)
878
+ output_path = cache_dir / "candidate_bank.jsonl"
879
+ seen_keys = set()
880
+ if output_path.exists():
881
+ with open(output_path, "r", encoding="utf-8") as handle:
882
+ for line in handle:
883
+ if not line.strip():
884
+ continue
885
+ record = json.loads(line)
886
+ seen_keys.add((record["sample_id"], int(record["position"])))
887
+
888
+ dataloader = build_dataloader(
889
+ dataset=dataset,
890
+ batch_size=1,
891
+ shuffle=False,
892
+ num_workers=0,
893
+ persistent_workers=False,
894
+ )
895
+
896
+ prepared = 0
897
+ with torch.no_grad(), open(output_path, "a", encoding="utf-8") as writer:
898
+ for batch in dataloader:
899
+ sample_id = batch["sample_id"][0]
900
+ prompt_text = batch["prompt_text"][0]
901
+ target_text = batch["target_text"][0]
902
+ pil_image = batch["pil_image"][0]
903
+ pixel_values = batch["pixel_values"].to(next(model.parameters()).device)
904
+ input_ids = batch["input_ids"].to(pixel_values.device)
905
+ attention_mask = batch["attention_mask"].to(pixel_values.device)
906
+ labels = batch["labels"].to(pixel_values.device)
907
+ loss_mask = batch["loss_mask"].to(pixel_values.device)
908
+
909
+ clean_logits = model.predict_clean_logits(pixel_values, input_ids, attention_mask)[0]
910
+ sample_labels = labels[0]
911
+ sample_loss_mask = loss_mask[0]
912
+ positions = choose_distillation_positions(
913
+ clean_logits=clean_logits,
914
+ labels=sample_labels,
915
+ loss_mask=sample_loss_mask,
916
+ max_positions=args.kd_positions_per_sample,
917
+ )
918
+
919
+ for position in positions:
920
+ cache_key = (sample_id, position)
921
+ if cache_key in seen_keys:
922
+ continue
923
+ gold_token_id = int(sample_labels[position].item())
924
+ candidate_token_ids = build_candidate_ids(
925
+ logits_at_position=clean_logits[position],
926
+ gold_token_id=gold_token_id,
927
+ top_k=args.kd_top_k,
928
+ )
929
+ candidate_texts: List[str] = []
930
+ for candidate_id in candidate_token_ids:
931
+ modified_ids = input_ids[0].clone()
932
+ modified_ids[position] = candidate_id
933
+ candidate_texts.append(
934
+ decode_assistant_text(
935
+ tokenizer=model.tokenizer,
936
+ full_ids=modified_ids,
937
+ attention_mask=attention_mask[0],
938
+ loss_mask=loss_mask[0],
939
+ )
940
+ )
941
+ teacher_logprobs = compute_teacher_logprobs(
942
+ teacher=teacher,
943
+ processor=processor,
944
+ pil_image=pil_image,
945
+ prompt_text=prompt_text,
946
+ candidate_texts=candidate_texts,
947
+ teacher_batch_size=args.teacher_batch_size,
948
+ )
949
+ teacher_probs = F.softmax(teacher_logprobs / 2.0, dim=-1).cpu().tolist()
950
+ record = {
951
+ "sample_id": sample_id,
952
+ "position": position,
953
+ "candidate_token_ids": candidate_token_ids,
954
+ "teacher_probs": teacher_probs,
955
+ "gold_token_id": gold_token_id,
956
+ "source_config": batch["source_config"][0],
957
+ "text_hash": stable_text_hash(sample_id, prompt_text, target_text),
958
+ }
959
+ writer.write(json.dumps(record) + "\n")
960
+ seen_keys.add(cache_key)
961
+ prepared += 1
962
+ if args.dry_run_batches and prepared >= args.kd_positions_per_sample * args.dry_run_batches:
963
+ break
964
+ if prepared and prepared % 50 == 0:
965
+ print(f"Prepared {prepared} KD entries...")
966
+
967
+ print(f"Teacher bank written to {output_path} with {prepared} new entries")
968
+
969
+
970
+ def load_teacher_bank(cache_dir: str) -> Dict[str, List[Dict[str, object]]]:
971
+ bank_path = Path(cache_dir) / "candidate_bank.jsonl"
972
+ if not bank_path.exists():
973
+ raise FileNotFoundError(f"Teacher bank not found: {bank_path}")
974
+ bank_map: Dict[str, List[Dict[str, object]]] = defaultdict(list)
975
+ with open(bank_path, "r", encoding="utf-8") as handle:
976
+ for line in handle:
977
+ if not line.strip():
978
+ continue
979
+ record = json.loads(line)
980
+ bank_map[record["sample_id"]].append(record)
981
+ print(f"Loaded teacher bank for {len(bank_map)} samples from {bank_path}")
982
+ return bank_map
983
+
984
+
985
+ def maybe_push_to_hub(
986
+ args: argparse.Namespace,
987
+ save_dir: Path,
988
+ params: Dict[str, int],
989
+ best_loss: float,
990
+ ) -> None:
991
+ if not args.push_to_hub:
992
+ print("Skipping Hub push (enable with --push_to_hub).")
993
+ return
994
+
995
+ print("\nPushing to Hub...")
996
+ api = HfApi()
997
+ repo_id = args.hub_model_id
998
+ try:
999
+ api.create_repo(repo_id, exist_ok=True, private=False)
1000
+ except Exception as exc:
1001
+ print(f"Repo note: {exc}")
1002
+
1003
+ config_dict = {
1004
+ "architecture": "ViL-DLM",
1005
+ "training_stage": args.stage,
1006
+ "best_loss": best_loss,
1007
+ "total_params_M": params["total"] / 1e6,
1008
+ "trainable_params_M": params["trainable"] / 1e6,
1009
+ "teacher": args.teacher_model_id,
1010
+ "dataset_configs": parse_dataset_configs(args.dataset_configs) if args.stage in {"2", "3a", "3b"} else ["llava_pretrain"],
1011
+ }
1012
+ with open(save_dir / "model_config.json", "w", encoding="utf-8") as handle:
1013
+ json.dump(config_dict, handle, indent=2)
1014
+
1015
+ api.upload_folder(
1016
+ folder_path=str(save_dir),
1017
+ repo_id=repo_id,
1018
+ commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})",
1019
+ )
1020
+ print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}")
1021
+
1022
+
1023
+ def run_training_stage(args: argparse.Namespace) -> None:
1024
+ tracker = _TrackioShim()
1025
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1026
+ print_device_info(device)
1027
+ ensure_runtime_requirements(args)
1028
+ lm_path = download_student_backbone()
1029
+
1030
  vil_config = ViLConfig()
1031
  proj_config = ProjConfig()
1032
  model = ViLDLM(vil_config, proj_config, lm_path)
1033
+ setup_model_for_stage(model, args.stage)
1034
+ maybe_resume_model(model, args)
1035
+
 
 
 
 
 
 
 
1036
  params = model.count_params()
1037
  print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M")
1038
  print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M")
1039
+
1040
  model = model.to(device)
1041
+ if hasattr(model.lm, "gradient_checkpointing_enable"):
 
 
1042
  model.lm.gradient_checkpointing_enable()
1043
+
1044
+ dataset, skip_stats = create_stage_dataset("1" if args.stage == "1" else "2", model.tokenizer, args)
1045
+ if skip_stats:
1046
+ print(f"Skip stats: {json.dumps(skip_stats)}")
1047
+
1048
+ if args.stage == "3a":
1049
+ prepare_teacher_bank(args=args, model=model, dataset=dataset)
1050
+ return
1051
+
1052
+ dataloader = build_dataloader(
1053
+ dataset=dataset,
1054
+ batch_size=args.batch_size,
1055
+ shuffle=args.stage != "3a",
1056
+ num_workers=args.num_workers,
1057
+ persistent_workers=args.persistent_workers,
1058
  )
1059
+
1060
+ optimizer = get_optimizer(model, stage="1" if args.stage == "1" else "2")
1061
+ total_steps = max(1, (len(dataloader) * args.epochs) // max(args.grad_accum, 1))
1062
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  tracker.init(name=f"vil-dlm-stage{args.stage}")
1064
+ teacher_bank = load_teacher_bank(args.teacher_cache_dir) if args.stage == "3b" else {}
1065
+
1066
+ best_loss = float("inf")
1067
  global_step = 0
1068
+ step_timer = time.time()
1069
+
1070
  for epoch in range(args.epochs):
1071
  model.train()
1072
+ epoch_loss = 0.0
1073
+ epoch_kd_loss = 0.0
1074
+ epoch_kd_entries = 0
1075
  num_batches = 0
1076
+
1077
+ optimizer.zero_grad(set_to_none=True)
1078
  for batch_idx, batch in enumerate(dataloader):
1079
+ pixel_values = batch["pixel_values"].to(device)
1080
+ input_ids = batch["input_ids"].to(device)
1081
+ attention_mask = batch["attention_mask"].to(device)
1082
+ labels = batch["labels"].to(device)
1083
+ loss_mask = batch["loss_mask"].to(device)
1084
+
1085
+ outputs = model(
1086
+ pixel_values=pixel_values,
1087
+ input_ids=input_ids,
1088
+ attention_mask=attention_mask,
1089
+ labels=labels,
1090
+ loss_mask=loss_mask,
1091
+ )
1092
+ diffusion_loss = outputs["loss"]
1093
+ kd_loss = torch.tensor(0.0, device=device)
1094
+ kd_entries = 0
1095
+ total_loss = diffusion_loss
1096
+ if args.stage == "3b":
1097
+ kd_loss, kd_entries = compute_sparse_kd_loss(
1098
+ student_logits=outputs["logits"],
1099
+ noise_mask=outputs["noise_mask"],
1100
+ sample_ids=batch["sample_id"],
1101
+ bank_map=teacher_bank,
1102
+ temperature=2.0,
1103
+ )
1104
+ total_loss = (1.0 - 0.5) * diffusion_loss + 0.5 * kd_loss
1105
+
1106
+ loss = total_loss / args.grad_accum
1107
  loss.backward()
1108
+
1109
  if (batch_idx + 1) % args.grad_accum == 0:
1110
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1111
  optimizer.step()
1112
  scheduler.step()
1113
+ optimizer.zero_grad(set_to_none=True)
1114
  global_step += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
 
1116
+ actual_loss = float(total_loss.item())
1117
+ actual_diffusion = float(diffusion_loss.item())
1118
+ actual_kd = float(kd_loss.item()) if args.stage == "3b" else 0.0
1119
+ elapsed = max(time.time() - step_timer, 1e-6)
1120
+ samples_per_sec = (args.batch_size * args.grad_accum) / elapsed
1121
+ step_timer = time.time()
1122
+ gpu_mem_gb = 0.0
1123
+ if torch.cuda.is_available():
1124
+ gpu_mem_gb = torch.cuda.max_memory_allocated(device) / 1e9
1125
 
1126
+ print(
1127
+ f"[E{epoch}] Step {global_step}/{total_steps} | "
1128
+ f"Loss: {actual_loss:.4f} | Diff: {actual_diffusion:.4f} | "
1129
+ f"KD: {actual_kd:.4f} | KD entries: {kd_entries} | "
1130
+ f"Samples/s: {samples_per_sec:.2f} | GPU mem: {gpu_mem_gb:.2f} GB"
1131
+ )
1132
+ tracker.log(
1133
+ {
1134
+ "train/loss": actual_loss,
1135
+ "train/diffusion_loss": actual_diffusion,
1136
+ "train/kd_loss": actual_kd,
1137
+ "train/kd_entries": kd_entries,
1138
+ "train/epoch": epoch,
1139
+ "train/step": global_step,
1140
+ "train/samples_per_sec": samples_per_sec,
1141
+ "train/gpu_mem_gb": gpu_mem_gb,
1142
+ }
1143
+ )
1144
 
1145
+ epoch_loss += float(total_loss.item())
1146
+ epoch_kd_loss += float(kd_loss.item())
1147
+ epoch_kd_entries += kd_entries
1148
+ num_batches += 1
 
 
 
 
 
 
 
1149
 
1150
+ if args.dry_run_batches and num_batches >= args.dry_run_batches:
1151
+ break
1152
 
1153
+ remainder = num_batches % args.grad_accum
1154
+ if remainder != 0:
1155
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1156
+ optimizer.step()
1157
+ scheduler.step()
1158
+ optimizer.zero_grad(set_to_none=True)
1159
+ global_step += 1
1160
 
1161
+ avg_loss = epoch_loss / max(num_batches, 1)
1162
+ avg_kd_loss = epoch_kd_loss / max(num_batches, 1)
1163
+ print(f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f} | Average KD: {avg_kd_loss:.4f} | KD entries: {epoch_kd_entries}\n")
1164
+ tracker.log(
1165
+ {
1166
+ "train/epoch_loss": avg_loss,
1167
+ "train/epoch_kd_loss": avg_kd_loss,
1168
+ "train/epoch_kd_entries": epoch_kd_entries,
1169
+ "train/epoch": epoch,
1170
+ }
1171
+ )
1172
 
1173
+ if avg_loss < best_loss:
1174
+ best_loss = avg_loss
1175
+ save_dir = Path(args.output_dir) / f"stage{args.stage}_best"
1176
+ include_lm = args.stage in {"2", "3b"}
1177
+ model.save_checkpoint(save_dir, include_lm=include_lm)
1178
+ training_state = {
1179
+ "stage": args.stage,
1180
+ "best_loss": best_loss,
1181
+ "args": vars(args),
1182
+ }
1183
+ with open(save_dir / "training_state.json", "w", encoding="utf-8") as handle:
1184
+ json.dump(training_state, handle, indent=2)
1185
+ print(f"Saved best checkpoint (loss={best_loss:.4f})")
1186
 
1187
+ maybe_push_to_hub(
1188
+ args=args,
1189
+ save_dir=Path(args.output_dir) / f"stage{args.stage}_best",
1190
+ params=params,
1191
+ best_loss=best_loss,
1192
+ )
 
 
 
 
 
 
 
 
1193
  print("Training complete!")
1194
 
1195
 
1196
+ def build_parser() -> argparse.ArgumentParser:
1197
  parser = argparse.ArgumentParser()
1198
+ parser.add_argument("--stage", type=str, default="1", choices=["1", "2", "3a", "3b"])
1199
  parser.add_argument("--epochs", type=int, default=2)
1200
  parser.add_argument("--batch_size", type=int, default=4)
1201
  parser.add_argument("--grad_accum", type=int, default=8)
 
1203
  parser.add_argument("--max_samples", type=int, default=None)
1204
  parser.add_argument("--output_dir", type=str, default="./vil-dlm-output")
1205
  parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B")
1206
+ parser.add_argument("--push_to_hub", action="store_true")
1207
+ parser.add_argument("--require_cuda", action="store_true")
1208
+ parser.add_argument("--resume_from", type=str, default=None)
1209
+ parser.add_argument("--dataset_configs", type=str, default=",".join(DEFAULT_CAULDRON_CONFIGS))
1210
+ parser.add_argument("--num_workers", type=int, default=4)
1211
+ parser.add_argument("--persistent_workers", action="store_true")
1212
+ parser.add_argument("--dry_run_batches", type=int, default=0)
1213
+ parser.add_argument("--teacher_model_id", type=str, default="google/gemma-4-E2B-it")
1214
+ parser.add_argument("--teacher_cache_dir", type=str, default="./vil-dlm-output/teacher-cache")
1215
+ parser.add_argument("--prepare_teacher_bank", action="store_true")
1216
+ parser.add_argument("--teacher_batch_size", type=int, default=1)
1217
+ parser.add_argument("--kd_top_k", type=int, default=8)
1218
+ parser.add_argument("--kd_positions_per_sample", type=int, default=16)
1219
+ return parser
1220
+
1221
+
1222
+ if __name__ == "__main__":
1223
+ ensure_hf_cache_root()
1224
+ parser = build_parser()
1225
  args = parser.parse_args()
1226
+ if args.prepare_teacher_bank and args.stage != "3a":
1227
+ raise ValueError("--prepare_teacher_bank is only valid with --stage 3a")
1228
+ run_training_stage(args)
code/vil_dlm_model.py CHANGED
@@ -393,11 +393,9 @@ class ViLDLMWithDistillation(ViLDLM):
393
  """
394
  ViL-DLM with knowledge distillation from Gemma 4 E2B teacher.
395
 
396
- Distillation losses:
397
- 1. Response-level KD: KL(teacher_logits || student_logits) on text output
398
- 2. Vision feature KD: MSE(teacher_vision_features, projected_vil_features)
399
-
400
- Uses LFM2-style Decoupled Top-K distillation for efficiency.
401
  """
402
 
403
  def __init__(self, config: TrainingConfig):
@@ -442,60 +440,54 @@ class ViLDLMWithDistillation(ViLDLM):
442
 
443
  print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params")
444
 
445
- def compute_kd_loss(
446
  self,
447
- student_logits: torch.Tensor, # [B, T, student_vocab]
448
- teacher_logits: torch.Tensor, # [B, T, teacher_vocab]
449
- mask: torch.Tensor, # [B, T] where to compute loss
450
  ) -> torch.Tensor:
451
- """
452
- Decoupled Top-K KL divergence (LFM2 recipe).
453
- Only align on top-K teacher logits for efficiency.
454
- """
455
- T = self.kd_config.temperature
456
- K = self.kd_config.top_k_logits
457
-
458
- # Get top-K teacher predictions
459
- teacher_topk_vals, teacher_topk_idx = teacher_logits.topk(K, dim=-1)
460
- teacher_topk_probs = F.softmax(teacher_topk_vals / T, dim=-1)
461
-
462
- # Gather student logits at teacher's top-K positions
463
- # Need to handle vocab size mismatch between student and teacher
464
- # Student vocab: 151936 (Qwen3), Teacher vocab: 262144 (Gemma4)
465
- # Only use indices that are valid in student vocab
466
- valid_mask = teacher_topk_idx < student_logits.shape[-1]
467
- teacher_topk_idx_clamped = teacher_topk_idx.clamp(0, student_logits.shape[-1] - 1)
468
-
469
- student_topk_logits = torch.gather(student_logits, -1, teacher_topk_idx_clamped)
470
- student_topk_probs = F.softmax(student_topk_logits / T, dim=-1)
471
-
472
- # KL divergence on top-K
473
- kl = F.kl_div(
474
- student_topk_probs.log(),
475
- teacher_topk_probs,
476
- reduction='none'
477
- )
478
-
479
- # Apply valid mask and position mask
480
- kl = kl * valid_mask.float()
481
- kl = kl.sum(-1) # sum over top-K
482
-
483
- if mask.sum() > 0:
484
- loss = (kl * mask.float()).sum() / mask.sum()
485
- else:
486
- loss = kl.mean()
487
-
488
- return loss * (T ** 2) # scale by T² as is standard for KD
489
 
490
  def forward_with_distillation(
491
  self,
492
  pixel_values: torch.Tensor,
493
  input_ids: torch.Tensor,
494
  attention_mask: torch.Tensor,
495
- teacher_pixel_values: Optional[torch.Tensor] = None, # may need different preprocessing
496
  labels: Optional[torch.Tensor] = None,
 
497
  ) -> Dict[str, torch.Tensor]:
498
- """Forward with both diffusion loss and distillation loss"""
499
 
500
  # Student forward (diffusion loss)
501
  student_outputs = self.forward(
@@ -506,29 +498,11 @@ class ViLDLMWithDistillation(ViLDLM):
506
  )
507
 
508
  diffusion_loss = student_outputs['loss']
509
-
510
- # Teacher forward (no grad)
511
- if self.teacher is not None:
512
- with torch.no_grad():
513
- # Prepare teacher inputs
514
- teacher_inputs = {
515
- 'input_ids': input_ids,
516
- 'attention_mask': attention_mask,
517
- }
518
- if teacher_pixel_values is not None:
519
- teacher_inputs['pixel_values'] = teacher_pixel_values
520
-
521
- teacher_outputs = self.teacher(**teacher_inputs)
522
- teacher_logits = teacher_outputs.logits
523
-
524
- # Compute KD loss
525
- kd_loss = self.compute_kd_loss(
526
- student_logits=student_outputs['logits'],
527
- teacher_logits=teacher_logits,
528
- mask=student_outputs['noise_mask'],
529
- )
530
- else:
531
- kd_loss = torch.tensor(0.0, device=pixel_values.device)
532
 
533
  # Combined loss
534
  alpha = self.kd_config.alpha_kd
 
393
  """
394
  ViL-DLM with knowledge distillation from Gemma 4 E2B teacher.
395
 
396
+ Real Stage 3 uses sparse cross-tokenizer KD targets that are
397
+ prepared offline with the teacher and cached in the student's
398
+ token space.
 
 
399
  """
400
 
401
  def __init__(self, config: TrainingConfig):
 
440
 
441
  print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params")
442
 
443
+ def compute_sparse_kd_loss(
444
  self,
445
+ student_logits: torch.Tensor,
446
+ noise_mask: torch.Tensor,
447
+ kd_targets: Optional[list[dict[str, Any]]],
448
  ) -> torch.Tensor:
449
+ """Compute sparse KL in the student's token space."""
450
+ if not kd_targets:
451
+ return torch.tensor(0.0, device=student_logits.device)
452
+
453
+ temperature = self.kd_config.temperature
454
+ losses = []
455
+ for entry in kd_targets:
456
+ batch_idx = int(entry["batch_idx"])
457
+ position = int(entry["position"])
458
+ if position >= student_logits.shape[1]:
459
+ continue
460
+ if not bool(noise_mask[batch_idx, position].item()):
461
+ continue
462
+ candidate_token_ids = torch.tensor(
463
+ entry["candidate_token_ids"],
464
+ device=student_logits.device,
465
+ dtype=torch.long,
466
+ )
467
+ teacher_probs = torch.tensor(
468
+ entry["teacher_probs"],
469
+ device=student_logits.device,
470
+ dtype=student_logits.dtype,
471
+ )
472
+ gathered = student_logits[batch_idx, position, candidate_token_ids]
473
+ student_log_probs = F.log_softmax(gathered / temperature, dim=-1)
474
+ losses.append(
475
+ F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
476
+ )
477
+
478
+ if not losses:
479
+ return torch.tensor(0.0, device=student_logits.device)
480
+ return torch.stack(losses).mean()
 
 
 
 
 
 
481
 
482
  def forward_with_distillation(
483
  self,
484
  pixel_values: torch.Tensor,
485
  input_ids: torch.Tensor,
486
  attention_mask: torch.Tensor,
 
487
  labels: Optional[torch.Tensor] = None,
488
+ kd_targets: Optional[list[dict[str, Any]]] = None,
489
  ) -> Dict[str, torch.Tensor]:
490
+ """Forward with diffusion loss plus sparse cached KD targets."""
491
 
492
  # Student forward (diffusion loss)
493
  student_outputs = self.forward(
 
498
  )
499
 
500
  diffusion_loss = student_outputs['loss']
501
+ kd_loss = self.compute_sparse_kd_loss(
502
+ student_logits=student_outputs["logits"],
503
+ noise_mask=student_outputs["noise_mask"],
504
+ kd_targets=kd_targets,
505
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  # Combined loss
508
  alpha = self.kd_config.alpha_kd
pyproject.toml CHANGED
@@ -18,6 +18,7 @@ dev = [
18
  "datasets",
19
  "accelerate",
20
  "trackio",
 
21
  ]
22
 
23
  [tool.uv]
 
18
  "datasets",
19
  "accelerate",
20
  "trackio",
21
+ "bitsandbytes>=0.45.0; platform_system == 'Linux'",
22
  ]
23
 
24
  [tool.uv]