UniDG-RFT-LoRA
LoRA weights for UniDG (Universal Defect Generation), trained via Consistency-RFT with Flow-GRPO and dual reward models on the UDG dataset (300K quadruplets).
[Paper] [Code] [UniDG-SFT-LoRA]
Overview
UniDG is a universal defect generation foundation model that transfers defects from a reference image to a target region via Defect-Context Editing and MM-DiT multimodal attention, without per-category fine-tuning. This checkpoint is the Consistency-RFT variant, further refined from UniDG-SFT using Flow-GRPO with dual reward models (Defect-Und-Reward & Defect-Recog-Reward) for improved defect fidelity and consistency.
| Variant | Training | Focus |
|---|---|---|
| UniDG-SFT | Diversity-SFT with complementary sampling | Diverse defect patterns |
| UniDG-RFT (this) | Consistency-RFT with Flow-GRPO + dual rewards | Consistent & faithful defects |
Important: Usage Difference from UniDG-SFT-LoRA
The UniDG-RFT-LoRA weights are stored in PEFT format (adapter_model.safetensors + adapter_config.json), which is different from UniDG-SFT-LoRA (which uses pytorch_lora_weights.safetensors). This means:
- UniDG-SFT-LoRA can be directly loaded via the
lora_weights_pathparameter inImageUniDG. - UniDG-RFT-LoRA must first be merged into the base SFT model using the provided
combine_peft_weights.pyscript. After merging, the resulting model can be loaded directly without any additional LoRA loading step.
Repository Contents
| File | Description |
|---|---|
adapter_model.safetensors |
PEFT LoRA weights (Consistency-RFT) |
adapter_config.json |
LoRA configuration (rank=64, alpha=128) |
combine_peft_weights.py |
Script to merge RFT LoRA into the base SFT model |
Step-by-Step Usage
Prerequisites
- FLUX.1-Fill-dev (inpainting backbone)
- FLUX.1-Redux-dev (reference conditioning)
- UniDG-SFT-LoRA (base SFT model — the RFT LoRA is fine-tuned on top of this)
- UniDG code (inference framework)
- Python dependencies:
diffusers,peft,torch
Step 1: Prepare the Base SFT Model
First, you need a base FLUX.1-Fill-dev model with UniDG-SFT-LoRA weights already merged in. If you haven't done this, you can prepare it by loading the SFT model and saving the merged weights:
from diffusers import FluxFillPipeline
import torch
# Load base FLUX.1-Fill-dev
pipe = FluxFillPipeline.from_pretrained(
"path/to/FLUX.1-Fill-dev",
torch_dtype=torch.bfloat16,
)
# Load SFT LoRA weights
pipe.load_lora_weights("path/to/UniDG-SFT-LoRA-Release/pytorch_lora_weights.safetensors")
# Save the merged SFT model as the base for RFT merging
pipe.save_pretrained("path/to/FLUX.1-Fill-dev-UDG-SFT", safe_serialization=True, max_shard_size="5GB")
Step 2: Merge RFT LoRA into the Base SFT Model
Use the provided combine_peft_weights.py to merge the RFT LoRA weights into the base SFT model:
python combine_peft_weights.py \
--base_model_path path/to/FLUX.1-Fill-dev-UDG-SFT \
--lora_weights_path path/to/UniDG-RFT-LoRA-Release \
--output_path path/to/FLUX.1-Fill-dev-UDG-RFT \
--save_full_pipeline
Parameters:
--base_model_path: Path to the base SFT model (from Step 1)--lora_weights_path: Path to this RFT LoRA repository (containingadapter_model.safetensorsandadapter_config.json)--output_path: Output path for the merged model--save_full_pipeline: Save the full pipeline (including VAE, text encoder, etc.) so you can load it directly later--dtype: Data type, defaultbfloat16--device: Device for loading, defaultcpu(recommended to avoid OOM)
Tip: Use
--device cpu(default) to save GPU memory during the merge process. The merge only needs to run once.
Step 3: Use the Merged Model with UniDG
After merging, the model can be used directly with the UniDG inference code — no additional LoRA loading is needed:
from unidg import ImageUniDG
from PIL import Image
import torch
# Load the merged RFT model — set lora_weights_path="" since LoRA is already merged
model = ImageUniDG(
flux_model_path="path/to/FLUX.1-Fill-dev-UDG-RFT",
redux_model_path="path/to/FLUX.1-Redux-dev",
lora_weights_path="", # No additional LoRA needed!
device="cuda:0",
dtype=torch.bfloat16,
)
result, mask = model.process_images(
target_image=Image.open("target.jpg"),
reference_image=Image.open("reference.jpg"),
reference_mask=Image.open("reference_mask.png"),
target_mask=Image.open("target_mask.png"),
num_inference_steps=28,
guidance_scale=3.5,
seed=42,
)
result.save("result.png")
Quick Reference: SFT vs RFT Usage
| UniDG-SFT | UniDG-RFT | |
|---|---|---|
| Weight format | pytorch_lora_weights.safetensors |
adapter_model.safetensors + adapter_config.json |
| Merge required? | No | Yes (with SFT base model) |
lora_weights_path |
Path to SFT weights | "" (empty, after merge) |
flux_model_path |
path/to/FLUX.1-Fill-dev |
path/to/merged-RFT-model |
| Load time | LoRA loaded on-the-fly | Pre-merged, no LoRA overhead |
LoRA Configuration
| Parameter | Value |
|---|---|
| PEFT type | LORA |
| Rank (r) | 64 |
| Alpha | 128 |
| Dropout | 0.0 |
| Target modules | ff.net.0.proj, ff.net.2, ff_context.net.0.proj, proj_mlp, attn.to_q, attn.to_v, attn.to_add_out, attn.add_k_proj, attn.add_v_proj, ff_context.net.2, attn.add_q_proj, attn.to_out.0, attn.to_k |
| Base model | FLUX.1-Fill-dev + UniDG-SFT-LoRA |
Citation
@article{fan2026unidg,
title={Large-Scale Universal Defect Generation: Foundation Models and Datasets},
author={Fan, Yuanting and Liu, Jun and Gao, Bin-Bin and Chen, Xiaochen and Lin, Yuhuan and Dai, Zhewei and Zhan, Jiawei and Wang, Chengjie},
journal={arXiv preprint arXiv:2604.08915},
year={2026}
}