le723z nielsr HF Staff commited on
Commit
3910531
·
1 Parent(s): 376029f

Improve model card: add paper link and sample usage (#1)

Browse files

- Improve model card: add paper link and sample usage (7eb882328f860d690f9c50438625886f63167539)


Co-authored-by: Niels Rogge <nielsr@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +40 -67
README.md CHANGED
@@ -1,24 +1,25 @@
1
  ---
2
- license: mit
 
3
  library_name: pytorch
 
4
  pipeline_tag: unconditional-image-generation
5
  tags:
6
- - diffusion
7
- - flow-matching
8
- - representation-learning
9
- - dinov2
10
- - imagenet
11
- datasets:
12
- - imagenet-1k
13
  ---
14
 
15
- # RiT-XL: Vanilla Diffusion Transformers Are Enough in Representation Space
16
 
17
- This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs
18
- on ImageNet 256×256 with frozen DINOv2-Small features.
 
19
 
20
  [![GitHub](https://img.shields.io/badge/GitHub-lezhang7%2FRiT-181717.svg)](https://github.com/lezhang7/RiT)
21
- [![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b.svg)](https://arxiv.org/)
22
 
23
  ## Results on ImageNet 256×256
24
 
@@ -36,83 +37,55 @@ on ImageNet 256×256 with frozen DINOv2-Small features.
36
 
37
  All FIDs use 25 Heun steps with the time-shift schedule.
38
 
39
- Few-step generation (no distillation, no consistency training):
40
-
41
- | Heun steps | 5 | 10 | 25 | 50 |
42
- |---------------:|-----:|-----:|-----:|-----:|
43
- | FID (CFG=1.0) | 2.44 | 1.59 | 1.47 | 1.46 |
44
- | FID (CFG=3.7) | 1.99 | 1.27 | 1.15 | 1.15 |
45
-
46
- ## Quick start
47
-
48
- The full training/inference code lives at
49
- [**lezhang7/RiT**](https://github.com/lezhang7/RiT). The eval script auto-pulls
50
- this checkpoint plus the matching RAE decoder on first run:
51
-
52
- ```bash
53
- git clone https://github.com/lezhang7/RiT.git
54
- cd RiT
55
- pip install -r requirements.txt
56
- bash scripts/eval.sh # CFG=3.7, FID ~1.14 on ImageNet 256x256
57
- ```
58
 
59
- To download just the weights manually:
60
 
61
  ```python
 
62
  from huggingface_hub import hf_hub_download
 
 
63
  ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth")
64
- import torch
 
65
  state = torch.load(ckpt, map_location="cpu", weights_only=False)
 
66
  # state['model'] / state['model_ema1'] / state['model_ema2'] are the
67
  # trainable + two EMA-decay parameter dictionaries.
 
 
68
  ```
69
 
70
- ## Checkpoint contents
71
-
72
- `checkpoint-last.pth` is a PyTorch checkpoint produced after 740 training
73
- epochs (the released model used for the paper's headline numbers). Top-level
74
- keys:
75
 
76
- - `model` — main parameters of the `Denoiser` (RiT-XL backbone).
77
- - `model_ema1` — EMA decay 0.9999 (used for sampling by default).
78
- - `model_ema2` — EMA decay 0.9996 (tracked but unused at inference).
79
- - `optimizer` — AdamW state for resuming training.
80
- - `epoch` — `740`.
81
- - `args` — argparse namespace from the original training run (legacy
82
- `JiT-RAE-XL/16` model name; the architecture matches the released
83
- `RiT-XL/16`).
84
-
85
- Loading uses only `model` / `model_ema*`, so the legacy `args` field does not
86
- matter — `eval.sh` constructs the model from the CLI flags.
87
 
88
  ## Model details
89
 
90
- - **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152,
91
- 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class
92
- tokens, joint [CLS]-patch modeling.
93
  - **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384).
94
- - **Decoder (frozen):** ViT-MAE-style decoder from
95
- [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections),
96
- variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`.
97
  - **Parameters (denoiser only):** 676M.
98
- - **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs (this
99
- ckpt: epoch 740), x-prediction loss, dimension-aware time shift
100
- (s ≈ 4.9), CLS auxiliary loss weight λ=0.2.
101
- - **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in
102
- interval [0.1, 0.98], coupled-noise initialization for [CLS].
103
 
104
  ## Citation
105
 
106
  ```bibtex
107
- @article{zhang2025rit,
108
- title = {RiT: Vanilla Diffusion Transformers Are Enough in Representation Space},
109
- author = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya},
110
- year = {2025}
 
111
  }
112
  ```
113
 
114
  ## Acknowledgments
115
 
116
- This release reuses the frozen DINOv2 encoder + ViT decoder pairing from
117
- [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT
118
- block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT).
 
1
  ---
2
+ datasets:
3
+ - imagenet-1k
4
  library_name: pytorch
5
+ license: mit
6
  pipeline_tag: unconditional-image-generation
7
  tags:
8
+ - diffusion
9
+ - flow-matching
10
+ - representation-learning
11
+ - dinov2
12
+ - imagenet
 
 
13
  ---
14
 
15
+ # RiT-XL: Vanilla Diffusion Transformers Suffice in Representation Space
16
 
17
+ This repository hosts the released **RiT-XL** checkpoint trained for 800 epochs on ImageNet 256×256 with frozen DINOv2-Small features.
18
+
19
+ RiT (Representation Image Transformer) is a vanilla Diffusion Transformer that effectively models distributions in high-dimensional representation spaces, as presented in the paper [RiT: Vanilla Diffusion Transformers Suffice in Representation Space](https://huggingface.co/papers/2605.21981).
20
 
21
  [![GitHub](https://img.shields.io/badge/GitHub-lezhang7%2FRiT-181717.svg)](https://github.com/lezhang7/RiT)
22
+ [![Paper](https://img.shields.io/badge/Paper-arXiv-b31b1b.svg)](https://huggingface.co/papers/2605.21981)
23
 
24
  ## Results on ImageNet 256×256
25
 
 
37
 
38
  All FIDs use 25 Heun steps with the time-shift schedule.
39
 
40
+ ## Sample Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ The full training/inference code is available at [lezhang7/RiT](https://github.com/lezhang7/RiT). To download the weights manually and load them in PyTorch:
43
 
44
  ```python
45
+ import torch
46
  from huggingface_hub import hf_hub_download
47
+
48
+ # Download the checkpoint
49
  ckpt = hf_hub_download(repo_id="le723z/RiT", filename="checkpoint-last.pth")
50
+
51
+ # Load the state dictionary
52
  state = torch.load(ckpt, map_location="cpu", weights_only=False)
53
+
54
  # state['model'] / state['model_ema1'] / state['model_ema2'] are the
55
  # trainable + two EMA-decay parameter dictionaries.
56
+ # state['model_ema1'] is the EMA decay 0.9999 (used for sampling by default).
57
+ model_weights = state['model_ema1']
58
  ```
59
 
60
+ To run the evaluation script (which auto-pulls this checkpoint plus the matching RAE decoder):
 
 
 
 
61
 
62
+ ```bash
63
+ git clone https://github.com/lezhang7/RiT.git
64
+ cd RiT
65
+ pip install -r requirements.txt
66
+ bash scripts/eval.sh # CFG=3.7, FID ~1.14 on ImageNet 256x256
67
+ ```
 
 
 
 
 
68
 
69
  ## Model details
70
 
71
+ - **Architecture:** vanilla Diffusion Transformer — 28 layers, hidden 1152, 16 heads, SwiGLU FFN, RMSNorm, QK-norm, 2D VisionRoPE, 32 in-context class tokens, joint [CLS]-patch modeling.
 
 
72
  - **Encoder (frozen):** `facebook/dinov2-with-registers-small` (d=384).
73
+ - **Decoder (frozen):** ViT-MAE-style decoder from [nyu-visionx/RAE-collections](https://huggingface.co/nyu-visionx/RAE-collections), variant `decoders/dinov2/wReg_small/ViTXL_n08/model.pt`.
 
 
74
  - **Parameters (denoiser only):** 676M.
75
+ - **Training:** 8×H200, batch 1536 effective, AdamW lr=5e-5, 800 epochs, x-prediction loss, dimension-aware time shift (s ≈ 4.9), CLS auxiliary loss weight λ=0.2.
76
+ - **Sampling defaults:** Heun, 25 steps, time-shift schedule, CFG=3.7 in interval [0.1, 0.98], coupled-noise initialization for [CLS].
 
 
 
77
 
78
  ## Citation
79
 
80
  ```bibtex
81
+ @article{zhang2026rit,
82
+ title = {RiT: Vanilla Diffusion Transformers Suffice in Representation Space},
83
+ author = {Zhang, Le and Mang, Ning and Agrawal, Aishwarya},
84
+ journal = {arXiv preprint arXiv:2605.21981},
85
+ year = {2026}
86
  }
87
  ```
88
 
89
  ## Acknowledgments
90
 
91
+ This release reuses the frozen DINOv2 encoder + ViT decoder pairing from [**RAE**](https://github.com/bytetriper/RAE) and adopts the modernized DiT block design + in-context class tokens from [**JiT**](https://github.com/LTH14/JiT).