FLUX2-4B-RTDMD / README.md
Harahan's picture
Update README.md
844762f verified
---
license: apache-2.0
language:
- en
base_model:
- black-forest-labs/FLUX.2-klein-4B
pipeline_tag: text-to-image
---
<div align="center">
<img width="70%" height="70%" alt="logo" src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/l1JM1Si5PDCgvJR5SSiqf.png" />
<h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2>
<p><b>Reward-Tilted DMD &nbsp;Β·&nbsp; Ambient-Consistent Distillation &nbsp;Β·&nbsp; Hybrid Policy Gradient</b></p>
[![Paper](https://img.shields.io/badge/paper-arXiv-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2605.26108)
[![Github](https://img.shields.io/badge/Harahan%2FRTDMD-000000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Harahan/RTDMD)
[![Hugging Face Collection](https://img.shields.io/badge/RTDMD_Collection-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/Harahan/rtdmd)
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](https://www.python.org/)
</div>
<div align="center">
[Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>†</sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>†</sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>‑
<sup>1</sup>The Hong Kong University of Science and Technology &nbsp;&nbsp;
<sup>2</sup>Tencent Hunyuan &nbsp;&nbsp;
<sup>3</sup>Westlake University
\* Equal contribution &nbsp;Β·&nbsp; † Work done during internship at Tencent Hunyuan &nbsp;Β·&nbsp; ‑ Corresponding author
</div>
---
## πŸ“– Abstract
We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a
two-stage framework that unifies distribution-matching distillation with
reward-guided RL for few-step flow generators. Minimizing the KL divergence to
a *reward-tilted teacher distribution* decomposes naturally into a
**distribution-matching** term and a **reward-maximization** term β€” instantiated
as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy
gradient** (SubGRPO + final-step reward back-propagation) for the RL stage.
With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the
distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most
rewards.
<table align="center">
<tr>
<td align="center" width="50%">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/MLr2YHfmAvKYQsfA50uOf.png" alt="RTDMD teaser" width="100%">
<br/>
<em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em>
</td>
<td align="center" width="50%">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/UJnH2QqNCw4aJDFgwtfjP.png" alt="RTDMD comparison" width="100%">
<br/>
<em>Qualitative comparison for few-step diffusion models (4 NFE).</em>
</td>
</tr>
</table>
---
## 🍭 Method Overview
<div align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/GSQ5Q9bF6SAiUqyFR4ZKs.png" alt="RTDMD method overview" width="70%">
<br/>
<em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em>
</div>
For the generator $G_\theta$, the reward-tilted KL objective decomposes as
$$
\nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) =
\underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}.
$$
The two terms map directly to the two trainers exposed by the CLI:
| Stage | Trainer | Key knobs |
| --- | --- | --- |
| 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` |
| 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`) | SubGRPO + final-step BP + AC-DMD |
---
## πŸ“¦ Contents
This repository hosts the 4-NFE LoRA checkpoints distilled from **FLUX.2-klein 4B**
with [RTDMD](https://github.com/Harahan/RTDMD).
```
.
β”œβ”€β”€ cold_start/
β”‚ └── generator_ema.pt # Stage-1 AC-DMD LoRA (4 NFE base)
└── rtdmd/
└── generator_ema.pt # Stage-2 RTDMD LoRA (stacked on top of cold_start)
```
Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA
adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters
are designed to be **stacked**: the cold-start LoRA distills FLUX.2-klein 4B
down to 4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with
reward-tilted RL.
---
## πŸš€ Usage
### Option 1 β€” RTDMD inference CLI (recommended)
The simplest path is to clone the RTDMD repo and let it stack both LoRAs and
run the CPS sampler for you:
```bash
git clone https://github.com/Harahan/RTDMD.git && cd RTDMD
pip install -r requirements.txt && pip install -e .
# Download this repo
huggingface-cli download Harahan/FLUX2-4B-RTDMD --local-dir ./ckpts/flux2_4b
# Run 4-NFE inference (single GPU)
python inference.py configs/inference/flux2_4b.yaml \
--override lora_paths='["./ckpts/flux2_4b/cold_start/generator_ema.pt","./ckpts/flux2_4b/rtdmd/generator_ema.pt"]' \
--override eval_reward=false \
--prompt "a cute cat sitting on a windowsill"
```
### Option 2 β€” Plain diffusers
```python
import torch
from diffusers import Flux2KleinPipeline, Flux2Transformer2DModel
from huggingface_hub import hf_hub_download
base = "black-forest-labs/FLUX.2-klein-4B"
pipe = Flux2KleinPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")
# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
"to_q", "to_k", "to_v", "to_out.0",
"add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
"to_qkv_mlp_proj",
] + [f"single_transformer_blocks.{i}.attn.to_out" for i in range(20)]
pipe.transformer.add_adapter(
LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)
# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
path = hf_hub_download("Harahan/FLUX2-4B-RTDMD", ckpt)
state = torch.load(path, map_location="cpu", weights_only=False)
pipe.transformer.load_state_dict(state, strict=False)
# 4-step CPS sampling
pipe(prompt="a cute cat sitting on a windowsill",
num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")
```
> **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling)
> scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler
> will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI
> is the only entry point that reproduces the paper numbers exactly.
---
## πŸ“„ Citation
```bibtex
@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching},
author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
year={2026},
eprint={2605.26108},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2605.26108},
}
```
---
## βš–οΈ License
Apache 2.0 β€” same as the upstream
[RTDMD](https://github.com/Harahan/RTDMD) repo. The base model
[`black-forest-labs/FLUX.2-klein-4B`](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B)
is governed by its own license; please review and comply with it separately.