Senfier-LiqiJing commited on
Commit
9f6fa31
Β·
verified Β·
1 Parent(s): 5e8078c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -15
README.md CHANGED
@@ -17,11 +17,11 @@ datasets:
17
 
18
  # StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
19
 
20
- Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales ($1{\times}1 \to 16{\times}16$, $L=680$ tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
21
 
22
  The model is trained in two stages from a pretrained vanilla VAR checkpoint:
23
  - **Stage 1 β€” SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
24
- - **Stage 2 β€” GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters ($r{=}256$) and Per-Action Normalization Weighting (PANW) to rebalance credit across the $256\times$ token imbalance between coarse and fine scales.
25
 
26
  - **Code:** https://github.com/Senfier-LiqiJing/StyleVAR
27
  - **Paper (OpenReview):** https://openreview.net/forum?id=UHW3PgLUsa
@@ -33,7 +33,7 @@ The model is trained in two stages from a pretrained vanilla VAR checkpoint:
33
 
34
  | File | Purpose | Notes |
35
  |---|---|---|
36
- | `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, $C_\text{vae}{=}32$, $V{=}4096$, $\text{ch}{=}160$). Shared between SFT and GRPO. |
37
  | `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | Plain state dict β€” no LoRA, no optimizer state. Use this for the SFT baseline. |
38
  | `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
39
 
@@ -122,7 +122,7 @@ with torch.no_grad():
122
  | Backbone | VAR depth-20 transformer (~600M params) |
123
  | Trainable | Full transformer (VQ-VAE frozen) |
124
  | Epochs | 10 |
125
- | Learning rate | $5\times10^{-4}$ (epochs 1-6) β†’ $1\times10^{-4}$ (epochs 7-10) |
126
  | Batch size | 128 (with gradient accumulation) |
127
  | Hardware | 1Γ— NVIDIA RTX 4090 (48GB) |
128
 
@@ -130,15 +130,15 @@ with torch.no_grad():
130
 
131
  | Setting | Value |
132
  |---|---|
133
- | LoRA | $r{=}256$, $\alpha/r{=}2$ on every attention ($W_Q^\text{target}$, $W_{QKV}^\text{cond}$, $W_\text{proj}$) and FFN linear β€” 131M trainable (18.2% of backbone) |
134
- | Reward | $R = -\lambda \cdot \text{DreamSim}(\hat{x}, x)$, $\lambda{=}5.0$ |
135
- | Group size $G$ | 16 |
136
- | Sampling | top-$k{=}900$, top-$p{=}0.96$ |
137
- | Clip / KL / PANW | $\varepsilon{=}0.2$, $\beta{=}0.1$, $\gamma{=}0.7$ |
138
- | Optimizer | AdamW, lr $1\times10^{-5}$, wd 0.01, $(\beta_1, \beta_2){=}(0.9, 0.95)$, FP32 |
139
- | Iterative merge | peak-triggered, $\tau_\text{gain}{=}0.05$ / $\tau_\text{patience}{=}50$ / 300-step cool-down |
140
  | Emergency merge | raw KL > 2.0, 50-step cool-down |
141
- | Hardware | 1Γ— NVIDIA RTX 4090 (48GB), physical batch 16, $G{=}16$ serial rollouts |
142
 
143
  ---
144
 
@@ -149,15 +149,15 @@ Both stages use a concatenation of two paired style-transfer datasets:
149
  - **OmniStyle-150K** β€” 143,992 (content, style, target) triplets.
150
  - **ImagePulse-StyleTransfer** β€” 137,886 triplets.
151
 
152
- Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the $G$ samples in a group.
153
 
154
  ---
155
 
156
  ## Results
157
 
158
- Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher $\uparrow$ / lower $\downarrow$ is better. **Best** in bold, <ins>second best</ins> underlined.
159
 
160
- | Dataset | Method | Style Loss $\downarrow$ | Content Loss $\downarrow$ | LPIPS $\downarrow$ | SSIM $\uparrow$ | DreamSim $\downarrow$ | CLIP Sim $\uparrow$ | Infer (s) $\downarrow$ |
161
  |---|---|---|---|---|---|---|---|---|
162
  | **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
163
  | | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |
 
17
 
18
  # StyleVAR: Controllable Image Style Transfer via Visual Autoregressive Modeling
19
 
20
+ Reference-based image style transfer built on Visual Autoregressive Modeling (VAR). Given a content image and a style image, StyleVAR generates a stylized output autoregressively over 10 scales (1Γ—1 β†’ 16Γ—16, L=680 tokens) using a **Blended Cross-Attention** mechanism: style and content features act as queries over the target's own autoregressive history, preserving content structure while absorbing style texture.
21
 
22
  The model is trained in two stages from a pretrained vanilla VAR checkpoint:
23
  - **Stage 1 β€” SFT.** Supervised fine-tuning on 267,710 paired (content, style, target) triplets.
24
+ - **Stage 2 β€” GRPO.** Reinforcement fine-tuning with Group Relative Policy Optimization against a DreamSim-based perceptual reward, using LoRA adapters (rank r=256) and Per-Action Normalization Weighting (PANW) to rebalance credit across the 256Γ— token imbalance between coarse and fine scales.
25
 
26
  - **Code:** https://github.com/Senfier-LiqiJing/StyleVAR
27
  - **Paper (OpenReview):** https://openreview.net/forum?id=UHW3PgLUsa
 
33
 
34
  | File | Purpose | Notes |
35
  |---|---|---|
36
+ | `vae_ch160v4096z32.pth` | Frozen multi-scale VQ-VAE tokenizer | Inherited from VAR (depth-16, C_vae=32, V=4096, ch=160). Shared between SFT and GRPO. |
37
  | `StyleVAR_SFT.pth` | Stage 1 supervised fine-tuning checkpoint | Plain state dict β€” no LoRA, no optimizer state. Use this for the SFT baseline. |
38
  | `StyleVAR-GRPO.pth` | Stage 2 GRPO-refined checkpoint | GRPO LoRA deltas already merged into the base weights. Drop-in replacement for `StyleVAR_SFT.pth`. |
39
 
 
122
  | Backbone | VAR depth-20 transformer (~600M params) |
123
  | Trainable | Full transformer (VQ-VAE frozen) |
124
  | Epochs | 10 |
125
+ | Learning rate | 5Γ—10⁻⁴ (epochs 1-6) β†’ 1Γ—10⁻⁴ (epochs 7-10) |
126
  | Batch size | 128 (with gradient accumulation) |
127
  | Hardware | 1Γ— NVIDIA RTX 4090 (48GB) |
128
 
 
130
 
131
  | Setting | Value |
132
  |---|---|
133
+ | LoRA | rank r=256, Ξ±/r=2 on every attention (W_Q_target, W_QKV_cond, W_proj) and FFN linear β€” 131M trainable (18.2% of backbone) |
134
+ | Reward | R = βˆ’Ξ» Β· DreamSim(xΜ‚, x), Ξ»=5.0 |
135
+ | Group size G | 16 |
136
+ | Sampling | top-k=900, top-p=0.96 |
137
+ | Clip / KL / PANW | Ξ΅=0.2, Ξ²=0.1, Ξ³=0.7 |
138
+ | Optimizer | AdamW, lr 1Γ—10⁻⁡, wd 0.01, (β₁, Ξ²β‚‚)=(0.9, 0.95), FP32 |
139
+ | Iterative merge | peak-triggered, Ο„_gain=0.05 / Ο„_patience=50 / 300-step cool-down |
140
  | Emergency merge | raw KL > 2.0, 50-step cool-down |
141
+ | Hardware | 1Γ— NVIDIA RTX 4090 (48GB), physical batch 16, G=16 serial rollouts |
142
 
143
  ---
144
 
 
149
  - **OmniStyle-150K** β€” 143,992 (content, style, target) triplets.
150
  - **ImagePulse-StyleTransfer** β€” 137,886 triplets.
151
 
152
+ Combined into a single **267,710-sample** training set with a 95/5 train/val split. SFT applies rotation + brightness to content and random cropping to style; GRPO rollouts are performed without augmentation so that the conditioning signal is deterministic across the G samples in a group.
153
 
154
  ---
155
 
156
  ## Results
157
 
158
+ Evaluation on three benchmarks spanning in-, near-, and out-of-distribution regimes. Arrows: higher ↑ / lower ↓ is better. **Best** in bold, <ins>second best</ins> underlined.
159
 
160
+ | Dataset | Method | Style Loss ↓ | Content Loss ↓ | LPIPS ↓ | SSIM ↑ | DreamSim ↓ | CLIP Sim ↑ | Infer (s) ↓ |
161
  |---|---|---|---|---|---|---|---|---|
162
  | **OmniStyle** | AdaIN (baseline) | 0.0625 | 198.3449 | 0.7506 | 0.1421 | 0.6522 | 0.6555 | **0.0079** |
163
  | | StyleVAR (SFT) | <ins>0.0468</ins> | <ins>116.3569</ins> | <ins>0.4743</ins> | <ins>0.3975</ins> | <ins>0.2276</ins> | <ins>0.8704</ins> | 0.4031 |