SD35M-RTDMD / README.md
nielsr's picture
nielsr HF Staff
Improve model card metadata and usage
1219e32 verified
|
raw
history blame
5.93 kB
metadata
base_model:
  - stabilityai/stable-diffusion-3.5-medium
language:
  - en
license: apache-2.0
pipeline_tag: text-to-image
library_name: diffusers
tags:
  - lora
  - flow-matching
  - distillation
  - stable-diffusion
  - rtdmd
logo

Reinforcing Few-step Generators via Reward-Tilted Distribution Matching

Reward-Tilted DMD  Β·  Ambient-Consistent Distillation  Β·  Hybrid Policy Gradient

Paper Github Hugging Face Collection

License: Apache 2.0 Python

Yushi Huang1, 2,*†, Xiangxin Zhou2,*, Ruoyu Wang2, 3,*†, Chi Zhang3, Jun Zhang1, Tianyu Pang2,‑

1The Hong Kong University of Science and Technology    2Tencent Hunyuan    3Westlake University

* Equal contribution  Β·  † Work done during internship at Tencent Hunyuan  Β·  ‑ Corresponding author


πŸ“– Abstract

This repository contains the 4-NFE LoRA checkpoints distilled from Stable Diffusion 3.5 Medium using the framework proposed in the paper Reinforcing Few-step Generators via Reward-Tilted Distribution Matching.

We propose Reward-Tilted Distribution Matching Distillation (RTDMD), a two-stage framework that unifies distribution-matching distillation with reward-guided RL for few-step flow generators. Minimizing the KL divergence to a reward-tilted teacher distribution decomposes naturally into a distribution-matching term and a reward-maximization term β€” instantiated as Ambient-Consistent DMD (AC-DMD) for the cold start and a hybrid policy gradient (SubGRPO + final-step reward back-propagation) for the RL stage.


🍭 Method Overview

RTDMD method overview
RTDMD overview. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).

πŸ“¦ Contents

This repository hosts the 4-NFE LoRA checkpoints distilled from Stable Diffusion 3.5 Medium with RTDMD.

.
β”œβ”€β”€ cold_start/
β”‚   └── generator_ema.pt    # Stage-1 AC-DMD LoRA (4 NFE base)
└── rtdmd/
    └── generator_ema.pt    # Stage-2 RTDMD LoRA (stacked on top of cold_start)

Each generator_ema.pt is a state_dict containing LoRA adapter keys (rank 32, alpha 64). The two adapters are designed to be stacked: the cold-start LoRA distills the model down to 4 NFE, and the RTDMD LoRA further fine-tunes it with reward-tilted RL.


πŸš€ Usage

Option 1 β€” RTDMD inference CLI (recommended)

For exact reproduction of the paper numbers, please use the official RTDMD repository.

Option 2 β€” Plain diffusers

You can use these LoRAs with the diffusers library as follows:

import torch
from diffusers import StableDiffusion3Pipeline
from peft import LoraConfig
from huggingface_hub import hf_hub_download

base = "stabilityai/stable-diffusion-3.5-medium"
pipe = StableDiffusion3Pipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda")

# Inject LoRA adapters with the rank/alpha used during training
TARGETS = [
    "to_q", "to_k", "to_v", "to_out.0",
    "add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
]
pipe.transformer.add_adapter(
    LoraConfig(r=32, lora_alpha=64, target_modules=TARGETS, init_lora_weights="gaussian")
)

# Sequentially load cold-start then RTDMD weights into the same adapter
for ckpt in ["cold_start/generator_ema.pt", "rtdmd/generator_ema.pt"]:
    path = hf_hub_download("Harahan/SD35M-RTDMD", ckpt)
    state = torch.load(path, map_location="cpu", weights_only=False)
    pipe.transformer.load_state_dict(state, strict=False)

# 4-step sampling
# Note: RTDMD is trained on the CPS scheduler with Ξ· = 0.9. 
# Default Flow-Matching Euler will still produce reasonable samples.
pipe(prompt="a cute cat sitting on a windowsill",
     num_inference_steps=4, guidance_scale=1.0).images[0].save("out.png")

πŸ“„ Citation

@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
      title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, 
      author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
      year={2026},
      eprint={2605.26108},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2605.26108}, 
}

βš–οΈ License

Apache 2.0. The base model stabilityai/stable-diffusion-3.5-medium is governed by its own license; please review and comply with it separately.