""" ViL-DLM production training script. Stages: 1 - projector-only alignment on LLaVA-Pretrain 2 - full-model finetune on The Cauldron 3a - offline teacher candidate-bank preparation with Gemma 4 E2B-it 3b - sparse cross-tokenizer distillation training using cached teacher targets """ import argparse import hashlib import json import math import os import time import zipfile from collections import defaultdict from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset from datasets import Features, Image as HFImage from datasets.features import Sequence as HFSequence from huggingface_hub import HfApi, snapshot_download from PIL import Image from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader, Dataset from transformers import ( AutoModelForImageTextToText, AutoModelForMaskedLM, AutoProcessor, AutoTokenizer, ) try: import trackio except Exception: trackio = None from vision_xlstm import ( VisionProjector as UpstreamVisionProjector, VisionXLSTM as UpstreamVisionXLSTM, ) DEFAULT_CAULDRON_CONFIGS = [ "ai2d", "vqav2", "aokvqa", "textvqa", "docvqa", "chartqa", "textcaps", "screen2words", ] DEFAULT_CAULDRON_GATE_CONFIGS = [ "ai2d", "aokvqa", ] @dataclass class ViLConfig: vision_backbone: str = "vil2-small" pretrained: bool = True img_size: int = 224 patch_size: int = 16 in_channels: int = 3 dim: int = 384 depth: int = 24 conv_kernel_size: int = 3 bidirectional: bool = True dropout: float = 0.0 @property def num_patches(self) -> int: return (self.img_size // self.patch_size) ** 2 @dataclass class ProjConfig: vil_dim: int = 384 lm_dim: int = 1024 hidden_mult: int = 2 num_layers: int = 2 dropout: float = 0.0 class _TrackioShim: def __init__(self) -> None: self.enabled = False def init(self, name: str, project: str = "vil-dlm") -> None: if trackio is None: print("Trackio disabled: package not installed in the active environment") self.enabled = False return try: trackio.init(name=name, project=project) self.enabled = True except Exception as exc: self.enabled = False print(f"Trackio disabled: {exc}") def log(self, payload: dict) -> None: if not self.enabled: return try: trackio.log(payload) except Exception as exc: self.enabled = False print(f"Trackio logging disabled after error: {exc}") class MDLMScheduler: def __init__(self, mask_token_id: int) -> None: self.mask_token_id = mask_token_id def add_noise( self, input_ids: torch.Tensor, t: torch.Tensor, eligible_mask: Optional[torch.Tensor] = None, force_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch, length = input_ids.shape mask_ratio = 1.0 - torch.cos(t * math.pi / 2) mask_ratio = mask_ratio.unsqueeze(1).expand(batch, length) mask = torch.rand(batch, length, device=input_ids.device) < mask_ratio if eligible_mask is not None: eligible_mask = eligible_mask.bool() mask = mask & eligible_mask if force_mask is not None: mask = mask | (force_mask.bool() & eligible_mask) missing_mask = (mask.sum(dim=1) == 0) & (eligible_mask.sum(dim=1) > 0) for batch_idx in torch.nonzero(missing_mask, as_tuple=False).flatten(): eligible_positions = torch.nonzero(eligible_mask[batch_idx], as_tuple=False).flatten() chosen = eligible_positions[torch.randint(eligible_positions.numel(), (1,), device=input_ids.device)] mask[batch_idx, chosen] = True elif force_mask is not None: mask = mask | force_mask.bool() noisy_ids = input_ids.clone() noisy_ids[mask] = self.mask_token_id return noisy_ids, mask def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.rand(batch_size, device=device) class ViLDLM(nn.Module): def __init__(self, vil_config: ViLConfig, proj_config: ProjConfig, lm_path: str) -> None: super().__init__() self.vil_config = vil_config self.vision_encoder = UpstreamVisionXLSTM(vil_config) self.projector = UpstreamVisionProjector(proj_config) self.lm = AutoModelForMaskedLM.from_pretrained( lm_path, trust_remote_code=True, torch_dtype=torch.bfloat16, ) self.tokenizer = AutoTokenizer.from_pretrained(lm_path, trust_remote_code=True) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id) @property def num_patches(self) -> int: return self.vil_config.num_patches def prepare_multimodal_inputs( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: vision_features = self.vision_encoder.forward_features(pixel_values) visual_tokens = self.projector(vision_features) text_embeds = self.lm.model.embed_tokens(input_ids) visual_tokens = visual_tokens.to(dtype=text_embeds.dtype) inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1) vis_mask = torch.ones( pixel_values.shape[0], self.num_patches, device=attention_mask.device, dtype=attention_mask.dtype, ) full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1) return inputs_embeds, full_attention_mask def predict_clean_logits( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, ) outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask) return outputs.logits[:, self.num_patches :, :] def forward( self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None, force_mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: batch_size, seq_len = input_ids.shape device = input_ids.device if labels is None: labels = input_ids.clone() if loss_mask is None: loss_mask = attention_mask t = self.scheduler.sample_timesteps(batch_size, device) eligible_mask = (loss_mask > 0) & (attention_mask > 0) noisy_ids, noise_mask = self.scheduler.add_noise( input_ids, t, eligible_mask=eligible_mask, force_mask=force_mask, ) inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( pixel_values=pixel_values, input_ids=noisy_ids, attention_mask=attention_mask, ) outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attention_mask) text_logits = outputs.logits[:, self.num_patches :, :] active_mask = noise_mask.float() * eligible_mask.float() if active_mask.sum() == 0: loss = torch.tensor(0.0, device=device, requires_grad=True) else: logits_flat = text_logits.reshape(-1, text_logits.shape[-1]) labels_flat = labels.reshape(-1) per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none").reshape(batch_size, seq_len) loss = (per_token * active_mask).sum() / active_mask.sum() return { "loss": loss, "logits": text_logits, "noise_mask": noise_mask, "t": t, } def freeze_vision(self) -> None: for param in self.vision_encoder.parameters(): param.requires_grad = False def freeze_lm(self) -> None: for param in self.lm.parameters(): param.requires_grad = False def unfreeze_all(self) -> None: for param in self.parameters(): param.requires_grad = True def count_params(self) -> Dict[str, int]: vil = sum(p.numel() for p in self.vision_encoder.parameters()) proj = sum(p.numel() for p in self.projector.parameters()) lm = sum(p.numel() for p in self.lm.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return {"vil": vil, "proj": proj, "lm": lm, "total": vil + proj + lm, "trainable": trainable} def save_checkpoint(self, save_dir: Path, include_lm: bool) -> None: save_dir.mkdir(parents=True, exist_ok=True) torch.save(self.vision_encoder.state_dict(), save_dir / "vision_encoder.pt") torch.save(self.projector.state_dict(), save_dir / "projector.pt") if include_lm: self.lm.save_pretrained(save_dir / "diffusion_lm") self.tokenizer.save_pretrained(save_dir / "diffusion_lm") def load_checkpoint(self, checkpoint_dir: Path, include_lm: bool) -> None: vision_path = checkpoint_dir / "vision_encoder.pt" projector_path = checkpoint_dir / "projector.pt" if vision_path.exists(): self.vision_encoder.load_state_dict(torch.load(vision_path, map_location="cpu")) if projector_path.exists(): self.projector.load_state_dict(torch.load(projector_path, map_location="cpu")) if include_lm: diffusion_dir = checkpoint_dir / "diffusion_lm" if diffusion_dir.exists(): self.lm = AutoModelForMaskedLM.from_pretrained( diffusion_dir, trust_remote_code=True, torch_dtype=torch.bfloat16, ) self.tokenizer = AutoTokenizer.from_pretrained(diffusion_dir, trust_remote_code=True) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.scheduler = MDLMScheduler(mask_token_id=self.tokenizer.pad_token_id) def ensure_hf_cache_root() -> None: os.environ.setdefault("HF_HOME", "/teamspace/studios/this_studio/.cache/huggingface") def patch_diffusion_modeling_file(lm_path: str) -> None: modeling_file = os.path.join(lm_path, "modeling_qwen3.py") with open(modeling_file, "r", encoding="utf-8") as handle: content = handle.read() content = content.replace( 'if __name__ == "__main__":\n import dllm', 'if __name__ == "__main__":\n pass\n # import dllm', ) content = content.replace( "attention_mask=causal_mask_mapping[decoder_layer.attention_type]", 'attention_mask=causal_mask_mapping.get(getattr(decoder_layer, "attention_type", "full_attention"), causal_mask_mapping.get("full_attention"))', ) with open(modeling_file, "w", encoding="utf-8") as handle: handle.write(content) def download_student_backbone() -> str: print("Downloading dLLM Qwen3-0.6B diffusion model...") lm_path = snapshot_download("dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1") patch_diffusion_modeling_file(lm_path) print(f"Model downloaded to {lm_path}") return lm_path def parse_dataset_configs(dataset_configs: Optional[str]) -> List[str]: if dataset_configs: return [item.strip() for item in dataset_configs.split(",") if item.strip()] return list(DEFAULT_CAULDRON_CONFIGS) def resolve_cauldron_configs(args: argparse.Namespace) -> List[str]: configs = parse_dataset_configs(args.dataset_configs) default_config_string = ",".join(DEFAULT_CAULDRON_CONFIGS) if args.dry_run_batches and args.dataset_configs == default_config_string: print( "Dry-run mode detected; using the cheap Stage 2 gate config set " f"{DEFAULT_CAULDRON_GATE_CONFIGS} instead of the full production mix." ) return list(DEFAULT_CAULDRON_GATE_CONFIGS) return configs def stable_text_hash(*parts: str) -> str: joined = "\n".join(parts) return hashlib.sha1(joined.encode("utf-8")).hexdigest() def build_prompt_prefix(prompt_text: str) -> str: return f"User: {prompt_text.strip()}\nAssistant:" def tokenize_prompt_and_target( tokenizer: AutoTokenizer, prompt_text: str, target_text: str, max_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: prefix_text = build_prompt_prefix(prompt_text) prefix_ids = tokenizer(prefix_text, add_special_tokens=True)["input_ids"] target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"] if not target_ids: target_ids = tokenizer(" " + "N/A", add_special_tokens=False)["input_ids"][:1] max_prefix_len = max_length - 1 if len(prefix_ids) > max_prefix_len: prefix_ids = prefix_ids[:max_prefix_len] remaining = max_length - len(prefix_ids) if remaining <= 0: prefix_ids = prefix_ids[: max_length - 1] remaining = 1 target_ids = target_ids[:remaining] if not target_ids: prefix_ids = prefix_ids[: max_length - 1] target_ids = tokenizer(" " + target_text.strip(), add_special_tokens=False)["input_ids"][:1] input_ids = prefix_ids + target_ids loss_mask = [0] * len(prefix_ids) + [1] * len(target_ids) attention_mask = [1] * len(input_ids) labels = list(input_ids) pad_token_id = tokenizer.pad_token_id if pad_token_id is None: pad_token_id = tokenizer.eos_token_id pad_len = max_length - len(input_ids) if pad_len > 0: input_ids = input_ids + [pad_token_id] * pad_len attention_mask = attention_mask + [0] * pad_len labels = labels + [pad_token_id] * pad_len loss_mask = loss_mask + [0] * pad_len return ( torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long), torch.tensor(labels, dtype=torch.long), torch.tensor(loss_mask, dtype=torch.float32), ) def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]: if isinstance(img, str): img = Image.open(img).convert("RGB") elif isinstance(img, dict) and "zip_path" in img and "member" in img: with zipfile.ZipFile(img["zip_path"], "r") as archive: with archive.open(img["member"], "r") as member_file: img = Image.open(member_file).convert("RGB") elif isinstance(img, dict) and img.get("bytes") is not None: img = Image.open(BytesIO(img["bytes"])).convert("RGB") elif isinstance(img, dict) and img.get("path") and os.path.exists(img["path"]): img = Image.open(img["path"]).convert("RGB") elif isinstance(img, Image.Image): img = img.convert("RGB") else: raise ValueError(f"Unsupported image payload type: {type(img)!r}") pil_image = img resized = pil_image.resize((img_size, img_size), Image.BICUBIC) arr = np.array(resized).astype(np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1) mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) tensor = (tensor - mean) / std return tensor, pil_image def is_usable_image_payload(img: object) -> bool: if isinstance(img, Image.Image): return True if isinstance(img, str): return os.path.exists(img) if isinstance(img, dict): if img.get("zip_path") and img.get("member"): return os.path.exists(img["zip_path"]) if img.get("bytes") is not None: return True if img.get("path"): return os.path.exists(img["path"]) return False class NormalizedVisionLanguageDataset(Dataset): def __init__( self, records: HFDataset, tokenizer: AutoTokenizer, max_length: int, img_size: int, ) -> None: self.records = records self.tokenizer = tokenizer self.max_length = max_length self.img_size = img_size def __len__(self) -> int: return len(self.records) def __getitem__(self, idx: int) -> Dict[str, object]: sample = self.records[int(idx)] pixel_values, pil_image = preprocess_image_for_student(sample["image"], self.img_size) input_ids, attention_mask, labels, loss_mask = tokenize_prompt_and_target( tokenizer=self.tokenizer, prompt_text=sample["prompt_text"], target_text=sample["target_text"], max_length=self.max_length, ) return { "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "loss_mask": loss_mask, "sample_id": sample["sample_id"], "prompt_text": sample["prompt_text"], "target_text": sample["target_text"], "source_config": sample.get("source_config", "unknown"), "pil_image": pil_image, } def build_llava_records(max_samples: Optional[int]) -> HFDataset: print("Loading LLaVA-Pretrain dataset...") dataset_root = None images_zip_path = None zip_members = None try: data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train") except Exception as exc: print(f"Primary dataset loader failed ({exc}). Falling back to direct JSON loading...") dataset_root = snapshot_download( "liuhaotian/LLaVA-Pretrain", repo_type="dataset", allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"], ) json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json") images_zip_path = os.path.join(dataset_root, "images.zip") if os.path.exists(images_zip_path): with zipfile.ZipFile(images_zip_path, "r") as archive: zip_members = set(archive.namelist()) data = load_dataset("json", data_files={"train": json_path}, split="train") if max_samples: data = data.select(range(min(max_samples, len(data)))) stats = defaultdict(int) def normalize(sample: Dict[str, object], idx: int) -> Optional[Dict[str, object]]: text = "" if "conversations" in sample: parts = [] for turn in sample["conversations"]: val = turn.get("value", "").replace("\n", "").replace("", "").strip() if val: parts.append(val) text = " ".join(parts) elif sample.get("blip_caption"): text = sample["blip_caption"].strip() if not text: text = "Describe this image." image_obj = sample.get("image") if image_obj is None: stats["missing_image_ref"] += 1 return None if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj): candidate_paths = [ image_obj, os.path.join(dataset_root, image_obj), os.path.join(dataset_root, "images", image_obj), ] resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None) if resolved_path: image_obj = resolved_path elif images_zip_path and os.path.exists(images_zip_path) and zip_members: member_name = None if image_obj in zip_members: member_name = image_obj elif f"images/{image_obj}" in zip_members: member_name = f"images/{image_obj}" if member_name is None: stats["missing_backing_image"] += 1 return None image_obj = { "zip_path": images_zip_path, "member": member_name, } else: stats["missing_backing_image"] += 1 return None stats["kept"] += 1 return { "image": image_obj, "prompt_text": "Describe this image.", "target_text": text, "sample_id": f"llava-pretrain:{sample.get('id', idx)}", "source_config": "llava_pretrain", } records = [record for i in range(len(data)) if (record := normalize(data[i], i)) is not None] normalized = HFDataset.from_list(records) print( f"Loaded {len(normalized)} LLaVA samples " f"(kept={stats['kept']}, missing_image_ref={stats['missing_image_ref']}, " f"missing_backing_image={stats['missing_backing_image']})" ) return normalized def disable_image_decoding(feature: object) -> object: if isinstance(feature, HFImage): return HFImage(decode=False) if isinstance(feature, HFSequence): return HFSequence(feature=disable_image_decoding(feature.feature), length=feature.length) if isinstance(feature, Features): return Features({key: disable_image_decoding(value) for key, value in feature.items()}) if isinstance(feature, dict): return {key: disable_image_decoding(value) for key, value in feature.items()} if isinstance(feature, list): return [disable_image_decoding(value) for value in feature] return feature def build_cauldron_records( configs: Sequence[str], max_samples: Optional[int], raw_row_limit: Optional[int] = None, ) -> Tuple[HFDataset, Dict[str, Dict[str, int]]]: normalized_configs: List[HFDataset] = [] skip_stats: Dict[str, Dict[str, int]] = {} per_config_limit = None if max_samples: per_config_limit = max(1, max_samples // max(len(configs), 1)) for config_name in configs: print(f"Loading The Cauldron config: {config_name}") ds = load_dataset("HuggingFaceM4/the_cauldron", config_name, split="train") if raw_row_limit is not None: ds = ds.select(range(min(raw_row_limit, len(ds)))) if "images" in ds.features: ds = ds.cast_column("images", disable_image_decoding(ds.features["images"])) if "image" in ds.features: ds = ds.cast_column("image", disable_image_decoding(ds.features["image"])) stats = defaultdict(int) def explode(batch: Dict[str, List[object]], indices: List[int]) -> Dict[str, List[object]]: output = { "image": [], "prompt_text": [], "target_text": [], "sample_id": [], "source_config": [], } batch_images = batch.get("images") batch_single_image = batch.get("image") batch_texts = batch.get("texts") or batch.get("conversations") for local_idx, row_idx in enumerate(indices): if batch_images is not None: images = batch_images[local_idx] elif batch_single_image is not None: images = batch_single_image[local_idx] else: stats["missing_image_column"] += 1 continue if batch_texts is None: stats["missing_text_column"] += 1 continue texts = batch_texts[local_idx] if images is None: images = [] elif not isinstance(images, list): images = [images] if texts is None: texts = [] elif isinstance(texts, dict): texts = [texts] if not images or len(images) != 1: stats["multi_or_missing_image"] += 1 continue image_payload = images[0] if not is_usable_image_payload(image_payload): stats["unusable_image_ref"] += 1 continue if not texts: stats["missing_turns"] += 1 continue for turn_idx, turn in enumerate(texts): if not isinstance(turn, dict): stats["unsupported_turn_type"] += 1 continue user_text = ( turn.get("user") or turn.get("question") or turn.get("prompt") or turn.get("input") or "" ).strip() assistant_text = ( turn.get("assistant") or turn.get("answer") or turn.get("response") or turn.get("output") or "" ).strip() if not user_text or not assistant_text: stats["missing_user_or_assistant"] += 1 continue output["image"].append(image_payload) output["prompt_text"].append(user_text) output["target_text"].append(assistant_text) output["sample_id"].append(f"{config_name}:{row_idx}:{turn_idx}") output["source_config"].append(config_name) stats["kept"] += 1 return output exploded = ds.map( explode, batched=True, with_indices=True, remove_columns=ds.column_names, desc=f"Normalizing {config_name}", ) if per_config_limit is not None: exploded = exploded.select(range(min(per_config_limit, len(exploded)))) normalized_configs.append(exploded) stats["kept"] = len(exploded) skip_stats[config_name] = dict(stats) print(f"{config_name}: kept={stats['kept']} skipped={sum(v for k, v in stats.items() if k != 'kept')}") if not normalized_configs: raise RuntimeError("No valid The Cauldron configs were loaded.") combined = concatenate_datasets(normalized_configs) if max_samples: combined = combined.select(range(min(max_samples, len(combined)))) print(f"Loaded {len(combined)} normalized The Cauldron samples") return combined, skip_stats def collate_vision_language(batch: List[Dict[str, object]]) -> Dict[str, object]: return { "pixel_values": torch.stack([sample["pixel_values"] for sample in batch]), "input_ids": torch.stack([sample["input_ids"] for sample in batch]), "attention_mask": torch.stack([sample["attention_mask"] for sample in batch]), "labels": torch.stack([sample["labels"] for sample in batch]), "loss_mask": torch.stack([sample["loss_mask"] for sample in batch]), "sample_id": [sample["sample_id"] for sample in batch], "prompt_text": [sample["prompt_text"] for sample in batch], "target_text": [sample["target_text"] for sample in batch], "source_config": [sample["source_config"] for sample in batch], "pil_image": [sample["pil_image"] for sample in batch], } def create_stage_dataset(stage: str, tokenizer: AutoTokenizer, args: argparse.Namespace) -> Tuple[NormalizedVisionLanguageDataset, Dict[str, Dict[str, int]]]: if stage == "1": return NormalizedVisionLanguageDataset( records=build_llava_records(args.max_samples), tokenizer=tokenizer, max_length=args.max_length, img_size=224, ), {} configs = resolve_cauldron_configs(args) raw_row_limit = None if args.dry_run_batches: raw_row_limit = max(32, args.batch_size * args.dry_run_batches * 8) records, skip_stats = build_cauldron_records(configs, args.max_samples, raw_row_limit=raw_row_limit) return NormalizedVisionLanguageDataset( records=records, tokenizer=tokenizer, max_length=args.max_length, img_size=224, ), skip_stats def build_dataloader( dataset: Dataset, batch_size: int, shuffle: bool, num_workers: int, persistent_workers: bool, ) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=persistent_workers and num_workers > 0, drop_last=False, collate_fn=collate_vision_language, ) def print_device_info(device: torch.device) -> None: print(f"Device: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"torch.version.cuda: {torch.version.cuda}") def ensure_runtime_requirements(args: argparse.Namespace) -> None: if args.require_cuda and not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this run but torch.cuda.is_available() is False.") if args.stage in {"2", "3a", "3b"} and not parse_dataset_configs(args.dataset_configs): raise RuntimeError("Stage 2/3 requires at least one The Cauldron config.") if args.stage in {"3a", "3b"} and not args.teacher_cache_dir: raise RuntimeError("Stage 3 requires --teacher_cache_dir.") if args.stage in {"3a", "3b"} and not args.resume_from: raise RuntimeError("Stage 3 requires --resume_from pointing to a Stage 2 checkpoint.") if args.stage == "3a": try: import bitsandbytes # noqa: F401 except ImportError as exc: raise RuntimeError("Stage 3a requires bitsandbytes in the active environment.") from exc def maybe_resume_model(model: ViLDLM, args: argparse.Namespace) -> None: if not args.resume_from: return checkpoint_dir = Path(args.resume_from) if not checkpoint_dir.exists(): raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") include_lm = args.stage in {"2", "3a", "3b"} print(f"Resuming from checkpoint: {checkpoint_dir}") model.load_checkpoint(checkpoint_dir, include_lm=include_lm) def get_optimizer(model: ViLDLM, stage: str) -> AdamW: if stage == "1": groups = [ { "params": [p for p in model.projector.parameters() if p.requires_grad], "lr": 1e-3, } ] else: groups = [ { "params": [p for p in model.vision_encoder.parameters() if p.requires_grad], "lr": 2e-6, }, { "params": [p for p in model.projector.parameters() if p.requires_grad], "lr": 1e-5, }, { "params": [p for p in model.lm.parameters() if p.requires_grad], "lr": 1e-5, }, ] groups = [group for group in groups if group["params"]] return AdamW(groups, weight_decay=0.05, betas=(0.9, 0.999)) def setup_model_for_stage(model: ViLDLM, stage: str) -> None: if stage == "1": print("\n=== STAGE 1: Projector-only alignment ===") model.freeze_vision() model.freeze_lm() elif stage in {"2", "3b"}: label = "Full finetune" if stage == "2" else "Sparse KD finetune" print(f"\n=== STAGE {stage.upper()}: {label} ===") model.unfreeze_all() elif stage == "3a": print("\n=== STAGE 3A: Teacher candidate-bank preparation ===") model.unfreeze_all() for param in model.parameters(): param.requires_grad = False else: raise ValueError(f"Unsupported stage: {stage}") def compute_sparse_kd_loss( student_logits: torch.Tensor, noise_mask: torch.Tensor, timesteps: torch.Tensor, sample_ids: Sequence[str], bank_map: Dict[str, List[Dict[str, object]]], temperature: float, ) -> Tuple[torch.Tensor, Dict[str, object]]: entries_used = 0 losses: List[torch.Tensor] = [] mask_probs: List[torch.Tensor] = [] mask_probability = 1.0 - torch.cos(timesteps * math.pi / 2) for batch_idx, sample_id in enumerate(sample_ids): sample_entries = bank_map.get(sample_id, []) for entry in sample_entries: position = int(entry["position"]) if position >= student_logits.shape[1]: continue if not bool(noise_mask[batch_idx, position].item()): continue candidate_ids = torch.tensor( entry["candidate_token_ids"], device=student_logits.device, dtype=torch.long, ) teacher_probs = torch.tensor( entry["teacher_probs"], device=student_logits.device, dtype=student_logits.dtype, ) gathered = student_logits[batch_idx, position, candidate_ids] student_log_probs = F.log_softmax(gathered / temperature, dim=-1) loss = F.kl_div( student_log_probs.unsqueeze(0), teacher_probs.unsqueeze(0), reduction="batchmean", ) * (temperature ** 2) losses.append(loss) mask_probs.append(mask_probability[batch_idx]) entries_used += 1 if not losses: zero = torch.tensor(0.0, device=student_logits.device) return zero, { "entries": 0, "loss_variance": zero, "mean_mask_prob": zero, } loss_tensor = torch.stack(losses) mask_prob_tensor = torch.stack(mask_probs) return loss_tensor.mean(), { "entries": entries_used, "loss_variance": loss_tensor.var(unbiased=False), "mean_mask_prob": mask_prob_tensor.mean(), } def build_kd_force_mask( sample_ids: Sequence[str], bank_map: Dict[str, List[Dict[str, object]]], seq_len: int, device: torch.device, ) -> torch.Tensor: force_mask = torch.zeros((len(sample_ids), seq_len), device=device, dtype=torch.bool) for batch_idx, sample_id in enumerate(sample_ids): for entry in bank_map.get(sample_id, []): position = int(entry["position"]) if 0 <= position < seq_len: force_mask[batch_idx, position] = True return force_mask def compute_teacher_logprobs( teacher: AutoModelForImageTextToText, processor: AutoProcessor, pil_image: Image.Image, prompt_text: str, candidate_texts: Sequence[str], teacher_batch_size: int, ) -> torch.Tensor: prompt_messages = [ { "role": "user", "content": [ {"type": "image", "image": pil_image}, {"type": "text", "text": prompt_text}, ], } ] prompt_inputs = processor.apply_chat_template( prompt_messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ) prompt_len = prompt_inputs["input_ids"].shape[1] teacher_device = next(teacher.parameters()).device all_logprobs = [] for start in range(0, len(candidate_texts), max(teacher_batch_size, 1)): batch_candidates = candidate_texts[start : start + max(teacher_batch_size, 1)] conversations = [] for candidate_text in batch_candidates: conversations.append( [ { "role": "user", "content": [ {"type": "image", "image": pil_image}, {"type": "text", "text": prompt_text}, ], }, { "role": "assistant", "content": [{"type": "text", "text": candidate_text}], }, ] ) batch_inputs = processor.apply_chat_template( conversations, tokenize=True, return_dict=True, return_tensors="pt", padding=True, add_generation_prompt=False, ) batch_inputs = {key: value.to(teacher_device) for key, value in batch_inputs.items()} outputs = teacher(**batch_inputs) logits = outputs.logits[:, :-1, :] labels = batch_inputs["input_ids"][:, 1:].clone() attention_mask = batch_inputs["attention_mask"] seq_len = batch_inputs["input_ids"].shape[1] for batch_idx in range(labels.shape[0]): valid_len = int(attention_mask[batch_idx].sum().item()) left_pad = seq_len - valid_len prefix_cut = left_pad + prompt_len - 1 if prefix_cut > 0: labels[batch_idx, :prefix_cut] = -100 labels[batch_idx, attention_mask[batch_idx, 1:] == 0] = -100 per_token = F.cross_entropy( logits.reshape(-1, logits.shape[-1]), labels.reshape(-1), ignore_index=-100, reduction="none", ).reshape(labels.shape) token_mask = (labels != -100).float() all_logprobs.append(-(per_token * token_mask).sum(dim=-1).cpu()) return torch.cat(all_logprobs, dim=0) def choose_distillation_positions( clean_logits: torch.Tensor, labels: torch.Tensor, loss_mask: torch.Tensor, max_positions: int, ) -> List[int]: valid_positions = torch.nonzero(loss_mask > 0, as_tuple=False).flatten() if valid_positions.numel() == 0: return [] probs = F.softmax(clean_logits[valid_positions], dim=-1) gold = labels[valid_positions].unsqueeze(-1) gold_probs = probs.gather(-1, gold).squeeze(-1) _, ranked = torch.sort(gold_probs, descending=False) selected = valid_positions[ranked][:max_positions] return [int(pos.item()) for pos in selected] def build_candidate_ids( logits_at_position: torch.Tensor, gold_token_id: int, top_k: int, ) -> List[int]: candidate_ids = logits_at_position.topk(max(top_k - 1, 1)).indices.tolist() if gold_token_id not in candidate_ids: candidate_ids.append(gold_token_id) deduped = [] seen = set() for token_id in candidate_ids: if token_id in seen: continue deduped.append(token_id) seen.add(token_id) return deduped[:top_k] def decode_assistant_text( tokenizer: AutoTokenizer, full_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, ) -> str: active = (attention_mask > 0) & (loss_mask > 0) assistant_ids = full_ids[active].tolist() return tokenizer.decode(assistant_ids, skip_special_tokens=True).strip() def prepare_teacher_bank( args: argparse.Namespace, model: ViLDLM, dataset: NormalizedVisionLanguageDataset, ) -> None: if args.dry_run_batches: max_items = min(args.teacher_batch_size * args.dry_run_batches, len(dataset)) elif args.max_samples: max_items = min(args.max_samples, len(dataset)) else: max_items = len(dataset) try: from transformers import BitsAndBytesConfig except ImportError as exc: raise RuntimeError("bitsandbytes/transformers quantization support is required for Stage 3a.") from exc print(f"Loading teacher: {args.teacher_model_id}") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", ) teacher = AutoModelForImageTextToText.from_pretrained( args.teacher_model_id, quantization_config=quantization_config, device_map="auto", attn_implementation="sdpa", ) teacher.eval() processor = AutoProcessor.from_pretrained(args.teacher_model_id, padding_side="left") cache_dir = Path(args.teacher_cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) output_path = cache_dir / "candidate_bank.jsonl" seen_keys = set() if output_path.exists(): with open(output_path, "r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue record = json.loads(line) seen_keys.add((record["sample_id"], int(record["position"]))) dataloader = build_dataloader( dataset=dataset, batch_size=1, shuffle=False, num_workers=0, persistent_workers=False, ) prepared = 0 teacher_entropies: List[float] = [] with torch.no_grad(), open(output_path, "a", encoding="utf-8") as writer: for batch in dataloader: sample_id = batch["sample_id"][0] prompt_text = batch["prompt_text"][0] target_text = batch["target_text"][0] pil_image = batch["pil_image"][0] pixel_values = batch["pixel_values"].to(next(model.parameters()).device) input_ids = batch["input_ids"].to(pixel_values.device) attention_mask = batch["attention_mask"].to(pixel_values.device) labels = batch["labels"].to(pixel_values.device) loss_mask = batch["loss_mask"].to(pixel_values.device) clean_logits = model.predict_clean_logits(pixel_values, input_ids, attention_mask)[0] sample_labels = labels[0] sample_loss_mask = loss_mask[0] positions = choose_distillation_positions( clean_logits=clean_logits, labels=sample_labels, loss_mask=sample_loss_mask, max_positions=args.kd_positions_per_sample, ) for position in positions: cache_key = (sample_id, position) if cache_key in seen_keys: continue gold_token_id = int(sample_labels[position].item()) candidate_token_ids = build_candidate_ids( logits_at_position=clean_logits[position], gold_token_id=gold_token_id, top_k=args.kd_top_k, ) candidate_texts: List[str] = [] for candidate_id in candidate_token_ids: modified_ids = input_ids[0].clone() modified_ids[position] = candidate_id candidate_texts.append( decode_assistant_text( tokenizer=model.tokenizer, full_ids=modified_ids, attention_mask=attention_mask[0], loss_mask=loss_mask[0], ) ) teacher_logprobs = compute_teacher_logprobs( teacher=teacher, processor=processor, pil_image=pil_image, prompt_text=prompt_text, candidate_texts=candidate_texts, teacher_batch_size=args.teacher_batch_size, ) teacher_probs_tensor = F.softmax(teacher_logprobs / args.kd_temperature, dim=-1) teacher_entropy = float( -(teacher_probs_tensor * teacher_probs_tensor.clamp_min(1e-12).log()).sum().item() ) teacher_probs = teacher_probs_tensor.cpu().tolist() record = { "sample_id": sample_id, "position": position, "candidate_token_ids": candidate_token_ids, "teacher_probs": teacher_probs, "gold_token_id": gold_token_id, "temperature": args.kd_temperature, "teacher_entropy": teacher_entropy, "source_config": batch["source_config"][0], "text_hash": stable_text_hash(sample_id, prompt_text, target_text), } writer.write(json.dumps(record) + "\n") seen_keys.add(cache_key) prepared += 1 teacher_entropies.append(teacher_entropy) if args.dry_run_batches and prepared >= args.kd_positions_per_sample * args.dry_run_batches: break if prepared and prepared % 50 == 0: print(f"Prepared {prepared} KD entries...") print(f"Teacher bank written to {output_path} with {prepared} new entries") if teacher_entropies: entropy_array = np.array(teacher_entropies, dtype=np.float32) print( "Teacher entropy: " f"mean={float(entropy_array.mean()):.4f}, " f"min={float(entropy_array.min()):.4f}, " f"max={float(entropy_array.max()):.4f}" ) def load_teacher_bank(cache_dir: str) -> Dict[str, List[Dict[str, object]]]: bank_path = Path(cache_dir) / "candidate_bank.jsonl" if not bank_path.exists(): raise FileNotFoundError(f"Teacher bank not found: {bank_path}") bank_map: Dict[str, List[Dict[str, object]]] = defaultdict(list) with open(bank_path, "r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue record = json.loads(line) bank_map[record["sample_id"]].append(record) print(f"Loaded teacher bank for {len(bank_map)} samples from {bank_path}") return bank_map def maybe_push_to_hub( args: argparse.Namespace, save_dir: Path, params: Dict[str, int], best_loss: float, ) -> None: if not args.push_to_hub: print("Skipping Hub push (enable with --push_to_hub).") return print("\nPushing to Hub...") api = HfApi() repo_id = args.hub_model_id try: api.create_repo(repo_id, exist_ok=True, private=False) except Exception as exc: print(f"Repo note: {exc}") config_dict = { "architecture": "ViL-DLM", "training_stage": args.stage, "best_loss": best_loss, "total_params_M": params["total"] / 1e6, "trainable_params_M": params["trainable"] / 1e6, "teacher": args.teacher_model_id, "dataset_configs": parse_dataset_configs(args.dataset_configs) if args.stage in {"2", "3a", "3b"} else ["llava_pretrain"], } with open(save_dir / "model_config.json", "w", encoding="utf-8") as handle: json.dump(config_dict, handle, indent=2) api.upload_folder( folder_path=str(save_dir), repo_id=repo_id, commit_message=f"Stage {args.stage} training (loss={best_loss:.4f})", ) print(f"\n✅ Model pushed to https://huggingface.co/{repo_id}") def run_training_stage(args: argparse.Namespace) -> None: tracker = _TrackioShim() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print_device_info(device) ensure_runtime_requirements(args) lm_path = download_student_backbone() vil_config = ViLConfig() proj_config = ProjConfig() model = ViLDLM(vil_config, proj_config, lm_path) setup_model_for_stage(model, args.stage) maybe_resume_model(model, args) params = model.count_params() print(f"Parameters: Total={params['total']/1e6:.1f}M, Trainable={params['trainable']/1e6:.1f}M") print(f" ViL: {params['vil']/1e6:.1f}M, Proj: {params['proj']/1e6:.1f}M, LM: {params['lm']/1e6:.1f}M") model = model.to(device) if hasattr(model.lm, "gradient_checkpointing_enable"): model.lm.gradient_checkpointing_enable() dataset, skip_stats = create_stage_dataset("1" if args.stage == "1" else "2", model.tokenizer, args) if skip_stats: print(f"Skip stats: {json.dumps(skip_stats)}") if args.stage == "3a": prepare_teacher_bank(args=args, model=model, dataset=dataset) return teacher_bank = load_teacher_bank(args.teacher_cache_dir) if args.stage == "3b" else {} dataloader = build_dataloader( dataset=dataset, batch_size=args.batch_size, shuffle=args.stage != "3a" and not (args.stage == "3b" and args.dry_run_batches), num_workers=args.num_workers, persistent_workers=args.persistent_workers, ) optimizer = get_optimizer(model, stage="1" if args.stage == "1" else "2") total_steps = max(1, (len(dataloader) * args.epochs) // max(args.grad_accum, 1)) scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) tracker.init(name=f"vil-dlm-stage{args.stage}") best_loss = float("inf") global_step = 0 step_timer = time.time() for epoch in range(args.epochs): model.train() epoch_loss = 0.0 epoch_kd_loss = 0.0 epoch_kd_entries = 0 epoch_effective_alpha = 0.0 epoch_kd_mask_prob = 0.0 epoch_kd_loss_variance = 0.0 num_batches = 0 optimizer.zero_grad(set_to_none=True) for batch_idx, batch in enumerate(dataloader): pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) loss_mask = batch["loss_mask"].to(device) force_mask = None if args.stage == "3b": force_mask = build_kd_force_mask( sample_ids=batch["sample_id"], bank_map=teacher_bank, seq_len=input_ids.shape[1], device=device, ) outputs = model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels, loss_mask=loss_mask, force_mask=force_mask, ) diffusion_loss = outputs["loss"] kd_loss = torch.tensor(0.0, device=device) kd_entries = 0 kd_loss_variance = torch.tensor(0.0, device=device) mean_kd_mask_prob = torch.tensor(0.0, device=device) effective_alpha_kd = torch.tensor(0.0, device=device) total_loss = diffusion_loss if args.stage == "3b": kd_loss, kd_metrics = compute_sparse_kd_loss( student_logits=outputs["logits"], noise_mask=outputs["noise_mask"], timesteps=outputs["t"], sample_ids=batch["sample_id"], bank_map=teacher_bank, temperature=args.kd_temperature, ) kd_entries = int(kd_metrics["entries"]) kd_loss_variance = kd_metrics["loss_variance"] mean_kd_mask_prob = kd_metrics["mean_mask_prob"] if kd_entries > 0: if args.kd_timestep_weighting: effective_alpha_kd = args.alpha_kd * mean_kd_mask_prob else: effective_alpha_kd = torch.tensor(args.alpha_kd, device=device) total_loss = (1.0 - effective_alpha_kd) * diffusion_loss + effective_alpha_kd * kd_loss loss = total_loss / args.grad_accum loss.backward() if (batch_idx + 1) % args.grad_accum == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) global_step += 1 actual_loss = float(total_loss.item()) actual_diffusion = float(diffusion_loss.item()) actual_kd = float(kd_loss.item()) if args.stage == "3b" else 0.0 actual_kd_variance = float(kd_loss_variance.item()) if args.stage == "3b" else 0.0 actual_kd_mask_prob = float(mean_kd_mask_prob.item()) if args.stage == "3b" else 0.0 actual_effective_alpha = float(effective_alpha_kd.item()) if args.stage == "3b" else 0.0 elapsed = max(time.time() - step_timer, 1e-6) samples_per_sec = (args.batch_size * args.grad_accum) / elapsed step_timer = time.time() gpu_mem_gb = 0.0 if torch.cuda.is_available(): gpu_mem_gb = torch.cuda.max_memory_allocated(device) / 1e9 print( f"[E{epoch}] Step {global_step}/{total_steps} | " f"Loss: {actual_loss:.4f} | Diff: {actual_diffusion:.4f} | " f"KD: {actual_kd:.4f} | KD entries: {kd_entries} | " f"KD var: {actual_kd_variance:.4f} | KD mask_p: {actual_kd_mask_prob:.4f} | " f"alpha_kd: {actual_effective_alpha:.4f} | " f"Samples/s: {samples_per_sec:.2f} | GPU mem: {gpu_mem_gb:.2f} GB" ) tracker.log( { "train/loss": actual_loss, "train/diffusion_loss": actual_diffusion, "train/kd_loss": actual_kd, "train/kd_entries": kd_entries, "train/kd_loss_variance": actual_kd_variance, "train/mean_kd_mask_prob": actual_kd_mask_prob, "train/effective_alpha_kd": actual_effective_alpha, "train/epoch": epoch, "train/step": global_step, "train/samples_per_sec": samples_per_sec, "train/gpu_mem_gb": gpu_mem_gb, } ) epoch_loss += float(total_loss.item()) epoch_kd_loss += float(kd_loss.item()) epoch_kd_entries += kd_entries epoch_effective_alpha += float(effective_alpha_kd.item()) epoch_kd_mask_prob += float(mean_kd_mask_prob.item()) epoch_kd_loss_variance += float(kd_loss_variance.item()) num_batches += 1 if args.dry_run_batches and num_batches >= args.dry_run_batches: break remainder = num_batches % args.grad_accum if remainder != 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) global_step += 1 avg_loss = epoch_loss / max(num_batches, 1) avg_kd_loss = epoch_kd_loss / max(num_batches, 1) avg_effective_alpha = epoch_effective_alpha / max(num_batches, 1) avg_kd_mask_prob = epoch_kd_mask_prob / max(num_batches, 1) avg_kd_loss_variance = epoch_kd_loss_variance / max(num_batches, 1) print( f"\n[Epoch {epoch}] Average Loss: {avg_loss:.4f} | Average KD: {avg_kd_loss:.4f} | " f"KD entries: {epoch_kd_entries} | Avg alpha_kd: {avg_effective_alpha:.4f} | " f"Avg KD mask_p: {avg_kd_mask_prob:.4f} | Avg KD var: {avg_kd_loss_variance:.4f}\n" ) tracker.log( { "train/epoch_loss": avg_loss, "train/epoch_kd_loss": avg_kd_loss, "train/epoch_kd_entries": epoch_kd_entries, "train/epoch_effective_alpha_kd": avg_effective_alpha, "train/epoch_mean_kd_mask_prob": avg_kd_mask_prob, "train/epoch_kd_loss_variance": avg_kd_loss_variance, "train/epoch": epoch, } ) if avg_loss < best_loss: best_loss = avg_loss save_dir = Path(args.output_dir) / f"stage{args.stage}_best" include_lm = args.stage in {"2", "3b"} model.save_checkpoint(save_dir, include_lm=include_lm) training_state = { "stage": args.stage, "best_loss": best_loss, "args": vars(args), } with open(save_dir / "training_state.json", "w", encoding="utf-8") as handle: json.dump(training_state, handle, indent=2) print(f"Saved best checkpoint (loss={best_loss:.4f})") maybe_push_to_hub( args=args, save_dir=Path(args.output_dir) / f"stage{args.stage}_best", params=params, best_loss=best_loss, ) print("Training complete!") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--stage", type=str, default="1", choices=["1", "2", "3a", "3b"]) parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--grad_accum", type=int, default=8) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--max_samples", type=int, default=None) parser.add_argument("--output_dir", type=str, default="./vil-dlm-output") parser.add_argument("--hub_model_id", type=str, default="omar-ah/ViL-DLM-0.6B") parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--require_cuda", action="store_true") parser.add_argument("--resume_from", type=str, default=None) parser.add_argument("--dataset_configs", type=str, default=",".join(DEFAULT_CAULDRON_CONFIGS)) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--persistent_workers", action="store_true") parser.add_argument("--dry_run_batches", type=int, default=0) parser.add_argument("--teacher_model_id", type=str, default="google/gemma-4-E2B-it") parser.add_argument("--teacher_cache_dir", type=str, default="./vil-dlm-output/teacher-cache") parser.add_argument("--prepare_teacher_bank", action="store_true") parser.add_argument("--teacher_batch_size", type=int, default=1) parser.add_argument("--alpha_kd", type=float, default=0.5) parser.add_argument("--kd_temperature", type=float, default=1.0) parser.add_argument("--kd_timestep_weighting", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--kd_top_k", type=int, default=16) parser.add_argument("--kd_positions_per_sample", type=int, default=16) return parser if __name__ == "__main__": ensure_hf_cache_root() parser = build_parser() args = parser.parse_args() if args.prepare_teacher_bank and args.stage != "3a": raise ValueError("--prepare_teacher_bank is only valid with --stage 3a") if args.kd_temperature <= 0: raise ValueError("--kd_temperature must be > 0") if not 0.0 <= args.alpha_kd <= 1.0: raise ValueError("--alpha_kd must be between 0 and 1") run_training_stage(args)