File size: 19,321 Bytes
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b05eb6
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b05eb6
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d77b0a
 
 
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b05eb6
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d77b0a
39b477f
0d77b0a
 
 
39b477f
0d77b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39b477f
 
 
 
 
 
 
0d77b0a
39b477f
0d77b0a
39b477f
 
 
 
 
 
 
 
 
 
0d77b0a
 
 
 
 
39b477f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
"""
ViL-DLM: Vision xLSTM Diffusion Language Model

Architecture:
  [Image] → ViL Encoder → MLP Projector → [Visual Tokens]
  [Visual Tokens] + [Text Tokens (masked)] → Bidirectional Diffusion LM → Denoised Tokens

Components:
  1. ViL (Vision xLSTM) - custom vision encoder with linear complexity
  2. MLP Projector - maps ViL features to LM embedding space  
  3. Qwen3-0.6B Diffusion LM - bidirectional masked diffusion backbone (from dLLM)
  
Training:
  Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain
  Stage 2: Full finetune on multimodal instruction data
  Stage 3: + Knowledge distillation from Gemma 4 E2B teacher

Diffusion Process (MDLM):
  Forward: progressively mask tokens with [MASK] according to cosine schedule
  Reverse: iteratively predict masked tokens using bidirectional attention
  Loss: weighted cross-entropy on masked positions
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, Tuple
from transformers import AutoModelForImageTextToText, AutoModelForMaskedLM, AutoTokenizer

from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig
from vision_xlstm import VisionXLSTM, VisionProjector


class MDLMScheduler:
    """
    Masked Diffusion Language Model noise scheduler.
    Cosine schedule for masking probability.
    """
    def __init__(self, num_steps=1000, mask_token_id=151643):
        self.num_steps = num_steps
        self.mask_token_id = mask_token_id
        
    def get_mask_ratio(self, t):
        """Cosine masking schedule: ratio of tokens to mask at timestep t"""
        # t in [0, 1]: 0 = clean, 1 = fully masked
        return torch.cos(t * math.pi / 2)  # mask_ratio decreases as t→0
    
    def add_noise(self, input_ids, t):
        """
        Forward diffusion: mask tokens according to timestep t.
        
        Args:
            input_ids: [B, T] clean token ids
            t: [B] timestep in [0, 1]
        Returns:
            noisy_ids: [B, T] with some tokens replaced by mask
            mask: [B, T] boolean - True where tokens are masked
        """
        B, T = input_ids.shape
        device = input_ids.device
        
        # Get mask ratio for each sample
        mask_ratio = 1.0 - self.get_mask_ratio(t)  # Higher t → more masking
        mask_ratio = mask_ratio.unsqueeze(1).expand(B, T)  # [B, T]
        
        # Sample mask: each token independently masked with probability mask_ratio
        rand = torch.rand(B, T, device=device)
        mask = rand < mask_ratio  # True = masked
        
        # Replace masked tokens
        noisy_ids = input_ids.clone()
        noisy_ids[mask] = self.mask_token_id
        
        return noisy_ids, mask
    
    def sample_timesteps(self, batch_size, device):
        """Sample random timesteps for training"""
        return torch.rand(batch_size, device=device)


class ViLDLM(nn.Module):
    """
    Vision xLSTM Diffusion Language Model.
    
    Combines:
    - ViL encoder for image understanding
    - MLP projector for modality alignment
    - Qwen3-0.6B diffusion backbone for masked denoising
    """
    
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        
        # 1. Vision Encoder (ViL)
        self.vision_encoder = VisionXLSTM(config.vil_encoder)
        
        # 2. MLP Projector
        self.projector = VisionProjector(config.projector)
        
        # 3. Diffusion LM backbone (loaded from pretrained)
        self.lm = None  # Will be loaded separately
        self.tokenizer = None
        
        # 4. Diffusion scheduler
        self.scheduler = MDLMScheduler(
            num_steps=config.diffusion.num_diffusion_steps,
            mask_token_id=config.diffusion.mask_token_id
        )
        
        # 5. Special token embedding for image placeholder
        # We'll use the LM's embedding layer directly
        
    def load_diffusion_lm(self, local_path: str = None):
        """Load the pretrained diffusion LM backbone"""
        model_path = local_path or self.config.diffusion_lm_id
        print(f"Loading diffusion LM from {model_path}...")
        self.lm = AutoModelForMaskedLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float32,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
        )
        print(f"Loaded diffusion LM: {sum(p.numel() for p in self.lm.parameters()) / 1e6:.1f}M params")
        return self
    
    def get_input_embeddings(self):
        """Get the LM's input embedding layer"""
        return self.lm.model.embed_tokens
    
    def prepare_multimodal_inputs(
        self,
        pixel_values: torch.Tensor,      # [B, C, H, W]
        input_ids: torch.Tensor,          # [B, T_text]
        attention_mask: torch.Tensor,     # [B, T_text]
        image_token_id: int = None,       # token id marking where image goes
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare multimodal input embeddings by:
        1. Encoding image with ViL
        2. Projecting to LM space
        3. Concatenating [visual_tokens, text_tokens]
        
        Returns:
            inputs_embeds: [B, T_vis + T_text, D]
            full_attention_mask: [B, T_vis + T_text]
        """
        B = pixel_values.shape[0]
        
        # Encode image
        with torch.set_grad_enabled(self.training):
            vision_features = self.vision_encoder.forward_features(pixel_values)
            # vision_features: [B, num_patches, vil_dim]
        
        # Project to LM space
        visual_tokens = self.projector(vision_features)
        # visual_tokens: [B, num_patches, lm_dim]
        
        # Get text embeddings
        text_embeds = self.get_input_embeddings()(input_ids)
        # text_embeds: [B, T_text, lm_dim]
        
        # Ensure matching dtype (ViL may be float32, LM may be bfloat16)
        target_dtype = text_embeds.dtype
        visual_tokens = visual_tokens.to(dtype=target_dtype)
        
        # Concatenate: [visual_tokens | text_tokens]
        inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
        
        # Build attention mask: all visual tokens are always visible
        num_vis = visual_tokens.shape[1]
        vis_mask = torch.ones(B, num_vis, device=attention_mask.device, dtype=attention_mask.dtype)
        full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
        
        return inputs_embeds, full_attention_mask
    
    def forward(
        self,
        pixel_values: torch.Tensor,      # [B, C, H, W]
        input_ids: torch.Tensor,          # [B, T] clean text tokens
        attention_mask: torch.Tensor,     # [B, T]
        labels: Optional[torch.Tensor] = None,  # [B, T] for loss computation
    ) -> Dict[str, torch.Tensor]:
        """
        Training forward pass with MDLM diffusion loss.
        
        1. Sample random timestep t
        2. Mask tokens according to t (forward diffusion)
        3. Encode image + masked text through model
        4. Compute cross-entropy loss on masked positions
        """
        B, T = input_ids.shape
        device = input_ids.device
        
        if labels is None:
            labels = input_ids.clone()
        
        # Sample timesteps
        t = self.scheduler.sample_timesteps(B, device)
        
        # Forward diffusion: mask text tokens
        noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
        
        # Prepare multimodal inputs with noisy text
        inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
            pixel_values=pixel_values,
            input_ids=noisy_ids,
            attention_mask=attention_mask,
        )
        
        # Forward through diffusion LM
        outputs = self.lm(
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention_mask,
        )
        
        # Get logits for text portion only (skip visual token positions)
        num_vis = self.config.vil_encoder.num_patches
        text_logits = outputs.logits[:, num_vis:, :]  # [B, T, vocab_size]
        
        # Compute loss only on masked positions (MDLM objective)
        # Weight by timestep: positions masked at higher t get higher weight
        loss_mask = noise_mask.float()
        
        if loss_mask.sum() == 0:
            # Edge case: no masked tokens
            loss = torch.tensor(0.0, device=device, requires_grad=True)
        else:
            # Cross-entropy on masked positions
            logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
            labels_flat = labels.reshape(-1)
            loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none')
            loss_flat = loss_flat.reshape(B, T)
            
            # Apply mask: only count loss on masked tokens
            loss = (loss_flat * loss_mask).sum() / loss_mask.sum()
        
        return {
            'loss': loss,
            'logits': text_logits,
            'noise_mask': noise_mask,
            't': t,
        }
    
    def freeze_vision_encoder(self):
        """Freeze ViL encoder (Stage 1)"""
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
    
    def unfreeze_vision_encoder(self):
        """Unfreeze ViL encoder (Stage 2+)"""
        for param in self.vision_encoder.parameters():
            param.requires_grad = True
    
    def freeze_lm(self):
        """Freeze diffusion LM backbone (Stage 1)"""
        for param in self.lm.parameters():
            param.requires_grad = False
    
    def unfreeze_lm(self):
        """Unfreeze diffusion LM backbone (Stage 2+)"""
        for param in self.lm.parameters():
            param.requires_grad = True
    
    def get_parameter_groups(self):
        """Get parameter groups with different learning rates"""
        groups = [
            {
                'params': [p for p in self.vision_encoder.parameters() if p.requires_grad],
                'lr': self.config.vil_learning_rate,
                'name': 'vision_encoder'
            },
            {
                'params': [p for p in self.projector.parameters() if p.requires_grad],
                'lr': self.config.projector_learning_rate,
                'name': 'projector'
            },
            {
                'params': [p for p in self.lm.parameters() if p.requires_grad],
                'lr': self.config.learning_rate,
                'name': 'diffusion_lm'
            },
        ]
        return [g for g in groups if len(g['params']) > 0]
    
    @torch.no_grad()
    def generate(
        self,
        pixel_values: torch.Tensor,
        prompt_ids: Optional[torch.Tensor] = None,
        max_new_tokens: int = 128,
        num_steps: int = 64,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        Generate text from image using iterative masked diffusion denoising.
        
        Steps:
        1. Start with all-masked output tokens
        2. At each step, predict all tokens, unmask most confident ones
        3. Repeat until all tokens are unmasked
        """
        self.eval()
        B = pixel_values.shape[0]
        device = pixel_values.device
        
        # Start with all masked tokens
        output_ids = torch.full(
            (B, max_new_tokens), 
            self.scheduler.mask_token_id, 
            device=device, dtype=torch.long
        )
        
        # If prompt provided, prepend it
        if prompt_ids is not None:
            full_ids = torch.cat([prompt_ids, output_ids], dim=1)
            prompt_len = prompt_ids.shape[1]
        else:
            full_ids = output_ids
            prompt_len = 0
        
        T_total = full_ids.shape[1]
        attention_mask = torch.ones(B, T_total, device=device)
        
        # Iterative denoising
        tokens_per_step = max(1, max_new_tokens // num_steps)
        
        for step in range(num_steps):
            # Get predictions
            inputs_embeds, full_attn = self.prepare_multimodal_inputs(
                pixel_values, full_ids, attention_mask
            )
            outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attn)
            
            num_vis = self.config.vil_encoder.num_patches
            logits = outputs.logits[:, num_vis:, :]  # text portion
            
            # Only update masked positions in the generation part
            gen_logits = logits[:, prompt_len:, :]  # [B, max_new_tokens, vocab]
            gen_ids = full_ids[:, prompt_len:]
            
            # Find masked positions
            is_masked = (gen_ids == self.scheduler.mask_token_id)
            
            if not is_masked.any():
                break
            
            # Get probabilities
            probs = F.softmax(gen_logits / temperature, dim=-1)
            predicted = probs.argmax(dim=-1)  # [B, max_new_tokens]
            
            # Confidence = max probability
            confidence = probs.max(dim=-1).values  # [B, max_new_tokens]
            confidence[~is_masked] = float('inf')  # don't re-unmask
            
            # Unmask top-k most confident tokens
            num_to_unmask = min(tokens_per_step, is_masked.sum().item())
            if num_to_unmask > 0:
                # Get indices of most confident masked positions
                _, topk_idx = confidence.topk(num_to_unmask, dim=-1, largest=True)
                
                # Unmask these positions
                for b in range(B):
                    for idx in topk_idx[b]:
                        if is_masked[b, idx]:
                            full_ids[b, prompt_len + idx] = predicted[b, idx]
        
        return full_ids[:, prompt_len:]  # Return generated tokens only
    
    def count_parameters(self):
        """Count parameters by component"""
        vil_params = sum(p.numel() for p in self.vision_encoder.parameters())
        proj_params = sum(p.numel() for p in self.projector.parameters())
        lm_params = sum(p.numel() for p in self.lm.parameters()) if self.lm else 0
        
        total = vil_params + proj_params + lm_params
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            'vision_encoder': vil_params,
            'projector': proj_params,
            'diffusion_lm': lm_params,
            'total': total,
            'trainable': trainable,
        }


class ViLDLMWithDistillation(ViLDLM):
    """
    ViL-DLM with knowledge distillation from Gemma 4 E2B teacher.
    
    Real Stage 3 uses sparse cross-tokenizer KD targets that are
    prepared offline with the teacher and cached in the student's
    token space.
    """
    
    def __init__(self, config: TrainingConfig):
        super().__init__(config)
        self.teacher = None
        self.teacher_processor = None
        self.kd_config = config.distillation
    
    def load_teacher(self):
        """Load Gemma 4 E2B as teacher (quantized for memory)"""
        from transformers import AutoProcessor
        
        print(f"Loading teacher: {self.kd_config.teacher_model_id}...")
        
        if self.kd_config.teacher_quantize:
            from transformers import BitsAndBytesConfig
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_quant_type="nf4",
            )
            self.teacher = AutoModelForImageTextToText.from_pretrained(
                self.kd_config.teacher_model_id,
                quantization_config=bnb_config,
                device_map="auto",
            )
        else:
            self.teacher = AutoModelForImageTextToText.from_pretrained(
                self.kd_config.teacher_model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
        
        self.teacher_processor = AutoProcessor.from_pretrained(
            self.kd_config.teacher_model_id
        )
        
        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
        self.teacher.eval()
        
        print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params")
    
    def compute_sparse_kd_loss(
        self,
        student_logits: torch.Tensor,
        noise_mask: torch.Tensor,
        kd_targets: Optional[list[dict[str, Any]]],
    ) -> torch.Tensor:
        """Compute sparse KL in the student's token space."""
        if not kd_targets:
            return torch.tensor(0.0, device=student_logits.device)

        temperature = self.kd_config.temperature
        losses = []
        for entry in kd_targets:
            batch_idx = int(entry["batch_idx"])
            position = int(entry["position"])
            if position >= student_logits.shape[1]:
                continue
            if not bool(noise_mask[batch_idx, position].item()):
                continue
            candidate_token_ids = torch.tensor(
                entry["candidate_token_ids"],
                device=student_logits.device,
                dtype=torch.long,
            )
            teacher_probs = torch.tensor(
                entry["teacher_probs"],
                device=student_logits.device,
                dtype=student_logits.dtype,
            )
            gathered = student_logits[batch_idx, position, candidate_token_ids]
            student_log_probs = F.log_softmax(gathered / temperature, dim=-1)
            losses.append(
                F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
            )

        if not losses:
            return torch.tensor(0.0, device=student_logits.device)
        return torch.stack(losses).mean()
    
    def forward_with_distillation(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        kd_targets: Optional[list[dict[str, Any]]] = None,
    ) -> Dict[str, torch.Tensor]:
        """Forward with diffusion loss plus sparse cached KD targets."""
        
        # Student forward (diffusion loss)
        student_outputs = self.forward(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        
        diffusion_loss = student_outputs['loss']
        kd_loss = self.compute_sparse_kd_loss(
            student_logits=student_outputs["logits"],
            noise_mask=student_outputs["noise_mask"],
            kd_targets=kd_targets,
        )
        
        # Combined loss
        alpha = self.kd_config.alpha_kd
        total_loss = (1 - alpha) * diffusion_loss + alpha * kd_loss
        
        return {
            'loss': total_loss,
            'diffusion_loss': diffusion_loss,
            'kd_loss': kd_loss,
            'logits': student_outputs['logits'],
            'noise_mask': student_outputs['noise_mask'],
            't': student_outputs['t'],
        }