omar-ah commited on
Commit
4fa9dd4
·
verified ·
1 Parent(s): 39b477f

Upload model_config.py

Browse files
Files changed (1) hide show
  1. code/model_config.py +129 -0
code/model_config.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViL-DLM: Vision xLSTM + Diffusion Language Model
3
+ Architecture Configuration
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional, List
8
+
9
+
10
+ @dataclass
11
+ class ViLEncoderConfig:
12
+ """Vision xLSTM (ViL) encoder configuration"""
13
+ img_size: int = 224
14
+ patch_size: int = 16
15
+ in_channels: int = 3
16
+ dim: int = 384 # ViL-S default (23M params)
17
+ depth: int = 24 # Standard ViL depth
18
+ mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim
19
+ conv_kernel_size: int = 3 # QK Conv2D kernel
20
+ bidirectional: bool = True # alternating scan directions
21
+ dropout: float = 0.0
22
+
23
+ @property
24
+ def num_patches(self):
25
+ return (self.img_size // self.patch_size) ** 2 # 196 for 224/16
26
+
27
+ @property
28
+ def num_params_approx(self):
29
+ # Rough estimate: each mLSTM block has ~4 * dim * (2*dim) params for QKV + gates
30
+ per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4
31
+ return self.depth * per_block
32
+
33
+
34
+ @dataclass
35
+ class ProjectorConfig:
36
+ """MLP projector: maps ViL features to LM embedding space"""
37
+ vil_dim: int = 384 # ViL-S output dim
38
+ lm_dim: int = 1024 # Qwen3-0.6B hidden_size
39
+ hidden_mult: int = 2 # projector hidden = lm_dim * hidden_mult
40
+ num_layers: int = 2 # 2-layer MLP (LaViDa/LLaDA-V standard)
41
+ activation: str = "gelu"
42
+ dropout: float = 0.0
43
+
44
+
45
+ @dataclass
46
+ class DiffusionConfig:
47
+ """Masked diffusion (MDLM) training configuration"""
48
+ noise_schedule: str = "cosine" # cosine schedule (MDLM default)
49
+ mask_token_id: int = 151643 # Qwen3 pad/mask token
50
+ num_diffusion_steps: int = 1000 # training steps
51
+ inference_steps: int = 128 # sampling steps
52
+ remasking: str = "low_confidence" # remasking strategy
53
+
54
+
55
+ @dataclass
56
+ class DistillationConfig:
57
+ """Knowledge distillation from Gemma 4 E2B teacher"""
58
+ teacher_model_id: str = "google/gemma-4-E2B-it"
59
+ teacher_quantize: bool = True # 4-bit quantization for memory
60
+ temperature: float = 2.0 # KD temperature
61
+ alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss
62
+ alpha_vision_kd: float = 0.3 # weight for vision feature distillation
63
+ top_k_logits: int = 32 # LFM2-style top-K distillation
64
+
65
+
66
+ @dataclass
67
+ class TrainingConfig:
68
+ """Full training configuration"""
69
+ # Model
70
+ vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig)
71
+ projector: ProjectorConfig = field(default_factory=ProjectorConfig)
72
+ diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
73
+ distillation: DistillationConfig = field(default_factory=DistillationConfig)
74
+
75
+ # Backbone
76
+ diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1"
77
+
78
+ # Training hyperparams (from dLLM + LLaDA-V + LFM2 recipes)
79
+ learning_rate: float = 1e-4
80
+ vil_learning_rate: float = 2e-6 # lower LR for vision encoder (LLaDA-V)
81
+ projector_learning_rate: float = 1e-3 # higher LR for projector (LLaDA-V Stage 1)
82
+ weight_decay: float = 0.05
83
+ warmup_ratio: float = 0.1
84
+ lr_scheduler: str = "cosine"
85
+
86
+ max_seq_len: int = 1024
87
+ per_device_train_batch_size: int = 4
88
+ gradient_accumulation_steps: int = 8 # effective batch = 32
89
+ num_epochs: int = 3
90
+
91
+ bf16: bool = True
92
+ gradient_checkpointing: bool = True
93
+
94
+ # Data
95
+ pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K
96
+ finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal
97
+
98
+ # Output
99
+ output_dir: str = "./vil-dlm-output"
100
+ hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
101
+ push_to_hub: bool = True
102
+
103
+ # Stages
104
+ stage: int = 1 # 1 = projector only, 2 = full finetune, 3 = + distillation
105
+
106
+
107
+ def get_config(stage: int = 1) -> TrainingConfig:
108
+ config = TrainingConfig()
109
+ config.stage = stage
110
+
111
+ if stage == 1:
112
+ # Stage 1: Train projector only (ViL frozen, LM frozen)
113
+ config.learning_rate = 1e-3
114
+ config.num_epochs = 1
115
+ config.per_device_train_batch_size = 8
116
+ config.gradient_accumulation_steps = 4
117
+ elif stage == 2:
118
+ # Stage 2: Full model finetune (ViL + projector + LM)
119
+ config.learning_rate = 1e-5
120
+ config.vil_learning_rate = 2e-6
121
+ config.projector_learning_rate = 1e-5
122
+ config.num_epochs = 3
123
+ elif stage == 3:
124
+ # Stage 3: + Distillation from Gemma 4
125
+ config.learning_rate = 1e-5
126
+ config.num_epochs = 2
127
+ config.distillation.alpha_kd = 0.5
128
+
129
+ return config