Add model card
Browse files
README.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<img width="70%" height="70%" alt="logo" src="https://github.com/user-attachments/assets/4d534e80-f8ec-4c0b-948f-730cc0311961" />
|
| 4 |
+
|
| 5 |
+
<h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>
|
| 6 |
+
|
| 7 |
+
<p><b>Reward-Tilted DMD Β· Ambient-Consistent Distillation Β· Hybrid Policy Gradient</b></p>
|
| 8 |
+
|
| 9 |
+
[](<TODO: arxiv link>)
|
| 10 |
+
[](https://github.com/Harahan/RTDMD)
|
| 11 |
+
[](https://huggingface.co/collections/Harahan/rtdmd)
|
| 12 |
+
|
| 13 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 14 |
+
[](https://www.python.org/)
|
| 15 |
+
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
<div align="center">
|
| 19 |
+
|
| 20 |
+
[Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>β </sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>β </sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>β‘
|
| 21 |
+
|
| 22 |
+
<sup>1</sup>The Hong Kong University of Science and Technology
|
| 23 |
+
<sup>2</sup>Tencent Hunyuan
|
| 24 |
+
<sup>3</sup>Westlake University
|
| 25 |
+
|
| 26 |
+
\* Equal contribution Β· β Work done during internship at Tencent Hunyuan Β· β‘ Corresponding author
|
| 27 |
+
|
| 28 |
+
</div>
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## π Abstract
|
| 33 |
+
|
| 34 |
+
We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
|
| 35 |
+
two-stage framework that unifies distribution-matching distillation with
|
| 36 |
+
reward-guided RL for few-step flow generators. Minimizing the KL divergence to
|
| 37 |
+
a *reward-tilted teacher distribution* decomposes naturally into a
|
| 38 |
+
**distribution-matching** term and a **reward-maximization** term β instantiated
|
| 39 |
+
as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
|
| 40 |
+
gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
|
| 41 |
+
With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
|
| 42 |
+
distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
|
| 43 |
+
rewards.
|
| 44 |
+
|
| 45 |
+
<table align="center">
|
| 46 |
+
<tr>
|
| 47 |
+
<td align="center" width="50%">
|
| 48 |
+
<img src="https://github.com/user-attachments/assets/cb1fb0da-d388-4846-9017-66bccebd0749" alt="RTDMD teaser" width="100%">
|
| 49 |
+
<br/>
|
| 50 |
+
<em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
|
| 51 |
+
</td>
|
| 52 |
+
<td align="center" width="50%">
|
| 53 |
+
<img src="https://github.com/user-attachments/assets/3d76d99a-6fe1-4059-8e68-9461ba067b01" alt="RTDMD comparison" width="100%">
|
| 54 |
+
<br/>
|
| 55 |
+
<em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
|
| 56 |
+
</td>
|
| 57 |
+
</tr>
|
| 58 |
+
</table>
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## π Method Overview
|
| 63 |
+
|
| 64 |
+
<div align="center">
|
| 65 |
+
<img src="https://github.com/user-attachments/assets/61a64fca-a143-40ae-9e36-79c6fcb5b696" alt="RTDMD method overview" width="70%">
|
| 66 |
+
<br/>
|
| 67 |
+
<em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
|
| 68 |
+
</div>
|
| 69 |
+
|
| 70 |
+
For the generator $G_\theta$, the reward-tilted KL objective decomposes as
|
| 71 |
+
|
| 72 |
+
$$
|
| 73 |
+
\nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
|
| 74 |
+
\underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
|
| 75 |
+
$$
|
| 76 |
+
|
| 77 |
+
The two terms map directly to the two trainers exposed by the CLI:
|
| 78 |
+
|
| 79 |
+
| Stage | Trainer | Key knobs |
|
| 80 |
+
| --- | --- | --- |
|
| 81 |
+
| 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
|
| 82 |
+
| 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`) | SubGRPO + final-step BP + AC-DMD |
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## π¦ Contents
|
| 87 |
+
|
| 88 |
+
This repository hosts the 4-NFE LoRA checkpoints distilled from
|
| 89 |
+
**Stable Diffusion 3.5 Medium** with [RTDMD](https://github.com/Harahan/RTDMD).
|
| 90 |
+
|
| 91 |
+
```
|
| 92 |
+
.
|
| 93 |
+
βββ cold_start/
|
| 94 |
+
β βββ generator_ema.pt # Stage-1 AC-DMD LoRA (4 NFE base)
|
| 95 |
+
βββ rtdmd/
|
| 96 |
+
βββ generator_ema.pt # Stage-2 RTDMD LoRA (stacked on top of cold_start)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
|
| 100 |
+
adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
|
| 101 |
+
are designed to be **stacked**: the cold-start LoRA distills SD3.5-M down to
|
| 102 |
+
4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
|
| 103 |
+
reward-tilted RL.
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## π Usage
|
| 108 |
+
|
| 109 |
+
### Option 1 β RTDMD inference CLI (recommended)
|
| 110 |
+
|
| 111 |
+
The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
|
| 112 |
+
run the CPS sampler for you:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
|
| 116 |
+
pip install -r requirements.txt && pip install -e .
|
| 117 |
+
|
| 118 |
+
# Download this repo
|
| 119 |
+
huggingface-cli download Harahan/SD35M-RTDMD --local-dir ./ckpts/sd35m
|
| 120 |
+
|
| 121 |
+
# Run 4-NFE inference (single GPU)
|
| 122 |
+
python inference.py configs/inference/sd35m.yaml \
|
| 123 |
+
--override lora_paths='["./ckpts/sd35m/cold_start/generator_ema.pt","./ckpts/sd35m/rtdmd/generator_ema.pt"]' \
|
| 124 |
+
--override eval_reward=false \
|
| 125 |
+
--prompt "a cute cat sitting on a windowsill"
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Option 2 β Plain diffusers
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
import torch
|
| 132 |
+
from diffusers import StableDiffusion3Pipeline
|
| 133 |
+
from peft import LoraConfig
|
| 134 |
+
from huggingface_hub import hf_hub_download
|
| 135 |
+
|
| 136 |
+
base = "stabilityai/stable-diffusion-3.5-medium"
|
| 137 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")
|
| 138 |
+
|
| 139 |
+
# Inject LoRA adapters with the rank/alpha used during training
|
| 140 |
+
TARGETS = [
|
| 141 |
+
"to_q", "to_k", "to_v", "to_out.0",
|
| 142 |
+
"add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
|
| 143 |
+
]
|
| 144 |
+
pipe.transformer.add_adapter(
|
| 145 |
+
LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Sequentially load cold-start then RTDMD weights into the same adapter
|
| 149 |
+
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
|
| 150 |
+
path = hf_hub_download("Harahan/SD35M-RTDMD", ckpt)
|
| 151 |
+
state = torch.load(path, map_location="cpu", weights_only=False)
|
| 152 |
+
pipe.transformer.load_state_dict(state, strict=False)
|
| 153 |
+
|
| 154 |
+
# 4-step CPS sampling
|
| 155 |
+
pipe(prompt="a cute cat sitting on a windowsill",
|
| 156 |
+
num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
> **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
|
| 160 |
+
> scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
|
| 161 |
+
> will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
|
| 162 |
+
> is the only entry point that reproduces the paper numbers exactly.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## π Citation
|
| 167 |
+
|
| 168 |
+
```bibtex
|
| 169 |
+
<TODO: add bibtex after arXiv release>
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## βοΈ License
|
| 175 |
+
|
| 176 |
+
Apache 2.0 β same as the upstream
|
| 177 |
+
[RTDMD](https://github.com/Harahan/RTDMD) repo. The base model
|
| 178 |
+
[`stabilityai/stable-diffusion-3.5-medium`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
|
| 179 |
+
is governed by its own license; please review and comply with it separately.
|