| --- |
| license: apache-2.0 |
| library_name: pytorch |
| pipeline_tag: image-to-image |
| tags: |
| - style-transfer |
| - visual-autoregressive |
| - VAR |
| - VQ-VAE |
| - GRPO |
| - reinforcement-learning |
| - image-generation |
| datasets: |
| - DiffSynth-Studio/OmniStyle |
| - DiffSynth-Studio/ImagePulse-StyleTransfer |
| --- |
| |
| # StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling |
|
|
| Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales (1Γ1 β 16Γ16, L=680 tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture. |
|
|
| The model is trained in two stages from a pretrained vanilla VAR checkpoint: |
| - **Stage 1 β SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets. |
| - **Stage 2 β GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters (rank r=256) and Per-Action Normalization Weighting (PANW) to rebalance credit across the 256Γ token imbalance between coarse and fine scales. |
|
|
| - **Code:** https://github.com/Senfier-LiqiJing/StyleVAR |
| - **Paper (arXiv):** https://arxiv.org/abs/2604.21052 |
| - **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University) |
|
|
| --- |
|
|
| ## Files in this repository |
|
|
| | File | Purpose | Notes | |
| |---|---|---| |
| | `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, C_vae=32, V=4096, ch=160). Shared between SFT and GRPO. | |
| | `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | State dict with optimizer state. Use this for the SFT baseline. | |
| | `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. | |
|
|
| Both transformer checkpoints ship as plain state dicts under the `"model"` key β LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers. |
|
|
| --- |
|
|
| ## Quick start |
|
|
| ```bash |
| pip install torch torchvision huggingface_hub pillow |
| git clone https://github.com/Senfier-LiqiJing/StyleVAR.git |
| cd StyleVAR |
| ``` |
|
|
| Download the checkpoints: |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| |
| REPO = "Senfier-LiqiJing/StyleVAR" |
| vae_path = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt") |
| sft_path = hf_hub_download(REPO, "StyleVAR_SFT.pth", local_dir="ckpt") |
| grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth", local_dir="ckpt") |
| ``` |
|
|
| Or via the CLI: |
|
|
| ```bash |
| hf download Senfier-LiqiJing/StyleVAR \ |
| vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \ |
| --local-dir ckpt |
| ``` |
|
|
| Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo): |
|
|
| ```bash |
| python eval/infer_grpo.py \ |
| --vae_ckpt ckpt/vae_ch160v4096z32.pth \ |
| --sft_ckpt ckpt/StyleVAR-GRPO.pth \ |
| --sft_only \ |
| --num 8 --out grpo_infer_results.png |
| ``` |
|
|
| > `--sft_only` here means *"do not attach a LoRA adapter on top"* β the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline. |
|
|
| Minimal Python usage: |
|
|
| ```python |
| import torch |
| from models import build_vae_stylevar |
| |
| device = torch.device("cuda") |
| patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) |
| |
| vae, model = build_vae_stylevar( |
| device=device, patch_nums=patch_nums, |
| V=4096, Cvae=32, ch=160, share_quant_resi=4, |
| depth=20, shared_aln=False, attn_l2_norm=True, |
| flash_if_available=True, fused_if_available=True, |
| init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, |
| style_enc_dim=512, |
| ) |
| vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True) |
| |
| ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu") |
| model.load_state_dict(ckpt["model"], strict=True) |
| model.eval() |
| |
| # content, style: (B, 3, 256, 256) tensors in [-1, 1] |
| with torch.no_grad(): |
| out = model.autoregressive_infer( |
| B=content.size(0), style_img=style, content_img=content, |
| top_k=900, top_p=0.96, g_seed=42, |
| ) |
| ``` |
|
|
| --- |
|
|
| ## Training recipe |
|
|
| ### Stage 1 β SFT |
|
|
| | Setting | Value | |
| |---|---| |
| | Backbone | VAR depth-20 transformer (~600M params) | |
| | Trainable | Full transformer (VQ-VAE frozen) | |
| | Epochs | 10 | |
| | Learning rate | 5Γ10β»β΄ (epochs 1-6) β 1Γ10β»β΄ (epochs 7-10) | |
| | Batch size | 128 (with gradient accumulation) | |
| | Hardware | 1Γ NVIDIA RTX 4090 (48GB) | |
|
|
| ### Stage 2 β GRPO |
|
|
| | Setting | Value | |
| |---|---| |
| | LoRA | rank r=256, Ξ±/r=2 on every attention (W_Q_target, W_QKV_cond, W_proj) and FFN linear β 131M trainable (18.2% of backbone) | |
| | Reward | R = βΞ» Β· DreamSim(xΜ, x), Ξ»=5.0 | |
| | Group size G | 16 | |
| | Sampling | top-k=900, top-p=0.96 | |
| | Clip / KL / PANW | Ξ΅=0.2, Ξ²=0.1, Ξ³=0.7 | |
| | Optimizer | AdamW, lr 1Γ10β»β΅, wd 0.01, (Ξ²β, Ξ²β)=(0.9, 0.95), FP32 | |
| | Iterative merge | peak-triggered, Ο_gain=0.05 / Ο_patience=50 / 300-step cool-down | |
| | Emergency merge | raw KL > 2.0, 50-step cool-down | |
| | Hardware | 1Γ NVIDIA RTX 4090 (48GB), physical batch 16, G=16 serial rollouts | |
| |
| --- |
| |
| ## Datasets |
| |
| Both stages use a concatenation of two paired style-transfer datasets: |
| |
| - **OmniStyle-150K** β 143,992 (content, style, target) triplets. |
| - **ImagePulse-StyleTransfer** β 137,886 triplets. |
| |
| Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the G samples in a group. |
| |
| --- |
| |
| ## Results |
| |
| Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher β / lower β is better. **Best** in bold, <ins>second best</ins> underlined. |
| |
| | Dataset | Method | Style Loss β | Content Loss β | LPIPS β | SSIM β | DreamSim β | CLIP Sim β | Infer (s) β | |
| |---|---|---|---|---|---|---|---|---| |
| | **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** | |
| | | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 | |
| | | StyleVAR (GRPO) | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 | |
| | **ImagePulse** | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** | |
| | | StyleVAR (SFT) | <ins>0.0452</ins> | **180.7923** | <ins>0.5618</ins> | <ins>0.4282</ins> | <ins>0.3168</ins> | <ins>0.7903</ins> | 0.4031 | |
| | | StyleVAR (GRPO) | **0.0387** | <ins>182.0954</ins> | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 | |
| | **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** | |
| | | StyleVAR (SFT) | <ins>0.0206</ins> | <ins>160.1233</ins> | <ins>0.7398</ins> | **0.2713** | <ins>0.6986</ins> | <ins>0.5308</ins> | 0.4031 | |
| | | StyleVAR (GRPO) | **0.0199** | **157.5109** | **0.7286** | <ins>0.2677</ins> | **0.6793** | **0.5335** | 0.4031 | |
| |
| Inference time measured on a single NVIDIA A100 (40GB). |
| |
| --- |
| |
| ## Limitations |
| |
| - **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool. |
| - **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces. |
| - **Sampling cost.** 10-scale autoregressive decoding is ~128Γ slower than AdaIN. |
| |
| --- |
| |
| ## Intended use |
| |
| Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles. |
| |
| --- |
| |
| ## Citation |
| |
| ```bibtex |
| @article{jing2026stylevar, |
| title = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling}, |
| author = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen}, |
| year = {2026}, |
| note = {Duke University} |
| } |
| ``` |
| |
| ## Acknowledgments |
| |
| Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer). |
| |