omar-ah commited on
Commit
02b453d
·
1 Parent(s): 8c5ef30

Guarantee masked assistant tokens in diffusion training

Browse files
Files changed (2) hide show
  1. README.md +9 -2
  2. code/train_production.py +17 -3
README.md CHANGED
@@ -82,16 +82,22 @@ This is a genuinely **unexplored frontier** in the literature:
82
  ## Running Training
83
 
84
  ```bash
 
 
 
 
 
 
85
  # Stage 1: projector-only alignment
86
  python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
87
 
88
  # Stage 2: full-model finetune on the balanced Cauldron mix
89
  python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
90
 
91
- # Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint
92
  python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_batch_size 2
93
 
94
- # Stage 3b: sparse KD training from the cached teacher bank
95
  python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16
96
 
97
  # Cheap validation gate for any stage
@@ -99,6 +105,7 @@ python code/train_production.py --stage 1 --require_cuda --dry_run_batches 1 --m
99
  ```
100
 
101
  Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
 
102
 
103
  ### Hardware Requirements
104
  - **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
 
82
  ## Running Training
83
 
84
  ```bash
85
+ # CPU smoke: Stage 1 projector path
86
+ python code/train_production.py --stage 1 --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 1 --dry_run_batches 1
87
+
88
+ # CPU smoke: Stage 2 subset path
89
+ python code/train_production.py --stage 2 --resume_from ./vil-dlm-output/stage1_best --dataset_configs ai2d,aokvqa --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 8 --dry_run_batches 1
90
+
91
  # Stage 1: projector-only alignment
92
  python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
93
 
94
  # Stage 2: full-model finetune on the balanced Cauldron mix
95
  python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
96
 
97
+ # Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint (GPU only)
98
  python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_batch_size 2
99
 
100
+ # Stage 3b: sparse KD training from the cached teacher bank (GPU only)
101
  python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16
102
 
103
  # Cheap validation gate for any stage
 
105
  ```
106
 
107
  Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
108
+ CPU sessions should stop after the Stage 2 subset smoke test. Stage 3 requires a CUDA GPU because Gemma 4 teacher-bank preparation uses quantized multimodal teacher inference.
109
 
110
  ### Hardware Requirements
111
  - **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
code/train_production.py CHANGED
@@ -126,11 +126,24 @@ class MDLMScheduler:
126
  def __init__(self, mask_token_id: int) -> None:
127
  self.mask_token_id = mask_token_id
128
 
129
- def add_noise(self, input_ids: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
130
  batch, length = input_ids.shape
131
  mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
132
  mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
133
  mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
 
 
 
 
 
 
 
 
134
  noisy_ids = input_ids.clone()
135
  noisy_ids[mask] = self.mask_token_id
136
  return noisy_ids, mask
@@ -209,7 +222,8 @@ class ViLDLM(nn.Module):
209
  loss_mask = attention_mask
210
 
211
  t = self.scheduler.sample_timesteps(batch_size, device)
212
- noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
 
213
  inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
214
  pixel_values=pixel_values,
215
  input_ids=noisy_ids,
@@ -218,7 +232,7 @@ class ViLDLM(nn.Module):
218
  outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
219
  text_logits = outputs.logits[:, self.num_patches :, :]
220
 
221
- active_mask = noise_mask.float() * loss_mask.float()
222
  if active_mask.sum() == 0:
223
  loss = torch.tensor(0.0, device=device, requires_grad=True)
224
  else:
 
126
  def __init__(self, mask_token_id: int) -> None:
127
  self.mask_token_id = mask_token_id
128
 
129
+ def add_noise(
130
+ self,
131
+ input_ids: torch.Tensor,
132
+ t: torch.Tensor,
133
+ eligible_mask: Optional[torch.Tensor] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
135
  batch, length = input_ids.shape
136
  mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
137
  mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
138
  mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
139
+ if eligible_mask is not None:
140
+ eligible_mask = eligible_mask.bool()
141
+ mask = mask & eligible_mask
142
+ missing_mask = (mask.sum(dim=1) == 0) & (eligible_mask.sum(dim=1) > 0)
143
+ for batch_idx in torch.nonzero(missing_mask, as_tuple=False).flatten():
144
+ eligible_positions = torch.nonzero(eligible_mask[batch_idx], as_tuple=False).flatten()
145
+ chosen = eligible_positions[torch.randint(eligible_positions.numel(), (1,), device=input_ids.device)]
146
+ mask[batch_idx, chosen] = True
147
  noisy_ids = input_ids.clone()
148
  noisy_ids[mask] = self.mask_token_id
149
  return noisy_ids, mask
 
222
  loss_mask = attention_mask
223
 
224
  t = self.scheduler.sample_timesteps(batch_size, device)
225
+ eligible_mask = (loss_mask > 0) & (attention_mask > 0)
226
+ noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t, eligible_mask=eligible_mask)
227
  inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
228
  pixel_values=pixel_values,
229
  input_ids=noisy_ids,
 
232
  outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
233
  text_logits = outputs.logits[:, self.num_patches :, :]
234
 
235
+ active_mask = noise_mask.float() * eligible_mask.float()
236
  if active_mask.sum() == 0:
237
  loss = torch.tensor(0.0, device=device, requires_grad=True)
238
  else: