Harahan commited on
Commit
63f8615
Β·
verified Β·
1 Parent(s): 9987140

Add model card

Browse files
Files changed (1) hide show
  1. README.md +179 -0
README.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img width="70%" height="70%" alt="logo" src="https://github.com/user-attachments/assets/4d534e80-f8ec-4c0b-948f-730cc0311961" />
4
+
5
+ <h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>
6
+
7
+ <p><b>Reward-Tilted DMD &nbsp;Β·&nbsp; Ambient-Consistent Distillation &nbsp;Β·&nbsp; Hybrid Policy Gradient</b></p>
8
+
9
+ [![Paper](https://img.shields.io/badge/paper-arXiv-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](<TODO: arxiv link>)
10
+ [![Github](https://img.shields.io/badge/Harahan%2FRTDMD-000000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Harahan/RTDMD)
11
+ [![Hugging Face Collection](https://img.shields.io/badge/RTDMD_Collection-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/Harahan/rtdmd)
12
+
13
+ [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
14
+ [![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](https://www.python.org/)
15
+
16
+ </div>
17
+
18
+ <div align="center">
19
+
20
+ [Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>†</sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>†</sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>‑
21
+
22
+ <sup>1</sup>The Hong Kong University of Science and Technology &nbsp;&nbsp;
23
+ <sup>2</sup>Tencent Hunyuan &nbsp;&nbsp;
24
+ <sup>3</sup>Westlake University
25
+
26
+ \* Equal contribution &nbsp;Β·&nbsp; † Work done during internship at Tencent Hunyuan &nbsp;Β·&nbsp; ‑ Corresponding author
27
+
28
+ </div>
29
+
30
+ ---
31
+
32
+ ## πŸ“– Abstract
33
+
34
+ We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
35
+ two-stage framework that unifies distribution-matching distillation with
36
+ reward-guided RL for few-step flow generators. Minimizing the KL divergence to
37
+ a *reward-tilted teacher distribution* decomposes naturally into a
38
+ **distribution-matching** term and a **reward-maximization** term β€” instantiated
39
+ as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
40
+ gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
41
+ With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
42
+ distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
43
+ rewards.
44
+
45
+ <table align="center">
46
+ <tr>
47
+ <td align="center" width="50%">
48
+ <img src="https://github.com/user-attachments/assets/cb1fb0da-d388-4846-9017-66bccebd0749" alt="RTDMD teaser" width="100%">
49
+ <br/>
50
+ <em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
51
+ </td>
52
+ <td align="center" width="50%">
53
+ <img src="https://github.com/user-attachments/assets/3d76d99a-6fe1-4059-8e68-9461ba067b01" alt="RTDMD comparison" width="100%">
54
+ <br/>
55
+ <em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
56
+ </td>
57
+ </tr>
58
+ </table>
59
+
60
+ ---
61
+
62
+ ## 🍭 Method Overview
63
+
64
+ <div align="center">
65
+ <img src="https://github.com/user-attachments/assets/61a64fca-a143-40ae-9e36-79c6fcb5b696" alt="RTDMD method overview" width="70%">
66
+ <br/>
67
+ <em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
68
+ </div>
69
+
70
+ For the generator $G_\theta$, the reward-tilted KL objective decomposes as
71
+
72
+ $$
73
+ \nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
74
+ \underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
75
+ $$
76
+
77
+ The two terms map directly to the two trainers exposed by the CLI:
78
+
79
+ | Stage | Trainer | Key knobs |
80
+ | --- | --- | --- |
81
+ | 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
82
+ | 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`) | SubGRPO + final-step BP + AC-DMD |
83
+
84
+ ---
85
+
86
+ ## πŸ“¦ Contents
87
+
88
+ This repository hosts the 4-NFE LoRA checkpoints distilled from
89
+ **Stable Diffusion 3.5 Medium** with [RTDMD](https://github.com/Harahan/RTDMD).
90
+
91
+ ```
92
+ .
93
+ β”œβ”€β”€ cold_start/
94
+ β”‚ └── generator_ema.pt # Stage-1 AC-DMD LoRA (4 NFE base)
95
+ └── rtdmd/
96
+ └── generator_ema.pt # Stage-2 RTDMD LoRA (stacked on top of cold_start)
97
+ ```
98
+
99
+ Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
100
+ adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
101
+ are designed to be **stacked**: the cold-start LoRA distills SD3.5-M down to
102
+ 4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
103
+ reward-tilted RL.
104
+
105
+ ---
106
+
107
+ ## πŸš€ Usage
108
+
109
+ ### Option 1 β€” RTDMD inference CLI (recommended)
110
+
111
+ The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
112
+ run the CPS sampler for you:
113
+
114
+ ```bash
115
+ git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
116
+ pip install -r requirements.txt && pip install -e .
117
+
118
+ # Download this repo
119
+ huggingface-cli download Harahan/SD35M-RTDMD --local-dir ./ckpts/sd35m
120
+
121
+ # Run 4-NFE inference (single GPU)
122
+ python inference.py configs/inference/sd35m.yaml \
123
+ --override lora_paths='["./ckpts/sd35m/cold_start/generator_ema.pt","./ckpts/sd35m/rtdmd/generator_ema.pt"]' \
124
+ --override eval_reward=false \
125
+ --prompt "a cute cat sitting on a windowsill"
126
+ ```
127
+
128
+ ### Option 2 β€” Plain diffusers
129
+
130
+ ```python
131
+ import torch
132
+ from diffusers import StableDiffusion3Pipeline
133
+ from peft import LoraConfig
134
+ from huggingface_hub import hf_hub_download
135
+
136
+ base = "stabilityai/stable-diffusion-3.5-medium"
137
+ pipe = StableDiffusion3Pipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")
138
+
139
+ # Inject LoRA adapters with the rank/alpha used during training
140
+ TARGETS = [
141
+ "to_q", "to_k", "to_v", "to_out.0",
142
+ "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
143
+ ]
144
+ pipe.transformer.add_adapter(
145
+ LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
146
+ )
147
+
148
+ # Sequentially load cold-start then RTDMD weights into the same adapter
149
+ for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
150
+ path = hf_hub_download("Harahan/SD35M-RTDMD", ckpt)
151
+ state = torch.load(path, map_location="cpu", weights_only=False)
152
+ pipe.transformer.load_state_dict(state, strict=False)
153
+
154
+ # 4-step CPS sampling
155
+ pipe(prompt="a cute cat sitting on a windowsill",
156
+ num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
157
+ ```
158
+
159
+ > **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
160
+ > scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
161
+ > will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
162
+ > is the only entry point that reproduces the paper numbers exactly.
163
+
164
+ ---
165
+
166
+ ## πŸ“„ Citation
167
+
168
+ ```bibtex
169
+ <TODO: add bibtex after arXiv release>
170
+ ```
171
+
172
+ ---
173
+
174
+ ## βš–οΈ License
175
+
176
+ Apache 2.0 β€” same as the upstream
177
+ [RTDMD](https://github.com/Harahan/RTDMD) repo. The base model
178
+ [`stabilityai/stable-diffusion-3.5-medium`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
179
+ is governed by its own license; please review and comply with it separately.