---
license: apache-2.0
library_name: pytorch
pipeline_tag: image-to-image
tags:
- style-transfer
- visual-autoregressive
- VAR
- VQ-VAE
- GRPO
- reinforcement-learning
- image-generation
datasets:
- DiffSynth-Studio/OmniStyle
- DiffSynth-Studio/ImagePulse-StyleTransfer
---
# StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales (1×1 → 16×16, L=680 tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
The model is trained in two stages from a pretrained vanilla VAR checkpoint:
- **Stage 1 — SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
- **Stage 2 — GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters (rank r=256) and Per-Action Normalization Weighting (PANW) to rebalance credit across the 256× token imbalance between coarse and fine scales.
- **Code:** https://github.com/Senfier-LiqiJing/StyleVAR
- **Paper (arXiv):** https://arxiv.org/abs/2604.21052
- **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University)
---
## Files in this repository
| File | Purpose | Notes |
|---|---|---|
| `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, C_vae=32, V=4096, ch=160). Shared between SFT and GRPO. |
| `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | State dict with optimizer state. Use this for the SFT baseline. |
| `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
Both transformer checkpoints ship as plain state dicts under the `"model"` key — LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers.
---
## Quick start
```bash
pip install torch torchvision huggingface_hub pillow
git clone https://github.com/Senfier-LiqiJing/StyleVAR.git
cd StyleVAR
```
Download the checkpoints:
```python
from huggingface_hub import hf_hub_download
REPO = "Senfier-LiqiJing/StyleVAR"
vae_path = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt")
sft_path = hf_hub_download(REPO, "StyleVAR_SFT.pth", local_dir="ckpt")
grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth", local_dir="ckpt")
```
Or via the CLI:
```bash
hf download Senfier-LiqiJing/StyleVAR \
vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \
--local-dir ckpt
```
Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo):
```bash
python eval/infer_grpo.py \
--vae_ckpt ckpt/vae_ch160v4096z32.pth \
--sft_ckpt ckpt/StyleVAR-GRPO.pth \
--sft_only \
--num 8 --out grpo_infer_results.png
```
> `--sft_only` here means *"do not attach a LoRA adapter on top"* — the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline.
Minimal Python usage:
```python
import torch
from models import build_vae_stylevar
device = torch.device("cuda")
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
vae, model = build_vae_stylevar(
device=device, patch_nums=patch_nums,
V=4096, Cvae=32, ch=160, share_quant_resi=4,
depth=20, shared_aln=False, attn_l2_norm=True,
flash_if_available=True, fused_if_available=True,
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,
style_enc_dim=512,
)
vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True)
ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu")
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
# content, style: (B, 3, 256, 256) tensors in [-1, 1]
with torch.no_grad():
out = model.autoregressive_infer(
B=content.size(0), style_img=style, content_img=content,
top_k=900, top_p=0.96, g_seed=42,
)
```
---
## Training recipe
### Stage 1 — SFT
| Setting | Value |
|---|---|
| Backbone | VAR depth-20 transformer (~600M params) |
| Trainable | Full transformer (VQ-VAE frozen) |
| Epochs | 10 |
| Learning rate | 5×10⁻⁴ (epochs 1-6) → 1×10⁻⁴ (epochs 7-10) |
| Batch size | 128 (with gradient accumulation) |
| Hardware | 1× NVIDIA RTX 4090 (48GB) |
### Stage 2 — GRPO
| Setting | Value |
|---|---|
| LoRA | rank r=256, α/r=2 on every attention (W_Q_target, W_QKV_cond, W_proj) and FFN linear — 131M trainable (18.2% of backbone) |
| Reward | R = −λ · DreamSim(x̂, x), λ=5.0 |
| Group size G | 16 |
| Sampling | top-k=900, top-p=0.96 |
| Clip / KL / PANW | ε=0.2, β=0.1, γ=0.7 |
| Optimizer | AdamW, lr 1×10⁻⁵, wd 0.01, (β₁, β₂)=(0.9, 0.95), FP32 |
| Iterative merge | peak-triggered, τ_gain=0.05 / τ_patience=50 / 300-step cool-down |
| Emergency merge | raw KL > 2.0, 50-step cool-down |
| Hardware | 1× NVIDIA RTX 4090 (48GB), physical batch 16, G=16 serial rollouts |
---
## Datasets
Both stages use a concatenation of two paired style-transfer datasets:
- **OmniStyle-150K** — 143,992 (content, style, target) triplets.
- **ImagePulse-StyleTransfer** — 137,886 triplets.
Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the G samples in a group.
---
## Results
Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher ↑ / lower ↓ is better. **Best** in bold, second best underlined.
| Dataset | Method | Style Loss ↓ | Content Loss ↓ | LPIPS ↓ | SSIM ↑ | DreamSim ↓ | CLIP Sim ↑ | Infer (s) ↓ |
|---|---|---|---|---|---|---|---|---|
| **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
| | StyleVAR (SFT) | 0.0468 | 116.3569 | 0.4743 | 0.3975 | 0.2276 | 0.8704 | 0.4031 |
| | StyleVAR (GRPO) | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 |
| **ImagePulse** | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** |
| | StyleVAR (SFT) | 0.0452 | **180.7923** | 0.5618 | 0.4282 | 0.3168 | 0.7903 | 0.4031 |
| | StyleVAR (GRPO) | **0.0387** | 182.0954 | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 |
| **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** |
| | StyleVAR (SFT) | 0.0206 | 160.1233 | 0.7398 | **0.2713** | 0.6986 | 0.5308 | 0.4031 |
| | StyleVAR (GRPO) | **0.0199** | **157.5109** | **0.7286** | 0.2677 | **0.6793** | **0.5335** | 0.4031 |
Inference time measured on a single NVIDIA A100 (40GB).
---
## Limitations
- **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool.
- **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces.
- **Sampling cost.** 10-scale autoregressive decoding is ~128× slower than AdaIN.
---
## Intended use
Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles.
---
## Citation
```bibtex
@article{jing2026stylevar,
title = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling},
author = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen},
year = {2026},
note = {Duke University}
}
```
## Acknowledgments
Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer).