RiT / README.md
le723z's picture
Improve model card: add paper link and sample usage (#1)
3910531
---
datasets:
- imagenet-1k
library_name: pytorch
license: mit
pipeline_tag: unconditional-image-generation
tags:
- diffusion
- flow-matching
- representation-learning
- dinov2
- imagenet
---
# RiT-XL: Vanilla Diffusion Transformers Suffice in Representation Space
This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs on ImageNet 256×256 with frozen DINOv2-Small features.
RiT (Representation Image Transformer) is a vanilla Diffusion Transformer that effectively models distributions in high-dimensional representation spaces, as presented in the paper [RiT: Vanilla Diffusion Transformers Suffice in Representation Space](https://huggingface.co/papers/2605.21981).
[![GitHub](https://img.shields.io/badge/GitHub-lezhang7%2FRiT-181717.svg)](https://github.com/lezhang7/RiT)
[![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b.svg)](https://huggingface.co/papers/2605.21981)
## Results on ImageNet 256×256
| Method | Encoder | Params | FID ↓ (CFG=1) | FID ↓ (CFG≈3.7) |
|--------------------------|---------------:|-------:|---------------:|----------------:|
| DiT-XL | SD-VAE | 675M | 9.62 | 2.27 |
| SiT-XL | SD-VAE | 675M | 8.61 | 2.06 |
| REPA-XL | SD-VAE | 675M | 5.78 | 1.29 |
| DDT-XL | SD-VAE | 675M | 6.27 | 1.26 |
| REG-XL | SD-VAE | 675M | 1.80 | 1.36 |
| RAE-XL | DINOv2-S | 676M | 1.87 | 1.41 |
| RAE-XL<sup>DH</sup> | DINOv2-B | 839M | 1.51 | 1.16 |
| FAE-XL | FAE-DINOv2-G | 675M | 1.48 | 1.29 |
| **RiT-XL (ours)** | **DINOv2-S** | **676M** | **1.45** | **1.14** |
All FIDs use 25 Heun steps with the time-shift schedule.
## Sample Usage
The full training/inference code is available at [lezhang7/RiT](https://github.com/lezhang7/RiT). To download the weights manually and load them in PyTorch:
```python
import torch
from huggingface_hub import hf_hub_download
# Download the checkpoint
ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth")
# Load the state dictionary
state = torch.load(ckpt, map_location="cpu", weights_only=False)
# state['model'] / state['model_ema1'] / state['model_ema2'] are the
# trainable + two EMA-decay parameter dictionaries.
# state['model_ema1'] is the EMA decay 0.9999 (used for sampling by default).
model_weights = state['model_ema1']
```
To run the evaluation script (which auto-pulls this checkpoint plus the matching RAE decoder):
```bash
git clone https://github.com/lezhang7/RiT.git
cd RiT
pip install -r requirements.txt
bash scripts/eval.sh # CFG=3.7, FID ~1.14 on ImageNet 256x256
```
## Model details
- **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152, 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class tokens, joint [CLS]-patch modeling.
- **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384).
- **Decoder (frozen):** ViT-MAE-style decoder from [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections), variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`.
- **Parameters (denoiser only):** 676M.
- **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs, x-prediction loss, dimension-aware time shift (s ≈ 4.9), CLS auxiliary loss weight λ=0.2.
- **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in interval [0.1, 0.98], coupled-noise initialization for [CLS].
## Citation
```bibtex
@article{zhang2026rit,
title = {RiT: Vanilla Diffusion Transformers Suffice in Representation Space},
author = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya},
journal = {arXiv preprint arXiv:2605.21981},
year = {2026}
}
```
## Acknowledgments
This release reuses the frozen DINOv2 encoder + ViT decoder pairing from [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT).