File size: 5,023 Bytes
4fa9dd4 2b05eb6 4fa9dd4 2b05eb6 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a f089e8f 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 0d77b0a 4fa9dd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """
ViL-DLM: Vision xLSTM + Diffusion Language Model
Architecture Configuration
"""
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ViLEncoderConfig:
"""Vision xLSTM (ViL) encoder configuration"""
vision_backbone: str = "vil2-small"
pretrained: bool = True
img_size: int = 224
patch_size: int = 16
in_channels: int = 3
dim: int = 384 # patch feature dim for vil-small / vil2-small
depth: int = 12 # VisionLSTM2 block-pairs; v1 vil-small internally uses 24
mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim
conv_kernel_size: int = 3 # QK Conv2D kernel
bidirectional: bool = True # alternating scan directions
dropout: float = 0.0
@property
def num_patches(self):
return (self.img_size // self.patch_size) ** 2 # 196 for 224/16
@property
def num_params_approx(self):
# Rough estimate: each mLSTM block has ~4 * dim * (2*dim) params for QKV + gates
per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4
return self.depth * per_block
@dataclass
class ProjectorConfig:
"""MLP projector: maps ViL features to LM embedding space"""
vil_dim: int = 384 # ViL-S output dim
lm_dim: int = 1024 # Qwen3-0.6B hidden_size
hidden_mult: int = 2 # projector hidden = lm_dim * hidden_mult
num_layers: int = 2 # 2-layer MLP (LaViDa/LLaDA-V standard)
activation: str = "gelu"
dropout: float = 0.0
@dataclass
class DiffusionConfig:
"""Masked diffusion (MDLM) training configuration"""
noise_schedule: str = "cosine" # cosine schedule (MDLM default)
mask_token_id: int = 151643 # Qwen3 pad/mask token
num_diffusion_steps: int = 1000 # training steps
inference_steps: int = 128 # sampling steps
remasking: str = "low_confidence" # remasking strategy
@dataclass
class DistillationConfig:
"""Knowledge distillation from Gemma 4 E2B teacher"""
teacher_model_id: str = "google/gemma-4-E2B-it"
teacher_quantize: bool = True # 4-bit quantization for memory
temperature: float = 2.0 # KD temperature
alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss
alpha_vision_kd: float = 0.3 # weight for vision feature distillation
kd_top_k: int = 8 # sparse cross-tokenizer candidate set size
kd_positions_per_sample: int = 16
teacher_cache_dir: str = "./vil-dlm-output/teacher-cache"
@dataclass
class TrainingConfig:
"""Full training configuration"""
# Model
vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig)
projector: ProjectorConfig = field(default_factory=ProjectorConfig)
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
distillation: DistillationConfig = field(default_factory=DistillationConfig)
# Backbone
diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1"
# Training hyperparams (from dLLM + LLaDA-V + LFM2 recipes)
learning_rate: float = 1e-4
vil_learning_rate: float = 2e-6 # lower LR for vision encoder (LLaDA-V)
projector_learning_rate: float = 1e-3 # higher LR for projector (LLaDA-V Stage 1)
weight_decay: float = 0.05
warmup_ratio: float = 0.1
lr_scheduler: str = "cosine"
max_seq_len: int = 1024
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 8 # effective batch = 32
num_epochs: int = 3
bf16: bool = True
gradient_checkpointing: bool = True
# Data
pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K
finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal
finetune_dataset_configs: List[str] = field(default_factory=lambda: [
"ai2d",
"vqav2",
"aokvqa",
"textvqa",
"docvqa",
"chartqa",
"textcaps",
"screen2words",
])
# Output
output_dir: str = "./vil-dlm-output"
hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
push_to_hub: bool = False
# Stages
stage: str = "1" # 1, 2, 3a, 3b
def get_config(stage: str = "1") -> TrainingConfig:
config = TrainingConfig()
config.stage = stage
if stage == "1":
# Stage 1: Train projector only (ViL frozen, LM frozen)
config.learning_rate = 1e-3
config.num_epochs = 1
config.per_device_train_batch_size = 8
config.gradient_accumulation_steps = 4
elif stage == "2":
# Stage 2: Full model finetune (ViL + projector + LM)
config.learning_rate = 1e-5
config.vil_learning_rate = 2e-6
config.projector_learning_rate = 1e-5
config.num_epochs = 3
elif stage in {"3a", "3b"}:
# Stage 3: sparse cross-tokenizer distillation with Gemma 4
config.learning_rate = 1e-5
config.num_epochs = 2
config.distillation.alpha_kd = 0.5
return config
|