File size: 4,277 Bytes
194bcc0
7eb8823
 
376029f
7eb8823
376029f
 
7eb8823
 
 
 
 
194bcc0
376029f
7eb8823
376029f
7eb8823
 
 
376029f
 
7eb8823
376029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eb8823
376029f
7eb8823
376029f
 
7eb8823
376029f
7eb8823
 
376029f
7eb8823
 
376029f
7eb8823
376029f
 
7eb8823
 
376029f
 
7eb8823
376029f
7eb8823
 
 
 
 
 
376029f
 
 
7eb8823
376029f
7eb8823
376029f
7eb8823
 
376029f
 
 
 
7eb8823
 
 
 
 
376029f
 
 
 
 
7eb8823
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
datasets:
- imagenet-1k
library_name: pytorch
license: mit
pipeline_tag: unconditional-image-generation
tags:
- diffusion
- flow-matching
- representation-learning
- dinov2
- imagenet
---

# RiT-XL: Vanilla Diffusion Transformers Suffice in Representation Space

This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs on ImageNet 256×256 with frozen DINOv2-Small features. 

RiT (Representation Image Transformer) is a vanilla Diffusion Transformer that effectively models distributions in high-dimensional representation spaces, as presented in the paper [RiT: Vanilla Diffusion Transformers Suffice in Representation Space](https://huggingface.co/papers/2605.21981).

[![GitHub](https://img.shields.io/badge/GitHub-lezhang7%2FRiT-181717.svg)](https://github.com/lezhang7/RiT)
[![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b.svg)](https://huggingface.co/papers/2605.21981)

## Results on ImageNet 256×256

| Method                   | Encoder        | Params |  FID ↓ (CFG=1) | FID ↓ (CFG≈3.7) |
|--------------------------|---------------:|-------:|---------------:|----------------:|
| DiT-XL                   | SD-VAE         |  675M  | 9.62           | 2.27            |
| SiT-XL                   | SD-VAE         |  675M  | 8.61           | 2.06            |
| REPA-XL                  | SD-VAE         |  675M  | 5.78           | 1.29            |
| DDT-XL                   | SD-VAE         |  675M  | 6.27           | 1.26            |
| REG-XL                   | SD-VAE         |  675M  | 1.80           | 1.36            |
| RAE-XL                   | DINOv2-S       |  676M  | 1.87           | 1.41            |
| RAE-XL<sup>DH</sup>      | DINOv2-B       |  839M  | 1.51           | 1.16            |
| FAE-XL                   | FAE-DINOv2-G   |  675M  | 1.48           | 1.29            |
| **RiT-XL (ours)**        | **DINOv2-S**   | **676M** | **1.45**     | **1.14**        |

All FIDs use 25 Heun steps with the time-shift schedule.

## Sample Usage

The full training/inference code is available at [lezhang7/RiT](https://github.com/lezhang7/RiT). To download the weights manually and load them in PyTorch:

```python
import torch
from huggingface_hub import hf_hub_download

# Download the checkpoint
ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth")

# Load the state dictionary
state = torch.load(ckpt, map_location="cpu", weights_only=False)

# state['model'] / state['model_ema1'] / state['model_ema2'] are the
# trainable + two EMA-decay parameter dictionaries.
# state['model_ema1'] is the EMA decay 0.9999 (used for sampling by default).
model_weights = state['model_ema1']
```

To run the evaluation script (which auto-pulls this checkpoint plus the matching RAE decoder):

```bash
git clone https://github.com/lezhang7/RiT.git
cd RiT
pip install -r requirements.txt
bash scripts/eval.sh        # CFG=3.7, FID ~1.14 on ImageNet 256x256
```

## Model details

- **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152, 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class tokens, joint [CLS]-patch modeling.
- **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384).
- **Decoder (frozen):** ViT-MAE-style decoder from [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections), variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`.
- **Parameters (denoiser only):** 676M.
- **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs, x-prediction loss, dimension-aware time shift (s ≈ 4.9), CLS auxiliary loss weight λ=0.2.
- **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in interval [0.1, 0.98], coupled-noise initialization for [CLS].

## Citation

```bibtex
@article{zhang2026rit,
  title   = {RiT: Vanilla Diffusion Transformers Suffice in Representation Space},
  author  = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya},
  journal = {arXiv preprint arXiv:2605.21981},
  year    = {2026}
}
```

## Acknowledgments

This release reuses the frozen DINOv2 encoder + ViT decoder pairing from [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT).