""" ViL-DLM: Vision xLSTM + Diffusion Language Model Architecture Configuration """ from dataclasses import dataclass, field from typing import Optional, List @dataclass class ViLEncoderConfig: """Vision xLSTM (ViL) encoder configuration""" vision_backbone: str = "vil2-small" pretrained: bool = True img_size: int = 224 patch_size: int = 16 in_channels: int = 3 dim: int = 384 # patch feature dim for vil-small / vil2-small depth: int = 12 # VisionLSTM2 block-pairs; v1 vil-small internally uses 24 mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim conv_kernel_size: int = 3 # QK Conv2D kernel bidirectional: bool = True # alternating scan directions dropout: float = 0.0 @property def num_patches(self): return (self.img_size // self.patch_size) ** 2 # 196 for 224/16 @property def num_params_approx(self): # Rough estimate: each mLSTM block has ~4 * dim * (2*dim) params for QKV + gates per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4 return self.depth * per_block @dataclass class ProjectorConfig: """MLP projector: maps ViL features to LM embedding space""" vil_dim: int = 384 # ViL-S output dim lm_dim: int = 1024 # Qwen3-0.6B hidden_size hidden_mult: int = 2 # projector hidden = lm_dim * hidden_mult num_layers: int = 2 # 2-layer MLP (LaViDa/LLaDA-V standard) activation: str = "gelu" dropout: float = 0.0 @dataclass class DiffusionConfig: """Masked diffusion (MDLM) training configuration""" noise_schedule: str = "cosine" # cosine schedule (MDLM default) mask_token_id: int = 151643 # Qwen3 pad/mask token num_diffusion_steps: int = 1000 # training steps inference_steps: int = 128 # sampling steps remasking: str = "low_confidence" # remasking strategy @dataclass class DistillationConfig: """Knowledge distillation from Gemma 4 E2B teacher""" teacher_model_id: str = "google/gemma-4-E2B-it" teacher_quantize: bool = True # 4-bit quantization for memory temperature: float = 2.0 # KD temperature alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss alpha_vision_kd: float = 0.3 # weight for vision feature distillation kd_top_k: int = 8 # sparse cross-tokenizer candidate set size kd_positions_per_sample: int = 16 teacher_cache_dir: str = "./vil-dlm-output/teacher-cache" @dataclass class TrainingConfig: """Full training configuration""" # Model vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig) projector: ProjectorConfig = field(default_factory=ProjectorConfig) diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) distillation: DistillationConfig = field(default_factory=DistillationConfig) # Backbone diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1" # Training hyperparams (from dLLM + LLaDA-V + LFM2 recipes) learning_rate: float = 1e-4 vil_learning_rate: float = 2e-6 # lower LR for vision encoder (LLaDA-V) projector_learning_rate: float = 1e-3 # higher LR for projector (LLaDA-V Stage 1) weight_decay: float = 0.05 warmup_ratio: float = 0.1 lr_scheduler: str = "cosine" max_seq_len: int = 1024 per_device_train_batch_size: int = 4 gradient_accumulation_steps: int = 8 # effective batch = 32 num_epochs: int = 3 bf16: bool = True gradient_checkpointing: bool = True # Data pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal finetune_dataset_configs: List[str] = field(default_factory=lambda: [ "ai2d", "vqav2", "aokvqa", "textvqa", "docvqa", "chartqa", "textcaps", "screen2words", ]) # Output output_dir: str = "./vil-dlm-output" hub_model_id: str = "omar-ah/ViL-DLM-0.6B" push_to_hub: bool = False # Stages stage: str = "1" # 1, 2, 3a, 3b def get_config(stage: str = "1") -> TrainingConfig: config = TrainingConfig() config.stage = stage if stage == "1": # Stage 1: Train projector only (ViL frozen, LM frozen) config.learning_rate = 1e-3 config.num_epochs = 1 config.per_device_train_batch_size = 8 config.gradient_accumulation_steps = 4 elif stage == "2": # Stage 2: Full model finetune (ViL + projector + LM) config.learning_rate = 1e-5 config.vil_learning_rate = 2e-6 config.projector_learning_rate = 1e-5 config.num_epochs = 3 elif stage in {"3a", "3b"}: # Stage 3: sparse cross-tokenizer distillation with Gemma 4 config.learning_rate = 1e-5 config.num_epochs = 2 config.distillation.alpha_kd = 0.5 return config