File size: 8,843 Bytes
9bfa7f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f6fa31
9bfa7f8
 
 
9f6fa31
9bfa7f8
5e8078c
33a56eb
9bfa7f8
 
 
 
 
 
 
 
9f6fa31
33a56eb
9bfa7f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f6fa31
9bfa7f8
 
 
 
 
 
 
9f6fa31
 
 
 
 
 
 
9bfa7f8
9f6fa31
9bfa7f8
 
 
 
 
 
 
 
 
 
9f6fa31
9bfa7f8
 
 
 
 
9f6fa31
9bfa7f8
9f6fa31
9bfa7f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
---
license: apache-2.0
library_name: pytorch
pipeline_tag: image-to-image
tags:
  - style-transfer
  - visual-autoregressive
  - VAR
  - VQ-VAE
  - GRPO
  - reinforcement-learning
  - image-generation
datasets:
  - DiffSynth-Studio/OmniStyle
  - DiffSynth-Studio/ImagePulse-StyleTransfer
---

# StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling

Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales (1Γ—1 β†’ 16Γ—16, L=680 tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.

The model is trained in two stages from a pretrained vanilla VAR checkpoint:
- **Stage 1 β€” SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
- **Stage 2 β€” GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters (rank r=256) and Per-Action Normalization Weighting (PANW) to rebalance credit across the 256Γ— token imbalance between coarse and fine scales.

- **Code:** https://github.com/Senfier-LiqiJing/StyleVAR
- **Paper (arXiv):** https://arxiv.org/abs/2604.21052
- **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University)

---

## Files in this repository

| File | Purpose | Notes |
|---|---|---|
| `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, C_vae=32, V=4096, ch=160). Shared between SFT and GRPO. |
| `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | State dict with optimizer state. Use this for the SFT baseline. |
| `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |

Both transformer checkpoints ship as plain state dicts under the `"model"` key β€” LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers.

---

## Quick start

```bash
pip install torch torchvision huggingface_hub pillow
git clone https://github.com/Senfier-LiqiJing/StyleVAR.git
cd StyleVAR
```

Download the checkpoints:

```python
from huggingface_hub import hf_hub_download

REPO = "Senfier-LiqiJing/StyleVAR"
vae_path  = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt")
sft_path  = hf_hub_download(REPO, "StyleVAR_SFT.pth",      local_dir="ckpt")
grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth",     local_dir="ckpt")
```

Or via the CLI:

```bash
hf download Senfier-LiqiJing/StyleVAR \
  vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \
  --local-dir ckpt
```

Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo):

```bash
python eval/infer_grpo.py \
  --vae_ckpt  ckpt/vae_ch160v4096z32.pth \
  --sft_ckpt  ckpt/StyleVAR-GRPO.pth \
  --sft_only \
  --num 8 --out grpo_infer_results.png
```

> `--sft_only` here means *"do not attach a LoRA adapter on top"* β€” the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline.

Minimal Python usage:

```python
import torch
from models import build_vae_stylevar

device = torch.device("cuda")
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

vae, model = build_vae_stylevar(
    device=device, patch_nums=patch_nums,
    V=4096, Cvae=32, ch=160, share_quant_resi=4,
    depth=20, shared_aln=False, attn_l2_norm=True,
    flash_if_available=True, fused_if_available=True,
    init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,
    style_enc_dim=512,
)
vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True)

ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu")
model.load_state_dict(ckpt["model"], strict=True)
model.eval()

# content, style: (B, 3, 256, 256) tensors in [-1, 1]
with torch.no_grad():
    out = model.autoregressive_infer(
        B=content.size(0), style_img=style, content_img=content,
        top_k=900, top_p=0.96, g_seed=42,
    )
```

---

## Training recipe

### Stage 1 β€” SFT

| Setting | Value |
|---|---|
| Backbone | VAR depth-20 transformer (~600M params) |
| Trainable | Full transformer (VQ-VAE frozen) |
| Epochs | 10 |
| Learning rate | 5Γ—10⁻⁴ (epochs 1-6) β†’ 1Γ—10⁻⁴ (epochs 7-10) |
| Batch size | 128 (with gradient accumulation) |
| Hardware | 1Γ— NVIDIA RTX 4090 (48GB) |

### Stage 2 β€” GRPO

| Setting | Value |
|---|---|
| LoRA | rank r=256, Ξ±/r=2 on every attention (W_Q_target, W_QKV_cond, W_proj) and FFN linear β€” 131M trainable (18.2% of backbone) |
| Reward | R = βˆ’Ξ» Β· DreamSim(xΜ‚, x), Ξ»=5.0 |
| Group size G | 16 |
| Sampling | top-k=900, top-p=0.96 |
| Clip / KL / PANW | Ξ΅=0.2, Ξ²=0.1, Ξ³=0.7 |
| Optimizer | AdamW, lr 1Γ—10⁻⁡, wd 0.01, (β₁, Ξ²β‚‚)=(0.9, 0.95), FP32 |
| Iterative merge | peak-triggered, Ο„_gain=0.05 / Ο„_patience=50 / 300-step cool-down |
| Emergency merge | raw KL > 2.0, 50-step cool-down |
| Hardware | 1Γ— NVIDIA RTX 4090 (48GB), physical batch 16, G=16 serial rollouts |

---

## Datasets

Both stages use a concatenation of two paired style-transfer datasets:

- **OmniStyle-150K** β€” 143,992 (content, style, target) triplets.
- **ImagePulse-StyleTransfer** β€” 137,886 triplets.

Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the G samples in a group.

---

## Results

Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher ↑ / lower ↓ is better. **Best** in bold, <ins>second best</ins> underlined.

| Dataset | Method | Style Loss ↓ | Content Loss ↓ | LPIPS ↓ | SSIM ↑ | DreamSim ↓ | CLIP Sim ↑ | Infer (s) ↓ |
|---|---|---|---|---|---|---|---|---|
| **OmniStyle**    | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
|                  | StyleVAR (SFT)   | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |
|                  | StyleVAR (GRPO)  | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 |
| **ImagePulse**   | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** |
|                  | StyleVAR (SFT)   | <ins>0.0452</ins> | **180.7923** | <ins>0.5618</ins> | <ins>0.4282</ins> | <ins>0.3168</ins> | <ins>0.7903</ins> | 0.4031 |
|                  | StyleVAR (GRPO)  | **0.0387** | <ins>182.0954</ins> | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 |
| **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** |
|                  | StyleVAR (SFT)   | <ins>0.0206</ins> | <ins>160.1233</ins> | <ins>0.7398</ins> | **0.2713** | <ins>0.6986</ins> | <ins>0.5308</ins> | 0.4031 |
|                  | StyleVAR (GRPO)  | **0.0199** | **157.5109** | **0.7286** | <ins>0.2677</ins> | **0.6793** | **0.5335** | 0.4031 |

Inference time measured on a single NVIDIA A100 (40GB).

---

## Limitations

- **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool.
- **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces.
- **Sampling cost.** 10-scale autoregressive decoding is ~128Γ— slower than AdaIN.

---

## Intended use

Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles.

---

## Citation

```bibtex
@article{jing2026stylevar,
  title   = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling},
  author  = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen},
  year    = {2026},
  note    = {Duke University}
}
```

## Acknowledgments

Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer).