APGASU commited on
Commit
7bef20f
·
verified ·
1 Parent(s): a328d4d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. LICENSE +21 -0
  3. README.md +177 -12
  4. TRAIN.md +110 -0
  5. app.py +471 -0
  6. assets/example_1.png +3 -0
  7. assets/generated_images.png +3 -0
  8. assets/reconstructed.png +3 -0
  9. assets/teaser.png +3 -0
  10. configs/training/VibeToken_small.yaml +102 -0
  11. configs/vibetoken_ll.yaml +40 -0
  12. configs/vibetoken_sl.yaml +40 -0
  13. data/__init__.py +1 -0
  14. data/convert_imagenet_to_wds.py +56 -0
  15. data/webdataset_reader.py +518 -0
  16. evaluator/__init__.py +1 -0
  17. evaluator/evaluator.py +230 -0
  18. evaluator/inception.py +215 -0
  19. examples/batch_inference.py +241 -0
  20. examples/encode_decode.py +172 -0
  21. generate.py +240 -0
  22. generator/__init__.py +4 -0
  23. modeling/__init__.py +0 -0
  24. modeling/modules/__init__.py +6 -0
  25. modeling/modules/base_model.py +124 -0
  26. modeling/modules/blocks.py +617 -0
  27. modeling/modules/discriminator.py +124 -0
  28. modeling/modules/ema_model.py +241 -0
  29. modeling/modules/encoder_decoder.py +1142 -0
  30. modeling/modules/fuzzy_embedding.py +70 -0
  31. modeling/modules/losses.py +339 -0
  32. modeling/modules/lpips.py +181 -0
  33. modeling/modules/maskgit_vqgan.py +346 -0
  34. modeling/modules/perceptual_loss.py +101 -0
  35. modeling/quantizer/__init__.py +3 -0
  36. modeling/quantizer/dist.py +302 -0
  37. modeling/quantizer/mvq.py +159 -0
  38. modeling/quantizer/quantizer.py +158 -0
  39. modeling/quantizer/softvq.py +170 -0
  40. modeling/vibetoken_model.py +219 -0
  41. reconstruct.py +148 -0
  42. requirements.txt +26 -0
  43. scripts/train_vibetoken.py +223 -0
  44. setup.sh +20 -0
  45. train_tokenvibe.sh +14 -0
  46. train_vibetoken.sh +14 -0
  47. utils/__init__.py +0 -0
  48. utils/logger.py +69 -0
  49. utils/lr_schedulers.py +129 -0
  50. utils/misc.py +342 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example_1.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/generated_images.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/reconstructed.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Maitreya Patel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: VibeToken
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 6.6.0
8
- python_version: '3.12'
9
- app_file: app.py
10
- pinned: false
11
- license: mit
12
- ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [CVPR 2026] VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations
2
+
3
+ <p align="center">
4
+ <img src="assets/teaser.png" alt="VibeToken Teaser" width="100%">
5
+ </p>
6
+
7
+ <p align="center">
8
+ <b>CVPR 2026</b> &nbsp;|&nbsp;
9
+ <a href="#">Paper</a> &nbsp;|&nbsp;
10
+ <a href="#">Project Page</a> &nbsp;|&nbsp;
11
+ <a href="#-checkpoints">Checkpoints</a>
12
+ </p>
13
+
14
+ <p align="center">
15
+ <img src="https://img.shields.io/badge/CVPR-2026-blue" alt="CVPR 2026">
16
+ <img src="https://img.shields.io/badge/arXiv-TODO-b31b1b" alt="arXiv">
17
+ <img src="https://img.shields.io/badge/License-MIT-green" alt="License">
18
+ <a href="https://huggingface.co/mpatel57/VibeToken"><img src="https://img.shields.io/badge/🤗-Model-yellow" alt="HuggingFace"></a>
19
+ </p>
20
+
21
  ---
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ We introduce an efficient, resolution-agnostic autoregressive (AR) image synthesis approach that generalizes to **arbitrary resolutions and aspect ratios**, narrowing the gap to diffusion models at scale. At its core is **VibeToken**, a novel resolution-agnostic 1D Transformer-based image tokenizer that encodes images into a dynamic, user-controllable sequence of 32--256 tokens, achieving state-of-the-art efficiency and performance trade-off. Building on VibeToken, we present **VibeToken-Gen**, a class-conditioned AR generator with out-of-the-box support for arbitrary resolutions while requiring significantly fewer compute resources.
24
+
25
+ ### 🔥 Highlights
26
+
27
+ | | |
28
+ |---|---|
29
+ | 🎯 **1024×1024 in just 64 tokens** | Achieves **3.94 gFID** vs. 5.87 gFID for diffusion-based SOTA (1,024 tokens) |
30
+ | ⚡ **Constant 179G FLOPs** | 63× more efficient than LlamaGen (11T FLOPs at 1024×1024) |
31
+ | 🌐 **Resolution-agnostic** | Supports arbitrary resolutions and aspect ratios out of the box |
32
+ | 🎛️ **Dynamic token count** | User-controllable 32--256 tokens per image |
33
+ | 🔍 **Native super-resolution** | Supports image super-resolution out of the box |
34
+
35
+
36
+ ## 📰 News
37
+
38
+ - **[Feb 2026]** 🎉 VibeToken is accepted at **CVPR 2026**!
39
+ - **[Feb 2026]** Training scripts released.
40
+ - **[Feb 2026]** Inference code and checkpoints released.
41
+
42
+
43
+ ## 🚀 Quick Start
44
+
45
+ ```bash
46
+ # 1. Clone and setup
47
+ git clone https://github.com/<your-org>/VibeToken.git
48
+ cd VibeToken
49
+ uv venv --python=3.11.6
50
+ source .venv/bin/activate
51
+ uv pip install -r requirements.txt
52
+
53
+ # 2. Download a checkpoint (see Checkpoints section below)
54
+ mkdir -p checkpoints
55
+ wget https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin -O ./checkpoints/VibeToken_LL.bin
56
+
57
+ # 3. Reconstruct an image
58
+ python reconstruct.py --auto \
59
+ --config configs/vibetoken_ll.yaml \
60
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
61
+ --image ./assets/example_1.png \
62
+ --output ./assets/reconstructed.png
63
+ ```
64
+
65
+
66
+ ## 📦 Checkpoints
67
+
68
+ All checkpoints are hosted on [Hugging Face](https://huggingface.co/mpatel57/VibeToken).
69
+
70
+ #### Reconstruction Checkpoints
71
+
72
+ | Name | Resolution | rFID (256 tokens) | rFID (64 tokens) | Download |
73
+ |------|:----------:|:-----------------:|:----------------:|----------|
74
+ | VibeToken-LL | 1024×1024 | 3.76 | 4.12 | [VibeToken_LL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin) |
75
+ | VibeToken-LL | 256×256 | 5.12 | 0.90 | same as above |
76
+ | VibeToken-SL | 1024×1024 | 4.25 | 2.41 | [VibeToken_SL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_SL.bin) |
77
+ | VibeToken-SL | 256×256 | 5.44 | 0.40 | same as above |
78
+
79
+ #### Generation Checkpoints
80
+
81
+ | Name | Training Resolution(s) | Tokens | Best gFID | Download |
82
+ |------|:----------------------:|:------:|:---------:|----------|
83
+ | VibeToken-Gen-B | 256×256 | 65 | 7.62 | [VibeTokenGen-b-fixed65_dynamic_1500k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-b-fixed65_dynamic_1500k.pt) |
84
+ | VibeToken-Gen-B | 1024×1024 | 65 | 7.37 | same as above |
85
+ | VibeToken-Gen-XXL | 256×256 | 65 | 3.62 | [VibeTokenGen-xxl-dynamic-65_750k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-xxl-dynamic-65_750k.pt) |
86
+ | VibeToken-Gen-XXL | 1024×1024 | 65 | **3.54** | same as above |
87
+
88
+
89
+ ## 🛠️ Setup
90
+
91
+ ```bash
92
+ uv venv --python=3.11.6
93
+ source .venv/bin/activate
94
+ uv pip install -r requirements.txt
95
+ ```
96
+
97
+ > **Tip:** If you don't have `uv`, install it via `pip install uv` or see [uv docs](https://github.com/astral-sh/uv). Alternatively, use `python -m venv .venv && pip install -r requirements.txt`.
98
+
99
+
100
+ ## 🖼️ VibeToken Reconstruction
101
+
102
+ Download the VibeToken-LL checkpoint (see [Checkpoints](#-checkpoints)), then:
103
+
104
+ ```bash
105
+ # Auto mode (recommended) -- automatically determines optimal patch sizes
106
+ python reconstruct.py --auto \
107
+ --config configs/vibetoken_ll.yaml \
108
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
109
+ --image ./assets/example_1.png \
110
+ --output ./assets/reconstructed.png
111
+
112
+ # Manual mode -- specify patch sizes explicitly
113
+ python reconstruct.py \
114
+ --config configs/vibetoken_ll.yaml \
115
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
116
+ --image ./assets/example_1.png \
117
+ --output ./assets/reconstructed.png \
118
+ --encoder_patch_size 16 \
119
+ --decoder_patch_size 16
120
+ ```
121
+
122
+ > **Note:** For best performance, the input image resolution should be a multiple of 32. Images with other resolutions are automatically rescaled to the nearest multiple of 32.
123
+
124
+
125
+ ## 🎨 VibeToken-Gen: ImageNet-1k Generation
126
+
127
+ Download both the VibeToken-LL and VibeToken-Gen-XXL checkpoints (see [Checkpoints](#-checkpoints)), then:
128
+
129
+ ```bash
130
+ python generate.py \
131
+ --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
132
+ --gpt-model GPT-XXL --num-output-layer 4 \
133
+ --num-codebooks 8 --codebook-size 32768 \
134
+ --image-size 256 --cfg-scale 4.0 --top-k 500 --temperature 1.0 \
135
+ --class-dropout-prob 0.1 \
136
+ --extra-layers "QKV" \
137
+ --latent-size 65 \
138
+ --config ./configs/vibetoken_ll.yaml \
139
+ --vq-ckpt ./checkpoints/VibeToken_LL.bin \
140
+ --sample-dir ./assets/ \
141
+ --skip-folder-creation \
142
+ --compile \
143
+ --decoder-patch-size 32,32 \
144
+ --target-resolution 1024,1024 \
145
+ --llamagen-target-resolution 256,256 \
146
+ --precision bf16 \
147
+ --global-seed 156464151
148
+ ```
149
+
150
+ The `--target-resolution` controls the tokenizer output resolution, while `--llamagen-target-resolution` controls the generator's internal resolution (max 512×512; for higher resolutions, the tokenizer handles upscaling).
151
+
152
+
153
+ ## 🏋️ Training
154
+
155
+ To train the VibeToken tokenizer from scratch, please refer to [TRAIN.md](TRAIN.md) for detailed instructions.
156
+
157
+
158
+ ## 🙏 Acknowledgement
159
+
160
+ We would like to acknowledge the following repositories that inspired our work and upon which we directly build:
161
+ [1d-tokenizer](https://github.com/bytedance/1d-tokenizer),
162
+ [LlamaGen](https://github.com/FoundationVision/LlamaGen), and
163
+ [UniTok](https://github.com/FoundationVision/UniTok).
164
+
165
+
166
+ ## 📝 Citation
167
+
168
+ If you find VibeToken useful in your research, please consider citing:
169
+
170
+ ```bibtex
171
+ @inproceedings{vibetoken2026,
172
+ title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
173
+ author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
174
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
175
+ year = {2026}
176
+ }
177
+ ```
178
+
179
+ If you have any questions, feel free to open an issue or reach out!
TRAIN.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Instructions
2
+
3
+ ## VibeToken MVQ Tokenizer
4
+
5
+ This repository contains the training code for our tokenizer.
6
+ We provide the example config [VibeToken-Small](configs/training/VibeToken_small.yaml) that trains the small encoder/decoder architecture with 32-64 tokens.
7
+
8
+ ### Data Preparation
9
+
10
+ All data paths are controlled by the `DATA_DIR` environment variable. Set it once to point to your preferred storage location:
11
+
12
+ ```bash
13
+ export DATA_DIR=/path/to/your/storage # defaults to ./data if unset
14
+ ```
15
+
16
+ Download ImageNet-1k and convert to WebDataset format:
17
+
18
+ ```bash
19
+ source .venv/bin/activate
20
+
21
+ # Option 1: Use the setup script (recommended)
22
+ bash setup.sh
23
+
24
+ # Option 2: Run steps manually
25
+ export HF_HUB_ENABLE_HF_TRANSFER=1
26
+ huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
27
+ python data/convert_imagenet_to_wds.py \
28
+ --input_dir "${DATA_DIR}/imagenet-1k" \
29
+ --output_dir "${DATA_DIR}/imagenet_wds"
30
+ ```
31
+
32
+ After preparation, update the shard paths in your training config to match your `DATA_DIR`:
33
+
34
+ ```yaml
35
+ dataset:
36
+ params:
37
+ train_shards_path_or_url: "<DATA_DIR>/imagenet_wds/imagenet-train-{000001..000128}.tar"
38
+ eval_shards_path_or_url: "<DATA_DIR>/imagenet_wds/imagenet-val-{000001..000004}.tar"
39
+ ```
40
+
41
+ ### Launch Training
42
+
43
+ Start training on 1 node with 8 GPUs:
44
+
45
+ ```bash
46
+ source .venv/bin/activate
47
+ bash train_tokenizer.sh
48
+ ```
49
+
50
+ ### Config Reference
51
+
52
+ Below are the important hyperparameters to manage the training.
53
+
54
+ ```yaml
55
+ model:
56
+ vq_model:
57
+ vit_enc_model_size: "small" # this can be small/base/large
58
+ vit_dec_model_size: "small" # this can be small/base/large
59
+ num_latent_tokens: 64 # in paper we set this to 256
60
+
61
+ losses:
62
+ discriminator_start: 100_000 # set based on convergence, in paper we set this to 250_000
63
+
64
+ dataset:
65
+ params:
66
+ pretokenization: True # keep this true if using the current setup
67
+ train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
68
+ eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
69
+ preprocessing:
70
+ resize_shorter_edge: 512 # maximum size during pretraining but can be any value
71
+ crop_size: 512 # maximum size during pretraining but can be any value
72
+ min_tokens: 32 # minimum number of tokens to generate
73
+ max_tokens: 64 # maximum number of tokens to generate
74
+
75
+ training:
76
+ gradient_accumulation_steps: 1 # increase for LL model that does not fit on single node
77
+ per_gpu_batch_size: 32 # decrease to 16 for LL model; during GAN training this is halved
78
+ max_train_steps: 400_000 # in paper we train up to 650_000; model may diverge after 600_000
79
+ num_generated_images: 2 # for validation
80
+ variable_resolution: # any-to-any resolution training
81
+ any2any: True
82
+ dim:
83
+ - [256, 256]
84
+ - [512, 512]
85
+ - [384, 256]
86
+ - [256, 384]
87
+ - [512, 384]
88
+ - [384, 512]
89
+ ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1] # probability per resolution; must sum to 1.0
90
+
91
+
92
+ # Remove patch mixture parameters unless the model does not fit in memory.
93
+ # This will slow down training and may hurt performance.
94
+ # We do not use this in our normal setup.
95
+ model:
96
+ vq_model:
97
+ encoder:
98
+ patch_mixture_start_layer: 2
99
+ patch_mixture_end_layer: 22
100
+ decoder:
101
+ patch_mixture_start_layer: 2
102
+ patch_mixture_end_layer: 22
103
+ ```
104
+
105
+
106
+ <!-- ### Reproduced Results on Small Baseline
107
+
108
+ > Note: Our released checkpoints are from a different codebase and may observe +/- changes in results.
109
+
110
+ Below we report the performance on the above training script on the small baseline. This baseline is not reported in the paper but achieves competitive performance as expected. -->
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VibeToken-Gen Gradio Demo
3
+ Class-conditional ImageNet generation with dynamic resolution support.
4
+ """
5
+ import spaces
6
+
7
+ import os
8
+ import random
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+ torch.set_float32_matmul_precision("high")
17
+ torch.set_grad_enabled(False)
18
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
19
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ from PIL import Image
23
+
24
+ from vibetokengen.generate import generate
25
+ from vibetokengen.model import GPT_models
26
+ from vibetoken import VibeTokenTokenizer
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Configuration
30
+ # ---------------------------------------------------------------------------
31
+
32
+ HF_REPO = "mpatel57/VibeToken"
33
+ USE_XXL = os.environ.get("VIBETOKEN_XXL", "0") == "1"
34
+
35
+ if USE_XXL:
36
+ GPT_MODEL_NAME = "GPT-XXL"
37
+ GPT_CKPT_FILENAME = "VibeTokenGen-xxl-dynamic-65_750k.pt"
38
+ NUM_OUTPUT_LAYER = 4
39
+ EXTRA_LAYERS = "QKV"
40
+ else:
41
+ GPT_MODEL_NAME = "GPT-B"
42
+ GPT_CKPT_FILENAME = "VibeTokenGen-b-fixed65_dynamic_1500k.pt"
43
+ NUM_OUTPUT_LAYER = 4
44
+ EXTRA_LAYERS = "QKV"
45
+
46
+ VQ_CKPT_FILENAME = "VibeToken_LL.bin"
47
+ CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "vibetoken_ll.yaml")
48
+
49
+ CODEBOOK_SIZE = 32768
50
+ NUM_CODEBOOKS = 8
51
+ LATENT_SIZE = 65
52
+ NUM_CLASSES = 1000
53
+ CLS_TOKEN_NUM = 1
54
+ CLASS_DROPOUT_PROB = 0.1
55
+ CAPPING = 50.0
56
+
57
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
59
+ COMPILE = os.environ.get("VIBETOKEN_NO_COMPILE", "0") != "1" and DEVICE == "cuda"
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # ImageNet class labels (curated popular subset)
63
+ # ---------------------------------------------------------------------------
64
+
65
+ IMAGENET_CLASSES = {
66
+ "Golden Retriever": 207,
67
+ "Labrador Retriever": 208,
68
+ "German Shepherd": 235,
69
+ "Siberian Husky": 250,
70
+ "Pembroke Corgi": 263,
71
+ "Tabby Cat": 281,
72
+ "Persian Cat": 283,
73
+ "Siamese Cat": 284,
74
+ "Tiger": 292,
75
+ "Lion": 291,
76
+ "Cheetah": 293,
77
+ "Brown Bear": 294,
78
+ "Giant Panda": 388,
79
+ "Red Fox": 277,
80
+ "Arctic Fox": 279,
81
+ "Timber Wolf": 269,
82
+ "Bald Eagle": 22,
83
+ "Macaw": 88,
84
+ "Flamingo": 130,
85
+ "Peacock": 84,
86
+ "Goldfish": 1,
87
+ "Great White Shark": 2,
88
+ "Jellyfish": 107,
89
+ "Monarch Butterfly": 323,
90
+ "Ladybug": 301,
91
+ "Snail": 113,
92
+ "Red Sports Car": 817,
93
+ "School Bus": 779,
94
+ "Steam Locomotive": 820,
95
+ "Sailboat": 914,
96
+ "Space Shuttle": 812,
97
+ "Castle": 483,
98
+ "Church": 497,
99
+ "Lighthouse": 437,
100
+ "Volcano": 980,
101
+ "Lakeside": 975,
102
+ "Cliff": 972,
103
+ "Coral Reef": 973,
104
+ "Valley": 979,
105
+ "Seashore": 978,
106
+ "Mushroom": 947,
107
+ "Broccoli": 937,
108
+ "Pizza": 963,
109
+ "Ice Cream": 928,
110
+ "Cheeseburger": 933,
111
+ "Espresso": 967,
112
+ "Acoustic Guitar": 402,
113
+ "Grand Piano": 579,
114
+ "Violin": 889,
115
+ "Balloon": 417,
116
+ }
117
+
118
+ GENERATOR_RESOLUTION_PRESETS = {
119
+ "256 × 256": (256, 256),
120
+ "384 × 256": (384, 256),
121
+ "256 × 384": (256, 384),
122
+ "384 × 384": (384, 384),
123
+ "512 × 256": (512, 256),
124
+ "256 × 512": (256, 512),
125
+ "512 × 512": (512, 512),
126
+ }
127
+
128
+ OUTPUT_RESOLUTION_PRESETS = {
129
+ "Same as generator": None,
130
+ "256 × 256": (256, 256),
131
+ "384 × 384": (384, 384),
132
+ "512 × 512": (512, 512),
133
+ "768 × 768": (768, 768),
134
+ "1024 × 1024": (1024, 1024),
135
+ "512 × 256 (2:1)": (512, 256),
136
+ "256 × 512 (1:2)": (256, 512),
137
+ "768 × 512 (3:2)": (768, 512),
138
+ "512 × 768 (2:3)": (512, 768),
139
+ "1024 × 512 (2:1)": (1024, 512),
140
+ "512 × 1024 (1:2)": (512, 1024),
141
+ }
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # Model loading
145
+ # ---------------------------------------------------------------------------
146
+
147
+ vq_model = None
148
+ gpt_model = None
149
+
150
+
151
+ def download_checkpoint(filename: str) -> str:
152
+ return hf_hub_download(repo_id=HF_REPO, filename=filename)
153
+
154
+
155
+ def _make_res_tensors(gen_h: int, gen_w: int, multiplier: int):
156
+ """Create normalized resolution tensors for the GPT generator."""
157
+ th = torch.tensor(gen_h / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
158
+ tw = torch.tensor(gen_w / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
159
+ return th, tw
160
+
161
+
162
+ def _warmup(model):
163
+ """Run a throwaway generation to trigger torch.compile and warm CUDA caches."""
164
+ print("Warming up (first call triggers compilation, may take ~30-60s)...")
165
+ dummy_cond = torch.tensor([0], device=DEVICE)
166
+ th, tw = _make_res_tensors(256, 256, multiplier=2)
167
+ with torch.inference_mode():
168
+ generate(
169
+ model, dummy_cond, LATENT_SIZE, NUM_CODEBOOKS,
170
+ cfg_scale=4.0, cfg_interval=-1,
171
+ target_h=th, target_w=tw,
172
+ temperature=1.0, top_k=500, top_p=1.0, sample_logits=True,
173
+ )
174
+ if DEVICE == "cuda":
175
+ torch.cuda.synchronize()
176
+ print("Warmup complete — subsequent generations will be fast.")
177
+
178
+
179
+ def load_models():
180
+ global vq_model, gpt_model
181
+
182
+ print("Downloading checkpoints (if needed)...")
183
+ vq_path = download_checkpoint(VQ_CKPT_FILENAME)
184
+ gpt_path = download_checkpoint(GPT_CKPT_FILENAME)
185
+
186
+ print(f"Loading VibeToken tokenizer from {vq_path}...")
187
+ vq_model = VibeTokenTokenizer.from_config(
188
+ CONFIG_PATH, vq_path, device=DEVICE, dtype=DTYPE,
189
+ )
190
+ print("VibeToken tokenizer loaded.")
191
+
192
+ print(f"Loading {GPT_MODEL_NAME} from {gpt_path}...")
193
+ gpt_model = GPT_models[GPT_MODEL_NAME](
194
+ vocab_size=CODEBOOK_SIZE,
195
+ block_size=LATENT_SIZE,
196
+ num_classes=NUM_CLASSES,
197
+ cls_token_num=CLS_TOKEN_NUM,
198
+ model_type="c2i",
199
+ num_codebooks=NUM_CODEBOOKS,
200
+ n_output_layer=NUM_OUTPUT_LAYER,
201
+ class_dropout_prob=CLASS_DROPOUT_PROB,
202
+ extra_layers=EXTRA_LAYERS,
203
+ capping=CAPPING,
204
+ ).to(device=DEVICE, dtype=DTYPE)
205
+
206
+ checkpoint = torch.load(gpt_path, map_location="cpu", weights_only=False)
207
+ if "model" in checkpoint:
208
+ weights = checkpoint["model"]
209
+ elif "module" in checkpoint:
210
+ weights = checkpoint["module"]
211
+ elif "state_dict" in checkpoint:
212
+ weights = checkpoint["state_dict"]
213
+ else:
214
+ weights = checkpoint
215
+ gpt_model.load_state_dict(weights, strict=True)
216
+ gpt_model.eval()
217
+ del checkpoint
218
+ print(f"{GPT_MODEL_NAME} loaded.")
219
+
220
+ if COMPILE:
221
+ print("Compiling GPT model with torch.compile (max-autotune)...")
222
+ gpt_model = torch.compile(gpt_model, mode="max-autotune", fullgraph=True)
223
+ _warmup(gpt_model)
224
+ else:
225
+ print("Skipping torch.compile (set VIBETOKEN_NO_COMPILE=0 to enable).")
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # Decoder patch-size heuristic
230
+ # ---------------------------------------------------------------------------
231
+
232
+ def auto_decoder_patch_size(h: int, w: int) -> tuple[int, int]:
233
+ max_dim = max(h, w)
234
+ if max_dim <= 256:
235
+ ps = 8
236
+ elif max_dim <= 512:
237
+ ps = 16
238
+ else:
239
+ ps = 32
240
+ return (ps, ps)
241
+
242
+
243
+ # ---------------------------------------------------------------------------
244
+ # Generation
245
+ # ---------------------------------------------------------------------------
246
+
247
+ @torch.inference_mode()
248
+ @spaces.GPU(duration=90)
249
+ def generate_image(
250
+ class_name: str,
251
+ class_id: int,
252
+ gen_resolution_preset: str,
253
+ out_resolution_preset: str,
254
+ decoder_ps_choice: str,
255
+ cfg_scale: float,
256
+ temperature: float,
257
+ top_k: int,
258
+ top_p: float,
259
+ seed: int,
260
+ randomize_seed: bool,
261
+ ):
262
+ if vq_model is None or gpt_model is None:
263
+ raise gr.Error("Models are still loading. Please wait a moment and try again.")
264
+
265
+ if randomize_seed:
266
+ seed = random.randint(0, 2**31 - 1)
267
+
268
+ torch.manual_seed(seed)
269
+ np.random.seed(seed)
270
+ if DEVICE == "cuda":
271
+ torch.cuda.manual_seed_all(seed)
272
+
273
+ if class_name and class_name != "Custom (enter ID below)":
274
+ cid = IMAGENET_CLASSES[class_name]
275
+ else:
276
+ cid = int(class_id)
277
+ cid = max(0, min(cid, NUM_CLASSES - 1))
278
+
279
+ gen_h, gen_w = GENERATOR_RESOLUTION_PRESETS[gen_resolution_preset]
280
+
281
+ out_res = OUTPUT_RESOLUTION_PRESETS[out_resolution_preset]
282
+ if out_res is None:
283
+ out_h, out_w = gen_h, gen_w
284
+ else:
285
+ out_h, out_w = out_res
286
+
287
+ if decoder_ps_choice == "Auto":
288
+ dec_ps = auto_decoder_patch_size(out_h, out_w)
289
+ else:
290
+ ps = int(decoder_ps_choice)
291
+ dec_ps = (ps, ps)
292
+
293
+ multiplier = 2 if cfg_scale > 1.0 else 1
294
+
295
+ c_indices = torch.tensor([cid], device=DEVICE)
296
+ th, tw = _make_res_tensors(gen_h, gen_w, multiplier)
297
+
298
+ index_sample = generate(
299
+ gpt_model,
300
+ c_indices,
301
+ LATENT_SIZE,
302
+ NUM_CODEBOOKS,
303
+ cfg_scale=cfg_scale,
304
+ cfg_interval=-1,
305
+ target_h=th,
306
+ target_w=tw,
307
+ temperature=temperature,
308
+ top_k=top_k,
309
+ top_p=top_p,
310
+ sample_logits=True,
311
+ )
312
+
313
+ index_sample = index_sample.unsqueeze(2)
314
+ samples = vq_model.decode(
315
+ index_sample,
316
+ height=out_h,
317
+ width=out_w,
318
+ patch_size=dec_ps,
319
+ )
320
+ samples = torch.clamp(samples, 0, 1)
321
+
322
+ img_np = (samples[0].permute(1, 2, 0).float().cpu().numpy() * 255).astype("uint8")
323
+ pil_img = Image.fromarray(img_np)
324
+
325
+ return pil_img, seed
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Gradio UI
330
+ # ---------------------------------------------------------------------------
331
+
332
+ HEADER_MD = """
333
+ # VibeToken-Gen: Dynamic Resolution Image Generation
334
+
335
+ <p style="margin-top:4px;">
336
+ <b>Maitreya Patel, Jingtao Li, Weiming Zhuang, Yezhou Yang, Lingjuan Lyu</b>
337
+ &nbsp;|&nbsp;
338
+ </p>
339
+ <h3>CVPR 2026 (Main Conference)</h3>
340
+
341
+ <p>
342
+ <a href="https://huggingface.co/mpatel57/VibeToken" target="_blank">🤗 Model</a> &nbsp;|&nbsp;
343
+ <a href="https://github.com/patel-maitreya/VibeToken" target="_blank">💻 GitHub</a>
344
+ </p>
345
+
346
+ Generate ImageNet class-conditional images at **arbitrary resolutions** using only **65 tokens**.
347
+ VibeToken-Gen maintains a constant **179G FLOPs** regardless of output resolution.
348
+ """
349
+
350
+ CITATION_MD = """
351
+ ### Citation
352
+ ```bibtex
353
+ @inproceedings{vibetoken2026,
354
+ title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
355
+ author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
356
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
357
+ year = {2026}
358
+ }
359
+ ```
360
+ """
361
+
362
+ class_choices = ["Custom (enter ID below)"] + sorted(IMAGENET_CLASSES.keys())
363
+
364
+ with gr.Blocks(
365
+ title="VibeToken-Gen Demo",
366
+ theme=gr.themes.Soft(),
367
+ ) as demo:
368
+ gr.Markdown(HEADER_MD)
369
+
370
+ with gr.Row():
371
+ # ---- Left column: controls ----
372
+ with gr.Column(scale=1):
373
+ class_dropdown = gr.Dropdown(
374
+ label="ImageNet Class",
375
+ choices=class_choices,
376
+ value="Golden Retriever",
377
+ info="Pick a class or choose 'Custom' to enter an ID manually.",
378
+ )
379
+ class_id_input = gr.Number(
380
+ label="Custom Class ID (0–999)",
381
+ value=207,
382
+ minimum=0,
383
+ maximum=999,
384
+ step=1,
385
+ visible=False,
386
+ )
387
+ gen_resolution_dropdown = gr.Dropdown(
388
+ label="Generator Resolution",
389
+ choices=list(GENERATOR_RESOLUTION_PRESETS.keys()),
390
+ value="256 × 256",
391
+ info="Internal resolution for the AR generator (max 512×512).",
392
+ )
393
+ out_resolution_dropdown = gr.Dropdown(
394
+ label="Output Resolution (Decoder)",
395
+ choices=list(OUTPUT_RESOLUTION_PRESETS.keys()),
396
+ value="Same as generator",
397
+ info="Final image resolution. Set higher for super-resolution (e.g. generate at 256, decode at 1024).",
398
+ )
399
+ decoder_ps_dropdown = gr.Dropdown(
400
+ label="Decoder Patch Size",
401
+ choices=["Auto", "8", "16", "32"],
402
+ value="Auto",
403
+ info="'Auto' selects based on output resolution. Larger = faster but coarser.",
404
+ )
405
+
406
+ with gr.Accordion("Advanced Sampling Parameters", open=False):
407
+ cfg_slider = gr.Slider(
408
+ label="CFG Scale",
409
+ minimum=1.0, maximum=20.0, value=4.0, step=0.5,
410
+ info="Classifier-free guidance strength.",
411
+ )
412
+ temp_slider = gr.Slider(
413
+ label="Temperature",
414
+ minimum=0.1, maximum=2.0, value=1.0, step=0.05,
415
+ )
416
+ topk_slider = gr.Slider(
417
+ label="Top-k",
418
+ minimum=0, maximum=2000, value=500, step=10,
419
+ info="0 disables top-k filtering.",
420
+ )
421
+ topp_slider = gr.Slider(
422
+ label="Top-p",
423
+ minimum=0.0, maximum=1.0, value=1.0, step=0.05,
424
+ info="1.0 disables nucleus sampling.",
425
+ )
426
+ seed_input = gr.Number(
427
+ label="Seed", value=0, minimum=0, maximum=2**31 - 1, step=1,
428
+ )
429
+ randomize_cb = gr.Checkbox(label="Randomize seed", value=True)
430
+
431
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
432
+
433
+ # ---- Right column: output ----
434
+ with gr.Column(scale=2):
435
+ output_image = gr.Image(label="Generated Image", type="pil", height=512)
436
+ used_seed = gr.Number(label="Seed used", interactive=False)
437
+
438
+ # Show/hide custom class ID field
439
+ def toggle_custom_id(choice):
440
+ return gr.update(visible=(choice == "Custom (enter ID below)"))
441
+
442
+ class_dropdown.change(
443
+ fn=toggle_custom_id,
444
+ inputs=[class_dropdown],
445
+ outputs=[class_id_input],
446
+ )
447
+
448
+ generate_btn.click(
449
+ fn=generate_image,
450
+ inputs=[
451
+ class_dropdown,
452
+ class_id_input,
453
+ gen_resolution_dropdown,
454
+ out_resolution_dropdown,
455
+ decoder_ps_dropdown,
456
+ cfg_slider,
457
+ temp_slider,
458
+ topk_slider,
459
+ topp_slider,
460
+ seed_input,
461
+ randomize_cb,
462
+ ],
463
+ outputs=[output_image, used_seed],
464
+ )
465
+
466
+ gr.Markdown(CITATION_MD)
467
+
468
+
469
+ if __name__ == "__main__":
470
+ load_models()
471
+ demo.launch()
assets/example_1.png ADDED

Git LFS Details

  • SHA256: da07f6cd58181e54c6b4bbbc0458d99f48946da19a27ea02e9a3920bfb2b5d15
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
assets/generated_images.png ADDED

Git LFS Details

  • SHA256: e1f5cdef18942f6331460be883485ce5141c7d9c6db7e1cd7596422a57b5cba7
  • Pointer size: 133 Bytes
  • Size of remote file: 11.1 MB
assets/reconstructed.png ADDED

Git LFS Details

  • SHA256: 86b57a2d196f6af62b37979b15a7dcda8de1097669ccb297d9213b32098d5873
  • Pointer size: 131 Bytes
  • Size of remote file: 344 kB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 46bd1ff58c18d17b2808ba9445d6beab0ae9098be21b316e9ed730d553b607fb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.86 MB
configs/training/VibeToken_small.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "VibeToken_mvq_tiny_main"
3
+ name: "VibeToken_mvq_tiny_main"
4
+ output_dir: "wandb/VibeToken_mvq_tiny_main"
5
+ max_train_examples: 1_281_167
6
+ save_every: 10_000
7
+ eval_every: 10_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+
13
+ model:
14
+ sub_model_type: "vibetoken"
15
+ train_with_attention: True
16
+ eval_with_attention: True
17
+ vq_model:
18
+ # encoder: # patch mixture is not supported
19
+ # patch_mixture_start_layer: 2
20
+ # patch_mixture_end_layer: 22
21
+ # decoder: # patch mixture is not supported
22
+ # patch_mixture_start_layer: 2
23
+ # patch_mixture_end_layer: 22
24
+ quantize_mode: mvq
25
+ codebook_size: 32768 # 32768 / 8 = 4096
26
+ token_size: 256 # 256 / 8 = 32
27
+ use_l2_norm: False
28
+ commitment_cost: 0.25
29
+ clustering_vq: False
30
+ num_codebooks: 8
31
+ # vit arch
32
+ vit_enc_model_size: "small"
33
+ vit_dec_model_size: "small"
34
+ vit_enc_patch_size: 32
35
+ vit_dec_patch_size: 32
36
+ num_latent_tokens: 64
37
+ finetune_decoder: False
38
+ is_legacy: False
39
+
40
+ losses:
41
+ discriminator_start: 100_000
42
+ quantizer_weight: 1.0
43
+ discriminator_factor: 1.0
44
+ discriminator_weight: 0.1
45
+ perceptual_loss: "lpips-convnext_s-1.0-0.1"
46
+ perceptual_weight: 1.1
47
+ reconstruction_loss: "l2"
48
+ reconstruction_weight: 1.0
49
+ lecam_regularization_weight: 0.001
50
+
51
+ dataset:
52
+ params:
53
+ pretokenization: True
54
+ train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
55
+ eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
56
+ num_workers_per_gpu: 12
57
+ preprocessing:
58
+ resize_shorter_edge: 512
59
+ crop_size: 512
60
+ random_crop: True
61
+ random_flip: True
62
+ res_ratio_filtering: True
63
+ min_tokens: 32
64
+ max_tokens: 64
65
+
66
+ optimizer:
67
+ name: adamw
68
+ params:
69
+ learning_rate: 1e-4
70
+ discriminator_learning_rate: 1e-4
71
+ beta1: 0.9
72
+ beta2: 0.999
73
+ weight_decay: 1e-4
74
+
75
+ lr_scheduler:
76
+ scheduler: "cosine"
77
+ params:
78
+ learning_rate: ${optimizer.params.learning_rate}
79
+ warmup_steps: 10_000
80
+ end_lr: 1e-5
81
+
82
+ training:
83
+ gradient_accumulation_steps: 1
84
+ per_gpu_batch_size: 32
85
+ mixed_precision: "fp16"
86
+ enable_tf32: True
87
+ enable_wandb: True
88
+ use_ema: True
89
+ seed: 42
90
+ max_train_steps: 400_000
91
+ num_generated_images: 2
92
+ max_grad_norm: 1.0
93
+ variable_resolution:
94
+ any2any: True
95
+ dim:
96
+ - [256, 256]
97
+ - [512, 512]
98
+ - [384, 256]
99
+ - [256, 384]
100
+ - [512, 384]
101
+ - [384, 512]
102
+ ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1]
configs/vibetoken_ll.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VibeToken Large-Large Configuration
2
+ # Large encoder + Large decoder for highest quality
3
+ #
4
+ # Usage:
5
+ # from vibetoken import VibeTokenTokenizer
6
+ # tokenizer = VibeTokenTokenizer.from_config(
7
+ # "configs/vibetoken_ll.yaml",
8
+ # "path/to/checkpoint.bin"
9
+ # )
10
+
11
+ model:
12
+ sub_model_type: "vibetoken"
13
+ vq_model:
14
+ # Quantization settings
15
+ quantize_mode: mvq
16
+ codebook_size: 32768 # 32768 / 8 = 4096 per codebook
17
+ token_size: 256 # 256 / 8 = 32 per codebook
18
+ num_codebooks: 8
19
+ use_l2_norm: false
20
+ commitment_cost: 0.25
21
+
22
+ # Encoder architecture
23
+ vit_enc_model_size: "large"
24
+ vit_enc_patch_size: 32
25
+
26
+ # Decoder architecture
27
+ vit_dec_model_size: "large"
28
+ vit_dec_patch_size: 32
29
+
30
+ # Latent tokens
31
+ num_latent_tokens: 256
32
+
33
+ # Mode flags
34
+ is_legacy: false
35
+ finetune_decoder: false
36
+
37
+ # Dataset preprocessing defaults (for reference)
38
+ dataset:
39
+ preprocessing:
40
+ crop_size: 512
configs/vibetoken_sl.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VibeToken Small-Large Configuration
2
+ # Small encoder + Large decoder for faster encoding
3
+ #
4
+ # Usage:
5
+ # from vibetoken import VibeTokenTokenizer
6
+ # tokenizer = VibeTokenTokenizer.from_config(
7
+ # "configs/vibetoken_sl.yaml",
8
+ # "path/to/checkpoint.bin"
9
+ # )
10
+
11
+ model:
12
+ sub_model_type: "vibetoken"
13
+ vq_model:
14
+ # Quantization settings
15
+ quantize_mode: mvq
16
+ codebook_size: 32768 # 32768 / 8 = 4096 per codebook
17
+ token_size: 256 # 256 / 8 = 32 per codebook
18
+ num_codebooks: 8
19
+ use_l2_norm: false
20
+ commitment_cost: 0.25
21
+
22
+ # Encoder architecture (Small for faster encoding)
23
+ vit_enc_model_size: "small"
24
+ vit_enc_patch_size: 32
25
+
26
+ # Decoder architecture (Large for quality)
27
+ vit_dec_model_size: "large"
28
+ vit_dec_patch_size: 32
29
+
30
+ # Latent tokens
31
+ num_latent_tokens: 256
32
+
33
+ # Mode flags
34
+ is_legacy: false
35
+ finetune_decoder: false
36
+
37
+ # Dataset preprocessing defaults (for reference)
38
+ dataset:
39
+ preprocessing:
40
+ crop_size: 512
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset
data/convert_imagenet_to_wds.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ import time
7
+
8
+ import webdataset as wds
9
+ from datasets import load_dataset
10
+
11
+
12
+ def convert_imagenet_to_wds(input_dir, output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
13
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
14
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
15
+
16
+ opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
17
+ output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
18
+ dataset = load_dataset(input_dir, split="train")
19
+ now = time.time()
20
+ for i, example in enumerate(dataset):
21
+ if i % max_train_samples_per_shard == 0:
22
+ print(i, file=sys.stderr)
23
+ img, label = example["image"], example["label"]
24
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
25
+ output.close()
26
+ time_taken = time.time() - now
27
+ print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
28
+
29
+ opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
30
+ output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
31
+ dataset = load_dataset(input_dir, split="validation")
32
+ now = time.time()
33
+ for i, example in enumerate(dataset):
34
+ if i % max_val_samples_per_shard == 0:
35
+ print(i, file=sys.stderr)
36
+ img, label = example["image"], example["label"]
37
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
38
+ output.close()
39
+ time_taken = time.time() - now
40
+ print(f"Wrote {i+1} val examples in {time_taken // 60} min.")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--input_dir", type=str, required=True,
46
+ help="Path to the ImageNet-1k dataset (HuggingFace format).")
47
+ parser.add_argument("--output_dir", type=str, required=True,
48
+ help="Path to the output directory for WebDataset shards.")
49
+ parser.add_argument("--max_train_samples_per_shard", type=int, default=10000,
50
+ help="Maximum number of training samples per shard.")
51
+ parser.add_argument("--max_val_samples_per_shard", type=int, default=10000,
52
+ help="Maximum number of validation samples per shard.")
53
+ args = parser.parse_args()
54
+
55
+ os.makedirs(args.output_dir, exist_ok=True)
56
+ convert_imagenet_to_wds(args.input_dir, args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard)
data/webdataset_reader.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loader using webdataset.
2
+
3
+ Reference:
4
+ https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
5
+ https://github.com/huggingface/open-muse/blob/main/training/data.py
6
+ """
7
+
8
+ import math
9
+ from typing import List, Union, Text
10
+ import webdataset as wds
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import default_collate
14
+ from torchvision import transforms
15
+ from torch.utils.data import Dataset
16
+ import linecache
17
+ import json
18
+ from PIL import Image
19
+ import random
20
+ import cv2
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+
24
+ Image.MAX_IMAGE_PIXELS = None
25
+
26
+
27
+ def load_json(sample):
28
+ sample['json'] = json.loads(sample['json'].decode('utf-8'))
29
+ return sample
30
+
31
+
32
+ def filter_keys(key_set):
33
+ def _f(dictionary):
34
+ return {k: v for k, v in dictionary.items() if k in key_set}
35
+
36
+ return _f
37
+
38
+
39
+ def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0):
40
+ def _f(sample):
41
+ cfg = sample['json']
42
+ h, w = cfg['original_height'], cfg['original_width']
43
+ ratio = h/w
44
+ longer_side = max(h, w)
45
+ return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
46
+ return _f
47
+
48
+ def calculate_laplacian_variance(image):
49
+ """Calculate the variance of Laplacian which is a measure of image sharpness/blur."""
50
+ # Convert to grayscale if it's RGB
51
+ image = np.array(image)
52
+ if len(image.shape) == 3:
53
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
54
+ else:
55
+ gray = image
56
+
57
+ # Calculate Laplacian
58
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
59
+
60
+ # Calculate variance
61
+ return laplacian.var()
62
+
63
+ # Add this function to map Laplacian values to token lengths
64
+ def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128):
65
+ """
66
+ Maps Laplacian values to token lengths using a bell curve approach.
67
+ At the mean Laplacian value, uses mean_tokens.
68
+ Values further from the mean get mapped to shorter/longer token lengths.
69
+ """
70
+ # Prevent division by zero and handle edge cases
71
+ if std <= 0:
72
+ return mean_tokens
73
+
74
+ # Calculate z-score
75
+ z_score = (laplacian_value - mean) / std
76
+
77
+ # Use bell curve mapping (gaussian)
78
+ # When z_score is 0 (at mean), we get mean_tokens
79
+ # As z_score increases, token length increases toward max_tokens
80
+ # As z_score decreases, token length decreases toward min_tokens
81
+ scaling_factor = 2.0 # Controls how quickly we reach min/max tokens
82
+ normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score))
83
+
84
+ # Map to token range [min_tokens, max_tokens]
85
+ token_length = min_tokens + normalized_position * (max_tokens - min_tokens)
86
+ return int(round(token_length))
87
+
88
+ # Add this function to map Laplacian values to token lengths
89
+ def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128):
90
+ """
91
+ Maps Laplacian values to token lengths using a linear mapping.
92
+ Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens,
93
+ and higher values scale up to max_tokens.
94
+ """
95
+ # Prevent division by zero and handle edge cases
96
+ if std <= 0:
97
+ return mean_tokens
98
+
99
+ # Linear mapping from laplacian space to token space
100
+ # First normalize laplacian value relative to mean
101
+ normalized = (laplacian_value - 0.0) / mean
102
+
103
+ # Map 0->min_tokens, mean->mean_tokens, and scale up linearly
104
+ if laplacian_value <= mean:
105
+ # Linear interpolation between min_tokens and mean_tokens
106
+ ratio = laplacian_value / mean
107
+ token_length = min_tokens + (mean_tokens - min_tokens) * ratio
108
+ else:
109
+ # Linear interpolation between mean_tokens and max_tokens
110
+ ratio = (laplacian_value - mean) / mean # How far past mean
111
+ token_length = mean_tokens + (max_tokens - mean_tokens) * ratio
112
+
113
+ # Clamp to valid range
114
+ token_length = max(min_tokens, min(max_tokens, token_length))
115
+ return int(round(token_length))
116
+
117
+ def get_laplacian_attention_mask(sample):
118
+ """Process sample to add Laplacian variance and attention mask."""
119
+ # Create a new dict to avoid modifying the input
120
+ processed = dict(sample)
121
+
122
+ # Calculate Laplacian variance
123
+ var = calculate_laplacian_variance(processed["image"])
124
+ length = get_dynamic_length(var)
125
+
126
+ # Create attention mask
127
+ attention_mask = torch.zeros((128,), dtype=torch.float32)
128
+ attention_mask[:length+1] = 1.0
129
+
130
+ # Add new fields to processed dict
131
+ processed["laplacian_var"] = var
132
+ processed["attention_mask"] = attention_mask
133
+
134
+ return processed
135
+
136
+ def get_uniform_attention_mask(min_tokens=32, max_tokens=128):
137
+ """Process sample to add uniform random attention mask."""
138
+ def _f(dictionary):
139
+ # Sample length uniformly between min_tokens and max_tokens
140
+ length = torch.randint(min_tokens, max_tokens+1, (1,)).item()
141
+
142
+ # Create attention mask
143
+ attention_mask = torch.zeros((max_tokens,), dtype=torch.float32)
144
+ attention_mask[:length+1] = 1.0
145
+
146
+ # Add attention mask to dictionary
147
+ dictionary["attention_mask"] = attention_mask
148
+ return dictionary
149
+ return _f
150
+
151
+ def process_recap_text(p):
152
+ def _f(dictionary):
153
+ if "recap_txt" in dictionary:
154
+ if random.random() < p:
155
+ recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']]
156
+ # Convert input to string and strip whitespace
157
+ text = dictionary["recap_txt"].decode("utf-8").strip()
158
+ # Check if text starts with any of the phrases
159
+ for phrase in recap_prefixes:
160
+ if text.startswith(phrase):
161
+ # Remove the phrase and any leading/trailing whitespace
162
+ text = text[len(phrase):].strip()
163
+ # Capitalize the first letter
164
+ text = text[0].upper() + text[1:] if text else ""
165
+ break
166
+
167
+ dictionary["text"] = text.encode("utf-8")
168
+ return dictionary
169
+
170
+ return _f
171
+
172
+
173
+ def identity(x):
174
+ return x
175
+
176
+
177
+ class ImageTransform:
178
+ def __init__(self,
179
+ resize_shorter_edge: int = 256,
180
+ crop_size: int = 256,
181
+ random_crop: bool = True,
182
+ random_flip: bool = True,
183
+ normalize_mean: List[float] = [0., 0., 0.],
184
+ normalize_std: List[float] = [1., 1., 1.]):
185
+ """Initializes the WebDatasetReader with specified augmentation parameters.
186
+
187
+ Args:
188
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
189
+ crop_size: An integer, the size to crop the input image to.
190
+ random_crop: A boolean, whether to use random crop augmentation during training.
191
+ random_flip: A boolean, whether to use random flipping augmentation during training.
192
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
193
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
194
+
195
+ Raises:
196
+ NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
197
+ """
198
+ train_transform = []
199
+ interpolation = transforms.InterpolationMode.BICUBIC
200
+
201
+ train_transform.append(
202
+ transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
203
+ if random_crop:
204
+ train_transform.append(transforms.RandomCrop(crop_size))
205
+ else:
206
+ train_transform.append(transforms.CenterCrop(crop_size))
207
+ if random_flip:
208
+ train_transform.append(transforms.RandomHorizontalFlip())
209
+ train_transform.append(transforms.ToTensor())
210
+ # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
211
+ # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
212
+ train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
213
+
214
+ self.train_transform = transforms.Compose(train_transform)
215
+ self.eval_transform = transforms.Compose(
216
+ [
217
+ # Note that we always resize to crop_size during eval to ensure the results
218
+ # can be compared against reference numbers on ImageNet etc.
219
+ transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
220
+ transforms.CenterCrop(crop_size),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize(normalize_mean, normalize_std)
223
+ ]
224
+ )
225
+ print(f"self.train_transform: {self.train_transform}")
226
+ print(f"self.eval_transform: {self.eval_transform}")
227
+
228
+
229
+ class SimpleImageDataset:
230
+ def __init__(
231
+ self,
232
+ train_shards_path: Union[Text, List[Text]],
233
+ eval_shards_path: Union[Text, List[Text]],
234
+ num_train_examples: int,
235
+ per_gpu_batch_size: int,
236
+ global_batch_size: int,
237
+ num_workers_per_gpu: int = 12,
238
+ resize_shorter_edge: int = 256,
239
+ crop_size: int = 256,
240
+ random_crop = True,
241
+ random_flip = True,
242
+ normalize_mean: List[float] = [0., 0., 0.],
243
+ normalize_std: List[float] = [1., 1., 1.],
244
+ dataset_with_class_label: bool = True,
245
+ dataset_with_text_label: bool = False,
246
+ res_ratio_filtering = False,
247
+ min_tokens = 32,
248
+ max_tokens = 128,
249
+ ):
250
+ """Initializes the WebDatasetReader class.
251
+
252
+ Args:
253
+ train_shards_path: A string or list of string, path to the training data shards in webdataset format.
254
+ eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
255
+ num_train_examples: An integer, total number of training examples.
256
+ per_gpu_batch_size: An integer, number of examples per GPU batch.
257
+ global_batch_size: An integer, total number of examples in a batch across all GPUs.
258
+ num_workers_per_gpu: An integer, number of workers per GPU.
259
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
260
+ crop_size: An integer, the size to crop the input image to.
261
+ random_crop: A boolean, whether to use random crop augmentation during training.
262
+ random_flip: A boolean, whether to use random flipping augmentation during training.
263
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
264
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
265
+ """
266
+ transform = ImageTransform(
267
+ resize_shorter_edge, crop_size, random_crop, random_flip,
268
+ normalize_mean, normalize_std)
269
+
270
+ if dataset_with_class_label:
271
+ train_processing_pipeline = [
272
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
273
+ wds.rename(
274
+ image="jpg;png;jpeg;webp",
275
+ class_id="cls",
276
+ handler=wds.warn_and_continue,
277
+ ),
278
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
279
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
280
+ wds.map_dict(
281
+ image=transform.train_transform,
282
+ class_id=lambda x: int(x),
283
+ attention_mask=lambda x: x,
284
+ handler=wds.warn_and_continue,
285
+ ),
286
+ ]
287
+ elif dataset_with_text_label:
288
+ train_processing_pipeline = [
289
+ wds.map(load_json),
290
+ wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity),
291
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue),
292
+ wds.rename(
293
+ image="jpg;png;jpeg;webp",
294
+ text="txt",
295
+ handler=wds.warn_and_continue,
296
+ ),
297
+ wds.map(filter_keys(set(["image", "text", "__key__"]))),
298
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
299
+ wds.map_dict(
300
+ image=transform.train_transform,
301
+ attention_mask=lambda x: x,
302
+ handler=wds.warn_and_continue,
303
+ ),
304
+ ]
305
+ else:
306
+ raise NotImplementedError
307
+
308
+ test_processing_pipeline = [
309
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
310
+ wds.rename(
311
+ image="jpg;png;jpeg;webp",
312
+ class_id="cls",
313
+ handler=wds.warn_and_continue,
314
+ ),
315
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
316
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
317
+ wds.map_dict(
318
+ image=transform.eval_transform,
319
+ class_id=lambda x: int(x),
320
+ # laplacian_var=lambda x: x,
321
+ attention_mask=lambda x: x,
322
+ handler=wds.warn_and_continue,
323
+ ),
324
+ ]
325
+
326
+ # Create train dataset and loader.
327
+ pipeline = [
328
+ wds.ResampledShards(train_shards_path),
329
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
330
+ wds.shuffle(bufsize=5000,
331
+ initial=1000),
332
+ *train_processing_pipeline,
333
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
334
+ ]
335
+
336
+ num_batches = math.ceil(num_train_examples / global_batch_size)
337
+ num_worker_batches = math.ceil(num_train_examples /
338
+ (global_batch_size * num_workers_per_gpu))
339
+ num_batches = num_worker_batches * num_workers_per_gpu
340
+ num_samples = num_batches * global_batch_size
341
+
342
+ # Each worker is iterating over the complete dataset.
343
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
344
+ self._train_dataloader = wds.WebLoader(
345
+ self._train_dataset,
346
+ batch_size=None,
347
+ shuffle=False,
348
+ num_workers=num_workers_per_gpu,
349
+ pin_memory=True,
350
+ persistent_workers=True,
351
+ )
352
+ # Add meta-data to dataloader instance for convenience.
353
+ self._train_dataloader.num_batches = num_batches
354
+ self._train_dataloader.num_samples = num_samples
355
+
356
+ # Create eval dataset and loader.
357
+ pipeline = [
358
+ wds.SimpleShardList(eval_shards_path),
359
+ wds.split_by_worker,
360
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
361
+ *test_processing_pipeline,
362
+ wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
363
+ ]
364
+ self._eval_dataset = wds.DataPipeline(*pipeline)
365
+ self._eval_dataloader = wds.WebLoader(
366
+ self._eval_dataset,
367
+ batch_size=None,
368
+ shuffle=False,
369
+ num_workers=num_workers_per_gpu,
370
+ pin_memory=True,
371
+ persistent_workers=True,
372
+ )
373
+
374
+ @property
375
+ def train_dataset(self):
376
+ return self._train_dataset
377
+
378
+ @property
379
+ def train_dataloader(self):
380
+ return self._train_dataloader
381
+
382
+ @property
383
+ def eval_dataset(self):
384
+ return self._eval_dataset
385
+
386
+ @property
387
+ def eval_dataloader(self):
388
+ return self._eval_dataloader
389
+
390
+
391
+ class PretoeknizedDataSetJSONL(Dataset):
392
+ def __init__(self, data_path):
393
+ super().__init__()
394
+ self.jsonl_file = data_path
395
+ self.num_lines = sum(1 for _ in open(self.jsonl_file))
396
+ # Ensure the file is cached
397
+ linecache.checkcache(self.jsonl_file)
398
+ print("Number of data:", self.num_lines)
399
+
400
+ def __len__(self):
401
+ return self.num_lines
402
+
403
+ def __getitem__(self, idx):
404
+ line = linecache.getline(self.jsonl_file, idx + 1).strip()
405
+ data = json.loads(line)
406
+ return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
407
+
408
+
409
+ class PretokenizedWebDataset(SimpleImageDataset):
410
+ def __init__ (
411
+ self,
412
+ train_shards_path: Union[Text, List[Text]],
413
+ eval_shards_path: Union[Text, List[Text]],
414
+ num_train_examples: int,
415
+ per_gpu_batch_size: int,
416
+ global_batch_size: int,
417
+ num_workers_per_gpu: int,
418
+ resize_shorter_edge: int = 256,
419
+ crop_size: int = 256,
420
+ random_crop = True,
421
+ random_flip = True,
422
+ normalize_mean: List[float] = [0., 0., 0.],
423
+ normalize_std: List[float] = [1., 1., 1.],
424
+ process_recap = False,
425
+ use_recap_prob = 0.95,
426
+ ):
427
+ """Initializes the PretokenizedWebDataset class.
428
+
429
+ Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training
430
+ """
431
+ transform = ImageTransform(
432
+ resize_shorter_edge, crop_size, random_crop, random_flip,
433
+ normalize_mean, normalize_std)
434
+
435
+ def decode_npy(x):
436
+ arr = np.frombuffer(x, dtype=np.float16)
437
+ ret = torch.tensor(arr)
438
+ return ret
439
+
440
+ def decode_text(x):
441
+ ret = x.decode("utf-8")
442
+ return ret
443
+
444
+ train_processing_pipeline = [
445
+ wds.rename(
446
+ tokens="token.npy",
447
+ text="txt",
448
+ handler=wds.warn_and_continue,
449
+ ),
450
+ wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)),
451
+ wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))),
452
+ wds.map_dict(
453
+ tokens=decode_npy,
454
+ text=decode_text,
455
+ handler=wds.warn_and_continue,
456
+ ),
457
+ ]
458
+
459
+ test_processing_pipeline = [
460
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
461
+ wds.rename(
462
+ image="jpg;png;jpeg;webp",
463
+ handler=wds.warn_and_continue,
464
+ ),
465
+ wds.map_dict(
466
+ image=transform.eval_transform,
467
+ handler=wds.warn_and_continue,
468
+ ),
469
+ ]
470
+
471
+
472
+ # Create train dataset and loader.
473
+ pipeline = [
474
+ wds.ResampledShards(train_shards_path),
475
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
476
+ wds.shuffle(bufsize=5000,
477
+ initial=1000),
478
+ *train_processing_pipeline,
479
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
480
+ ]
481
+
482
+ num_batches = math.ceil(num_train_examples / global_batch_size)
483
+ num_worker_batches = math.ceil(num_train_examples /
484
+ (global_batch_size * num_workers_per_gpu))
485
+ num_batches = num_worker_batches * num_workers_per_gpu
486
+ num_samples = num_batches * global_batch_size
487
+
488
+ # Each worker is iterating over the complete dataset.
489
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
490
+ self._train_dataloader = wds.WebLoader(
491
+ self._train_dataset,
492
+ batch_size=None,
493
+ shuffle=False,
494
+ num_workers=num_workers_per_gpu,
495
+ pin_memory=True,
496
+ persistent_workers=True,
497
+ )
498
+ # Add meta-data to dataloader instance for convenience.
499
+ self._train_dataloader.num_batches = num_batches
500
+ self._train_dataloader.num_samples = num_samples
501
+
502
+ # Create eval dataset and loader.
503
+ pipeline = [
504
+ wds.SimpleShardList(eval_shards_path),
505
+ wds.split_by_worker,
506
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
507
+ *test_processing_pipeline,
508
+ wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
509
+ ]
510
+ self._eval_dataset = wds.DataPipeline(*pipeline)
511
+ self._eval_dataloader = wds.WebLoader(
512
+ self._eval_dataset,
513
+ batch_size=None,
514
+ shuffle=False,
515
+ num_workers=num_workers_per_gpu,
516
+ pin_memory=True,
517
+ persistent_workers=True,
518
+ )
evaluator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .evaluator import VQGANEvaluator
evaluator/evaluator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluator for reconstruction results."""
2
+
3
+ import warnings
4
+
5
+ from typing import Sequence, Optional, Mapping, Text
6
+ import numpy as np
7
+ from scipy import linalg
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from .inception import get_inception_model
12
+
13
+
14
+ def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor:
15
+ """Computes covariance of the input tensor.
16
+
17
+ Args:
18
+ sigma: A torch.Tensor, sum of outer products of input features.
19
+ total: A torch.Tensor, sum of all input features.
20
+ num_examples: An integer, number of examples in the input tensor.
21
+ Returns:
22
+ A torch.Tensor, covariance of the input tensor.
23
+ """
24
+ if num_examples == 0:
25
+ return torch.zeros_like(sigma)
26
+
27
+ sub_matrix = torch.outer(total, total)
28
+ sub_matrix = sub_matrix / num_examples
29
+
30
+ return (sigma - sub_matrix) / (num_examples - 1)
31
+
32
+
33
+ class VQGANEvaluator:
34
+ def __init__(
35
+ self,
36
+ device,
37
+ enable_rfid: bool = True,
38
+ enable_inception_score: bool = True,
39
+ enable_codebook_usage_measure: bool = False,
40
+ enable_codebook_entropy_measure: bool = False,
41
+ num_codebook_entries: int = 1024
42
+ ):
43
+ """Initializes VQGAN Evaluator.
44
+
45
+ Args:
46
+ device: The device to use for evaluation.
47
+ enable_rfid: A boolean, whether enabling rFID score.
48
+ enable_inception_score: A boolean, whether enabling Inception Score.
49
+ enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure.
50
+ enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure.
51
+ num_codebook_entries: An integer, the number of codebook entries.
52
+ """
53
+ self._device = device
54
+
55
+ self._enable_rfid = enable_rfid
56
+ self._enable_inception_score = enable_inception_score
57
+ self._enable_codebook_usage_measure = enable_codebook_usage_measure
58
+ self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
59
+ self._num_codebook_entries = num_codebook_entries
60
+
61
+ # Variables related to Inception score and rFID.
62
+ self._inception_model = None
63
+ self._is_num_features = 0
64
+ self._rfid_num_features = 0
65
+ if self._enable_inception_score or self._enable_rfid:
66
+ self._rfid_num_features = 2048
67
+ self._is_num_features = 1008
68
+ self._inception_model = get_inception_model().to(self._device)
69
+ self._inception_model.eval()
70
+ self._is_eps = 1e-16
71
+ self._rfid_eps = 1e-6
72
+
73
+ self.reset_metrics()
74
+
75
+ def reset_metrics(self):
76
+ """Resets all metrics."""
77
+ self._num_examples = 0
78
+ self._num_updates = 0
79
+
80
+ self._is_prob_total = torch.zeros(
81
+ self._is_num_features, dtype=torch.float64, device=self._device
82
+ )
83
+ self._is_total_kl_d = torch.zeros(
84
+ self._is_num_features, dtype=torch.float64, device=self._device
85
+ )
86
+ self._rfid_real_sigma = torch.zeros(
87
+ (self._rfid_num_features, self._rfid_num_features),
88
+ dtype=torch.float64, device=self._device
89
+ )
90
+ self._rfid_real_total = torch.zeros(
91
+ self._rfid_num_features, dtype=torch.float64, device=self._device
92
+ )
93
+ self._rfid_fake_sigma = torch.zeros(
94
+ (self._rfid_num_features, self._rfid_num_features),
95
+ dtype=torch.float64, device=self._device
96
+ )
97
+ self._rfid_fake_total = torch.zeros(
98
+ self._rfid_num_features, dtype=torch.float64, device=self._device
99
+ )
100
+
101
+ self._set_of_codebook_indices = set()
102
+ self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device)
103
+
104
+ def update(
105
+ self,
106
+ real_images: torch.Tensor,
107
+ fake_images: torch.Tensor,
108
+ codebook_indices: Optional[torch.Tensor] = None
109
+ ):
110
+ """Updates the metrics with the given images.
111
+
112
+ Args:
113
+ real_images: A torch.Tensor, the real images.
114
+ fake_images: A torch.Tensor, the fake images.
115
+ codebook_indices: A torch.Tensor, the indices of the codebooks for each image.
116
+
117
+ Raises:
118
+ ValueError: If the fake images is not in RGB (3 channel).
119
+ ValueError: If the fake and real images have different shape.
120
+ """
121
+
122
+ batch_size = real_images.shape[0]
123
+ dim = tuple(range(1, real_images.ndim))
124
+ self._num_examples += batch_size
125
+ self._num_updates += 1
126
+
127
+ if self._enable_inception_score or self._enable_rfid:
128
+ # Quantize to uint8 as a real image.
129
+ fake_inception_images = (fake_images * 255).to(torch.uint8)
130
+ features_fake = self._inception_model(fake_inception_images)
131
+ inception_logits_fake = features_fake["logits_unbiased"]
132
+ inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1)
133
+
134
+ if self._enable_inception_score:
135
+ probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64)
136
+
137
+ log_prob = torch.log(inception_probabilities_fake + self._is_eps)
138
+ if log_prob.dtype != inception_probabilities_fake.dtype:
139
+ log_prob = log_prob.to(inception_probabilities_fake)
140
+ kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64)
141
+
142
+ self._is_prob_total += probabiliies_sum
143
+ self._is_total_kl_d += kl_sum
144
+
145
+ if self._enable_rfid:
146
+ real_inception_images = (real_images * 255).to(torch.uint8)
147
+ features_real = self._inception_model(real_inception_images)
148
+ if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or
149
+ features_real['2048'].shape[1] != features_fake['2048'].shape[1]):
150
+ raise ValueError(f"Number of features should be equal for real and fake.")
151
+
152
+ for f_real, f_fake in zip(features_real['2048'], features_fake['2048']):
153
+ self._rfid_real_total += f_real
154
+ self._rfid_fake_total += f_fake
155
+
156
+ self._rfid_real_sigma += torch.outer(f_real, f_real)
157
+ self._rfid_fake_sigma += torch.outer(f_fake, f_fake)
158
+
159
+ if self._enable_codebook_usage_measure:
160
+ self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist())
161
+
162
+ if self._enable_codebook_entropy_measure:
163
+ entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True)
164
+ self._codebook_frequencies.index_add_(0, entries.int(), counts.double())
165
+
166
+
167
+ def result(self) -> Mapping[Text, torch.Tensor]:
168
+ """Returns the evaluation result."""
169
+ eval_score = {}
170
+
171
+ if self._num_examples < 1:
172
+ raise ValueError("No examples to evaluate.")
173
+
174
+ if self._enable_inception_score:
175
+ mean_probs = self._is_prob_total / self._num_examples
176
+ log_mean_probs = torch.log(mean_probs + self._is_eps)
177
+ if log_mean_probs.dtype != self._is_prob_total.dtype:
178
+ log_mean_probs = log_mean_probs.to(self._is_prob_total)
179
+ excess_entropy = self._is_prob_total * log_mean_probs
180
+ avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples
181
+
182
+ inception_score = torch.exp(avg_kl_d).item()
183
+ eval_score["InceptionScore"] = inception_score
184
+
185
+ if self._enable_rfid:
186
+ mu_real = self._rfid_real_total / self._num_examples
187
+ mu_fake = self._rfid_fake_total / self._num_examples
188
+ sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples)
189
+ sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples)
190
+
191
+ mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu()
192
+ sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu()
193
+
194
+ diff = mu_real - mu_fake
195
+
196
+ # Product might be almost singular.
197
+ covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False)
198
+ # Numerical error might give slight imaginary component.
199
+ if np.iscomplexobj(covmean):
200
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
201
+ m = np.max(np.abs(covmean.imag))
202
+ raise ValueError("Imaginary component {}".format(m))
203
+ covmean = covmean.real
204
+
205
+ tr_covmean = np.trace(covmean)
206
+
207
+ if not np.isfinite(covmean).all():
208
+ tr_covmean = np.sum(np.sqrt((
209
+ (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps))
210
+ / (self._rfid_eps * self._rfid_eps)
211
+ ))
212
+
213
+ rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake)
214
+ - 2 * tr_covmean
215
+ )
216
+ if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)):
217
+ warnings.warn("The product of covariance of train and test features is out of bounds.")
218
+
219
+ eval_score["rFID"] = rfid
220
+
221
+ if self._enable_codebook_usage_measure:
222
+ usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries
223
+ eval_score["CodebookUsage"] = usage
224
+
225
+ if self._enable_codebook_entropy_measure:
226
+ probs = self._codebook_frequencies / self._codebook_frequencies.sum()
227
+ entropy = (-torch.log2(probs + 1e-8) * probs).sum()
228
+ eval_score["CodebookEntropy"] = entropy
229
+
230
+ return eval_score
evaluator/inception.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inception model for FID evaluation.
2
+
3
+ Reference:
4
+ https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from torch_fidelity.feature_extractor_base import FeatureExtractorBase
10
+ from torch_fidelity.helpers import vassert
11
+ from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
12
+ from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
13
+
14
+ try:
15
+ from torchvision.models.utils import load_state_dict_from_url
16
+ except ImportError:
17
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
18
+
19
+
20
+ # Note: Compared shasum and models should be the same.
21
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
22
+
23
+ class FeatureExtractorInceptionV3(FeatureExtractorBase):
24
+ INPUT_IMAGE_SIZE = 299
25
+
26
+ def __init__(
27
+ self,
28
+ name,
29
+ features_list,
30
+ **kwargs,
31
+ ):
32
+ """
33
+ InceptionV3 feature extractor for 2D RGB 24bit images.
34
+
35
+ Args:
36
+
37
+ name (str): Unique name of the feature extractor, must be the same as used in
38
+ :func:`register_feature_extractor`.
39
+
40
+ features_list (list): A list of the requested feature names, which will be produced for each input. This
41
+ feature extractor provides the following features:
42
+
43
+ - '64'
44
+ - '192'
45
+ - '768'
46
+ - '2048'
47
+ - 'logits_unbiased'
48
+ - 'logits'
49
+
50
+ """
51
+ super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
52
+ self.feature_extractor_internal_dtype = torch.float64
53
+
54
+ self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
55
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
56
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
57
+ self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
58
+
59
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
60
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
61
+ self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
62
+
63
+ self.Mixed_5b = InceptionA(192, pool_features=32)
64
+ self.Mixed_5c = InceptionA(256, pool_features=64)
65
+ self.Mixed_5d = InceptionA(288, pool_features=64)
66
+ self.Mixed_6a = InceptionB(288)
67
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
68
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
69
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
70
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
71
+
72
+ self.Mixed_7a = InceptionD(768)
73
+ self.Mixed_7b = InceptionE_1(1280)
74
+ self.Mixed_7c = InceptionE_2(2048)
75
+ self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
76
+
77
+ self.fc = torch.nn.Linear(2048, 1008)
78
+
79
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
80
+ #state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
81
+ self.load_state_dict(state_dict)
82
+
83
+ self.to(self.feature_extractor_internal_dtype)
84
+ self.requires_grad_(False)
85
+ self.eval()
86
+
87
+ def forward(self, x):
88
+ vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
89
+ vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
90
+ features = {}
91
+ remaining_features = self.features_list.copy()
92
+
93
+ x = x.to(self.feature_extractor_internal_dtype)
94
+ # N x 3 x ? x ?
95
+
96
+ x = interpolate_bilinear_2d_like_tensorflow1x(
97
+ x,
98
+ size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
99
+ align_corners=False,
100
+ )
101
+ # N x 3 x 299 x 299
102
+
103
+ # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
104
+ x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
105
+ # N x 3 x 299 x 299
106
+
107
+ x = self.Conv2d_1a_3x3(x)
108
+ # N x 32 x 149 x 149
109
+ x = self.Conv2d_2a_3x3(x)
110
+ # N x 32 x 147 x 147
111
+ x = self.Conv2d_2b_3x3(x)
112
+ # N x 64 x 147 x 147
113
+ x = self.MaxPool_1(x)
114
+ # N x 64 x 73 x 73
115
+
116
+ if '64' in remaining_features:
117
+ features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
118
+ remaining_features.remove('64')
119
+ if len(remaining_features) == 0:
120
+ return features
121
+
122
+ x = self.Conv2d_3b_1x1(x)
123
+ # N x 80 x 73 x 73
124
+ x = self.Conv2d_4a_3x3(x)
125
+ # N x 192 x 71 x 71
126
+ x = self.MaxPool_2(x)
127
+ # N x 192 x 35 x 35
128
+
129
+ if '192' in remaining_features:
130
+ features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
131
+ remaining_features.remove('192')
132
+ if len(remaining_features) == 0:
133
+ return features
134
+
135
+ x = self.Mixed_5b(x)
136
+ # N x 256 x 35 x 35
137
+ x = self.Mixed_5c(x)
138
+ # N x 288 x 35 x 35
139
+ x = self.Mixed_5d(x)
140
+ # N x 288 x 35 x 35
141
+ x = self.Mixed_6a(x)
142
+ # N x 768 x 17 x 17
143
+ x = self.Mixed_6b(x)
144
+ # N x 768 x 17 x 17
145
+ x = self.Mixed_6c(x)
146
+ # N x 768 x 17 x 17
147
+ x = self.Mixed_6d(x)
148
+ # N x 768 x 17 x 17
149
+ x = self.Mixed_6e(x)
150
+ # N x 768 x 17 x 17
151
+
152
+ if '768' in remaining_features:
153
+ features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
154
+ remaining_features.remove('768')
155
+ if len(remaining_features) == 0:
156
+ return features
157
+
158
+ x = self.Mixed_7a(x)
159
+ # N x 1280 x 8 x 8
160
+ x = self.Mixed_7b(x)
161
+ # N x 2048 x 8 x 8
162
+ x = self.Mixed_7c(x)
163
+ # N x 2048 x 8 x 8
164
+ x = self.AvgPool(x)
165
+ # N x 2048 x 1 x 1
166
+
167
+ x = torch.flatten(x, 1)
168
+ # N x 2048
169
+
170
+ if '2048' in remaining_features:
171
+ features['2048'] = x
172
+ remaining_features.remove('2048')
173
+ if len(remaining_features) == 0:
174
+ return features
175
+
176
+ if 'logits_unbiased' in remaining_features:
177
+ x = x.mm(self.fc.weight.T)
178
+ # N x 1008 (num_classes)
179
+ features['logits_unbiased'] = x
180
+ remaining_features.remove('logits_unbiased')
181
+ if len(remaining_features) == 0:
182
+ return features
183
+
184
+ x = x + self.fc.bias.unsqueeze(0)
185
+ else:
186
+ x = self.fc(x)
187
+ # N x 1008 (num_classes)
188
+
189
+ features['logits'] = x
190
+ return features
191
+
192
+ @staticmethod
193
+ def get_provided_features_list():
194
+ return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
195
+
196
+ @staticmethod
197
+ def get_default_feature_layer_for_metric(metric):
198
+ return {
199
+ 'isc': 'logits_unbiased',
200
+ 'fid': '2048',
201
+ 'kid': '2048',
202
+ 'prc': '2048',
203
+ }[metric]
204
+
205
+ @staticmethod
206
+ def can_be_compiled():
207
+ return True
208
+
209
+ @staticmethod
210
+ def get_dummy_input_for_compile():
211
+ return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
212
+
213
+ def get_inception_model():
214
+ model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
215
+ return model
examples/batch_inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Batch inference example for VibeToken.
3
+
4
+ Demonstrates how to process multiple images efficiently in batches.
5
+
6
+ Usage:
7
+ # Auto mode (recommended)
8
+ python examples/batch_inference.py --auto \
9
+ --config configs/vibetoken_ll.yaml \
10
+ --checkpoint path/to/checkpoint.bin \
11
+ --input_dir path/to/images/ \
12
+ --output_dir path/to/output/ \
13
+ --batch_size 4
14
+
15
+ # Manual mode
16
+ python examples/batch_inference.py \
17
+ --config configs/vibetoken_ll.yaml \
18
+ --checkpoint path/to/checkpoint.bin \
19
+ --input_dir path/to/images/ \
20
+ --output_dir path/to/output/ \
21
+ --batch_size 4 \
22
+ --resolution 512 \
23
+ --encoder_patch_size 16,32 \
24
+ --decoder_patch_size 16
25
+ """
26
+
27
+ import argparse
28
+ import time
29
+ from pathlib import Path
30
+
31
+ import torch
32
+ from PIL import Image
33
+ import numpy as np
34
+
35
+ import sys
36
+ sys.path.insert(0, str(Path(__file__).parent.parent))
37
+
38
+ from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
39
+
40
+
41
+ def parse_patch_size(value):
42
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
43
+ if value is None:
44
+ return None
45
+ if ',' in value:
46
+ parts = value.split(',')
47
+ return (int(parts[0]), int(parts[1]))
48
+ return int(value)
49
+
50
+
51
+ def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple:
52
+ """Load and preprocess image.
53
+
54
+ Args:
55
+ path: Path to image
56
+ target_size: Optional target size (width, height) for resizing
57
+ auto_mode: If True, use auto_preprocess_image for cropping
58
+
59
+ Returns:
60
+ image: numpy array
61
+ patch_size: auto-determined patch size (if auto_mode) or None
62
+ """
63
+ img = Image.open(path).convert("RGB")
64
+
65
+ if auto_mode:
66
+ # Use centralized auto_preprocess_image
67
+ img, patch_size, info = auto_preprocess_image(img, verbose=False)
68
+ return np.array(img), patch_size, info
69
+ else:
70
+ if target_size:
71
+ img = img.resize(target_size, Image.LANCZOS)
72
+ # Always center crop to ensure dimensions divisible by 32
73
+ img = center_crop_to_multiple(img, multiple=32)
74
+ return np.array(img), None, None
75
+
76
+
77
+ def main():
78
+ parser = argparse.ArgumentParser(description="VibeToken batch inference example")
79
+ parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
80
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
81
+ parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images")
82
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images")
83
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
84
+ parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
85
+
86
+ # Auto mode
87
+ parser.add_argument("--auto", action="store_true",
88
+ help="Auto mode: automatically determine optimal settings per image")
89
+
90
+ # Manual mode options
91
+ parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)")
92
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
93
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
94
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
95
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
96
+ args = parser.parse_args()
97
+
98
+ # Parse patch sizes
99
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
100
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
101
+
102
+ # Check CUDA
103
+ if args.device == "cuda" and not torch.cuda.is_available():
104
+ print("CUDA not available, falling back to CPU")
105
+ args.device = "cpu"
106
+
107
+ # Create output directory
108
+ output_dir = Path(args.output_dir)
109
+ output_dir.mkdir(parents=True, exist_ok=True)
110
+
111
+ # Load tokenizer
112
+ print(f"Loading tokenizer from {args.config}")
113
+ tokenizer = VibeTokenTokenizer.from_config(
114
+ config_path=args.config,
115
+ checkpoint_path=args.checkpoint,
116
+ device=args.device,
117
+ )
118
+
119
+ if args.auto:
120
+ print("Running in AUTO MODE - optimal settings determined per image")
121
+ else:
122
+ print(f"Running in MANUAL MODE - resolution: {args.resolution}")
123
+ if encoder_patch_size:
124
+ print(f" Encoder patch size: {encoder_patch_size}")
125
+ if decoder_patch_size:
126
+ print(f" Decoder patch size: {decoder_patch_size}")
127
+
128
+ # Find all images
129
+ input_dir = Path(args.input_dir)
130
+ image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
131
+ image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions]
132
+ print(f"Found {len(image_paths)} images")
133
+
134
+ if not image_paths:
135
+ print("No images found!")
136
+ return
137
+
138
+ # Process in batches
139
+ target_size = (args.resolution, args.resolution) if not args.auto else None
140
+ total_time = 0
141
+ num_processed = 0
142
+
143
+ if args.auto:
144
+ # AUTO MODE: Process images one by one since each may have different sizes
145
+ for i, path in enumerate(image_paths):
146
+ try:
147
+ img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True)
148
+ batch_array = img_array[np.newaxis, ...] # Add batch dim
149
+
150
+ start_time = time.time()
151
+
152
+ # Reconstruct with auto-determined patch size
153
+ height, width = info['cropped_size'][1], info['cropped_size'][0]
154
+ reconstructed = tokenizer.reconstruct(
155
+ batch_array,
156
+ encode_patch_size=patch_size,
157
+ decode_patch_size=patch_size,
158
+ target_height=height,
159
+ target_width=width,
160
+ )
161
+
162
+ if args.device == "cuda":
163
+ torch.cuda.synchronize()
164
+
165
+ batch_time = time.time() - start_time
166
+ total_time += batch_time
167
+ num_processed += 1
168
+
169
+ # Save output
170
+ output_images = tokenizer.to_pil(reconstructed)
171
+ output_path = output_dir / f"{path.stem}_recon.png"
172
+ output_images[0].save(output_path)
173
+
174
+ print(f"[{i+1}/{len(image_paths)}] {path.name}: "
175
+ f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, "
176
+ f"patch_size={patch_size}, {batch_time:.2f}s")
177
+
178
+ except Exception as e:
179
+ print(f"Error processing {path}: {e}")
180
+ continue
181
+ else:
182
+ # MANUAL MODE: Batch processing with uniform size
183
+ for batch_start in range(0, len(image_paths), args.batch_size):
184
+ batch_paths = image_paths[batch_start:batch_start + args.batch_size]
185
+ batch_names = [p.stem for p in batch_paths]
186
+
187
+ # Load batch
188
+ batch_images = []
189
+ for path in batch_paths:
190
+ try:
191
+ img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False)
192
+ batch_images.append(img_array)
193
+ except Exception as e:
194
+ print(f"Error loading {path}: {e}")
195
+ continue
196
+
197
+ if not batch_images:
198
+ continue
199
+
200
+ # Stack into batch tensor
201
+ batch_array = np.stack(batch_images, axis=0)
202
+
203
+ # Measure time
204
+ start_time = time.time()
205
+
206
+ # Reconstruct
207
+ reconstructed = tokenizer.reconstruct(
208
+ batch_array,
209
+ encode_patch_size=encoder_patch_size,
210
+ decode_patch_size=decoder_patch_size,
211
+ target_height=args.resolution,
212
+ target_width=args.resolution,
213
+ )
214
+
215
+ # Synchronize if GPU
216
+ if args.device == "cuda":
217
+ torch.cuda.synchronize()
218
+
219
+ batch_time = time.time() - start_time
220
+ total_time += batch_time
221
+ num_processed += len(batch_images)
222
+
223
+ # Save outputs
224
+ output_images = tokenizer.to_pil(reconstructed)
225
+ for name, img in zip(batch_names[:len(output_images)], output_images):
226
+ output_path = output_dir / f"{name}_recon.png"
227
+ img.save(output_path)
228
+
229
+ print(f"Processed batch {batch_start // args.batch_size + 1}: "
230
+ f"{len(batch_images)} images in {batch_time:.2f}s "
231
+ f"({len(batch_images) / batch_time:.2f} img/s)")
232
+
233
+ # Summary
234
+ if num_processed > 0:
235
+ print(f"\nTotal: {num_processed} images in {total_time:.2f}s")
236
+ print(f"Average: {num_processed / total_time:.2f} images/sec")
237
+ print(f"Per image: {total_time / num_processed * 1000:.1f}ms")
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
examples/encode_decode.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Basic encode-decode example for VibeToken.
3
+
4
+ Demonstrates how to:
5
+ 1. Load the tokenizer from config and checkpoint
6
+ 2. Encode an image to discrete tokens
7
+ 3. Decode tokens back to an image
8
+ 4. Save the reconstructed image
9
+
10
+ Usage:
11
+ # Auto mode (recommended)
12
+ python examples/encode_decode.py --auto \
13
+ --config configs/vibetoken_ll.yaml \
14
+ --checkpoint path/to/checkpoint.bin \
15
+ --image path/to/image.jpg \
16
+ --output reconstructed.png
17
+
18
+ # Manual mode
19
+ python examples/encode_decode.py \
20
+ --config configs/vibetoken_ll.yaml \
21
+ --checkpoint path/to/checkpoint.bin \
22
+ --image path/to/image.jpg \
23
+ --output reconstructed.png \
24
+ --encoder_patch_size 16,32 \
25
+ --decoder_patch_size 16
26
+ """
27
+
28
+ import argparse
29
+ from pathlib import Path
30
+
31
+ import torch
32
+ from PIL import Image
33
+
34
+ import sys
35
+ sys.path.insert(0, str(Path(__file__).parent.parent))
36
+
37
+ from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
38
+
39
+
40
+ def parse_patch_size(value):
41
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
42
+ if value is None:
43
+ return None
44
+ if ',' in value:
45
+ parts = value.split(',')
46
+ return (int(parts[0]), int(parts[1]))
47
+ return int(value)
48
+
49
+
50
+ def main():
51
+ parser = argparse.ArgumentParser(description="VibeToken encode-decode example")
52
+ parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
53
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
54
+ parser.add_argument("--image", type=str, required=True, help="Path to input image")
55
+ parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path")
56
+ parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
57
+
58
+ # Auto mode
59
+ parser.add_argument("--auto", action="store_true",
60
+ help="Auto mode: automatically determine optimal settings")
61
+
62
+ parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)")
63
+ parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)")
64
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
65
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
66
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
67
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
68
+ parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode")
69
+
70
+ args = parser.parse_args()
71
+
72
+ # Check if CUDA is available
73
+ if args.device == "cuda" and not torch.cuda.is_available():
74
+ print("CUDA not available, falling back to CPU")
75
+ args.device = "cpu"
76
+
77
+ print(f"Loading tokenizer from {args.config}")
78
+ tokenizer = VibeTokenTokenizer.from_config(
79
+ config_path=args.config,
80
+ checkpoint_path=args.checkpoint,
81
+ device=args.device,
82
+ )
83
+ print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, "
84
+ f"num_latent_tokens={tokenizer.num_latent_tokens}")
85
+
86
+ # Load image
87
+ print(f"Loading image from {args.image}")
88
+ image = Image.open(args.image).convert("RGB")
89
+ original_size = image.size # (W, H)
90
+ print(f"Original image size: {original_size[0]}x{original_size[1]}")
91
+
92
+ if args.auto:
93
+ # AUTO MODE - use centralized auto_preprocess_image
94
+ print("\n=== AUTO MODE ===")
95
+ image, patch_size, info = auto_preprocess_image(image, verbose=True)
96
+ encoder_patch_size = patch_size
97
+ decoder_patch_size = patch_size
98
+ height, width = info['cropped_size'][1], info['cropped_size'][0]
99
+ print("=================\n")
100
+
101
+ # Encode to tokens
102
+ print("Encoding image to tokens...")
103
+ print(f" Using encoder patch size: {encoder_patch_size}")
104
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
105
+ print(f"Token shape: {tokens.shape}")
106
+
107
+ # Decode back to image
108
+ print(f"Decoding tokens to image ({width}x{height})...")
109
+ print(f" Using decoder patch size: {decoder_patch_size}")
110
+ reconstructed = tokenizer.decode(
111
+ tokens, height=height, width=width, patch_size=decoder_patch_size
112
+ )
113
+
114
+ else:
115
+ # MANUAL MODE
116
+ # Parse patch sizes
117
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
118
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
119
+
120
+ # Always center crop to ensure dimensions divisible by 32
121
+ image = center_crop_to_multiple(image, multiple=32)
122
+ cropped_size = image.size # (W, H)
123
+ if cropped_size != original_size:
124
+ print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
125
+
126
+ # Encode to tokens
127
+ print("Encoding image to tokens...")
128
+ if encoder_patch_size:
129
+ print(f" Using encoder patch size: {encoder_patch_size}")
130
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens)
131
+ print(f"Token shape: {tokens.shape}")
132
+
133
+ if tokenizer.model.quantize_mode == "mvq":
134
+ print(f" - Batch size: {tokens.shape[0]}")
135
+ print(f" - Num codebooks: {tokens.shape[1]}")
136
+ print(f" - Sequence length: {tokens.shape[2]}")
137
+ else:
138
+ print(f" - Batch size: {tokens.shape[0]}")
139
+ print(f" - Sequence length: {tokens.shape[1]}")
140
+
141
+ # Decode back to image (use cropped size as default)
142
+ height = args.height or cropped_size[1]
143
+ width = args.width or cropped_size[0]
144
+ print(f"Decoding tokens to image ({width}x{height})...")
145
+ if decoder_patch_size:
146
+ print(f" Using decoder patch size: {decoder_patch_size}")
147
+
148
+ reconstructed = tokenizer.decode(
149
+ tokens, height=height, width=width, patch_size=decoder_patch_size
150
+ )
151
+
152
+ print(f"Reconstructed image shape: {reconstructed.shape}")
153
+
154
+ # Convert to PIL and save
155
+ output_images = tokenizer.to_pil(reconstructed)
156
+ output_path = Path(args.output)
157
+ output_images[0].save(output_path)
158
+ print(f"Saved reconstructed image to {output_path}")
159
+
160
+ # Compute PSNR (compare with cropped image)
161
+ import numpy as np
162
+ original_np = np.array(image).astype(np.float32)
163
+ recon_np = np.array(output_images[0]).astype(np.float32)
164
+ if original_np.shape == recon_np.shape:
165
+ mse = np.mean((original_np - recon_np) ** 2)
166
+ if mse > 0:
167
+ psnr = 20 * np.log10(255.0 / np.sqrt(mse))
168
+ print(f"PSNR: {psnr:.2f} dB")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
generate.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
3
+
4
+ """Example run:
5
+ python generate.py \
6
+ --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
7
+ --gpt-model GPT-XXL --num-output-layer 4 \
8
+ --num-codebooks 8 --codebook-size 32768 \
9
+ --image-size 256 --cfg-scale 2.0 --top-k 0 --temperature 1.0 \
10
+ --class-dropout-prob 0.1 \
11
+ --extra-layers "QKV" \
12
+ --latent-size 65 \
13
+ --config ./configs/vibetoken_ll.yaml \
14
+ --vq-ckpt ./checkpoints/VibeToken_LL.bin \
15
+ --sample-dir ./assets/ \
16
+ --skip-folder-creation \
17
+ --compile \
18
+ --decoder-patch-size 16,16 \
19
+ --target-resolution 1024,1024 \
20
+ --llamagen-target-resolution 256,256 \
21
+ --precision bf16
22
+ """
23
+
24
+ import torch
25
+
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+ torch.set_float32_matmul_precision('high')
29
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
30
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
31
+ import torch.nn.functional as F
32
+ import torch.distributed as dist
33
+
34
+ from tqdm import tqdm
35
+ import os
36
+ from PIL import Image
37
+ import numpy as np
38
+ import math
39
+ import argparse
40
+ import sys
41
+ from omegaconf import OmegaConf
42
+
43
+ from vibetokengen.model import GPT_models
44
+ from vibetokengen.generate import generate
45
+
46
+ from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
47
+
48
+
49
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
50
+ """
51
+ Builds a single .npz file from a folder of .png samples.
52
+ """
53
+ samples = []
54
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
55
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
56
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
57
+ samples.append(sample_np)
58
+ samples = np.stack(samples)
59
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
60
+ npz_path = f"{sample_dir}.npz"
61
+ np.savez(npz_path, arr_0=samples)
62
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
63
+ return npz_path
64
+
65
+
66
+ def main(args):
67
+ # Setup PyTorch:
68
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
69
+ torch.set_grad_enabled(False)
70
+
71
+ # Set global seed for reproducibility
72
+ torch.manual_seed(args.global_seed)
73
+ np.random.seed(args.global_seed)
74
+ if torch.cuda.is_available():
75
+ torch.cuda.manual_seed(args.global_seed)
76
+ torch.cuda.manual_seed_all(args.global_seed)
77
+
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
80
+
81
+ # Load VibeToken model
82
+ vq_model = VibeTokenTokenizer.from_config(
83
+ args.config,
84
+ args.vq_ckpt,
85
+ device=device,
86
+ dtype=precision,
87
+ )
88
+ print(f"VibeToken image tokenizer is loaded")
89
+
90
+ # create and load gpt model
91
+ gpt_model = GPT_models[args.gpt_model](
92
+ vocab_size=args.codebook_size,
93
+ block_size=args.latent_size,
94
+ num_classes=args.num_classes,
95
+ cls_token_num=args.cls_token_num,
96
+ model_type=args.gpt_type,
97
+ num_codebooks=args.num_codebooks,
98
+ n_output_layer=args.num_output_layer,
99
+ class_dropout_prob=args.class_dropout_prob,
100
+ extra_layers=args.extra_layers,
101
+ capping=args.capping,
102
+ ).to(device=device, dtype=precision)
103
+ print(f"GPT model is loaded")
104
+
105
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu", weights_only=False)
106
+ if args.from_fsdp: # fsdp
107
+ model_weight = checkpoint
108
+ elif "model" in checkpoint: # ddp
109
+ model_weight = checkpoint["model"]
110
+ elif "module" in checkpoint: # deepspeed
111
+ model_weight = checkpoint["module"]
112
+ elif "state_dict" in checkpoint:
113
+ model_weight = checkpoint["state_dict"]
114
+ else:
115
+ raise Exception("please check model weight, maybe add --from-fsdp to run command")
116
+ gpt_model.load_state_dict(model_weight, strict=True)
117
+ gpt_model.eval()
118
+ del checkpoint
119
+
120
+ print(f"GPT model weights are loaded")
121
+
122
+ if args.compile:
123
+ print(f"compiling the model...")
124
+ gpt_model = torch.compile(
125
+ gpt_model,
126
+ mode="reduce-overhead",
127
+ fullgraph=True
128
+ ) # requires PyTorch 2.0 (optional)
129
+ else:
130
+ print(f"no model compile")
131
+
132
+ print(f"GPT model is compiled")
133
+
134
+ # Create folder to save samples:
135
+ model_string_name = args.gpt_model.replace("/", "-")
136
+ if args.from_fsdp:
137
+ ckpt_string_name = args.gpt_ckpt.split('/')[-2]
138
+ else:
139
+ ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
140
+ folder_name = f"{model_string_name}-{ckpt_string_name}-target-resolution-{args.target_resolution}-llamagen-target-resolution-{args.llamagen_target_resolution}-vibetoken-" \
141
+ f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
142
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
143
+ if args.skip_folder_creation:
144
+ sample_folder_dir = args.sample_dir
145
+ else:
146
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
147
+
148
+ os.makedirs(sample_folder_dir, exist_ok=True)
149
+ print(f"Saving .png samples at {sample_folder_dir}")
150
+
151
+ multiplier = 2 if args.cfg_scale > 1.0 else 1
152
+
153
+ # Use fixed class labels
154
+ class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
155
+ c_indices = torch.tensor(class_labels, device=device)
156
+ n = len(class_labels)
157
+ nrow = 4 # 2 rows x 4 columns for 8 images
158
+
159
+ index_sample = generate(
160
+ gpt_model, c_indices, args.latent_size, args.num_codebooks,
161
+ cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
162
+ target_h=torch.tensor(args.llamagen_target_resolution[0]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
163
+ target_w=torch.tensor(args.llamagen_target_resolution[1]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
164
+ temperature=args.temperature, top_k=args.top_k,
165
+ top_p=args.top_p, sample_logits=True,
166
+ )
167
+
168
+ # Use VibeToken decode_tokens method
169
+ # VibeToken expects tokens in shape (batch_size, seq_len, 1)
170
+ index_sample = index_sample.unsqueeze(2)
171
+ samples = vq_model.decode(
172
+ index_sample,
173
+ height=args.target_resolution[0],
174
+ width=args.target_resolution[1],
175
+ patch_size=args.decoder_patch_size
176
+ )
177
+
178
+ # VibeToken output is in [0, 1] range, clamp and convert to uint8
179
+ samples = torch.clamp(samples, 0, 1)
180
+
181
+ # Create a grid of images (2 rows x 4 columns)
182
+ from torchvision.utils import make_grid
183
+ grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
184
+
185
+ # Convert to PIL and save
186
+ grid_np = (grid.permute(1, 2, 0).to(torch.float32).cpu().numpy() * 255).astype('uint8')
187
+ Image.fromarray(grid_np).save(f"{sample_folder_dir}/generated_images.png")
188
+ print(f"Saved grid of {n} images to {sample_folder_dir}/generated_images.png")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
194
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
195
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i",
196
+ help="class-conditional or text-conditional")
197
+ parser.add_argument("--from-fsdp", action='store_true')
198
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
199
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
200
+ parser.add_argument("--compile", action='store_true', default=True)
201
+ # parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
202
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
203
+ parser.add_argument("--config", type=str, required=True, help="Path to VibeToken config file")
204
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
205
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
206
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
207
+ parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
208
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
209
+ parser.add_argument("--num-classes", type=int, default=1000)
210
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
211
+ parser.add_argument("--cfg-interval", type=float, default=-1)
212
+ parser.add_argument("--sample-dir", type=str, default="samples")
213
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
214
+ parser.add_argument("--num-fid-samples", type=int, default=50000)
215
+ parser.add_argument("--global-seed", type=int, default=0) # not used
216
+ parser.add_argument("--top-k", type=int, default=500, help="top-k value to sample with")
217
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
218
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
219
+ parser.add_argument("--num-codebooks", type=int, default=1)
220
+ parser.add_argument("--num-output-layer", type=int, default=1)
221
+ parser.add_argument("--class-dropout-prob", type=float, default=0.1)
222
+ parser.add_argument("--extra-layers", type=str, choices=['QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip'], default=None,
223
+ help="Type of extra layers to add: QK (query-key), QKV (query-key-value), FC (fully connected), cap (caption), clip (clip), QK_cap (query-key-caption), QKV_cap (query-key-value-caption), QK_clip (query-key-clip), QKV_clip (query-key-value-clip), QK_FC_cap (query-key-fully-connected-caption), QKV_FC_cap (query-key-value-fully-connected-caption), QK_FC_clip (query-key-fully-connected-clip), QKV_FC_clip (query-key-value-fully-connected-clip)")
224
+ parser.add_argument("--capping", type=float, default=50.0, help="Capping for attention softmax")
225
+
226
+ # VibeToken dynamic
227
+ parser.add_argument("--decoder-patch-size", type=str, default="8,8", help="Decoder patch size as 'width,height'")
228
+ parser.add_argument("--target-resolution", type=str, default="256,256", help="Target resolution as 'width,height'")
229
+ parser.add_argument("--llamagen-target-resolution", type=str, default="256,256", help="LlamaGen target resolution as 'width,height'")
230
+
231
+ parser.add_argument("--latent-size", type=int, default=16, help="Latent size")
232
+ parser.add_argument("--skip-folder-creation", action='store_true', default=False, help="skip folder creation")
233
+
234
+ args = parser.parse_args()
235
+
236
+ args.decoder_patch_size = tuple(map(int, args.decoder_patch_size.split(",")))
237
+ args.target_resolution = tuple(map(int, args.target_resolution.split(",")))
238
+ args.llamagen_target_resolution = tuple(map(int, args.llamagen_target_resolution.split(",")))
239
+
240
+ main(args)
generator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Generator module placeholder for VibeToken-Gen integration."""
2
+
3
+ # Future: Add GPT-based generator for image synthesis
4
+ # from .gpt import VibeTokenGenerator
modeling/__init__.py ADDED
File without changes
modeling/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from .ema_model import EMAModel
3
+ from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, ReconstructionLoss_Single_Stage
4
+ from .blocks import TiTokEncoder, TiTokDecoder, TATiTokDecoder, UViTBlock
5
+ from .maskgit_vqgan import Decoder as Pixel_Decoder
6
+ from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
modeling/modules/base_model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class implementation for models.
2
+
3
+ Reference:
4
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
5
+ """
6
+ import os
7
+ from typing import Union, Callable, Dict, Optional
8
+
9
+ import torch
10
+
11
+
12
+ class BaseModel(torch.nn.Module):
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def save_pretrained_weight(
18
+ self,
19
+ save_directory: Union[str, os.PathLike],
20
+ save_function: Callable = None,
21
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
22
+ ):
23
+ """Saves a model and its configuration file to a directory.
24
+
25
+ Args:
26
+ save_directory: A string or os.PathLike, directory to which to save.
27
+ Will be created if it doesn't exist.
28
+ save_function: A Callable function, the function to use to save the state dictionary.
29
+ Useful on distributed training like TPUs when one need to replace `torch.save` by
30
+ another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
31
+ state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
32
+ If `None`, the model's state dictionary will be saved.
33
+ """
34
+ if os.path.isfile(save_directory):
35
+ print(f"Provided path ({save_directory}) should be a directory, not a file")
36
+ return
37
+
38
+ if save_function is None:
39
+ save_function = torch.save
40
+
41
+ os.makedirs(save_directory, exist_ok=True)
42
+
43
+ model_to_save = self
44
+
45
+ if state_dict is None:
46
+ state_dict = model_to_save.state_dict()
47
+ weights_name = "pytorch_model.bin"
48
+
49
+ save_function(state_dict, os.path.join(save_directory, weights_name))
50
+
51
+ print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
52
+
53
+ def load_pretrained_weight(
54
+ self,
55
+ pretrained_model_path: Union[str, os.PathLike],
56
+ strict_loading: bool = True,
57
+ torch_dtype: Optional[torch.dtype] = None
58
+ ):
59
+ r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
60
+
61
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
62
+ the model, you should first set it back in training mode with `model.train()`.
63
+
64
+ Args:
65
+ pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
66
+
67
+ Raises:
68
+ ValueError: If pretrained_model_path does not exist.
69
+ """
70
+ # If pretrained_model_path is a file, set model_file to this file.
71
+ if os.path.isfile(pretrained_model_path):
72
+ model_file = pretrained_model_path
73
+ # If pretrained_model_path is a directory, set model_file to the path of the
74
+ # file "pytorch_model.bin" in this directory.
75
+ elif os.path.isdir(pretrained_model_path):
76
+ pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
77
+ if os.path.isfile(pretrained_model_path):
78
+ model_file = pretrained_model_path
79
+ else:
80
+ raise ValueError(f"{pretrained_model_path} does not exist")
81
+ else:
82
+ raise ValueError(f"{pretrained_model_path} does not exist")
83
+
84
+ # Load model state from checkpoint.
85
+ checkpoint = torch.load(model_file, map_location="cpu")
86
+ # Load state dictionary into self.
87
+ msg = self.load_state_dict(checkpoint, strict=strict_loading)
88
+ # Print information about loading weights.
89
+ print(f"loading weight from {model_file}, msg: {msg}")
90
+ # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
91
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
92
+ raise ValueError(
93
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
94
+ )
95
+ elif torch_dtype is not None:
96
+ self.to(torch_dtype)
97
+
98
+ # Set model in evaluation mode to deactivate DropOut modules by default.
99
+ self.eval()
100
+
101
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
102
+ """Gets the number of parameters in the module.
103
+
104
+ Args:
105
+ only_trainable: A boolean, whether to only include trainable parameters.
106
+ exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
107
+
108
+ Returns:
109
+ An integer, the number of parameters.
110
+ """
111
+
112
+ if exclude_embeddings:
113
+ embedding_param_names = [
114
+ f"{name}.weight"
115
+ for name, module_type in self.named_modules()
116
+ if isinstance(module_type, torch.nn.Embedding)
117
+ ]
118
+ non_embedding_parameters = [
119
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
120
+ ]
121
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
122
+ else:
123
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
124
+
modeling/modules/blocks.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer building blocks.
2
+
3
+ Reference:
4
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
5
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.checkpoint import checkpoint
12
+ from collections import OrderedDict
13
+ import einops
14
+ from einops.layers.torch import Rearrange
15
+
16
+
17
+ def modulate(x, shift, scale):
18
+ return x * (1 + scale) + shift
19
+
20
+
21
+ class ResidualAttentionBlock(nn.Module):
22
+ def __init__(
23
+ self,
24
+ d_model,
25
+ n_head,
26
+ mlp_ratio = 4.0,
27
+ act_layer = nn.GELU,
28
+ norm_layer = nn.LayerNorm
29
+ ):
30
+ super().__init__()
31
+
32
+ self.ln_1 = norm_layer(d_model)
33
+ self.attn = nn.MultiheadAttention(d_model, n_head)
34
+ self.mlp_ratio = mlp_ratio
35
+ # optionally we can disable the FFN
36
+ if mlp_ratio > 0:
37
+ self.ln_2 = norm_layer(d_model)
38
+ mlp_width = int(d_model * mlp_ratio)
39
+ self.mlp = nn.Sequential(OrderedDict([
40
+ ("c_fc", nn.Linear(d_model, mlp_width)),
41
+ ("gelu", act_layer()),
42
+ ("c_proj", nn.Linear(mlp_width, d_model))
43
+ ]))
44
+
45
+ def attention(
46
+ self,
47
+ x: torch.Tensor
48
+ ):
49
+ return self.attn(x, x, x, need_weights=False)[0]
50
+
51
+ def forward(
52
+ self,
53
+ x: torch.Tensor,
54
+ ):
55
+ attn_output = self.attention(x=self.ln_1(x))
56
+ x = x + attn_output
57
+ if self.mlp_ratio > 0:
58
+ x = x + self.mlp(self.ln_2(x))
59
+ return x
60
+
61
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
62
+ ATTENTION_MODE = 'flash'
63
+ else:
64
+ try:
65
+ import xformers
66
+ import xformers.ops
67
+ ATTENTION_MODE = 'xformers'
68
+ except:
69
+ ATTENTION_MODE = 'math'
70
+ print(f'attention mode is {ATTENTION_MODE}')
71
+
72
+
73
+ class Attention(nn.Module):
74
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
75
+ super().__init__()
76
+ self.num_heads = num_heads
77
+ head_dim = dim // num_heads
78
+ self.scale = qk_scale or head_dim ** -0.5
79
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
80
+ self.attn_drop = nn.Dropout(attn_drop)
81
+ self.proj = nn.Linear(dim, dim)
82
+ self.proj_drop = nn.Dropout(proj_drop)
83
+
84
+ def forward(self, x):
85
+ B, L, C = x.shape
86
+
87
+ qkv = self.qkv(x)
88
+ if ATTENTION_MODE == 'flash':
89
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
90
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
91
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
92
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
93
+ elif ATTENTION_MODE == 'xformers':
94
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
95
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
96
+ x = xformers.ops.memory_efficient_attention(q, k, v)
97
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
98
+ elif ATTENTION_MODE == 'math':
99
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
100
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
101
+ attn = (q @ k.transpose(-2, -1)) * self.scale
102
+ attn = attn.softmax(dim=-1)
103
+ attn = self.attn_drop(attn)
104
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
105
+ else:
106
+ raise NotImplemented
107
+
108
+ x = self.proj(x)
109
+ x = self.proj_drop(x)
110
+ return x
111
+
112
+
113
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
114
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
115
+
116
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
117
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
118
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
119
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
120
+ 'survival rate' as the argument.
121
+
122
+ """
123
+ if drop_prob == 0. or not training:
124
+ return x
125
+ keep_prob = 1 - drop_prob
126
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
127
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
128
+ random_tensor.floor_() # binarize
129
+ output = x.div(keep_prob) * random_tensor
130
+ return output
131
+
132
+
133
+ class DropPath(nn.Module):
134
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
135
+ """
136
+ def __init__(self, drop_prob=None):
137
+ super(DropPath, self).__init__()
138
+ self.drop_prob = drop_prob
139
+
140
+ def forward(self, x):
141
+ return drop_path(x, self.drop_prob, self.training)
142
+
143
+
144
+ class Mlp(nn.Module):
145
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
146
+ super().__init__()
147
+ out_features = out_features or in_features
148
+ hidden_features = hidden_features or in_features
149
+ self.fc1 = nn.Linear(in_features, hidden_features)
150
+ self.act = act_layer()
151
+ self.fc2 = nn.Linear(hidden_features, out_features)
152
+ self.drop = nn.Dropout(drop)
153
+
154
+ def forward(self, x):
155
+ x = self.fc1(x)
156
+ x = self.act(x)
157
+ x = self.drop(x)
158
+ x = self.fc2(x)
159
+ x = self.drop(x)
160
+ return x
161
+
162
+
163
+ class UViTBlock(nn.Module):
164
+
165
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
166
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
167
+ super().__init__()
168
+ self.norm1 = norm_layer(dim)
169
+ self.attn = Attention(
170
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
171
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
172
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
173
+ self.norm2 = norm_layer(dim)
174
+ mlp_hidden_dim = int(dim * mlp_ratio)
175
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
176
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
177
+ self.use_checkpoint = use_checkpoint
178
+
179
+ def forward(self, x, skip=None):
180
+ if self.use_checkpoint:
181
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
182
+ else:
183
+ return self._forward(x, skip)
184
+
185
+ def _forward(self, x, skip=None):
186
+ if self.skip_linear is not None:
187
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
188
+ x = x + self.drop_path(self.attn(self.norm1(x)))
189
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
190
+ return x
191
+
192
+
193
+ def _expand_token(token, batch_size: int):
194
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
195
+
196
+
197
+ class TiTokEncoder(nn.Module):
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.config = config
201
+ self.image_size = config.dataset.preprocessing.crop_size
202
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
203
+ self.grid_size = self.image_size // self.patch_size
204
+ self.model_size = config.model.vq_model.vit_enc_model_size
205
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
206
+ self.token_size = config.model.vq_model.token_size
207
+
208
+ if config.model.vq_model.get("quantize_mode", "vq") == "vae":
209
+ self.token_size = self.token_size * 2 # needs to split into mean and std
210
+
211
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
212
+
213
+ self.width = {
214
+ "small": 512,
215
+ "base": 768,
216
+ "large": 1024,
217
+ }[self.model_size]
218
+ self.num_layers = {
219
+ "small": 8,
220
+ "base": 12,
221
+ "large": 24,
222
+ }[self.model_size]
223
+ self.num_heads = {
224
+ "small": 8,
225
+ "base": 12,
226
+ "large": 16,
227
+ }[self.model_size]
228
+
229
+ self.patch_embed = nn.Conv2d(
230
+ in_channels=3, out_channels=self.width,
231
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
232
+
233
+ scale = self.width ** -0.5
234
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
235
+ self.positional_embedding = nn.Parameter(
236
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
237
+ self.latent_token_positional_embedding = nn.Parameter(
238
+ scale * torch.randn(self.num_latent_tokens, self.width))
239
+ self.ln_pre = nn.LayerNorm(self.width)
240
+ self.transformer = nn.ModuleList()
241
+ for i in range(self.num_layers):
242
+ self.transformer.append(ResidualAttentionBlock(
243
+ self.width, self.num_heads, mlp_ratio=4.0
244
+ ))
245
+ self.ln_post = nn.LayerNorm(self.width)
246
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
247
+
248
+ def forward(self, pixel_values, latent_tokens):
249
+ batch_size = pixel_values.shape[0]
250
+ x = pixel_values
251
+ x = self.patch_embed(x)
252
+ x = x.reshape(x.shape[0], x.shape[1], -1)
253
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
254
+ # class embeddings and positional embeddings
255
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
256
+ x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
257
+
258
+
259
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
260
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
261
+ x = torch.cat([x, latent_tokens], dim=1)
262
+
263
+ x = self.ln_pre(x)
264
+ x = x.permute(1, 0, 2) # NLD -> LND
265
+ for i in range(self.num_layers):
266
+ x = self.transformer[i](x)
267
+ x = x.permute(1, 0, 2) # LND -> NLD
268
+
269
+ latent_tokens = x[:, 1+self.grid_size**2:]
270
+ latent_tokens = self.ln_post(latent_tokens)
271
+ # fake 2D shape
272
+ if self.is_legacy:
273
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
274
+ else:
275
+ # Fix legacy problem.
276
+ latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
277
+ latent_tokens = self.conv_out(latent_tokens)
278
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
279
+ return latent_tokens
280
+
281
+
282
+ class TiTokDecoder(nn.Module):
283
+ def __init__(self, config):
284
+ super().__init__()
285
+ self.config = config
286
+ self.image_size = config.dataset.preprocessing.crop_size
287
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
288
+ self.grid_size = self.image_size // self.patch_size
289
+ self.model_size = config.model.vq_model.vit_dec_model_size
290
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
291
+ self.token_size = config.model.vq_model.token_size
292
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
293
+ self.width = {
294
+ "small": 512,
295
+ "base": 768,
296
+ "large": 1024,
297
+ }[self.model_size]
298
+ self.num_layers = {
299
+ "small": 8,
300
+ "base": 12,
301
+ "large": 24,
302
+ }[self.model_size]
303
+ self.num_heads = {
304
+ "small": 8,
305
+ "base": 12,
306
+ "large": 16,
307
+ }[self.model_size]
308
+
309
+ self.decoder_embed = nn.Linear(
310
+ self.token_size, self.width, bias=True)
311
+ scale = self.width ** -0.5
312
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
313
+ self.positional_embedding = nn.Parameter(
314
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
315
+ # add mask token and query pos embed
316
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
317
+ self.latent_token_positional_embedding = nn.Parameter(
318
+ scale * torch.randn(self.num_latent_tokens, self.width))
319
+ self.ln_pre = nn.LayerNorm(self.width)
320
+ self.transformer = nn.ModuleList()
321
+ for i in range(self.num_layers):
322
+ self.transformer.append(ResidualAttentionBlock(
323
+ self.width, self.num_heads, mlp_ratio=4.0
324
+ ))
325
+ self.ln_post = nn.LayerNorm(self.width)
326
+
327
+ if self.is_legacy:
328
+ self.ffn = nn.Sequential(
329
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
330
+ nn.Tanh(),
331
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
332
+ )
333
+ self.conv_out = nn.Identity()
334
+ else:
335
+ # Directly predicting RGB pixels
336
+ self.ffn = nn.Sequential(
337
+ nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
338
+ Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
339
+ p1 = self.patch_size, p2 = self.patch_size),)
340
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
341
+
342
+ def forward(self, z_quantized):
343
+ N, C, H, W = z_quantized.shape
344
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
345
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
346
+ x = self.decoder_embed(x)
347
+
348
+ batchsize, seq_len, _ = x.shape
349
+
350
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
351
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
352
+ mask_tokens], dim=1)
353
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
354
+ x = x + self.latent_token_positional_embedding[:seq_len]
355
+ x = torch.cat([mask_tokens, x], dim=1)
356
+
357
+ x = self.ln_pre(x)
358
+ x = x.permute(1, 0, 2) # NLD -> LND
359
+ for i in range(self.num_layers):
360
+ x = self.transformer[i](x)
361
+ x = x.permute(1, 0, 2) # LND -> NLD
362
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
363
+ x = self.ln_post(x)
364
+ # N L D -> N D H W
365
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
366
+ x = self.ffn(x.contiguous())
367
+ x = self.conv_out(x)
368
+ return x
369
+
370
+
371
+ class TATiTokDecoder(TiTokDecoder):
372
+ def __init__(self, config):
373
+ super().__init__(config)
374
+ scale = self.width ** -0.5
375
+ self.text_context_length = config.model.vq_model.get("text_context_length", 77)
376
+ self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
377
+ self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
378
+ self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
379
+
380
+ def forward(self, z_quantized, text_guidance):
381
+ N, C, H, W = z_quantized.shape
382
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
383
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
384
+ x = self.decoder_embed(x)
385
+
386
+ batchsize, seq_len, _ = x.shape
387
+
388
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
389
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
390
+ mask_tokens], dim=1)
391
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
392
+ x = x + self.latent_token_positional_embedding[:seq_len]
393
+ x = torch.cat([mask_tokens, x], dim=1)
394
+
395
+ text_guidance = self.text_guidance_proj(text_guidance)
396
+ text_guidance = text_guidance + self.text_guidance_positional_embedding
397
+ x = torch.cat([x, text_guidance], dim=1)
398
+
399
+ x = self.ln_pre(x)
400
+ x = x.permute(1, 0, 2) # NLD -> LND
401
+ for i in range(self.num_layers):
402
+ x = self.transformer[i](x)
403
+ x = x.permute(1, 0, 2) # LND -> NLD
404
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
405
+ x = self.ln_post(x)
406
+ # N L D -> N D H W
407
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
408
+ x = self.ffn(x.contiguous())
409
+ x = self.conv_out(x)
410
+ return x
411
+
412
+
413
+ class WeightTiedLMHead(nn.Module):
414
+ def __init__(self, embeddings, target_codebook_size):
415
+ super().__init__()
416
+ self.weight = embeddings.weight
417
+ self.target_codebook_size = target_codebook_size
418
+
419
+ def forward(self, x):
420
+ # x shape: [batch_size, seq_len, embed_dim]
421
+ # Get the weights for the target codebook size
422
+ weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
423
+ # Compute the logits by matrix multiplication
424
+ logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
425
+ return logits
426
+
427
+
428
+ class TimestepEmbedder(nn.Module):
429
+ """
430
+ Embeds scalar timesteps into vector representations.
431
+ """
432
+ def __init__(self, hidden_size, frequency_embedding_size=256):
433
+ super().__init__()
434
+ self.mlp = nn.Sequential(
435
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
436
+ nn.SiLU(),
437
+ nn.Linear(hidden_size, hidden_size, bias=True),
438
+ )
439
+ self.frequency_embedding_size = frequency_embedding_size
440
+
441
+ @staticmethod
442
+ def timestep_embedding(t, dim, max_period=10000):
443
+ """
444
+ Create sinusoidal timestep embeddings.
445
+ :param t: a 1-D Tensor of N indices, one per batch element.
446
+ These may be fractional.
447
+ :param dim: the dimension of the output.
448
+ :param max_period: controls the minimum frequency of the embeddings.
449
+ :return: an (N, D) Tensor of positional embeddings.
450
+ """
451
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
452
+ half = dim // 2
453
+ freqs = torch.exp(
454
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
455
+ ).to(device=t.device)
456
+ args = t[:, None].float() * freqs[None]
457
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
458
+ if dim % 2:
459
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
460
+ return embedding
461
+
462
+ def forward(self, t):
463
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
464
+ t_emb = self.mlp(t_freq)
465
+ return t_emb
466
+
467
+
468
+ class ResBlock(nn.Module):
469
+ """
470
+ A residual block that can optionally change the number of channels.
471
+ :param channels: the number of input channels.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ channels
477
+ ):
478
+ super().__init__()
479
+ self.channels = channels
480
+
481
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
482
+ self.mlp = nn.Sequential(
483
+ nn.Linear(channels, channels, bias=True),
484
+ nn.SiLU(),
485
+ nn.Linear(channels, channels, bias=True),
486
+ )
487
+
488
+ self.adaLN_modulation = nn.Sequential(
489
+ nn.SiLU(),
490
+ nn.Linear(channels, 3 * channels, bias=True)
491
+ )
492
+
493
+ def forward(self, x, y):
494
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
495
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
496
+ h = self.mlp(h)
497
+ return x + gate_mlp * h
498
+
499
+
500
+ class FinalLayer(nn.Module):
501
+ """
502
+ The final layer adopted from DiT.
503
+ """
504
+ def __init__(self, model_channels, out_channels):
505
+ super().__init__()
506
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
507
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
508
+ self.adaLN_modulation = nn.Sequential(
509
+ nn.SiLU(),
510
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
511
+ )
512
+
513
+ def forward(self, x, c):
514
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
515
+ x = modulate(self.norm_final(x), shift, scale)
516
+ x = self.linear(x)
517
+ return x
518
+
519
+
520
+ class SimpleMLPAdaLN(nn.Module):
521
+ """
522
+ The MLP for Diffusion Loss.
523
+ :param in_channels: channels in the input Tensor.
524
+ :param model_channels: base channel count for the model.
525
+ :param out_channels: channels in the output Tensor.
526
+ :param z_channels: channels in the condition.
527
+ :param num_res_blocks: number of residual blocks per downsample.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ in_channels,
533
+ model_channels,
534
+ out_channels,
535
+ z_channels,
536
+ num_res_blocks,
537
+ grad_checkpointing=False,
538
+ ):
539
+ super().__init__()
540
+
541
+ self.in_channels = in_channels
542
+ self.model_channels = model_channels
543
+ self.out_channels = out_channels
544
+ self.num_res_blocks = num_res_blocks
545
+ self.grad_checkpointing = grad_checkpointing
546
+
547
+ self.time_embed = TimestepEmbedder(model_channels)
548
+ self.cond_embed = nn.Linear(z_channels, model_channels)
549
+
550
+ self.input_proj = nn.Linear(in_channels, model_channels)
551
+
552
+ res_blocks = []
553
+ for i in range(num_res_blocks):
554
+ res_blocks.append(ResBlock(
555
+ model_channels,
556
+ ))
557
+
558
+ self.res_blocks = nn.ModuleList(res_blocks)
559
+ self.final_layer = FinalLayer(model_channels, out_channels)
560
+
561
+ self.initialize_weights()
562
+
563
+ def initialize_weights(self):
564
+ def _basic_init(module):
565
+ if isinstance(module, nn.Linear):
566
+ torch.nn.init.xavier_uniform_(module.weight)
567
+ if module.bias is not None:
568
+ nn.init.constant_(module.bias, 0)
569
+ self.apply(_basic_init)
570
+
571
+ # Initialize timestep embedding MLP
572
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
573
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
574
+
575
+ # Zero-out adaLN modulation layers
576
+ for block in self.res_blocks:
577
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
578
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
579
+
580
+ # Zero-out output layers
581
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
582
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
583
+ nn.init.constant_(self.final_layer.linear.weight, 0)
584
+ nn.init.constant_(self.final_layer.linear.bias, 0)
585
+
586
+ def forward(self, x, t, c):
587
+ """
588
+ Apply the model to an input batch.
589
+ :param x: an [N x C] Tensor of inputs.
590
+ :param t: a 1-D batch of timesteps.
591
+ :param c: conditioning from AR transformer.
592
+ :return: an [N x C] Tensor of outputs.
593
+ """
594
+ x = self.input_proj(x)
595
+ t = self.time_embed(t)
596
+ c = self.cond_embed(c)
597
+
598
+ y = t + c
599
+
600
+ if self.grad_checkpointing and not torch.jit.is_scripting():
601
+ for block in self.res_blocks:
602
+ x = checkpoint(block, x, y)
603
+ else:
604
+ for block in self.res_blocks:
605
+ x = block(x, y)
606
+
607
+ return self.final_layer(x, y)
608
+
609
+ def forward_with_cfg(self, x, t, c, cfg_scale):
610
+ half = x[: len(x) // 2]
611
+ combined = torch.cat([half, half], dim=0)
612
+ model_out = self.forward(combined, t, c)
613
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
614
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
615
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
616
+ eps = torch.cat([half_eps, half_eps], dim=0)
617
+ return torch.cat([eps, rest], dim=1)
modeling/modules/discriminator.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Discriminator implementation."""
2
+ import functools
3
+ import math
4
+ from typing import Tuple
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .maskgit_vqgan import Conv2dSame
12
+
13
+
14
+ class BlurBlock(torch.nn.Module):
15
+ def __init__(self,
16
+ kernel: Tuple[int] = (1, 3, 3, 1)
17
+ ):
18
+ super().__init__()
19
+
20
+ kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
21
+ kernel = kernel[None, :] * kernel[:, None]
22
+ kernel /= kernel.sum()
23
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
24
+ self.register_buffer("kernel", kernel)
25
+
26
+ def calc_same_pad(self, i: int, k: int, s: int) -> int:
27
+ return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ ic, ih, iw = x.size()[-3:]
31
+ pad_h = self.calc_same_pad(i=ih, k=4, s=2)
32
+ pad_w = self.calc_same_pad(i=iw, k=4, s=2)
33
+ if pad_h > 0 or pad_w > 0:
34
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
35
+
36
+ weight = self.kernel.expand(ic, -1, -1, -1)
37
+
38
+ out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
39
+ return out
40
+
41
+
42
+ class NLayerDiscriminator(torch.nn.Module):
43
+ def __init__(
44
+ self,
45
+ num_channels: int = 3,
46
+ hidden_channels: int = 128,
47
+ num_stages: int = 3,
48
+ blur_resample: bool = True,
49
+ blur_kernel_size: int = 4
50
+ ):
51
+ """ Initializes the NLayerDiscriminator.
52
+
53
+ Args:
54
+ num_channels -> int: The number of input channels.
55
+ hidden_channels -> int: The number of hidden channels.
56
+ num_stages -> int: The number of stages.
57
+ blur_resample -> bool: Whether to use blur resampling.
58
+ blur_kernel_size -> int: The blur kernel size.
59
+ """
60
+ super().__init__()
61
+ assert num_stages > 0, "Discriminator cannot have 0 stages"
62
+ assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
63
+
64
+ in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
65
+ init_kernel_size = 5
66
+ activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
67
+
68
+ self.block_in = torch.nn.Sequential(
69
+ Conv2dSame(
70
+ num_channels,
71
+ hidden_channels,
72
+ kernel_size=init_kernel_size
73
+ ),
74
+ activation(),
75
+ )
76
+
77
+ BLUR_KERNEL_MAP = {
78
+ 3: (1,2,1),
79
+ 4: (1,3,3,1),
80
+ 5: (1,4,6,4,1),
81
+ }
82
+
83
+ discriminator_blocks = []
84
+ for i_level in range(num_stages):
85
+ in_channels = hidden_channels * in_channel_mult[i_level]
86
+ out_channels = hidden_channels * in_channel_mult[i_level + 1]
87
+ block = torch.nn.Sequential(
88
+ Conv2dSame(
89
+ in_channels,
90
+ out_channels,
91
+ kernel_size=3,
92
+ ),
93
+ torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
94
+ torch.nn.GroupNorm(32, out_channels),
95
+ activation(),
96
+ )
97
+ discriminator_blocks.append(block)
98
+
99
+ self.blocks = torch.nn.ModuleList(discriminator_blocks)
100
+
101
+ self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
102
+
103
+ self.to_logits = torch.nn.Sequential(
104
+ Conv2dSame(out_channels, out_channels, 1),
105
+ activation(),
106
+ Conv2dSame(out_channels, 1, kernel_size=5)
107
+ )
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """ Forward pass.
111
+
112
+ Args:
113
+ x -> torch.Tensor: The input tensor.
114
+
115
+ Returns:
116
+ output -> torch.Tensor: The output tensor.
117
+ """
118
+ hidden_states = self.block_in(x)
119
+ for block in self.blocks:
120
+ hidden_states = block(hidden_states)
121
+
122
+ hidden_states = self.pool(hidden_states)
123
+
124
+ return self.to_logits(hidden_states)
modeling/modules/ema_model.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EMA (Exponential Moving Average) model.
2
+
3
+ Reference:
4
+ https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8
5
+ """
6
+
7
+
8
+ import copy
9
+ from typing import Any, Iterable, Optional, Union
10
+
11
+ import torch
12
+
13
+
14
+ class EMAModel:
15
+ """Exponential Moving Average of models weights."""
16
+ def __init__(
17
+ self,
18
+ parameters: Iterable[torch.nn.Parameter],
19
+ decay: float = 0.9999,
20
+ min_decay: float = 0.0,
21
+ update_after_step: int = 0,
22
+ update_every: int = 1,
23
+ current_step: int = 0,
24
+ use_ema_warmup: bool = False,
25
+ inv_gamma: Union[float, int] = 1.0,
26
+ power: Union[float, int] = 2 / 3,
27
+ model_cls: Optional[Any] = None,
28
+ **model_config_kwargs
29
+ ):
30
+ """
31
+ Args:
32
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
33
+ decay (float): The decay factor for the exponential moving average.
34
+ min_decay (float): The minimum decay factor for the exponential moving average.
35
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
36
+ update_every (int): The number of steps between each EMA update.
37
+ current_step (int): The current training step.
38
+ use_ema_warmup (bool): Whether to use EMA warmup.
39
+ inv_gamma (float):
40
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
41
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
42
+
43
+ notes on EMA Warmup:
44
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
45
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
46
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
47
+ at 215.4k steps).
48
+ """
49
+
50
+ parameters = list(parameters)
51
+ self.shadow_params = [p.clone().detach() for p in parameters]
52
+ self.temp_stored_params = None
53
+
54
+ self.decay = decay
55
+ self.min_decay = min_decay
56
+ self.update_after_step = update_after_step
57
+ self.update_every = update_every
58
+ self.use_ema_warmup = use_ema_warmup
59
+ self.inv_gamma = inv_gamma
60
+ self.power = power
61
+ self.optimization_step = current_step
62
+ self.cur_decay_value = None # set in `step()`
63
+
64
+ self.model_cls = model_cls
65
+ self.model_config_kwargs = model_config_kwargs
66
+
67
+ @classmethod
68
+ def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel":
69
+ model = model_cls(**model_config_kwargs)
70
+ model.load_pretrained_weight(checkpoint)
71
+
72
+ ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs)
73
+ return ema_model
74
+
75
+ def save_pretrained(self, path):
76
+ if self.model_cls is None:
77
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
78
+
79
+ if self.model_config_kwargs is None:
80
+ raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.")
81
+
82
+ model = self.model_cls(**self.model_config_kwargs)
83
+ self.copy_to(model.parameters())
84
+ model.save_pretrained_weight(path)
85
+
86
+ def set_step(self, optimization_step: int):
87
+ self.optimization_step = optimization_step
88
+
89
+ def get_decay(self, optimization_step: int) -> float:
90
+ """Computes the decay factor for the exponential moving average."""
91
+ step = max(0, optimization_step - self.update_after_step - 1)
92
+
93
+ if step <= 0:
94
+ return 0.0
95
+
96
+ if self.use_ema_warmup:
97
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
98
+ else:
99
+ cur_decay_value = (1 + step) / (10 + step)
100
+
101
+ cur_decay_value = min(cur_decay_value, self.decay)
102
+ # Make sure decay is not smaller than min_decay.
103
+ cur_decay_value = max(cur_decay_value, self.min_decay)
104
+ return cur_decay_value
105
+
106
+ @torch.no_grad()
107
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
108
+ parameters = list(parameters)
109
+
110
+ self.optimization_step += 1
111
+
112
+ if (self.optimization_step - 1) % self.update_every != 0:
113
+ return
114
+
115
+ # Compute the decay factor for the exponential moving average.
116
+ decay = self.get_decay(self.optimization_step)
117
+ self.cur_decay_value = decay
118
+ one_minus_decay = 1 - decay
119
+
120
+ for s_param, param in zip(self.shadow_params, parameters):
121
+ if param.requires_grad:
122
+ s_param.sub_(one_minus_decay * (s_param - param))
123
+ else:
124
+ s_param.copy_(param)
125
+
126
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
127
+ """Copies current averaged parameters into given collection of parameters.
128
+
129
+ Args:
130
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
131
+ updated with the stored moving averages. If `None`, the parameters with which this
132
+ `ExponentialMovingAverage` was initialized will be used.
133
+ """
134
+ parameters = list(parameters)
135
+ for s_param, param in zip(self.shadow_params, parameters):
136
+ param.data.copy_(s_param.to(param.device).data)
137
+
138
+ def to(self, device=None, dtype=None) -> None:
139
+ r"""Moves internal buffers of the ExponentialMovingAverage to `device`.
140
+
141
+ Args:
142
+ device: like `device` argument to `torch.Tensor.to`
143
+ """
144
+ # .to() on the tensors handles None correctly
145
+ self.shadow_params = [
146
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
147
+ for p in self.shadow_params
148
+ ]
149
+
150
+ def state_dict(self) -> dict:
151
+ r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
152
+ checkpointing to save the ema state dict.
153
+ """
154
+ # Following PyTorch conventions, references to tensors are returned:
155
+ # "returns a reference to the state and not its copy!" -
156
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
157
+ return {
158
+ "decay": self.decay,
159
+ "min_decay": self.min_decay,
160
+ "optimization_step": self.optimization_step,
161
+ "update_after_step": self.update_after_step,
162
+ "use_ema_warmup": self.use_ema_warmup,
163
+ "inv_gamma": self.inv_gamma,
164
+ "power": self.power,
165
+ "shadow_params": self.shadow_params,
166
+ }
167
+
168
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
169
+ r"""
170
+ Args:
171
+ Save the current parameters for restoring later.
172
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
173
+ temporarily stored.
174
+ """
175
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
176
+
177
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
178
+ r"""Restores the parameters stored with the `store` method. Useful to validate
179
+ the model with EMA parameters without affecting the original optimization process.
180
+ Store the parameters before the `copy_to()` method. After validation (or
181
+ model saving), use this to restore the former parameters.
182
+
183
+ Args:
184
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
185
+ updated with the stored parameters. If `None`, the parameters with which this
186
+ `ExponentialMovingAverage` was initialized will be used.
187
+ """
188
+ if self.temp_stored_params is None:
189
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
190
+ for c_param, param in zip(self.temp_stored_params, parameters):
191
+ param.data.copy_(c_param.data)
192
+
193
+ # Better memory-wise.
194
+ self.temp_stored_params = None
195
+
196
+ def load_state_dict(self, state_dict: dict) -> None:
197
+ r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
198
+ ema state dict.
199
+
200
+ Args:
201
+ state_dict (dict): EMA state. Should be an object returned
202
+ from a call to :meth:`state_dict`.
203
+ """
204
+ # Deepcopy, to be consistent with module API
205
+ state_dict = copy.deepcopy(state_dict)
206
+
207
+ self.decay = state_dict.get("decay", self.decay)
208
+ if self.decay < 0.0 or self.decay > 1.0:
209
+ raise ValueError("Decay must be between 0 and 1")
210
+
211
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
212
+ if not isinstance(self.min_decay, float):
213
+ raise ValueError("Invalid min_decay")
214
+
215
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
216
+ if not isinstance(self.optimization_step, int):
217
+ raise ValueError("Invalid optimization_step")
218
+
219
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
220
+ if not isinstance(self.update_after_step, int):
221
+ raise ValueError("Invalid update_after_step")
222
+
223
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
224
+ if not isinstance(self.use_ema_warmup, bool):
225
+ raise ValueError("Invalid use_ema_warmup")
226
+
227
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
228
+ if not isinstance(self.inv_gamma, (float, int)):
229
+ raise ValueError("Invalid inv_gamma")
230
+
231
+ self.power = state_dict.get("power", self.power)
232
+ if not isinstance(self.power, (float, int)):
233
+ raise ValueError("Invalid power")
234
+
235
+ shadow_params = state_dict.get("shadow_params", None)
236
+ if shadow_params is not None:
237
+ self.shadow_params = shadow_params
238
+ if not isinstance(self.shadow_params, list):
239
+ raise ValueError("shadow_params must be a list")
240
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
241
+ raise ValueError("shadow_params must all be Tensors")
modeling/modules/encoder_decoder.py ADDED
@@ -0,0 +1,1142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encoder and decoder building blocks for VibeToken.
2
+
3
+ Reference:
4
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
5
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
6
+ """
7
+
8
+ import random
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.checkpoint import checkpoint
13
+ from collections import OrderedDict
14
+ import einops
15
+ from einops.layers.torch import Rearrange
16
+ from typing import Optional, Sequence, Tuple, Union
17
+ from modeling.modules.fuzzy_embedding import FuzzyEmbedding
18
+ import collections.abc
19
+ from itertools import repeat
20
+ from typing import Any
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from torch import vmap
25
+ from torch import Tensor
26
+
27
+ def to_2tuple(x: Any) -> Tuple:
28
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
29
+ return tuple(x)
30
+ return tuple(repeat(x, 2))
31
+
32
+ class PatchMixture():
33
+ def __init__(self, seed=42):
34
+ self.seed = seed
35
+
36
+ def get_mask(self, x, mask_ratio=0.0, l1_reg=0.0, inverse=False):
37
+ batch_size, num_patches, _ = x.shape
38
+ device = x.device
39
+ num_mask = int(num_patches * mask_ratio)
40
+ num_keep = num_patches - num_mask
41
+ token_magnitudes = x.abs().sum(dim=-1)
42
+ min_mags = token_magnitudes.min(dim=1, keepdim=True)[0]
43
+ max_mags = token_magnitudes.max(dim=1, keepdim=True)[0]
44
+ token_magnitudes = (token_magnitudes - min_mags) / (max_mags - min_mags + 1e-8)
45
+ if inverse:
46
+ adjusted_magnitudes = 1.0 - token_magnitudes
47
+ else:
48
+ adjusted_magnitudes = token_magnitudes
49
+ noise_random = torch.rand(batch_size, num_patches, device=device)
50
+ noise = (1.0 - l1_reg) * noise_random + l1_reg * adjusted_magnitudes
51
+ ids_shuffle = torch.argsort(noise, dim=1)
52
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
53
+ ids_keep = ids_shuffle[:, :num_keep]
54
+ ids_mask = ids_shuffle[:, num_keep:]
55
+ mask = torch.ones((batch_size, num_patches), device=device, dtype=torch.bool)
56
+ mask.scatter_(1, ids_keep, False)
57
+ return {
58
+ 'mask': mask,
59
+ 'ids_keep': ids_keep,
60
+ 'ids_mask': ids_mask,
61
+ 'ids_shuffle': ids_shuffle,
62
+ 'ids_restore': ids_restore
63
+ }
64
+
65
+ def start_route(self, x, mask_info):
66
+ ids_shuffle = mask_info['ids_shuffle']
67
+ num_keep = mask_info['ids_keep'].size(1)
68
+ batch_indices = torch.arange(x.size(0), device=x.device).unsqueeze(-1)
69
+ x_shuffled = x.gather(1, ids_shuffle.unsqueeze(-1).expand(-1, -1, x.size(2)))
70
+ masked_x = x_shuffled[:, :num_keep, :]
71
+ return masked_x
72
+
73
+ def end_route(self, masked_x, mask_info, original_x=None, mask_token=0.0):
74
+ batch_size, num_patches = mask_info['mask'].shape
75
+ num_keep = masked_x.size(1)
76
+ dim = masked_x.size(2)
77
+ device = masked_x.device
78
+ ids_restore = mask_info['ids_restore']
79
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(-1)
80
+ x_unshuffled = torch.empty((batch_size, num_patches, dim), device=device)
81
+ x_unshuffled[:, :num_keep, :] = masked_x
82
+ if original_x is not None:
83
+ x_shuffled = original_x.gather(1, mask_info['ids_shuffle'].unsqueeze(-1).expand(-1, -1, dim))
84
+ x_unshuffled[:, num_keep:, :] = x_shuffled[:, num_keep:, :]
85
+ else:
86
+ x_unshuffled[:, num_keep:, :].fill_(mask_token)
87
+ x_unmasked = x_unshuffled.gather(1, ids_restore.unsqueeze(-1).expand(-1, -1, dim))
88
+ return x_unmasked
89
+
90
+ class ResizableBlur(nn.Module):
91
+ """
92
+ Single-parameter anti‑aliasing layer.
93
+ Call with scale=1,2,4 to downsample by 1× (identity), 2×, or 4×.
94
+ """
95
+ def __init__(self, channels: int,
96
+ max_kernel_size: int = 9,
97
+ init_type: str = "gaussian"):
98
+ super().__init__()
99
+ self.C = channels
100
+ K = max_kernel_size # e.g. 9 for 4×
101
+ assert K % 2 == 1, "kernel must be odd"
102
+
103
+ # ----- initialise the largest kernel ---------------------------------
104
+ if init_type == "gaussian":
105
+ # 2‑D separable Gaussian, σ≈K/6
106
+ ax = torch.arange(-(K//2), K//2 + 1)
107
+ g1d = torch.exp(-0.5 * (ax / (K/6.0))**2)
108
+ g2d = torch.outer(g1d, g1d)
109
+ kernel = g2d / g2d.sum()
110
+ elif init_type == "lanczos":
111
+ a = K//2 # window size parameter
112
+ x = torch.arange(-a, a+1).float()
113
+ sinc = lambda t: torch.where(t==0, torch.ones_like(t), torch.sin(torch.pi*t)/(torch.pi*t))
114
+ k1d = sinc(x) * sinc(x/a)
115
+ k2d = torch.outer(k1d, k1d)
116
+ kernel = k2d / k2d.sum()
117
+ else:
118
+ raise ValueError("unknown init_type")
119
+
120
+ # learnable base kernel (shape 1×1×K×K)
121
+ self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0))
122
+
123
+ # ------------------------------------------------------------------------
124
+ @staticmethod
125
+ def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor:
126
+ """
127
+ Bilinearly interpolate weight (B,C,H,W) to target k_size×k_size,
128
+ then L1‑normalise over spatial dims so Σ=1.
129
+ """
130
+ if weight.shape[-1] != k_size:
131
+ weight = F.interpolate(weight, size=(k_size, k_size),
132
+ mode="bilinear", align_corners=True)
133
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8)
134
+ return weight
135
+
136
+ # ------------------------------------------------------------------------
137
+ def forward(self, x: torch.Tensor, input_size, target_size) -> torch.Tensor:
138
+ # Unpack input and target dimensions
139
+ input_h, input_w = input_size
140
+ target_h, target_w = target_size
141
+
142
+ # Calculate scale factors for height and width
143
+ scale_h = input_h / target_h
144
+ scale_w = input_w / target_w
145
+
146
+ # Determine kernel size based on scale factors
147
+ # Larger scale factors need larger kernels for better anti-aliasing
148
+ k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3)))
149
+ k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3)))
150
+
151
+ # Make sure kernel sizes are odd
152
+ k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1
153
+ k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1
154
+
155
+ # Use the maximum for a square kernel, or create a rectangular kernel if needed
156
+ k_size = max(k_size_h, k_size_w)
157
+
158
+ # Calculate appropriate stride and padding
159
+ stride_h = max(1, round(scale_h))
160
+ stride_w = max(1, round(scale_w))
161
+ pad_h = k_size_h // 2
162
+ pad_w = k_size_w // 2
163
+
164
+ # Get the kernel and normalize it
165
+ k = self._resize_and_normalise(self.weight, k_size) # (1,1,k,k)
166
+ k = k.repeat(self.C, 1, 1, 1) # depth-wise
167
+
168
+ # Apply convolution with calculated parameters
169
+ result = F.conv2d(x, weight=k, stride=(stride_h, stride_w),
170
+ padding=(pad_h, pad_w), groups=self.C)
171
+
172
+ # If the convolution didn't get us exactly to the target size, use interpolation for fine adjustment
173
+ if result.shape[2:] != target_size:
174
+ result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True)
175
+
176
+ return result
177
+
178
+ def modulate(x, shift, scale):
179
+ return x * (1 + scale) + shift
180
+
181
+
182
+ class ResidualAttentionBlock(nn.Module):
183
+ def __init__(
184
+ self,
185
+ d_model,
186
+ n_head,
187
+ mlp_ratio = 4.0,
188
+ act_layer = nn.GELU,
189
+ norm_layer = nn.LayerNorm
190
+ ):
191
+ super().__init__()
192
+
193
+ self.ln_1 = norm_layer(d_model)
194
+ self.attn = nn.MultiheadAttention(d_model, n_head)
195
+ self.mlp_ratio = mlp_ratio
196
+ # optionally we can disable the FFN
197
+ if mlp_ratio > 0:
198
+ self.ln_2 = norm_layer(d_model)
199
+ mlp_width = int(d_model * mlp_ratio)
200
+ self.mlp = nn.Sequential(OrderedDict([
201
+ ("c_fc", nn.Linear(d_model, mlp_width)),
202
+ ("gelu", act_layer()),
203
+ ("c_proj", nn.Linear(mlp_width, d_model))
204
+ ]))
205
+
206
+ def attention(
207
+ self,
208
+ x: torch.Tensor,
209
+ attention_mask: Optional[torch.Tensor] = None
210
+ ):
211
+ return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0]
212
+
213
+ def forward(
214
+ self,
215
+ x: torch.Tensor,
216
+ attention_mask: Optional[torch.Tensor] = None
217
+ ):
218
+ attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask)
219
+ x = x + attn_output
220
+ if self.mlp_ratio > 0:
221
+ x = x + self.mlp(self.ln_2(x))
222
+ return x
223
+
224
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
225
+ ATTENTION_MODE = 'flash'
226
+ else:
227
+ try:
228
+ import xformers
229
+ import xformers.ops
230
+ ATTENTION_MODE = 'xformers'
231
+ except:
232
+ ATTENTION_MODE = 'math'
233
+ print(f'attention mode is {ATTENTION_MODE}')
234
+
235
+
236
+ class Attention(nn.Module):
237
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
238
+ super().__init__()
239
+ self.num_heads = num_heads
240
+ head_dim = dim // num_heads
241
+ self.scale = qk_scale or head_dim ** -0.5
242
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
243
+ self.attn_drop = nn.Dropout(attn_drop)
244
+ self.proj = nn.Linear(dim, dim)
245
+ self.proj_drop = nn.Dropout(proj_drop)
246
+
247
+ def forward(self, x):
248
+ B, L, C = x.shape
249
+
250
+ qkv = self.qkv(x)
251
+ if ATTENTION_MODE == 'flash':
252
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
253
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
254
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
255
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
256
+ elif ATTENTION_MODE == 'xformers':
257
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
258
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
259
+ x = xformers.ops.memory_efficient_attention(q, k, v)
260
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
261
+ elif ATTENTION_MODE == 'math':
262
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
263
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
264
+ attn = (q @ k.transpose(-2, -1)) * self.scale
265
+ attn = attn.softmax(dim=-1)
266
+ attn = self.attn_drop(attn)
267
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
268
+ else:
269
+ raise NotImplemented
270
+
271
+ x = self.proj(x)
272
+ x = self.proj_drop(x)
273
+ return x
274
+
275
+
276
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
277
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
278
+
279
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
280
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
281
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
282
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
283
+ 'survival rate' as the argument.
284
+
285
+ """
286
+ if drop_prob == 0. or not training:
287
+ return x
288
+ keep_prob = 1 - drop_prob
289
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
290
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
291
+ random_tensor.floor_() # binarize
292
+ output = x.div(keep_prob) * random_tensor
293
+ return output
294
+
295
+
296
+ class DropPath(nn.Module):
297
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
298
+ """
299
+ def __init__(self, drop_prob=None):
300
+ super(DropPath, self).__init__()
301
+ self.drop_prob = drop_prob
302
+
303
+ def forward(self, x):
304
+ return drop_path(x, self.drop_prob, self.training)
305
+
306
+
307
+ class Mlp(nn.Module):
308
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
309
+ super().__init__()
310
+ out_features = out_features or in_features
311
+ hidden_features = hidden_features or in_features
312
+ self.fc1 = nn.Linear(in_features, hidden_features)
313
+ self.act = act_layer()
314
+ self.fc2 = nn.Linear(hidden_features, out_features)
315
+ self.drop = nn.Dropout(drop)
316
+
317
+ def forward(self, x):
318
+ x = self.fc1(x)
319
+ x = self.act(x)
320
+ x = self.drop(x)
321
+ x = self.fc2(x)
322
+ x = self.drop(x)
323
+ return x
324
+
325
+
326
+ class UViTBlock(nn.Module):
327
+
328
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
329
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
330
+ super().__init__()
331
+ self.norm1 = norm_layer(dim)
332
+ self.attn = Attention(
333
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
334
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
335
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
336
+ self.norm2 = norm_layer(dim)
337
+ mlp_hidden_dim = int(dim * mlp_ratio)
338
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
339
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
340
+ self.use_checkpoint = use_checkpoint
341
+
342
+ def forward(self, x, skip=None):
343
+ if self.use_checkpoint:
344
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
345
+ else:
346
+ return self._forward(x, skip)
347
+
348
+ def _forward(self, x, skip=None):
349
+ if self.skip_linear is not None:
350
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
351
+ x = x + self.drop_path(self.attn(self.norm1(x)))
352
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
353
+ return x
354
+
355
+
356
+ def _expand_token(token, batch_size: int):
357
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
358
+
359
+
360
+ class ResolutionEncoder(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.config = config
364
+ self.image_size = config.dataset.preprocessing.crop_size
365
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
366
+ self.model_size = config.model.vq_model.vit_enc_model_size
367
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
368
+ self.token_size = config.model.vq_model.token_size
369
+ self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
370
+ self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
371
+ self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
372
+
373
+ if config.model.vq_model.get("quantize_mode", "vq") == "vae":
374
+ self.token_size = self.token_size * 2 # needs to split into mean and std
375
+
376
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
377
+
378
+ self.width = {
379
+ "tiny": 256,
380
+ "small": 512,
381
+ "base": 768,
382
+ "large": 1024,
383
+ }[self.model_size]
384
+ self.num_layers = {
385
+ "tiny": 4,
386
+ "small": 8,
387
+ "base": 12,
388
+ "large": 24,
389
+ }[self.model_size]
390
+ self.num_heads = {
391
+ "tiny": 4,
392
+ "small": 8,
393
+ "base": 12,
394
+ "large": 16,
395
+ }[self.model_size]
396
+
397
+ self.patch_embed = nn.Conv2d(
398
+ in_channels=3, out_channels=self.width,
399
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
400
+
401
+ scale = self.width ** -0.5
402
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
403
+
404
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
405
+
406
+ self.latent_token_positional_embedding = nn.Parameter(
407
+ scale * torch.randn(self.num_latent_tokens, self.width))
408
+ self.ln_pre = nn.LayerNorm(self.width)
409
+
410
+ self.patch_mixture = PatchMixture()
411
+
412
+ self.transformer = nn.ModuleList()
413
+ for i in range(self.num_layers):
414
+ self.transformer.append(ResidualAttentionBlock(
415
+ self.width, self.num_heads, mlp_ratio=4.0
416
+ ))
417
+
418
+ self.ln_post = nn.LayerNorm(self.width)
419
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
420
+ self.pinvs = {}
421
+
422
+ def apply_flexivit_patch_embed(self, x, target_patch_size):
423
+ patch_size = to_2tuple(target_patch_size)
424
+
425
+ # Resize conv weights
426
+ if patch_size == to_2tuple(self.patch_size):
427
+ weight = self.patch_embed.weight
428
+ else:
429
+ weight = self.resize_patch_embed(self.patch_embed.weight, patch_size)
430
+
431
+ # Apply conv with resized weights
432
+ x = F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size)
433
+ return x
434
+
435
+ def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
436
+ x_resized = F.interpolate(
437
+ x[None, None, ...],
438
+ shape,
439
+ mode="bilinear",
440
+ antialias=False,
441
+ )
442
+ return x_resized[0, 0, ...]
443
+
444
+ def _calculate_pinv(
445
+ self, old_shape: Tuple[int, int], new_shape: Tuple[int, int], device=None
446
+ ) -> Tensor:
447
+ # Use the device from patch_embed weights if available
448
+ if device is None and hasattr(self, 'patch_embed'):
449
+ device = self.patch_embed.weight.device
450
+
451
+ mat = []
452
+ for i in range(np.prod(old_shape)):
453
+ basis_vec = torch.zeros(old_shape, device=device) # Specify device here
454
+ basis_vec[np.unravel_index(i, old_shape)] = 1.0
455
+ mat.append(self._resize(basis_vec, new_shape).reshape(-1))
456
+ resize_matrix = torch.stack(mat)
457
+ return torch.linalg.pinv(resize_matrix)
458
+
459
+ def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]):
460
+ """Resize patch_embed to target resolution via pseudo-inverse resizing"""
461
+ # Return original kernel if no resize is necessary
462
+ if to_2tuple(self.patch_size) == new_patch_size:
463
+ return patch_embed
464
+
465
+ # Calculate pseudo-inverse of resize matrix
466
+ if new_patch_size not in self.pinvs:
467
+ self.pinvs[new_patch_size] = self._calculate_pinv(
468
+ to_2tuple(self.patch_size), new_patch_size, device=patch_embed.device
469
+ )
470
+ pinv = self.pinvs[new_patch_size]
471
+
472
+ def resample_patch_embed(patch_embed: Tensor):
473
+ h, w = new_patch_size
474
+ original_dtype = patch_embed.dtype
475
+ patch_embed_float = patch_embed.float()
476
+ resampled_kernel = pinv @ patch_embed_float.reshape(-1)
477
+ resampled_kernel = resampled_kernel.to(original_dtype)
478
+ return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
479
+
480
+ v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
481
+
482
+ return v_resample_patch_embed(patch_embed)
483
+
484
+ def get_attention_mask(self, target_shape, attention_mask):
485
+ # Create mask for mask_tokens (all True since we want to attend to all mask tokens)
486
+ mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
487
+ # Combine with input attention mask
488
+ attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
489
+ sequence_length = attention_mask.shape[1]
490
+
491
+ # Create causal attention mask
492
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
493
+ attention_mask = attention_mask.expand(
494
+ attention_mask.shape[0],
495
+ self.num_heads,
496
+ sequence_length,
497
+ sequence_length
498
+ )
499
+
500
+ # Reshape to [B*num_heads, S, S]
501
+ attention_mask = attention_mask.reshape(
502
+ -1, sequence_length, sequence_length
503
+ )
504
+
505
+ # Convert boolean mask to float
506
+ attention_mask = attention_mask.float()
507
+
508
+ # Convert mask values: True -> 0.0, False -> -inf
509
+ attention_mask = attention_mask.masked_fill(
510
+ ~attention_mask.bool(),
511
+ float('-inf')
512
+ )
513
+ return attention_mask
514
+
515
+ def forward(self, pixel_values, latent_tokens, attention_mask=None, encode_patch_size=None, train=True):
516
+ batch_size, _, H, W = pixel_values.shape
517
+ x = pixel_values
518
+
519
+ # Apply dynamic patch embedding
520
+ # Determine patch size dynamically based on image resolution
521
+ # Base patch size (32) is for 512x512 images
522
+ # Scale proportionally for other resolutions to maintain ~256 tokens
523
+ base_resolution = 512
524
+
525
+ if encode_patch_size is None:
526
+ base_patch_size = random.choice([16, 32])
527
+ target_patch_size = min(int(min(H, W) / base_resolution * base_patch_size), 32) # we force it to be at most 32 otherwise we lose information
528
+ else:
529
+ target_patch_size = encode_patch_size
530
+
531
+ if isinstance(target_patch_size, int):
532
+ target_patch_size = (target_patch_size, target_patch_size)
533
+
534
+ x = self.apply_flexivit_patch_embed(x, target_patch_size)
535
+
536
+ x = x.reshape(x.shape[0], x.shape[1], -1)
537
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
538
+ # class embeddings and positional embeddings
539
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
540
+
541
+ # create image_rotary_emb
542
+ grid_height = H // target_patch_size[0]
543
+ grid_width = W // target_patch_size[1]
544
+
545
+ mask_ratio = 0.0
546
+ if grid_height*grid_width > 256 and train:
547
+ mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
548
+
549
+ num_latent_tokens = latent_tokens.shape[0]
550
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
551
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent_tokens]
552
+
553
+ x = x + self.positional_embedding(grid_height, grid_width, train=train, dtype=x.dtype)
554
+
555
+ # apply attention_mask before concatenating x and latent_tokens
556
+ if attention_mask is not None:
557
+ key_attention_mask = attention_mask.clone()
558
+ attention_mask = self.get_attention_mask((batch_size, x.shape[1]), key_attention_mask)
559
+ full_seq_attention_mask = attention_mask.clone()
560
+ else:
561
+ key_attention_mask = None
562
+ full_seq_attention_mask = None
563
+
564
+ # Concatenate x and latent_tokens first
565
+ x = torch.cat([x, latent_tokens], dim=1)
566
+
567
+ x = self.ln_pre(x)
568
+ x = x.permute(1, 0, 2) # NLD -> LND
569
+ for i in range(self.num_layers):
570
+ if i == self.patch_mixture_start_layer:
571
+ x = x.permute(1, 0, 2)
572
+ x_D_last = x[:, 1:grid_height*grid_width+1].clone()
573
+ mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
574
+ new_x = self.patch_mixture.start_route(x, mask_info)
575
+ x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
576
+ x = x.permute(1, 0, 2)
577
+ if key_attention_mask is not None:
578
+ attention_mask = self.get_attention_mask((batch_size, 1+new_x.shape[1]), key_attention_mask)
579
+ else:
580
+ attention_mask = None
581
+
582
+ x = self.transformer[i](x, attention_mask=attention_mask)
583
+
584
+ if i == self.patch_mixture_end_layer:
585
+ x = x.permute(1, 0, 2)
586
+ new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
587
+ x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
588
+ x = x.permute(1, 0, 2)
589
+ if full_seq_attention_mask is not None:
590
+ attention_mask = full_seq_attention_mask.clone()
591
+ else:
592
+ attention_mask = None
593
+
594
+ x = x.permute(1, 0, 2) # LND -> NLD
595
+
596
+ latent_tokens = x[:, 1+grid_height*grid_width:]
597
+ latent_tokens = self.ln_post(latent_tokens)
598
+
599
+ # fake 2D shape
600
+ if self.is_legacy:
601
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent_tokens, 1)
602
+ else:
603
+ # Fix legacy problem.
604
+ latent_tokens = latent_tokens.reshape(batch_size, num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
605
+ latent_tokens = self.conv_out(latent_tokens)
606
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent_tokens)
607
+ return latent_tokens
608
+
609
+ # Keep the original TiTokEncoder as a legacy class
610
+ class TiTokEncoder(ResolutionEncoder):
611
+ """Legacy TiTokEncoder - now inherits from ResolutionEncoder for backward compatibility"""
612
+ pass
613
+
614
+ class ResolutionDecoder(nn.Module):
615
+ def __init__(self, config):
616
+ super().__init__()
617
+ self.config = config
618
+ self.image_size = config.dataset.preprocessing.crop_size
619
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
620
+ self.model_size = config.model.vq_model.vit_dec_model_size
621
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
622
+ self.token_size = config.model.vq_model.token_size
623
+ self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
624
+ self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
625
+ self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
626
+
627
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
628
+ self.width = {
629
+ "tiny": 256,
630
+ "small": 512,
631
+ "base": 768,
632
+ "large": 1024,
633
+ }[self.model_size]
634
+ self.num_layers = {
635
+ "tiny": 4,
636
+ "small": 8,
637
+ "base": 12,
638
+ "large": 24,
639
+ }[self.model_size]
640
+ self.num_heads = {
641
+ "tiny": 4,
642
+ "small": 8,
643
+ "base": 12,
644
+ "large": 16,
645
+ }[self.model_size]
646
+
647
+ self.decoder_embed = nn.Linear(
648
+ self.token_size, self.width, bias=True)
649
+ scale = self.width ** -0.5
650
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
651
+
652
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
653
+
654
+ # add mask token and query pos embed
655
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
656
+ self.latent_token_positional_embedding = nn.Parameter(
657
+ scale * torch.randn(self.num_latent_tokens, self.width))
658
+ self.ln_pre = nn.LayerNorm(self.width)
659
+
660
+ self.patch_mixture = PatchMixture()
661
+
662
+ self.transformer = nn.ModuleList()
663
+ for i in range(self.num_layers):
664
+ self.transformer.append(ResidualAttentionBlock(
665
+ self.width, self.num_heads, mlp_ratio=4.0
666
+ ))
667
+ self.ln_post = nn.LayerNorm(self.width)
668
+
669
+ if self.is_legacy:
670
+ raise NotImplementedError("Legacy mode is not implemented for ResolutionDecoder")
671
+ else:
672
+ # Directly predicting RGB pixels
673
+ self.ffn = nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True)
674
+ self.rearrange = Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
675
+ p1 = self.patch_size, p2 = self.patch_size)
676
+ self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
677
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
678
+
679
+ def get_attention_mask(self, target_shape, attention_mask):
680
+ # Create mask for mask_tokens (all True since we want to attend to all mask tokens)
681
+ mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
682
+ # Combine with input attention mask
683
+ attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
684
+ sequence_length = attention_mask.shape[1]
685
+
686
+ # Create causal attention mask
687
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
688
+ attention_mask = attention_mask.expand(
689
+ attention_mask.shape[0],
690
+ self.num_heads,
691
+ sequence_length,
692
+ sequence_length
693
+ )
694
+
695
+ # Reshape to [B*num_heads, S, S]
696
+ attention_mask = attention_mask.reshape(
697
+ -1, sequence_length, sequence_length
698
+ )
699
+
700
+ # Convert boolean mask to float
701
+ attention_mask = attention_mask.float()
702
+
703
+ # Convert mask values: True -> 0.0, False -> -inf
704
+ attention_mask = attention_mask.masked_fill(
705
+ ~attention_mask.bool(),
706
+ float('-inf')
707
+ )
708
+ return attention_mask
709
+
710
+ def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
711
+ N, C, H, W = z_quantized.shape
712
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
713
+ x = self.decoder_embed(x)
714
+
715
+ batchsize, seq_len, _ = x.shape
716
+
717
+ if height is None:
718
+ height = self.image_size
719
+ if width is None:
720
+ width = self.image_size
721
+
722
+ # create image_rotary_emb
723
+ if decode_patch_size is None:
724
+ # Calculate total area and determine appropriate patch size
725
+ total_pixels = height * width
726
+
727
+ # Target patch counts between 256 and 1024
728
+ min_patches = 256
729
+ max_patches = 1024
730
+
731
+ # Calculate possible patch sizes that would give us patch counts in our target range
732
+ possible_patch_sizes = []
733
+ for patch_size in [8, 16, 32]:
734
+ grid_h = height // patch_size
735
+ grid_w = width // patch_size
736
+ total_patches = grid_h * grid_w
737
+ if min_patches <= total_patches <= max_patches:
738
+ possible_patch_sizes.append(patch_size)
739
+
740
+ if not possible_patch_sizes:
741
+ # If no patch size gives us the desired range, pick the one closest to our target range
742
+ patch_counts = []
743
+ for patch_size in [8, 16, 32]:
744
+ grid_h = height // patch_size
745
+ grid_w = width // patch_size
746
+ patch_counts.append((patch_size, grid_h * grid_w))
747
+
748
+ # Sort by how close the patch count is to our target range
749
+ patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
750
+ possible_patch_sizes = [patch_counts[0][0]]
751
+
752
+ selected_patch_size = random.choice(possible_patch_sizes)
753
+ else:
754
+ selected_patch_size = decode_patch_size
755
+
756
+ if isinstance(selected_patch_size, int):
757
+ selected_patch_size = (selected_patch_size, selected_patch_size)
758
+
759
+ grid_height = height // selected_patch_size[0]
760
+ grid_width = width // selected_patch_size[1]
761
+
762
+ # if grid_height*grid_width>1024 and train:
763
+ # grid_height = 32
764
+ # grid_width = 32
765
+
766
+ mask_ratio = 0.0
767
+ if grid_height*grid_width > 256 and train:
768
+ mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
769
+
770
+ mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
771
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
772
+ mask_tokens], dim=1)
773
+
774
+ mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
775
+
776
+ x = x + self.latent_token_positional_embedding[:seq_len]
777
+ x = torch.cat([mask_tokens, x], dim=1)
778
+
779
+ if attention_mask is not None:
780
+ key_attention_mask = attention_mask.clone()
781
+ attention_mask = self.get_attention_mask((batchsize, 1+grid_height*grid_width), key_attention_mask)
782
+ full_seq_attention_mask = attention_mask.clone()
783
+ else:
784
+ key_attention_mask = None
785
+ full_seq_attention_mask = None
786
+
787
+ x = self.ln_pre(x)
788
+ x = x.permute(1, 0, 2) # NLD -> LND
789
+ for i in range(self.num_layers):
790
+ if i == self.patch_mixture_start_layer:
791
+ x = x.permute(1, 0, 2)
792
+ x_D_last = x[:, 1:grid_height*grid_width+1].clone()
793
+ mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
794
+ new_x = self.patch_mixture.start_route(x, mask_info)
795
+ x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
796
+ x = x.permute(1, 0, 2)
797
+ if key_attention_mask is not None:
798
+ attention_mask = self.get_attention_mask((batchsize, 1+new_x.shape[1]), key_attention_mask)
799
+ else:
800
+ attention_mask = None
801
+
802
+ x = self.transformer[i](x, attention_mask=attention_mask)
803
+
804
+ if i == self.patch_mixture_end_layer:
805
+ x = x.permute(1, 0, 2)
806
+ new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
807
+ x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
808
+ x = x.permute(1, 0, 2)
809
+ if full_seq_attention_mask is not None:
810
+ attention_mask = full_seq_attention_mask.clone()
811
+ else:
812
+ attention_mask = None
813
+
814
+ x = x.permute(1, 0, 2) # LND -> NLD
815
+ x = x[:, 1:1+grid_height*grid_width] # remove cls embed
816
+ x = self.ln_post(x)
817
+ # N L D -> N D H W
818
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
819
+ x = self.ffn(x.contiguous())
820
+ x = self.rearrange(x)
821
+ _, _, org_h, org_w = x.shape
822
+ x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
823
+ x = self.conv_out(x)
824
+
825
+ return x
826
+
827
+ # Keep the original TiTokDecoder as a legacy class that inherits from ResolutionDecoder
828
+ class TiTokDecoder(ResolutionDecoder):
829
+ """Legacy TiTokDecoder - now inherits from ResolutionDecoder for backward compatibility"""
830
+
831
+ def __init__(self, config):
832
+ # Override config to disable patch mixture and other advanced features for legacy mode
833
+ config_copy = type(config)()
834
+ for attr in dir(config):
835
+ if not attr.startswith('__'):
836
+ try:
837
+ setattr(config_copy, attr, getattr(config, attr))
838
+ except:
839
+ pass
840
+
841
+ # Disable patch mixture for legacy mode
842
+ if hasattr(config_copy.model.vq_model, 'patch_mixture_start_layer'):
843
+ config_copy.model.vq_model.patch_mixture_start_layer = -1
844
+ if hasattr(config_copy.model.vq_model, 'patch_mixture_end_layer'):
845
+ config_copy.model.vq_model.patch_mixture_end_layer = -1
846
+
847
+ super().__init__(config_copy)
848
+
849
+ # Override grid_size for legacy compatibility
850
+ self.grid_size = self.image_size // self.patch_size
851
+
852
+ # Replace ResolutionDecoder's advanced final layers with legacy ones if needed
853
+ if self.is_legacy:
854
+ self.ffn = nn.Sequential(
855
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
856
+ nn.Tanh(),
857
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
858
+ )
859
+ self.conv_out = nn.Identity()
860
+ else:
861
+ # Use simpler final layers for backward compatibility
862
+ self.ffn = nn.Sequential(
863
+ nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
864
+ Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
865
+ p1 = self.patch_size, p2 = self.patch_size),)
866
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
867
+
868
+ def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
869
+ # Legacy compatibility: use fixed grid size if height/width not provided
870
+ if height is None:
871
+ height = self.image_size
872
+ if width is None:
873
+ width = self.image_size
874
+
875
+ # Force decode_patch_size to be the original patch_size for legacy compatibility
876
+ if decode_patch_size is None:
877
+ decode_patch_size = self.patch_size
878
+
879
+ # Use the parent's forward method but with legacy parameters
880
+ return super().forward(z_quantized, attention_mask, height, width, decode_patch_size, train)
881
+
882
+
883
+ class TATiTokDecoder(ResolutionDecoder):
884
+ def __init__(self, config):
885
+ super().__init__(config)
886
+ scale = self.width ** -0.5
887
+ self.text_context_length = config.model.vq_model.get("text_context_length", 77)
888
+ self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
889
+ self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
890
+ self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
891
+
892
+ # Add grid_size for backward compatibility
893
+ self.grid_size = self.image_size // self.patch_size
894
+
895
+ def forward(self, z_quantized, text_guidance, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
896
+ N, C, H, W = z_quantized.shape
897
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
898
+ x = self.decoder_embed(x)
899
+
900
+ batchsize, seq_len, _ = x.shape
901
+
902
+ # Use fixed grid size for backward compatibility
903
+ if height is None:
904
+ height = self.image_size
905
+ if width is None:
906
+ width = self.image_size
907
+ if decode_patch_size is None:
908
+ decode_patch_size = self.patch_size
909
+
910
+ grid_height = height // decode_patch_size
911
+ grid_width = width // decode_patch_size
912
+
913
+ mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
914
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
915
+ mask_tokens], dim=1)
916
+ mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
917
+ x = x + self.latent_token_positional_embedding[:seq_len]
918
+ x = torch.cat([mask_tokens, x], dim=1)
919
+
920
+ text_guidance = self.text_guidance_proj(text_guidance)
921
+ text_guidance = text_guidance + self.text_guidance_positional_embedding
922
+ x = torch.cat([x, text_guidance], dim=1)
923
+
924
+ x = self.ln_pre(x)
925
+ x = x.permute(1, 0, 2) # NLD -> LND
926
+ for i in range(self.num_layers):
927
+ x = self.transformer[i](x)
928
+ x = x.permute(1, 0, 2) # LND -> NLD
929
+ x = x[:, 1:1+grid_height*grid_width] # remove cls embed
930
+ x = self.ln_post(x)
931
+ # N L D -> N D H W
932
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
933
+ x = self.ffn(x.contiguous())
934
+ x = self.conv_out(x)
935
+ return x
936
+
937
+
938
+ class WeightTiedLMHead(nn.Module):
939
+ def __init__(self, embeddings, target_codebook_size):
940
+ super().__init__()
941
+ self.weight = embeddings.weight
942
+ self.target_codebook_size = target_codebook_size
943
+
944
+ def forward(self, x):
945
+ # x shape: [batch_size, seq_len, embed_dim]
946
+ # Get the weights for the target codebook size
947
+ weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
948
+ # Compute the logits by matrix multiplication
949
+ logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
950
+ return logits
951
+
952
+
953
+ class TimestepEmbedder(nn.Module):
954
+ """
955
+ Embeds scalar timesteps into vector representations.
956
+ """
957
+ def __init__(self, hidden_size, frequency_embedding_size=256):
958
+ super().__init__()
959
+ self.mlp = nn.Sequential(
960
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
961
+ nn.SiLU(),
962
+ nn.Linear(hidden_size, hidden_size, bias=True),
963
+ )
964
+ self.frequency_embedding_size = frequency_embedding_size
965
+
966
+ @staticmethod
967
+ def timestep_embedding(t, dim, max_period=10000):
968
+ """
969
+ Create sinusoidal timestep embeddings.
970
+ :param t: a 1-D Tensor of N indices, one per batch element.
971
+ These may be fractional.
972
+ :param dim: the dimension of the output.
973
+ :param max_period: controls the minimum frequency of the embeddings.
974
+ :return: an (N, D) Tensor of positional embeddings.
975
+ """
976
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
977
+ half = dim // 2
978
+ freqs = torch.exp(
979
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
980
+ ).to(device=t.device)
981
+ args = t[:, None].float() * freqs[None]
982
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
983
+ if dim % 2:
984
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
985
+ return embedding
986
+
987
+ def forward(self, t):
988
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
989
+ t_emb = self.mlp(t_freq)
990
+ return t_emb
991
+
992
+
993
+ class ResBlock(nn.Module):
994
+ """
995
+ A residual block that can optionally change the number of channels.
996
+ :param channels: the number of input channels.
997
+ """
998
+
999
+ def __init__(
1000
+ self,
1001
+ channels
1002
+ ):
1003
+ super().__init__()
1004
+ self.channels = channels
1005
+
1006
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
1007
+ self.mlp = nn.Sequential(
1008
+ nn.Linear(channels, channels, bias=True),
1009
+ nn.SiLU(),
1010
+ nn.Linear(channels, channels, bias=True),
1011
+ )
1012
+
1013
+ self.adaLN_modulation = nn.Sequential(
1014
+ nn.SiLU(),
1015
+ nn.Linear(channels, 3 * channels, bias=True)
1016
+ )
1017
+
1018
+ def forward(self, x, y):
1019
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
1020
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
1021
+ h = self.mlp(h)
1022
+ return x + gate_mlp * h
1023
+
1024
+
1025
+ class FinalLayer(nn.Module):
1026
+ """
1027
+ The final layer adopted from DiT.
1028
+ """
1029
+ def __init__(self, model_channels, out_channels):
1030
+ super().__init__()
1031
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
1032
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
1033
+ self.adaLN_modulation = nn.Sequential(
1034
+ nn.SiLU(),
1035
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
1036
+ )
1037
+
1038
+ def forward(self, x, c):
1039
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
1040
+ x = modulate(self.norm_final(x), shift, scale)
1041
+ x = self.linear(x)
1042
+ return x
1043
+
1044
+
1045
+ class SimpleMLPAdaLN(nn.Module):
1046
+ """
1047
+ The MLP for Diffusion Loss.
1048
+ :param in_channels: channels in the input Tensor.
1049
+ :param model_channels: base channel count for the model.
1050
+ :param out_channels: channels in the output Tensor.
1051
+ :param z_channels: channels in the condition.
1052
+ :param num_res_blocks: number of residual blocks per downsample.
1053
+ """
1054
+
1055
+ def __init__(
1056
+ self,
1057
+ in_channels,
1058
+ model_channels,
1059
+ out_channels,
1060
+ z_channels,
1061
+ num_res_blocks,
1062
+ grad_checkpointing=False,
1063
+ ):
1064
+ super().__init__()
1065
+
1066
+ self.in_channels = in_channels
1067
+ self.model_channels = model_channels
1068
+ self.out_channels = out_channels
1069
+ self.num_res_blocks = num_res_blocks
1070
+ self.grad_checkpointing = grad_checkpointing
1071
+
1072
+ self.time_embed = TimestepEmbedder(model_channels)
1073
+ self.cond_embed = nn.Linear(z_channels, model_channels)
1074
+
1075
+ self.input_proj = nn.Linear(in_channels, model_channels)
1076
+
1077
+ res_blocks = []
1078
+ for i in range(num_res_blocks):
1079
+ res_blocks.append(ResBlock(
1080
+ model_channels,
1081
+ ))
1082
+
1083
+ self.res_blocks = nn.ModuleList(res_blocks)
1084
+ self.final_layer = FinalLayer(model_channels, out_channels)
1085
+
1086
+ self.initialize_weights()
1087
+
1088
+ def initialize_weights(self):
1089
+ def _basic_init(module):
1090
+ if isinstance(module, nn.Linear):
1091
+ torch.nn.init.xavier_uniform_(module.weight)
1092
+ if module.bias is not None:
1093
+ nn.init.constant_(module.bias, 0)
1094
+ self.apply(_basic_init)
1095
+
1096
+ # Initialize timestep embedding MLP
1097
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
1098
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
1099
+
1100
+ # Zero-out adaLN modulation layers
1101
+ for block in self.res_blocks:
1102
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
1103
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
1104
+
1105
+ # Zero-out output layers
1106
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
1107
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
1108
+ nn.init.constant_(self.final_layer.linear.weight, 0)
1109
+ nn.init.constant_(self.final_layer.linear.bias, 0)
1110
+
1111
+ def forward(self, x, t, c):
1112
+ """
1113
+ Apply the model to an input batch.
1114
+ :param x: an [N x C] Tensor of inputs.
1115
+ :param t: a 1-D batch of timesteps.
1116
+ :param c: conditioning from AR transformer.
1117
+ :return: an [N x C] Tensor of outputs.
1118
+ """
1119
+ x = self.input_proj(x)
1120
+ t = self.time_embed(t)
1121
+ c = self.cond_embed(c)
1122
+
1123
+ y = t + c
1124
+
1125
+ if self.grad_checkpointing and not torch.jit.is_scripting():
1126
+ for block in self.res_blocks:
1127
+ x = checkpoint(block, x, y)
1128
+ else:
1129
+ for block in self.res_blocks:
1130
+ x = block(x, y)
1131
+
1132
+ return self.final_layer(x, y)
1133
+
1134
+ def forward_with_cfg(self, x, t, c, cfg_scale):
1135
+ half = x[: len(x) // 2]
1136
+ combined = torch.cat([half, half], dim=0)
1137
+ model_out = self.forward(combined, t, c)
1138
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
1139
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
1140
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
1141
+ eps = torch.cat([half_eps, half_eps], dim=0)
1142
+ return torch.cat([eps, rest], dim=1)
modeling/modules/fuzzy_embedding.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import einops
5
+ import math
6
+
7
+ class FuzzyEmbedding(nn.Module):
8
+ def __init__(self, grid_size, scale, width, apply_fuzzy=False):
9
+ super(FuzzyEmbedding, self).__init__()
10
+ assert grid_size == 1024, "grid_size must be 1024 for now"
11
+
12
+ self.grid_size = grid_size
13
+ self.scale = scale
14
+ self.width = width
15
+ self.apply_fuzzy = apply_fuzzy
16
+ # grid_size is the minimum possible token size
17
+ # then we can use grid_sample to get the fuzzy embedding for any resolution
18
+ self.positional_embedding = nn.Parameter(
19
+ scale * torch.randn(grid_size, width))
20
+
21
+ self.class_positional_embedding = nn.Parameter(
22
+ scale * torch.randn(1, width))
23
+
24
+ @torch.cuda.amp.autocast(enabled=False)
25
+ def forward(self, grid_height, grid_width, train=True, dtype=torch.float32):
26
+ meshx, meshy = torch.meshgrid(
27
+ torch.tensor(list(range(grid_height)), device=self.positional_embedding.device),
28
+ torch.tensor(list(range(grid_width)), device=self.positional_embedding.device)
29
+ )
30
+ meshx = meshx.to(dtype)
31
+ meshy = meshy.to(dtype)
32
+
33
+ # Normalize coordinates to [-1, 1] range
34
+ meshx = 2 * (meshx / (grid_height - 1)) - 1
35
+ meshy = 2 * (meshy / (grid_width - 1)) - 1
36
+
37
+ if self.apply_fuzzy:
38
+ # Add uniform noise in range [-0.0004, 0.0004] to x and y coordinates
39
+ if train:
40
+ noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004
41
+ noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004
42
+ else:
43
+ noise_x = torch.zeros_like(meshx)
44
+ noise_y = torch.zeros_like(meshy)
45
+
46
+ # Apply noise to the mesh coordinates
47
+ meshx = meshx + noise_x
48
+ meshy = meshy + noise_y
49
+
50
+ grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device)
51
+ grid = grid.unsqueeze(0) # add batch dim
52
+
53
+ positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size)))
54
+ positional_embedding = positional_embedding.to(dtype)
55
+ positional_embedding = positional_embedding.unsqueeze(0) # add batch dim
56
+
57
+ fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False)
58
+ fuzzy_embedding = fuzzy_embedding.to(dtype)
59
+ fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0)
60
+
61
+ final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0)
62
+ return final_embedding
63
+
64
+
65
+ if __name__ == "__main__":
66
+ fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024)
67
+ grid_height = 16
68
+ grid_width = 32
69
+ fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16)
70
+ print(fuzzy_embedding.shape)
modeling/modules/losses.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training loss implementation.
2
+
3
+ Ref:
4
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py
5
+ """
6
+ from typing import Mapping, Text, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from torch.cuda.amp import autocast
13
+
14
+ from modeling.modules.blocks import SimpleMLPAdaLN
15
+ from .perceptual_loss import PerceptualLoss
16
+ from .discriminator import NLayerDiscriminator
17
+
18
+
19
+ def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
20
+ """Hinge loss for discrminator.
21
+
22
+ This function is borrowed from
23
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20
24
+ """
25
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
26
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
27
+ d_loss = 0.5 * (loss_real + loss_fake)
28
+ return d_loss
29
+
30
+
31
+ def compute_lecam_loss(
32
+ logits_real_mean: torch.Tensor,
33
+ logits_fake_mean: torch.Tensor,
34
+ ema_logits_real_mean: torch.Tensor,
35
+ ema_logits_fake_mean: torch.Tensor
36
+ ) -> torch.Tensor:
37
+ """Computes the LeCam loss for the given average real and fake logits.
38
+
39
+ Args:
40
+ logits_real_mean -> torch.Tensor: The average real logits.
41
+ logits_fake_mean -> torch.Tensor: The average fake logits.
42
+ ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
43
+ ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
44
+
45
+ Returns:
46
+ lecam_loss -> torch.Tensor: The LeCam loss.
47
+ """
48
+ lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2))
49
+ lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2))
50
+ return lecam_loss
51
+
52
+
53
+ class ReconstructionLoss_Stage1(torch.nn.Module):
54
+ def __init__(
55
+ self,
56
+ config
57
+ ):
58
+ super().__init__()
59
+ loss_config = config.losses
60
+ self.quantizer_weight = loss_config.quantizer_weight
61
+ self.target_codebook_size = 1024
62
+
63
+ def forward(self,
64
+ target_codes: torch.Tensor,
65
+ reconstructions: torch.Tensor,
66
+ quantizer_loss: torch.Tensor,
67
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
68
+ return self._forward_generator(target_codes, reconstructions, quantizer_loss)
69
+
70
+ def _forward_generator(self,
71
+ target_codes: torch.Tensor,
72
+ reconstructions: torch.Tensor,
73
+ quantizer_loss: Mapping[Text, torch.Tensor],
74
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
75
+ reconstructions = reconstructions.contiguous()
76
+ loss_fct = nn.CrossEntropyLoss(reduction="mean")
77
+ batch_size = reconstructions.shape[0]
78
+ reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1),
79
+ target_codes.view(batch_size, -1))
80
+ total_loss = reconstruction_loss + \
81
+ self.quantizer_weight * quantizer_loss["quantizer_loss"]
82
+
83
+ loss_dict = dict(
84
+ total_loss=total_loss.clone().detach(),
85
+ reconstruction_loss=reconstruction_loss.detach(),
86
+ quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(),
87
+ commitment_loss=quantizer_loss["commitment_loss"].detach(),
88
+ codebook_loss=quantizer_loss["codebook_loss"].detach(),
89
+ )
90
+
91
+ return total_loss, loss_dict
92
+
93
+
94
+ class ReconstructionLoss_Stage2(torch.nn.Module):
95
+ def __init__(
96
+ self,
97
+ config
98
+ ):
99
+ """Initializes the losses module.
100
+
101
+ Args:
102
+ config: A dictionary, the configuration for the model and everything else.
103
+ """
104
+ super().__init__()
105
+ loss_config = config.losses
106
+ self.discriminator = NLayerDiscriminator()
107
+
108
+ self.reconstruction_loss = loss_config.reconstruction_loss
109
+ self.reconstruction_weight = loss_config.reconstruction_weight
110
+ self.quantizer_weight = loss_config.quantizer_weight
111
+ self.perceptual_loss = PerceptualLoss(
112
+ loss_config.perceptual_loss).eval()
113
+ self.perceptual_weight = loss_config.perceptual_weight
114
+ self.discriminator_iter_start = loss_config.discriminator_start
115
+
116
+ self.discriminator_factor = loss_config.discriminator_factor
117
+ self.discriminator_weight = loss_config.discriminator_weight
118
+ self.lecam_regularization_weight = loss_config.lecam_regularization_weight
119
+ self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999)
120
+ if self.lecam_regularization_weight > 0.0:
121
+ self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
122
+ self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
123
+
124
+ self.config = config
125
+
126
+ @autocast(enabled=False)
127
+ def forward(self,
128
+ inputs: torch.Tensor,
129
+ reconstructions: torch.Tensor,
130
+ extra_result_dict: Mapping[Text, torch.Tensor],
131
+ global_step: int,
132
+ mode: str = "generator",
133
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
134
+ # Both inputs and reconstructions are in range [0, 1].
135
+ inputs = inputs.float()
136
+ reconstructions = reconstructions.float()
137
+
138
+ if mode == "generator":
139
+ return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step)
140
+ elif mode == "discriminator":
141
+ return self._forward_discriminator(inputs, reconstructions, global_step)
142
+ else:
143
+ raise ValueError(f"Unsupported mode {mode}")
144
+
145
+ def should_discriminator_be_trained(self, global_step : int):
146
+ return global_step >= self.discriminator_iter_start
147
+
148
+ def _forward_generator(self,
149
+ inputs: torch.Tensor,
150
+ reconstructions: torch.Tensor,
151
+ extra_result_dict: Mapping[Text, torch.Tensor],
152
+ global_step: int
153
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
154
+ """Generator training step."""
155
+ inputs = inputs.contiguous()
156
+ reconstructions = reconstructions.contiguous()
157
+ if self.reconstruction_loss == "l1":
158
+ reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
159
+ elif self.reconstruction_loss == "l2":
160
+ reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
161
+ else:
162
+ raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
163
+ reconstruction_loss *= self.reconstruction_weight
164
+
165
+ # Compute perceptual loss.
166
+ perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
167
+
168
+ # Compute discriminator loss.
169
+ generator_loss = torch.zeros((), device=inputs.device)
170
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
171
+ d_weight = 1.0
172
+ if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
173
+ # Disable discriminator gradients.
174
+ for param in self.discriminator.parameters():
175
+ param.requires_grad = False
176
+ logits_fake = self.discriminator(reconstructions)
177
+ generator_loss = -torch.mean(logits_fake)
178
+
179
+ d_weight *= self.discriminator_weight
180
+
181
+ # Compute quantizer loss.
182
+ quantizer_loss = extra_result_dict["quantizer_loss"]
183
+ total_loss = (
184
+ reconstruction_loss
185
+ + self.perceptual_weight * perceptual_loss
186
+ + self.quantizer_weight * quantizer_loss
187
+ + d_weight * discriminator_factor * generator_loss
188
+ )
189
+ loss_dict = dict(
190
+ total_loss=total_loss.clone().detach(),
191
+ reconstruction_loss=reconstruction_loss.detach(),
192
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
193
+ quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
194
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
195
+ discriminator_factor=torch.tensor(discriminator_factor),
196
+ commitment_loss=extra_result_dict["commitment_loss"].detach(),
197
+ codebook_loss=extra_result_dict["codebook_loss"].detach(),
198
+ d_weight=d_weight,
199
+ gan_loss=generator_loss.detach(),
200
+ )
201
+
202
+ return total_loss, loss_dict
203
+
204
+ def _forward_discriminator(self,
205
+ inputs: torch.Tensor,
206
+ reconstructions: torch.Tensor,
207
+ global_step: int,
208
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
209
+ """Discrminator training step."""
210
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
211
+ loss_dict = {}
212
+ # Turn the gradients on.
213
+ for param in self.discriminator.parameters():
214
+ param.requires_grad = True
215
+
216
+ real_images = inputs.detach().requires_grad_(True)
217
+ logits_real = self.discriminator(real_images)
218
+ logits_fake = self.discriminator(reconstructions.detach())
219
+
220
+ discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake)
221
+
222
+ # optional lecam regularization
223
+ lecam_loss = torch.zeros((), device=inputs.device)
224
+ if self.lecam_regularization_weight > 0.0:
225
+ lecam_loss = compute_lecam_loss(
226
+ torch.mean(logits_real),
227
+ torch.mean(logits_fake),
228
+ self.ema_real_logits_mean,
229
+ self.ema_fake_logits_mean
230
+ ) * self.lecam_regularization_weight
231
+
232
+ self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay)
233
+ self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay)
234
+
235
+ discriminator_loss += lecam_loss
236
+
237
+ loss_dict = dict(
238
+ discriminator_loss=discriminator_loss.detach(),
239
+ logits_real=logits_real.detach().mean(),
240
+ logits_fake=logits_fake.detach().mean(),
241
+ lecam_loss=lecam_loss.detach(),
242
+ )
243
+ return discriminator_loss, loss_dict
244
+
245
+
246
+ class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2):
247
+ def __init__(
248
+ self,
249
+ config
250
+ ):
251
+ super().__init__(config)
252
+ loss_config = config.losses
253
+ self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
254
+
255
+ if self.quantize_mode == "vae":
256
+ self.kl_weight = loss_config.get("kl_weight", 1e-6)
257
+ logvar_init = loss_config.get("logvar_init", 0.0)
258
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init, requires_grad=False)
259
+
260
+ def _forward_generator(self,
261
+ inputs: torch.Tensor,
262
+ reconstructions: torch.Tensor,
263
+ extra_result_dict: Mapping[Text, torch.Tensor],
264
+ global_step: int
265
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
266
+ """Generator training step."""
267
+ inputs = inputs.contiguous()
268
+ reconstructions = reconstructions.contiguous()
269
+ if self.reconstruction_loss == "l1":
270
+ reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
271
+ elif self.reconstruction_loss == "l2":
272
+ reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
273
+ else:
274
+ raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
275
+ reconstruction_loss *= self.reconstruction_weight
276
+
277
+ # Compute perceptual loss.
278
+ perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
279
+
280
+ # Compute discriminator loss.
281
+ generator_loss = torch.zeros((), device=inputs.device)
282
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
283
+ d_weight = 1.0
284
+ if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
285
+ # Disable discriminator gradients.
286
+ for param in self.discriminator.parameters():
287
+ param.requires_grad = False
288
+ logits_fake = self.discriminator(reconstructions)
289
+ generator_loss = -torch.mean(logits_fake)
290
+
291
+ d_weight *= self.discriminator_weight
292
+
293
+ if self.quantize_mode in ["vq", "mvq", "softvq"]:
294
+ # Compute quantizer loss.
295
+ quantizer_loss = extra_result_dict["quantizer_loss"]
296
+ total_loss = (
297
+ reconstruction_loss
298
+ + self.perceptual_weight * perceptual_loss
299
+ + self.quantizer_weight * quantizer_loss
300
+ + d_weight * discriminator_factor * generator_loss
301
+ )
302
+ loss_dict = dict(
303
+ total_loss=total_loss.clone().detach(),
304
+ reconstruction_loss=reconstruction_loss.detach(),
305
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
306
+ quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
307
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
308
+ discriminator_factor=torch.tensor(discriminator_factor),
309
+ commitment_loss=extra_result_dict["commitment_loss"].detach(),
310
+ codebook_loss=extra_result_dict["codebook_loss"].detach(),
311
+ d_weight=d_weight,
312
+ gan_loss=generator_loss.detach(),
313
+ )
314
+ elif self.quantize_mode == "vae":
315
+ # Compute kl loss.
316
+ reconstruction_loss = reconstruction_loss / torch.exp(self.logvar)
317
+ posteriors = extra_result_dict
318
+ kl_loss = posteriors.kl()
319
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
320
+ total_loss = (
321
+ reconstruction_loss
322
+ + self.perceptual_weight * perceptual_loss
323
+ + self.kl_weight * kl_loss
324
+ + d_weight * discriminator_factor * generator_loss
325
+ )
326
+ loss_dict = dict(
327
+ total_loss=total_loss.clone().detach(),
328
+ reconstruction_loss=reconstruction_loss.detach(),
329
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
330
+ kl_loss=(self.kl_weight * kl_loss).detach(),
331
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
332
+ discriminator_factor=torch.tensor(discriminator_factor),
333
+ d_weight=d_weight,
334
+ gan_loss=generator_loss.detach(),
335
+ )
336
+ else:
337
+ raise NotImplementedError
338
+
339
+ return total_loss, loss_dict
modeling/modules/lpips.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LPIPS perceptual loss.
2
+
3
+ Reference:
4
+ https://github.com/richzhang/PerceptualSimilarity/
5
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py
6
+ https://github.com/CompVis/taming-transformers/blob/master/taming/util.py
7
+ """
8
+
9
+ import os
10
+ import hashlib
11
+ import requests
12
+ from collections import namedtuple
13
+ from tqdm import tqdm
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from torchvision import models
19
+
20
+ _LPIPS_MEAN = [-0.030, -0.088, -0.188]
21
+ _LPIPS_STD = [0.458, 0.448, 0.450]
22
+
23
+
24
+ URL_MAP = {
25
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
26
+ }
27
+
28
+ CKPT_MAP = {
29
+ "vgg_lpips": "vgg.pth"
30
+ }
31
+
32
+ MD5_MAP = {
33
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
34
+ }
35
+
36
+
37
+ def download(url, local_path, chunk_size=1024):
38
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
39
+ with requests.get(url, stream=True) as r:
40
+ total_size = int(r.headers.get("content-length", 0))
41
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
42
+ with open(local_path, "wb") as f:
43
+ for data in r.iter_content(chunk_size=chunk_size):
44
+ if data:
45
+ f.write(data)
46
+ pbar.update(chunk_size)
47
+
48
+
49
+ def md5_hash(path):
50
+ with open(path, "rb") as f:
51
+ content = f.read()
52
+ return hashlib.md5(content).hexdigest()
53
+
54
+
55
+ def get_ckpt_path(name, root, check=False):
56
+ assert name in URL_MAP
57
+ path = os.path.join(root, CKPT_MAP[name])
58
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
59
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
60
+ download(URL_MAP[name], path)
61
+ md5 = md5_hash(path)
62
+ assert md5 == MD5_MAP[name], md5
63
+ return path
64
+
65
+
66
+ class LPIPS(nn.Module):
67
+ # Learned perceptual metric.
68
+ def __init__(self, use_dropout=True):
69
+ super().__init__()
70
+ self.scaling_layer = ScalingLayer()
71
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
72
+ self.net = vgg16(pretrained=True, requires_grad=False)
73
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
74
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
75
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
76
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
77
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
78
+ self.load_pretrained()
79
+ for param in self.parameters():
80
+ param.requires_grad = False
81
+
82
+ def load_pretrained(self):
83
+ workspace = os.environ.get('WORKSPACE', '')
84
+ VGG_PATH = get_ckpt_path("vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True)
85
+ self.load_state_dict(torch.load(VGG_PATH, map_location=torch.device("cpu"), weights_only=True), strict=False)
86
+
87
+ def forward(self, input, target):
88
+ # Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1].
89
+ # However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed.
90
+ input = input * 2. - 1.
91
+ target = target * 2. - 1.
92
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
93
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
94
+ feats0, feats1, diffs = {}, {}, {}
95
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
96
+ for kk in range(len(self.chns)):
97
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
98
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
99
+
100
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
101
+ val = res[0]
102
+ for l in range(1, len(self.chns)):
103
+ val += res[l]
104
+ return val
105
+
106
+
107
+ class ScalingLayer(nn.Module):
108
+ def __init__(self):
109
+ super(ScalingLayer, self).__init__()
110
+ self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None])
111
+ self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None])
112
+
113
+ def forward(self, inp):
114
+ return (inp - self.shift) / self.scale
115
+
116
+
117
+ class NetLinLayer(nn.Module):
118
+ """A single linear layer which does a 1x1 conv."""
119
+
120
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
121
+ super(NetLinLayer, self).__init__()
122
+ layers = (
123
+ [
124
+ nn.Dropout(),
125
+ ]
126
+ if (use_dropout)
127
+ else []
128
+ )
129
+ layers += [
130
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
131
+ ]
132
+ self.model = nn.Sequential(*layers)
133
+
134
+
135
+ class vgg16(torch.nn.Module):
136
+ def __init__(self, requires_grad=False, pretrained=True):
137
+ super(vgg16, self).__init__()
138
+ vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
139
+ self.slice1 = torch.nn.Sequential()
140
+ self.slice2 = torch.nn.Sequential()
141
+ self.slice3 = torch.nn.Sequential()
142
+ self.slice4 = torch.nn.Sequential()
143
+ self.slice5 = torch.nn.Sequential()
144
+ self.N_slices = 5
145
+ for x in range(4):
146
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(4, 9):
148
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(9, 16):
150
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
151
+ for x in range(16, 23):
152
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
153
+ for x in range(23, 30):
154
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
155
+ if not requires_grad:
156
+ for param in self.parameters():
157
+ param.requires_grad = False
158
+
159
+ def forward(self, X):
160
+ h = self.slice1(X)
161
+ h_relu1_2 = h
162
+ h = self.slice2(h)
163
+ h_relu2_2 = h
164
+ h = self.slice3(h)
165
+ h_relu3_3 = h
166
+ h = self.slice4(h)
167
+ h_relu4_3 = h
168
+ h = self.slice5(h)
169
+ h_relu5_3 = h
170
+ vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
171
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
172
+ return out
173
+
174
+
175
+ def normalize_tensor(x, eps=1e-10):
176
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
177
+ return x / (norm_factor + eps)
178
+
179
+
180
+ def spatial_average(x, keepdim=True):
181
+ return x.mean([2, 3], keepdim=keepdim)
modeling/modules/maskgit_vqgan.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MaskGIT-VQGAN tokenizer.
2
+
3
+ Reference:
4
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py
5
+ """
6
+
7
+ r"""MaskGIT Tokenizer based on VQGAN.
8
+
9
+ This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841]
10
+ with several modifications. The non-local layers are removed from VQGAN for
11
+ faster speed.
12
+ """
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+
21
+ # Conv2D with same padding
22
+ class Conv2dSame(nn.Conv2d):
23
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
24
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ ih, iw = x.size()[-2:]
28
+
29
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
30
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
31
+
32
+ if pad_h > 0 or pad_w > 0:
33
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
34
+ return super().forward(x)
35
+
36
+
37
+ class ResnetBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels: int,
41
+ out_channels: int = None,
42
+ dropout_prob: float = 0.0,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.in_channels = in_channels
47
+ self.out_channels = out_channels
48
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
49
+
50
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
51
+ self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False)
52
+
53
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
54
+ self.dropout = nn.Dropout(dropout_prob)
55
+ self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False)
56
+
57
+ if self.in_channels != self.out_channels_:
58
+ self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False)
59
+
60
+ def forward(self, hidden_states):
61
+ residual = hidden_states
62
+ hidden_states = self.norm1(hidden_states)
63
+ hidden_states = F.silu(hidden_states)
64
+ hidden_states = self.conv1(hidden_states)
65
+
66
+ hidden_states = self.norm2(hidden_states)
67
+ hidden_states = F.silu(hidden_states)
68
+ hidden_states = self.dropout(hidden_states)
69
+ hidden_states = self.conv2(hidden_states)
70
+
71
+ if self.in_channels != self.out_channels_:
72
+ residual = self.nin_shortcut(hidden_states)
73
+
74
+ return hidden_states + residual
75
+
76
+
77
+ class DownsamplingBlock(nn.Module):
78
+ def __init__(self, config, block_idx: int):
79
+ super().__init__()
80
+
81
+ self.config = config
82
+ self.block_idx = block_idx
83
+
84
+ in_channel_mult = (1,) + tuple(self.config.channel_mult)
85
+ block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
86
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
87
+
88
+ res_blocks = nn.ModuleList()
89
+ for _ in range(self.config.num_res_blocks):
90
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
91
+ block_in = block_out
92
+ self.block = res_blocks
93
+
94
+ self.downsample = self.block_idx != self.config.num_resolutions - 1
95
+
96
+ def forward(self, hidden_states):
97
+ for res_block in self.block:
98
+ hidden_states = res_block(hidden_states)
99
+
100
+ if self.downsample:
101
+ hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
102
+
103
+ return hidden_states
104
+
105
+
106
+ class UpsamplingBlock(nn.Module):
107
+ def __init__(self, config, block_idx: int):
108
+ super().__init__()
109
+
110
+ self.config = config
111
+ self.block_idx = block_idx
112
+
113
+ if self.block_idx == self.config.num_resolutions - 1:
114
+ block_in = self.config.hidden_channels * self.config.channel_mult[-1]
115
+ else:
116
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
117
+
118
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
119
+
120
+ res_blocks = []
121
+ for _ in range(self.config.num_res_blocks):
122
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
123
+ block_in = block_out
124
+ self.block = nn.ModuleList(res_blocks)
125
+
126
+ self.add_upsample = self.block_idx != 0
127
+ if self.add_upsample:
128
+ self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)
129
+
130
+ def forward(self, hidden_states):
131
+ for res_block in self.block:
132
+ hidden_states = res_block(hidden_states)
133
+
134
+ if self.add_upsample:
135
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
136
+ hidden_states = self.upsample_conv(hidden_states)
137
+
138
+ return hidden_states
139
+
140
+
141
+ class Encoder(nn.Module):
142
+ def __init__(self, config):
143
+ super().__init__()
144
+ self.config = config
145
+ # downsampling
146
+ self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False)
147
+
148
+ downsample_blocks = []
149
+ for i_level in range(self.config.num_resolutions):
150
+ downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level))
151
+ self.down = nn.ModuleList(downsample_blocks)
152
+
153
+ # middle
154
+ mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
155
+ res_blocks = nn.ModuleList()
156
+ for _ in range(self.config.num_res_blocks):
157
+ res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout))
158
+ self.mid = res_blocks
159
+
160
+ # end
161
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
162
+ self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1)
163
+
164
+ def forward(self, pixel_values):
165
+ # downsampling
166
+ hidden_states = self.conv_in(pixel_values)
167
+ for block in self.down:
168
+ hidden_states = block(hidden_states)
169
+
170
+ # middle
171
+ for block in self.mid:
172
+ hidden_states = block(hidden_states)
173
+
174
+ # end
175
+ hidden_states = self.norm_out(hidden_states)
176
+ hidden_states = F.silu(hidden_states)
177
+ hidden_states = self.conv_out(hidden_states)
178
+ return hidden_states
179
+
180
+
181
+ class Decoder(nn.Module):
182
+ def __init__(self, config):
183
+ super().__init__()
184
+
185
+ self.config = config
186
+
187
+ # compute in_channel_mult, block_in and curr_res at lowest res
188
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
189
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
190
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
191
+
192
+ # z to block_in
193
+ self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3)
194
+
195
+ # middle
196
+ res_blocks = nn.ModuleList()
197
+ for _ in range(self.config.num_res_blocks):
198
+ res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout))
199
+ self.mid = res_blocks
200
+
201
+ # upsampling
202
+ upsample_blocks = []
203
+ for i_level in reversed(range(self.config.num_resolutions)):
204
+ upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))
205
+ self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
206
+
207
+ # end
208
+ block_out = self.config.hidden_channels * self.config.channel_mult[0]
209
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
210
+ self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3)
211
+
212
+ def forward(self, hidden_states):
213
+ # z to block_in
214
+ hidden_states = self.conv_in(hidden_states)
215
+
216
+ # middle
217
+ for block in self.mid:
218
+ hidden_states = block(hidden_states)
219
+
220
+ # upsampling
221
+ for block in reversed(self.up):
222
+ hidden_states = block(hidden_states)
223
+
224
+ # end
225
+ hidden_states = self.norm_out(hidden_states)
226
+ hidden_states = F.silu(hidden_states)
227
+ hidden_states = self.conv_out(hidden_states)
228
+
229
+ return hidden_states
230
+
231
+
232
+ class VectorQuantizer(nn.Module):
233
+ """
234
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
235
+ Discretization bottleneck part of the VQ-VAE.
236
+ """
237
+
238
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
239
+ r"""
240
+ Args:
241
+ num_embeddings: number of vectors in the quantized space.
242
+ embedding_dim: dimensionality of the tensors in the quantized space.
243
+ Inputs to the modules must be in this format as well.
244
+ commitment_cost: scalar which controls the weighting of the loss terms
245
+ (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
246
+ """
247
+ super().__init__()
248
+
249
+ self.num_embeddings = num_embeddings
250
+ self.embedding_dim = embedding_dim
251
+ self.commitment_cost = commitment_cost
252
+
253
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
254
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
255
+
256
+ def forward(self, hidden_states, return_loss=False):
257
+ """
258
+ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
259
+ closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
260
+ quantization pipeline:
261
+ 1. get encoder input (B,C,H,W)
262
+ 2. flatten input to (B*H*W,C)
263
+ """
264
+ # reshape z -> (batch, height, width, channel) and flatten
265
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
266
+
267
+ distances = self.compute_distances(hidden_states)
268
+ min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
269
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
270
+ min_encodings.scatter_(1, min_encoding_indices, 1)
271
+
272
+ # get quantized latent vectors
273
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
274
+
275
+ # reshape to (batch, num_tokens)
276
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
277
+
278
+ # compute loss for embedding
279
+ loss = None
280
+ if return_loss:
281
+ loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
282
+ (z_q - hidden_states.detach()) ** 2
283
+ )
284
+ # preserve gradients
285
+ z_q = hidden_states + (z_q - hidden_states).detach()
286
+
287
+ # reshape back to match original input shape
288
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
289
+
290
+ return z_q, min_encoding_indices, loss
291
+
292
+ def compute_distances(self, hidden_states):
293
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
294
+ hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
295
+ emb_weights = self.embedding.weight.t()
296
+
297
+ inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
298
+ codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
299
+ distances = torch.addmm(
300
+ inputs_norm_sq + codebook_t_norm_sq,
301
+ hidden_states_flattended,
302
+ emb_weights,
303
+ alpha=-2.0,
304
+ )
305
+ return distances
306
+
307
+ def get_codebook_entry(self, indices):
308
+ # indices are expected to be of shape (batch, num_tokens)
309
+ # get quantized latent vectors
310
+ if len(indices.shape) == 2:
311
+ batch, num_tokens = indices.shape
312
+ z_q = self.embedding(indices)
313
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
314
+ elif len(indices.shape) == 3:
315
+ batch, height, width = indices.shape
316
+ indices = indices.view(batch, -1)
317
+ z_q = self.embedding(indices)
318
+ z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2)
319
+ else:
320
+ print(indices.shape)
321
+ raise NotImplementedError
322
+ return z_q
323
+
324
+ # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
325
+ def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
326
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
327
+ distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
328
+
329
+ soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
330
+ if stochastic:
331
+ code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
332
+ else:
333
+ code = distances.argmin(dim=-1) # (batch * height * width)
334
+
335
+ code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
336
+ batch, num_tokens = code.shape
337
+ soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
338
+ return soft_code, code
339
+
340
+ def get_code(self, hidden_states):
341
+ # reshape z -> (batch, height, width, channel)
342
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
343
+ distances = self.compute_distances(hidden_states)
344
+ indices = torch.argmin(distances, axis=1).unsqueeze(1)
345
+ indices = indices.reshape(hidden_states.shape[0], -1)
346
+ return indices
modeling/modules/perceptual_loss.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Perceptual loss module using LPIPS and ConvNeXt-S."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from torchvision import models
7
+ from .lpips import LPIPS
8
+
9
+ _IMAGENET_MEAN = [0.485, 0.456, 0.406]
10
+ _IMAGENET_STD = [0.229, 0.224, 0.225]
11
+
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model_name: str = "convnext_s"):
15
+ """Initializes the PerceptualLoss class.
16
+
17
+ Args:
18
+ model_name: A string, the name of the perceptual loss model to use.
19
+
20
+ Raise:
21
+ ValueError: If the model_name does not contain "lpips" or "convnext_s".
22
+ """
23
+ super().__init__()
24
+ if ("lpips" not in model_name) and (
25
+ "convnext_s" not in model_name):
26
+ raise ValueError(f"Unsupported Perceptual Loss model name {model_name}")
27
+ self.lpips = None
28
+ self.convnext = None
29
+ self.loss_weight_lpips = None
30
+ self.loss_weight_convnext = None
31
+
32
+ # Parsing the model name. We support name formatted in
33
+ # "lpips-convnext_s-{float_number}-{float_number}", where the
34
+ # {float_number} refers to the loss weight for each component.
35
+ # E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss
36
+ # using both the convnext_s and lpips, and average the final loss with
37
+ # (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0).
38
+ if "lpips" in model_name:
39
+ self.lpips = LPIPS().eval()
40
+
41
+ if "convnext_s" in model_name:
42
+ self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval()
43
+
44
+ if "lpips" in model_name and "convnext_s" in model_name:
45
+ loss_config = model_name.split('-')[-2:]
46
+ self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1])
47
+ print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")
48
+
49
+ self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None])
50
+ self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None])
51
+
52
+ for param in self.parameters():
53
+ param.requires_grad = False
54
+
55
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
56
+ """Computes the perceptual loss.
57
+
58
+ Args:
59
+ input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1].
60
+ target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1].
61
+
62
+ Returns:
63
+ A scalar tensor, the perceptual loss.
64
+ """
65
+ # Always in eval mode.
66
+ self.eval()
67
+ loss = 0.
68
+ num_losses = 0.
69
+ lpips_loss = 0.
70
+ convnext_loss = 0.
71
+ # Computes LPIPS loss, if available.
72
+ if self.lpips is not None:
73
+ lpips_loss = self.lpips(input, target)
74
+ if self.loss_weight_lpips is None:
75
+ loss += lpips_loss
76
+ num_losses += 1
77
+ else:
78
+ num_losses += self.loss_weight_lpips
79
+ loss += self.loss_weight_lpips * lpips_loss
80
+
81
+ if self.convnext is not None:
82
+ # Computes ConvNeXt-s loss, if available.
83
+ input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True)
84
+ target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True)
85
+ pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std)
86
+ pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std)
87
+ convnext_loss = torch.nn.functional.mse_loss(
88
+ pred_input,
89
+ pred_target,
90
+ reduction="mean")
91
+
92
+ if self.loss_weight_convnext is None:
93
+ num_losses += 1
94
+ loss += convnext_loss
95
+ else:
96
+ num_losses += self.loss_weight_convnext
97
+ loss += self.loss_weight_convnext * convnext_loss
98
+
99
+ # weighted avg.
100
+ loss = loss / num_losses
101
+ return loss
modeling/quantizer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .quantizer import VectorQuantizer, DiagonalGaussianDistribution
2
+ from .mvq import VectorQuantizerMVQ
3
+ from .softvq import SoftVectorQuantizer
modeling/quantizer/dist.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import pytz
9
+ import torch
10
+ import torch.distributed as tdist
11
+ import torch.multiprocessing as mp
12
+
13
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ __rank_str_zfill = '0'
15
+ __initialized = False
16
+
17
+
18
+ def initialized():
19
+ return __initialized
20
+
21
+
22
+ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
23
+ global __device
24
+ if not torch.cuda.is_available():
25
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
26
+ return
27
+ elif 'RANK' not in os.environ:
28
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
29
+ __device = torch.empty(1).cuda().device
30
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
31
+ return
32
+ # then 'RANK' must exist
33
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
34
+ local_rank = global_rank % num_gpus
35
+ torch.cuda.set_device(local_rank)
36
+
37
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
38
+ if mp.get_start_method(allow_none=True) is None:
39
+ method = 'fork' if fork else 'spawn'
40
+ print(f'[dist initialize] mp method={method}')
41
+ mp.set_start_method(method)
42
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
43
+
44
+ global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
45
+ __local_rank = local_rank
46
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
47
+ __rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
48
+ __device = torch.empty(1).cuda().device
49
+ __initialized = True
50
+
51
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
52
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
53
+
54
+
55
+ def get_rank():
56
+ return __rank
57
+
58
+
59
+ def get_rank_str_zfill():
60
+ return __rank_str_zfill
61
+
62
+
63
+ def get_local_rank():
64
+ return __local_rank
65
+
66
+
67
+ def get_world_size():
68
+ return __world_size
69
+
70
+
71
+ def get_device():
72
+ return __device
73
+
74
+
75
+ def set_gpu_id(gpu_id: int):
76
+ if gpu_id is None: return
77
+ global __device
78
+ if isinstance(gpu_id, (str, int)):
79
+ torch.cuda.set_device(int(gpu_id))
80
+ __device = torch.empty(1).cuda().device
81
+ else:
82
+ raise NotImplementedError
83
+
84
+
85
+ def is_master():
86
+ return __rank == 0
87
+
88
+
89
+ def is_local_master():
90
+ return __local_rank == 0
91
+
92
+
93
+ def new_group(ranks: List[int]):
94
+ if __initialized:
95
+ return tdist.new_group(ranks=ranks)
96
+ return None
97
+
98
+
99
+ def new_local_machine_group():
100
+ if __initialized:
101
+ cur_subgroup, subgroups = tdist.new_subgroups()
102
+ return cur_subgroup
103
+ return None
104
+
105
+
106
+ def barrier():
107
+ if __initialized:
108
+ tdist.barrier()
109
+
110
+
111
+ def allreduce(t: torch.Tensor, async_op=False):
112
+ if __initialized:
113
+ if not t.is_cuda:
114
+ cu = t.detach().cuda()
115
+ ret = tdist.all_reduce(cu, async_op=async_op)
116
+ t.copy_(cu.cpu())
117
+ else:
118
+ ret = tdist.all_reduce(t, async_op=async_op)
119
+ return ret
120
+ return None
121
+
122
+
123
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
124
+ if __initialized:
125
+ if not t.is_cuda:
126
+ t = t.cuda()
127
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
128
+ tdist.all_gather(ls, t)
129
+ else:
130
+ ls = [t]
131
+ if cat:
132
+ ls = torch.cat(ls, dim=0)
133
+ return ls
134
+
135
+
136
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
137
+ if __initialized:
138
+ if not t.is_cuda:
139
+ t = t.cuda()
140
+
141
+ t_size = torch.tensor(t.size(), device=t.device)
142
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
143
+ tdist.all_gather(ls_size, t_size)
144
+
145
+ max_B = max(size[0].item() for size in ls_size)
146
+ pad = max_B - t_size[0].item()
147
+ if pad:
148
+ pad_size = (pad, *t.size()[1:])
149
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
150
+
151
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
152
+ tdist.all_gather(ls_padded, t)
153
+ ls = []
154
+ for t, size in zip(ls_padded, ls_size):
155
+ ls.append(t[:size[0].item()])
156
+ else:
157
+ ls = [t]
158
+ if cat:
159
+ ls = torch.cat(ls, dim=0)
160
+ return ls
161
+
162
+
163
+ def broadcast(t: torch.Tensor, src_rank) -> None:
164
+ if __initialized:
165
+ if not t.is_cuda:
166
+ cu = t.detach().cuda()
167
+ tdist.broadcast(cu, src=src_rank)
168
+ t.copy_(cu.cpu())
169
+ else:
170
+ tdist.broadcast(t, src=src_rank)
171
+
172
+
173
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
174
+ if not initialized():
175
+ return torch.tensor([val]) if fmt is None else [fmt % val]
176
+
177
+ ts = torch.zeros(__world_size)
178
+ ts[__rank] = val
179
+ allreduce(ts)
180
+ if fmt is None:
181
+ return ts
182
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
183
+
184
+
185
+ def master_only(func):
186
+ @functools.wraps(func)
187
+ def wrapper(*args, **kwargs):
188
+ force = kwargs.pop('force', False)
189
+ if force or is_master():
190
+ ret = func(*args, **kwargs)
191
+ else:
192
+ ret = None
193
+ barrier()
194
+ return ret
195
+ return wrapper
196
+
197
+
198
+ def local_master_only(func):
199
+ @functools.wraps(func)
200
+ def wrapper(*args, **kwargs):
201
+ force = kwargs.pop('force', False)
202
+ if force or is_local_master():
203
+ ret = func(*args, **kwargs)
204
+ else:
205
+ ret = None
206
+ barrier()
207
+ return ret
208
+ return wrapper
209
+
210
+
211
+ def for_visualize(func):
212
+ @functools.wraps(func)
213
+ def wrapper(*args, **kwargs):
214
+ if is_master():
215
+ # with torch.no_grad():
216
+ ret = func(*args, **kwargs)
217
+ else:
218
+ ret = None
219
+ return ret
220
+ return wrapper
221
+
222
+
223
+ def finalize():
224
+ if __initialized:
225
+ tdist.destroy_process_group()
226
+
227
+
228
+ def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
229
+ try:
230
+ __initialize(fork=False, timeout_minutes=timeout_minutes)
231
+ barrier()
232
+ except RuntimeError as e:
233
+ print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
234
+ raise e
235
+
236
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
237
+ _change_builtin_print(is_local_master())
238
+ if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
239
+ sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
240
+
241
+
242
+ def _change_builtin_print(is_master):
243
+ import builtins as __builtin__
244
+
245
+ builtin_print = __builtin__.print
246
+ if type(builtin_print) != type(open):
247
+ return
248
+
249
+ def prt(*args, **kwargs):
250
+ force = kwargs.pop('force', False)
251
+ clean = kwargs.pop('clean', False)
252
+ deeper = kwargs.pop('deeper', False)
253
+ if is_master or force:
254
+ if not clean:
255
+ f_back = sys._getframe().f_back
256
+ if deeper and f_back.f_back is not None:
257
+ f_back = f_back.f_back
258
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
259
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
260
+ builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
261
+ else:
262
+ builtin_print(*args, **kwargs)
263
+
264
+ __builtin__.print = prt
265
+
266
+
267
+ class BackupStreamToFile(object):
268
+ def __init__(self, local_output_dir, for_stdout=True):
269
+ self.for_stdout = for_stdout
270
+ self.terminal_stream = sys.stdout if for_stdout else sys.stderr
271
+ fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
272
+ existing = os.path.exists(fname)
273
+ self.file_stream = open(fname, 'a')
274
+ if existing:
275
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
276
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
277
+ self.file_stream.flush()
278
+ self.enabled = True
279
+
280
+ def write(self, message):
281
+ self.terminal_stream.write(message)
282
+ self.file_stream.write(message)
283
+
284
+ def flush(self):
285
+ self.terminal_stream.flush()
286
+ self.file_stream.flush()
287
+
288
+ def close(self):
289
+ if not self.enabled:
290
+ return
291
+ self.enabled = False
292
+ self.file_stream.flush()
293
+ self.file_stream.close()
294
+ if self.for_stdout:
295
+ sys.stdout = self.terminal_stream
296
+ sys.stdout.flush()
297
+ else:
298
+ sys.stderr = self.terminal_stream
299
+ sys.stderr.flush()
300
+
301
+ def __del__(self):
302
+ self.close()
modeling/quantizer/mvq.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Tuple
3
+ from torch.nn import functional as F
4
+ from torch import distributed as tdist, nn as nn
5
+
6
+ from .quantizer import VectorQuantizer
7
+
8
+ def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
9
+ E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
10
+ E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
11
+ logits = -E_dist.float().mul_(inv_entropy_tau)
12
+ # calc per_sample_entropy
13
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
14
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
15
+ # calc codebook_entropy
16
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
17
+ log_avg_prob = torch.log(avg_prob + 1e-7)
18
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
19
+ # calc entropy_loss
20
+ entropy_loss = per_sample_entropy - codebook_entropy
21
+ return entropy_loss
22
+
23
+
24
+ class NormalizedEmbedding(nn.Embedding):
25
+ def __init__(self, num_embeddings: int, embedding_dim: int):
26
+ super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
27
+ # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
28
+
29
+ def forward(self, idx):
30
+ return F.embedding(
31
+ idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
32
+ self.norm_type, self.scale_grad_by_freq, self.sparse
33
+ )
34
+
35
+ def get_norm_weight(self):
36
+ return F.normalize(self.weight, dim=1)
37
+
38
+
39
+ class ResConv(nn.Conv2d):
40
+ def __init__(self, embed_dim, quant_resi):
41
+ ks = 3 if quant_resi < 0 else 1
42
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
43
+ self.resi_ratio = abs(quant_resi)
44
+
45
+ def forward(self, h_BChw):
46
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
47
+
48
+ class VectorQuantizerMVQ(nn.Module):
49
+ def __init__(
50
+ self,
51
+ codebook_size,
52
+ token_size,
53
+ commitment_cost=0.25,
54
+ use_l2_norm=False,
55
+ # entropy_temp=0.01, # we do not use this
56
+ clustering_vq=False,
57
+ num_codebooks=16
58
+ ):
59
+ super().__init__()
60
+ self.num_codebooks = num_codebooks
61
+ self.codebooks = nn.ModuleList()
62
+ for _ in range(num_codebooks):
63
+ codebook = VectorQuantizer(
64
+ codebook_size=codebook_size // num_codebooks,
65
+ token_size=token_size // num_codebooks,
66
+ commitment_cost=commitment_cost,
67
+ use_l2_norm=use_l2_norm,
68
+ clustering_vq=clustering_vq,
69
+ )
70
+ self.codebooks.append(codebook)
71
+
72
+ def init_vocab(self, eini: float):
73
+ for codebook in self.codebooks:
74
+ codebook.init_vocab(eini)
75
+
76
+ def f_to_idx(self, features):
77
+ indices = []
78
+ chunk_size = features.shape[-1] // self.num_codebooks
79
+ splited_features = features.split(chunk_size, dim=-1)
80
+ for i, codebook in enumerate(self.codebooks):
81
+ indices.append(codebook.f_to_idx(splited_features[i]))
82
+ indices = torch.stack(indices, dim=1)
83
+ return indices
84
+
85
+ def idx_to_f(self, indices):
86
+ assert indices.shape[1] == self.num_codebooks
87
+ latent_features = []
88
+ for i, codebook in enumerate(self.codebooks):
89
+ sub_indices = indices[:, i].flatten(start_dim=1)
90
+ latent_feature = codebook.codebook(sub_indices)
91
+ latent_features.append(latent_feature)
92
+ latent_features = torch.cat(latent_features, dim=-1)
93
+ return latent_features
94
+
95
+ def get_codebook_entry(self, indices):
96
+ """Get codebook entries for multi-codebook indices.
97
+
98
+ Args:
99
+ indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W)
100
+
101
+ Returns:
102
+ z_quantized: Quantized features
103
+ """
104
+ if len(indices.shape) == 2:
105
+ # indices shape: (N, num_codebooks)
106
+ latent_features = []
107
+ for i, codebook in enumerate(self.codebooks):
108
+ sub_indices = indices[:, i]
109
+ latent_feature = codebook.get_codebook_entry(sub_indices)
110
+ latent_features.append(latent_feature)
111
+ return torch.cat(latent_features, dim=-1)
112
+ elif len(indices.shape) == 4:
113
+ # indices shape: (B, num_codebooks, H, W)
114
+ batch_size, _, height, width = indices.shape
115
+ latent_features = []
116
+ for i, codebook in enumerate(self.codebooks):
117
+ sub_indices = indices[:, i] # (B, H, W)
118
+ latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
119
+ # Reshape to (B, H, W, token_size // num_codebooks)
120
+ latent_feature = latent_feature.view(batch_size, height, width, -1)
121
+ latent_features.append(latent_feature)
122
+ # Concatenate along the last dimension and rearrange to (B, C, H, W)
123
+ latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C)
124
+ return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
125
+ else:
126
+ raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
127
+
128
+ def forward(self, features):
129
+ latent_features = []
130
+ all_result_dicts = []
131
+ chunk_size = features.shape[1] // self.num_codebooks
132
+ splited_features = features.split(chunk_size, dim=1)
133
+
134
+ for i, codebook in enumerate(self.codebooks):
135
+ latent_feature, result_dict = codebook(splited_features[i].float())
136
+ latent_features.append(latent_feature.to(features.dtype))
137
+ all_result_dicts.append(result_dict)
138
+
139
+ # Concatenate latent features
140
+ z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension
141
+
142
+ # Calculate global losses
143
+ global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
144
+ global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
145
+ global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
146
+
147
+ # Collect all min_encoding_indices
148
+ # Each codebook returns indices of shape (B, H, W)
149
+ # Stack them to get shape (B, num_codebooks, H, W)
150
+ all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
151
+
152
+ result_dict = dict(
153
+ quantizer_loss=global_quantizer_loss,
154
+ commitment_loss=global_commitment_loss,
155
+ codebook_loss=global_codebook_loss,
156
+ min_encoding_indices=all_indices
157
+ )
158
+
159
+ return z_quantized, result_dict
modeling/quantizer/quantizer.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector quantizer.
2
+
3
+ Reference:
4
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
5
+ https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
6
+ https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
7
+ https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py
8
+ """
9
+ from typing import Mapping, Text, Tuple
10
+
11
+ import torch
12
+ from einops import rearrange
13
+ from accelerate.utils.operations import gather
14
+ from torch.cuda.amp import autocast
15
+
16
+ class VectorQuantizer(torch.nn.Module):
17
+ def __init__(self,
18
+ codebook_size: int = 1024,
19
+ token_size: int = 256,
20
+ commitment_cost: float = 0.25,
21
+ use_l2_norm: bool = False,
22
+ clustering_vq: bool = False
23
+ ):
24
+ super().__init__()
25
+ self.codebook_size = codebook_size
26
+ self.token_size = token_size
27
+ self.commitment_cost = commitment_cost
28
+
29
+ self.embedding = torch.nn.Embedding(codebook_size, token_size)
30
+ self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
31
+ self.use_l2_norm = use_l2_norm
32
+
33
+ self.clustering_vq = clustering_vq
34
+ if clustering_vq:
35
+ self.decay = 0.99
36
+ self.register_buffer("embed_prob", torch.zeros(self.codebook_size))
37
+
38
+ # Ensure quantization is performed using f32
39
+ @autocast(enabled=False)
40
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
41
+ z = z.float()
42
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
43
+ z_flattened = rearrange(z, 'b h w c -> (b h w) c')
44
+ unnormed_z_flattened = z_flattened
45
+
46
+ if self.use_l2_norm:
47
+ z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
48
+ embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
49
+ else:
50
+ embedding = self.embedding.weight
51
+ d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
52
+ torch.sum(embedding**2, dim=1) - 2 * \
53
+ torch.einsum('bd,dn->bn', z_flattened, embedding.T)
54
+
55
+ min_encoding_indices = torch.argmin(d, dim=1) # num_ele
56
+ z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
57
+
58
+ if self.use_l2_norm:
59
+ z = torch.nn.functional.normalize(z, dim=-1)
60
+
61
+ # compute loss for embedding
62
+ commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
63
+ codebook_loss = torch.mean((z_quantized - z.detach()) **2)
64
+
65
+ if self.clustering_vq and self.training:
66
+ with torch.no_grad():
67
+ # Gather distance matrix from all GPUs.
68
+ encoding_indices = gather(min_encoding_indices)
69
+ if len(min_encoding_indices.shape) != 1:
70
+ raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}")
71
+ # Compute and update the usage of each entry in the codebook.
72
+ encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device)
73
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
74
+ avg_probs = torch.mean(encodings, dim=0)
75
+ self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay)
76
+ # Closest sampling to update the codebook.
77
+ all_d = gather(d)
78
+ all_unnormed_z_flattened = gather(unnormed_z_flattened).detach()
79
+ if all_d.shape[0] != all_unnormed_z_flattened.shape[0]:
80
+ raise ValueError(
81
+ "all_d and all_unnormed_z_flattened have different length" +
82
+ f"{all_d.shape}, {all_unnormed_z_flattened.shape}")
83
+ indices = torch.argmin(all_d, dim=0)
84
+ random_feat = all_unnormed_z_flattened[indices]
85
+ # Decay parameter based on the average usage.
86
+ decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) /
87
+ (1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size)
88
+ self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
89
+
90
+ loss = commitment_loss + codebook_loss
91
+
92
+ # preserve gradients
93
+ z_quantized = z + (z_quantized - z).detach()
94
+
95
+ # reshape back to match original input shape
96
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
97
+
98
+ result_dict = dict(
99
+ quantizer_loss=loss,
100
+ commitment_loss=commitment_loss,
101
+ codebook_loss=codebook_loss,
102
+ min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
103
+ )
104
+
105
+ return z_quantized, result_dict
106
+
107
+ @autocast(enabled=False)
108
+ def get_codebook_entry(self, indices):
109
+ indices = indices.long()
110
+ if len(indices.shape) == 1:
111
+ z_quantized = self.embedding(indices)
112
+ elif len(indices.shape) == 2:
113
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
114
+ else:
115
+ raise NotImplementedError
116
+ if self.use_l2_norm:
117
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
118
+ return z_quantized
119
+
120
+
121
+ class DiagonalGaussianDistribution(object):
122
+ @autocast(enabled=False)
123
+ def __init__(self, parameters, deterministic=False):
124
+ """Initializes a Gaussian distribution instance given the parameters.
125
+
126
+ Args:
127
+ parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected
128
+ to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension.
129
+ First C channels are used for mean and last C are used for logvar in the Gaussian distribution.
130
+ deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results
131
+ is purely based on mean (i.e., std = 0).
132
+ """
133
+ self.parameters = parameters
134
+ self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
135
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
136
+ self.deterministic = deterministic
137
+ self.std = torch.exp(0.5 * self.logvar)
138
+ self.var = torch.exp(self.logvar)
139
+ if self.deterministic:
140
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
141
+
142
+ @autocast(enabled=False)
143
+ def sample(self):
144
+ x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device)
145
+ return x
146
+
147
+ @autocast(enabled=False)
148
+ def mode(self):
149
+ return self.mean
150
+
151
+ @autocast(enabled=False)
152
+ def kl(self):
153
+ if self.deterministic:
154
+ return torch.Tensor([0.])
155
+ else:
156
+ return 0.5 * torch.sum(torch.pow(self.mean.float(), 2)
157
+ + self.var.float() - 1.0 - self.logvar.float(),
158
+ dim=[1, 2])
modeling/quantizer/softvq.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Mapping, Text, Tuple
5
+ from einops import rearrange
6
+ from torch.cuda.amp import autocast
7
+
8
+
9
+ class SoftVectorQuantizer(torch.nn.Module):
10
+ def __init__(self,
11
+ codebook_size: int = 1024,
12
+ token_size: int = 256,
13
+ commitment_cost: float = 0.25,
14
+ use_l2_norm: bool = False,
15
+ clustering_vq: bool = False,
16
+ entropy_loss_ratio: float = 0.01,
17
+ tau: float = 0.07,
18
+ num_codebooks: int = 1,
19
+ show_usage: bool = False
20
+ ):
21
+ super().__init__()
22
+ # Map new parameter names to internal names for compatibility
23
+ self.codebook_size = codebook_size
24
+ self.token_size = token_size
25
+ self.commitment_cost = commitment_cost
26
+ self.use_l2_norm = use_l2_norm
27
+ self.clustering_vq = clustering_vq
28
+
29
+ # Keep soft quantization specific parameters
30
+ self.num_codebooks = num_codebooks
31
+ self.n_e = codebook_size
32
+ self.e_dim = token_size
33
+ self.entropy_loss_ratio = entropy_loss_ratio
34
+ self.l2_norm = use_l2_norm
35
+ self.show_usage = show_usage
36
+ self.tau = tau
37
+
38
+ # Single embedding layer for all codebooks
39
+ self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size))
40
+ self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
41
+
42
+ if self.l2_norm:
43
+ self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1)
44
+
45
+ if self.show_usage:
46
+ self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536))
47
+
48
+ # Ensure quantization is performed using f32
49
+ @autocast(enabled=False)
50
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
51
+ z = z.float()
52
+ original_shape = z.shape
53
+
54
+ # Handle input reshaping to match VectorQuantizer format
55
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
56
+ z = z.view(z.size(0), -1, z.size(-1))
57
+
58
+ batch_size, seq_length, _ = z.shape
59
+
60
+ # Ensure sequence length is divisible by number of codebooks
61
+ assert seq_length % self.num_codebooks == 0, \
62
+ f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})"
63
+
64
+ segment_length = seq_length // self.num_codebooks
65
+ z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim)
66
+
67
+ # Apply L2 norm if needed
68
+ embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding
69
+ if self.l2_norm:
70
+ z_segments = F.normalize(z_segments, p=2, dim=-1)
71
+
72
+ z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim)
73
+
74
+ logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach())
75
+
76
+ # Calculate probabilities (soft quantization)
77
+ probs = F.softmax(logits / self.tau, dim=-1)
78
+
79
+ # Soft quantize
80
+ z_q = torch.einsum('nbk, nke -> nbe', probs, embedding)
81
+
82
+ # Reshape back
83
+ z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous()
84
+
85
+ # Calculate cosine similarity
86
+ with torch.no_grad():
87
+ zq_z_cos = F.cosine_similarity(
88
+ z_segments.view(-1, self.e_dim),
89
+ z_q.view(-1, self.e_dim),
90
+ dim=-1
91
+ ).mean()
92
+
93
+ # Get indices for usage tracking
94
+ indices = torch.argmax(probs, dim=-1) # (num_codebooks, batch_size * segment_length)
95
+ indices = indices.transpose(0, 1).contiguous() # (batch_size * segment_length, num_codebooks)
96
+
97
+ # Track codebook usage
98
+ if self.show_usage and self.training:
99
+ for k in range(self.num_codebooks):
100
+ cur_len = indices.size(0)
101
+ self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone())
102
+ self.codebook_used[k, -cur_len:].copy_(indices[:, k])
103
+
104
+ # Calculate losses if training
105
+ if self.training:
106
+ # Soft quantization doesn't have traditional commitment/codebook loss
107
+ # Map entropy loss to quantizer_loss for compatibility
108
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e))
109
+ quantizer_loss = entropy_loss
110
+ commitment_loss = torch.tensor(0.0, device=z.device)
111
+ codebook_loss = torch.tensor(0.0, device=z.device)
112
+ else:
113
+ quantizer_loss = torch.tensor(0.0, device=z.device)
114
+ commitment_loss = torch.tensor(0.0, device=z.device)
115
+ codebook_loss = torch.tensor(0.0, device=z.device)
116
+
117
+ # Calculate codebook usage
118
+ codebook_usage = torch.tensor([
119
+ len(torch.unique(self.codebook_used[k])) / self.n_e
120
+ for k in range(self.num_codebooks)
121
+ ]).mean() if self.show_usage else 0
122
+
123
+ z_q = z_q.view(batch_size, -1, self.e_dim)
124
+
125
+ # Reshape back to original input shape to match VectorQuantizer
126
+ z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1])
127
+ z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous()
128
+
129
+ # Calculate average probabilities
130
+ avg_probs = torch.mean(torch.mean(probs, dim=-1))
131
+ max_probs = torch.mean(torch.max(probs, dim=-1)[0])
132
+
133
+ # Return format matching VectorQuantizer
134
+ result_dict = dict(
135
+ quantizer_loss=quantizer_loss,
136
+ commitment_loss=commitment_loss,
137
+ codebook_loss=codebook_loss,
138
+ min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
139
+ )
140
+
141
+ return z_quantized, result_dict
142
+
143
+ def get_codebook_entry(self, indices):
144
+ """Added for compatibility with VectorQuantizer API"""
145
+ if len(indices.shape) == 1:
146
+ # For single codebook case
147
+ z_quantized = self.embedding[0][indices]
148
+ elif len(indices.shape) == 2:
149
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0])
150
+ else:
151
+ raise NotImplementedError
152
+ if self.use_l2_norm:
153
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
154
+ return z_quantized
155
+
156
+
157
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
158
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
159
+ flat_affinity /= temperature
160
+ probs = F.softmax(flat_affinity, dim=-1)
161
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
162
+ if loss_type == "softmax":
163
+ target_probs = probs
164
+ else:
165
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
166
+ avg_probs = torch.mean(target_probs, dim=0)
167
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6))
168
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
169
+ loss = sample_entropy - avg_entropy
170
+ return loss
modeling/vibetoken_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VibeToken model definition."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+ from modeling.modules.base_model import BaseModel
8
+ from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder
9
+ from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer
10
+ from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder
11
+ from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder
12
+ from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
13
+ import json
14
+ from omegaconf import OmegaConf
15
+ from pathlib import Path
16
+
17
+ from huggingface_hub import PyTorchModelHubMixin
18
+
19
+
20
+ class PretrainedTokenizer(nn.Module):
21
+ def __init__(self, pretrained_weight):
22
+ super().__init__()
23
+ conf = OmegaConf.create(
24
+ {"channel_mult": [1, 1, 2, 2, 4],
25
+ "num_resolutions": 5,
26
+ "dropout": 0.0,
27
+ "hidden_channels": 128,
28
+ "num_channels": 3,
29
+ "num_res_blocks": 2,
30
+ "resolution": 256,
31
+ "z_channels": 256})
32
+ self.encoder = Pixel_Eecoder(conf)
33
+ self.decoder = Pixel_Decoder(conf)
34
+ self.quantize = Pixel_Quantizer(
35
+ num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
36
+ # Load pretrained weights
37
+ self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True)
38
+
39
+ self.eval()
40
+ for param in self.parameters():
41
+ param.requires_grad = False
42
+
43
+ @torch.no_grad()
44
+ def encode(self, x):
45
+ hidden_states = self.encoder(x)
46
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states)
47
+ return codebook_indices.detach()
48
+
49
+ @torch.no_grad()
50
+ def decode(self, codes):
51
+ quantized_states = self.quantize.get_codebook_entry(codes)
52
+ rec_images = self.decoder(quantized_states)
53
+ rec_images = torch.clamp(rec_images, 0.0, 1.0)
54
+ return rec_images.detach()
55
+
56
+ @torch.no_grad()
57
+ def decode_tokens(self, codes):
58
+ return self.decode(codes)
59
+
60
+
61
+ class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]):
62
+ def __init__(self, config):
63
+
64
+ if isinstance(config, dict):
65
+ config = OmegaConf.create(config)
66
+
67
+ super().__init__()
68
+ self.config = config
69
+ # This should be False for stage1 and True for stage2.
70
+ self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True)
71
+
72
+ self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
73
+ if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]:
74
+ raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.")
75
+
76
+ if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]:
77
+ raise ValueError("Only supprot finetune_decoder with vq quantization for now.")
78
+
79
+ self.encoder = ResolutionEncoder(config)
80
+ self.decoder = ResolutionDecoder(config)
81
+
82
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
83
+ scale = self.encoder.width ** -0.5
84
+ self.latent_tokens = nn.Parameter(
85
+ scale * torch.randn(self.num_latent_tokens, self.encoder.width))
86
+
87
+ self.apply(self._init_weights)
88
+
89
+ if self.quantize_mode == "vq":
90
+ self.quantize = VectorQuantizer(
91
+ codebook_size=config.model.vq_model.codebook_size,
92
+ token_size=config.model.vq_model.token_size,
93
+ commitment_cost=config.model.vq_model.commitment_cost,
94
+ use_l2_norm=config.model.vq_model.use_l2_norm,)
95
+ elif self.quantize_mode == "vae":
96
+ self.quantize = DiagonalGaussianDistribution
97
+ elif self.quantize_mode == "mvq":
98
+ self.quantize = VectorQuantizerMVQ(
99
+ codebook_size=config.model.vq_model.codebook_size,
100
+ token_size=config.model.vq_model.token_size,
101
+ commitment_cost=config.model.vq_model.commitment_cost,
102
+ use_l2_norm=config.model.vq_model.use_l2_norm,
103
+ num_codebooks=config.model.vq_model.num_codebooks,
104
+ )
105
+ elif self.quantize_mode == "softvq":
106
+ self.quantize = SoftVectorQuantizer(
107
+ codebook_size=config.model.vq_model.codebook_size,
108
+ token_size=config.model.vq_model.token_size,
109
+ commitment_cost=config.model.vq_model.commitment_cost,
110
+ use_l2_norm=config.model.vq_model.use_l2_norm,
111
+ num_codebooks=config.model.vq_model.num_codebooks,
112
+ )
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ if self.finetune_decoder:
117
+ # Freeze encoder/quantizer/latent tokens
118
+ self.latent_tokens.requires_grad_(False)
119
+ self.encoder.eval()
120
+ self.encoder.requires_grad_(False)
121
+ self.quantize.eval()
122
+ self.quantize.requires_grad_(False)
123
+
124
+ # Include MaskGiT-VQGAN's quantizer and decoder
125
+ self.pixel_quantize = Pixel_Quantizer(
126
+ num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
127
+ self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
128
+ {"channel_mult": [1, 1, 2, 2, 4],
129
+ "num_resolutions": 5,
130
+ "dropout": 0.0,
131
+ "hidden_channels": 128,
132
+ "num_channels": 3,
133
+ "num_res_blocks": 2,
134
+ "resolution": 256,
135
+ "z_channels": 256}))
136
+
137
+ def _save_pretrained(self, save_directory: Path) -> None:
138
+ """Save weights and config to a local directory."""
139
+ # Assume 'self.config' is your DictConfig object
140
+ # Convert to a regular dictionary
141
+ dict_config = OmegaConf.to_container(self.config)
142
+ # Save as JSON
143
+ file_path = Path(save_directory) / "config.json"
144
+ with open(file_path, 'w') as json_file:
145
+ json.dump(dict_config, json_file, indent=4)
146
+ super()._save_pretrained(save_directory)
147
+
148
+ def _init_weights(self, module):
149
+ """ Initialize the weights.
150
+ :param:
151
+ module -> torch.nn.Module: module to initialize
152
+ """
153
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
154
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
155
+ if module.bias is not None:
156
+ module.bias.data.zero_()
157
+ elif isinstance(module, nn.Embedding):
158
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
159
+ elif isinstance(module, nn.LayerNorm):
160
+ module.bias.data.zero_()
161
+ module.weight.data.fill_(1.0)
162
+
163
+ def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None):
164
+ if self.finetune_decoder:
165
+ with torch.no_grad():
166
+ self.encoder.eval()
167
+ self.quantize.eval()
168
+ z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
169
+ z_quantized, result_dict = self.quantize(z)
170
+ result_dict["quantizer_loss"] *= 0
171
+ result_dict["commitment_loss"] *= 0
172
+ result_dict["codebook_loss"] *= 0
173
+ else:
174
+ if length is not None:
175
+ attention_mask = None
176
+ latent_tokens = self.latent_tokens[:length+1]
177
+ else:
178
+ latent_tokens = self.latent_tokens
179
+ z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
180
+ if self.quantize_mode in ["vq", "mvq", "softvq"]:
181
+ z_quantized, result_dict = self.quantize(z)
182
+ elif self.quantize_mode == "vae":
183
+ posteriors = self.quantize(z)
184
+ z_quantized = posteriors.sample()
185
+ result_dict = posteriors
186
+
187
+ return z_quantized, result_dict
188
+
189
+ def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
190
+ decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
191
+ if self.finetune_decoder:
192
+ quantized_states = torch.einsum(
193
+ 'nchw,cd->ndhw', decoded.softmax(1),
194
+ self.pixel_quantize.embedding.weight)
195
+ decoded = self.pixel_decoder(quantized_states)
196
+ return decoded
197
+
198
+ def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
199
+ if self.quantize_mode in ["vq", "softvq"]:
200
+ tokens = tokens.squeeze(1)
201
+ batch, seq_len = tokens.shape # B x N
202
+ z_quantized = self.quantize.get_codebook_entry(
203
+ tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
204
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
205
+ elif self.quantize_mode == "mvq":
206
+ z_quantized = self.quantize.get_codebook_entry(tokens)
207
+ elif self.quantize_mode == "vae":
208
+ z_quantized = tokens
209
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
210
+ decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
211
+ return decoded
212
+
213
+ def forward(self, x, key_attention_mask=None, height=None, width=None, train=True):
214
+ if height is None:
215
+ batch_size, channels, height, width = x.shape
216
+ z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train)
217
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
218
+ decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train)
219
+ return decoded, result_dict
reconstruct.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Simple reconstruction script for VibeToken.
3
+
4
+ Usage:
5
+ # Auto mode (recommended) - automatically determines optimal settings
6
+ python reconstruct.py --auto \
7
+ --config configs/vibetoken_ll.yaml \
8
+ --checkpoint /path/to/checkpoint.bin \
9
+ --image assets/example_1.jpg \
10
+ --output assets/reconstructed.png
11
+
12
+ # Manual mode - specify all parameters
13
+ python reconstruct.py \
14
+ --config configs/vibetoken_ll.yaml \
15
+ --checkpoint /path/to/checkpoint.bin \
16
+ --image assets/example_1.jpg \
17
+ --output assets/reconstructed.png \
18
+ --input_height 512 --input_width 512 \
19
+ --encoder_patch_size 16,32 \
20
+ --decoder_patch_size 16
21
+ """
22
+
23
+ import argparse
24
+ from PIL import Image
25
+ from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
26
+
27
+
28
+ def parse_patch_size(value):
29
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
30
+ if value is None:
31
+ return None
32
+ if ',' in value:
33
+ parts = value.split(',')
34
+ return (int(parts[0]), int(parts[1]))
35
+ return int(value)
36
+
37
+
38
+ def main():
39
+ parser = argparse.ArgumentParser(description="VibeToken image reconstruction")
40
+ parser.add_argument("--config", type=str, default="configs/vibetoken_ll.yaml",
41
+ help="Path to config YAML")
42
+ parser.add_argument("--checkpoint", type=str, required=True,
43
+ help="Path to model checkpoint")
44
+ parser.add_argument("--image", type=str, default="assets/example_1.jpg",
45
+ help="Path to input image")
46
+ parser.add_argument("--output", type=str, default="./assets/reconstructed.png",
47
+ help="Path to output image")
48
+ parser.add_argument("--device", type=str, default="cuda",
49
+ help="Device (cuda/cpu)")
50
+
51
+ # Auto mode
52
+ parser.add_argument("--auto", action="store_true",
53
+ help="Auto mode: automatically determine optimal input resolution and patch sizes")
54
+
55
+ # Input resolution (optional - resize input before encoding)
56
+ parser.add_argument("--input_height", type=int, default=None,
57
+ help="Resize input to this height before encoding (default: original)")
58
+ parser.add_argument("--input_width", type=int, default=None,
59
+ help="Resize input to this width before encoding (default: original)")
60
+
61
+ # Output resolution (optional - decode to this size)
62
+ parser.add_argument("--output_height", type=int, default=None,
63
+ help="Decode to this height (default: same as input)")
64
+ parser.add_argument("--output_width", type=int, default=None,
65
+ help="Decode to this width (default: same as input)")
66
+
67
+ # Patch sizes (optional) - supports single int or tuple like "16,32"
68
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
69
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
70
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
71
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
72
+
73
+ args = parser.parse_args()
74
+
75
+ # Load tokenizer
76
+ print(f"Loading tokenizer from {args.config}")
77
+ tokenizer = VibeTokenTokenizer.from_config(
78
+ args.config,
79
+ args.checkpoint,
80
+ device=args.device,
81
+ )
82
+
83
+ # Load image
84
+ print(f"Loading image from {args.image}")
85
+ image = Image.open(args.image).convert("RGB")
86
+ original_size = image.size # (W, H)
87
+ print(f"Original image size: {original_size[0]}x{original_size[1]}")
88
+
89
+ if args.auto:
90
+ # AUTO MODE - use centralized auto_preprocess_image
91
+ print("\n=== AUTO MODE ===")
92
+ image, patch_size, info = auto_preprocess_image(image, verbose=True)
93
+ input_width, input_height = info["cropped_size"]
94
+ output_width, output_height = input_width, input_height
95
+ encoder_patch_size = patch_size
96
+ decoder_patch_size = patch_size
97
+ print("=================\n")
98
+
99
+ else:
100
+ # MANUAL MODE
101
+ # Parse patch sizes
102
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
103
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
104
+
105
+ # Resize input if specified
106
+ if args.input_width or args.input_height:
107
+ input_width = args.input_width or original_size[0]
108
+ input_height = args.input_height or original_size[1]
109
+ print(f"Resizing input to {input_width}x{input_height}")
110
+ image = image.resize((input_width, input_height), Image.LANCZOS)
111
+
112
+ # Always center crop to ensure dimensions divisible by 32
113
+ image = center_crop_to_multiple(image, multiple=32)
114
+ input_width, input_height = image.size
115
+ if (input_width, input_height) != original_size:
116
+ print(f"Center cropped to {input_width}x{input_height} (divisible by 32)")
117
+
118
+ # Determine output size
119
+ output_height = args.output_height or input_height
120
+ output_width = args.output_width or input_width
121
+
122
+ # Encode image to tokens
123
+ print("Encoding image to tokens...")
124
+ if encoder_patch_size:
125
+ print(f" Using encoder patch size: {encoder_patch_size}")
126
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
127
+ print(f"Token shape: {tokens.shape}")
128
+
129
+ # Decode back to image
130
+ print(f"Decoding to {output_width}x{output_height}...")
131
+ if decoder_patch_size:
132
+ print(f" Using decoder patch size: {decoder_patch_size}")
133
+ reconstructed = tokenizer.decode(
134
+ tokens,
135
+ height=output_height,
136
+ width=output_width,
137
+ patch_size=decoder_patch_size
138
+ )
139
+ print(f"Reconstructed shape: {reconstructed.shape}")
140
+
141
+ # Convert tensor to PIL and save
142
+ output_images = tokenizer.to_pil(reconstructed)
143
+ output_images[0].save(args.output)
144
+ print(f"Saved reconstructed image to {args.output}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ torch>=2.0.0
3
+ torchvision
4
+ einops>=0.6.0
5
+ omegaconf>=2.3.0
6
+ pillow>=9.0.0
7
+ numpy>=1.20.0
8
+ huggingface_hub>=0.16.0
9
+ accelerate
10
+ wandb
11
+ webdataset
12
+ timm
13
+ open_clip_torch
14
+ transformers
15
+ scipy
16
+ torch-fidelity
17
+ torchinfo
18
+ termcolor
19
+ iopath
20
+ opencv-python
21
+ diffusers
22
+ gdown
23
+ tqdm
24
+ requests
25
+ datasets
26
+ gradio>=4.0.0
scripts/train_vibetoken.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training script for VibeToken.
2
+
3
+ Reference:
4
+ https://github.com/huggingface/open-muse
5
+ """
6
+ import math
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
11
+ sys.path.append(parent_dir)
12
+
13
+ from accelerate.utils import set_seed
14
+ from accelerate import Accelerator
15
+
16
+ import torch
17
+ import wandb
18
+ from omegaconf import OmegaConf
19
+ from utils.logger import setup_logger
20
+
21
+ from utils.train_utils import (
22
+ get_config, create_pretrained_tokenizer,
23
+ create_model_and_loss_module,
24
+ create_optimizer, create_lr_scheduler, create_dataloader,
25
+ create_evaluator, auto_resume, save_checkpoint,
26
+ train_one_epoch)
27
+
28
+
29
+ def main():
30
+ workspace = os.environ.get('WORKSPACE', '')
31
+ if workspace:
32
+ torch.hub.set_dir(workspace + "/models/hub")
33
+
34
+ config = get_config()
35
+ # Enable TF32 on Ampere GPUs.
36
+ if config.training.enable_tf32:
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+ torch.backends.cudnn.allow_tf32 = True
39
+ torch.backends.cudnn.benchmark = True
40
+ torch.backends.cudnn.deterministic = False
41
+
42
+ output_dir = config.experiment.output_dir
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ config.experiment.logging_dir = os.path.join(output_dir, "logs")
45
+
46
+ # Whether logging to Wandb or Tensorboard.
47
+ tracker = "tensorboard"
48
+ if config.training.enable_wandb:
49
+ tracker = "wandb"
50
+
51
+ accelerator = Accelerator(
52
+ gradient_accumulation_steps=config.training.gradient_accumulation_steps,
53
+ mixed_precision=config.training.mixed_precision,
54
+ log_with=tracker,
55
+ project_dir=config.experiment.logging_dir,
56
+ split_batches=False,
57
+ )
58
+
59
+ logger = setup_logger(name="VibeToken", log_level="INFO",
60
+ output_file=f"{output_dir}/log{accelerator.process_index}.txt")
61
+
62
+ if accelerator.is_main_process:
63
+ if config.training.enable_wandb:
64
+ wandb_config = config.training.get("wandb", {})
65
+ wandb_project = wandb_config.get("project", config.experiment.project)
66
+ wandb_entity = wandb_config.get("entity", None)
67
+ wandb_name = wandb_config.get("name", config.experiment.name)
68
+ wandb_tags = list(wandb_config.get("tags", []))
69
+ wandb_notes = wandb_config.get("notes", None)
70
+ wandb_resume_id = wandb_config.get("resume_id", None)
71
+
72
+ wandb_init_kwargs = {
73
+ "wandb": {
74
+ "name": wandb_name,
75
+ "dir": output_dir,
76
+ "resume": "allow",
77
+ }
78
+ }
79
+ if wandb_entity:
80
+ wandb_init_kwargs["wandb"]["entity"] = wandb_entity
81
+ if wandb_tags:
82
+ wandb_init_kwargs["wandb"]["tags"] = wandb_tags
83
+ if wandb_notes:
84
+ wandb_init_kwargs["wandb"]["notes"] = wandb_notes
85
+ if wandb_resume_id:
86
+ wandb_init_kwargs["wandb"]["id"] = wandb_resume_id
87
+
88
+ accelerator.init_trackers(
89
+ project_name=wandb_project,
90
+ config=OmegaConf.to_container(config, resolve=True),
91
+ init_kwargs=wandb_init_kwargs,
92
+ )
93
+ logger.info(f"WandB initialized - Project: {wandb_project}, Name: {wandb_name}")
94
+ else:
95
+ accelerator.init_trackers(config.experiment.name)
96
+
97
+ config_path = Path(output_dir) / "config.yaml"
98
+ logger.info(f"Saving config to {config_path}")
99
+ OmegaConf.save(config, config_path)
100
+ logger.info(f"Config:\n{OmegaConf.to_yaml(config)}")
101
+
102
+ # If passed along, set the training seed now.
103
+ if config.training.seed is not None:
104
+ set_seed(config.training.seed, device_specific=True)
105
+
106
+ accelerator.wait_for_everyone()
107
+
108
+ # Create pretrained tokenizer in a synchronized manner
109
+ if config.model.vq_model.is_legacy:
110
+ if accelerator.is_main_process:
111
+ logger.info("Creating pretrained tokenizer on main process...")
112
+ accelerator.wait_for_everyone()
113
+ pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator)
114
+ accelerator.wait_for_everyone()
115
+ if accelerator.is_main_process:
116
+ logger.info("Pretrained tokenizer creation completed.")
117
+ else:
118
+ pretrained_tokenizer = None
119
+
120
+ if accelerator.is_main_process:
121
+ logger.info("Creating model and loss module...")
122
+ accelerator.wait_for_everyone()
123
+
124
+ model, ema_model, loss_module = create_model_and_loss_module(
125
+ config, logger, accelerator, model_type="vibetoken")
126
+
127
+ accelerator.wait_for_everyone()
128
+ if accelerator.is_main_process:
129
+ logger.info("Model creation completed.")
130
+
131
+ optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module, model_type="vibetoken")
132
+
133
+ lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler(
134
+ config, logger, accelerator, optimizer, discriminator_optimizer)
135
+
136
+ if accelerator.is_main_process:
137
+ logger.info("Creating dataloaders...")
138
+ train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator)
139
+ accelerator.wait_for_everyone()
140
+
141
+ # Set up evaluator.
142
+ if accelerator.is_main_process:
143
+ logger.info("Setting up evaluator...")
144
+ evaluator = create_evaluator(config, logger, accelerator)
145
+
146
+ # Prepare everything with accelerator.
147
+ logger.info("Preparing model, optimizer and dataloaders")
148
+ # The dataloader are already aware of distributed training, so we don't need to prepare them.
149
+ if config.model.vq_model.is_legacy:
150
+ if config.model.vq_model.finetune_decoder:
151
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
152
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
153
+ )
154
+ else:
155
+ model, optimizer, lr_scheduler = accelerator.prepare(
156
+ model, optimizer, lr_scheduler
157
+ )
158
+ else:
159
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
160
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
161
+ )
162
+
163
+ if config.training.use_ema:
164
+ ema_model.to(accelerator.device)
165
+
166
+ total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes
167
+ num_batches = math.ceil(
168
+ config.experiment.max_train_examples / total_batch_size_without_accum)
169
+ num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps)
170
+ num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch)
171
+
172
+ # Start training.
173
+ logger.info("***** Running training *****")
174
+ logger.info(f" Num training steps = {config.training.max_train_steps}")
175
+ logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}")
176
+ logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}")
177
+ logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {(
178
+ config.training.per_gpu_batch_size *
179
+ accelerator.num_processes *
180
+ config.training.gradient_accumulation_steps)}""")
181
+ global_step = 0
182
+ first_epoch = 0
183
+
184
+ global_step, first_epoch = auto_resume(
185
+ config, logger, accelerator, ema_model, num_update_steps_per_epoch,
186
+ strict=True)
187
+
188
+ for current_epoch in range(first_epoch, num_train_epochs):
189
+ accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.")
190
+ global_step = train_one_epoch(config, logger, accelerator,
191
+ model, ema_model, loss_module,
192
+ optimizer, discriminator_optimizer,
193
+ lr_scheduler, discriminator_lr_scheduler,
194
+ train_dataloader, eval_dataloader,
195
+ evaluator,
196
+ global_step,
197
+ pretrained_tokenizer=pretrained_tokenizer,
198
+ model_type="vibetoken")
199
+ # Stop training if max steps is reached.
200
+ if global_step >= config.training.max_train_steps:
201
+ accelerator.print(
202
+ f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}"
203
+ )
204
+ break
205
+
206
+ accelerator.wait_for_everyone()
207
+ # Save checkpoint at the end of training.
208
+ save_checkpoint(model, output_dir, accelerator, global_step, logger=logger)
209
+ # Save the final trained checkpoint
210
+ if accelerator.is_main_process:
211
+ model = accelerator.unwrap_model(model)
212
+ if config.training.use_ema:
213
+ ema_model.copy_to(model.parameters())
214
+ model.save_pretrained_weight(output_dir)
215
+
216
+ if accelerator.is_main_process and config.training.enable_wandb:
217
+ wandb.finish()
218
+ logger.info("WandB run finished")
219
+ accelerator.end_training()
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()
setup.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Data preparation script for VibeToken training.
3
+ # Set DATA_DIR to control where datasets are stored (defaults to ./data).
4
+ #
5
+ # Usage:
6
+ # export DATA_DIR=/mnt/fastssd/datasets # optional, defaults to ./data
7
+ # bash setup.sh
8
+
9
+ DATA_DIR="${DATA_DIR:-./data}"
10
+
11
+ echo "Using DATA_DIR=${DATA_DIR}"
12
+
13
+ # Download ImageNet-1k via HuggingFace
14
+ export HF_HUB_ENABLE_HF_TRANSFER=1
15
+ huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
16
+
17
+ # Convert to WebDataset format
18
+ python data/convert_imagenet_to_wds.py \
19
+ --input_dir "${DATA_DIR}/imagenet-1k" \
20
+ --output_dir "${DATA_DIR}/imagenet_wds"
train_tokenvibe.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run training with 8 GPUs across 2 nodes (4 GPUs per node)
2
+ NODE_RANK=${RANK:-1}
3
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
4
+ MASTER_PORT=${MASTER_PORT:-9871}
5
+
6
+ accelerate launch \
7
+ --num_machines=1 \
8
+ --num_processes=8 \
9
+ --machine_rank=$NODE_RANK \
10
+ --main_process_ip=$MASTER_ADDR \
11
+ --main_process_port=$MASTER_PORT \
12
+ --same_network \
13
+ scripts/train_tokenvibe.py \
14
+ config=configs/training/VibeToken_small.yaml
train_vibetoken.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run training with 8 GPUs across 2 nodes (4 GPUs per node)
2
+ NODE_RANK=${RANK:-1}
3
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
4
+ MASTER_PORT=${MASTER_PORT:-9871}
5
+
6
+ accelerate launch \
7
+ --num_machines=1 \
8
+ --num_processes=8 \
9
+ --machine_rank=$NODE_RANK \
10
+ --main_process_ip=$MASTER_ADDR \
11
+ --main_process_port=$MASTER_PORT \
12
+ --same_network \
13
+ scripts/train_vibetoken.py \
14
+ config=configs/training/VibeToken_small.yaml
utils/__init__.py ADDED
File without changes
utils/logger.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Util functions supporting logging to terminal and files."""
2
+
3
+ import functools
4
+ import sys
5
+ from accelerate.logging import MultiProcessAdapter
6
+ import logging
7
+ from termcolor import colored
8
+
9
+ from iopath.common.file_io import PathManager as PathManagerClass
10
+
11
+ __all__ = ["setup_logger", "PathManager"]
12
+
13
+ PathManager = PathManagerClass()
14
+
15
+
16
+ class _ColorfulFormatter(logging.Formatter):
17
+ def __init__(self, *args, **kwargs):
18
+ self._root_name = kwargs.pop("root_name") + "."
19
+ self._abbrev_name = kwargs.pop("abbrev_name", self._root_name)
20
+ if len(self._abbrev_name):
21
+ self._abbrev_name = self._abbrev_name + "."
22
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
23
+
24
+ def formatMessage(self, record):
25
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
26
+ log = super(_ColorfulFormatter, self).formatMessage(record)
27
+ if record.levelno == logging.WARNING:
28
+ prefix = colored("WARNING", "red", attrs=["blink"])
29
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
30
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
31
+ else:
32
+ return log
33
+ return prefix + " " + log
34
+
35
+
36
+ @functools.lru_cache()
37
+ def setup_logger(name="TiTok", log_level: str = None, color=True, use_accelerate=True,
38
+ output_file=None):
39
+ logger = logging.getLogger(name)
40
+ if log_level is None:
41
+ logger.setLevel(logging.DEBUG)
42
+ else:
43
+ logger.setLevel(log_level.upper())
44
+
45
+ plain_formatter = logging.Formatter(
46
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
47
+ )
48
+ ch = logging.StreamHandler(stream=sys.stdout)
49
+ ch.setLevel(logging.DEBUG)
50
+ if color:
51
+ formatter = _ColorfulFormatter(
52
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
53
+ datefmt="%m/%d %H:%M:%S",
54
+ root_name=name,
55
+ )
56
+ else:
57
+ formatter = plain_formatter
58
+ ch.setFormatter(formatter)
59
+ logger.addHandler(ch)
60
+
61
+ if output_file is not None:
62
+ fileHandler = logging.FileHandler(output_file)
63
+ fileHandler.setFormatter(formatter)
64
+ logger.addHandler(fileHandler)
65
+
66
+ if use_accelerate:
67
+ return MultiProcessAdapter(logger, {})
68
+ else:
69
+ return logger
utils/lr_schedulers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Learning rate schedulers.
2
+
3
+ Reference:
4
+ https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py
5
+ """
6
+ import math
7
+ from enum import Enum
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+
12
+
13
+ class SchedulerType(Enum):
14
+ COSINE = "cosine"
15
+ CONSTANT = "constant"
16
+
17
+ def get_cosine_schedule_with_warmup(
18
+ optimizer: torch.optim.Optimizer,
19
+ num_warmup_steps: int,
20
+ num_training_steps: int,
21
+ num_cycles: float = 0.5,
22
+ last_epoch: int = -1,
23
+ base_lr: float = 1e-4,
24
+ end_lr: float = 0.0,
25
+ ):
26
+ """Creates a cosine learning rate schedule with warm-up and ending learning rate.
27
+
28
+ Args:
29
+ optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
30
+ num_warmup_steps: An integer, the number of steps for the warmup phase.
31
+ num_training_steps: An integer, the total number of training steps.
32
+ num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
33
+ just decrease from the max value to 0 following a half-cosine).
34
+ last_epoch: An integer, the index of the last epoch when resuming training.
35
+ base_lr: A float, the base learning rate.
36
+ end_lr: A float, the final learning rate.
37
+
38
+ Return:
39
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
40
+ """
41
+
42
+ def lr_lambda(current_step):
43
+ if current_step < num_warmup_steps:
44
+ return float(current_step) / float(max(1, num_warmup_steps))
45
+ progress = float(current_step - num_warmup_steps) / \
46
+ float(max(1, num_training_steps - num_warmup_steps))
47
+ ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
48
+ return (end_lr + (base_lr - end_lr) * ratio) / base_lr
49
+
50
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
51
+
52
+
53
+ def get_constant_schedule_with_warmup(
54
+ optimizer: torch.optim.Optimizer,
55
+ num_warmup_steps: int,
56
+ num_training_steps: int,
57
+ base_lr: float = 1e-4,
58
+ end_lr: float = 0.0,
59
+ ):
60
+ """UViT: Creates a constant learning rate schedule with warm-up.
61
+
62
+ Args:
63
+ optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
64
+ num_warmup_steps: An integer, the number of steps for the warmup phase.
65
+ num_training_steps: An integer, the total number of training steps.
66
+ num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
67
+ just decrease from the max value to 0 following a half-cosine).
68
+ last_epoch: An integer, the index of the last epoch when resuming training.
69
+ base_lr: A float, the base learning rate.
70
+ end_lr: A float, the final learning rate.
71
+
72
+ Return:
73
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
74
+ """
75
+
76
+ def lr_lambda(current_step):
77
+ if current_step < num_warmup_steps:
78
+ return float(current_step) / float(max(1, num_warmup_steps))
79
+ else:
80
+ return 1.0
81
+
82
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
83
+
84
+
85
+ TYPE_TO_SCHEDULER_FUNCTION = {
86
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
87
+ SchedulerType.CONSTANT: get_constant_schedule_with_warmup,
88
+ }
89
+
90
+ def get_scheduler(
91
+ name: Union[str, SchedulerType],
92
+ optimizer: torch.optim.Optimizer,
93
+ num_warmup_steps: Optional[int] = None,
94
+ num_training_steps: Optional[int] = None,
95
+ base_lr: float = 1e-4,
96
+ end_lr: float = 0.0,
97
+ ):
98
+ """Retrieves a learning rate scheduler from the given name and optimizer.
99
+
100
+ Args:
101
+ name: A string or SchedulerType, the name of the scheduler to retrieve.
102
+ optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler.
103
+ num_warmup_steps: An integer, the number of warmup steps.
104
+ num_training_steps: An integer, the total number of training steps.
105
+ base_lr: A float, the base learning rate.
106
+ end_lr: A float, the final learning rate.
107
+
108
+ Returns:
109
+ A instance of torch.optim.lr_scheduler.LambdaLR
110
+
111
+ Raises:
112
+ ValueError: If num_warmup_steps or num_training_steps is not provided.
113
+ """
114
+ name = SchedulerType(name)
115
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
116
+
117
+ if num_warmup_steps is None:
118
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
119
+
120
+ if num_training_steps is None:
121
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
122
+
123
+ return schedule_func(
124
+ optimizer,
125
+ num_warmup_steps=num_warmup_steps,
126
+ num_training_steps=num_training_steps,
127
+ base_lr=base_lr,
128
+ end_lr=end_lr,
129
+ )
utils/misc.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is borrowed from https://github.com/LTH14/mar/blob/main/util/misc.py
2
+ """
3
+ import builtins
4
+ import datetime
5
+ import os
6
+ import time
7
+ from collections import defaultdict, deque
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
13
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
14
+
15
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
16
+ from torch._six import inf
17
+ else:
18
+ from torch import inf
19
+ import copy
20
+
21
+
22
+ class SmoothedValue(object):
23
+ """Track a series of values and provide access to smoothed values over a
24
+ window or the global series average.
25
+ """
26
+
27
+ def __init__(self, window_size=20, fmt=None):
28
+ if fmt is None:
29
+ fmt = "{median:.4f} ({global_avg:.4f})"
30
+ self.deque = deque(maxlen=window_size)
31
+ self.total = 0.0
32
+ self.count = 0
33
+ self.fmt = fmt
34
+
35
+ def update(self, value, n=1):
36
+ self.deque.append(value)
37
+ self.count += n
38
+ self.total += value * n
39
+
40
+ def synchronize_between_processes(self):
41
+ """
42
+ Warning: does not synchronize the deque!
43
+ """
44
+ if not is_dist_avail_and_initialized():
45
+ return
46
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
47
+ dist.barrier()
48
+ dist.all_reduce(t)
49
+ t = t.tolist()
50
+ self.count = int(t[0])
51
+ self.total = t[1]
52
+
53
+ @property
54
+ def median(self):
55
+ d = torch.tensor(list(self.deque))
56
+ return d.median().item()
57
+
58
+ @property
59
+ def avg(self):
60
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
61
+ return d.mean().item()
62
+
63
+ @property
64
+ def global_avg(self):
65
+ return self.total / self.count
66
+
67
+ @property
68
+ def max(self):
69
+ return max(self.deque)
70
+
71
+ @property
72
+ def value(self):
73
+ return self.deque[-1]
74
+
75
+ def __str__(self):
76
+ return self.fmt.format(
77
+ median=self.median,
78
+ avg=self.avg,
79
+ global_avg=self.global_avg,
80
+ max=self.max,
81
+ value=self.value)
82
+
83
+
84
+ class MetricLogger(object):
85
+ def __init__(self, delimiter="\t"):
86
+ self.meters = defaultdict(SmoothedValue)
87
+ self.delimiter = delimiter
88
+
89
+ def update(self, **kwargs):
90
+ for k, v in kwargs.items():
91
+ if v is None:
92
+ continue
93
+ if isinstance(v, torch.Tensor):
94
+ v = v.item()
95
+ assert isinstance(v, (float, int))
96
+ self.meters[k].update(v)
97
+
98
+ def __getattr__(self, attr):
99
+ if attr in self.meters:
100
+ return self.meters[attr]
101
+ if attr in self.__dict__:
102
+ return self.__dict__[attr]
103
+ raise AttributeError("'{}' object has no attribute '{}'".format(
104
+ type(self).__name__, attr))
105
+
106
+ def __str__(self):
107
+ loss_str = []
108
+ for name, meter in self.meters.items():
109
+ loss_str.append(
110
+ "{}: {}".format(name, str(meter))
111
+ )
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ''
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
128
+ data_time = SmoothedValue(fmt='{avg:.4f}')
129
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130
+ log_msg = [
131
+ header,
132
+ '[{0' + space_fmt + '}/{1}]',
133
+ 'eta: {eta}',
134
+ '{meters}',
135
+ 'time: {time}',
136
+ 'data: {data}'
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append('max mem: {memory:.0f}')
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(log_msg.format(
151
+ i, len(iterable), eta=eta_string,
152
+ meters=str(self),
153
+ time=str(iter_time), data=str(data_time),
154
+ memory=torch.cuda.max_memory_allocated() / MB))
155
+ else:
156
+ print(log_msg.format(
157
+ i, len(iterable), eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time), data=str(data_time)))
160
+ i += 1
161
+ end = time.time()
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('{} Total time: {} ({:.4f} s / it)'.format(
165
+ header, total_time_str, total_time / len(iterable)))
166
+
167
+
168
+ def setup_for_distributed(is_master):
169
+ """
170
+ This function disables printing when not in master process
171
+ """
172
+ builtin_print = builtins.print
173
+
174
+ def print(*args, **kwargs):
175
+ force = kwargs.pop('force', False)
176
+ force = force or (get_world_size() > 8)
177
+ if is_master or force:
178
+ now = datetime.datetime.now().time()
179
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
180
+ builtin_print(*args, **kwargs)
181
+
182
+ builtins.print = print
183
+
184
+
185
+ def is_dist_avail_and_initialized():
186
+ if not dist.is_available():
187
+ return False
188
+ if not dist.is_initialized():
189
+ return False
190
+ return True
191
+
192
+
193
+ def get_world_size():
194
+ if not is_dist_avail_and_initialized():
195
+ return 1
196
+ return dist.get_world_size()
197
+
198
+
199
+ def get_rank():
200
+ if not is_dist_avail_and_initialized():
201
+ return 0
202
+ return dist.get_rank()
203
+
204
+
205
+ def is_main_process():
206
+ return get_rank() == 0
207
+
208
+
209
+ def save_on_master(*args, **kwargs):
210
+ if is_main_process():
211
+ torch.save(*args, **kwargs)
212
+
213
+
214
+ def init_distributed_mode(args):
215
+ if args.dist_on_itp:
216
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
217
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
218
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
219
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
220
+ os.environ['LOCAL_RANK'] = str(args.gpu)
221
+ os.environ['RANK'] = str(args.rank)
222
+ os.environ['WORLD_SIZE'] = str(args.world_size)
223
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
224
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
225
+ args.rank = int(os.environ["RANK"])
226
+ args.world_size = int(os.environ['WORLD_SIZE'])
227
+ args.gpu = int(os.environ['LOCAL_RANK'])
228
+ elif 'SLURM_PROCID' in os.environ:
229
+ args.rank = int(os.environ['SLURM_PROCID'])
230
+ args.gpu = args.rank % torch.cuda.device_count()
231
+ else:
232
+ print('Not using distributed mode')
233
+ setup_for_distributed(is_master=True) # hack
234
+ args.distributed = False
235
+ return
236
+
237
+ args.distributed = True
238
+
239
+ torch.cuda.set_device(args.gpu)
240
+ args.dist_backend = 'nccl'
241
+ print('| distributed init (rank {}): {}, gpu {}'.format(
242
+ args.rank, args.dist_url, args.gpu), flush=True)
243
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
244
+ world_size=args.world_size, rank=args.rank)
245
+ torch.distributed.barrier()
246
+ setup_for_distributed(args.rank == 0)
247
+
248
+
249
+ class NativeScalerWithGradNormCount:
250
+ state_dict_key = "amp_scaler"
251
+
252
+ def __init__(self):
253
+ self._scaler = torch.cuda.amp.GradScaler()
254
+
255
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
256
+ self._scaler.scale(loss).backward(create_graph=create_graph)
257
+ if update_grad:
258
+ if clip_grad is not None:
259
+ assert parameters is not None
260
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
261
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
262
+ else:
263
+ self._scaler.unscale_(optimizer)
264
+ norm = get_grad_norm_(parameters)
265
+ self._scaler.step(optimizer)
266
+ self._scaler.update()
267
+ else:
268
+ norm = None
269
+ return norm
270
+
271
+ def state_dict(self):
272
+ return self._scaler.state_dict()
273
+
274
+ def load_state_dict(self, state_dict):
275
+ self._scaler.load_state_dict(state_dict)
276
+
277
+
278
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
279
+ if isinstance(parameters, torch.Tensor):
280
+ parameters = [parameters]
281
+ parameters = [p for p in parameters if p.grad is not None]
282
+ norm_type = float(norm_type)
283
+ if len(parameters) == 0:
284
+ return torch.tensor(0.)
285
+ device = parameters[0].grad.device
286
+ if norm_type == inf:
287
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
288
+ else:
289
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
290
+ return total_norm
291
+
292
+
293
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
294
+ decay = []
295
+ no_decay = []
296
+ for name, param in model.named_parameters():
297
+ if not param.requires_grad:
298
+ continue # frozen weights
299
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
300
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
301
+ else:
302
+ decay.append(param)
303
+ return [
304
+ {'params': no_decay, 'weight_decay': 0.},
305
+ {'params': decay, 'weight_decay': weight_decay}]
306
+
307
+
308
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
309
+ if epoch_name is None:
310
+ epoch_name = str(epoch)
311
+ output_dir = Path(args.output_dir)
312
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
313
+
314
+ # ema
315
+ if ema_params is not None:
316
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
317
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
318
+ assert name in ema_state_dict
319
+ ema_state_dict[name] = ema_params[i]
320
+ else:
321
+ ema_state_dict = None
322
+
323
+ to_save = {
324
+ 'model': model_without_ddp.state_dict(),
325
+ 'model_ema': ema_state_dict,
326
+ 'optimizer': optimizer.state_dict(),
327
+ 'epoch': epoch,
328
+ 'scaler': loss_scaler.state_dict(),
329
+ 'args': args,
330
+ }
331
+ save_on_master(to_save, checkpoint_path)
332
+
333
+
334
+ def all_reduce_mean(x):
335
+ world_size = get_world_size()
336
+ if world_size > 1:
337
+ x_reduce = torch.tensor(x).cuda()
338
+ dist.all_reduce(x_reduce)
339
+ x_reduce /= world_size
340
+ return x_reduce.item()
341
+ else:
342
+ return x