File size: 8,624 Bytes
1e0c38c 0d77b0a 1e0c38c 0d77b0a 1e0c38c 0d77b0a 1e0c38c 02b453d 0d77b0a 1e0c38c 0d77b0a 1e0c38c 02b453d 25e4efd 0d77b0a 25e4efd 0d77b0a 9defebb 25e4efd 1e0c38c 0d77b0a 25e4efd 02b453d 0d77b0a 1e0c38c 0d77b0a 1e0c38c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | ---
license: apache-2.0
tags:
- vision-language
- diffusion
- xlstm
- vision-lstm
- masked-diffusion
- mdlm
- multimodal
language: en
pipeline_tag: image-text-to-text
---
# ViL-DLM: Vision xLSTM Diffusion Language Model (~660M)
**The first vision-language model combining Vision xLSTM (ViL) with a discrete masked diffusion language model backbone.**
## Architecture
```
[Image] → ViL-S Encoder (57M) → MLP Projector (7M) → [196 Visual Tokens]
[Visual Tokens] + [Masked Text Tokens] → Bidirectional Diffusion LM (596M) → Denoised Text
```
| Component | Model | Params | Key Innovation |
|-----------|-------|--------|----------------|
| Vision Encoder | **Vision-xLSTM-S (ViL-S)** | ~57M | O(N) linear complexity, alternating bidirectional mLSTM with Conv2D |
| Projector | 2-layer MLP (GELU) | ~7M | Maps ViL features (384d) → LM space (1024d) |
| Language Backbone | **dLLM Qwen3-0.6B (MDLM)** | ~596M | Bidirectional masked diffusion, non-autoregressive |
| **Total** | | **~660M** | |
## How It Works
### Vision Encoder: Vision xLSTM (ViL-S)
- Based on [Vision-LSTM](https://arxiv.org/abs/2406.04303) (Alkin et al., 2024)
- Processes 224×224 images into 196 patch tokens (16×16 patches)
- Uses **mLSTM blocks** with matrix memory cells and exponential gating
- **Alternating bidirectional scanning**: odd blocks scan top-left→bottom-right, even blocks reverse
- **Conv2D for spatial QK context**: depthwise 3×3 convolution adds local spatial awareness
- **SwiGLU FFN** after each mLSTM block
- Linear O(N) complexity vs ViT's quadratic O(N²) — critical for scaling to many visual tokens
### Diffusion Language Model: MDLM via dLLM
- Based on [dLLM](https://arxiv.org/abs/2602.22661) (Berkeley, 2025) converting Qwen3-0.6B to diffusion
- **Training**: Forward diffusion progressively masks tokens with cosine schedule → model predicts masked tokens
- **Inference**: Start with all-masked output → iteratively unmask most-confident tokens
- **Key change from AR**: replaces causal attention mask with bidirectional padding-only mask
- Weighted cross-entropy loss on masked positions only (MDLM objective)
### Knowledge Distillation (Stage 3)
- Teacher: [Gemma 4 E2B](https://huggingface.co/google/gemma-4-E2B-it) (5.1B params, ~2B effective)
- **Sparse cross-tokenizer distillation**: prepare a teacher-scored candidate bank in the student token space, then blend sparse KL with diffusion loss
- Temperature τ=2.0, α_KD=0.5 (50% diffusion loss + 50% KD loss)
## Training Recipe
Multi-stage training inspired by LLaDA-V, LaViDa, LFM2, and Mistral/Pixtral:
| Stage | Components Trained | Dataset | Learning Rate | Epochs |
|-------|-------------------|---------|---------------|--------|
| 1 | Projector only (ViL & LM frozen) | [LLaVA-Pretrain](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain) (558K) | 1e-3 | 1-2 |
| 2 | Full model (all components) | [The Cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) | ViL:2e-6, Proj:1e-5, LM:1e-5 | 3 |
| 3 | + KD from Gemma 4 E2B | Stage 2 data mix + cached teacher bank | Sparse cross-tokenizer KD (α=0.5) | 2 |
### Efficiency Tricks Applied
- **Per-component learning rates** (LLaDA-V recipe): vision encoder gets 5× lower LR
- **Gradient checkpointing** on LM backbone to reduce VRAM
- **Cosine LR schedule** with warmup
- **Gradient clipping** at 1.0
- AdamW optimizer with β=(0.9, 0.999), weight_decay=0.05
## Why This Combination Matters
This is a genuinely **unexplored frontier** in the literature:
1. **No published work** combines Vision xLSTM with a diffusion language model
2. ViL's **linear complexity** could be transformative for processing large numbers of visual tokens in multimodal diffusion models, where current Transformer-based approaches incur quadratic attention costs
3. The **bidirectional nature** of both ViL (alternating scan) and diffusion LM (full attention) creates natural synergy — both architectures process information non-autoregressively
4. **Distillation from Gemma 4 E2B** bridges the gap between the small student and a state-of-the-art multimodal teacher
## Running Training
```bash
# CPU smoke: Stage 1 projector path
python code/train_production.py --stage 1 --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 1 --dry_run_batches 1
# CPU smoke: Stage 2 subset path
python code/train_production.py --stage 2 --resume_from ./vil-dlm-output/stage1_best --dataset_configs ai2d,aokvqa --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 8 --dry_run_batches 1
# Stage 1: projector-only alignment
python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4
# Stage 2: full-model finetune on the balanced Cauldron mix
python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16
# Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint (GPU only)
python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --prepare_teacher_bank --teacher_batch_size 2 --kd_top_k 16 --kd_positions_per_sample 16 --kd_temperature 1.0
# Stage 3b: timestep-aware sparse KD training from the cached teacher bank (GPU only)
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16 --alpha_kd 0.5 --kd_temperature 1.0
# Cheap validation gate for any stage
python code/train_production.py --stage 1 --require_cuda --dry_run_batches 1 --max_samples 8
# Cheap Stage 3 KD-consumption gate after Stage 3a wrote a teacher bank
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_cache_dir ./vil-dlm-output/teacher-cache --epochs 1 --batch_size 1 --grad_accum 1 --dry_run_batches 1 --alpha_kd 0.5 --kd_temperature 1.0
```
Training now saves checkpoints locally by default. Add `--push_to_hub` only when you want to publish artifacts.
Stage 3b uses timestep-aware sparse KD by default; pass `--no-kd_timestep_weighting` only to reproduce the pilot's fixed-alpha KD behavior.
CPU sessions should stop after the Stage 2 subset smoke test. Stage 3 requires a CUDA GPU because Gemma 4 teacher-bank preparation uses quantized multimodal teacher inference.
### Hardware Requirements
- **Stage 1**: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
- **Stage 2**: A10G (24GB) recommended — full model gradients (~660M params)
- **Stage 3**: H100 / A100 (80GB) recommended — Gemma 4 teacher bank prep + student distillation
### Dependencies
```
torch>=2.0
transformers>=5.5.0
datasets
trackio
accelerate
einops
pillow
huggingface_hub
```
## Pretrained Components Used
| Component | Source | License |
|-----------|--------|---------|
| Diffusion LM backbone | [dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1](https://huggingface.co/dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1) | Apache 2.0 |
| Teacher (Stage 3) | [google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it) | Apache 2.0 |
| Pretraining data | [LLaVA-Pretrain](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain) | CC BY 4.0 |
| Instruction data | [The Cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) | Mixed |
## Full Literature Context
This model was designed based on a comprehensive literature review covering:
- **Language Diffusion Models**: MDLM, LLaDA, Plaid, BD3-LM, EDLM, ADLM
- **Vision-Language Diffusion Models**: LaViDa, LLaDA-V, MMaDA, Muddit, LMFusion, DEEM
- **Vision xLSTM**: ViL (arxiv:2406.04303)
- **Knowledge Distillation for DLMs**: TCS distillation, DiffuGPT/DiffuLLaMA, dLLM
- **JEPA family**: I-JEPA, D-JEPA, VL-JEPA, LLM-JEPA
- **Efficiency recipes**: LFM2 (Liquid AI), Mistral/Pixtral, Minitron
## References
| Paper | arxiv | Role |
|-------|-------|------|
| Vision-LSTM (ViL) | [2406.04303](https://arxiv.org/abs/2406.04303) | Vision encoder architecture |
| dLLM | [2602.22661](https://arxiv.org/abs/2602.22661) | Diffusion LM backbone |
| MDLM | [2406.07524](https://arxiv.org/abs/2406.07524) | Core diffusion objective |
| LLaDA-V | [2505.16933](https://arxiv.org/abs/2505.16933) | VLM training recipe |
| LaViDa | — | Complementary masking, prefix KV cache |
| LFM2 | [2511.23404](https://arxiv.org/abs/2511.23404) | Top-K distillation recipe |
| Gemma 4 | [blog](https://huggingface.co/blog/gemma4) | Teacher model |
| Pixtral | [2410.07073](https://arxiv.org/abs/2410.07073) | RoPE-2D, variable-res ViT |
| Minitron | [2407.14679](https://arxiv.org/abs/2407.14679) | Prune+distill best practices |
|