le723z commited on
Commit
376029f
·
verified ·
1 Parent(s): 2523ec2

Add model card

Browse files
Files changed (1) hide show
  1. README.md +116 -1
README.md CHANGED
@@ -1,3 +1,118 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ library_name: pytorch
4
+ pipeline_tag: unconditional-image-generation
5
+ tags:
6
+ - diffusion
7
+ - flow-matching
8
+ - representation-learning
9
+ - dinov2
10
+ - imagenet
11
+ datasets:
12
+ - imagenet-1k
13
  ---
14
+
15
+ # RiT-XL: Vanilla Diffusion Transformers Are Enough in Representation Space
16
+
17
+ This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs
18
+ on ImageNet 256×256 with frozen DINOv2-Small features.
19
+
20
+ [![GitHub](https://img.shields.io/badge/GitHub-lezhang7%2FRiT-181717.svg)](https://github.com/lezhang7/RiT)
21
+ [![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b.svg)](https://arxiv.org/)
22
+
23
+ ## Results on ImageNet 256×256
24
+
25
+ | Method | Encoder | Params | FID ↓ (CFG=1) | FID ↓ (CFG≈3.7) |
26
+ |--------------------------|---------------:|-------:|---------------:|----------------:|
27
+ | DiT-XL | SD-VAE | 675M | 9.62 | 2.27 |
28
+ | SiT-XL | SD-VAE | 675M | 8.61 | 2.06 |
29
+ | REPA-XL | SD-VAE | 675M | 5.78 | 1.29 |
30
+ | DDT-XL | SD-VAE | 675M | 6.27 | 1.26 |
31
+ | REG-XL | SD-VAE | 675M | 1.80 | 1.36 |
32
+ | RAE-XL | DINOv2-S | 676M | 1.87 | 1.41 |
33
+ | RAE-XL<sup>DH</sup> | DINOv2-B | 839M | 1.51 | 1.16 |
34
+ | FAE-XL | FAE-DINOv2-G | 675M | 1.48 | 1.29 |
35
+ | **RiT-XL (ours)** | **DINOv2-S** | **676M** | **1.45** | **1.14** |
36
+
37
+ All FIDs use 25 Heun steps with the time-shift schedule.
38
+
39
+ Few-step generation (no distillation, no consistency training):
40
+
41
+ | Heun steps | 5 | 10 | 25 | 50 |
42
+ |---------------:|-----:|-----:|-----:|-----:|
43
+ | FID (CFG=1.0) | 2.44 | 1.59 | 1.47 | 1.46 |
44
+ | FID (CFG=3.7) | 1.99 | 1.27 | 1.15 | 1.15 |
45
+
46
+ ## Quick start
47
+
48
+ The full training/inference code lives at
49
+ [**lezhang7/RiT**](https://github.com/lezhang7/RiT). The eval script auto-pulls
50
+ this checkpoint plus the matching RAE decoder on first run:
51
+
52
+ ```bash
53
+ git clone https://github.com/lezhang7/RiT.git
54
+ cd RiT
55
+ pip install -r requirements.txt
56
+ bash scripts/eval.sh # CFG=3.7, FID ~1.14 on ImageNet 256x256
57
+ ```
58
+
59
+ To download just the weights manually:
60
+
61
+ ```python
62
+ from huggingface_hub import hf_hub_download
63
+ ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth")
64
+ import torch
65
+ state = torch.load(ckpt, map_location="cpu", weights_only=False)
66
+ # state['model'] / state['model_ema1'] / state['model_ema2'] are the
67
+ # trainable + two EMA-decay parameter dictionaries.
68
+ ```
69
+
70
+ ## Checkpoint contents
71
+
72
+ `checkpoint-last.pth` is a PyTorch checkpoint produced after 740 training
73
+ epochs (the released model used for the paper's headline numbers). Top-level
74
+ keys:
75
+
76
+ - `model` — main parameters of the `Denoiser` (RiT-XL backbone).
77
+ - `model_ema1` — EMA decay 0.9999 (used for sampling by default).
78
+ - `model_ema2` — EMA decay 0.9996 (tracked but unused at inference).
79
+ - `optimizer` — AdamW state for resuming training.
80
+ - `epoch` — `740`.
81
+ - `args` — argparse namespace from the original training run (legacy
82
+ `JiT-RAE-XL/16` model name; the architecture matches the released
83
+ `RiT-XL/16`).
84
+
85
+ Loading uses only `model` / `model_ema*`, so the legacy `args` field does not
86
+ matter — `eval.sh` constructs the model from the CLI flags.
87
+
88
+ ## Model details
89
+
90
+ - **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152,
91
+ 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class
92
+ tokens, joint [CLS]-patch modeling.
93
+ - **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384).
94
+ - **Decoder (frozen):** ViT-MAE-style decoder from
95
+ [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections),
96
+ variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`.
97
+ - **Parameters (denoiser only):** 676M.
98
+ - **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs (this
99
+ ckpt: epoch 740), x-prediction loss, dimension-aware time shift
100
+ (s ≈ 4.9), CLS auxiliary loss weight λ=0.2.
101
+ - **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in
102
+ interval [0.1, 0.98], coupled-noise initialization for [CLS].
103
+
104
+ ## Citation
105
+
106
+ ```bibtex
107
+ @article{zhang2025rit,
108
+ title = {RiT: Vanilla Diffusion Transformers Are Enough in Representation Space},
109
+ author = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya},
110
+ year = {2025}
111
+ }
112
+ ```
113
+
114
+ ## Acknowledgments
115
+
116
+ This release reuses the frozen DINOv2 encoder + ViT decoder pairing from
117
+ [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT
118
+ block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT).