omar-ah commited on
Commit
39b477f
·
verified ·
1 Parent(s): 519f856

Upload vil_dlm_model.py

Browse files
Files changed (1) hide show
  1. code/vil_dlm_model.py +545 -0
code/vil_dlm_model.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViL-DLM: Vision xLSTM Diffusion Language Model
3
+
4
+ Architecture:
5
+ [Image] → ViL Encoder → MLP Projector → [Visual Tokens]
6
+ [Visual Tokens] + [Text Tokens (masked)] → Bidirectional Diffusion LM → Denoised Tokens
7
+
8
+ Components:
9
+ 1. ViL (Vision xLSTM) - custom vision encoder with linear complexity
10
+ 2. MLP Projector - maps ViL features to LM embedding space
11
+ 3. Qwen3-0.6B Diffusion LM - bidirectional masked diffusion backbone (from dLLM)
12
+
13
+ Training:
14
+ Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain
15
+ Stage 2: Full finetune on multimodal instruction data
16
+ Stage 3: + Knowledge distillation from Gemma 4 E2B teacher
17
+
18
+ Diffusion Process (MDLM):
19
+ Forward: progressively mask tokens with [MASK] according to cosine schedule
20
+ Reverse: iteratively predict masked tokens using bidirectional attention
21
+ Loss: weighted cross-entropy on masked positions
22
+ """
23
+
24
+ import math
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from typing import Optional, Dict, Any, Tuple
29
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
30
+
31
+ from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig
32
+ from vision_xlstm import VisionXLSTM, VisionProjector
33
+
34
+
35
+ class MDLMScheduler:
36
+ """
37
+ Masked Diffusion Language Model noise scheduler.
38
+ Cosine schedule for masking probability.
39
+ """
40
+ def __init__(self, num_steps=1000, mask_token_id=151643):
41
+ self.num_steps = num_steps
42
+ self.mask_token_id = mask_token_id
43
+
44
+ def get_mask_ratio(self, t):
45
+ """Cosine masking schedule: ratio of tokens to mask at timestep t"""
46
+ # t in [0, 1]: 0 = clean, 1 = fully masked
47
+ return torch.cos(t * math.pi / 2) # mask_ratio decreases as t→0
48
+
49
+ def add_noise(self, input_ids, t):
50
+ """
51
+ Forward diffusion: mask tokens according to timestep t.
52
+
53
+ Args:
54
+ input_ids: [B, T] clean token ids
55
+ t: [B] timestep in [0, 1]
56
+ Returns:
57
+ noisy_ids: [B, T] with some tokens replaced by mask
58
+ mask: [B, T] boolean - True where tokens are masked
59
+ """
60
+ B, T = input_ids.shape
61
+ device = input_ids.device
62
+
63
+ # Get mask ratio for each sample
64
+ mask_ratio = 1.0 - self.get_mask_ratio(t) # Higher t → more masking
65
+ mask_ratio = mask_ratio.unsqueeze(1).expand(B, T) # [B, T]
66
+
67
+ # Sample mask: each token independently masked with probability mask_ratio
68
+ rand = torch.rand(B, T, device=device)
69
+ mask = rand < mask_ratio # True = masked
70
+
71
+ # Replace masked tokens
72
+ noisy_ids = input_ids.clone()
73
+ noisy_ids[mask] = self.mask_token_id
74
+
75
+ return noisy_ids, mask
76
+
77
+ def sample_timesteps(self, batch_size, device):
78
+ """Sample random timesteps for training"""
79
+ return torch.rand(batch_size, device=device)
80
+
81
+
82
+ class ViLDLM(nn.Module):
83
+ """
84
+ Vision xLSTM Diffusion Language Model.
85
+
86
+ Combines:
87
+ - ViL encoder for image understanding
88
+ - MLP projector for modality alignment
89
+ - Qwen3-0.6B diffusion backbone for masked denoising
90
+ """
91
+
92
+ def __init__(self, config: TrainingConfig):
93
+ super().__init__()
94
+ self.config = config
95
+
96
+ # 1. Vision Encoder (ViL)
97
+ self.vision_encoder = VisionXLSTM(config.vil_encoder)
98
+
99
+ # 2. MLP Projector
100
+ self.projector = VisionProjector(config.projector)
101
+
102
+ # 3. Diffusion LM backbone (loaded from pretrained)
103
+ self.lm = None # Will be loaded separately
104
+ self.tokenizer = None
105
+
106
+ # 4. Diffusion scheduler
107
+ self.scheduler = MDLMScheduler(
108
+ num_steps=config.diffusion.num_diffusion_steps,
109
+ mask_token_id=config.diffusion.mask_token_id
110
+ )
111
+
112
+ # 5. Special token embedding for image placeholder
113
+ # We'll use the LM's embedding layer directly
114
+
115
+ def load_diffusion_lm(self, local_path: str = None):
116
+ """Load the pretrained diffusion LM backbone"""
117
+ model_path = local_path or self.config.diffusion_lm_id
118
+ print(f"Loading diffusion LM from {model_path}...")
119
+ self.lm = AutoModelForMaskedLM.from_pretrained(
120
+ model_path,
121
+ trust_remote_code=True,
122
+ dtype=torch.bfloat16 if self.config.bf16 else torch.float32,
123
+ )
124
+ self.tokenizer = AutoTokenizer.from_pretrained(
125
+ model_path,
126
+ trust_remote_code=True,
127
+ )
128
+ print(f"Loaded diffusion LM: {sum(p.numel() for p in self.lm.parameters()) / 1e6:.1f}M params")
129
+ return self
130
+
131
+ def get_input_embeddings(self):
132
+ """Get the LM's input embedding layer"""
133
+ return self.lm.model.embed_tokens
134
+
135
+ def prepare_multimodal_inputs(
136
+ self,
137
+ pixel_values: torch.Tensor, # [B, C, H, W]
138
+ input_ids: torch.Tensor, # [B, T_text]
139
+ attention_mask: torch.Tensor, # [B, T_text]
140
+ image_token_id: int = None, # token id marking where image goes
141
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ """
143
+ Prepare multimodal input embeddings by:
144
+ 1. Encoding image with ViL
145
+ 2. Projecting to LM space
146
+ 3. Concatenating [visual_tokens, text_tokens]
147
+
148
+ Returns:
149
+ inputs_embeds: [B, T_vis + T_text, D]
150
+ full_attention_mask: [B, T_vis + T_text]
151
+ """
152
+ B = pixel_values.shape[0]
153
+
154
+ # Encode image
155
+ with torch.set_grad_enabled(self.training):
156
+ vision_features = self.vision_encoder.forward_features(pixel_values)
157
+ # vision_features: [B, num_patches, vil_dim]
158
+
159
+ # Project to LM space
160
+ visual_tokens = self.projector(vision_features)
161
+ # visual_tokens: [B, num_patches, lm_dim]
162
+
163
+ # Get text embeddings
164
+ text_embeds = self.get_input_embeddings()(input_ids)
165
+ # text_embeds: [B, T_text, lm_dim]
166
+
167
+ # Ensure matching dtype (ViL may be float32, LM may be bfloat16)
168
+ target_dtype = text_embeds.dtype
169
+ visual_tokens = visual_tokens.to(dtype=target_dtype)
170
+
171
+ # Concatenate: [visual_tokens | text_tokens]
172
+ inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
173
+
174
+ # Build attention mask: all visual tokens are always visible
175
+ num_vis = visual_tokens.shape[1]
176
+ vis_mask = torch.ones(B, num_vis, device=attention_mask.device, dtype=attention_mask.dtype)
177
+ full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
178
+
179
+ return inputs_embeds, full_attention_mask
180
+
181
+ def forward(
182
+ self,
183
+ pixel_values: torch.Tensor, # [B, C, H, W]
184
+ input_ids: torch.Tensor, # [B, T] clean text tokens
185
+ attention_mask: torch.Tensor, # [B, T]
186
+ labels: Optional[torch.Tensor] = None, # [B, T] for loss computation
187
+ ) -> Dict[str, torch.Tensor]:
188
+ """
189
+ Training forward pass with MDLM diffusion loss.
190
+
191
+ 1. Sample random timestep t
192
+ 2. Mask tokens according to t (forward diffusion)
193
+ 3. Encode image + masked text through model
194
+ 4. Compute cross-entropy loss on masked positions
195
+ """
196
+ B, T = input_ids.shape
197
+ device = input_ids.device
198
+
199
+ if labels is None:
200
+ labels = input_ids.clone()
201
+
202
+ # Sample timesteps
203
+ t = self.scheduler.sample_timesteps(B, device)
204
+
205
+ # Forward diffusion: mask text tokens
206
+ noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
207
+
208
+ # Prepare multimodal inputs with noisy text
209
+ inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
210
+ pixel_values=pixel_values,
211
+ input_ids=noisy_ids,
212
+ attention_mask=attention_mask,
213
+ )
214
+
215
+ # Forward through diffusion LM
216
+ outputs = self.lm(
217
+ inputs_embeds=inputs_embeds,
218
+ attention_mask=full_attention_mask,
219
+ )
220
+
221
+ # Get logits for text portion only (skip visual token positions)
222
+ num_vis = self.config.vil_encoder.num_patches
223
+ text_logits = outputs.logits[:, num_vis:, :] # [B, T, vocab_size]
224
+
225
+ # Compute loss only on masked positions (MDLM objective)
226
+ # Weight by timestep: positions masked at higher t get higher weight
227
+ loss_mask = noise_mask.float()
228
+
229
+ if loss_mask.sum() == 0:
230
+ # Edge case: no masked tokens
231
+ loss = torch.tensor(0.0, device=device, requires_grad=True)
232
+ else:
233
+ # Cross-entropy on masked positions
234
+ logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
235
+ labels_flat = labels.reshape(-1)
236
+ loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none')
237
+ loss_flat = loss_flat.reshape(B, T)
238
+
239
+ # Apply mask: only count loss on masked tokens
240
+ loss = (loss_flat * loss_mask).sum() / loss_mask.sum()
241
+
242
+ return {
243
+ 'loss': loss,
244
+ 'logits': text_logits,
245
+ 'noise_mask': noise_mask,
246
+ 't': t,
247
+ }
248
+
249
+ def freeze_vision_encoder(self):
250
+ """Freeze ViL encoder (Stage 1)"""
251
+ for param in self.vision_encoder.parameters():
252
+ param.requires_grad = False
253
+
254
+ def unfreeze_vision_encoder(self):
255
+ """Unfreeze ViL encoder (Stage 2+)"""
256
+ for param in self.vision_encoder.parameters():
257
+ param.requires_grad = True
258
+
259
+ def freeze_lm(self):
260
+ """Freeze diffusion LM backbone (Stage 1)"""
261
+ for param in self.lm.parameters():
262
+ param.requires_grad = False
263
+
264
+ def unfreeze_lm(self):
265
+ """Unfreeze diffusion LM backbone (Stage 2+)"""
266
+ for param in self.lm.parameters():
267
+ param.requires_grad = True
268
+
269
+ def get_parameter_groups(self):
270
+ """Get parameter groups with different learning rates"""
271
+ groups = [
272
+ {
273
+ 'params': [p for p in self.vision_encoder.parameters() if p.requires_grad],
274
+ 'lr': self.config.vil_learning_rate,
275
+ 'name': 'vision_encoder'
276
+ },
277
+ {
278
+ 'params': [p for p in self.projector.parameters() if p.requires_grad],
279
+ 'lr': self.config.projector_learning_rate,
280
+ 'name': 'projector'
281
+ },
282
+ {
283
+ 'params': [p for p in self.lm.parameters() if p.requires_grad],
284
+ 'lr': self.config.learning_rate,
285
+ 'name': 'diffusion_lm'
286
+ },
287
+ ]
288
+ return [g for g in groups if len(g['params']) > 0]
289
+
290
+ @torch.no_grad()
291
+ def generate(
292
+ self,
293
+ pixel_values: torch.Tensor,
294
+ prompt_ids: Optional[torch.Tensor] = None,
295
+ max_new_tokens: int = 128,
296
+ num_steps: int = 64,
297
+ temperature: float = 1.0,
298
+ ) -> torch.Tensor:
299
+ """
300
+ Generate text from image using iterative masked diffusion denoising.
301
+
302
+ Steps:
303
+ 1. Start with all-masked output tokens
304
+ 2. At each step, predict all tokens, unmask most confident ones
305
+ 3. Repeat until all tokens are unmasked
306
+ """
307
+ self.eval()
308
+ B = pixel_values.shape[0]
309
+ device = pixel_values.device
310
+
311
+ # Start with all masked tokens
312
+ output_ids = torch.full(
313
+ (B, max_new_tokens),
314
+ self.scheduler.mask_token_id,
315
+ device=device, dtype=torch.long
316
+ )
317
+
318
+ # If prompt provided, prepend it
319
+ if prompt_ids is not None:
320
+ full_ids = torch.cat([prompt_ids, output_ids], dim=1)
321
+ prompt_len = prompt_ids.shape[1]
322
+ else:
323
+ full_ids = output_ids
324
+ prompt_len = 0
325
+
326
+ T_total = full_ids.shape[1]
327
+ attention_mask = torch.ones(B, T_total, device=device)
328
+
329
+ # Iterative denoising
330
+ tokens_per_step = max(1, max_new_tokens // num_steps)
331
+
332
+ for step in range(num_steps):
333
+ # Get predictions
334
+ inputs_embeds, full_attn = self.prepare_multimodal_inputs(
335
+ pixel_values, full_ids, attention_mask
336
+ )
337
+ outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attn)
338
+
339
+ num_vis = self.config.vil_encoder.num_patches
340
+ logits = outputs.logits[:, num_vis:, :] # text portion
341
+
342
+ # Only update masked positions in the generation part
343
+ gen_logits = logits[:, prompt_len:, :] # [B, max_new_tokens, vocab]
344
+ gen_ids = full_ids[:, prompt_len:]
345
+
346
+ # Find masked positions
347
+ is_masked = (gen_ids == self.scheduler.mask_token_id)
348
+
349
+ if not is_masked.any():
350
+ break
351
+
352
+ # Get probabilities
353
+ probs = F.softmax(gen_logits / temperature, dim=-1)
354
+ predicted = probs.argmax(dim=-1) # [B, max_new_tokens]
355
+
356
+ # Confidence = max probability
357
+ confidence = probs.max(dim=-1).values # [B, max_new_tokens]
358
+ confidence[~is_masked] = float('inf') # don't re-unmask
359
+
360
+ # Unmask top-k most confident tokens
361
+ num_to_unmask = min(tokens_per_step, is_masked.sum().item())
362
+ if num_to_unmask > 0:
363
+ # Get indices of most confident masked positions
364
+ _, topk_idx = confidence.topk(num_to_unmask, dim=-1, largest=True)
365
+
366
+ # Unmask these positions
367
+ for b in range(B):
368
+ for idx in topk_idx[b]:
369
+ if is_masked[b, idx]:
370
+ full_ids[b, prompt_len + idx] = predicted[b, idx]
371
+
372
+ return full_ids[:, prompt_len:] # Return generated tokens only
373
+
374
+ def count_parameters(self):
375
+ """Count parameters by component"""
376
+ vil_params = sum(p.numel() for p in self.vision_encoder.parameters())
377
+ proj_params = sum(p.numel() for p in self.projector.parameters())
378
+ lm_params = sum(p.numel() for p in self.lm.parameters()) if self.lm else 0
379
+
380
+ total = vil_params + proj_params + lm_params
381
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
382
+
383
+ return {
384
+ 'vision_encoder': vil_params,
385
+ 'projector': proj_params,
386
+ 'diffusion_lm': lm_params,
387
+ 'total': total,
388
+ 'trainable': trainable,
389
+ }
390
+
391
+
392
+ class ViLDLMWithDistillation(ViLDLM):
393
+ """
394
+ ViL-DLM with knowledge distillation from Gemma 4 E2B teacher.
395
+
396
+ Distillation losses:
397
+ 1. Response-level KD: KL(teacher_logits || student_logits) on text output
398
+ 2. Vision feature KD: MSE(teacher_vision_features, projected_vil_features)
399
+
400
+ Uses LFM2-style Decoupled Top-K distillation for efficiency.
401
+ """
402
+
403
+ def __init__(self, config: TrainingConfig):
404
+ super().__init__(config)
405
+ self.teacher = None
406
+ self.teacher_processor = None
407
+ self.kd_config = config.distillation
408
+
409
+ def load_teacher(self):
410
+ """Load Gemma 4 E2B as teacher (quantized for memory)"""
411
+ from transformers import AutoProcessor
412
+
413
+ print(f"Loading teacher: {self.kd_config.teacher_model_id}...")
414
+
415
+ if self.kd_config.teacher_quantize:
416
+ from transformers import BitsAndBytesConfig
417
+ bnb_config = BitsAndBytesConfig(
418
+ load_in_4bit=True,
419
+ bnb_4bit_compute_dtype=torch.bfloat16,
420
+ bnb_4bit_quant_type="nf4",
421
+ )
422
+ self.teacher = AutoModelForMaskedLM.from_pretrained(
423
+ self.kd_config.teacher_model_id,
424
+ quantization_config=bnb_config,
425
+ device_map="auto",
426
+ )
427
+ else:
428
+ from transformers import AutoModelForImageTextToText
429
+ self.teacher = AutoModelForImageTextToText.from_pretrained(
430
+ self.kd_config.teacher_model_id,
431
+ torch_dtype=torch.bfloat16,
432
+ device_map="auto",
433
+ )
434
+
435
+ self.teacher_processor = AutoProcessor.from_pretrained(
436
+ self.kd_config.teacher_model_id
437
+ )
438
+
439
+ # Freeze teacher
440
+ for param in self.teacher.parameters():
441
+ param.requires_grad = False
442
+ self.teacher.eval()
443
+
444
+ print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params")
445
+
446
+ def compute_kd_loss(
447
+ self,
448
+ student_logits: torch.Tensor, # [B, T, student_vocab]
449
+ teacher_logits: torch.Tensor, # [B, T, teacher_vocab]
450
+ mask: torch.Tensor, # [B, T] where to compute loss
451
+ ) -> torch.Tensor:
452
+ """
453
+ Decoupled Top-K KL divergence (LFM2 recipe).
454
+ Only align on top-K teacher logits for efficiency.
455
+ """
456
+ T = self.kd_config.temperature
457
+ K = self.kd_config.top_k_logits
458
+
459
+ # Get top-K teacher predictions
460
+ teacher_topk_vals, teacher_topk_idx = teacher_logits.topk(K, dim=-1)
461
+ teacher_topk_probs = F.softmax(teacher_topk_vals / T, dim=-1)
462
+
463
+ # Gather student logits at teacher's top-K positions
464
+ # Need to handle vocab size mismatch between student and teacher
465
+ # Student vocab: 151936 (Qwen3), Teacher vocab: 262144 (Gemma4)
466
+ # Only use indices that are valid in student vocab
467
+ valid_mask = teacher_topk_idx < student_logits.shape[-1]
468
+ teacher_topk_idx_clamped = teacher_topk_idx.clamp(0, student_logits.shape[-1] - 1)
469
+
470
+ student_topk_logits = torch.gather(student_logits, -1, teacher_topk_idx_clamped)
471
+ student_topk_probs = F.softmax(student_topk_logits / T, dim=-1)
472
+
473
+ # KL divergence on top-K
474
+ kl = F.kl_div(
475
+ student_topk_probs.log(),
476
+ teacher_topk_probs,
477
+ reduction='none'
478
+ )
479
+
480
+ # Apply valid mask and position mask
481
+ kl = kl * valid_mask.float()
482
+ kl = kl.sum(-1) # sum over top-K
483
+
484
+ if mask.sum() > 0:
485
+ loss = (kl * mask.float()).sum() / mask.sum()
486
+ else:
487
+ loss = kl.mean()
488
+
489
+ return loss * (T ** 2) # scale by T² as is standard for KD
490
+
491
+ def forward_with_distillation(
492
+ self,
493
+ pixel_values: torch.Tensor,
494
+ input_ids: torch.Tensor,
495
+ attention_mask: torch.Tensor,
496
+ teacher_pixel_values: Optional[torch.Tensor] = None, # may need different preprocessing
497
+ labels: Optional[torch.Tensor] = None,
498
+ ) -> Dict[str, torch.Tensor]:
499
+ """Forward with both diffusion loss and distillation loss"""
500
+
501
+ # Student forward (diffusion loss)
502
+ student_outputs = self.forward(
503
+ pixel_values=pixel_values,
504
+ input_ids=input_ids,
505
+ attention_mask=attention_mask,
506
+ labels=labels,
507
+ )
508
+
509
+ diffusion_loss = student_outputs['loss']
510
+
511
+ # Teacher forward (no grad)
512
+ if self.teacher is not None:
513
+ with torch.no_grad():
514
+ # Prepare teacher inputs
515
+ teacher_inputs = {
516
+ 'input_ids': input_ids,
517
+ 'attention_mask': attention_mask,
518
+ }
519
+ if teacher_pixel_values is not None:
520
+ teacher_inputs['pixel_values'] = teacher_pixel_values
521
+
522
+ teacher_outputs = self.teacher(**teacher_inputs)
523
+ teacher_logits = teacher_outputs.logits
524
+
525
+ # Compute KD loss
526
+ kd_loss = self.compute_kd_loss(
527
+ student_logits=student_outputs['logits'],
528
+ teacher_logits=teacher_logits,
529
+ mask=student_outputs['noise_mask'],
530
+ )
531
+ else:
532
+ kd_loss = torch.tensor(0.0, device=pixel_values.device)
533
+
534
+ # Combined loss
535
+ alpha = self.kd_config.alpha_kd
536
+ total_loss = (1 - alpha) * diffusion_loss + alpha * kd_loss
537
+
538
+ return {
539
+ 'loss': total_loss,
540
+ 'diffusion_loss': diffusion_loss,
541
+ 'kd_loss': kd_loss,
542
+ 'logits': student_outputs['logits'],
543
+ 'noise_mask': student_outputs['noise_mask'],
544
+ 't': student_outputs['t'],
545
+ }