File size: 5,023 Bytes
4fa9dd4
 
 
 
 
 
 
 
 
 
 
 
2b05eb6
 
4fa9dd4
 
 
2b05eb6
 
4fa9dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d77b0a
 
 
4fa9dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d77b0a
 
 
f089e8f
0d77b0a
 
 
 
 
 
 
4fa9dd4
 
 
0d77b0a
4fa9dd4
 
0d77b0a
4fa9dd4
 
0d77b0a
4fa9dd4
 
 
0d77b0a
4fa9dd4
 
 
 
 
0d77b0a
4fa9dd4
 
 
 
 
0d77b0a
 
4fa9dd4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
ViL-DLM: Vision xLSTM + Diffusion Language Model
Architecture Configuration
"""

from dataclasses import dataclass, field
from typing import Optional, List


@dataclass
class ViLEncoderConfig:
    """Vision xLSTM (ViL) encoder configuration"""
    vision_backbone: str = "vil2-small"
    pretrained: bool = True
    img_size: int = 224
    patch_size: int = 16
    in_channels: int = 3
    dim: int = 384          # patch feature dim for vil-small / vil2-small
    depth: int = 12         # VisionLSTM2 block-pairs; v1 vil-small internally uses 24
    mlstm_dim_mult: int = 2  # mLSTM internal dim = 2 * dim
    conv_kernel_size: int = 3  # QK Conv2D kernel
    bidirectional: bool = True  # alternating scan directions
    dropout: float = 0.0
    
    @property
    def num_patches(self):
        return (self.img_size // self.patch_size) ** 2  # 196 for 224/16
    
    @property  
    def num_params_approx(self):
        # Rough estimate: each mLSTM block has ~4 * dim * (2*dim) params for QKV + gates
        per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4
        return self.depth * per_block


@dataclass
class ProjectorConfig:
    """MLP projector: maps ViL features to LM embedding space"""
    vil_dim: int = 384         # ViL-S output dim
    lm_dim: int = 1024         # Qwen3-0.6B hidden_size
    hidden_mult: int = 2       # projector hidden = lm_dim * hidden_mult
    num_layers: int = 2        # 2-layer MLP (LaViDa/LLaDA-V standard)
    activation: str = "gelu"
    dropout: float = 0.0


@dataclass
class DiffusionConfig:
    """Masked diffusion (MDLM) training configuration"""
    noise_schedule: str = "cosine"  # cosine schedule (MDLM default)
    mask_token_id: int = 151643     # Qwen3 pad/mask token
    num_diffusion_steps: int = 1000  # training steps
    inference_steps: int = 128       # sampling steps
    remasking: str = "low_confidence"  # remasking strategy


@dataclass
class DistillationConfig:
    """Knowledge distillation from Gemma 4 E2B teacher"""
    teacher_model_id: str = "google/gemma-4-E2B-it"
    teacher_quantize: bool = True  # 4-bit quantization for memory
    temperature: float = 2.0       # KD temperature
    alpha_kd: float = 0.5          # weight for KD loss vs diffusion loss
    alpha_vision_kd: float = 0.3   # weight for vision feature distillation
    kd_top_k: int = 8              # sparse cross-tokenizer candidate set size
    kd_positions_per_sample: int = 16
    teacher_cache_dir: str = "./vil-dlm-output/teacher-cache"


@dataclass
class TrainingConfig:
    """Full training configuration"""
    # Model
    vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig)
    projector: ProjectorConfig = field(default_factory=ProjectorConfig)
    diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
    distillation: DistillationConfig = field(default_factory=DistillationConfig)
    
    # Backbone
    diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1"
    
    # Training hyperparams (from dLLM + LLaDA-V + LFM2 recipes)
    learning_rate: float = 1e-4
    vil_learning_rate: float = 2e-6      # lower LR for vision encoder (LLaDA-V)
    projector_learning_rate: float = 1e-3  # higher LR for projector (LLaDA-V Stage 1)
    weight_decay: float = 0.05
    warmup_ratio: float = 0.1
    lr_scheduler: str = "cosine"
    
    max_seq_len: int = 1024
    per_device_train_batch_size: int = 4
    gradient_accumulation_steps: int = 8  # effective batch = 32
    num_epochs: int = 3
    
    bf16: bool = True
    gradient_checkpointing: bool = True
    
    # Data
    pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain"  # Stage 1: 558K
    finetune_dataset: str = "HuggingFaceM4/the_cauldron"  # Stage 2: rich multimodal
    finetune_dataset_configs: List[str] = field(default_factory=lambda: [
        "ai2d",
        "vqav2",
        "aokvqa",
        "textvqa",
        "docvqa",
        "chartqa",
        "textcaps",
        "screen2words",
    ])

    # Output
    output_dir: str = "./vil-dlm-output"
    hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
    push_to_hub: bool = False
    
    # Stages
    stage: str = "1"  # 1, 2, 3a, 3b


def get_config(stage: str = "1") -> TrainingConfig:
    config = TrainingConfig()
    config.stage = stage
    
    if stage == "1":
        # Stage 1: Train projector only (ViL frozen, LM frozen)
        config.learning_rate = 1e-3
        config.num_epochs = 1
        config.per_device_train_batch_size = 8
        config.gradient_accumulation_steps = 4
    elif stage == "2":
        # Stage 2: Full model finetune (ViL + projector + LM)
        config.learning_rate = 1e-5
        config.vil_learning_rate = 2e-6
        config.projector_learning_rate = 1e-5
        config.num_epochs = 3
    elif stage in {"3a", "3b"}:
        # Stage 3: sparse cross-tokenizer distillation with Gemma 4
        config.learning_rate = 1e-5
        config.num_epochs = 2
        config.distillation.alpha_kd = 0.5
    
    return config