File size: 7,888 Bytes
b8aa080
 
 
6dfb684
 
 
b8aa080
6dfb684
b8aa080
6dfb684
9def9e0
 
c54119e
9def9e0
 
 
 
 
6dfb684
9def9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb684
 
9def9e0
 
 
c54119e
9def9e0
 
 
 
c54119e
9def9e0
 
 
 
 
 
 
 
 
 
 
3aa2df7
9def9e0
 
 
 
 
 
 
6dfb684
 
 
 
 
9def9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb684
9def9e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bf00a
 
 
 
 
 
 
 
 
6dfb684
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
---
base_model:
- black-forest-labs/FLUX.2-klein-4B
language:
- en
license: apache-2.0
pipeline_tag: text-to-image
library_name: diffusers
---

<div align="center">

<img width="70%" height="70%" alt="logo" src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/l1JM1Si5PDCgvJR5SSiqf.png" />

<h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>

<p><b>Reward-Tilted DMD &nbsp;Β·&nbsp; Ambient-Consistent Distillation &nbsp;Β·&nbsp; Hybrid Policy Gradient</b></p>

[![Paper](https://img.shields.io/badge/paper-arXiv-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://huggingface.co/papers/2605.26108)
[![Github](https://img.shields.io/badge/Harahan%2FRTDMD-000000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Harahan/RTDMD)
[![Hugging Face Collection](https://img.shields.io/badge/RTDMD_Collection-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/Harahan/rtdmd)

[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](https://www.python.org/)

</div>

<div align="center">

[Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>†</sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>†</sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>‑

<sup>1</sup>The Hong Kong University of Science and Technology &nbsp;&nbsp;
<sup>2</sup>Tencent Hunyuan &nbsp;&nbsp;
<sup>3</sup>Westlake University

\* Equal contribution &nbsp;Β·&nbsp; † Work done during internship at Tencent Hunyuan &nbsp;Β·&nbsp; ‑ Corresponding author

</div>

---

## πŸ“– Abstract

We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
two-stage framework that unifies distribution-matching distillation with
reward-guided RL for few-step flow generators. Minimizing the KL divergence to
a *reward-tilted teacher distribution* decomposes naturally into a
**distribution-matching** term and a **reward-maximization** term β€” instantiated
as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
rewards.

More details can be found in the paper: [Reinforcing Few-step Generators via Reward-Tilted Distribution Matching](https://huggingface.co/papers/2605.26108).

<table align="center">
  <tr>
    <td align="center" width="50%">
      <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/MLr2YHfmAvKYQsfA50uOf.png" alt="RTDMD teaser" width="100%">
      <br/>
      <em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
    </td>
    <td align="center" width="50%">
      <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/UJnH2QqNCw4aJDFgwtfjP.png" alt="RTDMD comparison" width="100%">
      <br/>
      <em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
    </td>
  </tr>
</table>

---

## 🍭 Method Overview

<div align="center">
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/GSQ5Q9bF6SAiUqyFR4ZKs.png" alt="RTDMD method overview" width="70%">
  <br/>
  <em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
</div>

For the generator $G_\theta$, the reward-tilted KL objective decomposes as

$$

abla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
\underbrace{
abla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{
abla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
$$

The two terms map directly to the two trainers exposed by the CLI:

| Stage | Trainer | Key knobs |
| --- | --- | --- |
| 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
| 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`)  | SubGRPO + final-step BP + AC-DMD |

---

## πŸ“¦ Contents

This repository hosts the 4-NFE LoRA checkpoints distilled from **FLUX.2-klein 4B**
with [RTDMD](https://github.com/Harahan/RTDMD).

```
.
β”œβ”€β”€ cold_start/
β”‚   └── generator_ema.pt    # Stage-1 AC-DMD LoRA (4 NFE base)
└── rtdmd/
    └── generator_ema.pt    # Stage-2 RTDMD LoRA (stacked on top of cold_start)
```

Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
are designed to be **stacked**: the cold-start LoRA distills FLUX.2-klein 4B
down to 4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
reward-tilted RL.

---

## πŸš€ Usage

### Option 1 β€” RTDMD inference CLI (recommended)

The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
run the CPS sampler for you:

```bash
git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
pip install -r requirements.txt && pip install -e .

# Download this repo
huggingface-cli download Harahan/FLUX2-4B-RTDMD --local-dir ./ckpts/flux2_4b

# Run 4-NFE inference (single GPU)
python inference.py configs/inference/flux2_4b.yaml \
    --override lora_paths='["./ckpts/flux2_4b/cold_start/generator_ema.pt","./ckpts/flux2_4b/rtdmd/generator_ema.pt"]' \
    --override eval_reward=false \
    --prompt "a cute cat sitting on a windowsill"
```

### Option 2 β€” Plain diffusers

```python
import torch
from diffusers import Flux2KleinPipeline, Flux2Transformer2DModel
from huggingface_hub import hf_hub_download
from peft import LoraConfig

base = "black-forest-labs/FLUX.2-klein-4B"
pipe = Flux2KleinPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")

# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
    "to_q", "to_k", "to_v", "to_out.0",
    "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
    "to_qkv_mlp_proj",
] + [f"single_transformer_blocks.{i}.attn.to_out" for i in range(20)]
pipe.transformer.add_adapter(
    LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)

# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
    path = hf_hub_download("Harahan/FLUX2-4B-RTDMD", ckpt)
    state = torch.load(path, map_location="cpu", weights_only=False)
    pipe.transformer.load_state_dict(state, strict=False)

# 4-step CPS sampling
pipe(prompt="a cute cat sitting on a windowsill",
     num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
```

> **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
> scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
> will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
> is the only entry point that reproduces the paper numbers exactly.

---

## πŸ“„ Citation

```bibtex
@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
      title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, 
      author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
      year={2026},
      eprint={2605.26108},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2605.26108}, 
}
```