File size: 8,843 Bytes
9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 5e8078c 33a56eb 9bfa7f8 9f6fa31 33a56eb 9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 9f6fa31 9bfa7f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | ---
license: apache-2.0
library_name: pytorch
pipeline_tag: image-to-image
tags:
- style-transfer
- visual-autoregressive
- VAR
- VQ-VAE
- GRPO
- reinforcement-learning
- image-generation
datasets:
- DiffSynth-Studio/OmniStyle
- DiffSynth-Studio/ImagePulse-StyleTransfer
---
# StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales (1Γ1 β 16Γ16, L=680 tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
The model is trained in two stages from a pretrained vanilla VAR checkpoint:
- **Stage 1 β SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
- **Stage 2 β GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters (rank r=256) and Per-Action Normalization Weighting (PANW) to rebalance credit across the 256Γ token imbalance between coarse and fine scales.
- **Code:** https://github.com/Senfier-LiqiJing/StyleVAR
- **Paper (arXiv):** https://arxiv.org/abs/2604.21052
- **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University)
---
## Files in this repository
| File | Purpose | Notes |
|---|---|---|
| `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, C_vae=32, V=4096, ch=160). Shared between SFT and GRPO. |
| `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | State dict with optimizer state. Use this for the SFT baseline. |
| `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
Both transformer checkpoints ship as plain state dicts under the `"model"` key β LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers.
---
## Quick start
```bash
pip install torch torchvision huggingface_hub pillow
git clone https://github.com/Senfier-LiqiJing/StyleVAR.git
cd StyleVAR
```
Download the checkpoints:
```python
from huggingface_hub import hf_hub_download
REPO = "Senfier-LiqiJing/StyleVAR"
vae_path = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt")
sft_path = hf_hub_download(REPO, "StyleVAR_SFT.pth", local_dir="ckpt")
grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth", local_dir="ckpt")
```
Or via the CLI:
```bash
hf download Senfier-LiqiJing/StyleVAR \
vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \
--local-dir ckpt
```
Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo):
```bash
python eval/infer_grpo.py \
--vae_ckpt ckpt/vae_ch160v4096z32.pth \
--sft_ckpt ckpt/StyleVAR-GRPO.pth \
--sft_only \
--num 8 --out grpo_infer_results.png
```
> `--sft_only` here means *"do not attach a LoRA adapter on top"* β the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline.
Minimal Python usage:
```python
import torch
from models import build_vae_stylevar
device = torch.device("cuda")
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
vae, model = build_vae_stylevar(
device=device, patch_nums=patch_nums,
V=4096, Cvae=32, ch=160, share_quant_resi=4,
depth=20, shared_aln=False, attn_l2_norm=True,
flash_if_available=True, fused_if_available=True,
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,
style_enc_dim=512,
)
vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True)
ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu")
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
# content, style: (B, 3, 256, 256) tensors in [-1, 1]
with torch.no_grad():
out = model.autoregressive_infer(
B=content.size(0), style_img=style, content_img=content,
top_k=900, top_p=0.96, g_seed=42,
)
```
---
## Training recipe
### Stage 1 β SFT
| Setting | Value |
|---|---|
| Backbone | VAR depth-20 transformer (~600M params) |
| Trainable | Full transformer (VQ-VAE frozen) |
| Epochs | 10 |
| Learning rate | 5Γ10β»β΄ (epochs 1-6) β 1Γ10β»β΄ (epochs 7-10) |
| Batch size | 128 (with gradient accumulation) |
| Hardware | 1Γ NVIDIA RTX 4090 (48GB) |
### Stage 2 β GRPO
| Setting | Value |
|---|---|
| LoRA | rank r=256, Ξ±/r=2 on every attention (W_Q_target, W_QKV_cond, W_proj) and FFN linear β 131M trainable (18.2% of backbone) |
| Reward | R = βΞ» Β· DreamSim(xΜ, x), Ξ»=5.0 |
| Group size G | 16 |
| Sampling | top-k=900, top-p=0.96 |
| Clip / KL / PANW | Ξ΅=0.2, Ξ²=0.1, Ξ³=0.7 |
| Optimizer | AdamW, lr 1Γ10β»β΅, wd 0.01, (Ξ²β, Ξ²β)=(0.9, 0.95), FP32 |
| Iterative merge | peak-triggered, Ο_gain=0.05 / Ο_patience=50 / 300-step cool-down |
| Emergency merge | raw KL > 2.0, 50-step cool-down |
| Hardware | 1Γ NVIDIA RTX 4090 (48GB), physical batch 16, G=16 serial rollouts |
---
## Datasets
Both stages use a concatenation of two paired style-transfer datasets:
- **OmniStyle-150K** β 143,992 (content, style, target) triplets.
- **ImagePulse-StyleTransfer** β 137,886 triplets.
Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the G samples in a group.
---
## Results
Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher β / lower β is better. **Best** in bold, <ins>second best</ins> underlined.
| Dataset | Method | Style Loss β | Content Loss β | LPIPS β | SSIM β | DreamSim β | CLIP Sim β | Infer (s) β |
|---|---|---|---|---|---|---|---|---|
| **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
| | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |
| | StyleVAR (GRPO) | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 |
| **ImagePulse** | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** |
| | StyleVAR (SFT) | <ins>0.0452</ins> | **180.7923** | <ins>0.5618</ins> | <ins>0.4282</ins> | <ins>0.3168</ins> | <ins>0.7903</ins> | 0.4031 |
| | StyleVAR (GRPO) | **0.0387** | <ins>182.0954</ins> | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 |
| **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** |
| | StyleVAR (SFT) | <ins>0.0206</ins> | <ins>160.1233</ins> | <ins>0.7398</ins> | **0.2713** | <ins>0.6986</ins> | <ins>0.5308</ins> | 0.4031 |
| | StyleVAR (GRPO) | **0.0199** | **157.5109** | **0.7286** | <ins>0.2677</ins> | **0.6793** | **0.5335** | 0.4031 |
Inference time measured on a single NVIDIA A100 (40GB).
---
## Limitations
- **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool.
- **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces.
- **Sampling cost.** 10-scale autoregressive decoding is ~128Γ slower than AdaIN.
---
## Intended use
Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles.
---
## Citation
```bibtex
@article{jing2026stylevar,
title = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling},
author = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen},
year = {2026},
note = {Duke University}
}
```
## Acknowledgments
Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer).
|