ViL-DLM-0.6B / README.md
omar-ah's picture
Add timestep-aware sparse KD weighting
25e4efd
metadata
license: apache-2.0
tags:
  - vision-language
  - diffusion
  - xlstm
  - vision-lstm
  - masked-diffusion
  - mdlm
  - multimodal
language: en
pipeline_tag: image-text-to-text

ViL-DLM: Vision xLSTM Diffusion Language Model (~660M)

The first vision-language model combining Vision xLSTM (ViL) with a discrete masked diffusion language model backbone.

Architecture

[Image] → ViL-S Encoder (57M) → MLP Projector (7M) → [196 Visual Tokens]
[Visual Tokens] + [Masked Text Tokens] → Bidirectional Diffusion LM (596M) → Denoised Text
Component Model Params Key Innovation
Vision Encoder Vision-xLSTM-S (ViL-S) ~57M O(N) linear complexity, alternating bidirectional mLSTM with Conv2D
Projector 2-layer MLP (GELU) ~7M Maps ViL features (384d) → LM space (1024d)
Language Backbone dLLM Qwen3-0.6B (MDLM) ~596M Bidirectional masked diffusion, non-autoregressive
Total ~660M

How It Works

Vision Encoder: Vision xLSTM (ViL-S)

  • Based on Vision-LSTM (Alkin et al., 2024)
  • Processes 224×224 images into 196 patch tokens (16×16 patches)
  • Uses mLSTM blocks with matrix memory cells and exponential gating
  • Alternating bidirectional scanning: odd blocks scan top-left→bottom-right, even blocks reverse
  • Conv2D for spatial QK context: depthwise 3×3 convolution adds local spatial awareness
  • SwiGLU FFN after each mLSTM block
  • Linear O(N) complexity vs ViT's quadratic O(N²) — critical for scaling to many visual tokens

Diffusion Language Model: MDLM via dLLM

  • Based on dLLM (Berkeley, 2025) converting Qwen3-0.6B to diffusion
  • Training: Forward diffusion progressively masks tokens with cosine schedule → model predicts masked tokens
  • Inference: Start with all-masked output → iteratively unmask most-confident tokens
  • Key change from AR: replaces causal attention mask with bidirectional padding-only mask
  • Weighted cross-entropy loss on masked positions only (MDLM objective)

Knowledge Distillation (Stage 3)

  • Teacher: Gemma 4 E2B (5.1B params, ~2B effective)
  • Sparse cross-tokenizer distillation: prepare a teacher-scored candidate bank in the student token space, then blend sparse KL with diffusion loss
  • Temperature τ=2.0, α_KD=0.5 (50% diffusion loss + 50% KD loss)

Training Recipe

Multi-stage training inspired by LLaDA-V, LaViDa, LFM2, and Mistral/Pixtral:

Stage Components Trained Dataset Learning Rate Epochs
1 Projector only (ViL & LM frozen) LLaVA-Pretrain (558K) 1e-3 1-2
2 Full model (all components) The Cauldron ViL:2e-6, Proj:1e-5, LM:1e-5 3
3 + KD from Gemma 4 E2B Stage 2 data mix + cached teacher bank Sparse cross-tokenizer KD (α=0.5) 2

Efficiency Tricks Applied

  • Per-component learning rates (LLaDA-V recipe): vision encoder gets 5× lower LR
  • Gradient checkpointing on LM backbone to reduce VRAM
  • Cosine LR schedule with warmup
  • Gradient clipping at 1.0
  • AdamW optimizer with β=(0.9, 0.999), weight_decay=0.05

Why This Combination Matters

This is a genuinely unexplored frontier in the literature:

  1. No published work combines Vision xLSTM with a diffusion language model
  2. ViL's linear complexity could be transformative for processing large numbers of visual tokens in multimodal diffusion models, where current Transformer-based approaches incur quadratic attention costs
  3. The bidirectional nature of both ViL (alternating scan) and diffusion LM (full attention) creates natural synergy — both architectures process information non-autoregressively
  4. Distillation from Gemma 4 E2B bridges the gap between the small student and a state-of-the-art multimodal teacher

Running Training

# CPU smoke: Stage 1 projector path
python code/train_production.py --stage 1 --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 1 --dry_run_batches 1

# CPU smoke: Stage 2 subset path
python code/train_production.py --stage 2 --resume_from ./vil-dlm-output/stage1_best --dataset_configs ai2d,aokvqa --epochs 1 --batch_size 1 --grad_accum 1 --num_workers 0 --max_samples 8 --dry_run_batches 1

# Stage 1: projector-only alignment
python code/train_production.py --stage 1 --require_cuda --epochs 1 --batch_size 8 --grad_accum 4

# Stage 2: full-model finetune on the balanced Cauldron mix
python code/train_production.py --stage 2 --require_cuda --epochs 3 --batch_size 2 --grad_accum 16

# Stage 3a: build the Gemma teacher candidate bank from a Stage 2 checkpoint (GPU only)
python code/train_production.py --stage 3a --require_cuda --resume_from ./vil-dlm-output/stage2_best --prepare_teacher_bank --teacher_batch_size 2 --kd_top_k 16 --kd_positions_per_sample 16 --kd_temperature 1.0

# Stage 3b: timestep-aware sparse KD training from the cached teacher bank (GPU only)
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --epochs 2 --batch_size 2 --grad_accum 16 --alpha_kd 0.5 --kd_temperature 1.0

# Cheap validation gate for any stage
python code/train_production.py --stage 1 --require_cuda --dry_run_batches 1 --max_samples 8

# Cheap Stage 3 KD-consumption gate after Stage 3a wrote a teacher bank
python code/train_production.py --stage 3b --require_cuda --resume_from ./vil-dlm-output/stage2_best --teacher_cache_dir ./vil-dlm-output/teacher-cache --epochs 1 --batch_size 1 --grad_accum 1 --dry_run_batches 1 --alpha_kd 0.5 --kd_temperature 1.0

Training now saves checkpoints locally by default. Add --push_to_hub only when you want to publish artifacts. Stage 3b uses timestep-aware sparse KD by default; pass --no-kd_timestep_weighting only to reproduce the pilot's fixed-alpha KD behavior. CPU sessions should stop after the Stage 2 subset smoke test. Stage 3 requires a CUDA GPU because Gemma 4 teacher-bank preparation uses quantized multimodal teacher inference.

Hardware Requirements

  • Stage 1: A10G (24GB) or T4 (16GB) — only projector gradients (~7M params)
  • Stage 2: A10G (24GB) recommended — full model gradients (~660M params)
  • Stage 3: H100 / A100 (80GB) recommended — Gemma 4 teacher bank prep + student distillation

Dependencies

torch>=2.0
transformers>=5.5.0
datasets
trackio
accelerate
einops
pillow
huggingface_hub

Pretrained Components Used

Component Source License
Diffusion LM backbone dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1 Apache 2.0
Teacher (Stage 3) google/gemma-4-E2B-it Apache 2.0
Pretraining data LLaVA-Pretrain CC BY 4.0
Instruction data The Cauldron Mixed

Full Literature Context

This model was designed based on a comprehensive literature review covering:

  • Language Diffusion Models: MDLM, LLaDA, Plaid, BD3-LM, EDLM, ADLM
  • Vision-Language Diffusion Models: LaViDa, LLaDA-V, MMaDA, Muddit, LMFusion, DEEM
  • Vision xLSTM: ViL (arxiv:2406.04303)
  • Knowledge Distillation for DLMs: TCS distillation, DiffuGPT/DiffuLLaMA, dLLM
  • JEPA family: I-JEPA, D-JEPA, VL-JEPA, LLM-JEPA
  • Efficiency recipes: LFM2 (Liquid AI), Mistral/Pixtral, Minitron

References

Paper arxiv Role
Vision-LSTM (ViL) 2406.04303 Vision encoder architecture
dLLM 2602.22661 Diffusion LM backbone
MDLM 2406.07524 Core diffusion objective
LLaDA-V 2505.16933 VLM training recipe
LaViDa Complementary masking, prefix KV cache
LFM2 2511.23404 Top-K distillation recipe
Gemma 4 blog Teacher model
Pixtral 2410.07073 RoPE-2D, variable-res ViT
Minitron 2407.14679 Prune+distill best practices