| """ |
| ViL-DLM production training script. |
| |
| Stages: |
| 1 - projector-only alignment on LLaVA-Pretrain |
| 2 - full-model finetune on The Cauldron |
| 3a - offline teacher candidate-bank preparation with Gemma 4 E2B-it |
| 3b - sparse cross-tokenizer distillation training using cached teacher targets |
| """ |
|
|
| import argparse |
| import hashlib |
| import json |
| import math |
| import os |
| import time |
| import zipfile |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from datasets import Dataset as HFDataset |
| from datasets import concatenate_datasets, load_dataset |
| from datasets import Features, Image as HFImage |
| from datasets.features import Sequence as HFSequence |
| from huggingface_hub import HfApi, snapshot_download |
| from PIL import Image |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import ( |
| AutoModelForImageTextToText, |
| AutoModelForMaskedLM, |
| AutoProcessor, |
| AutoTokenizer, |
| ) |
|
|
| try: |
| import trackio |
| except Exception: |
| trackio = None |
|
|
| from vision_xlstm import ( |
| VisionProjector as UpstreamVisionProjector, |
| VisionXLSTM as UpstreamVisionXLSTM, |
| ) |
|
|
|
|
| DEFAULT_CAULDRON_CONFIGS = [ |
| "ai2d", |
| "vqav2", |
| "aokvqa", |
| "textvqa", |
| "docvqa", |
| "chartqa", |
| "textcaps", |
| "screen2words", |
| ] |
|
|
| DEFAULT_CAULDRON_GATE_CONFIGS = [ |
| "ai2d", |
| "aokvqa", |
| ] |
|
|
|
|
| @dataclass |
| class ViLConfig: |
| vision_backbone: str = "vil2-small" |
| pretrained: bool = True |
| img_size: int = 224 |
| patch_size: int = 16 |
| in_channels: int = 3 |
| dim: int = 384 |
| depth: int = 24 |
| conv_kernel_size: int = 3 |
| bidirectional: bool = True |
| dropout: float = 0.0 |
|
|
| @property |
| def num_patches(self) -> int: |
| return (self.img_size // self.patch_size) ** 2 |
|
|
|
|
| @dataclass |
| class ProjConfig: |
| vil_dim: int = 384 |
| lm_dim: int = 1024 |
| hidden_mult: int = 2 |
| num_layers: int = 2 |
| dropout: float = 0.0 |
|
|
|
|
| class _TrackioShim: |
| def __init__(self) -> None: |
| self.enabled = False |
|
|
| def init(self, name: str, project: str = "vil-dlm") -> None: |
| if trackio is None: |
| print("Trackio disabled: package not installed in the active environment") |
| self.enabled = False |
| return |
| try: |
| trackio.init(name=name, project=project) |
| self.enabled = True |
| except Exception as exc: |
| self.enabled = False |
| print(f"Trackio disabled: {exc}") |
|
|
| def log(self, payload: dict) -> None: |
| if not self.enabled: |
| return |
| try: |
| trackio.log(payload) |
| except Exception as exc: |
| self.enabled = False |
| print(f"Trackio logging disabled after error: {exc}") |
|
|
|
|
| class MDLMScheduler: |
| def __init__(self, mask_token_id: int) -> None: |
| self.mask_token_id = mask_token_id |
|
|
| def add_noise( |
| self, |
| input_ids: torch.Tensor, |
| t: torch.Tensor, |
| eligible_mask: Optional[torch.Tensor] = None, |
| force_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| batch, length = input_ids.shape |
| mask_ratio = 1.0 - torch.cos(t * math.pi / 2) |
| mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length) |
| mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio |
| if eligible_mask is not None: |
| eligible_mask = eligible_mask.bool() |
| mask = mask & eligible_mask |
| if force_mask is not None: |
| mask = mask | (force_mask.bool() & eligible_mask) |
| missing_mask = (mask.sum(dim=1) == 0) & (eligible_mask.sum(dim=1) > 0) |
| for batch_idx in torch.nonzero(missing_mask, as_tuple=False).flatten(): |
| eligible_positions = torch.nonzero(eligible_mask[batch_idx], as_tuple=False).flatten() |
| chosen = eligible_positions[torch.randint(eligible_positions.numel(), (1,), device=input_ids.device)] |
| mask[batch_idx, chosen] = True |
| elif force_mask is not None: |
| mask = mask | force_mask.bool() |
| noisy_ids = input_ids.clone() |
| noisy_ids[mask] = self.mask_token_id |
| return noisy_ids, mask |
|
|
| def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| return torch.rand(batch_size, device=device) |
|
|
|
|
| class ViLDLM(nn.Module): |
| def __init__(self, vil_config: ViLConfig, proj_config: ProjConfig, lm_path: str) -> None: |
| super().__init__() |
| self.vil_config = vil_config |
| self.vision_encoder = UpstreamVisionXLSTM(vil_config) |
| self.projector = UpstreamVisionProjector(proj_config) |
| self.lm = AutoModelForMaskedLM.from_pretrained( |
| lm_path, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True) |
| if self.tokenizer.pad_token_id is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id) |
|
|
| @property |
| def num_patches(self) -> int: |
| return self.vil_config.num_patches |
|
|
| def prepare_multimodal_inputs( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| vision_features = self.vision_encoder.forward_features(pixel_values) |
| visual_tokens = self.projector(vision_features) |
| text_embeds = self.lm.model.embed_tokens(input_ids) |
| visual_tokens = visual_tokens.to(dtype=text_embeds.dtype) |
| inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1) |
| vis_mask = torch.ones( |
| pixel_values.shape[0], |
| self.num_patches, |
| device=attention_mask.device, |
| dtype=attention_mask.dtype, |
| ) |
| full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1) |
| return inputs_embeds, full_attention_mask |
|
|
| def predict_clean_logits( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask) |
| return outputs.logits[:, self.num_patches :, :] |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| loss_mask: Optional[torch.Tensor] = None, |
| force_mask: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
| if labels is None: |
| labels = input_ids.clone() |
| if loss_mask is None: |
| loss_mask = attention_mask |
|
|
| t = self.scheduler.sample_timesteps(batch_size, device) |
| eligible_mask = (loss_mask > 0) & (attention_mask > 0) |
| noisy_ids, noise_mask = self.scheduler.add_noise( |
| input_ids, |
| t, |
| eligible_mask=eligible_mask, |
| force_mask=force_mask, |
| ) |
| inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( |
| pixel_values=pixel_values, |
| input_ids=noisy_ids, |
| attention_mask=attention_mask, |
| ) |
| outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask) |
| text_logits = outputs.logits[:, self.num_patches :, :] |
|
|
| active_mask = noise_mask.float() * eligible_mask.float() |
| if active_mask.sum() == 0: |
| loss = torch.tensor(0.0, device=device, requires_grad=True) |
| else: |
| logits_flat = text_logits.reshape(-1, text_logits.shape[-1]) |
| labels_flat = labels.reshape(-1) |
| per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none").reshape(batch_size, seq_len) |
| loss = (per_token * active_mask).sum() / active_mask.sum() |
|
|
| return { |
| "loss": loss, |
| "logits": text_logits, |
| "noise_mask": noise_mask, |
| "t": t, |
| } |
|
|
| def freeze_vision(self) -> None: |
| for param in self.vision_encoder.parameters(): |
| param.requires_grad = False |
|
|
| def freeze_lm(self) -> None: |
| for param in self.lm.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_all(self) -> None: |
| for param in self.parameters(): |
| param.requires_grad = True |
|
|
| def count_params(self) -> Dict[str, int]: |
| vil = sum(p.numel() for p in self.vision_encoder.parameters()) |
| proj = sum(p.numel() for p in self.projector.parameters()) |
| lm = sum(p.numel() for p in self.lm.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return {"vil": vil, "proj": proj, "lm": lm, "total": vil + proj + lm, "trainable": trainable} |
|
|
| def save_checkpoint(self, save_dir: Path, include_lm: bool) -> None: |
| save_dir.mkdir(parents=True, exist_ok=True) |
| torch.save(self.vision_encoder.state_dict(), save_dir / "vision_encoder.pt") |
| torch.save(self.projector.state_dict(), save_dir / "projector.pt") |
| if include_lm: |
| self.lm.save_pretrained(save_dir / "diffusion_lm") |
| self.tokenizer.save_pretrained(save_dir / "diffusion_lm") |
|
|
| def load_checkpoint(self, checkpoint_dir: Path, include_lm: bool) -> None: |
| vision_path = checkpoint_dir / "vision_encoder.pt" |
| projector_path = checkpoint_dir / "projector.pt" |
| if vision_path.exists(): |
| self.vision_encoder.load_state_dict(torch.load(vision_path, map_location="cpu")) |
| if projector_path.exists(): |
| self.projector.load_state_dict(torch.load(projector_path, map_location="cpu")) |
| if include_lm: |
| diffusion_dir = checkpoint_dir / "diffusion_lm" |
| if diffusion_dir.exists(): |
| self.lm = AutoModelForMaskedLM.from_pretrained( |
| diffusion_dir, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(diffusion_dir, trust_remote_code=True) |
| if self.tokenizer.pad_token_id is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id) |
|
|
|
|
| def ensure_hf_cache_root() -> None: |
| os.environ.setdefault("HF_HOME", "/teamspace/studios/this_studio/.cache/huggingface") |
|
|
|
|
| def patch_diffusion_modeling_file(lm_path: str) -> None: |
| modeling_file = os.path.join(lm_path, "modeling_qwen3.py") |
| with open(modeling_file, "r", encoding="utf-8") as handle: |
| content = handle.read() |
| content = content.replace( |
| 'if __name__ == "__main__":\n import dllm', |
| 'if __name__ == "__main__":\n pass\n # import dllm', |
| ) |
| content = content.replace( |
| "attention_mask=causal_mask_mapping[decoder_layer.attention_type]", |
| 'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))', |
| ) |
| with open(modeling_file, "w", encoding="utf-8") as handle: |
| handle.write(content) |
|
|
|
|
| def download_student_backbone() -> str: |
| print("Downloading dLLM Qwen3-0.6B diffusion model...") |
| lm_path = snapshot_download("dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1") |
| patch_diffusion_modeling_file(lm_path) |
| print(f"Model downloaded to {lm_path}") |
| return lm_path |
|
|
|
|
| def parse_dataset_configs(dataset_configs: Optional[str]) -> List[str]: |
| if dataset_configs: |
| return [item.strip() for item in dataset_configs.split(",") if item.strip()] |
| return list(DEFAULT_CAULDRON_CONFIGS) |
|
|
|
|
| def resolve_cauldron_configs(args: argparse.Namespace) -> List[str]: |
| configs = parse_dataset_configs(args.dataset_configs) |
| default_config_string = ",".join(DEFAULT_CAULDRON_CONFIGS) |
| if args.dry_run_batches and args.dataset_configs == default_config_string: |
| print( |
| "Dry-run mode detected; using the cheap Stage 2 gate config set " |
| f"{DEFAULT_CAULDRON_GATE_CONFIGS} instead of the full production mix." |
| ) |
| return list(DEFAULT_CAULDRON_GATE_CONFIGS) |
| return configs |
|
|
|
|
| def stable_text_hash(*parts: str) -> str: |
| joined = "\n".join(parts) |
| return hashlib.sha1(joined.encode("utf-8")).hexdigest() |
|
|
|
|
| def build_prompt_prefix(prompt_text: str) -> str: |
| return f"User: {prompt_text.strip()}\nAssistant:" |
|
|
|
|
| def tokenize_prompt_and_target( |
| tokenizer: AutoTokenizer, |
| prompt_text: str, |
| target_text: str, |
| max_length: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| prefix_text = build_prompt_prefix(prompt_text) |
| prefix_ids = tokenizer(prefix_text, add_special_tokens=True)["input_ids"] |
| target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"] |
| if not target_ids: |
| target_ids = tokenizer(" " + "N/A", add_special_tokens=False)["input_ids"][:1] |
|
|
| max_prefix_len = max_length - 1 |
| if len(prefix_ids) > max_prefix_len: |
| prefix_ids = prefix_ids[:max_prefix_len] |
|
|
| remaining = max_length - len(prefix_ids) |
| if remaining <= 0: |
| prefix_ids = prefix_ids[: max_length - 1] |
| remaining = 1 |
| target_ids = target_ids[:remaining] |
| if not target_ids: |
| prefix_ids = prefix_ids[: max_length - 1] |
| target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"][:1] |
|
|
| input_ids = prefix_ids + target_ids |
| loss_mask = [0] * len(prefix_ids) + [1] * len(target_ids) |
| attention_mask = [1] * len(input_ids) |
| labels = list(input_ids) |
|
|
| pad_token_id = tokenizer.pad_token_id |
| if pad_token_id is None: |
| pad_token_id = tokenizer.eos_token_id |
|
|
| pad_len = max_length - len(input_ids) |
| if pad_len > 0: |
| input_ids = input_ids + [pad_token_id] * pad_len |
| attention_mask = attention_mask + [0] * pad_len |
| labels = labels + [pad_token_id] * pad_len |
| loss_mask = loss_mask + [0] * pad_len |
|
|
| return ( |
| torch.tensor(input_ids, dtype=torch.long), |
| torch.tensor(attention_mask, dtype=torch.long), |
| torch.tensor(labels, dtype=torch.long), |
| torch.tensor(loss_mask, dtype=torch.float32), |
| ) |
|
|
|
|
| def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]: |
| if isinstance(img, str): |
| img = Image.open(img).convert("RGB") |
| elif isinstance(img, dict) and "zip_path" in img and "member" in img: |
| with zipfile.ZipFile(img["zip_path"], "r") as archive: |
| with archive.open(img["member"], "r") as member_file: |
| img = Image.open(member_file).convert("RGB") |
| elif isinstance(img, dict) and img.get("bytes") is not None: |
| img = Image.open(BytesIO(img["bytes"])).convert("RGB") |
| elif isinstance(img, dict) and img.get("path") and os.path.exists(img["path"]): |
| img = Image.open(img["path"]).convert("RGB") |
| elif isinstance(img, Image.Image): |
| img = img.convert("RGB") |
| else: |
| raise ValueError(f"Unsupported image payload type: {type(img)!r}") |
|
|
| pil_image = img |
| resized = pil_image.resize((img_size, img_size), Image.BICUBIC) |
| arr = np.array(resized).astype(np.float32) / 255.0 |
| tensor = torch.from_numpy(arr).permute(2, 0, 1) |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
| tensor = (tensor - mean) / std |
| return tensor, pil_image |
|
|
|
|
| def is_usable_image_payload(img: object) -> bool: |
| if isinstance(img, Image.Image): |
| return True |
| if isinstance(img, str): |
| return os.path.exists(img) |
| if isinstance(img, dict): |
| if img.get("zip_path") and img.get("member"): |
| return os.path.exists(img["zip_path"]) |
| if img.get("bytes") is not None: |
| return True |
| if img.get("path"): |
| return os.path.exists(img["path"]) |
| return False |
|
|
|
|
| class NormalizedVisionLanguageDataset(Dataset): |
| def __init__( |
| self, |
| records: HFDataset, |
| tokenizer: AutoTokenizer, |
| max_length: int, |
| img_size: int, |
| ) -> None: |
| self.records = records |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.img_size = img_size |
|
|
| def __len__(self) -> int: |
| return len(self.records) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, object]: |
| sample = self.records[int(idx)] |
| pixel_values, pil_image = preprocess_image_for_student(sample["image"], self.img_size) |
| input_ids, attention_mask, labels, loss_mask = tokenize_prompt_and_target( |
| tokenizer=self.tokenizer, |
| prompt_text=sample["prompt_text"], |
| target_text=sample["target_text"], |
| max_length=self.max_length, |
| ) |
| return { |
| "pixel_values": pixel_values, |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "loss_mask": loss_mask, |
| "sample_id": sample["sample_id"], |
| "prompt_text": sample["prompt_text"], |
| "target_text": sample["target_text"], |
| "source_config": sample.get("source_config", "unknown"), |
| "pil_image": pil_image, |
| } |
|
|
|
|
| def build_llava_records(max_samples: Optional[int]) -> HFDataset: |
| print("Loading LLaVA-Pretrain dataset...") |
| dataset_root = None |
| images_zip_path = None |
| zip_members = None |
| try: |
| data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train") |
| except Exception as exc: |
| print(f"Primary dataset loader failed ({exc}). Falling back to direct JSON loading...") |
| dataset_root = snapshot_download( |
| "liuhaotian/LLaVA-Pretrain", |
| repo_type="dataset", |
| allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"], |
| ) |
| json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json") |
| images_zip_path = os.path.join(dataset_root, "images.zip") |
| if os.path.exists(images_zip_path): |
| with zipfile.ZipFile(images_zip_path, "r") as archive: |
| zip_members = set(archive.namelist()) |
| data = load_dataset("json", data_files={"train": json_path}, split="train") |
| if max_samples: |
| data = data.select(range(min(max_samples, len(data)))) |
|
|
| stats = defaultdict(int) |
|
|
| def normalize(sample: Dict[str, object], idx: int) -> Optional[Dict[str, object]]: |
| text = "" |
| if "conversations" in sample: |
| parts = [] |
| for turn in sample["conversations"]: |
| val = turn.get("value", "").replace("<image>\n", "").replace("<image>", "").strip() |
| if val: |
| parts.append(val) |
| text = " ".join(parts) |
| elif sample.get("blip_caption"): |
| text = sample["blip_caption"].strip() |
| if not text: |
| text = "Describe this image." |
|
|
| image_obj = sample.get("image") |
| if image_obj is None: |
| stats["missing_image_ref"] += 1 |
| return None |
| if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj): |
| candidate_paths = [ |
| image_obj, |
| os.path.join(dataset_root, image_obj), |
| os.path.join(dataset_root, "images", image_obj), |
| ] |
| resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None) |
| if resolved_path: |
| image_obj = resolved_path |
| elif images_zip_path and os.path.exists(images_zip_path) and zip_members: |
| member_name = None |
| if image_obj in zip_members: |
| member_name = image_obj |
| elif f"images/{image_obj}" in zip_members: |
| member_name = f"images/{image_obj}" |
| if member_name is None: |
| stats["missing_backing_image"] += 1 |
| return None |
| image_obj = { |
| "zip_path": images_zip_path, |
| "member": member_name, |
| } |
| else: |
| stats["missing_backing_image"] += 1 |
| return None |
|
|
| stats["kept"] += 1 |
| return { |
| "image": image_obj, |
| "prompt_text": "Describe this image.", |
| "target_text": text, |
| "sample_id": f"llava-pretrain:{sample.get('id', idx)}", |
| "source_config": "llava_pretrain", |
| } |
|
|
| records = [record for i in range(len(data)) if (record := normalize(data[i], i)) is not None] |
| normalized = HFDataset.from_list(records) |
| print( |
| f"Loaded {len(normalized)} LLaVA samples " |
| f"(kept={stats['kept']}, missing_image_ref={stats['missing_image_ref']}, " |
| f"missing_backing_image={stats['missing_backing_image']})" |
| ) |
| return normalized |
|
|
|
|
| def disable_image_decoding(feature: object) -> object: |
| if isinstance(feature, HFImage): |
| return HFImage(decode=False) |
| if isinstance(feature, HFSequence): |
| return HFSequence(feature=disable_image_decoding(feature.feature), length=feature.length) |
| if isinstance(feature, Features): |
| return Features({key: disable_image_decoding(value) for key, value in feature.items()}) |
| if isinstance(feature, dict): |
| return {key: disable_image_decoding(value) for key, value in feature.items()} |
| if isinstance(feature, list): |
| return [disable_image_decoding(value) for value in feature] |
| return feature |
|
|
|
|
| def build_cauldron_records( |
| configs: Sequence[str], |
| max_samples: Optional[int], |
| raw_row_limit: Optional[int] = None, |
| ) -> Tuple[HFDataset, Dict[str, Dict[str, int]]]: |
| normalized_configs: List[HFDataset] = [] |
| skip_stats: Dict[str, Dict[str, int]] = {} |
| per_config_limit = None |
| if max_samples: |
| per_config_limit = max(1, max_samples // max(len(configs), 1)) |
|
|
| for config_name in configs: |
| print(f"Loading The Cauldron config: {config_name}") |
| ds = load_dataset("HuggingFaceM4/the_cauldron", config_name, split="train") |
| if raw_row_limit is not None: |
| ds = ds.select(range(min(raw_row_limit, len(ds)))) |
| if "images" in ds.features: |
| ds = ds.cast_column("images", disable_image_decoding(ds.features["images"])) |
| if "image" in ds.features: |
| ds = ds.cast_column("image", disable_image_decoding(ds.features["image"])) |
| stats = defaultdict(int) |
|
|
| def explode(batch: Dict[str, List[object]], indices: List[int]) -> Dict[str, List[object]]: |
| output = { |
| "image": [], |
| "prompt_text": [], |
| "target_text": [], |
| "sample_id": [], |
| "source_config": [], |
| } |
| batch_images = batch.get("images") |
| batch_single_image = batch.get("image") |
| batch_texts = batch.get("texts") or batch.get("conversations") |
| for local_idx, row_idx in enumerate(indices): |
| if batch_images is not None: |
| images = batch_images[local_idx] |
| elif batch_single_image is not None: |
| images = batch_single_image[local_idx] |
| else: |
| stats["missing_image_column"] += 1 |
| continue |
|
|
| if batch_texts is None: |
| stats["missing_text_column"] += 1 |
| continue |
| texts = batch_texts[local_idx] |
|
|
| if images is None: |
| images = [] |
| elif not isinstance(images, list): |
| images = [images] |
| if texts is None: |
| texts = [] |
| elif isinstance(texts, dict): |
| texts = [texts] |
|
|
| if not images or len(images) != 1: |
| stats["multi_or_missing_image"] += 1 |
| continue |
| image_payload = images[0] |
| if not is_usable_image_payload(image_payload): |
| stats["unusable_image_ref"] += 1 |
| continue |
| if not texts: |
| stats["missing_turns"] += 1 |
| continue |
| for turn_idx, turn in enumerate(texts): |
| if not isinstance(turn, dict): |
| stats["unsupported_turn_type"] += 1 |
| continue |
| user_text = ( |
| turn.get("user") |
| or turn.get("question") |
| or turn.get("prompt") |
| or turn.get("input") |
| or "" |
| ).strip() |
| assistant_text = ( |
| turn.get("assistant") |
| or turn.get("answer") |
| or turn.get("response") |
| or turn.get("output") |
| or "" |
| ).strip() |
| if not user_text or not assistant_text: |
| stats["missing_user_or_assistant"] += 1 |
| continue |
| output["image"].append(image_payload) |
| output["prompt_text"].append(user_text) |
| output["target_text"].append(assistant_text) |
| output["sample_id"].append(f"{config_name}:{row_idx}:{turn_idx}") |
| output["source_config"].append(config_name) |
| stats["kept"] += 1 |
| return output |
|
|
| exploded = ds.map( |
| explode, |
| batched=True, |
| with_indices=True, |
| remove_columns=ds.column_names, |
| desc=f"Normalizing {config_name}", |
| ) |
| if per_config_limit is not None: |
| exploded = exploded.select(range(min(per_config_limit, len(exploded)))) |
| normalized_configs.append(exploded) |
| stats["kept"] = len(exploded) |
| skip_stats[config_name] = dict(stats) |
| print(f"{config_name}: kept={stats['kept']} skipped={sum(v for k, v in stats.items() if k != 'kept')}") |
|
|
| if not normalized_configs: |
| raise RuntimeError("No valid The Cauldron configs were loaded.") |
|
|
| combined = concatenate_datasets(normalized_configs) |
| if max_samples: |
| combined = combined.select(range(min(max_samples, len(combined)))) |
| print(f"Loaded {len(combined)} normalized The Cauldron samples") |
| return combined, skip_stats |
|
|
|
|
| def collate_vision_language(batch: List[Dict[str, object]]) -> Dict[str, object]: |
| return { |
| "pixel_values": torch.stack([sample["pixel_values"] for sample in batch]), |
| "input_ids": torch.stack([sample["input_ids"] for sample in batch]), |
| "attention_mask": torch.stack([sample["attention_mask"] for sample in batch]), |
| "labels": torch.stack([sample["labels"] for sample in batch]), |
| "loss_mask": torch.stack([sample["loss_mask"] for sample in batch]), |
| "sample_id": [sample["sample_id"] for sample in batch], |
| "prompt_text": [sample["prompt_text"] for sample in batch], |
| "target_text": [sample["target_text"] for sample in batch], |
| "source_config": [sample["source_config"] for sample in batch], |
| "pil_image": [sample["pil_image"] for sample in batch], |
| } |
|
|
|
|
| def create_stage_dataset(stage: str, tokenizer: AutoTokenizer, args: argparse.Namespace) -> Tuple[NormalizedVisionLanguageDataset, Dict[str, Dict[str, int]]]: |
| if stage == "1": |
| return NormalizedVisionLanguageDataset( |
| records=build_llava_records(args.max_samples), |
| tokenizer=tokenizer, |
| max_length=args.max_length, |
| img_size=224, |
| ), {} |
|
|
| configs = resolve_cauldron_configs(args) |
| raw_row_limit = None |
| if args.dry_run_batches: |
| raw_row_limit = max(32, args.batch_size * args.dry_run_batches * 8) |
| records, skip_stats = build_cauldron_records(configs, args.max_samples, raw_row_limit=raw_row_limit) |
| return NormalizedVisionLanguageDataset( |
| records=records, |
| tokenizer=tokenizer, |
| max_length=args.max_length, |
| img_size=224, |
| ), skip_stats |
|
|
|
|
| def build_dataloader( |
| dataset: Dataset, |
| batch_size: int, |
| shuffle: bool, |
| num_workers: int, |
| persistent_workers: bool, |
| ) -> DataLoader: |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| pin_memory=torch.cuda.is_available(), |
| persistent_workers=persistent_workers and num_workers > 0, |
| drop_last=False, |
| collate_fn=collate_vision_language, |
| ) |
|
|
|
|
| def print_device_info(device: torch.device) -> None: |
| print(f"Device: {device}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| print(f"torch.version.cuda: {torch.version.cuda}") |
|
|
|
|
| def ensure_runtime_requirements(args: argparse.Namespace) -> None: |
| if args.require_cuda and not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required for this run but torch.cuda.is_available() is False.") |
| if args.stage in {"2", "3a", "3b"} and not parse_dataset_configs(args.dataset_configs): |
| raise RuntimeError("Stage 2/3 requires at least one The Cauldron config.") |
| if args.stage in {"3a", "3b"} and not args.teacher_cache_dir: |
| raise RuntimeError("Stage 3 requires --teacher_cache_dir.") |
| if args.stage in {"3a", "3b"} and not args.resume_from: |
| raise RuntimeError("Stage 3 requires --resume_from pointing to a Stage 2 checkpoint.") |
| if args.stage == "3a": |
| try: |
| import bitsandbytes |
| except ImportError as exc: |
| raise RuntimeError("Stage 3a requires bitsandbytes in the active environment.") from exc |
|
|
|
|
| def maybe_resume_model(model: ViLDLM, args: argparse.Namespace) -> None: |
| if not args.resume_from: |
| return |
| checkpoint_dir = Path(args.resume_from) |
| if not checkpoint_dir.exists(): |
| raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") |
| include_lm = args.stage in {"2", "3a", "3b"} |
| print(f"Resuming from checkpoint: {checkpoint_dir}") |
| model.load_checkpoint(checkpoint_dir, include_lm=include_lm) |
|
|
|
|
| def get_optimizer(model: ViLDLM, stage: str) -> AdamW: |
| if stage == "1": |
| groups = [ |
| { |
| "params": [p for p in model.projector.parameters() if p.requires_grad], |
| "lr": 1e-3, |
| } |
| ] |
| else: |
| groups = [ |
| { |
| "params": [p for p in model.vision_encoder.parameters() if p.requires_grad], |
| "lr": 2e-6, |
| }, |
| { |
| "params": [p for p in model.projector.parameters() if p.requires_grad], |
| "lr": 1e-5, |
| }, |
| { |
| "params": [p for p in model.lm.parameters() if p.requires_grad], |
| "lr": 1e-5, |
| }, |
| ] |
| groups = [group for group in groups if group["params"]] |
| return AdamW(groups, weight_decay=0.05, betas=(0.9, 0.999)) |
|
|
|
|
| def setup_model_for_stage(model: ViLDLM, stage: str) -> None: |
| if stage == "1": |
| print("\n=== STAGE 1: Projector-only alignment ===") |
| model.freeze_vision() |
| model.freeze_lm() |
| elif stage in {"2", "3b"}: |
| label = "Full finetune" if stage == "2" else "Sparse KD finetune" |
| print(f"\n=== STAGE {stage.upper()}: {label} ===") |
| model.unfreeze_all() |
| elif stage == "3a": |
| print("\n=== STAGE 3A: Teacher candidate-bank preparation ===") |
| model.unfreeze_all() |
| for param in model.parameters(): |
| param.requires_grad = False |
| else: |
| raise ValueError(f"Unsupported stage: {stage}") |
|
|
|
|
| def compute_sparse_kd_loss( |
| student_logits: torch.Tensor, |
| noise_mask: torch.Tensor, |
| timesteps: torch.Tensor, |
| sample_ids: Sequence[str], |
| bank_map: Dict[str, List[Dict[str, object]]], |
| temperature: float, |
| ) -> Tuple[torch.Tensor, Dict[str, object]]: |
| entries_used = 0 |
| losses: List[torch.Tensor] = [] |
| mask_probs: List[torch.Tensor] = [] |
| mask_probability = 1.0 - torch.cos(timesteps * math.pi / 2) |
| for batch_idx, sample_id in enumerate(sample_ids): |
| sample_entries = bank_map.get(sample_id, []) |
| for entry in sample_entries: |
| position = int(entry["position"]) |
| if position >= student_logits.shape[1]: |
| continue |
| if not bool(noise_mask[batch_idx, position].item()): |
| continue |
| candidate_ids = torch.tensor( |
| entry["candidate_token_ids"], |
| device=student_logits.device, |
| dtype=torch.long, |
| ) |
| teacher_probs = torch.tensor( |
| entry["teacher_probs"], |
| device=student_logits.device, |
| dtype=student_logits.dtype, |
| ) |
| gathered = student_logits[batch_idx, position, candidate_ids] |
| student_log_probs = F.log_softmax(gathered / temperature, dim=-1) |
| loss = F.kl_div( |
| student_log_probs.unsqueeze(0), |
| teacher_probs.unsqueeze(0), |
| reduction="batchmean", |
| ) * (temperature ** 2) |
| losses.append(loss) |
| mask_probs.append(mask_probability[batch_idx]) |
| entries_used += 1 |
|
|
| if not losses: |
| zero = torch.tensor(0.0, device=student_logits.device) |
| return zero, { |
| "entries": 0, |
| "loss_variance": zero, |
| "mean_mask_prob": zero, |
| } |
|
|
| loss_tensor = torch.stack(losses) |
| mask_prob_tensor = torch.stack(mask_probs) |
| return loss_tensor.mean(), { |
| "entries": entries_used, |
| "loss_variance": loss_tensor.var(unbiased=False), |
| "mean_mask_prob": mask_prob_tensor.mean(), |
| } |
|
|
|
|
| def build_kd_force_mask( |
| sample_ids: Sequence[str], |
| bank_map: Dict[str, List[Dict[str, object]]], |
| seq_len: int, |
| device: torch.device, |
| ) -> torch.Tensor: |
| force_mask = torch.zeros((len(sample_ids), seq_len), device=device, dtype=torch.bool) |
| for batch_idx, sample_id in enumerate(sample_ids): |
| for entry in bank_map.get(sample_id, []): |
| position = int(entry["position"]) |
| if 0 <= position < seq_len: |
| force_mask[batch_idx, position] = True |
| return force_mask |
|
|
|
|
| def compute_teacher_logprobs( |
| teacher: AutoModelForImageTextToText, |
| processor: AutoProcessor, |
| pil_image: Image.Image, |
| prompt_text: str, |
| candidate_texts: Sequence[str], |
| teacher_batch_size: int, |
| ) -> torch.Tensor: |
| prompt_messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": pil_image}, |
| {"type": "text", "text": prompt_text}, |
| ], |
| } |
| ] |
| prompt_inputs = processor.apply_chat_template( |
| prompt_messages, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| ) |
| prompt_len = prompt_inputs["input_ids"].shape[1] |
|
|
| teacher_device = next(teacher.parameters()).device |
| all_logprobs = [] |
| for start in range(0, len(candidate_texts), max(teacher_batch_size, 1)): |
| batch_candidates = candidate_texts[start : start + max(teacher_batch_size, 1)] |
| conversations = [] |
| for candidate_text in batch_candidates: |
| conversations.append( |
| [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": pil_image}, |
| {"type": "text", "text": prompt_text}, |
| ], |
| }, |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": candidate_text}], |
| }, |
| ] |
| ) |
|
|
| batch_inputs = processor.apply_chat_template( |
| conversations, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| padding=True, |
| add_generation_prompt=False, |
| ) |
| batch_inputs = {key: value.to(teacher_device) for key, value in batch_inputs.items()} |
| outputs = teacher(**batch_inputs) |
| logits = outputs.logits[:, :-1, :] |
| labels = batch_inputs["input_ids"][:, 1:].clone() |
| attention_mask = batch_inputs["attention_mask"] |
|
|
| seq_len = batch_inputs["input_ids"].shape[1] |
| for batch_idx in range(labels.shape[0]): |
| valid_len = int(attention_mask[batch_idx].sum().item()) |
| left_pad = seq_len - valid_len |
| prefix_cut = left_pad + prompt_len - 1 |
| if prefix_cut > 0: |
| labels[batch_idx, :prefix_cut] = -100 |
| labels[batch_idx, attention_mask[batch_idx, 1:] == 0] = -100 |
|
|
| per_token = F.cross_entropy( |
| logits.reshape(-1, logits.shape[-1]), |
| labels.reshape(-1), |
| ignore_index=-100, |
| reduction="none", |
| ).reshape(labels.shape) |
| token_mask = (labels != -100).float() |
| all_logprobs.append(-(per_token * token_mask).sum(dim=-1).cpu()) |
|
|
| return torch.cat(all_logprobs, dim=0) |
|
|
|
|
| def choose_distillation_positions( |
| clean_logits: torch.Tensor, |
| labels: torch.Tensor, |
| loss_mask: torch.Tensor, |
| max_positions: int, |
| ) -> List[int]: |
| valid_positions = torch.nonzero(loss_mask > 0, as_tuple=False).flatten() |
| if valid_positions.numel() == 0: |
| return [] |
| probs = F.softmax(clean_logits[valid_positions], dim=-1) |
| gold = labels[valid_positions].unsqueeze(-1) |
| gold_probs = probs.gather(-1, gold).squeeze(-1) |
| _, ranked = torch.sort(gold_probs, descending=False) |
| selected = valid_positions[ranked][:max_positions] |
| return [int(pos.item()) for pos in selected] |
|
|
|
|
| def build_candidate_ids( |
| logits_at_position: torch.Tensor, |
| gold_token_id: int, |
| top_k: int, |
| ) -> List[int]: |
| candidate_ids = logits_at_position.topk(max(top_k - 1, 1)).indices.tolist() |
| if gold_token_id not in candidate_ids: |
| candidate_ids.append(gold_token_id) |
| deduped = [] |
| seen = set() |
| for token_id in candidate_ids: |
| if token_id in seen: |
| continue |
| deduped.append(token_id) |
| seen.add(token_id) |
| return deduped[:top_k] |
|
|
|
|
| def decode_assistant_text( |
| tokenizer: AutoTokenizer, |
| full_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| ) -> str: |
| active = (attention_mask > 0) & (loss_mask > 0) |
| assistant_ids = full_ids[active].tolist() |
| return tokenizer.decode(assistant_ids, skip_special_tokens=True).strip() |
|
|
|
|
| def prepare_teacher_bank( |
| args: argparse.Namespace, |
| model: ViLDLM, |
| dataset: NormalizedVisionLanguageDataset, |
| ) -> None: |
| if args.dry_run_batches: |
| max_items = min(args.teacher_batch_size * args.dry_run_batches, len(dataset)) |
| elif args.max_samples: |
| max_items = min(args.max_samples, len(dataset)) |
| else: |
| max_items = len(dataset) |
|
|
| try: |
| from transformers import BitsAndBytesConfig |
| except ImportError as exc: |
| raise RuntimeError("bitsandbytes/transformers quantization support is required for Stage 3a.") from exc |
|
|
| print(f"Loading teacher: {args.teacher_model_id}") |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_quant_type="nf4", |
| ) |
| teacher = AutoModelForImageTextToText.from_pretrained( |
| args.teacher_model_id, |
| quantization_config=quantization_config, |
| device_map="auto", |
| attn_implementation="sdpa", |
| ) |
| teacher.eval() |
| processor = AutoProcessor.from_pretrained(args.teacher_model_id, padding_side="left") |
|
|
| cache_dir = Path(args.teacher_cache_dir) |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| output_path = cache_dir / "candidate_bank.jsonl" |
| seen_keys = set() |
| if output_path.exists(): |
| with open(output_path, "r", encoding="utf-8") as handle: |
| for line in handle: |
| if not line.strip(): |
| continue |
| record = json.loads(line) |
| seen_keys.add((record["sample_id"], int(record["position"]))) |
|
|
| dataloader = build_dataloader( |
| dataset=dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=0, |
| persistent_workers=False, |
| ) |
|
|
| prepared = 0 |
| teacher_entropies: List[float] = [] |
| with torch.no_grad(), open(output_path, "a", encoding="utf-8") as writer: |
| for batch in dataloader: |
| sample_id = batch["sample_id"][0] |
| prompt_text = batch["prompt_text"][0] |
| target_text = batch["target_text"][0] |
| pil_image = batch["pil_image"][0] |
| pixel_values = batch["pixel_values"].to(next(model.parameters()).device) |
| input_ids = batch["input_ids"].to(pixel_values.device) |
| attention_mask = batch["attention_mask"].to(pixel_values.device) |
| labels = batch["labels"].to(pixel_values.device) |
| loss_mask = batch["loss_mask"].to(pixel_values.device) |
|
|
| clean_logits = model.predict_clean_logits(pixel_values, input_ids, attention_mask)[0] |
| sample_labels = labels[0] |
| sample_loss_mask = loss_mask[0] |
| positions = choose_distillation_positions( |
| clean_logits=clean_logits, |
| labels=sample_labels, |
| loss_mask=sample_loss_mask, |
| max_positions=args.kd_positions_per_sample, |
| ) |
|
|
| for position in positions: |
| cache_key = (sample_id, position) |
| if cache_key in seen_keys: |
| continue |
| gold_token_id = int(sample_labels[position].item()) |
| candidate_token_ids = build_candidate_ids( |
| logits_at_position=clean_logits[position], |
| gold_token_id=gold_token_id, |
| top_k=args.kd_top_k, |
| ) |
| candidate_texts: List[str] = [] |
| for candidate_id in candidate_token_ids: |
| modified_ids = input_ids[0].clone() |
| modified_ids[position] = candidate_id |
| candidate_texts.append( |
| decode_assistant_text( |
| tokenizer=model.tokenizer, |
| full_ids=modified_ids, |
| attention_mask=attention_mask[0], |
| loss_mask=loss_mask[0], |
| ) |
| ) |
| teacher_logprobs = compute_teacher_logprobs( |
| teacher=teacher, |
| processor=processor, |
| pil_image=pil_image, |
| prompt_text=prompt_text, |
| candidate_texts=candidate_texts, |
| teacher_batch_size=args.teacher_batch_size, |
| ) |
| teacher_probs_tensor = F.softmax(teacher_logprobs / args.kd_temperature, dim=-1) |
| teacher_entropy = float( |
| -(teacher_probs_tensor * teacher_probs_tensor.clamp_min(1e-12).log()).sum().item() |
| ) |
| teacher_probs = teacher_probs_tensor.cpu().tolist() |
| record = { |
| "sample_id": sample_id, |
| "position": position, |
| "candidate_token_ids": candidate_token_ids, |
| "teacher_probs": teacher_probs, |
| "gold_token_id": gold_token_id, |
| "temperature": args.kd_temperature, |
| "teacher_entropy": teacher_entropy, |
| "source_config": batch["source_config"][0], |
| "text_hash": stable_text_hash(sample_id, prompt_text, target_text), |
| } |
| writer.write(json.dumps(record) + "\n") |
| seen_keys.add(cache_key) |
| prepared += 1 |
| teacher_entropies.append(teacher_entropy) |
| if args.dry_run_batches and prepared >= args.kd_positions_per_sample * args.dry_run_batches: |
| break |
| if prepared and prepared % 50 == 0: |
| print(f"Prepared {prepared} KD entries...") |
|
|
| print(f"Teacher bank written to {output_path} with {prepared} new entries") |
| if teacher_entropies: |
| entropy_array = np.array(teacher_entropies, dtype=np.float32) |
| print( |
| "Teacher entropy: " |
| f"mean={float(entropy_array.mean()):.4f}, " |
| f"min={float(entropy_array.min()):.4f}, " |
| f"max={float(entropy_array.max()):.4f}" |
| ) |
|
|
|
|
| def load_teacher_bank(cache_dir: str) -> Dict[str, List[Dict[str, object]]]: |
| bank_path = Path(cache_dir) / "candidate_bank.jsonl" |
| if not bank_path.exists(): |
| raise FileNotFoundError(f"Teacher bank not found: {bank_path}") |
| bank_map: Dict[str, List[Dict[str, object]]] = defaultdict(list) |
| with open(bank_path, "r", encoding="utf-8") as handle: |
| for line in handle: |
| if not line.strip(): |
| continue |
| record = json.loads(line) |
| bank_map[record["sample_id"]].append(record) |
| print(f"Loaded teacher bank for {len(bank_map)} samples from {bank_path}") |
| return bank_map |
|
|
|
|
| def maybe_push_to_hub( |
| args: argparse.Namespace, |
| save_dir: Path, |
| params: Dict[str, int], |
| best_loss: float, |
| ) -> None: |
| if not args.push_to_hub: |
| print("Skipping Hub push (enable with --push_to_hub).") |
| return |
|
|
| print("\nPushing to Hub...") |
| api = HfApi() |
| repo_id = args.hub_model_id |
| try: |
| api.create_repo(repo_id, exist_ok=True, private=False) |
| except Exception as exc: |
| print(f"Repo note: {exc}") |
|
|
| config_dict = { |
| "architecture": "ViL-DLM", |
| "training_stage": args.stage, |
| "best_loss": best_loss, |
| "total_params_M": params["total"] / 1e6, |
| "trainable_params_M": params["trainable"] / 1e6, |
| "teacher": args.teacher_model_id, |
| "dataset_configs": parse_dataset_configs(args.dataset_configs) if args.stage in {"2", "3a", "3b"} else ["llava_pretrain"], |
| } |
| with open(save_dir / "model_config.json", "w", encoding="utf-8") as handle: |
| json.dump(config_dict, handle, indent=2) |
|
|
| api.upload_folder( |
| folder_path=str(save_dir), |
| repo_id=repo_id, |
| commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})", |
| ) |
| print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}") |
|
|
|
|
| def run_training_stage(args: argparse.Namespace) -> None: |
| tracker = _TrackioShim() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print_device_info(device) |
| ensure_runtime_requirements(args) |
| lm_path = download_student_backbone() |
|
|
| vil_config = ViLConfig() |
| proj_config = ProjConfig() |
| model = ViLDLM(vil_config, proj_config, lm_path) |
| setup_model_for_stage(model, args.stage) |
| maybe_resume_model(model, args) |
|
|
| params = model.count_params() |
| print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M") |
| print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M") |
|
|
| model = model.to(device) |
| if hasattr(model.lm, "gradient_checkpointing_enable"): |
| model.lm.gradient_checkpointing_enable() |
|
|
| dataset, skip_stats = create_stage_dataset("1" if args.stage == "1" else "2", model.tokenizer, args) |
| if skip_stats: |
| print(f"Skip stats: {json.dumps(skip_stats)}") |
|
|
| if args.stage == "3a": |
| prepare_teacher_bank(args=args, model=model, dataset=dataset) |
| return |
|
|
| teacher_bank = load_teacher_bank(args.teacher_cache_dir) if args.stage == "3b" else {} |
| dataloader = build_dataloader( |
| dataset=dataset, |
| batch_size=args.batch_size, |
| shuffle=args.stage != "3a" and not (args.stage == "3b" and args.dry_run_batches), |
| num_workers=args.num_workers, |
| persistent_workers=args.persistent_workers, |
| ) |
|
|
| optimizer = get_optimizer(model, stage="1" if args.stage == "1" else "2") |
| total_steps = max(1, (len(dataloader) * args.epochs) // max(args.grad_accum, 1)) |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) |
| tracker.init(name=f"vil-dlm-stage{args.stage}") |
|
|
| best_loss = float("inf") |
| global_step = 0 |
| step_timer = time.time() |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| epoch_loss = 0.0 |
| epoch_kd_loss = 0.0 |
| epoch_kd_entries = 0 |
| epoch_effective_alpha = 0.0 |
| epoch_kd_mask_prob = 0.0 |
| epoch_kd_loss_variance = 0.0 |
| num_batches = 0 |
|
|
| optimizer.zero_grad(set_to_none=True) |
| for batch_idx, batch in enumerate(dataloader): |
| pixel_values = batch["pixel_values"].to(device) |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
| loss_mask = batch["loss_mask"].to(device) |
| force_mask = None |
| if args.stage == "3b": |
| force_mask = build_kd_force_mask( |
| sample_ids=batch["sample_id"], |
| bank_map=teacher_bank, |
| seq_len=input_ids.shape[1], |
| device=device, |
| ) |
|
|
| outputs = model( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| loss_mask=loss_mask, |
| force_mask=force_mask, |
| ) |
| diffusion_loss = outputs["loss"] |
| kd_loss = torch.tensor(0.0, device=device) |
| kd_entries = 0 |
| kd_loss_variance = torch.tensor(0.0, device=device) |
| mean_kd_mask_prob = torch.tensor(0.0, device=device) |
| effective_alpha_kd = torch.tensor(0.0, device=device) |
| total_loss = diffusion_loss |
| if args.stage == "3b": |
| kd_loss, kd_metrics = compute_sparse_kd_loss( |
| student_logits=outputs["logits"], |
| noise_mask=outputs["noise_mask"], |
| timesteps=outputs["t"], |
| sample_ids=batch["sample_id"], |
| bank_map=teacher_bank, |
| temperature=args.kd_temperature, |
| ) |
| kd_entries = int(kd_metrics["entries"]) |
| kd_loss_variance = kd_metrics["loss_variance"] |
| mean_kd_mask_prob = kd_metrics["mean_mask_prob"] |
| if kd_entries > 0: |
| if args.kd_timestep_weighting: |
| effective_alpha_kd = args.alpha_kd * mean_kd_mask_prob |
| else: |
| effective_alpha_kd = torch.tensor(args.alpha_kd, device=device) |
| total_loss = (1.0 - effective_alpha_kd) * diffusion_loss + effective_alpha_kd * kd_loss |
|
|
| loss = total_loss / args.grad_accum |
| loss.backward() |
|
|
| if (batch_idx + 1) % args.grad_accum == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
| global_step += 1 |
|
|
| actual_loss = float(total_loss.item()) |
| actual_diffusion = float(diffusion_loss.item()) |
| actual_kd = float(kd_loss.item()) if args.stage == "3b" else 0.0 |
| actual_kd_variance = float(kd_loss_variance.item()) if args.stage == "3b" else 0.0 |
| actual_kd_mask_prob = float(mean_kd_mask_prob.item()) if args.stage == "3b" else 0.0 |
| actual_effective_alpha = float(effective_alpha_kd.item()) if args.stage == "3b" else 0.0 |
| elapsed = max(time.time() - step_timer, 1e-6) |
| samples_per_sec = (args.batch_size * args.grad_accum) / elapsed |
| step_timer = time.time() |
| gpu_mem_gb = 0.0 |
| if torch.cuda.is_available(): |
| gpu_mem_gb = torch.cuda.max_memory_allocated(device) / 1e9 |
|
|
| print( |
| f"[E{epoch}] Step {global_step}/{total_steps} | " |
| f"Loss: {actual_loss:.4f} | Diff: {actual_diffusion:.4f} | " |
| f"KD: {actual_kd:.4f} | KD entries: {kd_entries} | " |
| f"KD var: {actual_kd_variance:.4f} | KD mask_p: {actual_kd_mask_prob:.4f} | " |
| f"alpha_kd: {actual_effective_alpha:.4f} | " |
| f"Samples/s: {samples_per_sec:.2f} | GPU mem: {gpu_mem_gb:.2f} GB" |
| ) |
| tracker.log( |
| { |
| "train/loss": actual_loss, |
| "train/diffusion_loss": actual_diffusion, |
| "train/kd_loss": actual_kd, |
| "train/kd_entries": kd_entries, |
| "train/kd_loss_variance": actual_kd_variance, |
| "train/mean_kd_mask_prob": actual_kd_mask_prob, |
| "train/effective_alpha_kd": actual_effective_alpha, |
| "train/epoch": epoch, |
| "train/step": global_step, |
| "train/samples_per_sec": samples_per_sec, |
| "train/gpu_mem_gb": gpu_mem_gb, |
| } |
| ) |
|
|
| epoch_loss += float(total_loss.item()) |
| epoch_kd_loss += float(kd_loss.item()) |
| epoch_kd_entries += kd_entries |
| epoch_effective_alpha += float(effective_alpha_kd.item()) |
| epoch_kd_mask_prob += float(mean_kd_mask_prob.item()) |
| epoch_kd_loss_variance += float(kd_loss_variance.item()) |
| num_batches += 1 |
|
|
| if args.dry_run_batches and num_batches >= args.dry_run_batches: |
| break |
|
|
| remainder = num_batches % args.grad_accum |
| if remainder != 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
| global_step += 1 |
|
|
| avg_loss = epoch_loss / max(num_batches, 1) |
| avg_kd_loss = epoch_kd_loss / max(num_batches, 1) |
| avg_effective_alpha = epoch_effective_alpha / max(num_batches, 1) |
| avg_kd_mask_prob = epoch_kd_mask_prob / max(num_batches, 1) |
| avg_kd_loss_variance = epoch_kd_loss_variance / max(num_batches, 1) |
| print( |
| f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f} | Average KD: {avg_kd_loss:.4f} | " |
| f"KD entries: {epoch_kd_entries} | Avg alpha_kd: {avg_effective_alpha:.4f} | " |
| f"Avg KD mask_p: {avg_kd_mask_prob:.4f} | Avg KD var: {avg_kd_loss_variance:.4f}\n" |
| ) |
| tracker.log( |
| { |
| "train/epoch_loss": avg_loss, |
| "train/epoch_kd_loss": avg_kd_loss, |
| "train/epoch_kd_entries": epoch_kd_entries, |
| "train/epoch_effective_alpha_kd": avg_effective_alpha, |
| "train/epoch_mean_kd_mask_prob": avg_kd_mask_prob, |
| "train/epoch_kd_loss_variance": avg_kd_loss_variance, |
| "train/epoch": epoch, |
| } |
| ) |
|
|
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| save_dir = Path(args.output_dir) / f"stage{args.stage}_best" |
| include_lm = args.stage in {"2", "3b"} |
| model.save_checkpoint(save_dir, include_lm=include_lm) |
| training_state = { |
| "stage": args.stage, |
| "best_loss": best_loss, |
| "args": vars(args), |
| } |
| with open(save_dir / "training_state.json", "w", encoding="utf-8") as handle: |
| json.dump(training_state, handle, indent=2) |
| print(f"Saved best checkpoint (loss={best_loss:.4f})") |
|
|
| maybe_push_to_hub( |
| args=args, |
| save_dir=Path(args.output_dir) / f"stage{args.stage}_best", |
| params=params, |
| best_loss=best_loss, |
| ) |
| print("Training complete!") |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--stage", type=str, default="1", choices=["1", "2", "3a", "3b"]) |
| parser.add_argument("--epochs", type=int, default=2) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--grad_accum", type=int, default=8) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--max_samples", type=int, default=None) |
| parser.add_argument("--output_dir", type=str, default="./vil-dlm-output") |
| parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B") |
| parser.add_argument("--push_to_hub", action="store_true") |
| parser.add_argument("--require_cuda", action="store_true") |
| parser.add_argument("--resume_from", type=str, default=None) |
| parser.add_argument("--dataset_configs", type=str, default=",".join(DEFAULT_CAULDRON_CONFIGS)) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--persistent_workers", action="store_true") |
| parser.add_argument("--dry_run_batches", type=int, default=0) |
| parser.add_argument("--teacher_model_id", type=str, default="google/gemma-4-E2B-it") |
| parser.add_argument("--teacher_cache_dir", type=str, default="./vil-dlm-output/teacher-cache") |
| parser.add_argument("--prepare_teacher_bank", action="store_true") |
| parser.add_argument("--teacher_batch_size", type=int, default=1) |
| parser.add_argument("--alpha_kd", type=float, default=0.5) |
| parser.add_argument("--kd_temperature", type=float, default=1.0) |
| parser.add_argument("--kd_timestep_weighting", action=argparse.BooleanOptionalAction, default=True) |
| parser.add_argument("--kd_top_k", type=int, default=16) |
| parser.add_argument("--kd_positions_per_sample", type=int, default=16) |
| return parser |
|
|
|
|
| if __name__ == "__main__": |
| ensure_hf_cache_root() |
| parser = build_parser() |
| args = parser.parse_args() |
| if args.prepare_teacher_bank and args.stage != "3a": |
| raise ValueError("--prepare_teacher_bank is only valid with --stage 3a") |
| if args.kd_temperature <= 0: |
| raise ValueError("--kd_temperature must be > 0") |
| if not 0.0 <= args.alpha_kd <= 1.0: |
| raise ValueError("--alpha_kd must be between 0 and 1") |
| run_training_stage(args) |
|
|