Senfier-LiqiJing commited on
Commit
9bfa7f8
Β·
verified Β·
1 Parent(s): e1160c6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +202 -3
README.md CHANGED
@@ -1,3 +1,202 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: pytorch
4
+ pipeline_tag: image-to-image
5
+ tags:
6
+ - style-transfer
7
+ - visual-autoregressive
8
+ - VAR
9
+ - VQ-VAE
10
+ - GRPO
11
+ - reinforcement-learning
12
+ - image-generation
13
+ datasets:
14
+ - DiffSynth-Studio/OmniStyle
15
+ - DiffSynth-Studio/ImagePulse-StyleTransfer
16
+ ---
17
+
18
+ # StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
19
+
20
+ Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales ($1{\times}1 \to 16{\times}16$, $L=680$ tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
21
+
22
+ The model is trained in two stages from a pretrained vanilla VAR checkpoint:
23
+ - **Stage 1 β€” SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
24
+ - **Stage 2 β€” GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters ($r{=}256$) and Per-Action Normalization Weighting (PANW) to rebalance credit across the $256\times$ token imbalance between coarse and fine scales.
25
+
26
+ - **Paper / code:** https://github.com/Senfier-LiqiJing/StyleVAR
27
+ - **Authors:** Liqi Jing, Dingming Zhang, Peinian Li, Lichen Zhu (Duke University)
28
+
29
+ ---
30
+
31
+ ## Files in this repository
32
+
33
+ | File | Purpose | Notes |
34
+ |---|---|---|
35
+ | `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, $C_\text{vae}{=}32$, $V{=}4096$, $\text{ch}{=}160$). Shared between SFT and GRPO. |
36
+ | `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | Plain state dict β€” no LoRA, no optimizer state. Use this for the SFT baseline. |
37
+ | `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
38
+
39
+ Both transformer checkpoints ship as plain state dicts under the `"model"` key β€” LoRA adapters have been baked in, so you can load them directly into a fresh StyleVAR without constructing any LoRA wrappers.
40
+
41
+ ---
42
+
43
+ ## Quick start
44
+
45
+ ```bash
46
+ pip install torch torchvision huggingface_hub pillow
47
+ git clone https://github.com/Senfier-LiqiJing/StyleVAR.git
48
+ cd StyleVAR
49
+ ```
50
+
51
+ Download the checkpoints:
52
+
53
+ ```python
54
+ from huggingface_hub import hf_hub_download
55
+
56
+ REPO = "Senfier-LiqiJing/StyleVAR"
57
+ vae_path = hf_hub_download(REPO, "vae_ch160v4096z32.pth", local_dir="ckpt")
58
+ sft_path = hf_hub_download(REPO, "StyleVAR_SFT.pth", local_dir="ckpt")
59
+ grpo_path = hf_hub_download(REPO, "StyleVAR-GRPO.pth", local_dir="ckpt")
60
+ ```
61
+
62
+ Or via the CLI:
63
+
64
+ ```bash
65
+ hf download Senfier-LiqiJing/StyleVAR \
66
+ vae_ch160v4096z32.pth StyleVAR_SFT.pth StyleVAR-GRPO.pth \
67
+ --local-dir ckpt
68
+ ```
69
+
70
+ Run inference with the GRPO checkpoint (uses `eval/infer_grpo.py` from the project repo):
71
+
72
+ ```bash
73
+ python eval/infer_grpo.py \
74
+ --vae_ckpt ckpt/vae_ch160v4096z32.pth \
75
+ --sft_ckpt ckpt/StyleVAR-GRPO.pth \
76
+ --sft_only \
77
+ --num 8 --out grpo_infer_results.png
78
+ ```
79
+
80
+ > `--sft_only` here means *"do not attach a LoRA adapter on top"* β€” the GRPO deltas are already merged into `StyleVAR-GRPO.pth`, so it loads like any plain checkpoint. Pass `StyleVAR_SFT.pth` instead to reproduce the SFT baseline.
81
+
82
+ Minimal Python usage:
83
+
84
+ ```python
85
+ import torch
86
+ from models import build_vae_stylevar
87
+
88
+ device = torch.device("cuda")
89
+ patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
90
+
91
+ vae, model = build_vae_stylevar(
92
+ device=device, patch_nums=patch_nums,
93
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
94
+ depth=20, shared_aln=False, attn_l2_norm=True,
95
+ flash_if_available=True, fused_if_available=True,
96
+ init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,
97
+ style_enc_dim=512,
98
+ )
99
+ vae.load_state_dict(torch.load("ckpt/vae_ch160v4096z32.pth", map_location="cpu"), strict=True)
100
+
101
+ ckpt = torch.load("ckpt/StyleVAR-GRPO.pth", map_location="cpu")
102
+ model.load_state_dict(ckpt["model"], strict=True)
103
+ model.eval()
104
+
105
+ # content, style: (B, 3, 256, 256) tensors in [-1, 1]
106
+ with torch.no_grad():
107
+ out = model.autoregressive_infer(
108
+ B=content.size(0), style_img=style, content_img=content,
109
+ top_k=900, top_p=0.96, g_seed=42,
110
+ )
111
+ ```
112
+
113
+ ---
114
+
115
+ ## Training recipe
116
+
117
+ ### Stage 1 β€” SFT
118
+
119
+ | Setting | Value |
120
+ |---|---|
121
+ | Backbone | VAR depth-20 transformer (~600M params) |
122
+ | Trainable | Full transformer (VQ-VAE frozen) |
123
+ | Epochs | 10 |
124
+ | Learning rate | $5\times10^{-4}$ (epochs 1-6) β†’ $1\times10^{-4}$ (epochs 7-10) |
125
+ | Batch size | 128 (with gradient accumulation) |
126
+ | Hardware | 1Γ— NVIDIA RTX 4090 (48GB) |
127
+
128
+ ### Stage 2 β€” GRPO
129
+
130
+ | Setting | Value |
131
+ |---|---|
132
+ | LoRA | $r{=}256$, $\alpha/r{=}2$ on every attention ($W_Q^\text{target}$, $W_{QKV}^\text{cond}$, $W_\text{proj}$) and FFN linear β€” 131M trainable (18.2% of backbone) |
133
+ | Reward | $R = -\lambda \cdot \text{DreamSim}(\hat{x}, x)$, $\lambda{=}5.0$ |
134
+ | Group size $G$ | 16 |
135
+ | Sampling | top-$k{=}900$, top-$p{=}0.96$ |
136
+ | Clip / KL / PANW | $\varepsilon{=}0.2$, $\beta{=}0.1$, $\gamma{=}0.7$ |
137
+ | Optimizer | AdamW, lr $1\times10^{-5}$, wd 0.01, $(\beta_1, \beta_2){=}(0.9, 0.95)$, FP32 |
138
+ | Iterative merge | peak-triggered, $\tau_\text{gain}{=}0.05$ / $\tau_\text{patience}{=}50$ / 300-step cool-down |
139
+ | Emergency merge | raw KL > 2.0, 50-step cool-down |
140
+ | Hardware | 1Γ— NVIDIA RTX 4090 (48GB), physical batch 16, $G{=}16$ serial rollouts |
141
+
142
+ ---
143
+
144
+ ## Datasets
145
+
146
+ Both stages use a concatenation of two paired style-transfer datasets:
147
+
148
+ - **OmniStyle-150K** β€” 143,992 (content, style, target) triplets.
149
+ - **ImagePulse-StyleTransfer** β€” 137,886 triplets.
150
+
151
+ Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the $G$ samples in a group.
152
+
153
+ ---
154
+
155
+ ## Results
156
+
157
+ Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher $\uparrow$ / lower $\downarrow$ is better. **Best** in bold, <ins>second best</ins> underlined.
158
+
159
+ | Dataset | Method | Style Loss $\downarrow$ | Content Loss $\downarrow$ | LPIPS $\downarrow$ | SSIM $\uparrow$ | DreamSim $\downarrow$ | CLIP Sim $\uparrow$ | Infer (s) $\downarrow$ |
160
+ |---|---|---|---|---|---|---|---|---|
161
+ | **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
162
+ | | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |
163
+ | | StyleVAR (GRPO) | **0.0466** | **114.5686** | **0.4656** | **0.4024** | **0.2164** | **0.8740** | 0.4031 |
164
+ | **ImagePulse** | AdaIN (baseline) | 0.0735 | 223.4699 | 0.7802 | 0.1574 | 0.6958 | 0.5651 | **0.0029** |
165
+ | | StyleVAR (SFT) | <ins>0.0452</ins> | **180.7923** | <ins>0.5618</ins> | <ins>0.4282</ins> | <ins>0.3168</ins> | <ins>0.7903</ins> | 0.4031 |
166
+ | | StyleVAR (GRPO) | **0.0387** | <ins>182.0954</ins> | **0.5572** | **0.4320** | **0.2979** | **0.8000** | 0.4031 |
167
+ | **COCO+WikiArt** | AdaIN (baseline) | 0.0282 | 171.0877 | 0.7688 | 0.1985 | 0.7536 | 0.5319 | **0.0027** |
168
+ | | StyleVAR (SFT) | <ins>0.0206</ins> | <ins>160.1233</ins> | <ins>0.7398</ins> | **0.2713** | <ins>0.6986</ins> | <ins>0.5308</ins> | 0.4031 |
169
+ | | StyleVAR (GRPO) | **0.0199** | **157.5109** | **0.7286** | <ins>0.2677</ins> | **0.6793** | **0.5335** | 0.4031 |
170
+
171
+ Inference time measured on a single NVIDIA A100 (40GB).
172
+
173
+ ---
174
+
175
+ ## Limitations
176
+
177
+ - **Generalization gap on internet images.** OmniStyle-150K's ~150K triplets come from only ~1,800 unique content images; at 600M parameters the model partially memorizes this limited content pool.
178
+ - **Human faces.** Facial topology is structurally more sensitive and perceptually more scrutinized than natural scenes; the model performs well on landscapes and architecture but struggles on faces.
179
+ - **Sampling cost.** 10-scale autoregressive decoding is ~128Γ— slower than AdaIN.
180
+
181
+ ---
182
+
183
+ ## Intended use
184
+
185
+ Research on multi-scale visual autoregressive generation, reference-based style transfer, and GRPO-style reinforcement fine-tuning of visual policies. Not intended for deployment in production pipelines that render identifiable people or copyrighted artistic styles.
186
+
187
+ ---
188
+
189
+ ## Citation
190
+
191
+ ```bibtex
192
+ @article{jing2026stylevar,
193
+ title = {StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling},
194
+ author = {Jing, Liqi and Zhang, Dingming and Li, Peinian and Zhu, Lichen},
195
+ year = {2026},
196
+ note = {Duke University}
197
+ }
198
+ ```
199
+
200
+ ## Acknowledgments
201
+
202
+ Built on the [VAR](https://github.com/FoundationVision/VAR) framework; trained on [OmniStyle-150K](https://www.modelscope.cn/datasets/DiffSynth-Studio/OmniStyle) and [ImagePulse-StyleTransfer](https://www.modelscope.cn/datasets/DiffSynth-Studio/ImagePulse-StyleTransfer).