Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,202 @@
|
|
| 1 |
-
---
|
| 2 |
-
license:
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
pipeline_tag: image-to-image
|
| 5 |
+
tags:
|
| 6 |
+
- style-transfer
|
| 7 |
+
- visual-autoregressive
|
| 8 |
+
- VAR
|
| 9 |
+
- VQ-VAE
|
| 10 |
+
- GRPO
|
| 11 |
+
- reinforcement-learning
|
| 12 |
+
- image-generation
|
| 13 |
+
datasets:
|
| 14 |
+
- DiffSynth-Studio/OmniStyle
|
| 15 |
+
- DiffSynth-Studio/ImagePulse-StyleTransfer
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
|
| 19 |
+
|
| 20 |
+
Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales ($1{\times}1 \to 16{\times}16$, $L=680$ tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
|
| 21 |
+
|
| 22 |
+
The model is trained in two stages from a pretrained vanilla VAR checkpoint:
|
| 23 |
+
- **Stage 1 β SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
|
| 24 |
+
- **Stage 2 β GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters ($r{=}256$) and Per-Action Normalization Weighting (PANW) to rebalance credit across the $256\times$ token imbalance between coarse and fine scales.
|
| 25 |
+
|
| 26 |
+
- **Paper / code:** https://github.com/Senfier-LiqiJing/StyleVAR
|
| 27 |
+
- **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University)
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## Files in this repository
|
| 32 |
+
|
| 33 |
+
| File | Purpose | Notes |
|
| 34 |
+
|---|---|---|
|
| 35 |
+
| `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, $C_\text{vae}{=}32$, $V{=}4096$, $\text{ch}{=}160$). Shared between SFT and GRPO. |
|
| 36 |
+
| `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | Plain state dict β no LoRA, no optimizer state. Use this for the SFT baseline. |
|
| 37 |
+
| `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
|
| 38 |
+
|
| 39 |
+
Both transformer checkpoints ship as plain state dicts under the `"model"` key β LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers.
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## Quick start
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install torch torchvision huggingface_hub pillow
|
| 47 |
+
git clone https://github.com/Senfier-LiqiJing/StyleVAR.git
|
| 48 |
+
cd StyleVAR
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Download the checkpoints:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
|
| 56 |
+
REPO = "Senfier-LiqiJing/StyleVAR"
|
| 57 |
+
vae_path = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt")
|
| 58 |
+
sft_path = hf_hub_download(REPO, "StyleVAR_SFT.pth", local_dir="ckpt")
|
| 59 |
+
grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth", local_dir="ckpt")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Or via the CLI:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
hf download Senfier-LiqiJing/StyleVAR \
|
| 66 |
+
vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \
|
| 67 |
+
--local-dir ckpt
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo):
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python eval/infer_grpo.py \
|
| 74 |
+
--vae_ckpt ckpt/vae_ch160v4096z32.pth \
|
| 75 |
+
--sft_ckpt ckpt/StyleVAR-GRPO.pth \
|
| 76 |
+
--sft_only \
|
| 77 |
+
--num 8 --out grpo_infer_results.png
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
> `--sft_only` here means *"do not attach a LoRA adapter on top"* β the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline.
|
| 81 |
+
|
| 82 |
+
Minimal Python usage:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
import torch
|
| 86 |
+
from models import build_vae_stylevar
|
| 87 |
+
|
| 88 |
+
device = torch.device("cuda")
|
| 89 |
+
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
|
| 90 |
+
|
| 91 |
+
vae, model = build_vae_stylevar(
|
| 92 |
+
device=device, patch_nums=patch_nums,
|
| 93 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
| 94 |
+
depth=20, shared_aln=False, attn_l2_norm=True,
|
| 95 |
+
flash_if_available=True, fused_if_available=True,
|
| 96 |
+
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,
|
| 97 |
+
style_enc_dim=512,
|
| 98 |
+
)
|
| 99 |
+
vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True)
|
| 100 |
+
|
| 101 |
+
ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu")
|
| 102 |
+
model.load_state_dict(ckpt["model"], strict=True)
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
+
# content, style: (B, 3, 256, 256) tensors in [-1, 1]
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
out = model.autoregressive_infer(
|
| 108 |
+
B=content.size(0), style_img=style, content_img=content,
|
| 109 |
+
top_k=900, top_p=0.96, g_seed=42,
|
| 110 |
+
)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Training recipe
|
| 116 |
+
|
| 117 |
+
### Stage 1 β SFT
|
| 118 |
+
|
| 119 |
+
| Setting | Value |
|
| 120 |
+
|---|---|
|
| 121 |
+
| Backbone | VAR depth-20 transformer (~600M params) |
|
| 122 |
+
| Trainable | Full transformer (VQ-VAE frozen) |
|
| 123 |
+
| Epochs | 10 |
|
| 124 |
+
| Learning rate | $5\times10^{-4}$ (epochs 1-6) β $1\times10^{-4}$ (epochs 7-10) |
|
| 125 |
+
| Batch size | 128 (with gradient accumulation) |
|
| 126 |
+
| Hardware | 1Γ NVIDIA RTX 4090 (48GB) |
|
| 127 |
+
|
| 128 |
+
### Stage 2 β GRPO
|
| 129 |
+
|
| 130 |
+
| Setting | Value |
|
| 131 |
+
|---|---|
|
| 132 |
+
| LoRA | $r{=}256$, $\alpha/r{=}2$ on every attention ($W_Q^\text{target}$, $W_{QKV}^\text{cond}$, $W_\text{proj}$) and FFN linear β 131M trainable (18.2% of backbone) |
|
| 133 |
+
| Reward | $R = -\lambda \cdot \text{DreamSim}(\hat{x}, x)$, $\lambda{=}5.0$ |
|
| 134 |
+
| Group size $G$ | 16 |
|
| 135 |
+
| Sampling | top-$k{=}900$, top-$p{=}0.96$ |
|
| 136 |
+
| Clip / KL / PANW | $\varepsilon{=}0.2$, $\beta{=}0.1$, $\gamma{=}0.7$ |
|
| 137 |
+
| Optimizer | AdamW, lr $1\times10^{-5}$, wd 0.01, $(\beta_1, \beta_2){=}(0.9, 0.95)$, FP32 |
|
| 138 |
+
| Iterative merge | peak-triggered, $\tau_\text{gain}{=}0.05$ / $\tau_\text{patience}{=}50$ / 300-step cool-down |
|
| 139 |
+
| Emergency merge | raw KL > 2.0, 50-step cool-down |
|
| 140 |
+
| Hardware | 1Γ NVIDIA RTX 4090 (48GB), physical batch 16, $G{=}16$ serial rollouts |
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## Datasets
|
| 145 |
+
|
| 146 |
+
Both stages use a concatenation of two paired style-transfer datasets:
|
| 147 |
+
|
| 148 |
+
- **OmniStyle-150K** β 143,992 (content, style, target) triplets.
|
| 149 |
+
- **ImagePulse-StyleTransfer** β 137,886 triplets.
|
| 150 |
+
|
| 151 |
+
Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the $G$ samples in a group.
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## Results
|
| 156 |
+
|
| 157 |
+
Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher $\uparrow$ / lower $\downarrow$ is better. **Best** in bold, <ins>second best</ins> underlined.
|
| 158 |
+
|
| 159 |
+
| Dataset | Method | Style Loss $\downarrow$ | Content Loss $\downarrow$ | LPIPS $\downarrow$ | SSIM $\uparrow$ | DreamSim $\downarrow$ | CLIP Sim $\uparrow$ | Infer (s) $\downarrow$ |
|
| 160 |
+
|---|---|---|---|---|---|---|---|---|
|
| 161 |
+
| **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
|
| 162 |
+
| | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |
|
| 163 |
+
| | StyleVAR (GRPO) | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 |
|
| 164 |
+
| **ImagePulse** | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** |
|
| 165 |
+
| | StyleVAR (SFT) | <ins>0.0452</ins> | **180.7923** | <ins>0.5618</ins> | <ins>0.4282</ins> | <ins>0.3168</ins> | <ins>0.7903</ins> | 0.4031 |
|
| 166 |
+
| | StyleVAR (GRPO) | **0.0387** | <ins>182.0954</ins> | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 |
|
| 167 |
+
| **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** |
|
| 168 |
+
| | StyleVAR (SFT) | <ins>0.0206</ins> | <ins>160.1233</ins> | <ins>0.7398</ins> | **0.2713** | <ins>0.6986</ins> | <ins>0.5308</ins> | 0.4031 |
|
| 169 |
+
| | StyleVAR (GRPO) | **0.0199** | **157.5109** | **0.7286** | <ins>0.2677</ins> | **0.6793** | **0.5335** | 0.4031 |
|
| 170 |
+
|
| 171 |
+
Inference time measured on a single NVIDIA A100 (40GB).
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## Limitations
|
| 176 |
+
|
| 177 |
+
- **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool.
|
| 178 |
+
- **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces.
|
| 179 |
+
- **Sampling cost.** 10-scale autoregressive decoding is ~128Γ slower than AdaIN.
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## Intended use
|
| 184 |
+
|
| 185 |
+
Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## Citation
|
| 190 |
+
|
| 191 |
+
```bibtex
|
| 192 |
+
@article{jing2026stylevar,
|
| 193 |
+
title = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling},
|
| 194 |
+
author = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen},
|
| 195 |
+
year = {2026},
|
| 196 |
+
note = {Duke University}
|
| 197 |
+
}
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
## Acknowledgments
|
| 201 |
+
|
| 202 |
+
Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer).
|