File size: 7,904 Bytes
c68fef4
 
 
 
 
 
 
 
63f8615
 
93ddb76
63f8615
 
 
 
 
2cecc60
63f8615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93ddb76
63f8615
 
 
 
93ddb76
63f8615
 
 
 
 
 
 
 
 
 
93ddb76
63f8615
93ddb76
63f8615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791fe48
 
 
 
 
 
 
 
 
63f8615
 
 
 
 
 
 
 
 
c68fef4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
---
license: apache-2.0
language:
- en
base_model:
- stabilityai/stable-diffusion-3.5-medium
pipeline_tag: text-to-image
---
<div align="center">

<img width="70%" height="70%" alt="logo" src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/l1JM1Si5PDCgvJR5SSiqf.png" />

<h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>

<p><b>Reward-Tilted DMD &nbsp;Β·&nbsp; Ambient-Consistent Distillation &nbsp;Β·&nbsp; Hybrid Policy Gradient</b></p>

[![Paper](https://img.shields.io/badge/paper-arXiv-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2605.26108)
[![Github](https://img.shields.io/badge/Harahan%2FRTDMD-000000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Harahan/RTDMD)
[![Hugging Face Collection](https://img.shields.io/badge/RTDMD_Collection-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/Harahan/rtdmd)

[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](https://www.python.org/)

</div>

<div align="center">

[Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>†</sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>†</sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>‑

<sup>1</sup>The Hong Kong University of Science and Technology &nbsp;&nbsp;
<sup>2</sup>Tencent Hunyuan &nbsp;&nbsp;
<sup>3</sup>Westlake University

\* Equal contribution &nbsp;Β·&nbsp; † Work done during internship at Tencent Hunyuan &nbsp;Β·&nbsp; ‑ Corresponding author

</div>

---

## πŸ“– Abstract

We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
two-stage framework that unifies distribution-matching distillation with
reward-guided RL for few-step flow generators. Minimizing the KL divergence to
a *reward-tilted teacher distribution* decomposes naturally into a
**distribution-matching** term and a **reward-maximization** term β€” instantiated
as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
rewards.

<table align="center">
  <tr>
    <td align="center" width="50%">
      <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/MLr2YHfmAvKYQsfA50uOf.png" alt="RTDMD teaser" width="100%">
      <br/>
      <em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
    </td>
    <td align="center" width="50%">
      <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/UJnH2QqNCw4aJDFgwtfjP.png" alt="RTDMD comparison" width="100%">
      <br/>
      <em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
    </td>
  </tr>
</table>

---

## 🍭 Method Overview


<div align="center">
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/GSQ5Q9bF6SAiUqyFR4ZKs.png" alt="RTDMD method overview" width="70%">
  <br/>
  <em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
</div>

For the generator $G_\theta$, the reward-tilted KL objective decomposes as

$$
\nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
\underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
$$

The two terms map directly to the two trainers exposed by the CLI:

| Stage | Trainer | Key knobs |
| --- | --- | --- |
| 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
| 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`)  | SubGRPO + final-step BP + AC-DMD |

---

## πŸ“¦ Contents

This repository hosts the 4-NFE LoRA checkpoints distilled from
**Stable Diffusion 3.5 Medium** with [RTDMD](https://github.com/Harahan/RTDMD).

```
.
β”œβ”€β”€ cold_start/
β”‚   └── generator_ema.pt    # Stage-1 AC-DMD LoRA (4 NFE base)
└── rtdmd/
    └── generator_ema.pt    # Stage-2 RTDMD LoRA (stacked on top of cold_start)
```

Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
are designed to be **stacked**: the cold-start LoRA distills SD3.5-M down to
4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
reward-tilted RL.

---

## πŸš€ Usage

### Option 1 β€” RTDMD inference CLI (recommended)

The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
run the CPS sampler for you:

```bash
git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
pip install -r requirements.txt && pip install -e .

# Download this repo
huggingface-cli download Harahan/SD35M-RTDMD --local-dir ./ckpts/sd35m

# Run 4-NFE inference (single GPU)
python inference.py configs/inference/sd35m.yaml \
    --override lora_paths='["./ckpts/sd35m/cold_start/generator_ema.pt","./ckpts/sd35m/rtdmd/generator_ema.pt"]' \
    --override eval_reward=false \
    --prompt "a cute cat sitting on a windowsill"
```

### Option 2 β€” Plain diffusers

```python
import torch
from diffusers import StableDiffusion3Pipeline
from peft import LoraConfig
from huggingface_hub import hf_hub_download

base = "stabilityai/stable-diffusion-3.5-medium"
pipe = StableDiffusion3Pipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")

# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
    "to_q", "to_k", "to_v", "to_out.0",
    "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
]
pipe.transformer.add_adapter(
    LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)

# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
    path = hf_hub_download("Harahan/SD35M-RTDMD", ckpt)
    state = torch.load(path, map_location="cpu", weights_only=False)
    pipe.transformer.load_state_dict(state, strict=False)

# 4-step CPS sampling
pipe(prompt="a cute cat sitting on a windowsill",
     num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
```

> **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
> scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
> will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
> is the only entry point that reproduces the paper numbers exactly.

---

## πŸ“„ Citation

```bibtex
@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
      title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, 
      author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
      year={2026},
      eprint={2605.26108},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2605.26108}, 
}
```

---

## βš–οΈ License

Apache 2.0 β€” same as the upstream
[RTDMD](https://github.com/Harahan/RTDMD) repo. The base model
[`stabilityai/stable-diffusion-3.5-medium`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
is governed by its own license; please review and comply with it separately.