File size: 7,904 Bytes
c68fef4 63f8615 93ddb76 63f8615 2cecc60 63f8615 93ddb76 63f8615 93ddb76 63f8615 93ddb76 63f8615 93ddb76 63f8615 791fe48 63f8615 c68fef4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | ---
license: apache-2.0
language:
- en
base_model:
- stabilityai/stable-diffusion-3.5-medium
pipeline_tag: text-to-image
---
<div align="center">
<img width="70%" height="70%" alt="logo" src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/l1JM1Si5PDCgvJR5SSiqf.png" />
<h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>
<p><b>Reward-Tilted DMD Β· Ambient-Consistent Distillation Β· Hybrid Policy Gradient</b></p>
[](https://arxiv.org/abs/2605.26108)
[](https://github.com/Harahan/RTDMD)
[](https://huggingface.co/collections/Harahan/rtdmd)
[](https://opensource.org/licenses/Apache-2.0)
[](https://www.python.org/)
</div>
<div align="center">
[Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>β </sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>β </sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>β‘
<sup>1</sup>The Hong Kong University of Science and Technology
<sup>2</sup>Tencent Hunyuan
<sup>3</sup>Westlake University
\* Equal contribution Β· β Work done during internship at Tencent Hunyuan Β· β‘ Corresponding author
</div>
---
## π Abstract
We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
two-stage framework that unifies distribution-matching distillation with
reward-guided RL for few-step flow generators. Minimizing the KL divergence to
a *reward-tilted teacher distribution* decomposes naturally into a
**distribution-matching** term and a **reward-maximization** term β instantiated
as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
rewards.
<table align="center">
<tr>
<td align="center" width="50%">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/MLr2YHfmAvKYQsfA50uOf.png" alt="RTDMD teaser" width="100%">
<br/>
<em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
</td>
<td align="center" width="50%">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/UJnH2QqNCw4aJDFgwtfjP.png" alt="RTDMD comparison" width="100%">
<br/>
<em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
</td>
</tr>
</table>
---
## π Method Overview
<div align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/GSQ5Q9bF6SAiUqyFR4ZKs.png" alt="RTDMD method overview" width="70%">
<br/>
<em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
</div>
For the generator $G_\theta$, the reward-tilted KL objective decomposes as
$$
\nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
\underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
$$
The two terms map directly to the two trainers exposed by the CLI:
| Stage | Trainer | Key knobs |
| --- | --- | --- |
| 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
| 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`) | SubGRPO + final-step BP + AC-DMD |
---
## π¦ Contents
This repository hosts the 4-NFE LoRA checkpoints distilled from
**Stable Diffusion 3.5 Medium** with [RTDMD](https://github.com/Harahan/RTDMD).
```
.
βββ cold_start/
β βββ generator_ema.pt # Stage-1 AC-DMD LoRA (4 NFE base)
βββ rtdmd/
βββ generator_ema.pt # Stage-2 RTDMD LoRA (stacked on top of cold_start)
```
Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
are designed to be **stacked**: the cold-start LoRA distills SD3.5-M down to
4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
reward-tilted RL.
---
## π Usage
### Option 1 β RTDMD inference CLI (recommended)
The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
run the CPS sampler for you:
```bash
git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
pip install -r requirements.txt && pip install -e .
# Download this repo
huggingface-cli download Harahan/SD35M-RTDMD --local-dir ./ckpts/sd35m
# Run 4-NFE inference (single GPU)
python inference.py configs/inference/sd35m.yaml \
--override lora_paths='["./ckpts/sd35m/cold_start/generator_ema.pt","./ckpts/sd35m/rtdmd/generator_ema.pt"]' \
--override eval_reward=false \
--prompt "a cute cat sitting on a windowsill"
```
### Option 2 β Plain diffusers
```python
import torch
from diffusers import StableDiffusion3Pipeline
from peft import LoraConfig
from huggingface_hub import hf_hub_download
base = "stabilityai/stable-diffusion-3.5-medium"
pipe = StableDiffusion3Pipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")
# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
"to_q", "to_k", "to_v", "to_out.0",
"add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
]
pipe.transformer.add_adapter(
LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)
# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
path = hf_hub_download("Harahan/SD35M-RTDMD", ckpt)
state = torch.load(path, map_location="cpu", weights_only=False)
pipe.transformer.load_state_dict(state, strict=False)
# 4-step CPS sampling
pipe(prompt="a cute cat sitting on a windowsill",
num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
```
> **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
> scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
> will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
> is the only entry point that reproduces the paper numbers exactly.
---
## π Citation
```bibtex
@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching},
author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
year={2026},
eprint={2605.26108},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2605.26108},
}
```
---
## βοΈ License
Apache 2.0 β same as the upstream
[RTDMD](https://github.com/Harahan/RTDMD) repo. The base model
[`stabilityai/stable-diffusion-3.5-medium`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
is governed by its own license; please review and comply with it separately. |