Guarantee masked assistant tokens in diffusion training
Browse files- README.md +9 -2
- code/train_production.py +17 -3
README.md
CHANGED
|
@@ -82,16 +82,22 @@ This is a genuinely **unexplored frontier** in the literature:
|
|
| 82 |
## Running Training
|
| 83 |
|
| 84 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Stage 1: projector-only alignment
|
| 86 |
python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
|
| 87 |
|
| 88 |
# Stage 2: full-model finetune on the balanced Cauldron mix
|
| 89 |
python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
|
| 90 |
|
| 91 |
-
# Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint
|
| 92 |
python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_batch_size 2
|
| 93 |
|
| 94 |
-
# Stage 3b: sparse KD training from the cached teacher bank
|
| 95 |
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16
|
| 96 |
|
| 97 |
# Cheap validation gate for any stage
|
|
@@ -99,6 +105,7 @@ python code/train_production.py --stage 1 --require_cuda --dry_run_batches 1 --m
|
|
| 99 |
```
|
| 100 |
|
| 101 |
Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
|
|
|
|
| 102 |
|
| 103 |
### Hardware Requirements
|
| 104 |
- **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
|
|
|
|
| 82 |
## Running Training
|
| 83 |
|
| 84 |
```bash
|
| 85 |
+
# CPU smoke: Stage 1 projector path
|
| 86 |
+
python code/train_production.py --stage 1 --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 1 --dry_run_batches 1
|
| 87 |
+
|
| 88 |
+
# CPU smoke: Stage 2 subset path
|
| 89 |
+
python code/train_production.py --stage 2 --resume_from ./vil-dlm-output/stage1_best --dataset_configs ai2d,aokvqa --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 8 --dry_run_batches 1
|
| 90 |
+
|
| 91 |
# Stage 1: projector-only alignment
|
| 92 |
python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
|
| 93 |
|
| 94 |
# Stage 2: full-model finetune on the balanced Cauldron mix
|
| 95 |
python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
|
| 96 |
|
| 97 |
+
# Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint (GPU only)
|
| 98 |
python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_batch_size 2
|
| 99 |
|
| 100 |
+
# Stage 3b: sparse KD training from the cached teacher bank (GPU only)
|
| 101 |
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16
|
| 102 |
|
| 103 |
# Cheap validation gate for any stage
|
|
|
|
| 105 |
```
|
| 106 |
|
| 107 |
Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
|
| 108 |
+
CPU sessions should stop after the Stage 2 subset smoke test. Stage 3 requires a CUDA GPU because Gemma 4 teacher-bank preparation uses quantized multimodal teacher inference.
|
| 109 |
|
| 110 |
### Hardware Requirements
|
| 111 |
- **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
|
code/train_production.py
CHANGED
|
@@ -126,11 +126,24 @@ class MDLMScheduler:
|
|
| 126 |
def __init__(self, mask_token_id: int) -> None:
|
| 127 |
self.mask_token_id = mask_token_id
|
| 128 |
|
| 129 |
-
def add_noise(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
batch, length = input_ids.shape
|
| 131 |
mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
|
| 132 |
mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
|
| 133 |
mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
noisy_ids = input_ids.clone()
|
| 135 |
noisy_ids[mask] = self.mask_token_id
|
| 136 |
return noisy_ids, mask
|
|
@@ -209,7 +222,8 @@ class ViLDLM(nn.Module):
|
|
| 209 |
loss_mask = attention_mask
|
| 210 |
|
| 211 |
t = self.scheduler.sample_timesteps(batch_size, device)
|
| 212 |
-
|
|
|
|
| 213 |
inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
|
| 214 |
pixel_values=pixel_values,
|
| 215 |
input_ids=noisy_ids,
|
|
@@ -218,7 +232,7 @@ class ViLDLM(nn.Module):
|
|
| 218 |
outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
|
| 219 |
text_logits = outputs.logits[:, self.num_patches :, :]
|
| 220 |
|
| 221 |
-
active_mask = noise_mask.float() *
|
| 222 |
if active_mask.sum() == 0:
|
| 223 |
loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 224 |
else:
|
|
|
|
| 126 |
def __init__(self, mask_token_id: int) -> None:
|
| 127 |
self.mask_token_id = mask_token_id
|
| 128 |
|
| 129 |
+
def add_noise(
|
| 130 |
+
self,
|
| 131 |
+
input_ids: torch.Tensor,
|
| 132 |
+
t: torch.Tensor,
|
| 133 |
+
eligible_mask: Optional[torch.Tensor] = None,
|
| 134 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 135 |
batch, length = input_ids.shape
|
| 136 |
mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
|
| 137 |
mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
|
| 138 |
mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
|
| 139 |
+
if eligible_mask is not None:
|
| 140 |
+
eligible_mask = eligible_mask.bool()
|
| 141 |
+
mask = mask & eligible_mask
|
| 142 |
+
missing_mask = (mask.sum(dim=1) == 0) & (eligible_mask.sum(dim=1) > 0)
|
| 143 |
+
for batch_idx in torch.nonzero(missing_mask, as_tuple=False).flatten():
|
| 144 |
+
eligible_positions = torch.nonzero(eligible_mask[batch_idx], as_tuple=False).flatten()
|
| 145 |
+
chosen = eligible_positions[torch.randint(eligible_positions.numel(), (1,), device=input_ids.device)]
|
| 146 |
+
mask[batch_idx, chosen] = True
|
| 147 |
noisy_ids = input_ids.clone()
|
| 148 |
noisy_ids[mask] = self.mask_token_id
|
| 149 |
return noisy_ids, mask
|
|
|
|
| 222 |
loss_mask = attention_mask
|
| 223 |
|
| 224 |
t = self.scheduler.sample_timesteps(batch_size, device)
|
| 225 |
+
eligible_mask = (loss_mask > 0) & (attention_mask > 0)
|
| 226 |
+
noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t, eligible_mask=eligible_mask)
|
| 227 |
inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
|
| 228 |
pixel_values=pixel_values,
|
| 229 |
input_ids=noisy_ids,
|
|
|
|
| 232 |
outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
|
| 233 |
text_logits = outputs.logits[:, self.num_patches :, :]
|
| 234 |
|
| 235 |
+
active_mask = noise_mask.float() * eligible_mask.float()
|
| 236 |
if active_mask.sum() == 0:
|
| 237 |
loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 238 |
else:
|