| --- |
| datasets: |
| - imagenet-1k |
| library_name: pytorch |
| license: mit |
| pipeline_tag: unconditional-image-generation |
| tags: |
| - diffusion |
| - flow-matching |
| - representation-learning |
| - dinov2 |
| - imagenet |
| --- |
| |
| # RiT-XL: Vanilla Diffusion Transformers Suffice in Representation Space |
|
|
| This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs on ImageNet 256×256 with frozen DINOv2-Small features. |
|
|
| RiT (Representation Image Transformer) is a vanilla Diffusion Transformer that effectively models distributions in high-dimensional representation spaces, as presented in the paper [RiT: Vanilla Diffusion Transformers Suffice in Representation Space](https://huggingface.co/papers/2605.21981). |
|
|
| [](https://github.com/lezhang7/RiT) |
| [](https://huggingface.co/papers/2605.21981) |
|
|
| ## Results on ImageNet 256×256 |
|
|
| | Method | Encoder | Params | FID ↓ (CFG=1) | FID ↓ (CFG≈3.7) | |
| |--------------------------|---------------:|-------:|---------------:|----------------:| |
| | DiT-XL | SD-VAE | 675M | 9.62 | 2.27 | |
| | SiT-XL | SD-VAE | 675M | 8.61 | 2.06 | |
| | REPA-XL | SD-VAE | 675M | 5.78 | 1.29 | |
| | DDT-XL | SD-VAE | 675M | 6.27 | 1.26 | |
| | REG-XL | SD-VAE | 675M | 1.80 | 1.36 | |
| | RAE-XL | DINOv2-S | 676M | 1.87 | 1.41 | |
| | RAE-XL<sup>DH</sup> | DINOv2-B | 839M | 1.51 | 1.16 | |
| | FAE-XL | FAE-DINOv2-G | 675M | 1.48 | 1.29 | |
| | **RiT-XL (ours)** | **DINOv2-S** | **676M** | **1.45** | **1.14** | |
|
|
| All FIDs use 25 Heun steps with the time-shift schedule. |
|
|
| ## Sample Usage |
|
|
| The full training/inference code is available at [lezhang7/RiT](https://github.com/lezhang7/RiT). To download the weights manually and load them in PyTorch: |
|
|
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| |
| # Download the checkpoint |
| ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth") |
| |
| # Load the state dictionary |
| state = torch.load(ckpt, map_location="cpu", weights_only=False) |
| |
| # state['model'] / state['model_ema1'] / state['model_ema2'] are the |
| # trainable + two EMA-decay parameter dictionaries. |
| # state['model_ema1'] is the EMA decay 0.9999 (used for sampling by default). |
| model_weights = state['model_ema1'] |
| ``` |
|
|
| To run the evaluation script (which auto-pulls this checkpoint plus the matching RAE decoder): |
|
|
| ```bash |
| git clone https://github.com/lezhang7/RiT.git |
| cd RiT |
| pip install -r requirements.txt |
| bash scripts/eval.sh # CFG=3.7, FID ~1.14 on ImageNet 256x256 |
| ``` |
|
|
| ## Model details |
|
|
| - **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152, 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class tokens, joint [CLS]-patch modeling. |
| - **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384). |
| - **Decoder (frozen):** ViT-MAE-style decoder from [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections), variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`. |
| - **Parameters (denoiser only):** 676M. |
| - **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs, x-prediction loss, dimension-aware time shift (s ≈ 4.9), CLS auxiliary loss weight λ=0.2. |
| - **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in interval [0.1, 0.98], coupled-noise initialization for [CLS]. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{zhang2026rit, |
| title = {RiT: Vanilla Diffusion Transformers Suffice in Representation Space}, |
| author = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya}, |
| journal = {arXiv preprint arXiv:2605.21981}, |
| year = {2026} |
| } |
| ``` |
|
|
| ## Acknowledgments |
|
|
| This release reuses the frozen DINOv2 encoder + ViT decoder pairing from [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT). |