| --- |
| license: apache-2.0 |
| language: |
| - en |
| base_model: |
| - black-forest-labs/FLUX.2-klein-4B |
| pipeline_tag: text-to-image |
| --- |
| <div align="center"> |
|
|
| <img width="70%" height="70%" alt="logo" src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/l1JM1Si5PDCgvJR5SSiqf.png" /> |
|
|
| <h2> Reinforcing Few-step Generators via Reward-Tilted Distribution Matching </h2> |
|
|
| <p><b>Reward-Tilted DMD Β· Ambient-Consistent Distillation Β· Hybrid Policy Gradient</b></p> |
|
|
| [](https://arxiv.org/abs/2605.26108) |
| [](https://github.com/Harahan/RTDMD) |
| [](https://huggingface.co/collections/Harahan/rtdmd) |
|
|
| [](https://opensource.org/licenses/Apache-2.0) |
| [](https://www.python.org/) |
|
|
| </div> |
|
|
| <div align="center"> |
|
|
| [Yushi Huang](https://harahan.github.io/)<sup>1, 2,</sup>\*<sup>β </sup>, [Xiangxin Zhou](https://zhouxiangxin1998.github.io/)<sup>2,</sup>\*, Ruoyu Wang<sup>2, 3,</sup>\*<sup>β </sup>, [Chi Zhang](https://icoz69.github.io/)<sup>3</sup>, [Jun Zhang](https://eejzhang.people.ust.hk/)<sup>1</sup>, [Tianyu Pang](https://p2333.github.io/)<sup>2,</sup>β‘ |
| |
| <sup>1</sup>The Hong Kong University of Science and Technology |
| <sup>2</sup>Tencent Hunyuan |
| <sup>3</sup>Westlake University |
| |
| \* Equal contribution Β· β Work done during internship at Tencent Hunyuan Β· β‘ Corresponding author |
|
|
| </div> |
|
|
| --- |
|
|
| ## π Abstract |
|
|
| We propose **Reward-Tilted Distribution Matching Distillation (RTDMD)**, a |
| two-stage framework that unifies distribution-matching distillation with |
| reward-guided RL for few-step flow generators. Minimizing the KL divergence to |
| a *reward-tilted teacher distribution* decomposes naturally into a |
| **distribution-matching** term and a **reward-maximization** term β instantiated |
| as **Ambient-Consistent DMD (AC-DMD)** for the cold start and a **hybrid policy |
| gradient** (SubGRPO + final-step reward back-propagation) for the RL stage. |
| With **4 NFE** RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the |
| distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most |
| rewards. |
|
|
| <table align="center"> |
| <tr> |
| <td align="center" width="50%"> |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/MLr2YHfmAvKYQsfA50uOf.png" alt="RTDMD teaser" width="100%"> |
| <br/> |
| <em>4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).</em> |
| </td> |
| <td align="center" width="50%"> |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/UJnH2QqNCw4aJDFgwtfjP.png" alt="RTDMD comparison" width="100%"> |
| <br/> |
| <em>Qualitative comparison for few-step diffusion models (4 NFE).</em> |
| </td> |
| </tr> |
| </table> |
| |
| --- |
|
|
| ## π Method Overview |
|
|
| <div align="center"> |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/64b500fdf460afaefc5c64b3/GSQ5Q9bF6SAiUqyFR4ZKs.png" alt="RTDMD method overview" width="70%"> |
| <br/> |
| <em>RTDMD overview. <b>Det.</b> = deterministic final step, <b>Stoc.</b> = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).</em> |
| </div> |
|
|
| For the generator $G_\theta$, the reward-tilted KL objective decomposes as |
| |
| $$ |
| \nabla_\theta D_{\text{KL}}(p_\theta \| \tilde{p}_\psi) = |
| \underbrace{\nabla_\theta D_{\text{KL}}(p_\theta \| p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}. |
| $$ |
|
|
| The two terms map directly to the two trainers exposed by the CLI: |
|
|
| | Stage | Trainer | Key knobs | |
| | --- | --- | --- | |
| | 1. AC-DMD cold start | `ACDMDTrainer` (`--trainer ac_dmd`) | sub-interval renoising, consistency weight `Ξ³`, CPS sampler `Ξ· = 0.9` | |
| | 2. RTDMD RL fine-tune | `RTDMDTrainer` (`--trainer rtdmd`) | SubGRPO + final-step BP + AC-DMD | |
|
|
| --- |
|
|
| ## π¦ Contents |
|
|
| This repository hosts the 4-NFE LoRA checkpoints distilled from **FLUX.2-klein 4B** |
| with [RTDMD](https://github.com/Harahan/RTDMD). |
|
|
| ``` |
| . |
| βββ cold_start/ |
| β βββ generator_ema.pt # Stage-1 AC-DMD LoRA (4 NFE base) |
| βββ rtdmd/ |
| βββ generator_ema.pt # Stage-2 RTDMD LoRA (stacked on top of cold_start) |
| ``` |
|
|
| Each `generator_ema.pt` is a `torch.save`-d `state_dict` containing only LoRA |
| adapter keys (`lora_A` / `lora_B`, rank **32**, alpha **64**). The two adapters |
| are designed to be **stacked**: the cold-start LoRA distills FLUX.2-klein 4B |
| down to 4 NFE, and the RTDMD LoRA further fine-tunes that distilled model with |
| reward-tilted RL. |
|
|
| --- |
|
|
| ## π Usage |
|
|
| ### Option 1 β RTDMD inference CLI (recommended) |
|
|
| The simplest path is to clone the RTDMD repo and let it stack both LoRAs and |
| run the CPS sampler for you: |
|
|
| ```bash |
| git clone https://github.com/Harahan/RTDMD.git && cd RTDMD |
| pip install -r requirements.txt && pip install -e . |
| |
| # Download this repo |
| huggingface-cli download Harahan/FLUX2-4B-RTDMD --local-dir ./ckpts/flux2_4b |
| |
| # Run 4-NFE inference (single GPU) |
| python inference.py configs/inference/flux2_4b.yaml \ |
| --override lora_paths='["./ckpts/flux2_4b/cold_start/generator_ema.pt","./ckpts/flux2_4b/rtdmd/generator_ema.pt"]' \ |
| --override eval_reward=false \ |
| --prompt "a cute cat sitting on a windowsill" |
| ``` |
|
|
| ### Option 2 β Plain diffusers |
|
|
| ```python |
| import torch |
| from diffusers import Flux2KleinPipeline, Flux2Transformer2DModel |
| from huggingface_hub import hf_hub_download |
| |
| base = "black-forest-labs/FLUX.2-klein-4B" |
| pipe = Flux2KleinPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda") |
| |
| # Inject LoRA adapters with the rank/alpha used during training |
| TARGETS = [ |
| "to_q", "to_k", "to_v", "to_out.0", |
| "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out", |
| "to_qkv_mlp_proj", |
| ] + [f"single_transformer_blocks.{i}.attn.to_out" for i in range(20)] |
| pipe.transformer.add_adapter( |
| LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian") |
| ) |
| |
| # Sequentially load cold-start then RTDMD weights into the same adapter |
| for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]: |
| path = hf_hub_download("Harahan/FLUX2-4B-RTDMD", ckpt) |
| state = torch.load(path, map_location="cpu", weights_only=False) |
| pipe.transformer.load_state_dict(state, strict=False) |
| |
| # 4-step CPS sampling |
| pipe(prompt="a cute cat sitting on a windowsill", |
| num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png") |
| ``` |
|
|
| > **Note:** RTDMD is trained on the CPS (Coefficients-Preserving Sampling) |
| > scheduler with `Ξ· = 0.9`. Using the default Flow-Matching Euler scheduler |
| > will still produce reasonable samples at 4 NFE, but the RTDMD inference CLI |
| > is the only entry point that reproduces the paper numbers exactly. |
|
|
| --- |
|
|
| ## π Citation |
|
|
| ```bibtex |
| @misc{huang2026reinforcingfewstepgeneratorsrewardtilted, |
| title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, |
| author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang}, |
| year={2026}, |
| eprint={2605.26108}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.CV}, |
| url={https://arxiv.org/abs/2605.26108}, |
| } |
| ``` |
|
|
| --- |
|
|
| ## βοΈ License |
|
|
| Apache 2.0 β same as the upstream |
| [RTDMD](https://github.com/Harahan/RTDMD) repo. The base model |
| [`black-forest-labs/FLUX.2-klein-4B`](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) |
| is governed by its own license; please review and comply with it separately. |