""" ViL-DLM: Vision xLSTM Diffusion Language Model Architecture: [Image] → ViL Encoder → MLP Projector → [Visual Tokens] [Visual Tokens] + [Text Tokens (masked)] → Bidirectional Diffusion LM → Denoised Tokens Components: 1. ViL (Vision xLSTM) - custom vision encoder with linear complexity 2. MLP Projector - maps ViL features to LM embedding space 3. Qwen3-0.6B Diffusion LM - bidirectional masked diffusion backbone (from dLLM) Training: Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain Stage 2: Full finetune on multimodal instruction data Stage 3: + Knowledge distillation from Gemma 4 E2B teacher Diffusion Process (MDLM): Forward: progressively mask tokens with [MASK] according to cosine schedule Reverse: iteratively predict masked tokens using bidirectional attention Loss: weighted cross-entropy on masked positions """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any, Tuple from transformers import AutoModelForImageTextToText, AutoModelForMaskedLM, AutoTokenizer from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig from vision_xlstm import VisionXLSTM, VisionProjector class MDLMScheduler: """ Masked Diffusion Language Model noise scheduler. Cosine schedule for masking probability. """ def __init__(self, num_steps=1000, mask_token_id=151643): self.num_steps = num_steps self.mask_token_id = mask_token_id def get_mask_ratio(self, t): """Cosine masking schedule: ratio of tokens to mask at timestep t""" # t in [0, 1]: 0 = clean, 1 = fully masked return torch.cos(t * math.pi / 2) # mask_ratio decreases as t→0 def add_noise(self, input_ids, t): """ Forward diffusion: mask tokens according to timestep t. Args: input_ids: [B, T] clean token ids t: [B] timestep in [0, 1] Returns: noisy_ids: [B, T] with some tokens replaced by mask mask: [B, T] boolean - True where tokens are masked """ B, T = input_ids.shape device = input_ids.device # Get mask ratio for each sample mask_ratio = 1.0 - self.get_mask_ratio(t) # Higher t → more masking mask_ratio = mask_ratio.unsqueeze(1).expand(B, T) # [B, T] # Sample mask: each token independently masked with probability mask_ratio rand = torch.rand(B, T, device=device) mask = rand < mask_ratio # True = masked # Replace masked tokens noisy_ids = input_ids.clone() noisy_ids[mask] = self.mask_token_id return noisy_ids, mask def sample_timesteps(self, batch_size, device): """Sample random timesteps for training""" return torch.rand(batch_size, device=device) class ViLDLM(nn.Module): """ Vision xLSTM Diffusion Language Model. Combines: - ViL encoder for image understanding - MLP projector for modality alignment - Qwen3-0.6B diffusion backbone for masked denoising """ def __init__(self, config: TrainingConfig): super().__init__() self.config = config # 1. Vision Encoder (ViL) self.vision_encoder = VisionXLSTM(config.vil_encoder) # 2. MLP Projector self.projector = VisionProjector(config.projector) # 3. Diffusion LM backbone (loaded from pretrained) self.lm = None # Will be loaded separately self.tokenizer = None # 4. Diffusion scheduler self.scheduler = MDLMScheduler( num_steps=config.diffusion.num_diffusion_steps, mask_token_id=config.diffusion.mask_token_id ) # 5. Special token embedding for image placeholder # We'll use the LM's embedding layer directly def load_diffusion_lm(self, local_path: str = None): """Load the pretrained diffusion LM backbone""" model_path = local_path or self.config.diffusion_lm_id print(f"Loading diffusion LM from {model_path}...") self.lm = AutoModelForMaskedLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float32, ) self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, ) print(f"Loaded diffusion LM: {sum(p.numel() for p in self.lm.parameters()) / 1e6:.1f}M params") return self def get_input_embeddings(self): """Get the LM's input embedding layer""" return self.lm.model.embed_tokens def prepare_multimodal_inputs( self, pixel_values: torch.Tensor, # [B, C, H, W] input_ids: torch.Tensor, # [B, T_text] attention_mask: torch.Tensor, # [B, T_text] image_token_id: int = None, # token id marking where image goes ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare multimodal input embeddings by: 1. Encoding image with ViL 2. Projecting to LM space 3. Concatenating [visual_tokens, text_tokens] Returns: inputs_embeds: [B, T_vis + T_text, D] full_attention_mask: [B, T_vis + T_text] """ B = pixel_values.shape[0] # Encode image with torch.set_grad_enabled(self.training): vision_features = self.vision_encoder.forward_features(pixel_values) # vision_features: [B, num_patches, vil_dim] # Project to LM space visual_tokens = self.projector(vision_features) # visual_tokens: [B, num_patches, lm_dim] # Get text embeddings text_embeds = self.get_input_embeddings()(input_ids) # text_embeds: [B, T_text, lm_dim] # Ensure matching dtype (ViL may be float32, LM may be bfloat16) target_dtype = text_embeds.dtype visual_tokens = visual_tokens.to(dtype=target_dtype) # Concatenate: [visual_tokens | text_tokens] inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1) # Build attention mask: all visual tokens are always visible num_vis = visual_tokens.shape[1] vis_mask = torch.ones(B, num_vis, device=attention_mask.device, dtype=attention_mask.dtype) full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1) return inputs_embeds, full_attention_mask def forward( self, pixel_values: torch.Tensor, # [B, C, H, W] input_ids: torch.Tensor, # [B, T] clean text tokens attention_mask: torch.Tensor, # [B, T] labels: Optional[torch.Tensor] = None, # [B, T] for loss computation ) -> Dict[str, torch.Tensor]: """ Training forward pass with MDLM diffusion loss. 1. Sample random timestep t 2. Mask tokens according to t (forward diffusion) 3. Encode image + masked text through model 4. Compute cross-entropy loss on masked positions """ B, T = input_ids.shape device = input_ids.device if labels is None: labels = input_ids.clone() # Sample timesteps t = self.scheduler.sample_timesteps(B, device) # Forward diffusion: mask text tokens noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t) # Prepare multimodal inputs with noisy text inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( pixel_values=pixel_values, input_ids=noisy_ids, attention_mask=attention_mask, ) # Forward through diffusion LM outputs = self.lm( inputs_embeds=inputs_embeds, attention_mask=full_attention_mask, ) # Get logits for text portion only (skip visual token positions) num_vis = self.config.vil_encoder.num_patches text_logits = outputs.logits[:, num_vis:, :] # [B, T, vocab_size] # Compute loss only on masked positions (MDLM objective) # Weight by timestep: positions masked at higher t get higher weight loss_mask = noise_mask.float() if loss_mask.sum() == 0: # Edge case: no masked tokens loss = torch.tensor(0.0, device=device, requires_grad=True) else: # Cross-entropy on masked positions logits_flat = text_logits.reshape(-1, text_logits.shape[-1]) labels_flat = labels.reshape(-1) loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none') loss_flat = loss_flat.reshape(B, T) # Apply mask: only count loss on masked tokens loss = (loss_flat * loss_mask).sum() / loss_mask.sum() return { 'loss': loss, 'logits': text_logits, 'noise_mask': noise_mask, 't': t, } def freeze_vision_encoder(self): """Freeze ViL encoder (Stage 1)""" for param in self.vision_encoder.parameters(): param.requires_grad = False def unfreeze_vision_encoder(self): """Unfreeze ViL encoder (Stage 2+)""" for param in self.vision_encoder.parameters(): param.requires_grad = True def freeze_lm(self): """Freeze diffusion LM backbone (Stage 1)""" for param in self.lm.parameters(): param.requires_grad = False def unfreeze_lm(self): """Unfreeze diffusion LM backbone (Stage 2+)""" for param in self.lm.parameters(): param.requires_grad = True def get_parameter_groups(self): """Get parameter groups with different learning rates""" groups = [ { 'params': [p for p in self.vision_encoder.parameters() if p.requires_grad], 'lr': self.config.vil_learning_rate, 'name': 'vision_encoder' }, { 'params': [p for p in self.projector.parameters() if p.requires_grad], 'lr': self.config.projector_learning_rate, 'name': 'projector' }, { 'params': [p for p in self.lm.parameters() if p.requires_grad], 'lr': self.config.learning_rate, 'name': 'diffusion_lm' }, ] return [g for g in groups if len(g['params']) > 0] @torch.no_grad() def generate( self, pixel_values: torch.Tensor, prompt_ids: Optional[torch.Tensor] = None, max_new_tokens: int = 128, num_steps: int = 64, temperature: float = 1.0, ) -> torch.Tensor: """ Generate text from image using iterative masked diffusion denoising. Steps: 1. Start with all-masked output tokens 2. At each step, predict all tokens, unmask most confident ones 3. Repeat until all tokens are unmasked """ self.eval() B = pixel_values.shape[0] device = pixel_values.device # Start with all masked tokens output_ids = torch.full( (B, max_new_tokens), self.scheduler.mask_token_id, device=device, dtype=torch.long ) # If prompt provided, prepend it if prompt_ids is not None: full_ids = torch.cat([prompt_ids, output_ids], dim=1) prompt_len = prompt_ids.shape[1] else: full_ids = output_ids prompt_len = 0 T_total = full_ids.shape[1] attention_mask = torch.ones(B, T_total, device=device) # Iterative denoising tokens_per_step = max(1, max_new_tokens // num_steps) for step in range(num_steps): # Get predictions inputs_embeds, full_attn = self.prepare_multimodal_inputs( pixel_values, full_ids, attention_mask ) outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attn) num_vis = self.config.vil_encoder.num_patches logits = outputs.logits[:, num_vis:, :] # text portion # Only update masked positions in the generation part gen_logits = logits[:, prompt_len:, :] # [B, max_new_tokens, vocab] gen_ids = full_ids[:, prompt_len:] # Find masked positions is_masked = (gen_ids == self.scheduler.mask_token_id) if not is_masked.any(): break # Get probabilities probs = F.softmax(gen_logits / temperature, dim=-1) predicted = probs.argmax(dim=-1) # [B, max_new_tokens] # Confidence = max probability confidence = probs.max(dim=-1).values # [B, max_new_tokens] confidence[~is_masked] = float('inf') # don't re-unmask # Unmask top-k most confident tokens num_to_unmask = min(tokens_per_step, is_masked.sum().item()) if num_to_unmask > 0: # Get indices of most confident masked positions _, topk_idx = confidence.topk(num_to_unmask, dim=-1, largest=True) # Unmask these positions for b in range(B): for idx in topk_idx[b]: if is_masked[b, idx]: full_ids[b, prompt_len + idx] = predicted[b, idx] return full_ids[:, prompt_len:] # Return generated tokens only def count_parameters(self): """Count parameters by component""" vil_params = sum(p.numel() for p in self.vision_encoder.parameters()) proj_params = sum(p.numel() for p in self.projector.parameters()) lm_params = sum(p.numel() for p in self.lm.parameters()) if self.lm else 0 total = vil_params + proj_params + lm_params trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return { 'vision_encoder': vil_params, 'projector': proj_params, 'diffusion_lm': lm_params, 'total': total, 'trainable': trainable, } class ViLDLMWithDistillation(ViLDLM): """ ViL-DLM with knowledge distillation from Gemma 4 E2B teacher. Real Stage 3 uses sparse cross-tokenizer KD targets that are prepared offline with the teacher and cached in the student's token space. """ def __init__(self, config: TrainingConfig): super().__init__(config) self.teacher = None self.teacher_processor = None self.kd_config = config.distillation def load_teacher(self): """Load Gemma 4 E2B as teacher (quantized for memory)""" from transformers import AutoProcessor print(f"Loading teacher: {self.kd_config.teacher_model_id}...") if self.kd_config.teacher_quantize: from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ) self.teacher = AutoModelForImageTextToText.from_pretrained( self.kd_config.teacher_model_id, quantization_config=bnb_config, device_map="auto", ) else: self.teacher = AutoModelForImageTextToText.from_pretrained( self.kd_config.teacher_model_id, torch_dtype=torch.bfloat16, device_map="auto", ) self.teacher_processor = AutoProcessor.from_pretrained( self.kd_config.teacher_model_id ) # Freeze teacher for param in self.teacher.parameters(): param.requires_grad = False self.teacher.eval() print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params") def compute_sparse_kd_loss( self, student_logits: torch.Tensor, noise_mask: torch.Tensor, kd_targets: Optional[list[dict[str, Any]]], ) -> torch.Tensor: """Compute sparse KL in the student's token space.""" if not kd_targets: return torch.tensor(0.0, device=student_logits.device) temperature = self.kd_config.temperature losses = [] for entry in kd_targets: batch_idx = int(entry["batch_idx"]) position = int(entry["position"]) if position >= student_logits.shape[1]: continue if not bool(noise_mask[batch_idx, position].item()): continue candidate_token_ids = torch.tensor( entry["candidate_token_ids"], device=student_logits.device, dtype=torch.long, ) teacher_probs = torch.tensor( entry["teacher_probs"], device=student_logits.device, dtype=student_logits.dtype, ) gathered = student_logits[batch_idx, position, candidate_token_ids] student_log_probs = F.log_softmax(gathered / temperature, dim=-1) losses.append( F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2) ) if not losses: return torch.tensor(0.0, device=student_logits.device) return torch.stack(losses).mean() def forward_with_distillation( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, kd_targets: Optional[list[dict[str, Any]]] = None, ) -> Dict[str, torch.Tensor]: """Forward with diffusion loss plus sparse cached KD targets.""" # Student forward (diffusion loss) student_outputs = self.forward( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) diffusion_loss = student_outputs['loss'] kd_loss = self.compute_sparse_kd_loss( student_logits=student_outputs["logits"], noise_mask=student_outputs["noise_mask"], kd_targets=kd_targets, ) # Combined loss alpha = self.kd_config.alpha_kd total_loss = (1 - alpha) * diffusion_loss + alpha * kd_loss return { 'loss': total_loss, 'diffusion_loss': diffusion_loss, 'kd_loss': kd_loss, 'logits': student_outputs['logits'], 'noise_mask': student_outputs['noise_mask'], 't': student_outputs['t'], }