ViL-DLM-0.6B / code /train_production.py
omar-ah's picture
Add timestep-aware sparse KD weighting
25e4efd
"""
ViL-DLM production training script.
Stages:
1 - projector-only alignment on LLaVA-Pretrain
2 - full-model finetune on The Cauldron
3a - offline teacher candidate-bank preparation with Gemma 4 E2B-it
3b - sparse cross-tokenizer distillation training using cached teacher targets
"""
import argparse
import hashlib
import json
import math
import os
import time
import zipfile
from collections import defaultdict
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets, load_dataset
from datasets import Features, Image as HFImage
from datasets.features import Sequence as HFSequence
from huggingface_hub import HfApi, snapshot_download
from PIL import Image
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from transformers import (
AutoModelForImageTextToText,
AutoModelForMaskedLM,
AutoProcessor,
AutoTokenizer,
)
try:
import trackio
except Exception:
trackio = None
from vision_xlstm import (
VisionProjector as UpstreamVisionProjector,
VisionXLSTM as UpstreamVisionXLSTM,
)
DEFAULT_CAULDRON_CONFIGS = [
"ai2d",
"vqav2",
"aokvqa",
"textvqa",
"docvqa",
"chartqa",
"textcaps",
"screen2words",
]
DEFAULT_CAULDRON_GATE_CONFIGS = [
"ai2d",
"aokvqa",
]
@dataclass
class ViLConfig:
vision_backbone: str = "vil2-small"
pretrained: bool = True
img_size: int = 224
patch_size: int = 16
in_channels: int = 3
dim: int = 384
depth: int = 24
conv_kernel_size: int = 3
bidirectional: bool = True
dropout: float = 0.0
@property
def num_patches(self) -> int:
return (self.img_size // self.patch_size) ** 2
@dataclass
class ProjConfig:
vil_dim: int = 384
lm_dim: int = 1024
hidden_mult: int = 2
num_layers: int = 2
dropout: float = 0.0
class _TrackioShim:
def __init__(self) -> None:
self.enabled = False
def init(self, name: str, project: str = "vil-dlm") -> None:
if trackio is None:
print("Trackio disabled: package not installed in the active environment")
self.enabled = False
return
try:
trackio.init(name=name, project=project)
self.enabled = True
except Exception as exc:
self.enabled = False
print(f"Trackio disabled: {exc}")
def log(self, payload: dict) -> None:
if not self.enabled:
return
try:
trackio.log(payload)
except Exception as exc:
self.enabled = False
print(f"Trackio logging disabled after error: {exc}")
class MDLMScheduler:
def __init__(self, mask_token_id: int) -> None:
self.mask_token_id = mask_token_id
def add_noise(
self,
input_ids: torch.Tensor,
t: torch.Tensor,
eligible_mask: Optional[torch.Tensor] = None,
force_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch, length = input_ids.shape
mask_ratio = 1.0 - torch.cos(t * math.pi / 2)
mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length)
mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio
if eligible_mask is not None:
eligible_mask = eligible_mask.bool()
mask = mask & eligible_mask
if force_mask is not None:
mask = mask | (force_mask.bool() & eligible_mask)
missing_mask = (mask.sum(dim=1) == 0) & (eligible_mask.sum(dim=1) > 0)
for batch_idx in torch.nonzero(missing_mask, as_tuple=False).flatten():
eligible_positions = torch.nonzero(eligible_mask[batch_idx], as_tuple=False).flatten()
chosen = eligible_positions[torch.randint(eligible_positions.numel(), (1,), device=input_ids.device)]
mask[batch_idx, chosen] = True
elif force_mask is not None:
mask = mask | force_mask.bool()
noisy_ids = input_ids.clone()
noisy_ids[mask] = self.mask_token_id
return noisy_ids, mask
def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.rand(batch_size, device=device)
class ViLDLM(nn.Module):
def __init__(self, vil_config: ViLConfig, proj_config: ProjConfig, lm_path: str) -> None:
super().__init__()
self.vil_config = vil_config
self.vision_encoder = UpstreamVisionXLSTM(vil_config)
self.projector = UpstreamVisionProjector(proj_config)
self.lm = AutoModelForMaskedLM.from_pretrained(
lm_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id)
@property
def num_patches(self) -> int:
return self.vil_config.num_patches
def prepare_multimodal_inputs(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
vision_features = self.vision_encoder.forward_features(pixel_values)
visual_tokens = self.projector(vision_features)
text_embeds = self.lm.model.embed_tokens(input_ids)
visual_tokens = visual_tokens.to(dtype=text_embeds.dtype)
inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
vis_mask = torch.ones(
pixel_values.shape[0],
self.num_patches,
device=attention_mask.device,
dtype=attention_mask.dtype,
)
full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
return inputs_embeds, full_attention_mask
def predict_clean_logits(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
)
outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
return outputs.logits[:, self.num_patches :, :]
def forward(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
force_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
batch_size, seq_len = input_ids.shape
device = input_ids.device
if labels is None:
labels = input_ids.clone()
if loss_mask is None:
loss_mask = attention_mask
t = self.scheduler.sample_timesteps(batch_size, device)
eligible_mask = (loss_mask > 0) & (attention_mask > 0)
noisy_ids, noise_mask = self.scheduler.add_noise(
input_ids,
t,
eligible_mask=eligible_mask,
force_mask=force_mask,
)
inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
pixel_values=pixel_values,
input_ids=noisy_ids,
attention_mask=attention_mask,
)
outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask)
text_logits = outputs.logits[:, self.num_patches :, :]
active_mask = noise_mask.float() * eligible_mask.float()
if active_mask.sum() == 0:
loss = torch.tensor(0.0, device=device, requires_grad=True)
else:
logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
labels_flat = labels.reshape(-1)
per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none").reshape(batch_size, seq_len)
loss = (per_token * active_mask).sum() / active_mask.sum()
return {
"loss": loss,
"logits": text_logits,
"noise_mask": noise_mask,
"t": t,
}
def freeze_vision(self) -> None:
for param in self.vision_encoder.parameters():
param.requires_grad = False
def freeze_lm(self) -> None:
for param in self.lm.parameters():
param.requires_grad = False
def unfreeze_all(self) -> None:
for param in self.parameters():
param.requires_grad = True
def count_params(self) -> Dict[str, int]:
vil = sum(p.numel() for p in self.vision_encoder.parameters())
proj = sum(p.numel() for p in self.projector.parameters())
lm = sum(p.numel() for p in self.lm.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {"vil": vil, "proj": proj, "lm": lm, "total": vil + proj + lm, "trainable": trainable}
def save_checkpoint(self, save_dir: Path, include_lm: bool) -> None:
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.vision_encoder.state_dict(), save_dir / "vision_encoder.pt")
torch.save(self.projector.state_dict(), save_dir / "projector.pt")
if include_lm:
self.lm.save_pretrained(save_dir / "diffusion_lm")
self.tokenizer.save_pretrained(save_dir / "diffusion_lm")
def load_checkpoint(self, checkpoint_dir: Path, include_lm: bool) -> None:
vision_path = checkpoint_dir / "vision_encoder.pt"
projector_path = checkpoint_dir / "projector.pt"
if vision_path.exists():
self.vision_encoder.load_state_dict(torch.load(vision_path, map_location="cpu"))
if projector_path.exists():
self.projector.load_state_dict(torch.load(projector_path, map_location="cpu"))
if include_lm:
diffusion_dir = checkpoint_dir / "diffusion_lm"
if diffusion_dir.exists():
self.lm = AutoModelForMaskedLM.from_pretrained(
diffusion_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
self.tokenizer = AutoTokenizer.from_pretrained(diffusion_dir, trust_remote_code=True)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id)
def ensure_hf_cache_root() -> None:
os.environ.setdefault("HF_HOME", "/teamspace/studios/this_studio/.cache/huggingface")
def patch_diffusion_modeling_file(lm_path: str) -> None:
modeling_file = os.path.join(lm_path, "modeling_qwen3.py")
with open(modeling_file, "r", encoding="utf-8") as handle:
content = handle.read()
content = content.replace(
'if __name__ == "__main__":\n import dllm',
'if __name__ == "__main__":\n pass\n # import dllm',
)
content = content.replace(
"attention_mask=causal_mask_mapping[decoder_layer.attention_type]",
'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))',
)
with open(modeling_file, "w", encoding="utf-8") as handle:
handle.write(content)
def download_student_backbone() -> str:
print("Downloading dLLM Qwen3-0.6B diffusion model...")
lm_path = snapshot_download("dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1")
patch_diffusion_modeling_file(lm_path)
print(f"Model downloaded to {lm_path}")
return lm_path
def parse_dataset_configs(dataset_configs: Optional[str]) -> List[str]:
if dataset_configs:
return [item.strip() for item in dataset_configs.split(",") if item.strip()]
return list(DEFAULT_CAULDRON_CONFIGS)
def resolve_cauldron_configs(args: argparse.Namespace) -> List[str]:
configs = parse_dataset_configs(args.dataset_configs)
default_config_string = ",".join(DEFAULT_CAULDRON_CONFIGS)
if args.dry_run_batches and args.dataset_configs == default_config_string:
print(
"Dry-run mode detected; using the cheap Stage 2 gate config set "
f"{DEFAULT_CAULDRON_GATE_CONFIGS} instead of the full production mix."
)
return list(DEFAULT_CAULDRON_GATE_CONFIGS)
return configs
def stable_text_hash(*parts: str) -> str:
joined = "\n".join(parts)
return hashlib.sha1(joined.encode("utf-8")).hexdigest()
def build_prompt_prefix(prompt_text: str) -> str:
return f"User: {prompt_text.strip()}\nAssistant:"
def tokenize_prompt_and_target(
tokenizer: AutoTokenizer,
prompt_text: str,
target_text: str,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
prefix_text = build_prompt_prefix(prompt_text)
prefix_ids = tokenizer(prefix_text, add_special_tokens=True)["input_ids"]
target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"]
if not target_ids:
target_ids = tokenizer(" " + "N/A", add_special_tokens=False)["input_ids"][:1]
max_prefix_len = max_length - 1
if len(prefix_ids) > max_prefix_len:
prefix_ids = prefix_ids[:max_prefix_len]
remaining = max_length - len(prefix_ids)
if remaining <= 0:
prefix_ids = prefix_ids[: max_length - 1]
remaining = 1
target_ids = target_ids[:remaining]
if not target_ids:
prefix_ids = prefix_ids[: max_length - 1]
target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"][:1]
input_ids = prefix_ids + target_ids
loss_mask = [0] * len(prefix_ids) + [1] * len(target_ids)
attention_mask = [1] * len(input_ids)
labels = list(input_ids)
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
pad_len = max_length - len(input_ids)
if pad_len > 0:
input_ids = input_ids + [pad_token_id] * pad_len
attention_mask = attention_mask + [0] * pad_len
labels = labels + [pad_token_id] * pad_len
loss_mask = loss_mask + [0] * pad_len
return (
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(attention_mask, dtype=torch.long),
torch.tensor(labels, dtype=torch.long),
torch.tensor(loss_mask, dtype=torch.float32),
)
def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
if isinstance(img, str):
img = Image.open(img).convert("RGB")
elif isinstance(img, dict) and "zip_path" in img and "member" in img:
with zipfile.ZipFile(img["zip_path"], "r") as archive:
with archive.open(img["member"], "r") as member_file:
img = Image.open(member_file).convert("RGB")
elif isinstance(img, dict) and img.get("bytes") is not None:
img = Image.open(BytesIO(img["bytes"])).convert("RGB")
elif isinstance(img, dict) and img.get("path") and os.path.exists(img["path"]):
img = Image.open(img["path"]).convert("RGB")
elif isinstance(img, Image.Image):
img = img.convert("RGB")
else:
raise ValueError(f"Unsupported image payload type: {type(img)!r}")
pil_image = img
resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
arr = np.array(resized).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1)
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
tensor = (tensor - mean) / std
return tensor, pil_image
def is_usable_image_payload(img: object) -> bool:
if isinstance(img, Image.Image):
return True
if isinstance(img, str):
return os.path.exists(img)
if isinstance(img, dict):
if img.get("zip_path") and img.get("member"):
return os.path.exists(img["zip_path"])
if img.get("bytes") is not None:
return True
if img.get("path"):
return os.path.exists(img["path"])
return False
class NormalizedVisionLanguageDataset(Dataset):
def __init__(
self,
records: HFDataset,
tokenizer: AutoTokenizer,
max_length: int,
img_size: int,
) -> None:
self.records = records
self.tokenizer = tokenizer
self.max_length = max_length
self.img_size = img_size
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int) -> Dict[str, object]:
sample = self.records[int(idx)]
pixel_values, pil_image = preprocess_image_for_student(sample["image"], self.img_size)
input_ids, attention_mask, labels, loss_mask = tokenize_prompt_and_target(
tokenizer=self.tokenizer,
prompt_text=sample["prompt_text"],
target_text=sample["target_text"],
max_length=self.max_length,
)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"loss_mask": loss_mask,
"sample_id": sample["sample_id"],
"prompt_text": sample["prompt_text"],
"target_text": sample["target_text"],
"source_config": sample.get("source_config", "unknown"),
"pil_image": pil_image,
}
def build_llava_records(max_samples: Optional[int]) -> HFDataset:
print("Loading LLaVA-Pretrain dataset...")
dataset_root = None
images_zip_path = None
zip_members = None
try:
data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
except Exception as exc:
print(f"Primary dataset loader failed ({exc}). Falling back to direct JSON loading...")
dataset_root = snapshot_download(
"liuhaotian/LLaVA-Pretrain",
repo_type="dataset",
allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
)
json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
images_zip_path = os.path.join(dataset_root, "images.zip")
if os.path.exists(images_zip_path):
with zipfile.ZipFile(images_zip_path, "r") as archive:
zip_members = set(archive.namelist())
data = load_dataset("json", data_files={"train": json_path}, split="train")
if max_samples:
data = data.select(range(min(max_samples, len(data))))
stats = defaultdict(int)
def normalize(sample: Dict[str, object], idx: int) -> Optional[Dict[str, object]]:
text = ""
if "conversations" in sample:
parts = []
for turn in sample["conversations"]:
val = turn.get("value", "").replace("<image>\n", "").replace("<image>", "").strip()
if val:
parts.append(val)
text = " ".join(parts)
elif sample.get("blip_caption"):
text = sample["blip_caption"].strip()
if not text:
text = "Describe this image."
image_obj = sample.get("image")
if image_obj is None:
stats["missing_image_ref"] += 1
return None
if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
candidate_paths = [
image_obj,
os.path.join(dataset_root, image_obj),
os.path.join(dataset_root, "images", image_obj),
]
resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
if resolved_path:
image_obj = resolved_path
elif images_zip_path and os.path.exists(images_zip_path) and zip_members:
member_name = None
if image_obj in zip_members:
member_name = image_obj
elif f"images/{image_obj}" in zip_members:
member_name = f"images/{image_obj}"
if member_name is None:
stats["missing_backing_image"] += 1
return None
image_obj = {
"zip_path": images_zip_path,
"member": member_name,
}
else:
stats["missing_backing_image"] += 1
return None
stats["kept"] += 1
return {
"image": image_obj,
"prompt_text": "Describe this image.",
"target_text": text,
"sample_id": f"llava-pretrain:{sample.get('id', idx)}",
"source_config": "llava_pretrain",
}
records = [record for i in range(len(data)) if (record := normalize(data[i], i)) is not None]
normalized = HFDataset.from_list(records)
print(
f"Loaded {len(normalized)} LLaVA samples "
f"(kept={stats['kept']}, missing_image_ref={stats['missing_image_ref']}, "
f"missing_backing_image={stats['missing_backing_image']})"
)
return normalized
def disable_image_decoding(feature: object) -> object:
if isinstance(feature, HFImage):
return HFImage(decode=False)
if isinstance(feature, HFSequence):
return HFSequence(feature=disable_image_decoding(feature.feature), length=feature.length)
if isinstance(feature, Features):
return Features({key: disable_image_decoding(value) for key, value in feature.items()})
if isinstance(feature, dict):
return {key: disable_image_decoding(value) for key, value in feature.items()}
if isinstance(feature, list):
return [disable_image_decoding(value) for value in feature]
return feature
def build_cauldron_records(
configs: Sequence[str],
max_samples: Optional[int],
raw_row_limit: Optional[int] = None,
) -> Tuple[HFDataset, Dict[str, Dict[str, int]]]:
normalized_configs: List[HFDataset] = []
skip_stats: Dict[str, Dict[str, int]] = {}
per_config_limit = None
if max_samples:
per_config_limit = max(1, max_samples // max(len(configs), 1))
for config_name in configs:
print(f"Loading The Cauldron config: {config_name}")
ds = load_dataset("HuggingFaceM4/the_cauldron", config_name, split="train")
if raw_row_limit is not None:
ds = ds.select(range(min(raw_row_limit, len(ds))))
if "images" in ds.features:
ds = ds.cast_column("images", disable_image_decoding(ds.features["images"]))
if "image" in ds.features:
ds = ds.cast_column("image", disable_image_decoding(ds.features["image"]))
stats = defaultdict(int)
def explode(batch: Dict[str, List[object]], indices: List[int]) -> Dict[str, List[object]]:
output = {
"image": [],
"prompt_text": [],
"target_text": [],
"sample_id": [],
"source_config": [],
}
batch_images = batch.get("images")
batch_single_image = batch.get("image")
batch_texts = batch.get("texts") or batch.get("conversations")
for local_idx, row_idx in enumerate(indices):
if batch_images is not None:
images = batch_images[local_idx]
elif batch_single_image is not None:
images = batch_single_image[local_idx]
else:
stats["missing_image_column"] += 1
continue
if batch_texts is None:
stats["missing_text_column"] += 1
continue
texts = batch_texts[local_idx]
if images is None:
images = []
elif not isinstance(images, list):
images = [images]
if texts is None:
texts = []
elif isinstance(texts, dict):
texts = [texts]
if not images or len(images) != 1:
stats["multi_or_missing_image"] += 1
continue
image_payload = images[0]
if not is_usable_image_payload(image_payload):
stats["unusable_image_ref"] += 1
continue
if not texts:
stats["missing_turns"] += 1
continue
for turn_idx, turn in enumerate(texts):
if not isinstance(turn, dict):
stats["unsupported_turn_type"] += 1
continue
user_text = (
turn.get("user")
or turn.get("question")
or turn.get("prompt")
or turn.get("input")
or ""
).strip()
assistant_text = (
turn.get("assistant")
or turn.get("answer")
or turn.get("response")
or turn.get("output")
or ""
).strip()
if not user_text or not assistant_text:
stats["missing_user_or_assistant"] += 1
continue
output["image"].append(image_payload)
output["prompt_text"].append(user_text)
output["target_text"].append(assistant_text)
output["sample_id"].append(f"{config_name}:{row_idx}:{turn_idx}")
output["source_config"].append(config_name)
stats["kept"] += 1
return output
exploded = ds.map(
explode,
batched=True,
with_indices=True,
remove_columns=ds.column_names,
desc=f"Normalizing {config_name}",
)
if per_config_limit is not None:
exploded = exploded.select(range(min(per_config_limit, len(exploded))))
normalized_configs.append(exploded)
stats["kept"] = len(exploded)
skip_stats[config_name] = dict(stats)
print(f"{config_name}: kept={stats['kept']} skipped={sum(v for k, v in stats.items() if k != 'kept')}")
if not normalized_configs:
raise RuntimeError("No valid The Cauldron configs were loaded.")
combined = concatenate_datasets(normalized_configs)
if max_samples:
combined = combined.select(range(min(max_samples, len(combined))))
print(f"Loaded {len(combined)} normalized The Cauldron samples")
return combined, skip_stats
def collate_vision_language(batch: List[Dict[str, object]]) -> Dict[str, object]:
return {
"pixel_values": torch.stack([sample["pixel_values"] for sample in batch]),
"input_ids": torch.stack([sample["input_ids"] for sample in batch]),
"attention_mask": torch.stack([sample["attention_mask"] for sample in batch]),
"labels": torch.stack([sample["labels"] for sample in batch]),
"loss_mask": torch.stack([sample["loss_mask"] for sample in batch]),
"sample_id": [sample["sample_id"] for sample in batch],
"prompt_text": [sample["prompt_text"] for sample in batch],
"target_text": [sample["target_text"] for sample in batch],
"source_config": [sample["source_config"] for sample in batch],
"pil_image": [sample["pil_image"] for sample in batch],
}
def create_stage_dataset(stage: str, tokenizer: AutoTokenizer, args: argparse.Namespace) -> Tuple[NormalizedVisionLanguageDataset, Dict[str, Dict[str, int]]]:
if stage == "1":
return NormalizedVisionLanguageDataset(
records=build_llava_records(args.max_samples),
tokenizer=tokenizer,
max_length=args.max_length,
img_size=224,
), {}
configs = resolve_cauldron_configs(args)
raw_row_limit = None
if args.dry_run_batches:
raw_row_limit = max(32, args.batch_size * args.dry_run_batches * 8)
records, skip_stats = build_cauldron_records(configs, args.max_samples, raw_row_limit=raw_row_limit)
return NormalizedVisionLanguageDataset(
records=records,
tokenizer=tokenizer,
max_length=args.max_length,
img_size=224,
), skip_stats
def build_dataloader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
num_workers: int,
persistent_workers: bool,
) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=persistent_workers and num_workers > 0,
drop_last=False,
collate_fn=collate_vision_language,
)
def print_device_info(device: torch.device) -> None:
print(f"Device: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"torch.version.cuda: {torch.version.cuda}")
def ensure_runtime_requirements(args: argparse.Namespace) -> None:
if args.require_cuda and not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this run but torch.cuda.is_available() is False.")
if args.stage in {"2", "3a", "3b"} and not parse_dataset_configs(args.dataset_configs):
raise RuntimeError("Stage 2/3 requires at least one The Cauldron config.")
if args.stage in {"3a", "3b"} and not args.teacher_cache_dir:
raise RuntimeError("Stage 3 requires --teacher_cache_dir.")
if args.stage in {"3a", "3b"} and not args.resume_from:
raise RuntimeError("Stage 3 requires --resume_from pointing to a Stage 2 checkpoint.")
if args.stage == "3a":
try:
import bitsandbytes # noqa: F401
except ImportError as exc:
raise RuntimeError("Stage 3a requires bitsandbytes in the active environment.") from exc
def maybe_resume_model(model: ViLDLM, args: argparse.Namespace) -> None:
if not args.resume_from:
return
checkpoint_dir = Path(args.resume_from)
if not checkpoint_dir.exists():
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
include_lm = args.stage in {"2", "3a", "3b"}
print(f"Resuming from checkpoint: {checkpoint_dir}")
model.load_checkpoint(checkpoint_dir, include_lm=include_lm)
def get_optimizer(model: ViLDLM, stage: str) -> AdamW:
if stage == "1":
groups = [
{
"params": [p for p in model.projector.parameters() if p.requires_grad],
"lr": 1e-3,
}
]
else:
groups = [
{
"params": [p for p in model.vision_encoder.parameters() if p.requires_grad],
"lr": 2e-6,
},
{
"params": [p for p in model.projector.parameters() if p.requires_grad],
"lr": 1e-5,
},
{
"params": [p for p in model.lm.parameters() if p.requires_grad],
"lr": 1e-5,
},
]
groups = [group for group in groups if group["params"]]
return AdamW(groups, weight_decay=0.05, betas=(0.9, 0.999))
def setup_model_for_stage(model: ViLDLM, stage: str) -> None:
if stage == "1":
print("\n=== STAGE 1: Projector-only alignment ===")
model.freeze_vision()
model.freeze_lm()
elif stage in {"2", "3b"}:
label = "Full finetune" if stage == "2" else "Sparse KD finetune"
print(f"\n=== STAGE {stage.upper()}: {label} ===")
model.unfreeze_all()
elif stage == "3a":
print("\n=== STAGE 3A: Teacher candidate-bank preparation ===")
model.unfreeze_all()
for param in model.parameters():
param.requires_grad = False
else:
raise ValueError(f"Unsupported stage: {stage}")
def compute_sparse_kd_loss(
student_logits: torch.Tensor,
noise_mask: torch.Tensor,
timesteps: torch.Tensor,
sample_ids: Sequence[str],
bank_map: Dict[str, List[Dict[str, object]]],
temperature: float,
) -> Tuple[torch.Tensor, Dict[str, object]]:
entries_used = 0
losses: List[torch.Tensor] = []
mask_probs: List[torch.Tensor] = []
mask_probability = 1.0 - torch.cos(timesteps * math.pi / 2)
for batch_idx, sample_id in enumerate(sample_ids):
sample_entries = bank_map.get(sample_id, [])
for entry in sample_entries:
position = int(entry["position"])
if position >= student_logits.shape[1]:
continue
if not bool(noise_mask[batch_idx, position].item()):
continue
candidate_ids = torch.tensor(
entry["candidate_token_ids"],
device=student_logits.device,
dtype=torch.long,
)
teacher_probs = torch.tensor(
entry["teacher_probs"],
device=student_logits.device,
dtype=student_logits.dtype,
)
gathered = student_logits[batch_idx, position, candidate_ids]
student_log_probs = F.log_softmax(gathered / temperature, dim=-1)
loss = F.kl_div(
student_log_probs.unsqueeze(0),
teacher_probs.unsqueeze(0),
reduction="batchmean",
) * (temperature ** 2)
losses.append(loss)
mask_probs.append(mask_probability[batch_idx])
entries_used += 1
if not losses:
zero = torch.tensor(0.0, device=student_logits.device)
return zero, {
"entries": 0,
"loss_variance": zero,
"mean_mask_prob": zero,
}
loss_tensor = torch.stack(losses)
mask_prob_tensor = torch.stack(mask_probs)
return loss_tensor.mean(), {
"entries": entries_used,
"loss_variance": loss_tensor.var(unbiased=False),
"mean_mask_prob": mask_prob_tensor.mean(),
}
def build_kd_force_mask(
sample_ids: Sequence[str],
bank_map: Dict[str, List[Dict[str, object]]],
seq_len: int,
device: torch.device,
) -> torch.Tensor:
force_mask = torch.zeros((len(sample_ids), seq_len), device=device, dtype=torch.bool)
for batch_idx, sample_id in enumerate(sample_ids):
for entry in bank_map.get(sample_id, []):
position = int(entry["position"])
if 0 <= position < seq_len:
force_mask[batch_idx, position] = True
return force_mask
def compute_teacher_logprobs(
teacher: AutoModelForImageTextToText,
processor: AutoProcessor,
pil_image: Image.Image,
prompt_text: str,
candidate_texts: Sequence[str],
teacher_batch_size: int,
) -> torch.Tensor:
prompt_messages = [
{
"role": "user",
"content": [
{"type": "image", "image": pil_image},
{"type": "text", "text": prompt_text},
],
}
]
prompt_inputs = processor.apply_chat_template(
prompt_messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
prompt_len = prompt_inputs["input_ids"].shape[1]
teacher_device = next(teacher.parameters()).device
all_logprobs = []
for start in range(0, len(candidate_texts), max(teacher_batch_size, 1)):
batch_candidates = candidate_texts[start : start + max(teacher_batch_size, 1)]
conversations = []
for candidate_text in batch_candidates:
conversations.append(
[
{
"role": "user",
"content": [
{"type": "image", "image": pil_image},
{"type": "text", "text": prompt_text},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": candidate_text}],
},
]
)
batch_inputs = processor.apply_chat_template(
conversations,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=False,
)
batch_inputs = {key: value.to(teacher_device) for key, value in batch_inputs.items()}
outputs = teacher(**batch_inputs)
logits = outputs.logits[:, :-1, :]
labels = batch_inputs["input_ids"][:, 1:].clone()
attention_mask = batch_inputs["attention_mask"]
seq_len = batch_inputs["input_ids"].shape[1]
for batch_idx in range(labels.shape[0]):
valid_len = int(attention_mask[batch_idx].sum().item())
left_pad = seq_len - valid_len
prefix_cut = left_pad + prompt_len - 1
if prefix_cut > 0:
labels[batch_idx, :prefix_cut] = -100
labels[batch_idx, attention_mask[batch_idx, 1:] == 0] = -100
per_token = F.cross_entropy(
logits.reshape(-1, logits.shape[-1]),
labels.reshape(-1),
ignore_index=-100,
reduction="none",
).reshape(labels.shape)
token_mask = (labels != -100).float()
all_logprobs.append(-(per_token * token_mask).sum(dim=-1).cpu())
return torch.cat(all_logprobs, dim=0)
def choose_distillation_positions(
clean_logits: torch.Tensor,
labels: torch.Tensor,
loss_mask: torch.Tensor,
max_positions: int,
) -> List[int]:
valid_positions = torch.nonzero(loss_mask > 0, as_tuple=False).flatten()
if valid_positions.numel() == 0:
return []
probs = F.softmax(clean_logits[valid_positions], dim=-1)
gold = labels[valid_positions].unsqueeze(-1)
gold_probs = probs.gather(-1, gold).squeeze(-1)
_, ranked = torch.sort(gold_probs, descending=False)
selected = valid_positions[ranked][:max_positions]
return [int(pos.item()) for pos in selected]
def build_candidate_ids(
logits_at_position: torch.Tensor,
gold_token_id: int,
top_k: int,
) -> List[int]:
candidate_ids = logits_at_position.topk(max(top_k - 1, 1)).indices.tolist()
if gold_token_id not in candidate_ids:
candidate_ids.append(gold_token_id)
deduped = []
seen = set()
for token_id in candidate_ids:
if token_id in seen:
continue
deduped.append(token_id)
seen.add(token_id)
return deduped[:top_k]
def decode_assistant_text(
tokenizer: AutoTokenizer,
full_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> str:
active = (attention_mask > 0) & (loss_mask > 0)
assistant_ids = full_ids[active].tolist()
return tokenizer.decode(assistant_ids, skip_special_tokens=True).strip()
def prepare_teacher_bank(
args: argparse.Namespace,
model: ViLDLM,
dataset: NormalizedVisionLanguageDataset,
) -> None:
if args.dry_run_batches:
max_items = min(args.teacher_batch_size * args.dry_run_batches, len(dataset))
elif args.max_samples:
max_items = min(args.max_samples, len(dataset))
else:
max_items = len(dataset)
try:
from transformers import BitsAndBytesConfig
except ImportError as exc:
raise RuntimeError("bitsandbytes/transformers quantization support is required for Stage 3a.") from exc
print(f"Loading teacher: {args.teacher_model_id}")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
teacher = AutoModelForImageTextToText.from_pretrained(
args.teacher_model_id,
quantization_config=quantization_config,
device_map="auto",
attn_implementation="sdpa",
)
teacher.eval()
processor = AutoProcessor.from_pretrained(args.teacher_model_id, padding_side="left")
cache_dir = Path(args.teacher_cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
output_path = cache_dir / "candidate_bank.jsonl"
seen_keys = set()
if output_path.exists():
with open(output_path, "r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
record = json.loads(line)
seen_keys.add((record["sample_id"], int(record["position"])))
dataloader = build_dataloader(
dataset=dataset,
batch_size=1,
shuffle=False,
num_workers=0,
persistent_workers=False,
)
prepared = 0
teacher_entropies: List[float] = []
with torch.no_grad(), open(output_path, "a", encoding="utf-8") as writer:
for batch in dataloader:
sample_id = batch["sample_id"][0]
prompt_text = batch["prompt_text"][0]
target_text = batch["target_text"][0]
pil_image = batch["pil_image"][0]
pixel_values = batch["pixel_values"].to(next(model.parameters()).device)
input_ids = batch["input_ids"].to(pixel_values.device)
attention_mask = batch["attention_mask"].to(pixel_values.device)
labels = batch["labels"].to(pixel_values.device)
loss_mask = batch["loss_mask"].to(pixel_values.device)
clean_logits = model.predict_clean_logits(pixel_values, input_ids, attention_mask)[0]
sample_labels = labels[0]
sample_loss_mask = loss_mask[0]
positions = choose_distillation_positions(
clean_logits=clean_logits,
labels=sample_labels,
loss_mask=sample_loss_mask,
max_positions=args.kd_positions_per_sample,
)
for position in positions:
cache_key = (sample_id, position)
if cache_key in seen_keys:
continue
gold_token_id = int(sample_labels[position].item())
candidate_token_ids = build_candidate_ids(
logits_at_position=clean_logits[position],
gold_token_id=gold_token_id,
top_k=args.kd_top_k,
)
candidate_texts: List[str] = []
for candidate_id in candidate_token_ids:
modified_ids = input_ids[0].clone()
modified_ids[position] = candidate_id
candidate_texts.append(
decode_assistant_text(
tokenizer=model.tokenizer,
full_ids=modified_ids,
attention_mask=attention_mask[0],
loss_mask=loss_mask[0],
)
)
teacher_logprobs = compute_teacher_logprobs(
teacher=teacher,
processor=processor,
pil_image=pil_image,
prompt_text=prompt_text,
candidate_texts=candidate_texts,
teacher_batch_size=args.teacher_batch_size,
)
teacher_probs_tensor = F.softmax(teacher_logprobs / args.kd_temperature, dim=-1)
teacher_entropy = float(
-(teacher_probs_tensor * teacher_probs_tensor.clamp_min(1e-12).log()).sum().item()
)
teacher_probs = teacher_probs_tensor.cpu().tolist()
record = {
"sample_id": sample_id,
"position": position,
"candidate_token_ids": candidate_token_ids,
"teacher_probs": teacher_probs,
"gold_token_id": gold_token_id,
"temperature": args.kd_temperature,
"teacher_entropy": teacher_entropy,
"source_config": batch["source_config"][0],
"text_hash": stable_text_hash(sample_id, prompt_text, target_text),
}
writer.write(json.dumps(record) + "\n")
seen_keys.add(cache_key)
prepared += 1
teacher_entropies.append(teacher_entropy)
if args.dry_run_batches and prepared >= args.kd_positions_per_sample * args.dry_run_batches:
break
if prepared and prepared % 50 == 0:
print(f"Prepared {prepared} KD entries...")
print(f"Teacher bank written to {output_path} with {prepared} new entries")
if teacher_entropies:
entropy_array = np.array(teacher_entropies, dtype=np.float32)
print(
"Teacher entropy: "
f"mean={float(entropy_array.mean()):.4f}, "
f"min={float(entropy_array.min()):.4f}, "
f"max={float(entropy_array.max()):.4f}"
)
def load_teacher_bank(cache_dir: str) -> Dict[str, List[Dict[str, object]]]:
bank_path = Path(cache_dir) / "candidate_bank.jsonl"
if not bank_path.exists():
raise FileNotFoundError(f"Teacher bank not found: {bank_path}")
bank_map: Dict[str, List[Dict[str, object]]] = defaultdict(list)
with open(bank_path, "r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
record = json.loads(line)
bank_map[record["sample_id"]].append(record)
print(f"Loaded teacher bank for {len(bank_map)} samples from {bank_path}")
return bank_map
def maybe_push_to_hub(
args: argparse.Namespace,
save_dir: Path,
params: Dict[str, int],
best_loss: float,
) -> None:
if not args.push_to_hub:
print("Skipping Hub push (enable with --push_to_hub).")
return
print("\nPushing to Hub...")
api = HfApi()
repo_id = args.hub_model_id
try:
api.create_repo(repo_id, exist_ok=True, private=False)
except Exception as exc:
print(f"Repo note: {exc}")
config_dict = {
"architecture": "ViL-DLM",
"training_stage": args.stage,
"best_loss": best_loss,
"total_params_M": params["total"] / 1e6,
"trainable_params_M": params["trainable"] / 1e6,
"teacher": args.teacher_model_id,
"dataset_configs": parse_dataset_configs(args.dataset_configs) if args.stage in {"2", "3a", "3b"} else ["llava_pretrain"],
}
with open(save_dir / "model_config.json", "w", encoding="utf-8") as handle:
json.dump(config_dict, handle, indent=2)
api.upload_folder(
folder_path=str(save_dir),
repo_id=repo_id,
commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})",
)
print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}")
def run_training_stage(args: argparse.Namespace) -> None:
tracker = _TrackioShim()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_device_info(device)
ensure_runtime_requirements(args)
lm_path = download_student_backbone()
vil_config = ViLConfig()
proj_config = ProjConfig()
model = ViLDLM(vil_config, proj_config, lm_path)
setup_model_for_stage(model, args.stage)
maybe_resume_model(model, args)
params = model.count_params()
print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M")
print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M")
model = model.to(device)
if hasattr(model.lm, "gradient_checkpointing_enable"):
model.lm.gradient_checkpointing_enable()
dataset, skip_stats = create_stage_dataset("1" if args.stage == "1" else "2", model.tokenizer, args)
if skip_stats:
print(f"Skip stats: {json.dumps(skip_stats)}")
if args.stage == "3a":
prepare_teacher_bank(args=args, model=model, dataset=dataset)
return
teacher_bank = load_teacher_bank(args.teacher_cache_dir) if args.stage == "3b" else {}
dataloader = build_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=args.stage != "3a" and not (args.stage == "3b" and args.dry_run_batches),
num_workers=args.num_workers,
persistent_workers=args.persistent_workers,
)
optimizer = get_optimizer(model, stage="1" if args.stage == "1" else "2")
total_steps = max(1, (len(dataloader) * args.epochs) // max(args.grad_accum, 1))
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
tracker.init(name=f"vil-dlm-stage{args.stage}")
best_loss = float("inf")
global_step = 0
step_timer = time.time()
for epoch in range(args.epochs):
model.train()
epoch_loss = 0.0
epoch_kd_loss = 0.0
epoch_kd_entries = 0
epoch_effective_alpha = 0.0
epoch_kd_mask_prob = 0.0
epoch_kd_loss_variance = 0.0
num_batches = 0
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(dataloader):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
loss_mask = batch["loss_mask"].to(device)
force_mask = None
if args.stage == "3b":
force_mask = build_kd_force_mask(
sample_ids=batch["sample_id"],
bank_map=teacher_bank,
seq_len=input_ids.shape[1],
device=device,
)
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_mask=loss_mask,
force_mask=force_mask,
)
diffusion_loss = outputs["loss"]
kd_loss = torch.tensor(0.0, device=device)
kd_entries = 0
kd_loss_variance = torch.tensor(0.0, device=device)
mean_kd_mask_prob = torch.tensor(0.0, device=device)
effective_alpha_kd = torch.tensor(0.0, device=device)
total_loss = diffusion_loss
if args.stage == "3b":
kd_loss, kd_metrics = compute_sparse_kd_loss(
student_logits=outputs["logits"],
noise_mask=outputs["noise_mask"],
timesteps=outputs["t"],
sample_ids=batch["sample_id"],
bank_map=teacher_bank,
temperature=args.kd_temperature,
)
kd_entries = int(kd_metrics["entries"])
kd_loss_variance = kd_metrics["loss_variance"]
mean_kd_mask_prob = kd_metrics["mean_mask_prob"]
if kd_entries > 0:
if args.kd_timestep_weighting:
effective_alpha_kd = args.alpha_kd * mean_kd_mask_prob
else:
effective_alpha_kd = torch.tensor(args.alpha_kd, device=device)
total_loss = (1.0 - effective_alpha_kd) * diffusion_loss + effective_alpha_kd * kd_loss
loss = total_loss / args.grad_accum
loss.backward()
if (batch_idx + 1) % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
actual_loss = float(total_loss.item())
actual_diffusion = float(diffusion_loss.item())
actual_kd = float(kd_loss.item()) if args.stage == "3b" else 0.0
actual_kd_variance = float(kd_loss_variance.item()) if args.stage == "3b" else 0.0
actual_kd_mask_prob = float(mean_kd_mask_prob.item()) if args.stage == "3b" else 0.0
actual_effective_alpha = float(effective_alpha_kd.item()) if args.stage == "3b" else 0.0
elapsed = max(time.time() - step_timer, 1e-6)
samples_per_sec = (args.batch_size * args.grad_accum) / elapsed
step_timer = time.time()
gpu_mem_gb = 0.0
if torch.cuda.is_available():
gpu_mem_gb = torch.cuda.max_memory_allocated(device) / 1e9
print(
f"[E{epoch}] Step {global_step}/{total_steps} | "
f"Loss: {actual_loss:.4f} | Diff: {actual_diffusion:.4f} | "
f"KD: {actual_kd:.4f} | KD entries: {kd_entries} | "
f"KD var: {actual_kd_variance:.4f} | KD mask_p: {actual_kd_mask_prob:.4f} | "
f"alpha_kd: {actual_effective_alpha:.4f} | "
f"Samples/s: {samples_per_sec:.2f} | GPU mem: {gpu_mem_gb:.2f} GB"
)
tracker.log(
{
"train/loss": actual_loss,
"train/diffusion_loss": actual_diffusion,
"train/kd_loss": actual_kd,
"train/kd_entries": kd_entries,
"train/kd_loss_variance": actual_kd_variance,
"train/mean_kd_mask_prob": actual_kd_mask_prob,
"train/effective_alpha_kd": actual_effective_alpha,
"train/epoch": epoch,
"train/step": global_step,
"train/samples_per_sec": samples_per_sec,
"train/gpu_mem_gb": gpu_mem_gb,
}
)
epoch_loss += float(total_loss.item())
epoch_kd_loss += float(kd_loss.item())
epoch_kd_entries += kd_entries
epoch_effective_alpha += float(effective_alpha_kd.item())
epoch_kd_mask_prob += float(mean_kd_mask_prob.item())
epoch_kd_loss_variance += float(kd_loss_variance.item())
num_batches += 1
if args.dry_run_batches and num_batches >= args.dry_run_batches:
break
remainder = num_batches % args.grad_accum
if remainder != 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
avg_loss = epoch_loss / max(num_batches, 1)
avg_kd_loss = epoch_kd_loss / max(num_batches, 1)
avg_effective_alpha = epoch_effective_alpha / max(num_batches, 1)
avg_kd_mask_prob = epoch_kd_mask_prob / max(num_batches, 1)
avg_kd_loss_variance = epoch_kd_loss_variance / max(num_batches, 1)
print(
f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f} | Average KD: {avg_kd_loss:.4f} | "
f"KD entries: {epoch_kd_entries} | Avg alpha_kd: {avg_effective_alpha:.4f} | "
f"Avg KD mask_p: {avg_kd_mask_prob:.4f} | Avg KD var: {avg_kd_loss_variance:.4f}\n"
)
tracker.log(
{
"train/epoch_loss": avg_loss,
"train/epoch_kd_loss": avg_kd_loss,
"train/epoch_kd_entries": epoch_kd_entries,
"train/epoch_effective_alpha_kd": avg_effective_alpha,
"train/epoch_mean_kd_mask_prob": avg_kd_mask_prob,
"train/epoch_kd_loss_variance": avg_kd_loss_variance,
"train/epoch": epoch,
}
)
if avg_loss < best_loss:
best_loss = avg_loss
save_dir = Path(args.output_dir) / f"stage{args.stage}_best"
include_lm = args.stage in {"2", "3b"}
model.save_checkpoint(save_dir, include_lm=include_lm)
training_state = {
"stage": args.stage,
"best_loss": best_loss,
"args": vars(args),
}
with open(save_dir / "training_state.json", "w", encoding="utf-8") as handle:
json.dump(training_state, handle, indent=2)
print(f"Saved best checkpoint (loss={best_loss:.4f})")
maybe_push_to_hub(
args=args,
save_dir=Path(args.output_dir) / f"stage{args.stage}_best",
params=params,
best_loss=best_loss,
)
print("Training complete!")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--stage", type=str, default="1", choices=["1", "2", "3a", "3b"])
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--grad_accum", type=int, default=8)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./vil-dlm-output")
parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B")
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--require_cuda", action="store_true")
parser.add_argument("--resume_from", type=str, default=None)
parser.add_argument("--dataset_configs", type=str, default=",".join(DEFAULT_CAULDRON_CONFIGS))
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--persistent_workers", action="store_true")
parser.add_argument("--dry_run_batches", type=int, default=0)
parser.add_argument("--teacher_model_id", type=str, default="google/gemma-4-E2B-it")
parser.add_argument("--teacher_cache_dir", type=str, default="./vil-dlm-output/teacher-cache")
parser.add_argument("--prepare_teacher_bank", action="store_true")
parser.add_argument("--teacher_batch_size", type=int, default=1)
parser.add_argument("--alpha_kd", type=float, default=0.5)
parser.add_argument("--kd_temperature", type=float, default=1.0)
parser.add_argument("--kd_timestep_weighting", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--kd_top_k", type=int, default=16)
parser.add_argument("--kd_positions_per_sample", type=int, default=16)
return parser
if __name__ == "__main__":
ensure_hf_cache_root()
parser = build_parser()
args = parser.parse_args()
if args.prepare_teacher_bank and args.stage != "3a":
raise ValueError("--prepare_teacher_bank is only valid with --stage 3a")
if args.kd_temperature <= 0:
raise ValueError("--kd_temperature must be > 0")
if not 0.0 <= args.alpha_kd <= 1.0:
raise ValueError("--alpha_kd must be between 0 and 1")
run_training_stage(args)