scripts
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- LICENSE +21 -0
- README.md +177 -12
- TRAIN.md +110 -0
- app.py +471 -0
- assets/example_1.png +3 -0
- assets/generated_images.png +3 -0
- assets/reconstructed.png +3 -0
- assets/teaser.png +3 -0
- configs/training/VibeToken_small.yaml +102 -0
- configs/vibetoken_ll.yaml +40 -0
- configs/vibetoken_sl.yaml +40 -0
- data/__init__.py +1 -0
- data/convert_imagenet_to_wds.py +56 -0
- data/webdataset_reader.py +518 -0
- evaluator/__init__.py +1 -0
- evaluator/evaluator.py +230 -0
- evaluator/inception.py +215 -0
- examples/batch_inference.py +241 -0
- examples/encode_decode.py +172 -0
- generate.py +240 -0
- generator/__init__.py +4 -0
- modeling/__init__.py +0 -0
- modeling/modules/__init__.py +6 -0
- modeling/modules/base_model.py +124 -0
- modeling/modules/blocks.py +617 -0
- modeling/modules/discriminator.py +124 -0
- modeling/modules/ema_model.py +241 -0
- modeling/modules/encoder_decoder.py +1142 -0
- modeling/modules/fuzzy_embedding.py +70 -0
- modeling/modules/losses.py +339 -0
- modeling/modules/lpips.py +181 -0
- modeling/modules/maskgit_vqgan.py +346 -0
- modeling/modules/perceptual_loss.py +101 -0
- modeling/quantizer/__init__.py +3 -0
- modeling/quantizer/dist.py +302 -0
- modeling/quantizer/mvq.py +159 -0
- modeling/quantizer/quantizer.py +158 -0
- modeling/quantizer/softvq.py +170 -0
- modeling/vibetoken_model.py +219 -0
- reconstruct.py +148 -0
- requirements.txt +26 -0
- scripts/train_vibetoken.py +223 -0
- setup.sh +20 -0
- train_tokenvibe.sh +14 -0
- train_vibetoken.sh +14 -0
- utils/__init__.py +0 -0
- utils/logger.py +69 -0
- utils/lr_schedulers.py +129 -0
- utils/misc.py +342 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/example_1.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/generated_images.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/reconstructed.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Maitreya Patel
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,14 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title: VibeToken
|
| 3 |
-
emoji: 🦀
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.6.0
|
| 8 |
-
python_version: '3.12'
|
| 9 |
-
app_file: app.py
|
| 10 |
-
pinned: false
|
| 11 |
-
license: mit
|
| 12 |
-
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [CVPR 2026] VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<img src="assets/teaser.png" alt="VibeToken Teaser" width="100%">
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<b>CVPR 2026</b> |
|
| 9 |
+
<a href="#">Paper</a> |
|
| 10 |
+
<a href="#">Project Page</a> |
|
| 11 |
+
<a href="#-checkpoints">Checkpoints</a>
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="https://img.shields.io/badge/CVPR-2026-blue" alt="CVPR 2026">
|
| 16 |
+
<img src="https://img.shields.io/badge/arXiv-TODO-b31b1b" alt="arXiv">
|
| 17 |
+
<img src="https://img.shields.io/badge/License-MIT-green" alt="License">
|
| 18 |
+
<a href="https://huggingface.co/mpatel57/VibeToken"><img src="https://img.shields.io/badge/🤗-Model-yellow" alt="HuggingFace"></a>
|
| 19 |
+
</p>
|
| 20 |
+
|
| 21 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
We introduce an efficient, resolution-agnostic autoregressive (AR) image synthesis approach that generalizes to **arbitrary resolutions and aspect ratios**, narrowing the gap to diffusion models at scale. At its core is **VibeToken**, a novel resolution-agnostic 1D Transformer-based image tokenizer that encodes images into a dynamic, user-controllable sequence of 32--256 tokens, achieving state-of-the-art efficiency and performance trade-off. Building on VibeToken, we present **VibeToken-Gen**, a class-conditioned AR generator with out-of-the-box support for arbitrary resolutions while requiring significantly fewer compute resources.
|
| 24 |
+
|
| 25 |
+
### 🔥 Highlights
|
| 26 |
+
|
| 27 |
+
| | |
|
| 28 |
+
|---|---|
|
| 29 |
+
| 🎯 **1024×1024 in just 64 tokens** | Achieves **3.94 gFID** vs. 5.87 gFID for diffusion-based SOTA (1,024 tokens) |
|
| 30 |
+
| ⚡ **Constant 179G FLOPs** | 63× more efficient than LlamaGen (11T FLOPs at 1024×1024) |
|
| 31 |
+
| 🌐 **Resolution-agnostic** | Supports arbitrary resolutions and aspect ratios out of the box |
|
| 32 |
+
| 🎛️ **Dynamic token count** | User-controllable 32--256 tokens per image |
|
| 33 |
+
| 🔍 **Native super-resolution** | Supports image super-resolution out of the box |
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## 📰 News
|
| 37 |
+
|
| 38 |
+
- **[Feb 2026]** 🎉 VibeToken is accepted at **CVPR 2026**!
|
| 39 |
+
- **[Feb 2026]** Training scripts released.
|
| 40 |
+
- **[Feb 2026]** Inference code and checkpoints released.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## 🚀 Quick Start
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# 1. Clone and setup
|
| 47 |
+
git clone https://github.com/<your-org>/VibeToken.git
|
| 48 |
+
cd VibeToken
|
| 49 |
+
uv venv --python=3.11.6
|
| 50 |
+
source .venv/bin/activate
|
| 51 |
+
uv pip install -r requirements.txt
|
| 52 |
+
|
| 53 |
+
# 2. Download a checkpoint (see Checkpoints section below)
|
| 54 |
+
mkdir -p checkpoints
|
| 55 |
+
wget https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin -O ./checkpoints/VibeToken_LL.bin
|
| 56 |
+
|
| 57 |
+
# 3. Reconstruct an image
|
| 58 |
+
python reconstruct.py --auto \
|
| 59 |
+
--config configs/vibetoken_ll.yaml \
|
| 60 |
+
--checkpoint ./checkpoints/VibeToken_LL.bin \
|
| 61 |
+
--image ./assets/example_1.png \
|
| 62 |
+
--output ./assets/reconstructed.png
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
## 📦 Checkpoints
|
| 67 |
+
|
| 68 |
+
All checkpoints are hosted on [Hugging Face](https://huggingface.co/mpatel57/VibeToken).
|
| 69 |
+
|
| 70 |
+
#### Reconstruction Checkpoints
|
| 71 |
+
|
| 72 |
+
| Name | Resolution | rFID (256 tokens) | rFID (64 tokens) | Download |
|
| 73 |
+
|------|:----------:|:-----------------:|:----------------:|----------|
|
| 74 |
+
| VibeToken-LL | 1024×1024 | 3.76 | 4.12 | [VibeToken_LL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin) |
|
| 75 |
+
| VibeToken-LL | 256×256 | 5.12 | 0.90 | same as above |
|
| 76 |
+
| VibeToken-SL | 1024×1024 | 4.25 | 2.41 | [VibeToken_SL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_SL.bin) |
|
| 77 |
+
| VibeToken-SL | 256×256 | 5.44 | 0.40 | same as above |
|
| 78 |
+
|
| 79 |
+
#### Generation Checkpoints
|
| 80 |
+
|
| 81 |
+
| Name | Training Resolution(s) | Tokens | Best gFID | Download |
|
| 82 |
+
|------|:----------------------:|:------:|:---------:|----------|
|
| 83 |
+
| VibeToken-Gen-B | 256×256 | 65 | 7.62 | [VibeTokenGen-b-fixed65_dynamic_1500k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-b-fixed65_dynamic_1500k.pt) |
|
| 84 |
+
| VibeToken-Gen-B | 1024×1024 | 65 | 7.37 | same as above |
|
| 85 |
+
| VibeToken-Gen-XXL | 256×256 | 65 | 3.62 | [VibeTokenGen-xxl-dynamic-65_750k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-xxl-dynamic-65_750k.pt) |
|
| 86 |
+
| VibeToken-Gen-XXL | 1024×1024 | 65 | **3.54** | same as above |
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
## 🛠️ Setup
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
uv venv --python=3.11.6
|
| 93 |
+
source .venv/bin/activate
|
| 94 |
+
uv pip install -r requirements.txt
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
> **Tip:** If you don't have `uv`, install it via `pip install uv` or see [uv docs](https://github.com/astral-sh/uv). Alternatively, use `python -m venv .venv && pip install -r requirements.txt`.
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
## 🖼️ VibeToken Reconstruction
|
| 101 |
+
|
| 102 |
+
Download the VibeToken-LL checkpoint (see [Checkpoints](#-checkpoints)), then:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# Auto mode (recommended) -- automatically determines optimal patch sizes
|
| 106 |
+
python reconstruct.py --auto \
|
| 107 |
+
--config configs/vibetoken_ll.yaml \
|
| 108 |
+
--checkpoint ./checkpoints/VibeToken_LL.bin \
|
| 109 |
+
--image ./assets/example_1.png \
|
| 110 |
+
--output ./assets/reconstructed.png
|
| 111 |
+
|
| 112 |
+
# Manual mode -- specify patch sizes explicitly
|
| 113 |
+
python reconstruct.py \
|
| 114 |
+
--config configs/vibetoken_ll.yaml \
|
| 115 |
+
--checkpoint ./checkpoints/VibeToken_LL.bin \
|
| 116 |
+
--image ./assets/example_1.png \
|
| 117 |
+
--output ./assets/reconstructed.png \
|
| 118 |
+
--encoder_patch_size 16 \
|
| 119 |
+
--decoder_patch_size 16
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
> **Note:** For best performance, the input image resolution should be a multiple of 32. Images with other resolutions are automatically rescaled to the nearest multiple of 32.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
## 🎨 VibeToken-Gen: ImageNet-1k Generation
|
| 126 |
+
|
| 127 |
+
Download both the VibeToken-LL and VibeToken-Gen-XXL checkpoints (see [Checkpoints](#-checkpoints)), then:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
python generate.py \
|
| 131 |
+
--gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
|
| 132 |
+
--gpt-model GPT-XXL --num-output-layer 4 \
|
| 133 |
+
--num-codebooks 8 --codebook-size 32768 \
|
| 134 |
+
--image-size 256 --cfg-scale 4.0 --top-k 500 --temperature 1.0 \
|
| 135 |
+
--class-dropout-prob 0.1 \
|
| 136 |
+
--extra-layers "QKV" \
|
| 137 |
+
--latent-size 65 \
|
| 138 |
+
--config ./configs/vibetoken_ll.yaml \
|
| 139 |
+
--vq-ckpt ./checkpoints/VibeToken_LL.bin \
|
| 140 |
+
--sample-dir ./assets/ \
|
| 141 |
+
--skip-folder-creation \
|
| 142 |
+
--compile \
|
| 143 |
+
--decoder-patch-size 32,32 \
|
| 144 |
+
--target-resolution 1024,1024 \
|
| 145 |
+
--llamagen-target-resolution 256,256 \
|
| 146 |
+
--precision bf16 \
|
| 147 |
+
--global-seed 156464151
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
The `--target-resolution` controls the tokenizer output resolution, while `--llamagen-target-resolution` controls the generator's internal resolution (max 512×512; for higher resolutions, the tokenizer handles upscaling).
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
## 🏋️ Training
|
| 154 |
+
|
| 155 |
+
To train the VibeToken tokenizer from scratch, please refer to [TRAIN.md](TRAIN.md) for detailed instructions.
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## 🙏 Acknowledgement
|
| 159 |
+
|
| 160 |
+
We would like to acknowledge the following repositories that inspired our work and upon which we directly build:
|
| 161 |
+
[1d-tokenizer](https://github.com/bytedance/1d-tokenizer),
|
| 162 |
+
[LlamaGen](https://github.com/FoundationVision/LlamaGen), and
|
| 163 |
+
[UniTok](https://github.com/FoundationVision/UniTok).
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
## 📝 Citation
|
| 167 |
+
|
| 168 |
+
If you find VibeToken useful in your research, please consider citing:
|
| 169 |
+
|
| 170 |
+
```bibtex
|
| 171 |
+
@inproceedings{vibetoken2026,
|
| 172 |
+
title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
|
| 173 |
+
author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
|
| 174 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 175 |
+
year = {2026}
|
| 176 |
+
}
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
If you have any questions, feel free to open an issue or reach out!
|
TRAIN.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Instructions
|
| 2 |
+
|
| 3 |
+
## VibeToken MVQ Tokenizer
|
| 4 |
+
|
| 5 |
+
This repository contains the training code for our tokenizer.
|
| 6 |
+
We provide the example config [VibeToken-Small](configs/training/VibeToken_small.yaml) that trains the small encoder/decoder architecture with 32-64 tokens.
|
| 7 |
+
|
| 8 |
+
### Data Preparation
|
| 9 |
+
|
| 10 |
+
All data paths are controlled by the `DATA_DIR` environment variable. Set it once to point to your preferred storage location:
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
export DATA_DIR=/path/to/your/storage # defaults to ./data if unset
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Download ImageNet-1k and convert to WebDataset format:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
source .venv/bin/activate
|
| 20 |
+
|
| 21 |
+
# Option 1: Use the setup script (recommended)
|
| 22 |
+
bash setup.sh
|
| 23 |
+
|
| 24 |
+
# Option 2: Run steps manually
|
| 25 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 26 |
+
huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
|
| 27 |
+
python data/convert_imagenet_to_wds.py \
|
| 28 |
+
--input_dir "${DATA_DIR}/imagenet-1k" \
|
| 29 |
+
--output_dir "${DATA_DIR}/imagenet_wds"
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
After preparation, update the shard paths in your training config to match your `DATA_DIR`:
|
| 33 |
+
|
| 34 |
+
```yaml
|
| 35 |
+
dataset:
|
| 36 |
+
params:
|
| 37 |
+
train_shards_path_or_url: "<DATA_DIR>/imagenet_wds/imagenet-train-{000001..000128}.tar"
|
| 38 |
+
eval_shards_path_or_url: "<DATA_DIR>/imagenet_wds/imagenet-val-{000001..000004}.tar"
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Launch Training
|
| 42 |
+
|
| 43 |
+
Start training on 1 node with 8 GPUs:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
source .venv/bin/activate
|
| 47 |
+
bash train_tokenizer.sh
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Config Reference
|
| 51 |
+
|
| 52 |
+
Below are the important hyperparameters to manage the training.
|
| 53 |
+
|
| 54 |
+
```yaml
|
| 55 |
+
model:
|
| 56 |
+
vq_model:
|
| 57 |
+
vit_enc_model_size: "small" # this can be small/base/large
|
| 58 |
+
vit_dec_model_size: "small" # this can be small/base/large
|
| 59 |
+
num_latent_tokens: 64 # in paper we set this to 256
|
| 60 |
+
|
| 61 |
+
losses:
|
| 62 |
+
discriminator_start: 100_000 # set based on convergence, in paper we set this to 250_000
|
| 63 |
+
|
| 64 |
+
dataset:
|
| 65 |
+
params:
|
| 66 |
+
pretokenization: True # keep this true if using the current setup
|
| 67 |
+
train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
|
| 68 |
+
eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
|
| 69 |
+
preprocessing:
|
| 70 |
+
resize_shorter_edge: 512 # maximum size during pretraining but can be any value
|
| 71 |
+
crop_size: 512 # maximum size during pretraining but can be any value
|
| 72 |
+
min_tokens: 32 # minimum number of tokens to generate
|
| 73 |
+
max_tokens: 64 # maximum number of tokens to generate
|
| 74 |
+
|
| 75 |
+
training:
|
| 76 |
+
gradient_accumulation_steps: 1 # increase for LL model that does not fit on single node
|
| 77 |
+
per_gpu_batch_size: 32 # decrease to 16 for LL model; during GAN training this is halved
|
| 78 |
+
max_train_steps: 400_000 # in paper we train up to 650_000; model may diverge after 600_000
|
| 79 |
+
num_generated_images: 2 # for validation
|
| 80 |
+
variable_resolution: # any-to-any resolution training
|
| 81 |
+
any2any: True
|
| 82 |
+
dim:
|
| 83 |
+
- [256, 256]
|
| 84 |
+
- [512, 512]
|
| 85 |
+
- [384, 256]
|
| 86 |
+
- [256, 384]
|
| 87 |
+
- [512, 384]
|
| 88 |
+
- [384, 512]
|
| 89 |
+
ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1] # probability per resolution; must sum to 1.0
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Remove patch mixture parameters unless the model does not fit in memory.
|
| 93 |
+
# This will slow down training and may hurt performance.
|
| 94 |
+
# We do not use this in our normal setup.
|
| 95 |
+
model:
|
| 96 |
+
vq_model:
|
| 97 |
+
encoder:
|
| 98 |
+
patch_mixture_start_layer: 2
|
| 99 |
+
patch_mixture_end_layer: 22
|
| 100 |
+
decoder:
|
| 101 |
+
patch_mixture_start_layer: 2
|
| 102 |
+
patch_mixture_end_layer: 22
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
<!-- ### Reproduced Results on Small Baseline
|
| 107 |
+
|
| 108 |
+
> Note: Our released checkpoints are from a different codebase and may observe +/- changes in results.
|
| 109 |
+
|
| 110 |
+
Below we report the performance on the above training script on the small baseline. This baseline is not reported in the paper but achieves competitive performance as expected. -->
|
app.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VibeToken-Gen Gradio Demo
|
| 3 |
+
Class-conditional ImageNet generation with dynamic resolution support.
|
| 4 |
+
"""
|
| 5 |
+
import spaces
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 16 |
+
torch.set_float32_matmul_precision("high")
|
| 17 |
+
torch.set_grad_enabled(False)
|
| 18 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 19 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 20 |
+
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from vibetokengen.generate import generate
|
| 25 |
+
from vibetokengen.model import GPT_models
|
| 26 |
+
from vibetoken import VibeTokenTokenizer
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Configuration
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
HF_REPO = "mpatel57/VibeToken"
|
| 33 |
+
USE_XXL = os.environ.get("VIBETOKEN_XXL", "0") == "1"
|
| 34 |
+
|
| 35 |
+
if USE_XXL:
|
| 36 |
+
GPT_MODEL_NAME = "GPT-XXL"
|
| 37 |
+
GPT_CKPT_FILENAME = "VibeTokenGen-xxl-dynamic-65_750k.pt"
|
| 38 |
+
NUM_OUTPUT_LAYER = 4
|
| 39 |
+
EXTRA_LAYERS = "QKV"
|
| 40 |
+
else:
|
| 41 |
+
GPT_MODEL_NAME = "GPT-B"
|
| 42 |
+
GPT_CKPT_FILENAME = "VibeTokenGen-b-fixed65_dynamic_1500k.pt"
|
| 43 |
+
NUM_OUTPUT_LAYER = 4
|
| 44 |
+
EXTRA_LAYERS = "QKV"
|
| 45 |
+
|
| 46 |
+
VQ_CKPT_FILENAME = "VibeToken_LL.bin"
|
| 47 |
+
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "vibetoken_ll.yaml")
|
| 48 |
+
|
| 49 |
+
CODEBOOK_SIZE = 32768
|
| 50 |
+
NUM_CODEBOOKS = 8
|
| 51 |
+
LATENT_SIZE = 65
|
| 52 |
+
NUM_CLASSES = 1000
|
| 53 |
+
CLS_TOKEN_NUM = 1
|
| 54 |
+
CLASS_DROPOUT_PROB = 0.1
|
| 55 |
+
CAPPING = 50.0
|
| 56 |
+
|
| 57 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 58 |
+
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
|
| 59 |
+
COMPILE = os.environ.get("VIBETOKEN_NO_COMPILE", "0") != "1" and DEVICE == "cuda"
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# ImageNet class labels (curated popular subset)
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
IMAGENET_CLASSES = {
|
| 66 |
+
"Golden Retriever": 207,
|
| 67 |
+
"Labrador Retriever": 208,
|
| 68 |
+
"German Shepherd": 235,
|
| 69 |
+
"Siberian Husky": 250,
|
| 70 |
+
"Pembroke Corgi": 263,
|
| 71 |
+
"Tabby Cat": 281,
|
| 72 |
+
"Persian Cat": 283,
|
| 73 |
+
"Siamese Cat": 284,
|
| 74 |
+
"Tiger": 292,
|
| 75 |
+
"Lion": 291,
|
| 76 |
+
"Cheetah": 293,
|
| 77 |
+
"Brown Bear": 294,
|
| 78 |
+
"Giant Panda": 388,
|
| 79 |
+
"Red Fox": 277,
|
| 80 |
+
"Arctic Fox": 279,
|
| 81 |
+
"Timber Wolf": 269,
|
| 82 |
+
"Bald Eagle": 22,
|
| 83 |
+
"Macaw": 88,
|
| 84 |
+
"Flamingo": 130,
|
| 85 |
+
"Peacock": 84,
|
| 86 |
+
"Goldfish": 1,
|
| 87 |
+
"Great White Shark": 2,
|
| 88 |
+
"Jellyfish": 107,
|
| 89 |
+
"Monarch Butterfly": 323,
|
| 90 |
+
"Ladybug": 301,
|
| 91 |
+
"Snail": 113,
|
| 92 |
+
"Red Sports Car": 817,
|
| 93 |
+
"School Bus": 779,
|
| 94 |
+
"Steam Locomotive": 820,
|
| 95 |
+
"Sailboat": 914,
|
| 96 |
+
"Space Shuttle": 812,
|
| 97 |
+
"Castle": 483,
|
| 98 |
+
"Church": 497,
|
| 99 |
+
"Lighthouse": 437,
|
| 100 |
+
"Volcano": 980,
|
| 101 |
+
"Lakeside": 975,
|
| 102 |
+
"Cliff": 972,
|
| 103 |
+
"Coral Reef": 973,
|
| 104 |
+
"Valley": 979,
|
| 105 |
+
"Seashore": 978,
|
| 106 |
+
"Mushroom": 947,
|
| 107 |
+
"Broccoli": 937,
|
| 108 |
+
"Pizza": 963,
|
| 109 |
+
"Ice Cream": 928,
|
| 110 |
+
"Cheeseburger": 933,
|
| 111 |
+
"Espresso": 967,
|
| 112 |
+
"Acoustic Guitar": 402,
|
| 113 |
+
"Grand Piano": 579,
|
| 114 |
+
"Violin": 889,
|
| 115 |
+
"Balloon": 417,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
GENERATOR_RESOLUTION_PRESETS = {
|
| 119 |
+
"256 × 256": (256, 256),
|
| 120 |
+
"384 × 256": (384, 256),
|
| 121 |
+
"256 × 384": (256, 384),
|
| 122 |
+
"384 × 384": (384, 384),
|
| 123 |
+
"512 × 256": (512, 256),
|
| 124 |
+
"256 × 512": (256, 512),
|
| 125 |
+
"512 × 512": (512, 512),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
OUTPUT_RESOLUTION_PRESETS = {
|
| 129 |
+
"Same as generator": None,
|
| 130 |
+
"256 × 256": (256, 256),
|
| 131 |
+
"384 × 384": (384, 384),
|
| 132 |
+
"512 × 512": (512, 512),
|
| 133 |
+
"768 × 768": (768, 768),
|
| 134 |
+
"1024 × 1024": (1024, 1024),
|
| 135 |
+
"512 × 256 (2:1)": (512, 256),
|
| 136 |
+
"256 × 512 (1:2)": (256, 512),
|
| 137 |
+
"768 × 512 (3:2)": (768, 512),
|
| 138 |
+
"512 × 768 (2:3)": (512, 768),
|
| 139 |
+
"1024 × 512 (2:1)": (1024, 512),
|
| 140 |
+
"512 × 1024 (1:2)": (512, 1024),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# Model loading
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
vq_model = None
|
| 148 |
+
gpt_model = None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def download_checkpoint(filename: str) -> str:
|
| 152 |
+
return hf_hub_download(repo_id=HF_REPO, filename=filename)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _make_res_tensors(gen_h: int, gen_w: int, multiplier: int):
|
| 156 |
+
"""Create normalized resolution tensors for the GPT generator."""
|
| 157 |
+
th = torch.tensor(gen_h / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
|
| 158 |
+
tw = torch.tensor(gen_w / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
|
| 159 |
+
return th, tw
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _warmup(model):
|
| 163 |
+
"""Run a throwaway generation to trigger torch.compile and warm CUDA caches."""
|
| 164 |
+
print("Warming up (first call triggers compilation, may take ~30-60s)...")
|
| 165 |
+
dummy_cond = torch.tensor([0], device=DEVICE)
|
| 166 |
+
th, tw = _make_res_tensors(256, 256, multiplier=2)
|
| 167 |
+
with torch.inference_mode():
|
| 168 |
+
generate(
|
| 169 |
+
model, dummy_cond, LATENT_SIZE, NUM_CODEBOOKS,
|
| 170 |
+
cfg_scale=4.0, cfg_interval=-1,
|
| 171 |
+
target_h=th, target_w=tw,
|
| 172 |
+
temperature=1.0, top_k=500, top_p=1.0, sample_logits=True,
|
| 173 |
+
)
|
| 174 |
+
if DEVICE == "cuda":
|
| 175 |
+
torch.cuda.synchronize()
|
| 176 |
+
print("Warmup complete — subsequent generations will be fast.")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_models():
|
| 180 |
+
global vq_model, gpt_model
|
| 181 |
+
|
| 182 |
+
print("Downloading checkpoints (if needed)...")
|
| 183 |
+
vq_path = download_checkpoint(VQ_CKPT_FILENAME)
|
| 184 |
+
gpt_path = download_checkpoint(GPT_CKPT_FILENAME)
|
| 185 |
+
|
| 186 |
+
print(f"Loading VibeToken tokenizer from {vq_path}...")
|
| 187 |
+
vq_model = VibeTokenTokenizer.from_config(
|
| 188 |
+
CONFIG_PATH, vq_path, device=DEVICE, dtype=DTYPE,
|
| 189 |
+
)
|
| 190 |
+
print("VibeToken tokenizer loaded.")
|
| 191 |
+
|
| 192 |
+
print(f"Loading {GPT_MODEL_NAME} from {gpt_path}...")
|
| 193 |
+
gpt_model = GPT_models[GPT_MODEL_NAME](
|
| 194 |
+
vocab_size=CODEBOOK_SIZE,
|
| 195 |
+
block_size=LATENT_SIZE,
|
| 196 |
+
num_classes=NUM_CLASSES,
|
| 197 |
+
cls_token_num=CLS_TOKEN_NUM,
|
| 198 |
+
model_type="c2i",
|
| 199 |
+
num_codebooks=NUM_CODEBOOKS,
|
| 200 |
+
n_output_layer=NUM_OUTPUT_LAYER,
|
| 201 |
+
class_dropout_prob=CLASS_DROPOUT_PROB,
|
| 202 |
+
extra_layers=EXTRA_LAYERS,
|
| 203 |
+
capping=CAPPING,
|
| 204 |
+
).to(device=DEVICE, dtype=DTYPE)
|
| 205 |
+
|
| 206 |
+
checkpoint = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
| 207 |
+
if "model" in checkpoint:
|
| 208 |
+
weights = checkpoint["model"]
|
| 209 |
+
elif "module" in checkpoint:
|
| 210 |
+
weights = checkpoint["module"]
|
| 211 |
+
elif "state_dict" in checkpoint:
|
| 212 |
+
weights = checkpoint["state_dict"]
|
| 213 |
+
else:
|
| 214 |
+
weights = checkpoint
|
| 215 |
+
gpt_model.load_state_dict(weights, strict=True)
|
| 216 |
+
gpt_model.eval()
|
| 217 |
+
del checkpoint
|
| 218 |
+
print(f"{GPT_MODEL_NAME} loaded.")
|
| 219 |
+
|
| 220 |
+
if COMPILE:
|
| 221 |
+
print("Compiling GPT model with torch.compile (max-autotune)...")
|
| 222 |
+
gpt_model = torch.compile(gpt_model, mode="max-autotune", fullgraph=True)
|
| 223 |
+
_warmup(gpt_model)
|
| 224 |
+
else:
|
| 225 |
+
print("Skipping torch.compile (set VIBETOKEN_NO_COMPILE=0 to enable).")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
# Decoder patch-size heuristic
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
|
| 232 |
+
def auto_decoder_patch_size(h: int, w: int) -> tuple[int, int]:
|
| 233 |
+
max_dim = max(h, w)
|
| 234 |
+
if max_dim <= 256:
|
| 235 |
+
ps = 8
|
| 236 |
+
elif max_dim <= 512:
|
| 237 |
+
ps = 16
|
| 238 |
+
else:
|
| 239 |
+
ps = 32
|
| 240 |
+
return (ps, ps)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
# Generation
|
| 245 |
+
# ---------------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
@torch.inference_mode()
|
| 248 |
+
@spaces.GPU(duration=90)
|
| 249 |
+
def generate_image(
|
| 250 |
+
class_name: str,
|
| 251 |
+
class_id: int,
|
| 252 |
+
gen_resolution_preset: str,
|
| 253 |
+
out_resolution_preset: str,
|
| 254 |
+
decoder_ps_choice: str,
|
| 255 |
+
cfg_scale: float,
|
| 256 |
+
temperature: float,
|
| 257 |
+
top_k: int,
|
| 258 |
+
top_p: float,
|
| 259 |
+
seed: int,
|
| 260 |
+
randomize_seed: bool,
|
| 261 |
+
):
|
| 262 |
+
if vq_model is None or gpt_model is None:
|
| 263 |
+
raise gr.Error("Models are still loading. Please wait a moment and try again.")
|
| 264 |
+
|
| 265 |
+
if randomize_seed:
|
| 266 |
+
seed = random.randint(0, 2**31 - 1)
|
| 267 |
+
|
| 268 |
+
torch.manual_seed(seed)
|
| 269 |
+
np.random.seed(seed)
|
| 270 |
+
if DEVICE == "cuda":
|
| 271 |
+
torch.cuda.manual_seed_all(seed)
|
| 272 |
+
|
| 273 |
+
if class_name and class_name != "Custom (enter ID below)":
|
| 274 |
+
cid = IMAGENET_CLASSES[class_name]
|
| 275 |
+
else:
|
| 276 |
+
cid = int(class_id)
|
| 277 |
+
cid = max(0, min(cid, NUM_CLASSES - 1))
|
| 278 |
+
|
| 279 |
+
gen_h, gen_w = GENERATOR_RESOLUTION_PRESETS[gen_resolution_preset]
|
| 280 |
+
|
| 281 |
+
out_res = OUTPUT_RESOLUTION_PRESETS[out_resolution_preset]
|
| 282 |
+
if out_res is None:
|
| 283 |
+
out_h, out_w = gen_h, gen_w
|
| 284 |
+
else:
|
| 285 |
+
out_h, out_w = out_res
|
| 286 |
+
|
| 287 |
+
if decoder_ps_choice == "Auto":
|
| 288 |
+
dec_ps = auto_decoder_patch_size(out_h, out_w)
|
| 289 |
+
else:
|
| 290 |
+
ps = int(decoder_ps_choice)
|
| 291 |
+
dec_ps = (ps, ps)
|
| 292 |
+
|
| 293 |
+
multiplier = 2 if cfg_scale > 1.0 else 1
|
| 294 |
+
|
| 295 |
+
c_indices = torch.tensor([cid], device=DEVICE)
|
| 296 |
+
th, tw = _make_res_tensors(gen_h, gen_w, multiplier)
|
| 297 |
+
|
| 298 |
+
index_sample = generate(
|
| 299 |
+
gpt_model,
|
| 300 |
+
c_indices,
|
| 301 |
+
LATENT_SIZE,
|
| 302 |
+
NUM_CODEBOOKS,
|
| 303 |
+
cfg_scale=cfg_scale,
|
| 304 |
+
cfg_interval=-1,
|
| 305 |
+
target_h=th,
|
| 306 |
+
target_w=tw,
|
| 307 |
+
temperature=temperature,
|
| 308 |
+
top_k=top_k,
|
| 309 |
+
top_p=top_p,
|
| 310 |
+
sample_logits=True,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
index_sample = index_sample.unsqueeze(2)
|
| 314 |
+
samples = vq_model.decode(
|
| 315 |
+
index_sample,
|
| 316 |
+
height=out_h,
|
| 317 |
+
width=out_w,
|
| 318 |
+
patch_size=dec_ps,
|
| 319 |
+
)
|
| 320 |
+
samples = torch.clamp(samples, 0, 1)
|
| 321 |
+
|
| 322 |
+
img_np = (samples[0].permute(1, 2, 0).float().cpu().numpy() * 255).astype("uint8")
|
| 323 |
+
pil_img = Image.fromarray(img_np)
|
| 324 |
+
|
| 325 |
+
return pil_img, seed
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
# Gradio UI
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
HEADER_MD = """
|
| 333 |
+
# VibeToken-Gen: Dynamic Resolution Image Generation
|
| 334 |
+
|
| 335 |
+
<p style="margin-top:4px;">
|
| 336 |
+
<b>Maitreya Patel, Jingtao Li, Weiming Zhuang, Yezhou Yang, Lingjuan Lyu</b>
|
| 337 |
+
|
|
| 338 |
+
</p>
|
| 339 |
+
<h3>CVPR 2026 (Main Conference)</h3>
|
| 340 |
+
|
| 341 |
+
<p>
|
| 342 |
+
<a href="https://huggingface.co/mpatel57/VibeToken" target="_blank">🤗 Model</a> |
|
| 343 |
+
<a href="https://github.com/patel-maitreya/VibeToken" target="_blank">💻 GitHub</a>
|
| 344 |
+
</p>
|
| 345 |
+
|
| 346 |
+
Generate ImageNet class-conditional images at **arbitrary resolutions** using only **65 tokens**.
|
| 347 |
+
VibeToken-Gen maintains a constant **179G FLOPs** regardless of output resolution.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
CITATION_MD = """
|
| 351 |
+
### Citation
|
| 352 |
+
```bibtex
|
| 353 |
+
@inproceedings{vibetoken2026,
|
| 354 |
+
title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
|
| 355 |
+
author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
|
| 356 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 357 |
+
year = {2026}
|
| 358 |
+
}
|
| 359 |
+
```
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
class_choices = ["Custom (enter ID below)"] + sorted(IMAGENET_CLASSES.keys())
|
| 363 |
+
|
| 364 |
+
with gr.Blocks(
|
| 365 |
+
title="VibeToken-Gen Demo",
|
| 366 |
+
theme=gr.themes.Soft(),
|
| 367 |
+
) as demo:
|
| 368 |
+
gr.Markdown(HEADER_MD)
|
| 369 |
+
|
| 370 |
+
with gr.Row():
|
| 371 |
+
# ---- Left column: controls ----
|
| 372 |
+
with gr.Column(scale=1):
|
| 373 |
+
class_dropdown = gr.Dropdown(
|
| 374 |
+
label="ImageNet Class",
|
| 375 |
+
choices=class_choices,
|
| 376 |
+
value="Golden Retriever",
|
| 377 |
+
info="Pick a class or choose 'Custom' to enter an ID manually.",
|
| 378 |
+
)
|
| 379 |
+
class_id_input = gr.Number(
|
| 380 |
+
label="Custom Class ID (0–999)",
|
| 381 |
+
value=207,
|
| 382 |
+
minimum=0,
|
| 383 |
+
maximum=999,
|
| 384 |
+
step=1,
|
| 385 |
+
visible=False,
|
| 386 |
+
)
|
| 387 |
+
gen_resolution_dropdown = gr.Dropdown(
|
| 388 |
+
label="Generator Resolution",
|
| 389 |
+
choices=list(GENERATOR_RESOLUTION_PRESETS.keys()),
|
| 390 |
+
value="256 × 256",
|
| 391 |
+
info="Internal resolution for the AR generator (max 512×512).",
|
| 392 |
+
)
|
| 393 |
+
out_resolution_dropdown = gr.Dropdown(
|
| 394 |
+
label="Output Resolution (Decoder)",
|
| 395 |
+
choices=list(OUTPUT_RESOLUTION_PRESETS.keys()),
|
| 396 |
+
value="Same as generator",
|
| 397 |
+
info="Final image resolution. Set higher for super-resolution (e.g. generate at 256, decode at 1024).",
|
| 398 |
+
)
|
| 399 |
+
decoder_ps_dropdown = gr.Dropdown(
|
| 400 |
+
label="Decoder Patch Size",
|
| 401 |
+
choices=["Auto", "8", "16", "32"],
|
| 402 |
+
value="Auto",
|
| 403 |
+
info="'Auto' selects based on output resolution. Larger = faster but coarser.",
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
with gr.Accordion("Advanced Sampling Parameters", open=False):
|
| 407 |
+
cfg_slider = gr.Slider(
|
| 408 |
+
label="CFG Scale",
|
| 409 |
+
minimum=1.0, maximum=20.0, value=4.0, step=0.5,
|
| 410 |
+
info="Classifier-free guidance strength.",
|
| 411 |
+
)
|
| 412 |
+
temp_slider = gr.Slider(
|
| 413 |
+
label="Temperature",
|
| 414 |
+
minimum=0.1, maximum=2.0, value=1.0, step=0.05,
|
| 415 |
+
)
|
| 416 |
+
topk_slider = gr.Slider(
|
| 417 |
+
label="Top-k",
|
| 418 |
+
minimum=0, maximum=2000, value=500, step=10,
|
| 419 |
+
info="0 disables top-k filtering.",
|
| 420 |
+
)
|
| 421 |
+
topp_slider = gr.Slider(
|
| 422 |
+
label="Top-p",
|
| 423 |
+
minimum=0.0, maximum=1.0, value=1.0, step=0.05,
|
| 424 |
+
info="1.0 disables nucleus sampling.",
|
| 425 |
+
)
|
| 426 |
+
seed_input = gr.Number(
|
| 427 |
+
label="Seed", value=0, minimum=0, maximum=2**31 - 1, step=1,
|
| 428 |
+
)
|
| 429 |
+
randomize_cb = gr.Checkbox(label="Randomize seed", value=True)
|
| 430 |
+
|
| 431 |
+
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
| 432 |
+
|
| 433 |
+
# ---- Right column: output ----
|
| 434 |
+
with gr.Column(scale=2):
|
| 435 |
+
output_image = gr.Image(label="Generated Image", type="pil", height=512)
|
| 436 |
+
used_seed = gr.Number(label="Seed used", interactive=False)
|
| 437 |
+
|
| 438 |
+
# Show/hide custom class ID field
|
| 439 |
+
def toggle_custom_id(choice):
|
| 440 |
+
return gr.update(visible=(choice == "Custom (enter ID below)"))
|
| 441 |
+
|
| 442 |
+
class_dropdown.change(
|
| 443 |
+
fn=toggle_custom_id,
|
| 444 |
+
inputs=[class_dropdown],
|
| 445 |
+
outputs=[class_id_input],
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
generate_btn.click(
|
| 449 |
+
fn=generate_image,
|
| 450 |
+
inputs=[
|
| 451 |
+
class_dropdown,
|
| 452 |
+
class_id_input,
|
| 453 |
+
gen_resolution_dropdown,
|
| 454 |
+
out_resolution_dropdown,
|
| 455 |
+
decoder_ps_dropdown,
|
| 456 |
+
cfg_slider,
|
| 457 |
+
temp_slider,
|
| 458 |
+
topk_slider,
|
| 459 |
+
topp_slider,
|
| 460 |
+
seed_input,
|
| 461 |
+
randomize_cb,
|
| 462 |
+
],
|
| 463 |
+
outputs=[output_image, used_seed],
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
gr.Markdown(CITATION_MD)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
load_models()
|
| 471 |
+
demo.launch()
|
assets/example_1.png
ADDED
|
Git LFS Details
|
assets/generated_images.png
ADDED
|
Git LFS Details
|
assets/reconstructed.png
ADDED
|
Git LFS Details
|
assets/teaser.png
ADDED
|
Git LFS Details
|
configs/training/VibeToken_small.yaml
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
experiment:
|
| 2 |
+
project: "VibeToken_mvq_tiny_main"
|
| 3 |
+
name: "VibeToken_mvq_tiny_main"
|
| 4 |
+
output_dir: "wandb/VibeToken_mvq_tiny_main"
|
| 5 |
+
max_train_examples: 1_281_167
|
| 6 |
+
save_every: 10_000
|
| 7 |
+
eval_every: 10_000
|
| 8 |
+
generate_every: 5_000
|
| 9 |
+
log_every: 50
|
| 10 |
+
log_grad_norm_every: 1_000
|
| 11 |
+
resume: True
|
| 12 |
+
|
| 13 |
+
model:
|
| 14 |
+
sub_model_type: "vibetoken"
|
| 15 |
+
train_with_attention: True
|
| 16 |
+
eval_with_attention: True
|
| 17 |
+
vq_model:
|
| 18 |
+
# encoder: # patch mixture is not supported
|
| 19 |
+
# patch_mixture_start_layer: 2
|
| 20 |
+
# patch_mixture_end_layer: 22
|
| 21 |
+
# decoder: # patch mixture is not supported
|
| 22 |
+
# patch_mixture_start_layer: 2
|
| 23 |
+
# patch_mixture_end_layer: 22
|
| 24 |
+
quantize_mode: mvq
|
| 25 |
+
codebook_size: 32768 # 32768 / 8 = 4096
|
| 26 |
+
token_size: 256 # 256 / 8 = 32
|
| 27 |
+
use_l2_norm: False
|
| 28 |
+
commitment_cost: 0.25
|
| 29 |
+
clustering_vq: False
|
| 30 |
+
num_codebooks: 8
|
| 31 |
+
# vit arch
|
| 32 |
+
vit_enc_model_size: "small"
|
| 33 |
+
vit_dec_model_size: "small"
|
| 34 |
+
vit_enc_patch_size: 32
|
| 35 |
+
vit_dec_patch_size: 32
|
| 36 |
+
num_latent_tokens: 64
|
| 37 |
+
finetune_decoder: False
|
| 38 |
+
is_legacy: False
|
| 39 |
+
|
| 40 |
+
losses:
|
| 41 |
+
discriminator_start: 100_000
|
| 42 |
+
quantizer_weight: 1.0
|
| 43 |
+
discriminator_factor: 1.0
|
| 44 |
+
discriminator_weight: 0.1
|
| 45 |
+
perceptual_loss: "lpips-convnext_s-1.0-0.1"
|
| 46 |
+
perceptual_weight: 1.1
|
| 47 |
+
reconstruction_loss: "l2"
|
| 48 |
+
reconstruction_weight: 1.0
|
| 49 |
+
lecam_regularization_weight: 0.001
|
| 50 |
+
|
| 51 |
+
dataset:
|
| 52 |
+
params:
|
| 53 |
+
pretokenization: True
|
| 54 |
+
train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
|
| 55 |
+
eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
|
| 56 |
+
num_workers_per_gpu: 12
|
| 57 |
+
preprocessing:
|
| 58 |
+
resize_shorter_edge: 512
|
| 59 |
+
crop_size: 512
|
| 60 |
+
random_crop: True
|
| 61 |
+
random_flip: True
|
| 62 |
+
res_ratio_filtering: True
|
| 63 |
+
min_tokens: 32
|
| 64 |
+
max_tokens: 64
|
| 65 |
+
|
| 66 |
+
optimizer:
|
| 67 |
+
name: adamw
|
| 68 |
+
params:
|
| 69 |
+
learning_rate: 1e-4
|
| 70 |
+
discriminator_learning_rate: 1e-4
|
| 71 |
+
beta1: 0.9
|
| 72 |
+
beta2: 0.999
|
| 73 |
+
weight_decay: 1e-4
|
| 74 |
+
|
| 75 |
+
lr_scheduler:
|
| 76 |
+
scheduler: "cosine"
|
| 77 |
+
params:
|
| 78 |
+
learning_rate: ${optimizer.params.learning_rate}
|
| 79 |
+
warmup_steps: 10_000
|
| 80 |
+
end_lr: 1e-5
|
| 81 |
+
|
| 82 |
+
training:
|
| 83 |
+
gradient_accumulation_steps: 1
|
| 84 |
+
per_gpu_batch_size: 32
|
| 85 |
+
mixed_precision: "fp16"
|
| 86 |
+
enable_tf32: True
|
| 87 |
+
enable_wandb: True
|
| 88 |
+
use_ema: True
|
| 89 |
+
seed: 42
|
| 90 |
+
max_train_steps: 400_000
|
| 91 |
+
num_generated_images: 2
|
| 92 |
+
max_grad_norm: 1.0
|
| 93 |
+
variable_resolution:
|
| 94 |
+
any2any: True
|
| 95 |
+
dim:
|
| 96 |
+
- [256, 256]
|
| 97 |
+
- [512, 512]
|
| 98 |
+
- [384, 256]
|
| 99 |
+
- [256, 384]
|
| 100 |
+
- [512, 384]
|
| 101 |
+
- [384, 512]
|
| 102 |
+
ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1]
|
configs/vibetoken_ll.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VibeToken Large-Large Configuration
|
| 2 |
+
# Large encoder + Large decoder for highest quality
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# from vibetoken import VibeTokenTokenizer
|
| 6 |
+
# tokenizer = VibeTokenTokenizer.from_config(
|
| 7 |
+
# "configs/vibetoken_ll.yaml",
|
| 8 |
+
# "path/to/checkpoint.bin"
|
| 9 |
+
# )
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
sub_model_type: "vibetoken"
|
| 13 |
+
vq_model:
|
| 14 |
+
# Quantization settings
|
| 15 |
+
quantize_mode: mvq
|
| 16 |
+
codebook_size: 32768 # 32768 / 8 = 4096 per codebook
|
| 17 |
+
token_size: 256 # 256 / 8 = 32 per codebook
|
| 18 |
+
num_codebooks: 8
|
| 19 |
+
use_l2_norm: false
|
| 20 |
+
commitment_cost: 0.25
|
| 21 |
+
|
| 22 |
+
# Encoder architecture
|
| 23 |
+
vit_enc_model_size: "large"
|
| 24 |
+
vit_enc_patch_size: 32
|
| 25 |
+
|
| 26 |
+
# Decoder architecture
|
| 27 |
+
vit_dec_model_size: "large"
|
| 28 |
+
vit_dec_patch_size: 32
|
| 29 |
+
|
| 30 |
+
# Latent tokens
|
| 31 |
+
num_latent_tokens: 256
|
| 32 |
+
|
| 33 |
+
# Mode flags
|
| 34 |
+
is_legacy: false
|
| 35 |
+
finetune_decoder: false
|
| 36 |
+
|
| 37 |
+
# Dataset preprocessing defaults (for reference)
|
| 38 |
+
dataset:
|
| 39 |
+
preprocessing:
|
| 40 |
+
crop_size: 512
|
configs/vibetoken_sl.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VibeToken Small-Large Configuration
|
| 2 |
+
# Small encoder + Large decoder for faster encoding
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# from vibetoken import VibeTokenTokenizer
|
| 6 |
+
# tokenizer = VibeTokenTokenizer.from_config(
|
| 7 |
+
# "configs/vibetoken_sl.yaml",
|
| 8 |
+
# "path/to/checkpoint.bin"
|
| 9 |
+
# )
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
sub_model_type: "vibetoken"
|
| 13 |
+
vq_model:
|
| 14 |
+
# Quantization settings
|
| 15 |
+
quantize_mode: mvq
|
| 16 |
+
codebook_size: 32768 # 32768 / 8 = 4096 per codebook
|
| 17 |
+
token_size: 256 # 256 / 8 = 32 per codebook
|
| 18 |
+
num_codebooks: 8
|
| 19 |
+
use_l2_norm: false
|
| 20 |
+
commitment_cost: 0.25
|
| 21 |
+
|
| 22 |
+
# Encoder architecture (Small for faster encoding)
|
| 23 |
+
vit_enc_model_size: "small"
|
| 24 |
+
vit_enc_patch_size: 32
|
| 25 |
+
|
| 26 |
+
# Decoder architecture (Large for quality)
|
| 27 |
+
vit_dec_model_size: "large"
|
| 28 |
+
vit_dec_patch_size: 32
|
| 29 |
+
|
| 30 |
+
# Latent tokens
|
| 31 |
+
num_latent_tokens: 256
|
| 32 |
+
|
| 33 |
+
# Mode flags
|
| 34 |
+
is_legacy: false
|
| 35 |
+
finetune_decoder: false
|
| 36 |
+
|
| 37 |
+
# Dataset preprocessing defaults (for reference)
|
| 38 |
+
dataset:
|
| 39 |
+
preprocessing:
|
| 40 |
+
crop_size: 512
|
data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset
|
data/convert_imagenet_to_wds.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import webdataset as wds
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def convert_imagenet_to_wds(input_dir, output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
|
| 13 |
+
assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
|
| 14 |
+
assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
|
| 15 |
+
|
| 16 |
+
opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
|
| 17 |
+
output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
|
| 18 |
+
dataset = load_dataset(input_dir, split="train")
|
| 19 |
+
now = time.time()
|
| 20 |
+
for i, example in enumerate(dataset):
|
| 21 |
+
if i % max_train_samples_per_shard == 0:
|
| 22 |
+
print(i, file=sys.stderr)
|
| 23 |
+
img, label = example["image"], example["label"]
|
| 24 |
+
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
|
| 25 |
+
output.close()
|
| 26 |
+
time_taken = time.time() - now
|
| 27 |
+
print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
|
| 28 |
+
|
| 29 |
+
opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
|
| 30 |
+
output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
|
| 31 |
+
dataset = load_dataset(input_dir, split="validation")
|
| 32 |
+
now = time.time()
|
| 33 |
+
for i, example in enumerate(dataset):
|
| 34 |
+
if i % max_val_samples_per_shard == 0:
|
| 35 |
+
print(i, file=sys.stderr)
|
| 36 |
+
img, label = example["image"], example["label"]
|
| 37 |
+
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
|
| 38 |
+
output.close()
|
| 39 |
+
time_taken = time.time() - now
|
| 40 |
+
print(f"Wrote {i+1} val examples in {time_taken // 60} min.")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument("--input_dir", type=str, required=True,
|
| 46 |
+
help="Path to the ImageNet-1k dataset (HuggingFace format).")
|
| 47 |
+
parser.add_argument("--output_dir", type=str, required=True,
|
| 48 |
+
help="Path to the output directory for WebDataset shards.")
|
| 49 |
+
parser.add_argument("--max_train_samples_per_shard", type=int, default=10000,
|
| 50 |
+
help="Maximum number of training samples per shard.")
|
| 51 |
+
parser.add_argument("--max_val_samples_per_shard", type=int, default=10000,
|
| 52 |
+
help="Maximum number of validation samples per shard.")
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 56 |
+
convert_imagenet_to_wds(args.input_dir, args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard)
|
data/webdataset_reader.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loader using webdataset.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
|
| 5 |
+
https://github.com/huggingface/open-muse/blob/main/training/data.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import List, Union, Text
|
| 10 |
+
import webdataset as wds
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import default_collate
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import linecache
|
| 17 |
+
import json
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import random
|
| 20 |
+
import cv2
|
| 21 |
+
import numpy as np
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_json(sample):
|
| 28 |
+
sample['json'] = json.loads(sample['json'].decode('utf-8'))
|
| 29 |
+
return sample
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def filter_keys(key_set):
|
| 33 |
+
def _f(dictionary):
|
| 34 |
+
return {k: v for k, v in dictionary.items() if k in key_set}
|
| 35 |
+
|
| 36 |
+
return _f
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0):
|
| 40 |
+
def _f(sample):
|
| 41 |
+
cfg = sample['json']
|
| 42 |
+
h, w = cfg['original_height'], cfg['original_width']
|
| 43 |
+
ratio = h/w
|
| 44 |
+
longer_side = max(h, w)
|
| 45 |
+
return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
|
| 46 |
+
return _f
|
| 47 |
+
|
| 48 |
+
def calculate_laplacian_variance(image):
|
| 49 |
+
"""Calculate the variance of Laplacian which is a measure of image sharpness/blur."""
|
| 50 |
+
# Convert to grayscale if it's RGB
|
| 51 |
+
image = np.array(image)
|
| 52 |
+
if len(image.shape) == 3:
|
| 53 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 54 |
+
else:
|
| 55 |
+
gray = image
|
| 56 |
+
|
| 57 |
+
# Calculate Laplacian
|
| 58 |
+
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
| 59 |
+
|
| 60 |
+
# Calculate variance
|
| 61 |
+
return laplacian.var()
|
| 62 |
+
|
| 63 |
+
# Add this function to map Laplacian values to token lengths
|
| 64 |
+
def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128):
|
| 65 |
+
"""
|
| 66 |
+
Maps Laplacian values to token lengths using a bell curve approach.
|
| 67 |
+
At the mean Laplacian value, uses mean_tokens.
|
| 68 |
+
Values further from the mean get mapped to shorter/longer token lengths.
|
| 69 |
+
"""
|
| 70 |
+
# Prevent division by zero and handle edge cases
|
| 71 |
+
if std <= 0:
|
| 72 |
+
return mean_tokens
|
| 73 |
+
|
| 74 |
+
# Calculate z-score
|
| 75 |
+
z_score = (laplacian_value - mean) / std
|
| 76 |
+
|
| 77 |
+
# Use bell curve mapping (gaussian)
|
| 78 |
+
# When z_score is 0 (at mean), we get mean_tokens
|
| 79 |
+
# As z_score increases, token length increases toward max_tokens
|
| 80 |
+
# As z_score decreases, token length decreases toward min_tokens
|
| 81 |
+
scaling_factor = 2.0 # Controls how quickly we reach min/max tokens
|
| 82 |
+
normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score))
|
| 83 |
+
|
| 84 |
+
# Map to token range [min_tokens, max_tokens]
|
| 85 |
+
token_length = min_tokens + normalized_position * (max_tokens - min_tokens)
|
| 86 |
+
return int(round(token_length))
|
| 87 |
+
|
| 88 |
+
# Add this function to map Laplacian values to token lengths
|
| 89 |
+
def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128):
|
| 90 |
+
"""
|
| 91 |
+
Maps Laplacian values to token lengths using a linear mapping.
|
| 92 |
+
Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens,
|
| 93 |
+
and higher values scale up to max_tokens.
|
| 94 |
+
"""
|
| 95 |
+
# Prevent division by zero and handle edge cases
|
| 96 |
+
if std <= 0:
|
| 97 |
+
return mean_tokens
|
| 98 |
+
|
| 99 |
+
# Linear mapping from laplacian space to token space
|
| 100 |
+
# First normalize laplacian value relative to mean
|
| 101 |
+
normalized = (laplacian_value - 0.0) / mean
|
| 102 |
+
|
| 103 |
+
# Map 0->min_tokens, mean->mean_tokens, and scale up linearly
|
| 104 |
+
if laplacian_value <= mean:
|
| 105 |
+
# Linear interpolation between min_tokens and mean_tokens
|
| 106 |
+
ratio = laplacian_value / mean
|
| 107 |
+
token_length = min_tokens + (mean_tokens - min_tokens) * ratio
|
| 108 |
+
else:
|
| 109 |
+
# Linear interpolation between mean_tokens and max_tokens
|
| 110 |
+
ratio = (laplacian_value - mean) / mean # How far past mean
|
| 111 |
+
token_length = mean_tokens + (max_tokens - mean_tokens) * ratio
|
| 112 |
+
|
| 113 |
+
# Clamp to valid range
|
| 114 |
+
token_length = max(min_tokens, min(max_tokens, token_length))
|
| 115 |
+
return int(round(token_length))
|
| 116 |
+
|
| 117 |
+
def get_laplacian_attention_mask(sample):
|
| 118 |
+
"""Process sample to add Laplacian variance and attention mask."""
|
| 119 |
+
# Create a new dict to avoid modifying the input
|
| 120 |
+
processed = dict(sample)
|
| 121 |
+
|
| 122 |
+
# Calculate Laplacian variance
|
| 123 |
+
var = calculate_laplacian_variance(processed["image"])
|
| 124 |
+
length = get_dynamic_length(var)
|
| 125 |
+
|
| 126 |
+
# Create attention mask
|
| 127 |
+
attention_mask = torch.zeros((128,), dtype=torch.float32)
|
| 128 |
+
attention_mask[:length+1] = 1.0
|
| 129 |
+
|
| 130 |
+
# Add new fields to processed dict
|
| 131 |
+
processed["laplacian_var"] = var
|
| 132 |
+
processed["attention_mask"] = attention_mask
|
| 133 |
+
|
| 134 |
+
return processed
|
| 135 |
+
|
| 136 |
+
def get_uniform_attention_mask(min_tokens=32, max_tokens=128):
|
| 137 |
+
"""Process sample to add uniform random attention mask."""
|
| 138 |
+
def _f(dictionary):
|
| 139 |
+
# Sample length uniformly between min_tokens and max_tokens
|
| 140 |
+
length = torch.randint(min_tokens, max_tokens+1, (1,)).item()
|
| 141 |
+
|
| 142 |
+
# Create attention mask
|
| 143 |
+
attention_mask = torch.zeros((max_tokens,), dtype=torch.float32)
|
| 144 |
+
attention_mask[:length+1] = 1.0
|
| 145 |
+
|
| 146 |
+
# Add attention mask to dictionary
|
| 147 |
+
dictionary["attention_mask"] = attention_mask
|
| 148 |
+
return dictionary
|
| 149 |
+
return _f
|
| 150 |
+
|
| 151 |
+
def process_recap_text(p):
|
| 152 |
+
def _f(dictionary):
|
| 153 |
+
if "recap_txt" in dictionary:
|
| 154 |
+
if random.random() < p:
|
| 155 |
+
recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']]
|
| 156 |
+
# Convert input to string and strip whitespace
|
| 157 |
+
text = dictionary["recap_txt"].decode("utf-8").strip()
|
| 158 |
+
# Check if text starts with any of the phrases
|
| 159 |
+
for phrase in recap_prefixes:
|
| 160 |
+
if text.startswith(phrase):
|
| 161 |
+
# Remove the phrase and any leading/trailing whitespace
|
| 162 |
+
text = text[len(phrase):].strip()
|
| 163 |
+
# Capitalize the first letter
|
| 164 |
+
text = text[0].upper() + text[1:] if text else ""
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
dictionary["text"] = text.encode("utf-8")
|
| 168 |
+
return dictionary
|
| 169 |
+
|
| 170 |
+
return _f
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def identity(x):
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ImageTransform:
|
| 178 |
+
def __init__(self,
|
| 179 |
+
resize_shorter_edge: int = 256,
|
| 180 |
+
crop_size: int = 256,
|
| 181 |
+
random_crop: bool = True,
|
| 182 |
+
random_flip: bool = True,
|
| 183 |
+
normalize_mean: List[float] = [0., 0., 0.],
|
| 184 |
+
normalize_std: List[float] = [1., 1., 1.]):
|
| 185 |
+
"""Initializes the WebDatasetReader with specified augmentation parameters.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
|
| 189 |
+
crop_size: An integer, the size to crop the input image to.
|
| 190 |
+
random_crop: A boolean, whether to use random crop augmentation during training.
|
| 191 |
+
random_flip: A boolean, whether to use random flipping augmentation during training.
|
| 192 |
+
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
|
| 193 |
+
normalize_std: A list of float, the normalization std used to normalize the image tensor.
|
| 194 |
+
|
| 195 |
+
Raises:
|
| 196 |
+
NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
|
| 197 |
+
"""
|
| 198 |
+
train_transform = []
|
| 199 |
+
interpolation = transforms.InterpolationMode.BICUBIC
|
| 200 |
+
|
| 201 |
+
train_transform.append(
|
| 202 |
+
transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
|
| 203 |
+
if random_crop:
|
| 204 |
+
train_transform.append(transforms.RandomCrop(crop_size))
|
| 205 |
+
else:
|
| 206 |
+
train_transform.append(transforms.CenterCrop(crop_size))
|
| 207 |
+
if random_flip:
|
| 208 |
+
train_transform.append(transforms.RandomHorizontalFlip())
|
| 209 |
+
train_transform.append(transforms.ToTensor())
|
| 210 |
+
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
|
| 211 |
+
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
|
| 212 |
+
train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
|
| 213 |
+
|
| 214 |
+
self.train_transform = transforms.Compose(train_transform)
|
| 215 |
+
self.eval_transform = transforms.Compose(
|
| 216 |
+
[
|
| 217 |
+
# Note that we always resize to crop_size during eval to ensure the results
|
| 218 |
+
# can be compared against reference numbers on ImageNet etc.
|
| 219 |
+
transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
|
| 220 |
+
transforms.CenterCrop(crop_size),
|
| 221 |
+
transforms.ToTensor(),
|
| 222 |
+
transforms.Normalize(normalize_mean, normalize_std)
|
| 223 |
+
]
|
| 224 |
+
)
|
| 225 |
+
print(f"self.train_transform: {self.train_transform}")
|
| 226 |
+
print(f"self.eval_transform: {self.eval_transform}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class SimpleImageDataset:
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
train_shards_path: Union[Text, List[Text]],
|
| 233 |
+
eval_shards_path: Union[Text, List[Text]],
|
| 234 |
+
num_train_examples: int,
|
| 235 |
+
per_gpu_batch_size: int,
|
| 236 |
+
global_batch_size: int,
|
| 237 |
+
num_workers_per_gpu: int = 12,
|
| 238 |
+
resize_shorter_edge: int = 256,
|
| 239 |
+
crop_size: int = 256,
|
| 240 |
+
random_crop = True,
|
| 241 |
+
random_flip = True,
|
| 242 |
+
normalize_mean: List[float] = [0., 0., 0.],
|
| 243 |
+
normalize_std: List[float] = [1., 1., 1.],
|
| 244 |
+
dataset_with_class_label: bool = True,
|
| 245 |
+
dataset_with_text_label: bool = False,
|
| 246 |
+
res_ratio_filtering = False,
|
| 247 |
+
min_tokens = 32,
|
| 248 |
+
max_tokens = 128,
|
| 249 |
+
):
|
| 250 |
+
"""Initializes the WebDatasetReader class.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
train_shards_path: A string or list of string, path to the training data shards in webdataset format.
|
| 254 |
+
eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
|
| 255 |
+
num_train_examples: An integer, total number of training examples.
|
| 256 |
+
per_gpu_batch_size: An integer, number of examples per GPU batch.
|
| 257 |
+
global_batch_size: An integer, total number of examples in a batch across all GPUs.
|
| 258 |
+
num_workers_per_gpu: An integer, number of workers per GPU.
|
| 259 |
+
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
|
| 260 |
+
crop_size: An integer, the size to crop the input image to.
|
| 261 |
+
random_crop: A boolean, whether to use random crop augmentation during training.
|
| 262 |
+
random_flip: A boolean, whether to use random flipping augmentation during training.
|
| 263 |
+
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
|
| 264 |
+
normalize_std: A list of float, the normalization std used to normalize the image tensor.
|
| 265 |
+
"""
|
| 266 |
+
transform = ImageTransform(
|
| 267 |
+
resize_shorter_edge, crop_size, random_crop, random_flip,
|
| 268 |
+
normalize_mean, normalize_std)
|
| 269 |
+
|
| 270 |
+
if dataset_with_class_label:
|
| 271 |
+
train_processing_pipeline = [
|
| 272 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
|
| 273 |
+
wds.rename(
|
| 274 |
+
image="jpg;png;jpeg;webp",
|
| 275 |
+
class_id="cls",
|
| 276 |
+
handler=wds.warn_and_continue,
|
| 277 |
+
),
|
| 278 |
+
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
|
| 279 |
+
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
|
| 280 |
+
wds.map_dict(
|
| 281 |
+
image=transform.train_transform,
|
| 282 |
+
class_id=lambda x: int(x),
|
| 283 |
+
attention_mask=lambda x: x,
|
| 284 |
+
handler=wds.warn_and_continue,
|
| 285 |
+
),
|
| 286 |
+
]
|
| 287 |
+
elif dataset_with_text_label:
|
| 288 |
+
train_processing_pipeline = [
|
| 289 |
+
wds.map(load_json),
|
| 290 |
+
wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity),
|
| 291 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue),
|
| 292 |
+
wds.rename(
|
| 293 |
+
image="jpg;png;jpeg;webp",
|
| 294 |
+
text="txt",
|
| 295 |
+
handler=wds.warn_and_continue,
|
| 296 |
+
),
|
| 297 |
+
wds.map(filter_keys(set(["image", "text", "__key__"]))),
|
| 298 |
+
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
|
| 299 |
+
wds.map_dict(
|
| 300 |
+
image=transform.train_transform,
|
| 301 |
+
attention_mask=lambda x: x,
|
| 302 |
+
handler=wds.warn_and_continue,
|
| 303 |
+
),
|
| 304 |
+
]
|
| 305 |
+
else:
|
| 306 |
+
raise NotImplementedError
|
| 307 |
+
|
| 308 |
+
test_processing_pipeline = [
|
| 309 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
|
| 310 |
+
wds.rename(
|
| 311 |
+
image="jpg;png;jpeg;webp",
|
| 312 |
+
class_id="cls",
|
| 313 |
+
handler=wds.warn_and_continue,
|
| 314 |
+
),
|
| 315 |
+
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
|
| 316 |
+
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
|
| 317 |
+
wds.map_dict(
|
| 318 |
+
image=transform.eval_transform,
|
| 319 |
+
class_id=lambda x: int(x),
|
| 320 |
+
# laplacian_var=lambda x: x,
|
| 321 |
+
attention_mask=lambda x: x,
|
| 322 |
+
handler=wds.warn_and_continue,
|
| 323 |
+
),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
# Create train dataset and loader.
|
| 327 |
+
pipeline = [
|
| 328 |
+
wds.ResampledShards(train_shards_path),
|
| 329 |
+
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
| 330 |
+
wds.shuffle(bufsize=5000,
|
| 331 |
+
initial=1000),
|
| 332 |
+
*train_processing_pipeline,
|
| 333 |
+
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
num_batches = math.ceil(num_train_examples / global_batch_size)
|
| 337 |
+
num_worker_batches = math.ceil(num_train_examples /
|
| 338 |
+
(global_batch_size * num_workers_per_gpu))
|
| 339 |
+
num_batches = num_worker_batches * num_workers_per_gpu
|
| 340 |
+
num_samples = num_batches * global_batch_size
|
| 341 |
+
|
| 342 |
+
# Each worker is iterating over the complete dataset.
|
| 343 |
+
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
|
| 344 |
+
self._train_dataloader = wds.WebLoader(
|
| 345 |
+
self._train_dataset,
|
| 346 |
+
batch_size=None,
|
| 347 |
+
shuffle=False,
|
| 348 |
+
num_workers=num_workers_per_gpu,
|
| 349 |
+
pin_memory=True,
|
| 350 |
+
persistent_workers=True,
|
| 351 |
+
)
|
| 352 |
+
# Add meta-data to dataloader instance for convenience.
|
| 353 |
+
self._train_dataloader.num_batches = num_batches
|
| 354 |
+
self._train_dataloader.num_samples = num_samples
|
| 355 |
+
|
| 356 |
+
# Create eval dataset and loader.
|
| 357 |
+
pipeline = [
|
| 358 |
+
wds.SimpleShardList(eval_shards_path),
|
| 359 |
+
wds.split_by_worker,
|
| 360 |
+
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
|
| 361 |
+
*test_processing_pipeline,
|
| 362 |
+
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
|
| 363 |
+
]
|
| 364 |
+
self._eval_dataset = wds.DataPipeline(*pipeline)
|
| 365 |
+
self._eval_dataloader = wds.WebLoader(
|
| 366 |
+
self._eval_dataset,
|
| 367 |
+
batch_size=None,
|
| 368 |
+
shuffle=False,
|
| 369 |
+
num_workers=num_workers_per_gpu,
|
| 370 |
+
pin_memory=True,
|
| 371 |
+
persistent_workers=True,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
@property
|
| 375 |
+
def train_dataset(self):
|
| 376 |
+
return self._train_dataset
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def train_dataloader(self):
|
| 380 |
+
return self._train_dataloader
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def eval_dataset(self):
|
| 384 |
+
return self._eval_dataset
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def eval_dataloader(self):
|
| 388 |
+
return self._eval_dataloader
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class PretoeknizedDataSetJSONL(Dataset):
|
| 392 |
+
def __init__(self, data_path):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.jsonl_file = data_path
|
| 395 |
+
self.num_lines = sum(1 for _ in open(self.jsonl_file))
|
| 396 |
+
# Ensure the file is cached
|
| 397 |
+
linecache.checkcache(self.jsonl_file)
|
| 398 |
+
print("Number of data:", self.num_lines)
|
| 399 |
+
|
| 400 |
+
def __len__(self):
|
| 401 |
+
return self.num_lines
|
| 402 |
+
|
| 403 |
+
def __getitem__(self, idx):
|
| 404 |
+
line = linecache.getline(self.jsonl_file, idx + 1).strip()
|
| 405 |
+
data = json.loads(line)
|
| 406 |
+
return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class PretokenizedWebDataset(SimpleImageDataset):
|
| 410 |
+
def __init__ (
|
| 411 |
+
self,
|
| 412 |
+
train_shards_path: Union[Text, List[Text]],
|
| 413 |
+
eval_shards_path: Union[Text, List[Text]],
|
| 414 |
+
num_train_examples: int,
|
| 415 |
+
per_gpu_batch_size: int,
|
| 416 |
+
global_batch_size: int,
|
| 417 |
+
num_workers_per_gpu: int,
|
| 418 |
+
resize_shorter_edge: int = 256,
|
| 419 |
+
crop_size: int = 256,
|
| 420 |
+
random_crop = True,
|
| 421 |
+
random_flip = True,
|
| 422 |
+
normalize_mean: List[float] = [0., 0., 0.],
|
| 423 |
+
normalize_std: List[float] = [1., 1., 1.],
|
| 424 |
+
process_recap = False,
|
| 425 |
+
use_recap_prob = 0.95,
|
| 426 |
+
):
|
| 427 |
+
"""Initializes the PretokenizedWebDataset class.
|
| 428 |
+
|
| 429 |
+
Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training
|
| 430 |
+
"""
|
| 431 |
+
transform = ImageTransform(
|
| 432 |
+
resize_shorter_edge, crop_size, random_crop, random_flip,
|
| 433 |
+
normalize_mean, normalize_std)
|
| 434 |
+
|
| 435 |
+
def decode_npy(x):
|
| 436 |
+
arr = np.frombuffer(x, dtype=np.float16)
|
| 437 |
+
ret = torch.tensor(arr)
|
| 438 |
+
return ret
|
| 439 |
+
|
| 440 |
+
def decode_text(x):
|
| 441 |
+
ret = x.decode("utf-8")
|
| 442 |
+
return ret
|
| 443 |
+
|
| 444 |
+
train_processing_pipeline = [
|
| 445 |
+
wds.rename(
|
| 446 |
+
tokens="token.npy",
|
| 447 |
+
text="txt",
|
| 448 |
+
handler=wds.warn_and_continue,
|
| 449 |
+
),
|
| 450 |
+
wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)),
|
| 451 |
+
wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))),
|
| 452 |
+
wds.map_dict(
|
| 453 |
+
tokens=decode_npy,
|
| 454 |
+
text=decode_text,
|
| 455 |
+
handler=wds.warn_and_continue,
|
| 456 |
+
),
|
| 457 |
+
]
|
| 458 |
+
|
| 459 |
+
test_processing_pipeline = [
|
| 460 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
|
| 461 |
+
wds.rename(
|
| 462 |
+
image="jpg;png;jpeg;webp",
|
| 463 |
+
handler=wds.warn_and_continue,
|
| 464 |
+
),
|
| 465 |
+
wds.map_dict(
|
| 466 |
+
image=transform.eval_transform,
|
| 467 |
+
handler=wds.warn_and_continue,
|
| 468 |
+
),
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# Create train dataset and loader.
|
| 473 |
+
pipeline = [
|
| 474 |
+
wds.ResampledShards(train_shards_path),
|
| 475 |
+
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
| 476 |
+
wds.shuffle(bufsize=5000,
|
| 477 |
+
initial=1000),
|
| 478 |
+
*train_processing_pipeline,
|
| 479 |
+
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
| 480 |
+
]
|
| 481 |
+
|
| 482 |
+
num_batches = math.ceil(num_train_examples / global_batch_size)
|
| 483 |
+
num_worker_batches = math.ceil(num_train_examples /
|
| 484 |
+
(global_batch_size * num_workers_per_gpu))
|
| 485 |
+
num_batches = num_worker_batches * num_workers_per_gpu
|
| 486 |
+
num_samples = num_batches * global_batch_size
|
| 487 |
+
|
| 488 |
+
# Each worker is iterating over the complete dataset.
|
| 489 |
+
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
|
| 490 |
+
self._train_dataloader = wds.WebLoader(
|
| 491 |
+
self._train_dataset,
|
| 492 |
+
batch_size=None,
|
| 493 |
+
shuffle=False,
|
| 494 |
+
num_workers=num_workers_per_gpu,
|
| 495 |
+
pin_memory=True,
|
| 496 |
+
persistent_workers=True,
|
| 497 |
+
)
|
| 498 |
+
# Add meta-data to dataloader instance for convenience.
|
| 499 |
+
self._train_dataloader.num_batches = num_batches
|
| 500 |
+
self._train_dataloader.num_samples = num_samples
|
| 501 |
+
|
| 502 |
+
# Create eval dataset and loader.
|
| 503 |
+
pipeline = [
|
| 504 |
+
wds.SimpleShardList(eval_shards_path),
|
| 505 |
+
wds.split_by_worker,
|
| 506 |
+
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
|
| 507 |
+
*test_processing_pipeline,
|
| 508 |
+
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
|
| 509 |
+
]
|
| 510 |
+
self._eval_dataset = wds.DataPipeline(*pipeline)
|
| 511 |
+
self._eval_dataloader = wds.WebLoader(
|
| 512 |
+
self._eval_dataset,
|
| 513 |
+
batch_size=None,
|
| 514 |
+
shuffle=False,
|
| 515 |
+
num_workers=num_workers_per_gpu,
|
| 516 |
+
pin_memory=True,
|
| 517 |
+
persistent_workers=True,
|
| 518 |
+
)
|
evaluator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .evaluator import VQGANEvaluator
|
evaluator/evaluator.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluator for reconstruction results."""
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from typing import Sequence, Optional, Mapping, Text
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy import linalg
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .inception import get_inception_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor:
|
| 15 |
+
"""Computes covariance of the input tensor.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
sigma: A torch.Tensor, sum of outer products of input features.
|
| 19 |
+
total: A torch.Tensor, sum of all input features.
|
| 20 |
+
num_examples: An integer, number of examples in the input tensor.
|
| 21 |
+
Returns:
|
| 22 |
+
A torch.Tensor, covariance of the input tensor.
|
| 23 |
+
"""
|
| 24 |
+
if num_examples == 0:
|
| 25 |
+
return torch.zeros_like(sigma)
|
| 26 |
+
|
| 27 |
+
sub_matrix = torch.outer(total, total)
|
| 28 |
+
sub_matrix = sub_matrix / num_examples
|
| 29 |
+
|
| 30 |
+
return (sigma - sub_matrix) / (num_examples - 1)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class VQGANEvaluator:
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
device,
|
| 37 |
+
enable_rfid: bool = True,
|
| 38 |
+
enable_inception_score: bool = True,
|
| 39 |
+
enable_codebook_usage_measure: bool = False,
|
| 40 |
+
enable_codebook_entropy_measure: bool = False,
|
| 41 |
+
num_codebook_entries: int = 1024
|
| 42 |
+
):
|
| 43 |
+
"""Initializes VQGAN Evaluator.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
device: The device to use for evaluation.
|
| 47 |
+
enable_rfid: A boolean, whether enabling rFID score.
|
| 48 |
+
enable_inception_score: A boolean, whether enabling Inception Score.
|
| 49 |
+
enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure.
|
| 50 |
+
enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure.
|
| 51 |
+
num_codebook_entries: An integer, the number of codebook entries.
|
| 52 |
+
"""
|
| 53 |
+
self._device = device
|
| 54 |
+
|
| 55 |
+
self._enable_rfid = enable_rfid
|
| 56 |
+
self._enable_inception_score = enable_inception_score
|
| 57 |
+
self._enable_codebook_usage_measure = enable_codebook_usage_measure
|
| 58 |
+
self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
|
| 59 |
+
self._num_codebook_entries = num_codebook_entries
|
| 60 |
+
|
| 61 |
+
# Variables related to Inception score and rFID.
|
| 62 |
+
self._inception_model = None
|
| 63 |
+
self._is_num_features = 0
|
| 64 |
+
self._rfid_num_features = 0
|
| 65 |
+
if self._enable_inception_score or self._enable_rfid:
|
| 66 |
+
self._rfid_num_features = 2048
|
| 67 |
+
self._is_num_features = 1008
|
| 68 |
+
self._inception_model = get_inception_model().to(self._device)
|
| 69 |
+
self._inception_model.eval()
|
| 70 |
+
self._is_eps = 1e-16
|
| 71 |
+
self._rfid_eps = 1e-6
|
| 72 |
+
|
| 73 |
+
self.reset_metrics()
|
| 74 |
+
|
| 75 |
+
def reset_metrics(self):
|
| 76 |
+
"""Resets all metrics."""
|
| 77 |
+
self._num_examples = 0
|
| 78 |
+
self._num_updates = 0
|
| 79 |
+
|
| 80 |
+
self._is_prob_total = torch.zeros(
|
| 81 |
+
self._is_num_features, dtype=torch.float64, device=self._device
|
| 82 |
+
)
|
| 83 |
+
self._is_total_kl_d = torch.zeros(
|
| 84 |
+
self._is_num_features, dtype=torch.float64, device=self._device
|
| 85 |
+
)
|
| 86 |
+
self._rfid_real_sigma = torch.zeros(
|
| 87 |
+
(self._rfid_num_features, self._rfid_num_features),
|
| 88 |
+
dtype=torch.float64, device=self._device
|
| 89 |
+
)
|
| 90 |
+
self._rfid_real_total = torch.zeros(
|
| 91 |
+
self._rfid_num_features, dtype=torch.float64, device=self._device
|
| 92 |
+
)
|
| 93 |
+
self._rfid_fake_sigma = torch.zeros(
|
| 94 |
+
(self._rfid_num_features, self._rfid_num_features),
|
| 95 |
+
dtype=torch.float64, device=self._device
|
| 96 |
+
)
|
| 97 |
+
self._rfid_fake_total = torch.zeros(
|
| 98 |
+
self._rfid_num_features, dtype=torch.float64, device=self._device
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self._set_of_codebook_indices = set()
|
| 102 |
+
self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device)
|
| 103 |
+
|
| 104 |
+
def update(
|
| 105 |
+
self,
|
| 106 |
+
real_images: torch.Tensor,
|
| 107 |
+
fake_images: torch.Tensor,
|
| 108 |
+
codebook_indices: Optional[torch.Tensor] = None
|
| 109 |
+
):
|
| 110 |
+
"""Updates the metrics with the given images.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
real_images: A torch.Tensor, the real images.
|
| 114 |
+
fake_images: A torch.Tensor, the fake images.
|
| 115 |
+
codebook_indices: A torch.Tensor, the indices of the codebooks for each image.
|
| 116 |
+
|
| 117 |
+
Raises:
|
| 118 |
+
ValueError: If the fake images is not in RGB (3 channel).
|
| 119 |
+
ValueError: If the fake and real images have different shape.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
batch_size = real_images.shape[0]
|
| 123 |
+
dim = tuple(range(1, real_images.ndim))
|
| 124 |
+
self._num_examples += batch_size
|
| 125 |
+
self._num_updates += 1
|
| 126 |
+
|
| 127 |
+
if self._enable_inception_score or self._enable_rfid:
|
| 128 |
+
# Quantize to uint8 as a real image.
|
| 129 |
+
fake_inception_images = (fake_images * 255).to(torch.uint8)
|
| 130 |
+
features_fake = self._inception_model(fake_inception_images)
|
| 131 |
+
inception_logits_fake = features_fake["logits_unbiased"]
|
| 132 |
+
inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1)
|
| 133 |
+
|
| 134 |
+
if self._enable_inception_score:
|
| 135 |
+
probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64)
|
| 136 |
+
|
| 137 |
+
log_prob = torch.log(inception_probabilities_fake + self._is_eps)
|
| 138 |
+
if log_prob.dtype != inception_probabilities_fake.dtype:
|
| 139 |
+
log_prob = log_prob.to(inception_probabilities_fake)
|
| 140 |
+
kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64)
|
| 141 |
+
|
| 142 |
+
self._is_prob_total += probabiliies_sum
|
| 143 |
+
self._is_total_kl_d += kl_sum
|
| 144 |
+
|
| 145 |
+
if self._enable_rfid:
|
| 146 |
+
real_inception_images = (real_images * 255).to(torch.uint8)
|
| 147 |
+
features_real = self._inception_model(real_inception_images)
|
| 148 |
+
if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or
|
| 149 |
+
features_real['2048'].shape[1] != features_fake['2048'].shape[1]):
|
| 150 |
+
raise ValueError(f"Number of features should be equal for real and fake.")
|
| 151 |
+
|
| 152 |
+
for f_real, f_fake in zip(features_real['2048'], features_fake['2048']):
|
| 153 |
+
self._rfid_real_total += f_real
|
| 154 |
+
self._rfid_fake_total += f_fake
|
| 155 |
+
|
| 156 |
+
self._rfid_real_sigma += torch.outer(f_real, f_real)
|
| 157 |
+
self._rfid_fake_sigma += torch.outer(f_fake, f_fake)
|
| 158 |
+
|
| 159 |
+
if self._enable_codebook_usage_measure:
|
| 160 |
+
self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist())
|
| 161 |
+
|
| 162 |
+
if self._enable_codebook_entropy_measure:
|
| 163 |
+
entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True)
|
| 164 |
+
self._codebook_frequencies.index_add_(0, entries.int(), counts.double())
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def result(self) -> Mapping[Text, torch.Tensor]:
|
| 168 |
+
"""Returns the evaluation result."""
|
| 169 |
+
eval_score = {}
|
| 170 |
+
|
| 171 |
+
if self._num_examples < 1:
|
| 172 |
+
raise ValueError("No examples to evaluate.")
|
| 173 |
+
|
| 174 |
+
if self._enable_inception_score:
|
| 175 |
+
mean_probs = self._is_prob_total / self._num_examples
|
| 176 |
+
log_mean_probs = torch.log(mean_probs + self._is_eps)
|
| 177 |
+
if log_mean_probs.dtype != self._is_prob_total.dtype:
|
| 178 |
+
log_mean_probs = log_mean_probs.to(self._is_prob_total)
|
| 179 |
+
excess_entropy = self._is_prob_total * log_mean_probs
|
| 180 |
+
avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples
|
| 181 |
+
|
| 182 |
+
inception_score = torch.exp(avg_kl_d).item()
|
| 183 |
+
eval_score["InceptionScore"] = inception_score
|
| 184 |
+
|
| 185 |
+
if self._enable_rfid:
|
| 186 |
+
mu_real = self._rfid_real_total / self._num_examples
|
| 187 |
+
mu_fake = self._rfid_fake_total / self._num_examples
|
| 188 |
+
sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples)
|
| 189 |
+
sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples)
|
| 190 |
+
|
| 191 |
+
mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu()
|
| 192 |
+
sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu()
|
| 193 |
+
|
| 194 |
+
diff = mu_real - mu_fake
|
| 195 |
+
|
| 196 |
+
# Product might be almost singular.
|
| 197 |
+
covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False)
|
| 198 |
+
# Numerical error might give slight imaginary component.
|
| 199 |
+
if np.iscomplexobj(covmean):
|
| 200 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 201 |
+
m = np.max(np.abs(covmean.imag))
|
| 202 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 203 |
+
covmean = covmean.real
|
| 204 |
+
|
| 205 |
+
tr_covmean = np.trace(covmean)
|
| 206 |
+
|
| 207 |
+
if not np.isfinite(covmean).all():
|
| 208 |
+
tr_covmean = np.sum(np.sqrt((
|
| 209 |
+
(np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps))
|
| 210 |
+
/ (self._rfid_eps * self._rfid_eps)
|
| 211 |
+
))
|
| 212 |
+
|
| 213 |
+
rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake)
|
| 214 |
+
- 2 * tr_covmean
|
| 215 |
+
)
|
| 216 |
+
if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)):
|
| 217 |
+
warnings.warn("The product of covariance of train and test features is out of bounds.")
|
| 218 |
+
|
| 219 |
+
eval_score["rFID"] = rfid
|
| 220 |
+
|
| 221 |
+
if self._enable_codebook_usage_measure:
|
| 222 |
+
usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries
|
| 223 |
+
eval_score["CodebookUsage"] = usage
|
| 224 |
+
|
| 225 |
+
if self._enable_codebook_entropy_measure:
|
| 226 |
+
probs = self._codebook_frequencies / self._codebook_frequencies.sum()
|
| 227 |
+
entropy = (-torch.log2(probs + 1e-8) * probs).sum()
|
| 228 |
+
eval_score["CodebookEntropy"] = entropy
|
| 229 |
+
|
| 230 |
+
return eval_score
|
evaluator/inception.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inception model for FID evaluation.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from torch_fidelity.feature_extractor_base import FeatureExtractorBase
|
| 10 |
+
from torch_fidelity.helpers import vassert
|
| 11 |
+
from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
|
| 12 |
+
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from torchvision.models.utils import load_state_dict_from_url
|
| 16 |
+
except ImportError:
|
| 17 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Note: Compared shasum and models should be the same.
|
| 21 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
| 22 |
+
|
| 23 |
+
class FeatureExtractorInceptionV3(FeatureExtractorBase):
|
| 24 |
+
INPUT_IMAGE_SIZE = 299
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
name,
|
| 29 |
+
features_list,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
InceptionV3 feature extractor for 2D RGB 24bit images.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
|
| 37 |
+
name (str): Unique name of the feature extractor, must be the same as used in
|
| 38 |
+
:func:`register_feature_extractor`.
|
| 39 |
+
|
| 40 |
+
features_list (list): A list of the requested feature names, which will be produced for each input. This
|
| 41 |
+
feature extractor provides the following features:
|
| 42 |
+
|
| 43 |
+
- '64'
|
| 44 |
+
- '192'
|
| 45 |
+
- '768'
|
| 46 |
+
- '2048'
|
| 47 |
+
- 'logits_unbiased'
|
| 48 |
+
- 'logits'
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
|
| 52 |
+
self.feature_extractor_internal_dtype = torch.float64
|
| 53 |
+
|
| 54 |
+
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
| 55 |
+
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
|
| 56 |
+
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
| 57 |
+
self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
|
| 58 |
+
|
| 59 |
+
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
| 60 |
+
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
| 61 |
+
self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
|
| 62 |
+
|
| 63 |
+
self.Mixed_5b = InceptionA(192, pool_features=32)
|
| 64 |
+
self.Mixed_5c = InceptionA(256, pool_features=64)
|
| 65 |
+
self.Mixed_5d = InceptionA(288, pool_features=64)
|
| 66 |
+
self.Mixed_6a = InceptionB(288)
|
| 67 |
+
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
| 68 |
+
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
| 69 |
+
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
| 70 |
+
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
| 71 |
+
|
| 72 |
+
self.Mixed_7a = InceptionD(768)
|
| 73 |
+
self.Mixed_7b = InceptionE_1(1280)
|
| 74 |
+
self.Mixed_7c = InceptionE_2(2048)
|
| 75 |
+
self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
| 76 |
+
|
| 77 |
+
self.fc = torch.nn.Linear(2048, 1008)
|
| 78 |
+
|
| 79 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
|
| 80 |
+
#state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
|
| 81 |
+
self.load_state_dict(state_dict)
|
| 82 |
+
|
| 83 |
+
self.to(self.feature_extractor_internal_dtype)
|
| 84 |
+
self.requires_grad_(False)
|
| 85 |
+
self.eval()
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
|
| 89 |
+
vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
|
| 90 |
+
features = {}
|
| 91 |
+
remaining_features = self.features_list.copy()
|
| 92 |
+
|
| 93 |
+
x = x.to(self.feature_extractor_internal_dtype)
|
| 94 |
+
# N x 3 x ? x ?
|
| 95 |
+
|
| 96 |
+
x = interpolate_bilinear_2d_like_tensorflow1x(
|
| 97 |
+
x,
|
| 98 |
+
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
|
| 99 |
+
align_corners=False,
|
| 100 |
+
)
|
| 101 |
+
# N x 3 x 299 x 299
|
| 102 |
+
|
| 103 |
+
# x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
|
| 104 |
+
x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
|
| 105 |
+
# N x 3 x 299 x 299
|
| 106 |
+
|
| 107 |
+
x = self.Conv2d_1a_3x3(x)
|
| 108 |
+
# N x 32 x 149 x 149
|
| 109 |
+
x = self.Conv2d_2a_3x3(x)
|
| 110 |
+
# N x 32 x 147 x 147
|
| 111 |
+
x = self.Conv2d_2b_3x3(x)
|
| 112 |
+
# N x 64 x 147 x 147
|
| 113 |
+
x = self.MaxPool_1(x)
|
| 114 |
+
# N x 64 x 73 x 73
|
| 115 |
+
|
| 116 |
+
if '64' in remaining_features:
|
| 117 |
+
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
| 118 |
+
remaining_features.remove('64')
|
| 119 |
+
if len(remaining_features) == 0:
|
| 120 |
+
return features
|
| 121 |
+
|
| 122 |
+
x = self.Conv2d_3b_1x1(x)
|
| 123 |
+
# N x 80 x 73 x 73
|
| 124 |
+
x = self.Conv2d_4a_3x3(x)
|
| 125 |
+
# N x 192 x 71 x 71
|
| 126 |
+
x = self.MaxPool_2(x)
|
| 127 |
+
# N x 192 x 35 x 35
|
| 128 |
+
|
| 129 |
+
if '192' in remaining_features:
|
| 130 |
+
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
| 131 |
+
remaining_features.remove('192')
|
| 132 |
+
if len(remaining_features) == 0:
|
| 133 |
+
return features
|
| 134 |
+
|
| 135 |
+
x = self.Mixed_5b(x)
|
| 136 |
+
# N x 256 x 35 x 35
|
| 137 |
+
x = self.Mixed_5c(x)
|
| 138 |
+
# N x 288 x 35 x 35
|
| 139 |
+
x = self.Mixed_5d(x)
|
| 140 |
+
# N x 288 x 35 x 35
|
| 141 |
+
x = self.Mixed_6a(x)
|
| 142 |
+
# N x 768 x 17 x 17
|
| 143 |
+
x = self.Mixed_6b(x)
|
| 144 |
+
# N x 768 x 17 x 17
|
| 145 |
+
x = self.Mixed_6c(x)
|
| 146 |
+
# N x 768 x 17 x 17
|
| 147 |
+
x = self.Mixed_6d(x)
|
| 148 |
+
# N x 768 x 17 x 17
|
| 149 |
+
x = self.Mixed_6e(x)
|
| 150 |
+
# N x 768 x 17 x 17
|
| 151 |
+
|
| 152 |
+
if '768' in remaining_features:
|
| 153 |
+
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
|
| 154 |
+
remaining_features.remove('768')
|
| 155 |
+
if len(remaining_features) == 0:
|
| 156 |
+
return features
|
| 157 |
+
|
| 158 |
+
x = self.Mixed_7a(x)
|
| 159 |
+
# N x 1280 x 8 x 8
|
| 160 |
+
x = self.Mixed_7b(x)
|
| 161 |
+
# N x 2048 x 8 x 8
|
| 162 |
+
x = self.Mixed_7c(x)
|
| 163 |
+
# N x 2048 x 8 x 8
|
| 164 |
+
x = self.AvgPool(x)
|
| 165 |
+
# N x 2048 x 1 x 1
|
| 166 |
+
|
| 167 |
+
x = torch.flatten(x, 1)
|
| 168 |
+
# N x 2048
|
| 169 |
+
|
| 170 |
+
if '2048' in remaining_features:
|
| 171 |
+
features['2048'] = x
|
| 172 |
+
remaining_features.remove('2048')
|
| 173 |
+
if len(remaining_features) == 0:
|
| 174 |
+
return features
|
| 175 |
+
|
| 176 |
+
if 'logits_unbiased' in remaining_features:
|
| 177 |
+
x = x.mm(self.fc.weight.T)
|
| 178 |
+
# N x 1008 (num_classes)
|
| 179 |
+
features['logits_unbiased'] = x
|
| 180 |
+
remaining_features.remove('logits_unbiased')
|
| 181 |
+
if len(remaining_features) == 0:
|
| 182 |
+
return features
|
| 183 |
+
|
| 184 |
+
x = x + self.fc.bias.unsqueeze(0)
|
| 185 |
+
else:
|
| 186 |
+
x = self.fc(x)
|
| 187 |
+
# N x 1008 (num_classes)
|
| 188 |
+
|
| 189 |
+
features['logits'] = x
|
| 190 |
+
return features
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def get_provided_features_list():
|
| 194 |
+
return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def get_default_feature_layer_for_metric(metric):
|
| 198 |
+
return {
|
| 199 |
+
'isc': 'logits_unbiased',
|
| 200 |
+
'fid': '2048',
|
| 201 |
+
'kid': '2048',
|
| 202 |
+
'prc': '2048',
|
| 203 |
+
}[metric]
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def can_be_compiled():
|
| 207 |
+
return True
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def get_dummy_input_for_compile():
|
| 211 |
+
return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
|
| 212 |
+
|
| 213 |
+
def get_inception_model():
|
| 214 |
+
model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
|
| 215 |
+
return model
|
examples/batch_inference.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Batch inference example for VibeToken.
|
| 3 |
+
|
| 4 |
+
Demonstrates how to process multiple images efficiently in batches.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# Auto mode (recommended)
|
| 8 |
+
python examples/batch_inference.py --auto \
|
| 9 |
+
--config configs/vibetoken_ll.yaml \
|
| 10 |
+
--checkpoint path/to/checkpoint.bin \
|
| 11 |
+
--input_dir path/to/images/ \
|
| 12 |
+
--output_dir path/to/output/ \
|
| 13 |
+
--batch_size 4
|
| 14 |
+
|
| 15 |
+
# Manual mode
|
| 16 |
+
python examples/batch_inference.py \
|
| 17 |
+
--config configs/vibetoken_ll.yaml \
|
| 18 |
+
--checkpoint path/to/checkpoint.bin \
|
| 19 |
+
--input_dir path/to/images/ \
|
| 20 |
+
--output_dir path/to/output/ \
|
| 21 |
+
--batch_size 4 \
|
| 22 |
+
--resolution 512 \
|
| 23 |
+
--encoder_patch_size 16,32 \
|
| 24 |
+
--decoder_patch_size 16
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from PIL import Image
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
import sys
|
| 36 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 37 |
+
|
| 38 |
+
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def parse_patch_size(value):
|
| 42 |
+
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
|
| 43 |
+
if value is None:
|
| 44 |
+
return None
|
| 45 |
+
if ',' in value:
|
| 46 |
+
parts = value.split(',')
|
| 47 |
+
return (int(parts[0]), int(parts[1]))
|
| 48 |
+
return int(value)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple:
|
| 52 |
+
"""Load and preprocess image.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
path: Path to image
|
| 56 |
+
target_size: Optional target size (width, height) for resizing
|
| 57 |
+
auto_mode: If True, use auto_preprocess_image for cropping
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
image: numpy array
|
| 61 |
+
patch_size: auto-determined patch size (if auto_mode) or None
|
| 62 |
+
"""
|
| 63 |
+
img = Image.open(path).convert("RGB")
|
| 64 |
+
|
| 65 |
+
if auto_mode:
|
| 66 |
+
# Use centralized auto_preprocess_image
|
| 67 |
+
img, patch_size, info = auto_preprocess_image(img, verbose=False)
|
| 68 |
+
return np.array(img), patch_size, info
|
| 69 |
+
else:
|
| 70 |
+
if target_size:
|
| 71 |
+
img = img.resize(target_size, Image.LANCZOS)
|
| 72 |
+
# Always center crop to ensure dimensions divisible by 32
|
| 73 |
+
img = center_crop_to_multiple(img, multiple=32)
|
| 74 |
+
return np.array(img), None, None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main():
|
| 78 |
+
parser = argparse.ArgumentParser(description="VibeToken batch inference example")
|
| 79 |
+
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
|
| 80 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
|
| 81 |
+
parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images")
|
| 82 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images")
|
| 83 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
| 84 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
| 85 |
+
|
| 86 |
+
# Auto mode
|
| 87 |
+
parser.add_argument("--auto", action="store_true",
|
| 88 |
+
help="Auto mode: automatically determine optimal settings per image")
|
| 89 |
+
|
| 90 |
+
# Manual mode options
|
| 91 |
+
parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)")
|
| 92 |
+
parser.add_argument("--encoder_patch_size", type=str, default=None,
|
| 93 |
+
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 94 |
+
parser.add_argument("--decoder_patch_size", type=str, default=None,
|
| 95 |
+
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
|
| 98 |
+
# Parse patch sizes
|
| 99 |
+
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
|
| 100 |
+
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
|
| 101 |
+
|
| 102 |
+
# Check CUDA
|
| 103 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 104 |
+
print("CUDA not available, falling back to CPU")
|
| 105 |
+
args.device = "cpu"
|
| 106 |
+
|
| 107 |
+
# Create output directory
|
| 108 |
+
output_dir = Path(args.output_dir)
|
| 109 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
# Load tokenizer
|
| 112 |
+
print(f"Loading tokenizer from {args.config}")
|
| 113 |
+
tokenizer = VibeTokenTokenizer.from_config(
|
| 114 |
+
config_path=args.config,
|
| 115 |
+
checkpoint_path=args.checkpoint,
|
| 116 |
+
device=args.device,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if args.auto:
|
| 120 |
+
print("Running in AUTO MODE - optimal settings determined per image")
|
| 121 |
+
else:
|
| 122 |
+
print(f"Running in MANUAL MODE - resolution: {args.resolution}")
|
| 123 |
+
if encoder_patch_size:
|
| 124 |
+
print(f" Encoder patch size: {encoder_patch_size}")
|
| 125 |
+
if decoder_patch_size:
|
| 126 |
+
print(f" Decoder patch size: {decoder_patch_size}")
|
| 127 |
+
|
| 128 |
+
# Find all images
|
| 129 |
+
input_dir = Path(args.input_dir)
|
| 130 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
| 131 |
+
image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions]
|
| 132 |
+
print(f"Found {len(image_paths)} images")
|
| 133 |
+
|
| 134 |
+
if not image_paths:
|
| 135 |
+
print("No images found!")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
# Process in batches
|
| 139 |
+
target_size = (args.resolution, args.resolution) if not args.auto else None
|
| 140 |
+
total_time = 0
|
| 141 |
+
num_processed = 0
|
| 142 |
+
|
| 143 |
+
if args.auto:
|
| 144 |
+
# AUTO MODE: Process images one by one since each may have different sizes
|
| 145 |
+
for i, path in enumerate(image_paths):
|
| 146 |
+
try:
|
| 147 |
+
img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True)
|
| 148 |
+
batch_array = img_array[np.newaxis, ...] # Add batch dim
|
| 149 |
+
|
| 150 |
+
start_time = time.time()
|
| 151 |
+
|
| 152 |
+
# Reconstruct with auto-determined patch size
|
| 153 |
+
height, width = info['cropped_size'][1], info['cropped_size'][0]
|
| 154 |
+
reconstructed = tokenizer.reconstruct(
|
| 155 |
+
batch_array,
|
| 156 |
+
encode_patch_size=patch_size,
|
| 157 |
+
decode_patch_size=patch_size,
|
| 158 |
+
target_height=height,
|
| 159 |
+
target_width=width,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if args.device == "cuda":
|
| 163 |
+
torch.cuda.synchronize()
|
| 164 |
+
|
| 165 |
+
batch_time = time.time() - start_time
|
| 166 |
+
total_time += batch_time
|
| 167 |
+
num_processed += 1
|
| 168 |
+
|
| 169 |
+
# Save output
|
| 170 |
+
output_images = tokenizer.to_pil(reconstructed)
|
| 171 |
+
output_path = output_dir / f"{path.stem}_recon.png"
|
| 172 |
+
output_images[0].save(output_path)
|
| 173 |
+
|
| 174 |
+
print(f"[{i+1}/{len(image_paths)}] {path.name}: "
|
| 175 |
+
f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, "
|
| 176 |
+
f"patch_size={patch_size}, {batch_time:.2f}s")
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error processing {path}: {e}")
|
| 180 |
+
continue
|
| 181 |
+
else:
|
| 182 |
+
# MANUAL MODE: Batch processing with uniform size
|
| 183 |
+
for batch_start in range(0, len(image_paths), args.batch_size):
|
| 184 |
+
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
|
| 185 |
+
batch_names = [p.stem for p in batch_paths]
|
| 186 |
+
|
| 187 |
+
# Load batch
|
| 188 |
+
batch_images = []
|
| 189 |
+
for path in batch_paths:
|
| 190 |
+
try:
|
| 191 |
+
img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False)
|
| 192 |
+
batch_images.append(img_array)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Error loading {path}: {e}")
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
if not batch_images:
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
# Stack into batch tensor
|
| 201 |
+
batch_array = np.stack(batch_images, axis=0)
|
| 202 |
+
|
| 203 |
+
# Measure time
|
| 204 |
+
start_time = time.time()
|
| 205 |
+
|
| 206 |
+
# Reconstruct
|
| 207 |
+
reconstructed = tokenizer.reconstruct(
|
| 208 |
+
batch_array,
|
| 209 |
+
encode_patch_size=encoder_patch_size,
|
| 210 |
+
decode_patch_size=decoder_patch_size,
|
| 211 |
+
target_height=args.resolution,
|
| 212 |
+
target_width=args.resolution,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Synchronize if GPU
|
| 216 |
+
if args.device == "cuda":
|
| 217 |
+
torch.cuda.synchronize()
|
| 218 |
+
|
| 219 |
+
batch_time = time.time() - start_time
|
| 220 |
+
total_time += batch_time
|
| 221 |
+
num_processed += len(batch_images)
|
| 222 |
+
|
| 223 |
+
# Save outputs
|
| 224 |
+
output_images = tokenizer.to_pil(reconstructed)
|
| 225 |
+
for name, img in zip(batch_names[:len(output_images)], output_images):
|
| 226 |
+
output_path = output_dir / f"{name}_recon.png"
|
| 227 |
+
img.save(output_path)
|
| 228 |
+
|
| 229 |
+
print(f"Processed batch {batch_start // args.batch_size + 1}: "
|
| 230 |
+
f"{len(batch_images)} images in {batch_time:.2f}s "
|
| 231 |
+
f"({len(batch_images) / batch_time:.2f} img/s)")
|
| 232 |
+
|
| 233 |
+
# Summary
|
| 234 |
+
if num_processed > 0:
|
| 235 |
+
print(f"\nTotal: {num_processed} images in {total_time:.2f}s")
|
| 236 |
+
print(f"Average: {num_processed / total_time:.2f} images/sec")
|
| 237 |
+
print(f"Per image: {total_time / num_processed * 1000:.1f}ms")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
examples/encode_decode.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Basic encode-decode example for VibeToken.
|
| 3 |
+
|
| 4 |
+
Demonstrates how to:
|
| 5 |
+
1. Load the tokenizer from config and checkpoint
|
| 6 |
+
2. Encode an image to discrete tokens
|
| 7 |
+
3. Decode tokens back to an image
|
| 8 |
+
4. Save the reconstructed image
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# Auto mode (recommended)
|
| 12 |
+
python examples/encode_decode.py --auto \
|
| 13 |
+
--config configs/vibetoken_ll.yaml \
|
| 14 |
+
--checkpoint path/to/checkpoint.bin \
|
| 15 |
+
--image path/to/image.jpg \
|
| 16 |
+
--output reconstructed.png
|
| 17 |
+
|
| 18 |
+
# Manual mode
|
| 19 |
+
python examples/encode_decode.py \
|
| 20 |
+
--config configs/vibetoken_ll.yaml \
|
| 21 |
+
--checkpoint path/to/checkpoint.bin \
|
| 22 |
+
--image path/to/image.jpg \
|
| 23 |
+
--output reconstructed.png \
|
| 24 |
+
--encoder_patch_size 16,32 \
|
| 25 |
+
--decoder_patch_size 16
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from PIL import Image
|
| 33 |
+
|
| 34 |
+
import sys
|
| 35 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 36 |
+
|
| 37 |
+
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def parse_patch_size(value):
|
| 41 |
+
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
|
| 42 |
+
if value is None:
|
| 43 |
+
return None
|
| 44 |
+
if ',' in value:
|
| 45 |
+
parts = value.split(',')
|
| 46 |
+
return (int(parts[0]), int(parts[1]))
|
| 47 |
+
return int(value)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
parser = argparse.ArgumentParser(description="VibeToken encode-decode example")
|
| 52 |
+
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
|
| 53 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
|
| 54 |
+
parser.add_argument("--image", type=str, required=True, help="Path to input image")
|
| 55 |
+
parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path")
|
| 56 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
| 57 |
+
|
| 58 |
+
# Auto mode
|
| 59 |
+
parser.add_argument("--auto", action="store_true",
|
| 60 |
+
help="Auto mode: automatically determine optimal settings")
|
| 61 |
+
|
| 62 |
+
parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)")
|
| 63 |
+
parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)")
|
| 64 |
+
parser.add_argument("--encoder_patch_size", type=str, default=None,
|
| 65 |
+
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 66 |
+
parser.add_argument("--decoder_patch_size", type=str, default=None,
|
| 67 |
+
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 68 |
+
parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode")
|
| 69 |
+
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
# Check if CUDA is available
|
| 73 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 74 |
+
print("CUDA not available, falling back to CPU")
|
| 75 |
+
args.device = "cpu"
|
| 76 |
+
|
| 77 |
+
print(f"Loading tokenizer from {args.config}")
|
| 78 |
+
tokenizer = VibeTokenTokenizer.from_config(
|
| 79 |
+
config_path=args.config,
|
| 80 |
+
checkpoint_path=args.checkpoint,
|
| 81 |
+
device=args.device,
|
| 82 |
+
)
|
| 83 |
+
print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, "
|
| 84 |
+
f"num_latent_tokens={tokenizer.num_latent_tokens}")
|
| 85 |
+
|
| 86 |
+
# Load image
|
| 87 |
+
print(f"Loading image from {args.image}")
|
| 88 |
+
image = Image.open(args.image).convert("RGB")
|
| 89 |
+
original_size = image.size # (W, H)
|
| 90 |
+
print(f"Original image size: {original_size[0]}x{original_size[1]}")
|
| 91 |
+
|
| 92 |
+
if args.auto:
|
| 93 |
+
# AUTO MODE - use centralized auto_preprocess_image
|
| 94 |
+
print("\n=== AUTO MODE ===")
|
| 95 |
+
image, patch_size, info = auto_preprocess_image(image, verbose=True)
|
| 96 |
+
encoder_patch_size = patch_size
|
| 97 |
+
decoder_patch_size = patch_size
|
| 98 |
+
height, width = info['cropped_size'][1], info['cropped_size'][0]
|
| 99 |
+
print("=================\n")
|
| 100 |
+
|
| 101 |
+
# Encode to tokens
|
| 102 |
+
print("Encoding image to tokens...")
|
| 103 |
+
print(f" Using encoder patch size: {encoder_patch_size}")
|
| 104 |
+
tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
|
| 105 |
+
print(f"Token shape: {tokens.shape}")
|
| 106 |
+
|
| 107 |
+
# Decode back to image
|
| 108 |
+
print(f"Decoding tokens to image ({width}x{height})...")
|
| 109 |
+
print(f" Using decoder patch size: {decoder_patch_size}")
|
| 110 |
+
reconstructed = tokenizer.decode(
|
| 111 |
+
tokens, height=height, width=width, patch_size=decoder_patch_size
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
# MANUAL MODE
|
| 116 |
+
# Parse patch sizes
|
| 117 |
+
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
|
| 118 |
+
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
|
| 119 |
+
|
| 120 |
+
# Always center crop to ensure dimensions divisible by 32
|
| 121 |
+
image = center_crop_to_multiple(image, multiple=32)
|
| 122 |
+
cropped_size = image.size # (W, H)
|
| 123 |
+
if cropped_size != original_size:
|
| 124 |
+
print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
|
| 125 |
+
|
| 126 |
+
# Encode to tokens
|
| 127 |
+
print("Encoding image to tokens...")
|
| 128 |
+
if encoder_patch_size:
|
| 129 |
+
print(f" Using encoder patch size: {encoder_patch_size}")
|
| 130 |
+
tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens)
|
| 131 |
+
print(f"Token shape: {tokens.shape}")
|
| 132 |
+
|
| 133 |
+
if tokenizer.model.quantize_mode == "mvq":
|
| 134 |
+
print(f" - Batch size: {tokens.shape[0]}")
|
| 135 |
+
print(f" - Num codebooks: {tokens.shape[1]}")
|
| 136 |
+
print(f" - Sequence length: {tokens.shape[2]}")
|
| 137 |
+
else:
|
| 138 |
+
print(f" - Batch size: {tokens.shape[0]}")
|
| 139 |
+
print(f" - Sequence length: {tokens.shape[1]}")
|
| 140 |
+
|
| 141 |
+
# Decode back to image (use cropped size as default)
|
| 142 |
+
height = args.height or cropped_size[1]
|
| 143 |
+
width = args.width or cropped_size[0]
|
| 144 |
+
print(f"Decoding tokens to image ({width}x{height})...")
|
| 145 |
+
if decoder_patch_size:
|
| 146 |
+
print(f" Using decoder patch size: {decoder_patch_size}")
|
| 147 |
+
|
| 148 |
+
reconstructed = tokenizer.decode(
|
| 149 |
+
tokens, height=height, width=width, patch_size=decoder_patch_size
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
print(f"Reconstructed image shape: {reconstructed.shape}")
|
| 153 |
+
|
| 154 |
+
# Convert to PIL and save
|
| 155 |
+
output_images = tokenizer.to_pil(reconstructed)
|
| 156 |
+
output_path = Path(args.output)
|
| 157 |
+
output_images[0].save(output_path)
|
| 158 |
+
print(f"Saved reconstructed image to {output_path}")
|
| 159 |
+
|
| 160 |
+
# Compute PSNR (compare with cropped image)
|
| 161 |
+
import numpy as np
|
| 162 |
+
original_np = np.array(image).astype(np.float32)
|
| 163 |
+
recon_np = np.array(output_images[0]).astype(np.float32)
|
| 164 |
+
if original_np.shape == recon_np.shape:
|
| 165 |
+
mse = np.mean((original_np - recon_np) ** 2)
|
| 166 |
+
if mse > 0:
|
| 167 |
+
psnr = 20 * np.log10(255.0 / np.sqrt(mse))
|
| 168 |
+
print(f"PSNR: {psnr:.2f} dB")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
main()
|
generate.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from:
|
| 2 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
|
| 3 |
+
|
| 4 |
+
"""Example run:
|
| 5 |
+
python generate.py \
|
| 6 |
+
--gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
|
| 7 |
+
--gpt-model GPT-XXL --num-output-layer 4 \
|
| 8 |
+
--num-codebooks 8 --codebook-size 32768 \
|
| 9 |
+
--image-size 256 --cfg-scale 2.0 --top-k 0 --temperature 1.0 \
|
| 10 |
+
--class-dropout-prob 0.1 \
|
| 11 |
+
--extra-layers "QKV" \
|
| 12 |
+
--latent-size 65 \
|
| 13 |
+
--config ./configs/vibetoken_ll.yaml \
|
| 14 |
+
--vq-ckpt ./checkpoints/VibeToken_LL.bin \
|
| 15 |
+
--sample-dir ./assets/ \
|
| 16 |
+
--skip-folder-creation \
|
| 17 |
+
--compile \
|
| 18 |
+
--decoder-patch-size 16,16 \
|
| 19 |
+
--target-resolution 1024,1024 \
|
| 20 |
+
--llamagen-target-resolution 256,256 \
|
| 21 |
+
--precision bf16
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 27 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 28 |
+
torch.set_float32_matmul_precision('high')
|
| 29 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
| 30 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import torch.distributed as dist
|
| 33 |
+
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
import os
|
| 36 |
+
from PIL import Image
|
| 37 |
+
import numpy as np
|
| 38 |
+
import math
|
| 39 |
+
import argparse
|
| 40 |
+
import sys
|
| 41 |
+
from omegaconf import OmegaConf
|
| 42 |
+
|
| 43 |
+
from vibetokengen.model import GPT_models
|
| 44 |
+
from vibetokengen.generate import generate
|
| 45 |
+
|
| 46 |
+
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 50 |
+
"""
|
| 51 |
+
Builds a single .npz file from a folder of .png samples.
|
| 52 |
+
"""
|
| 53 |
+
samples = []
|
| 54 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 55 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 56 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 57 |
+
samples.append(sample_np)
|
| 58 |
+
samples = np.stack(samples)
|
| 59 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 60 |
+
npz_path = f"{sample_dir}.npz"
|
| 61 |
+
np.savez(npz_path, arr_0=samples)
|
| 62 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 63 |
+
return npz_path
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main(args):
|
| 67 |
+
# Setup PyTorch:
|
| 68 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 69 |
+
torch.set_grad_enabled(False)
|
| 70 |
+
|
| 71 |
+
# Set global seed for reproducibility
|
| 72 |
+
torch.manual_seed(args.global_seed)
|
| 73 |
+
np.random.seed(args.global_seed)
|
| 74 |
+
if torch.cuda.is_available():
|
| 75 |
+
torch.cuda.manual_seed(args.global_seed)
|
| 76 |
+
torch.cuda.manual_seed_all(args.global_seed)
|
| 77 |
+
|
| 78 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 79 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
| 80 |
+
|
| 81 |
+
# Load VibeToken model
|
| 82 |
+
vq_model = VibeTokenTokenizer.from_config(
|
| 83 |
+
args.config,
|
| 84 |
+
args.vq_ckpt,
|
| 85 |
+
device=device,
|
| 86 |
+
dtype=precision,
|
| 87 |
+
)
|
| 88 |
+
print(f"VibeToken image tokenizer is loaded")
|
| 89 |
+
|
| 90 |
+
# create and load gpt model
|
| 91 |
+
gpt_model = GPT_models[args.gpt_model](
|
| 92 |
+
vocab_size=args.codebook_size,
|
| 93 |
+
block_size=args.latent_size,
|
| 94 |
+
num_classes=args.num_classes,
|
| 95 |
+
cls_token_num=args.cls_token_num,
|
| 96 |
+
model_type=args.gpt_type,
|
| 97 |
+
num_codebooks=args.num_codebooks,
|
| 98 |
+
n_output_layer=args.num_output_layer,
|
| 99 |
+
class_dropout_prob=args.class_dropout_prob,
|
| 100 |
+
extra_layers=args.extra_layers,
|
| 101 |
+
capping=args.capping,
|
| 102 |
+
).to(device=device, dtype=precision)
|
| 103 |
+
print(f"GPT model is loaded")
|
| 104 |
+
|
| 105 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu", weights_only=False)
|
| 106 |
+
if args.from_fsdp: # fsdp
|
| 107 |
+
model_weight = checkpoint
|
| 108 |
+
elif "model" in checkpoint: # ddp
|
| 109 |
+
model_weight = checkpoint["model"]
|
| 110 |
+
elif "module" in checkpoint: # deepspeed
|
| 111 |
+
model_weight = checkpoint["module"]
|
| 112 |
+
elif "state_dict" in checkpoint:
|
| 113 |
+
model_weight = checkpoint["state_dict"]
|
| 114 |
+
else:
|
| 115 |
+
raise Exception("please check model weight, maybe add --from-fsdp to run command")
|
| 116 |
+
gpt_model.load_state_dict(model_weight, strict=True)
|
| 117 |
+
gpt_model.eval()
|
| 118 |
+
del checkpoint
|
| 119 |
+
|
| 120 |
+
print(f"GPT model weights are loaded")
|
| 121 |
+
|
| 122 |
+
if args.compile:
|
| 123 |
+
print(f"compiling the model...")
|
| 124 |
+
gpt_model = torch.compile(
|
| 125 |
+
gpt_model,
|
| 126 |
+
mode="reduce-overhead",
|
| 127 |
+
fullgraph=True
|
| 128 |
+
) # requires PyTorch 2.0 (optional)
|
| 129 |
+
else:
|
| 130 |
+
print(f"no model compile")
|
| 131 |
+
|
| 132 |
+
print(f"GPT model is compiled")
|
| 133 |
+
|
| 134 |
+
# Create folder to save samples:
|
| 135 |
+
model_string_name = args.gpt_model.replace("/", "-")
|
| 136 |
+
if args.from_fsdp:
|
| 137 |
+
ckpt_string_name = args.gpt_ckpt.split('/')[-2]
|
| 138 |
+
else:
|
| 139 |
+
ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
|
| 140 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-target-resolution-{args.target_resolution}-llamagen-target-resolution-{args.llamagen_target_resolution}-vibetoken-" \
|
| 141 |
+
f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
|
| 142 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
|
| 143 |
+
if args.skip_folder_creation:
|
| 144 |
+
sample_folder_dir = args.sample_dir
|
| 145 |
+
else:
|
| 146 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 147 |
+
|
| 148 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 149 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 150 |
+
|
| 151 |
+
multiplier = 2 if args.cfg_scale > 1.0 else 1
|
| 152 |
+
|
| 153 |
+
# Use fixed class labels
|
| 154 |
+
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
|
| 155 |
+
c_indices = torch.tensor(class_labels, device=device)
|
| 156 |
+
n = len(class_labels)
|
| 157 |
+
nrow = 4 # 2 rows x 4 columns for 8 images
|
| 158 |
+
|
| 159 |
+
index_sample = generate(
|
| 160 |
+
gpt_model, c_indices, args.latent_size, args.num_codebooks,
|
| 161 |
+
cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
|
| 162 |
+
target_h=torch.tensor(args.llamagen_target_resolution[0]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
|
| 163 |
+
target_w=torch.tensor(args.llamagen_target_resolution[1]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
|
| 164 |
+
temperature=args.temperature, top_k=args.top_k,
|
| 165 |
+
top_p=args.top_p, sample_logits=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Use VibeToken decode_tokens method
|
| 169 |
+
# VibeToken expects tokens in shape (batch_size, seq_len, 1)
|
| 170 |
+
index_sample = index_sample.unsqueeze(2)
|
| 171 |
+
samples = vq_model.decode(
|
| 172 |
+
index_sample,
|
| 173 |
+
height=args.target_resolution[0],
|
| 174 |
+
width=args.target_resolution[1],
|
| 175 |
+
patch_size=args.decoder_patch_size
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# VibeToken output is in [0, 1] range, clamp and convert to uint8
|
| 179 |
+
samples = torch.clamp(samples, 0, 1)
|
| 180 |
+
|
| 181 |
+
# Create a grid of images (2 rows x 4 columns)
|
| 182 |
+
from torchvision.utils import make_grid
|
| 183 |
+
grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
|
| 184 |
+
|
| 185 |
+
# Convert to PIL and save
|
| 186 |
+
grid_np = (grid.permute(1, 2, 0).to(torch.float32).cpu().numpy() * 255).astype('uint8')
|
| 187 |
+
Image.fromarray(grid_np).save(f"{sample_folder_dir}/generated_images.png")
|
| 188 |
+
print(f"Saved grid of {n} images to {sample_folder_dir}/generated_images.png")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
parser = argparse.ArgumentParser()
|
| 193 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
|
| 194 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
| 195 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i",
|
| 196 |
+
help="class-conditional or text-conditional")
|
| 197 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
| 198 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
| 199 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
| 200 |
+
parser.add_argument("--compile", action='store_true', default=True)
|
| 201 |
+
# parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
| 202 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
| 203 |
+
parser.add_argument("--config", type=str, required=True, help="Path to VibeToken config file")
|
| 204 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
| 205 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
| 206 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
|
| 207 |
+
parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
|
| 208 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
| 209 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 210 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 211 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
| 212 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 213 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
| 214 |
+
parser.add_argument("--num-fid-samples", type=int, default=50000)
|
| 215 |
+
parser.add_argument("--global-seed", type=int, default=0) # not used
|
| 216 |
+
parser.add_argument("--top-k", type=int, default=500, help="top-k value to sample with")
|
| 217 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
| 218 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
| 219 |
+
parser.add_argument("--num-codebooks", type=int, default=1)
|
| 220 |
+
parser.add_argument("--num-output-layer", type=int, default=1)
|
| 221 |
+
parser.add_argument("--class-dropout-prob", type=float, default=0.1)
|
| 222 |
+
parser.add_argument("--extra-layers", type=str, choices=['QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip'], default=None,
|
| 223 |
+
help="Type of extra layers to add: QK (query-key), QKV (query-key-value), FC (fully connected), cap (caption), clip (clip), QK_cap (query-key-caption), QKV_cap (query-key-value-caption), QK_clip (query-key-clip), QKV_clip (query-key-value-clip), QK_FC_cap (query-key-fully-connected-caption), QKV_FC_cap (query-key-value-fully-connected-caption), QK_FC_clip (query-key-fully-connected-clip), QKV_FC_clip (query-key-value-fully-connected-clip)")
|
| 224 |
+
parser.add_argument("--capping", type=float, default=50.0, help="Capping for attention softmax")
|
| 225 |
+
|
| 226 |
+
# VibeToken dynamic
|
| 227 |
+
parser.add_argument("--decoder-patch-size", type=str, default="8,8", help="Decoder patch size as 'width,height'")
|
| 228 |
+
parser.add_argument("--target-resolution", type=str, default="256,256", help="Target resolution as 'width,height'")
|
| 229 |
+
parser.add_argument("--llamagen-target-resolution", type=str, default="256,256", help="LlamaGen target resolution as 'width,height'")
|
| 230 |
+
|
| 231 |
+
parser.add_argument("--latent-size", type=int, default=16, help="Latent size")
|
| 232 |
+
parser.add_argument("--skip-folder-creation", action='store_true', default=False, help="skip folder creation")
|
| 233 |
+
|
| 234 |
+
args = parser.parse_args()
|
| 235 |
+
|
| 236 |
+
args.decoder_patch_size = tuple(map(int, args.decoder_patch_size.split(",")))
|
| 237 |
+
args.target_resolution = tuple(map(int, args.target_resolution.split(",")))
|
| 238 |
+
args.llamagen_target_resolution = tuple(map(int, args.llamagen_target_resolution.split(",")))
|
| 239 |
+
|
| 240 |
+
main(args)
|
generator/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generator module placeholder for VibeToken-Gen integration."""
|
| 2 |
+
|
| 3 |
+
# Future: Add GPT-based generator for image synthesis
|
| 4 |
+
# from .gpt import VibeTokenGenerator
|
modeling/__init__.py
ADDED
|
File without changes
|
modeling/modules/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_model import BaseModel
|
| 2 |
+
from .ema_model import EMAModel
|
| 3 |
+
from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, ReconstructionLoss_Single_Stage
|
| 4 |
+
from .blocks import TiTokEncoder, TiTokDecoder, TATiTokDecoder, UViTBlock
|
| 5 |
+
from .maskgit_vqgan import Decoder as Pixel_Decoder
|
| 6 |
+
from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
|
modeling/modules/base_model.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class implementation for models.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union, Callable, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseModel(torch.nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def save_pretrained_weight(
|
| 18 |
+
self,
|
| 19 |
+
save_directory: Union[str, os.PathLike],
|
| 20 |
+
save_function: Callable = None,
|
| 21 |
+
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
| 22 |
+
):
|
| 23 |
+
"""Saves a model and its configuration file to a directory.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
save_directory: A string or os.PathLike, directory to which to save.
|
| 27 |
+
Will be created if it doesn't exist.
|
| 28 |
+
save_function: A Callable function, the function to use to save the state dictionary.
|
| 29 |
+
Useful on distributed training like TPUs when one need to replace `torch.save` by
|
| 30 |
+
another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
|
| 31 |
+
state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
|
| 32 |
+
If `None`, the model's state dictionary will be saved.
|
| 33 |
+
"""
|
| 34 |
+
if os.path.isfile(save_directory):
|
| 35 |
+
print(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
if save_function is None:
|
| 39 |
+
save_function = torch.save
|
| 40 |
+
|
| 41 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
model_to_save = self
|
| 44 |
+
|
| 45 |
+
if state_dict is None:
|
| 46 |
+
state_dict = model_to_save.state_dict()
|
| 47 |
+
weights_name = "pytorch_model.bin"
|
| 48 |
+
|
| 49 |
+
save_function(state_dict, os.path.join(save_directory, weights_name))
|
| 50 |
+
|
| 51 |
+
print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
| 52 |
+
|
| 53 |
+
def load_pretrained_weight(
|
| 54 |
+
self,
|
| 55 |
+
pretrained_model_path: Union[str, os.PathLike],
|
| 56 |
+
strict_loading: bool = True,
|
| 57 |
+
torch_dtype: Optional[torch.dtype] = None
|
| 58 |
+
):
|
| 59 |
+
r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
|
| 60 |
+
|
| 61 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
| 62 |
+
the model, you should first set it back in training mode with `model.train()`.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
ValueError: If pretrained_model_path does not exist.
|
| 69 |
+
"""
|
| 70 |
+
# If pretrained_model_path is a file, set model_file to this file.
|
| 71 |
+
if os.path.isfile(pretrained_model_path):
|
| 72 |
+
model_file = pretrained_model_path
|
| 73 |
+
# If pretrained_model_path is a directory, set model_file to the path of the
|
| 74 |
+
# file "pytorch_model.bin" in this directory.
|
| 75 |
+
elif os.path.isdir(pretrained_model_path):
|
| 76 |
+
pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
|
| 77 |
+
if os.path.isfile(pretrained_model_path):
|
| 78 |
+
model_file = pretrained_model_path
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"{pretrained_model_path} does not exist")
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"{pretrained_model_path} does not exist")
|
| 83 |
+
|
| 84 |
+
# Load model state from checkpoint.
|
| 85 |
+
checkpoint = torch.load(model_file, map_location="cpu")
|
| 86 |
+
# Load state dictionary into self.
|
| 87 |
+
msg = self.load_state_dict(checkpoint, strict=strict_loading)
|
| 88 |
+
# Print information about loading weights.
|
| 89 |
+
print(f"loading weight from {model_file}, msg: {msg}")
|
| 90 |
+
# If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
|
| 91 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
| 94 |
+
)
|
| 95 |
+
elif torch_dtype is not None:
|
| 96 |
+
self.to(torch_dtype)
|
| 97 |
+
|
| 98 |
+
# Set model in evaluation mode to deactivate DropOut modules by default.
|
| 99 |
+
self.eval()
|
| 100 |
+
|
| 101 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
| 102 |
+
"""Gets the number of parameters in the module.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
only_trainable: A boolean, whether to only include trainable parameters.
|
| 106 |
+
exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
An integer, the number of parameters.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
if exclude_embeddings:
|
| 113 |
+
embedding_param_names = [
|
| 114 |
+
f"{name}.weight"
|
| 115 |
+
for name, module_type in self.named_modules()
|
| 116 |
+
if isinstance(module_type, torch.nn.Embedding)
|
| 117 |
+
]
|
| 118 |
+
non_embedding_parameters = [
|
| 119 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
| 120 |
+
]
|
| 121 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
| 122 |
+
else:
|
| 123 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 124 |
+
|
modeling/modules/blocks.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer building blocks.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
|
| 5 |
+
https://github.com/baofff/U-ViT/blob/main/libs/timm.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
import einops
|
| 14 |
+
from einops.layers.torch import Rearrange
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def modulate(x, shift, scale):
|
| 18 |
+
return x * (1 + scale) + shift
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ResidualAttentionBlock(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
d_model,
|
| 25 |
+
n_head,
|
| 26 |
+
mlp_ratio = 4.0,
|
| 27 |
+
act_layer = nn.GELU,
|
| 28 |
+
norm_layer = nn.LayerNorm
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.ln_1 = norm_layer(d_model)
|
| 33 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 34 |
+
self.mlp_ratio = mlp_ratio
|
| 35 |
+
# optionally we can disable the FFN
|
| 36 |
+
if mlp_ratio > 0:
|
| 37 |
+
self.ln_2 = norm_layer(d_model)
|
| 38 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 39 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 40 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 41 |
+
("gelu", act_layer()),
|
| 42 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 43 |
+
]))
|
| 44 |
+
|
| 45 |
+
def attention(
|
| 46 |
+
self,
|
| 47 |
+
x: torch.Tensor
|
| 48 |
+
):
|
| 49 |
+
return self.attn(x, x, x, need_weights=False)[0]
|
| 50 |
+
|
| 51 |
+
def forward(
|
| 52 |
+
self,
|
| 53 |
+
x: torch.Tensor,
|
| 54 |
+
):
|
| 55 |
+
attn_output = self.attention(x=self.ln_1(x))
|
| 56 |
+
x = x + attn_output
|
| 57 |
+
if self.mlp_ratio > 0:
|
| 58 |
+
x = x + self.mlp(self.ln_2(x))
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
| 62 |
+
ATTENTION_MODE = 'flash'
|
| 63 |
+
else:
|
| 64 |
+
try:
|
| 65 |
+
import xformers
|
| 66 |
+
import xformers.ops
|
| 67 |
+
ATTENTION_MODE = 'xformers'
|
| 68 |
+
except:
|
| 69 |
+
ATTENTION_MODE = 'math'
|
| 70 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Attention(nn.Module):
|
| 74 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.num_heads = num_heads
|
| 77 |
+
head_dim = dim // num_heads
|
| 78 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 79 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 80 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 81 |
+
self.proj = nn.Linear(dim, dim)
|
| 82 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
B, L, C = x.shape
|
| 86 |
+
|
| 87 |
+
qkv = self.qkv(x)
|
| 88 |
+
if ATTENTION_MODE == 'flash':
|
| 89 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
| 90 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
| 91 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 92 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
| 93 |
+
elif ATTENTION_MODE == 'xformers':
|
| 94 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
| 95 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
| 96 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
| 97 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
| 98 |
+
elif ATTENTION_MODE == 'math':
|
| 99 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
|
| 100 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
| 101 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 102 |
+
attn = attn.softmax(dim=-1)
|
| 103 |
+
attn = self.attn_drop(attn)
|
| 104 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
| 105 |
+
else:
|
| 106 |
+
raise NotImplemented
|
| 107 |
+
|
| 108 |
+
x = self.proj(x)
|
| 109 |
+
x = self.proj_drop(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 114 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 115 |
+
|
| 116 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 117 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 118 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 119 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 120 |
+
'survival rate' as the argument.
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
if drop_prob == 0. or not training:
|
| 124 |
+
return x
|
| 125 |
+
keep_prob = 1 - drop_prob
|
| 126 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 127 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 128 |
+
random_tensor.floor_() # binarize
|
| 129 |
+
output = x.div(keep_prob) * random_tensor
|
| 130 |
+
return output
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class DropPath(nn.Module):
|
| 134 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self, drop_prob=None):
|
| 137 |
+
super(DropPath, self).__init__()
|
| 138 |
+
self.drop_prob = drop_prob
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Mlp(nn.Module):
|
| 145 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 146 |
+
super().__init__()
|
| 147 |
+
out_features = out_features or in_features
|
| 148 |
+
hidden_features = hidden_features or in_features
|
| 149 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 150 |
+
self.act = act_layer()
|
| 151 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 152 |
+
self.drop = nn.Dropout(drop)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
x = self.fc1(x)
|
| 156 |
+
x = self.act(x)
|
| 157 |
+
x = self.drop(x)
|
| 158 |
+
x = self.fc2(x)
|
| 159 |
+
x = self.drop(x)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class UViTBlock(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 166 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.norm1 = norm_layer(dim)
|
| 169 |
+
self.attn = Attention(
|
| 170 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 171 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 172 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 173 |
+
self.norm2 = norm_layer(dim)
|
| 174 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 175 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 176 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
| 177 |
+
self.use_checkpoint = use_checkpoint
|
| 178 |
+
|
| 179 |
+
def forward(self, x, skip=None):
|
| 180 |
+
if self.use_checkpoint:
|
| 181 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
| 182 |
+
else:
|
| 183 |
+
return self._forward(x, skip)
|
| 184 |
+
|
| 185 |
+
def _forward(self, x, skip=None):
|
| 186 |
+
if self.skip_linear is not None:
|
| 187 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
| 188 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 189 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _expand_token(token, batch_size: int):
|
| 194 |
+
return token.unsqueeze(0).expand(batch_size, -1, -1)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class TiTokEncoder(nn.Module):
|
| 198 |
+
def __init__(self, config):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.config = config
|
| 201 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
| 202 |
+
self.patch_size = config.model.vq_model.vit_enc_patch_size
|
| 203 |
+
self.grid_size = self.image_size // self.patch_size
|
| 204 |
+
self.model_size = config.model.vq_model.vit_enc_model_size
|
| 205 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
| 206 |
+
self.token_size = config.model.vq_model.token_size
|
| 207 |
+
|
| 208 |
+
if config.model.vq_model.get("quantize_mode", "vq") == "vae":
|
| 209 |
+
self.token_size = self.token_size * 2 # needs to split into mean and std
|
| 210 |
+
|
| 211 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
| 212 |
+
|
| 213 |
+
self.width = {
|
| 214 |
+
"small": 512,
|
| 215 |
+
"base": 768,
|
| 216 |
+
"large": 1024,
|
| 217 |
+
}[self.model_size]
|
| 218 |
+
self.num_layers = {
|
| 219 |
+
"small": 8,
|
| 220 |
+
"base": 12,
|
| 221 |
+
"large": 24,
|
| 222 |
+
}[self.model_size]
|
| 223 |
+
self.num_heads = {
|
| 224 |
+
"small": 8,
|
| 225 |
+
"base": 12,
|
| 226 |
+
"large": 16,
|
| 227 |
+
}[self.model_size]
|
| 228 |
+
|
| 229 |
+
self.patch_embed = nn.Conv2d(
|
| 230 |
+
in_channels=3, out_channels=self.width,
|
| 231 |
+
kernel_size=self.patch_size, stride=self.patch_size, bias=True)
|
| 232 |
+
|
| 233 |
+
scale = self.width ** -0.5
|
| 234 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
| 235 |
+
self.positional_embedding = nn.Parameter(
|
| 236 |
+
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
|
| 237 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
| 238 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
| 239 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
| 240 |
+
self.transformer = nn.ModuleList()
|
| 241 |
+
for i in range(self.num_layers):
|
| 242 |
+
self.transformer.append(ResidualAttentionBlock(
|
| 243 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
| 244 |
+
))
|
| 245 |
+
self.ln_post = nn.LayerNorm(self.width)
|
| 246 |
+
self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
|
| 247 |
+
|
| 248 |
+
def forward(self, pixel_values, latent_tokens):
|
| 249 |
+
batch_size = pixel_values.shape[0]
|
| 250 |
+
x = pixel_values
|
| 251 |
+
x = self.patch_embed(x)
|
| 252 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 253 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 254 |
+
# class embeddings and positional embeddings
|
| 255 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
| 256 |
+
x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
|
| 260 |
+
latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
|
| 261 |
+
x = torch.cat([x, latent_tokens], dim=1)
|
| 262 |
+
|
| 263 |
+
x = self.ln_pre(x)
|
| 264 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 265 |
+
for i in range(self.num_layers):
|
| 266 |
+
x = self.transformer[i](x)
|
| 267 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 268 |
+
|
| 269 |
+
latent_tokens = x[:, 1+self.grid_size**2:]
|
| 270 |
+
latent_tokens = self.ln_post(latent_tokens)
|
| 271 |
+
# fake 2D shape
|
| 272 |
+
if self.is_legacy:
|
| 273 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
|
| 274 |
+
else:
|
| 275 |
+
# Fix legacy problem.
|
| 276 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
|
| 277 |
+
latent_tokens = self.conv_out(latent_tokens)
|
| 278 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
|
| 279 |
+
return latent_tokens
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class TiTokDecoder(nn.Module):
|
| 283 |
+
def __init__(self, config):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.config = config
|
| 286 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
| 287 |
+
self.patch_size = config.model.vq_model.vit_dec_patch_size
|
| 288 |
+
self.grid_size = self.image_size // self.patch_size
|
| 289 |
+
self.model_size = config.model.vq_model.vit_dec_model_size
|
| 290 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
| 291 |
+
self.token_size = config.model.vq_model.token_size
|
| 292 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
| 293 |
+
self.width = {
|
| 294 |
+
"small": 512,
|
| 295 |
+
"base": 768,
|
| 296 |
+
"large": 1024,
|
| 297 |
+
}[self.model_size]
|
| 298 |
+
self.num_layers = {
|
| 299 |
+
"small": 8,
|
| 300 |
+
"base": 12,
|
| 301 |
+
"large": 24,
|
| 302 |
+
}[self.model_size]
|
| 303 |
+
self.num_heads = {
|
| 304 |
+
"small": 8,
|
| 305 |
+
"base": 12,
|
| 306 |
+
"large": 16,
|
| 307 |
+
}[self.model_size]
|
| 308 |
+
|
| 309 |
+
self.decoder_embed = nn.Linear(
|
| 310 |
+
self.token_size, self.width, bias=True)
|
| 311 |
+
scale = self.width ** -0.5
|
| 312 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
| 313 |
+
self.positional_embedding = nn.Parameter(
|
| 314 |
+
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
|
| 315 |
+
# add mask token and query pos embed
|
| 316 |
+
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
|
| 317 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
| 318 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
| 319 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
| 320 |
+
self.transformer = nn.ModuleList()
|
| 321 |
+
for i in range(self.num_layers):
|
| 322 |
+
self.transformer.append(ResidualAttentionBlock(
|
| 323 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
| 324 |
+
))
|
| 325 |
+
self.ln_post = nn.LayerNorm(self.width)
|
| 326 |
+
|
| 327 |
+
if self.is_legacy:
|
| 328 |
+
self.ffn = nn.Sequential(
|
| 329 |
+
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
|
| 330 |
+
nn.Tanh(),
|
| 331 |
+
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
|
| 332 |
+
)
|
| 333 |
+
self.conv_out = nn.Identity()
|
| 334 |
+
else:
|
| 335 |
+
# Directly predicting RGB pixels
|
| 336 |
+
self.ffn = nn.Sequential(
|
| 337 |
+
nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
|
| 338 |
+
Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
|
| 339 |
+
p1 = self.patch_size, p2 = self.patch_size),)
|
| 340 |
+
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
|
| 341 |
+
|
| 342 |
+
def forward(self, z_quantized):
|
| 343 |
+
N, C, H, W = z_quantized.shape
|
| 344 |
+
assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
|
| 345 |
+
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
|
| 346 |
+
x = self.decoder_embed(x)
|
| 347 |
+
|
| 348 |
+
batchsize, seq_len, _ = x.shape
|
| 349 |
+
|
| 350 |
+
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
|
| 351 |
+
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
|
| 352 |
+
mask_tokens], dim=1)
|
| 353 |
+
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
|
| 354 |
+
x = x + self.latent_token_positional_embedding[:seq_len]
|
| 355 |
+
x = torch.cat([mask_tokens, x], dim=1)
|
| 356 |
+
|
| 357 |
+
x = self.ln_pre(x)
|
| 358 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 359 |
+
for i in range(self.num_layers):
|
| 360 |
+
x = self.transformer[i](x)
|
| 361 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 362 |
+
x = x[:, 1:1+self.grid_size**2] # remove cls embed
|
| 363 |
+
x = self.ln_post(x)
|
| 364 |
+
# N L D -> N D H W
|
| 365 |
+
x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
|
| 366 |
+
x = self.ffn(x.contiguous())
|
| 367 |
+
x = self.conv_out(x)
|
| 368 |
+
return x
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class TATiTokDecoder(TiTokDecoder):
|
| 372 |
+
def __init__(self, config):
|
| 373 |
+
super().__init__(config)
|
| 374 |
+
scale = self.width ** -0.5
|
| 375 |
+
self.text_context_length = config.model.vq_model.get("text_context_length", 77)
|
| 376 |
+
self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
|
| 377 |
+
self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
|
| 378 |
+
self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
|
| 379 |
+
|
| 380 |
+
def forward(self, z_quantized, text_guidance):
|
| 381 |
+
N, C, H, W = z_quantized.shape
|
| 382 |
+
assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
|
| 383 |
+
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
|
| 384 |
+
x = self.decoder_embed(x)
|
| 385 |
+
|
| 386 |
+
batchsize, seq_len, _ = x.shape
|
| 387 |
+
|
| 388 |
+
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
|
| 389 |
+
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
|
| 390 |
+
mask_tokens], dim=1)
|
| 391 |
+
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
|
| 392 |
+
x = x + self.latent_token_positional_embedding[:seq_len]
|
| 393 |
+
x = torch.cat([mask_tokens, x], dim=1)
|
| 394 |
+
|
| 395 |
+
text_guidance = self.text_guidance_proj(text_guidance)
|
| 396 |
+
text_guidance = text_guidance + self.text_guidance_positional_embedding
|
| 397 |
+
x = torch.cat([x, text_guidance], dim=1)
|
| 398 |
+
|
| 399 |
+
x = self.ln_pre(x)
|
| 400 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 401 |
+
for i in range(self.num_layers):
|
| 402 |
+
x = self.transformer[i](x)
|
| 403 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 404 |
+
x = x[:, 1:1+self.grid_size**2] # remove cls embed
|
| 405 |
+
x = self.ln_post(x)
|
| 406 |
+
# N L D -> N D H W
|
| 407 |
+
x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
|
| 408 |
+
x = self.ffn(x.contiguous())
|
| 409 |
+
x = self.conv_out(x)
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class WeightTiedLMHead(nn.Module):
|
| 414 |
+
def __init__(self, embeddings, target_codebook_size):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.weight = embeddings.weight
|
| 417 |
+
self.target_codebook_size = target_codebook_size
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
# x shape: [batch_size, seq_len, embed_dim]
|
| 421 |
+
# Get the weights for the target codebook size
|
| 422 |
+
weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
|
| 423 |
+
# Compute the logits by matrix multiplication
|
| 424 |
+
logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
|
| 425 |
+
return logits
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class TimestepEmbedder(nn.Module):
|
| 429 |
+
"""
|
| 430 |
+
Embeds scalar timesteps into vector representations.
|
| 431 |
+
"""
|
| 432 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.mlp = nn.Sequential(
|
| 435 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 436 |
+
nn.SiLU(),
|
| 437 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 438 |
+
)
|
| 439 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 440 |
+
|
| 441 |
+
@staticmethod
|
| 442 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 443 |
+
"""
|
| 444 |
+
Create sinusoidal timestep embeddings.
|
| 445 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 446 |
+
These may be fractional.
|
| 447 |
+
:param dim: the dimension of the output.
|
| 448 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 449 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 450 |
+
"""
|
| 451 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 452 |
+
half = dim // 2
|
| 453 |
+
freqs = torch.exp(
|
| 454 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 455 |
+
).to(device=t.device)
|
| 456 |
+
args = t[:, None].float() * freqs[None]
|
| 457 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 458 |
+
if dim % 2:
|
| 459 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 460 |
+
return embedding
|
| 461 |
+
|
| 462 |
+
def forward(self, t):
|
| 463 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 464 |
+
t_emb = self.mlp(t_freq)
|
| 465 |
+
return t_emb
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class ResBlock(nn.Module):
|
| 469 |
+
"""
|
| 470 |
+
A residual block that can optionally change the number of channels.
|
| 471 |
+
:param channels: the number of input channels.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
def __init__(
|
| 475 |
+
self,
|
| 476 |
+
channels
|
| 477 |
+
):
|
| 478 |
+
super().__init__()
|
| 479 |
+
self.channels = channels
|
| 480 |
+
|
| 481 |
+
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
| 482 |
+
self.mlp = nn.Sequential(
|
| 483 |
+
nn.Linear(channels, channels, bias=True),
|
| 484 |
+
nn.SiLU(),
|
| 485 |
+
nn.Linear(channels, channels, bias=True),
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
self.adaLN_modulation = nn.Sequential(
|
| 489 |
+
nn.SiLU(),
|
| 490 |
+
nn.Linear(channels, 3 * channels, bias=True)
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def forward(self, x, y):
|
| 494 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
| 495 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
| 496 |
+
h = self.mlp(h)
|
| 497 |
+
return x + gate_mlp * h
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class FinalLayer(nn.Module):
|
| 501 |
+
"""
|
| 502 |
+
The final layer adopted from DiT.
|
| 503 |
+
"""
|
| 504 |
+
def __init__(self, model_channels, out_channels):
|
| 505 |
+
super().__init__()
|
| 506 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
| 507 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
| 508 |
+
self.adaLN_modulation = nn.Sequential(
|
| 509 |
+
nn.SiLU(),
|
| 510 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
def forward(self, x, c):
|
| 514 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 515 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 516 |
+
x = self.linear(x)
|
| 517 |
+
return x
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class SimpleMLPAdaLN(nn.Module):
|
| 521 |
+
"""
|
| 522 |
+
The MLP for Diffusion Loss.
|
| 523 |
+
:param in_channels: channels in the input Tensor.
|
| 524 |
+
:param model_channels: base channel count for the model.
|
| 525 |
+
:param out_channels: channels in the output Tensor.
|
| 526 |
+
:param z_channels: channels in the condition.
|
| 527 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
in_channels,
|
| 533 |
+
model_channels,
|
| 534 |
+
out_channels,
|
| 535 |
+
z_channels,
|
| 536 |
+
num_res_blocks,
|
| 537 |
+
grad_checkpointing=False,
|
| 538 |
+
):
|
| 539 |
+
super().__init__()
|
| 540 |
+
|
| 541 |
+
self.in_channels = in_channels
|
| 542 |
+
self.model_channels = model_channels
|
| 543 |
+
self.out_channels = out_channels
|
| 544 |
+
self.num_res_blocks = num_res_blocks
|
| 545 |
+
self.grad_checkpointing = grad_checkpointing
|
| 546 |
+
|
| 547 |
+
self.time_embed = TimestepEmbedder(model_channels)
|
| 548 |
+
self.cond_embed = nn.Linear(z_channels, model_channels)
|
| 549 |
+
|
| 550 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
| 551 |
+
|
| 552 |
+
res_blocks = []
|
| 553 |
+
for i in range(num_res_blocks):
|
| 554 |
+
res_blocks.append(ResBlock(
|
| 555 |
+
model_channels,
|
| 556 |
+
))
|
| 557 |
+
|
| 558 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 559 |
+
self.final_layer = FinalLayer(model_channels, out_channels)
|
| 560 |
+
|
| 561 |
+
self.initialize_weights()
|
| 562 |
+
|
| 563 |
+
def initialize_weights(self):
|
| 564 |
+
def _basic_init(module):
|
| 565 |
+
if isinstance(module, nn.Linear):
|
| 566 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 567 |
+
if module.bias is not None:
|
| 568 |
+
nn.init.constant_(module.bias, 0)
|
| 569 |
+
self.apply(_basic_init)
|
| 570 |
+
|
| 571 |
+
# Initialize timestep embedding MLP
|
| 572 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 573 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 574 |
+
|
| 575 |
+
# Zero-out adaLN modulation layers
|
| 576 |
+
for block in self.res_blocks:
|
| 577 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 578 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 579 |
+
|
| 580 |
+
# Zero-out output layers
|
| 581 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 582 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 583 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 584 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 585 |
+
|
| 586 |
+
def forward(self, x, t, c):
|
| 587 |
+
"""
|
| 588 |
+
Apply the model to an input batch.
|
| 589 |
+
:param x: an [N x C] Tensor of inputs.
|
| 590 |
+
:param t: a 1-D batch of timesteps.
|
| 591 |
+
:param c: conditioning from AR transformer.
|
| 592 |
+
:return: an [N x C] Tensor of outputs.
|
| 593 |
+
"""
|
| 594 |
+
x = self.input_proj(x)
|
| 595 |
+
t = self.time_embed(t)
|
| 596 |
+
c = self.cond_embed(c)
|
| 597 |
+
|
| 598 |
+
y = t + c
|
| 599 |
+
|
| 600 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 601 |
+
for block in self.res_blocks:
|
| 602 |
+
x = checkpoint(block, x, y)
|
| 603 |
+
else:
|
| 604 |
+
for block in self.res_blocks:
|
| 605 |
+
x = block(x, y)
|
| 606 |
+
|
| 607 |
+
return self.final_layer(x, y)
|
| 608 |
+
|
| 609 |
+
def forward_with_cfg(self, x, t, c, cfg_scale):
|
| 610 |
+
half = x[: len(x) // 2]
|
| 611 |
+
combined = torch.cat([half, half], dim=0)
|
| 612 |
+
model_out = self.forward(combined, t, c)
|
| 613 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 614 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 615 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 616 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 617 |
+
return torch.cat([eps, rest], dim=1)
|
modeling/modules/discriminator.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Discriminator implementation."""
|
| 2 |
+
import functools
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .maskgit_vqgan import Conv2dSame
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BlurBlock(torch.nn.Module):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
kernel: Tuple[int] = (1, 3, 3, 1)
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
|
| 21 |
+
kernel = kernel[None, :] * kernel[:, None]
|
| 22 |
+
kernel /= kernel.sum()
|
| 23 |
+
kernel = kernel.unsqueeze(0).unsqueeze(0)
|
| 24 |
+
self.register_buffer("kernel", kernel)
|
| 25 |
+
|
| 26 |
+
def calc_same_pad(self, i: int, k: int, s: int) -> int:
|
| 27 |
+
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
ic, ih, iw = x.size()[-3:]
|
| 31 |
+
pad_h = self.calc_same_pad(i=ih, k=4, s=2)
|
| 32 |
+
pad_w = self.calc_same_pad(i=iw, k=4, s=2)
|
| 33 |
+
if pad_h > 0 or pad_w > 0:
|
| 34 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 35 |
+
|
| 36 |
+
weight = self.kernel.expand(ic, -1, -1, -1)
|
| 37 |
+
|
| 38 |
+
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NLayerDiscriminator(torch.nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
num_channels: int = 3,
|
| 46 |
+
hidden_channels: int = 128,
|
| 47 |
+
num_stages: int = 3,
|
| 48 |
+
blur_resample: bool = True,
|
| 49 |
+
blur_kernel_size: int = 4
|
| 50 |
+
):
|
| 51 |
+
""" Initializes the NLayerDiscriminator.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
num_channels -> int: The number of input channels.
|
| 55 |
+
hidden_channels -> int: The number of hidden channels.
|
| 56 |
+
num_stages -> int: The number of stages.
|
| 57 |
+
blur_resample -> bool: Whether to use blur resampling.
|
| 58 |
+
blur_kernel_size -> int: The blur kernel size.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
assert num_stages > 0, "Discriminator cannot have 0 stages"
|
| 62 |
+
assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
|
| 63 |
+
|
| 64 |
+
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
|
| 65 |
+
init_kernel_size = 5
|
| 66 |
+
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
|
| 67 |
+
|
| 68 |
+
self.block_in = torch.nn.Sequential(
|
| 69 |
+
Conv2dSame(
|
| 70 |
+
num_channels,
|
| 71 |
+
hidden_channels,
|
| 72 |
+
kernel_size=init_kernel_size
|
| 73 |
+
),
|
| 74 |
+
activation(),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
BLUR_KERNEL_MAP = {
|
| 78 |
+
3: (1,2,1),
|
| 79 |
+
4: (1,3,3,1),
|
| 80 |
+
5: (1,4,6,4,1),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
discriminator_blocks = []
|
| 84 |
+
for i_level in range(num_stages):
|
| 85 |
+
in_channels = hidden_channels * in_channel_mult[i_level]
|
| 86 |
+
out_channels = hidden_channels * in_channel_mult[i_level + 1]
|
| 87 |
+
block = torch.nn.Sequential(
|
| 88 |
+
Conv2dSame(
|
| 89 |
+
in_channels,
|
| 90 |
+
out_channels,
|
| 91 |
+
kernel_size=3,
|
| 92 |
+
),
|
| 93 |
+
torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
|
| 94 |
+
torch.nn.GroupNorm(32, out_channels),
|
| 95 |
+
activation(),
|
| 96 |
+
)
|
| 97 |
+
discriminator_blocks.append(block)
|
| 98 |
+
|
| 99 |
+
self.blocks = torch.nn.ModuleList(discriminator_blocks)
|
| 100 |
+
|
| 101 |
+
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
|
| 102 |
+
|
| 103 |
+
self.to_logits = torch.nn.Sequential(
|
| 104 |
+
Conv2dSame(out_channels, out_channels, 1),
|
| 105 |
+
activation(),
|
| 106 |
+
Conv2dSame(out_channels, 1, kernel_size=5)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
""" Forward pass.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
x -> torch.Tensor: The input tensor.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
output -> torch.Tensor: The output tensor.
|
| 117 |
+
"""
|
| 118 |
+
hidden_states = self.block_in(x)
|
| 119 |
+
for block in self.blocks:
|
| 120 |
+
hidden_states = block(hidden_states)
|
| 121 |
+
|
| 122 |
+
hidden_states = self.pool(hidden_states)
|
| 123 |
+
|
| 124 |
+
return self.to_logits(hidden_states)
|
modeling/modules/ema_model.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""EMA (Exponential Moving Average) model.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
from typing import Any, Iterable, Optional, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EMAModel:
|
| 15 |
+
"""Exponential Moving Average of models weights."""
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 19 |
+
decay: float = 0.9999,
|
| 20 |
+
min_decay: float = 0.0,
|
| 21 |
+
update_after_step: int = 0,
|
| 22 |
+
update_every: int = 1,
|
| 23 |
+
current_step: int = 0,
|
| 24 |
+
use_ema_warmup: bool = False,
|
| 25 |
+
inv_gamma: Union[float, int] = 1.0,
|
| 26 |
+
power: Union[float, int] = 2 / 3,
|
| 27 |
+
model_cls: Optional[Any] = None,
|
| 28 |
+
**model_config_kwargs
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
| 33 |
+
decay (float): The decay factor for the exponential moving average.
|
| 34 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
| 35 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
| 36 |
+
update_every (int): The number of steps between each EMA update.
|
| 37 |
+
current_step (int): The current training step.
|
| 38 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
| 39 |
+
inv_gamma (float):
|
| 40 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
| 41 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
| 42 |
+
|
| 43 |
+
notes on EMA Warmup:
|
| 44 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
| 45 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
| 46 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
| 47 |
+
at 215.4k steps).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
parameters = list(parameters)
|
| 51 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 52 |
+
self.temp_stored_params = None
|
| 53 |
+
|
| 54 |
+
self.decay = decay
|
| 55 |
+
self.min_decay = min_decay
|
| 56 |
+
self.update_after_step = update_after_step
|
| 57 |
+
self.update_every = update_every
|
| 58 |
+
self.use_ema_warmup = use_ema_warmup
|
| 59 |
+
self.inv_gamma = inv_gamma
|
| 60 |
+
self.power = power
|
| 61 |
+
self.optimization_step = current_step
|
| 62 |
+
self.cur_decay_value = None # set in `step()`
|
| 63 |
+
|
| 64 |
+
self.model_cls = model_cls
|
| 65 |
+
self.model_config_kwargs = model_config_kwargs
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel":
|
| 69 |
+
model = model_cls(**model_config_kwargs)
|
| 70 |
+
model.load_pretrained_weight(checkpoint)
|
| 71 |
+
|
| 72 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs)
|
| 73 |
+
return ema_model
|
| 74 |
+
|
| 75 |
+
def save_pretrained(self, path):
|
| 76 |
+
if self.model_cls is None:
|
| 77 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
| 78 |
+
|
| 79 |
+
if self.model_config_kwargs is None:
|
| 80 |
+
raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.")
|
| 81 |
+
|
| 82 |
+
model = self.model_cls(**self.model_config_kwargs)
|
| 83 |
+
self.copy_to(model.parameters())
|
| 84 |
+
model.save_pretrained_weight(path)
|
| 85 |
+
|
| 86 |
+
def set_step(self, optimization_step: int):
|
| 87 |
+
self.optimization_step = optimization_step
|
| 88 |
+
|
| 89 |
+
def get_decay(self, optimization_step: int) -> float:
|
| 90 |
+
"""Computes the decay factor for the exponential moving average."""
|
| 91 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
| 92 |
+
|
| 93 |
+
if step <= 0:
|
| 94 |
+
return 0.0
|
| 95 |
+
|
| 96 |
+
if self.use_ema_warmup:
|
| 97 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
| 98 |
+
else:
|
| 99 |
+
cur_decay_value = (1 + step) / (10 + step)
|
| 100 |
+
|
| 101 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
| 102 |
+
# Make sure decay is not smaller than min_decay.
|
| 103 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
| 104 |
+
return cur_decay_value
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
| 108 |
+
parameters = list(parameters)
|
| 109 |
+
|
| 110 |
+
self.optimization_step += 1
|
| 111 |
+
|
| 112 |
+
if (self.optimization_step - 1) % self.update_every != 0:
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
# Compute the decay factor for the exponential moving average.
|
| 116 |
+
decay = self.get_decay(self.optimization_step)
|
| 117 |
+
self.cur_decay_value = decay
|
| 118 |
+
one_minus_decay = 1 - decay
|
| 119 |
+
|
| 120 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 121 |
+
if param.requires_grad:
|
| 122 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 123 |
+
else:
|
| 124 |
+
s_param.copy_(param)
|
| 125 |
+
|
| 126 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 127 |
+
"""Copies current averaged parameters into given collection of parameters.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 131 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
| 132 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 133 |
+
"""
|
| 134 |
+
parameters = list(parameters)
|
| 135 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 136 |
+
param.data.copy_(s_param.to(param.device).data)
|
| 137 |
+
|
| 138 |
+
def to(self, device=None, dtype=None) -> None:
|
| 139 |
+
r"""Moves internal buffers of the ExponentialMovingAverage to `device`.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 143 |
+
"""
|
| 144 |
+
# .to() on the tensors handles None correctly
|
| 145 |
+
self.shadow_params = [
|
| 146 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
| 147 |
+
for p in self.shadow_params
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
def state_dict(self) -> dict:
|
| 151 |
+
r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
| 152 |
+
checkpointing to save the ema state dict.
|
| 153 |
+
"""
|
| 154 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 155 |
+
# "returns a reference to the state and not its copy!" -
|
| 156 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 157 |
+
return {
|
| 158 |
+
"decay": self.decay,
|
| 159 |
+
"min_decay": self.min_decay,
|
| 160 |
+
"optimization_step": self.optimization_step,
|
| 161 |
+
"update_after_step": self.update_after_step,
|
| 162 |
+
"use_ema_warmup": self.use_ema_warmup,
|
| 163 |
+
"inv_gamma": self.inv_gamma,
|
| 164 |
+
"power": self.power,
|
| 165 |
+
"shadow_params": self.shadow_params,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 169 |
+
r"""
|
| 170 |
+
Args:
|
| 171 |
+
Save the current parameters for restoring later.
|
| 172 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 173 |
+
temporarily stored.
|
| 174 |
+
"""
|
| 175 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
| 176 |
+
|
| 177 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 178 |
+
r"""Restores the parameters stored with the `store` method. Useful to validate
|
| 179 |
+
the model with EMA parameters without affecting the original optimization process.
|
| 180 |
+
Store the parameters before the `copy_to()` method. After validation (or
|
| 181 |
+
model saving), use this to restore the former parameters.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 185 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
| 186 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 187 |
+
"""
|
| 188 |
+
if self.temp_stored_params is None:
|
| 189 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
| 190 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
| 191 |
+
param.data.copy_(c_param.data)
|
| 192 |
+
|
| 193 |
+
# Better memory-wise.
|
| 194 |
+
self.temp_stored_params = None
|
| 195 |
+
|
| 196 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 197 |
+
r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
| 198 |
+
ema state dict.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 202 |
+
from a call to :meth:`state_dict`.
|
| 203 |
+
"""
|
| 204 |
+
# Deepcopy, to be consistent with module API
|
| 205 |
+
state_dict = copy.deepcopy(state_dict)
|
| 206 |
+
|
| 207 |
+
self.decay = state_dict.get("decay", self.decay)
|
| 208 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 209 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 210 |
+
|
| 211 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
| 212 |
+
if not isinstance(self.min_decay, float):
|
| 213 |
+
raise ValueError("Invalid min_decay")
|
| 214 |
+
|
| 215 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
| 216 |
+
if not isinstance(self.optimization_step, int):
|
| 217 |
+
raise ValueError("Invalid optimization_step")
|
| 218 |
+
|
| 219 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
| 220 |
+
if not isinstance(self.update_after_step, int):
|
| 221 |
+
raise ValueError("Invalid update_after_step")
|
| 222 |
+
|
| 223 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
| 224 |
+
if not isinstance(self.use_ema_warmup, bool):
|
| 225 |
+
raise ValueError("Invalid use_ema_warmup")
|
| 226 |
+
|
| 227 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
| 228 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
| 229 |
+
raise ValueError("Invalid inv_gamma")
|
| 230 |
+
|
| 231 |
+
self.power = state_dict.get("power", self.power)
|
| 232 |
+
if not isinstance(self.power, (float, int)):
|
| 233 |
+
raise ValueError("Invalid power")
|
| 234 |
+
|
| 235 |
+
shadow_params = state_dict.get("shadow_params", None)
|
| 236 |
+
if shadow_params is not None:
|
| 237 |
+
self.shadow_params = shadow_params
|
| 238 |
+
if not isinstance(self.shadow_params, list):
|
| 239 |
+
raise ValueError("shadow_params must be a list")
|
| 240 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
| 241 |
+
raise ValueError("shadow_params must all be Tensors")
|
modeling/modules/encoder_decoder.py
ADDED
|
@@ -0,0 +1,1142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Encoder and decoder building blocks for VibeToken.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
|
| 5 |
+
https://github.com/baofff/U-ViT/blob/main/libs/timm.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.utils.checkpoint import checkpoint
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
import einops
|
| 15 |
+
from einops.layers.torch import Rearrange
|
| 16 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 17 |
+
from modeling.modules.fuzzy_embedding import FuzzyEmbedding
|
| 18 |
+
import collections.abc
|
| 19 |
+
from itertools import repeat
|
| 20 |
+
from typing import Any
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from torch import vmap
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
|
| 27 |
+
def to_2tuple(x: Any) -> Tuple:
|
| 28 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 29 |
+
return tuple(x)
|
| 30 |
+
return tuple(repeat(x, 2))
|
| 31 |
+
|
| 32 |
+
class PatchMixture():
|
| 33 |
+
def __init__(self, seed=42):
|
| 34 |
+
self.seed = seed
|
| 35 |
+
|
| 36 |
+
def get_mask(self, x, mask_ratio=0.0, l1_reg=0.0, inverse=False):
|
| 37 |
+
batch_size, num_patches, _ = x.shape
|
| 38 |
+
device = x.device
|
| 39 |
+
num_mask = int(num_patches * mask_ratio)
|
| 40 |
+
num_keep = num_patches - num_mask
|
| 41 |
+
token_magnitudes = x.abs().sum(dim=-1)
|
| 42 |
+
min_mags = token_magnitudes.min(dim=1, keepdim=True)[0]
|
| 43 |
+
max_mags = token_magnitudes.max(dim=1, keepdim=True)[0]
|
| 44 |
+
token_magnitudes = (token_magnitudes - min_mags) / (max_mags - min_mags + 1e-8)
|
| 45 |
+
if inverse:
|
| 46 |
+
adjusted_magnitudes = 1.0 - token_magnitudes
|
| 47 |
+
else:
|
| 48 |
+
adjusted_magnitudes = token_magnitudes
|
| 49 |
+
noise_random = torch.rand(batch_size, num_patches, device=device)
|
| 50 |
+
noise = (1.0 - l1_reg) * noise_random + l1_reg * adjusted_magnitudes
|
| 51 |
+
ids_shuffle = torch.argsort(noise, dim=1)
|
| 52 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 53 |
+
ids_keep = ids_shuffle[:, :num_keep]
|
| 54 |
+
ids_mask = ids_shuffle[:, num_keep:]
|
| 55 |
+
mask = torch.ones((batch_size, num_patches), device=device, dtype=torch.bool)
|
| 56 |
+
mask.scatter_(1, ids_keep, False)
|
| 57 |
+
return {
|
| 58 |
+
'mask': mask,
|
| 59 |
+
'ids_keep': ids_keep,
|
| 60 |
+
'ids_mask': ids_mask,
|
| 61 |
+
'ids_shuffle': ids_shuffle,
|
| 62 |
+
'ids_restore': ids_restore
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def start_route(self, x, mask_info):
|
| 66 |
+
ids_shuffle = mask_info['ids_shuffle']
|
| 67 |
+
num_keep = mask_info['ids_keep'].size(1)
|
| 68 |
+
batch_indices = torch.arange(x.size(0), device=x.device).unsqueeze(-1)
|
| 69 |
+
x_shuffled = x.gather(1, ids_shuffle.unsqueeze(-1).expand(-1, -1, x.size(2)))
|
| 70 |
+
masked_x = x_shuffled[:, :num_keep, :]
|
| 71 |
+
return masked_x
|
| 72 |
+
|
| 73 |
+
def end_route(self, masked_x, mask_info, original_x=None, mask_token=0.0):
|
| 74 |
+
batch_size, num_patches = mask_info['mask'].shape
|
| 75 |
+
num_keep = masked_x.size(1)
|
| 76 |
+
dim = masked_x.size(2)
|
| 77 |
+
device = masked_x.device
|
| 78 |
+
ids_restore = mask_info['ids_restore']
|
| 79 |
+
batch_indices = torch.arange(batch_size, device=device).unsqueeze(-1)
|
| 80 |
+
x_unshuffled = torch.empty((batch_size, num_patches, dim), device=device)
|
| 81 |
+
x_unshuffled[:, :num_keep, :] = masked_x
|
| 82 |
+
if original_x is not None:
|
| 83 |
+
x_shuffled = original_x.gather(1, mask_info['ids_shuffle'].unsqueeze(-1).expand(-1, -1, dim))
|
| 84 |
+
x_unshuffled[:, num_keep:, :] = x_shuffled[:, num_keep:, :]
|
| 85 |
+
else:
|
| 86 |
+
x_unshuffled[:, num_keep:, :].fill_(mask_token)
|
| 87 |
+
x_unmasked = x_unshuffled.gather(1, ids_restore.unsqueeze(-1).expand(-1, -1, dim))
|
| 88 |
+
return x_unmasked
|
| 89 |
+
|
| 90 |
+
class ResizableBlur(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
Single-parameter anti‑aliasing layer.
|
| 93 |
+
Call with scale=1,2,4 to downsample by 1× (identity), 2×, or 4×.
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, channels: int,
|
| 96 |
+
max_kernel_size: int = 9,
|
| 97 |
+
init_type: str = "gaussian"):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.C = channels
|
| 100 |
+
K = max_kernel_size # e.g. 9 for 4×
|
| 101 |
+
assert K % 2 == 1, "kernel must be odd"
|
| 102 |
+
|
| 103 |
+
# ----- initialise the largest kernel ---------------------------------
|
| 104 |
+
if init_type == "gaussian":
|
| 105 |
+
# 2‑D separable Gaussian, σ≈K/6
|
| 106 |
+
ax = torch.arange(-(K//2), K//2 + 1)
|
| 107 |
+
g1d = torch.exp(-0.5 * (ax / (K/6.0))**2)
|
| 108 |
+
g2d = torch.outer(g1d, g1d)
|
| 109 |
+
kernel = g2d / g2d.sum()
|
| 110 |
+
elif init_type == "lanczos":
|
| 111 |
+
a = K//2 # window size parameter
|
| 112 |
+
x = torch.arange(-a, a+1).float()
|
| 113 |
+
sinc = lambda t: torch.where(t==0, torch.ones_like(t), torch.sin(torch.pi*t)/(torch.pi*t))
|
| 114 |
+
k1d = sinc(x) * sinc(x/a)
|
| 115 |
+
k2d = torch.outer(k1d, k1d)
|
| 116 |
+
kernel = k2d / k2d.sum()
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError("unknown init_type")
|
| 119 |
+
|
| 120 |
+
# learnable base kernel (shape 1×1×K×K)
|
| 121 |
+
self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0))
|
| 122 |
+
|
| 123 |
+
# ------------------------------------------------------------------------
|
| 124 |
+
@staticmethod
|
| 125 |
+
def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Bilinearly interpolate weight (B,C,H,W) to target k_size×k_size,
|
| 128 |
+
then L1‑normalise over spatial dims so Σ=1.
|
| 129 |
+
"""
|
| 130 |
+
if weight.shape[-1] != k_size:
|
| 131 |
+
weight = F.interpolate(weight, size=(k_size, k_size),
|
| 132 |
+
mode="bilinear", align_corners=True)
|
| 133 |
+
weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8)
|
| 134 |
+
return weight
|
| 135 |
+
|
| 136 |
+
# ------------------------------------------------------------------------
|
| 137 |
+
def forward(self, x: torch.Tensor, input_size, target_size) -> torch.Tensor:
|
| 138 |
+
# Unpack input and target dimensions
|
| 139 |
+
input_h, input_w = input_size
|
| 140 |
+
target_h, target_w = target_size
|
| 141 |
+
|
| 142 |
+
# Calculate scale factors for height and width
|
| 143 |
+
scale_h = input_h / target_h
|
| 144 |
+
scale_w = input_w / target_w
|
| 145 |
+
|
| 146 |
+
# Determine kernel size based on scale factors
|
| 147 |
+
# Larger scale factors need larger kernels for better anti-aliasing
|
| 148 |
+
k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3)))
|
| 149 |
+
k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3)))
|
| 150 |
+
|
| 151 |
+
# Make sure kernel sizes are odd
|
| 152 |
+
k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1
|
| 153 |
+
k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1
|
| 154 |
+
|
| 155 |
+
# Use the maximum for a square kernel, or create a rectangular kernel if needed
|
| 156 |
+
k_size = max(k_size_h, k_size_w)
|
| 157 |
+
|
| 158 |
+
# Calculate appropriate stride and padding
|
| 159 |
+
stride_h = max(1, round(scale_h))
|
| 160 |
+
stride_w = max(1, round(scale_w))
|
| 161 |
+
pad_h = k_size_h // 2
|
| 162 |
+
pad_w = k_size_w // 2
|
| 163 |
+
|
| 164 |
+
# Get the kernel and normalize it
|
| 165 |
+
k = self._resize_and_normalise(self.weight, k_size) # (1,1,k,k)
|
| 166 |
+
k = k.repeat(self.C, 1, 1, 1) # depth-wise
|
| 167 |
+
|
| 168 |
+
# Apply convolution with calculated parameters
|
| 169 |
+
result = F.conv2d(x, weight=k, stride=(stride_h, stride_w),
|
| 170 |
+
padding=(pad_h, pad_w), groups=self.C)
|
| 171 |
+
|
| 172 |
+
# If the convolution didn't get us exactly to the target size, use interpolation for fine adjustment
|
| 173 |
+
if result.shape[2:] != target_size:
|
| 174 |
+
result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True)
|
| 175 |
+
|
| 176 |
+
return result
|
| 177 |
+
|
| 178 |
+
def modulate(x, shift, scale):
|
| 179 |
+
return x * (1 + scale) + shift
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ResidualAttentionBlock(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
d_model,
|
| 186 |
+
n_head,
|
| 187 |
+
mlp_ratio = 4.0,
|
| 188 |
+
act_layer = nn.GELU,
|
| 189 |
+
norm_layer = nn.LayerNorm
|
| 190 |
+
):
|
| 191 |
+
super().__init__()
|
| 192 |
+
|
| 193 |
+
self.ln_1 = norm_layer(d_model)
|
| 194 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 195 |
+
self.mlp_ratio = mlp_ratio
|
| 196 |
+
# optionally we can disable the FFN
|
| 197 |
+
if mlp_ratio > 0:
|
| 198 |
+
self.ln_2 = norm_layer(d_model)
|
| 199 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 200 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 201 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 202 |
+
("gelu", act_layer()),
|
| 203 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 204 |
+
]))
|
| 205 |
+
|
| 206 |
+
def attention(
|
| 207 |
+
self,
|
| 208 |
+
x: torch.Tensor,
|
| 209 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 210 |
+
):
|
| 211 |
+
return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0]
|
| 212 |
+
|
| 213 |
+
def forward(
|
| 214 |
+
self,
|
| 215 |
+
x: torch.Tensor,
|
| 216 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 217 |
+
):
|
| 218 |
+
attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask)
|
| 219 |
+
x = x + attn_output
|
| 220 |
+
if self.mlp_ratio > 0:
|
| 221 |
+
x = x + self.mlp(self.ln_2(x))
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
| 225 |
+
ATTENTION_MODE = 'flash'
|
| 226 |
+
else:
|
| 227 |
+
try:
|
| 228 |
+
import xformers
|
| 229 |
+
import xformers.ops
|
| 230 |
+
ATTENTION_MODE = 'xformers'
|
| 231 |
+
except:
|
| 232 |
+
ATTENTION_MODE = 'math'
|
| 233 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Attention(nn.Module):
|
| 237 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.num_heads = num_heads
|
| 240 |
+
head_dim = dim // num_heads
|
| 241 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 242 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 243 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 244 |
+
self.proj = nn.Linear(dim, dim)
|
| 245 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
B, L, C = x.shape
|
| 249 |
+
|
| 250 |
+
qkv = self.qkv(x)
|
| 251 |
+
if ATTENTION_MODE == 'flash':
|
| 252 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
| 253 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
| 254 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 255 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
| 256 |
+
elif ATTENTION_MODE == 'xformers':
|
| 257 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
| 258 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
| 259 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
| 260 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
| 261 |
+
elif ATTENTION_MODE == 'math':
|
| 262 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
|
| 263 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
| 264 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 265 |
+
attn = attn.softmax(dim=-1)
|
| 266 |
+
attn = self.attn_drop(attn)
|
| 267 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplemented
|
| 270 |
+
|
| 271 |
+
x = self.proj(x)
|
| 272 |
+
x = self.proj_drop(x)
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 277 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 278 |
+
|
| 279 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 280 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 281 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 282 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 283 |
+
'survival rate' as the argument.
|
| 284 |
+
|
| 285 |
+
"""
|
| 286 |
+
if drop_prob == 0. or not training:
|
| 287 |
+
return x
|
| 288 |
+
keep_prob = 1 - drop_prob
|
| 289 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 290 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 291 |
+
random_tensor.floor_() # binarize
|
| 292 |
+
output = x.div(keep_prob) * random_tensor
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class DropPath(nn.Module):
|
| 297 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 298 |
+
"""
|
| 299 |
+
def __init__(self, drop_prob=None):
|
| 300 |
+
super(DropPath, self).__init__()
|
| 301 |
+
self.drop_prob = drop_prob
|
| 302 |
+
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class Mlp(nn.Module):
|
| 308 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 309 |
+
super().__init__()
|
| 310 |
+
out_features = out_features or in_features
|
| 311 |
+
hidden_features = hidden_features or in_features
|
| 312 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 313 |
+
self.act = act_layer()
|
| 314 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 315 |
+
self.drop = nn.Dropout(drop)
|
| 316 |
+
|
| 317 |
+
def forward(self, x):
|
| 318 |
+
x = self.fc1(x)
|
| 319 |
+
x = self.act(x)
|
| 320 |
+
x = self.drop(x)
|
| 321 |
+
x = self.fc2(x)
|
| 322 |
+
x = self.drop(x)
|
| 323 |
+
return x
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class UViTBlock(nn.Module):
|
| 327 |
+
|
| 328 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 329 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.norm1 = norm_layer(dim)
|
| 332 |
+
self.attn = Attention(
|
| 333 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 334 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 335 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 336 |
+
self.norm2 = norm_layer(dim)
|
| 337 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 338 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 339 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
| 340 |
+
self.use_checkpoint = use_checkpoint
|
| 341 |
+
|
| 342 |
+
def forward(self, x, skip=None):
|
| 343 |
+
if self.use_checkpoint:
|
| 344 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
| 345 |
+
else:
|
| 346 |
+
return self._forward(x, skip)
|
| 347 |
+
|
| 348 |
+
def _forward(self, x, skip=None):
|
| 349 |
+
if self.skip_linear is not None:
|
| 350 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
| 351 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 352 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def _expand_token(token, batch_size: int):
|
| 357 |
+
return token.unsqueeze(0).expand(batch_size, -1, -1)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class ResolutionEncoder(nn.Module):
|
| 361 |
+
def __init__(self, config):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.config = config
|
| 364 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
| 365 |
+
self.patch_size = config.model.vq_model.vit_enc_patch_size
|
| 366 |
+
self.model_size = config.model.vq_model.vit_enc_model_size
|
| 367 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
| 368 |
+
self.token_size = config.model.vq_model.token_size
|
| 369 |
+
self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
|
| 370 |
+
self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
|
| 371 |
+
self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
|
| 372 |
+
|
| 373 |
+
if config.model.vq_model.get("quantize_mode", "vq") == "vae":
|
| 374 |
+
self.token_size = self.token_size * 2 # needs to split into mean and std
|
| 375 |
+
|
| 376 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
| 377 |
+
|
| 378 |
+
self.width = {
|
| 379 |
+
"tiny": 256,
|
| 380 |
+
"small": 512,
|
| 381 |
+
"base": 768,
|
| 382 |
+
"large": 1024,
|
| 383 |
+
}[self.model_size]
|
| 384 |
+
self.num_layers = {
|
| 385 |
+
"tiny": 4,
|
| 386 |
+
"small": 8,
|
| 387 |
+
"base": 12,
|
| 388 |
+
"large": 24,
|
| 389 |
+
}[self.model_size]
|
| 390 |
+
self.num_heads = {
|
| 391 |
+
"tiny": 4,
|
| 392 |
+
"small": 8,
|
| 393 |
+
"base": 12,
|
| 394 |
+
"large": 16,
|
| 395 |
+
}[self.model_size]
|
| 396 |
+
|
| 397 |
+
self.patch_embed = nn.Conv2d(
|
| 398 |
+
in_channels=3, out_channels=self.width,
|
| 399 |
+
kernel_size=self.patch_size, stride=self.patch_size, bias=True)
|
| 400 |
+
|
| 401 |
+
scale = self.width ** -0.5
|
| 402 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
| 403 |
+
|
| 404 |
+
self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
|
| 405 |
+
|
| 406 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
| 407 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
| 408 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
| 409 |
+
|
| 410 |
+
self.patch_mixture = PatchMixture()
|
| 411 |
+
|
| 412 |
+
self.transformer = nn.ModuleList()
|
| 413 |
+
for i in range(self.num_layers):
|
| 414 |
+
self.transformer.append(ResidualAttentionBlock(
|
| 415 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
| 416 |
+
))
|
| 417 |
+
|
| 418 |
+
self.ln_post = nn.LayerNorm(self.width)
|
| 419 |
+
self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
|
| 420 |
+
self.pinvs = {}
|
| 421 |
+
|
| 422 |
+
def apply_flexivit_patch_embed(self, x, target_patch_size):
|
| 423 |
+
patch_size = to_2tuple(target_patch_size)
|
| 424 |
+
|
| 425 |
+
# Resize conv weights
|
| 426 |
+
if patch_size == to_2tuple(self.patch_size):
|
| 427 |
+
weight = self.patch_embed.weight
|
| 428 |
+
else:
|
| 429 |
+
weight = self.resize_patch_embed(self.patch_embed.weight, patch_size)
|
| 430 |
+
|
| 431 |
+
# Apply conv with resized weights
|
| 432 |
+
x = F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size)
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
|
| 436 |
+
x_resized = F.interpolate(
|
| 437 |
+
x[None, None, ...],
|
| 438 |
+
shape,
|
| 439 |
+
mode="bilinear",
|
| 440 |
+
antialias=False,
|
| 441 |
+
)
|
| 442 |
+
return x_resized[0, 0, ...]
|
| 443 |
+
|
| 444 |
+
def _calculate_pinv(
|
| 445 |
+
self, old_shape: Tuple[int, int], new_shape: Tuple[int, int], device=None
|
| 446 |
+
) -> Tensor:
|
| 447 |
+
# Use the device from patch_embed weights if available
|
| 448 |
+
if device is None and hasattr(self, 'patch_embed'):
|
| 449 |
+
device = self.patch_embed.weight.device
|
| 450 |
+
|
| 451 |
+
mat = []
|
| 452 |
+
for i in range(np.prod(old_shape)):
|
| 453 |
+
basis_vec = torch.zeros(old_shape, device=device) # Specify device here
|
| 454 |
+
basis_vec[np.unravel_index(i, old_shape)] = 1.0
|
| 455 |
+
mat.append(self._resize(basis_vec, new_shape).reshape(-1))
|
| 456 |
+
resize_matrix = torch.stack(mat)
|
| 457 |
+
return torch.linalg.pinv(resize_matrix)
|
| 458 |
+
|
| 459 |
+
def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]):
|
| 460 |
+
"""Resize patch_embed to target resolution via pseudo-inverse resizing"""
|
| 461 |
+
# Return original kernel if no resize is necessary
|
| 462 |
+
if to_2tuple(self.patch_size) == new_patch_size:
|
| 463 |
+
return patch_embed
|
| 464 |
+
|
| 465 |
+
# Calculate pseudo-inverse of resize matrix
|
| 466 |
+
if new_patch_size not in self.pinvs:
|
| 467 |
+
self.pinvs[new_patch_size] = self._calculate_pinv(
|
| 468 |
+
to_2tuple(self.patch_size), new_patch_size, device=patch_embed.device
|
| 469 |
+
)
|
| 470 |
+
pinv = self.pinvs[new_patch_size]
|
| 471 |
+
|
| 472 |
+
def resample_patch_embed(patch_embed: Tensor):
|
| 473 |
+
h, w = new_patch_size
|
| 474 |
+
original_dtype = patch_embed.dtype
|
| 475 |
+
patch_embed_float = patch_embed.float()
|
| 476 |
+
resampled_kernel = pinv @ patch_embed_float.reshape(-1)
|
| 477 |
+
resampled_kernel = resampled_kernel.to(original_dtype)
|
| 478 |
+
return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
|
| 479 |
+
|
| 480 |
+
v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
|
| 481 |
+
|
| 482 |
+
return v_resample_patch_embed(patch_embed)
|
| 483 |
+
|
| 484 |
+
def get_attention_mask(self, target_shape, attention_mask):
|
| 485 |
+
# Create mask for mask_tokens (all True since we want to attend to all mask tokens)
|
| 486 |
+
mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
|
| 487 |
+
# Combine with input attention mask
|
| 488 |
+
attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
|
| 489 |
+
sequence_length = attention_mask.shape[1]
|
| 490 |
+
|
| 491 |
+
# Create causal attention mask
|
| 492 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
|
| 493 |
+
attention_mask = attention_mask.expand(
|
| 494 |
+
attention_mask.shape[0],
|
| 495 |
+
self.num_heads,
|
| 496 |
+
sequence_length,
|
| 497 |
+
sequence_length
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Reshape to [B*num_heads, S, S]
|
| 501 |
+
attention_mask = attention_mask.reshape(
|
| 502 |
+
-1, sequence_length, sequence_length
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# Convert boolean mask to float
|
| 506 |
+
attention_mask = attention_mask.float()
|
| 507 |
+
|
| 508 |
+
# Convert mask values: True -> 0.0, False -> -inf
|
| 509 |
+
attention_mask = attention_mask.masked_fill(
|
| 510 |
+
~attention_mask.bool(),
|
| 511 |
+
float('-inf')
|
| 512 |
+
)
|
| 513 |
+
return attention_mask
|
| 514 |
+
|
| 515 |
+
def forward(self, pixel_values, latent_tokens, attention_mask=None, encode_patch_size=None, train=True):
|
| 516 |
+
batch_size, _, H, W = pixel_values.shape
|
| 517 |
+
x = pixel_values
|
| 518 |
+
|
| 519 |
+
# Apply dynamic patch embedding
|
| 520 |
+
# Determine patch size dynamically based on image resolution
|
| 521 |
+
# Base patch size (32) is for 512x512 images
|
| 522 |
+
# Scale proportionally for other resolutions to maintain ~256 tokens
|
| 523 |
+
base_resolution = 512
|
| 524 |
+
|
| 525 |
+
if encode_patch_size is None:
|
| 526 |
+
base_patch_size = random.choice([16, 32])
|
| 527 |
+
target_patch_size = min(int(min(H, W) / base_resolution * base_patch_size), 32) # we force it to be at most 32 otherwise we lose information
|
| 528 |
+
else:
|
| 529 |
+
target_patch_size = encode_patch_size
|
| 530 |
+
|
| 531 |
+
if isinstance(target_patch_size, int):
|
| 532 |
+
target_patch_size = (target_patch_size, target_patch_size)
|
| 533 |
+
|
| 534 |
+
x = self.apply_flexivit_patch_embed(x, target_patch_size)
|
| 535 |
+
|
| 536 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 537 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 538 |
+
# class embeddings and positional embeddings
|
| 539 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
| 540 |
+
|
| 541 |
+
# create image_rotary_emb
|
| 542 |
+
grid_height = H // target_patch_size[0]
|
| 543 |
+
grid_width = W // target_patch_size[1]
|
| 544 |
+
|
| 545 |
+
mask_ratio = 0.0
|
| 546 |
+
if grid_height*grid_width > 256 and train:
|
| 547 |
+
mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
|
| 548 |
+
|
| 549 |
+
num_latent_tokens = latent_tokens.shape[0]
|
| 550 |
+
latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
|
| 551 |
+
latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent_tokens]
|
| 552 |
+
|
| 553 |
+
x = x + self.positional_embedding(grid_height, grid_width, train=train, dtype=x.dtype)
|
| 554 |
+
|
| 555 |
+
# apply attention_mask before concatenating x and latent_tokens
|
| 556 |
+
if attention_mask is not None:
|
| 557 |
+
key_attention_mask = attention_mask.clone()
|
| 558 |
+
attention_mask = self.get_attention_mask((batch_size, x.shape[1]), key_attention_mask)
|
| 559 |
+
full_seq_attention_mask = attention_mask.clone()
|
| 560 |
+
else:
|
| 561 |
+
key_attention_mask = None
|
| 562 |
+
full_seq_attention_mask = None
|
| 563 |
+
|
| 564 |
+
# Concatenate x and latent_tokens first
|
| 565 |
+
x = torch.cat([x, latent_tokens], dim=1)
|
| 566 |
+
|
| 567 |
+
x = self.ln_pre(x)
|
| 568 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 569 |
+
for i in range(self.num_layers):
|
| 570 |
+
if i == self.patch_mixture_start_layer:
|
| 571 |
+
x = x.permute(1, 0, 2)
|
| 572 |
+
x_D_last = x[:, 1:grid_height*grid_width+1].clone()
|
| 573 |
+
mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
|
| 574 |
+
new_x = self.patch_mixture.start_route(x, mask_info)
|
| 575 |
+
x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
|
| 576 |
+
x = x.permute(1, 0, 2)
|
| 577 |
+
if key_attention_mask is not None:
|
| 578 |
+
attention_mask = self.get_attention_mask((batch_size, 1+new_x.shape[1]), key_attention_mask)
|
| 579 |
+
else:
|
| 580 |
+
attention_mask = None
|
| 581 |
+
|
| 582 |
+
x = self.transformer[i](x, attention_mask=attention_mask)
|
| 583 |
+
|
| 584 |
+
if i == self.patch_mixture_end_layer:
|
| 585 |
+
x = x.permute(1, 0, 2)
|
| 586 |
+
new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
|
| 587 |
+
x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
|
| 588 |
+
x = x.permute(1, 0, 2)
|
| 589 |
+
if full_seq_attention_mask is not None:
|
| 590 |
+
attention_mask = full_seq_attention_mask.clone()
|
| 591 |
+
else:
|
| 592 |
+
attention_mask = None
|
| 593 |
+
|
| 594 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 595 |
+
|
| 596 |
+
latent_tokens = x[:, 1+grid_height*grid_width:]
|
| 597 |
+
latent_tokens = self.ln_post(latent_tokens)
|
| 598 |
+
|
| 599 |
+
# fake 2D shape
|
| 600 |
+
if self.is_legacy:
|
| 601 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent_tokens, 1)
|
| 602 |
+
else:
|
| 603 |
+
# Fix legacy problem.
|
| 604 |
+
latent_tokens = latent_tokens.reshape(batch_size, num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
|
| 605 |
+
latent_tokens = self.conv_out(latent_tokens)
|
| 606 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent_tokens)
|
| 607 |
+
return latent_tokens
|
| 608 |
+
|
| 609 |
+
# Keep the original TiTokEncoder as a legacy class
|
| 610 |
+
class TiTokEncoder(ResolutionEncoder):
|
| 611 |
+
"""Legacy TiTokEncoder - now inherits from ResolutionEncoder for backward compatibility"""
|
| 612 |
+
pass
|
| 613 |
+
|
| 614 |
+
class ResolutionDecoder(nn.Module):
|
| 615 |
+
def __init__(self, config):
|
| 616 |
+
super().__init__()
|
| 617 |
+
self.config = config
|
| 618 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
| 619 |
+
self.patch_size = config.model.vq_model.vit_dec_patch_size
|
| 620 |
+
self.model_size = config.model.vq_model.vit_dec_model_size
|
| 621 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
| 622 |
+
self.token_size = config.model.vq_model.token_size
|
| 623 |
+
self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
|
| 624 |
+
self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
|
| 625 |
+
self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
|
| 626 |
+
|
| 627 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
| 628 |
+
self.width = {
|
| 629 |
+
"tiny": 256,
|
| 630 |
+
"small": 512,
|
| 631 |
+
"base": 768,
|
| 632 |
+
"large": 1024,
|
| 633 |
+
}[self.model_size]
|
| 634 |
+
self.num_layers = {
|
| 635 |
+
"tiny": 4,
|
| 636 |
+
"small": 8,
|
| 637 |
+
"base": 12,
|
| 638 |
+
"large": 24,
|
| 639 |
+
}[self.model_size]
|
| 640 |
+
self.num_heads = {
|
| 641 |
+
"tiny": 4,
|
| 642 |
+
"small": 8,
|
| 643 |
+
"base": 12,
|
| 644 |
+
"large": 16,
|
| 645 |
+
}[self.model_size]
|
| 646 |
+
|
| 647 |
+
self.decoder_embed = nn.Linear(
|
| 648 |
+
self.token_size, self.width, bias=True)
|
| 649 |
+
scale = self.width ** -0.5
|
| 650 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
| 651 |
+
|
| 652 |
+
self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
|
| 653 |
+
|
| 654 |
+
# add mask token and query pos embed
|
| 655 |
+
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
|
| 656 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
| 657 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
| 658 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
| 659 |
+
|
| 660 |
+
self.patch_mixture = PatchMixture()
|
| 661 |
+
|
| 662 |
+
self.transformer = nn.ModuleList()
|
| 663 |
+
for i in range(self.num_layers):
|
| 664 |
+
self.transformer.append(ResidualAttentionBlock(
|
| 665 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
| 666 |
+
))
|
| 667 |
+
self.ln_post = nn.LayerNorm(self.width)
|
| 668 |
+
|
| 669 |
+
if self.is_legacy:
|
| 670 |
+
raise NotImplementedError("Legacy mode is not implemented for ResolutionDecoder")
|
| 671 |
+
else:
|
| 672 |
+
# Directly predicting RGB pixels
|
| 673 |
+
self.ffn = nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True)
|
| 674 |
+
self.rearrange = Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
|
| 675 |
+
p1 = self.patch_size, p2 = self.patch_size)
|
| 676 |
+
self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
|
| 677 |
+
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
|
| 678 |
+
|
| 679 |
+
def get_attention_mask(self, target_shape, attention_mask):
|
| 680 |
+
# Create mask for mask_tokens (all True since we want to attend to all mask tokens)
|
| 681 |
+
mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
|
| 682 |
+
# Combine with input attention mask
|
| 683 |
+
attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
|
| 684 |
+
sequence_length = attention_mask.shape[1]
|
| 685 |
+
|
| 686 |
+
# Create causal attention mask
|
| 687 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
|
| 688 |
+
attention_mask = attention_mask.expand(
|
| 689 |
+
attention_mask.shape[0],
|
| 690 |
+
self.num_heads,
|
| 691 |
+
sequence_length,
|
| 692 |
+
sequence_length
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# Reshape to [B*num_heads, S, S]
|
| 696 |
+
attention_mask = attention_mask.reshape(
|
| 697 |
+
-1, sequence_length, sequence_length
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Convert boolean mask to float
|
| 701 |
+
attention_mask = attention_mask.float()
|
| 702 |
+
|
| 703 |
+
# Convert mask values: True -> 0.0, False -> -inf
|
| 704 |
+
attention_mask = attention_mask.masked_fill(
|
| 705 |
+
~attention_mask.bool(),
|
| 706 |
+
float('-inf')
|
| 707 |
+
)
|
| 708 |
+
return attention_mask
|
| 709 |
+
|
| 710 |
+
def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
|
| 711 |
+
N, C, H, W = z_quantized.shape
|
| 712 |
+
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
|
| 713 |
+
x = self.decoder_embed(x)
|
| 714 |
+
|
| 715 |
+
batchsize, seq_len, _ = x.shape
|
| 716 |
+
|
| 717 |
+
if height is None:
|
| 718 |
+
height = self.image_size
|
| 719 |
+
if width is None:
|
| 720 |
+
width = self.image_size
|
| 721 |
+
|
| 722 |
+
# create image_rotary_emb
|
| 723 |
+
if decode_patch_size is None:
|
| 724 |
+
# Calculate total area and determine appropriate patch size
|
| 725 |
+
total_pixels = height * width
|
| 726 |
+
|
| 727 |
+
# Target patch counts between 256 and 1024
|
| 728 |
+
min_patches = 256
|
| 729 |
+
max_patches = 1024
|
| 730 |
+
|
| 731 |
+
# Calculate possible patch sizes that would give us patch counts in our target range
|
| 732 |
+
possible_patch_sizes = []
|
| 733 |
+
for patch_size in [8, 16, 32]:
|
| 734 |
+
grid_h = height // patch_size
|
| 735 |
+
grid_w = width // patch_size
|
| 736 |
+
total_patches = grid_h * grid_w
|
| 737 |
+
if min_patches <= total_patches <= max_patches:
|
| 738 |
+
possible_patch_sizes.append(patch_size)
|
| 739 |
+
|
| 740 |
+
if not possible_patch_sizes:
|
| 741 |
+
# If no patch size gives us the desired range, pick the one closest to our target range
|
| 742 |
+
patch_counts = []
|
| 743 |
+
for patch_size in [8, 16, 32]:
|
| 744 |
+
grid_h = height // patch_size
|
| 745 |
+
grid_w = width // patch_size
|
| 746 |
+
patch_counts.append((patch_size, grid_h * grid_w))
|
| 747 |
+
|
| 748 |
+
# Sort by how close the patch count is to our target range
|
| 749 |
+
patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
|
| 750 |
+
possible_patch_sizes = [patch_counts[0][0]]
|
| 751 |
+
|
| 752 |
+
selected_patch_size = random.choice(possible_patch_sizes)
|
| 753 |
+
else:
|
| 754 |
+
selected_patch_size = decode_patch_size
|
| 755 |
+
|
| 756 |
+
if isinstance(selected_patch_size, int):
|
| 757 |
+
selected_patch_size = (selected_patch_size, selected_patch_size)
|
| 758 |
+
|
| 759 |
+
grid_height = height // selected_patch_size[0]
|
| 760 |
+
grid_width = width // selected_patch_size[1]
|
| 761 |
+
|
| 762 |
+
# if grid_height*grid_width>1024 and train:
|
| 763 |
+
# grid_height = 32
|
| 764 |
+
# grid_width = 32
|
| 765 |
+
|
| 766 |
+
mask_ratio = 0.0
|
| 767 |
+
if grid_height*grid_width > 256 and train:
|
| 768 |
+
mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
|
| 769 |
+
|
| 770 |
+
mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
|
| 771 |
+
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
|
| 772 |
+
mask_tokens], dim=1)
|
| 773 |
+
|
| 774 |
+
mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
|
| 775 |
+
|
| 776 |
+
x = x + self.latent_token_positional_embedding[:seq_len]
|
| 777 |
+
x = torch.cat([mask_tokens, x], dim=1)
|
| 778 |
+
|
| 779 |
+
if attention_mask is not None:
|
| 780 |
+
key_attention_mask = attention_mask.clone()
|
| 781 |
+
attention_mask = self.get_attention_mask((batchsize, 1+grid_height*grid_width), key_attention_mask)
|
| 782 |
+
full_seq_attention_mask = attention_mask.clone()
|
| 783 |
+
else:
|
| 784 |
+
key_attention_mask = None
|
| 785 |
+
full_seq_attention_mask = None
|
| 786 |
+
|
| 787 |
+
x = self.ln_pre(x)
|
| 788 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 789 |
+
for i in range(self.num_layers):
|
| 790 |
+
if i == self.patch_mixture_start_layer:
|
| 791 |
+
x = x.permute(1, 0, 2)
|
| 792 |
+
x_D_last = x[:, 1:grid_height*grid_width+1].clone()
|
| 793 |
+
mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
|
| 794 |
+
new_x = self.patch_mixture.start_route(x, mask_info)
|
| 795 |
+
x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
|
| 796 |
+
x = x.permute(1, 0, 2)
|
| 797 |
+
if key_attention_mask is not None:
|
| 798 |
+
attention_mask = self.get_attention_mask((batchsize, 1+new_x.shape[1]), key_attention_mask)
|
| 799 |
+
else:
|
| 800 |
+
attention_mask = None
|
| 801 |
+
|
| 802 |
+
x = self.transformer[i](x, attention_mask=attention_mask)
|
| 803 |
+
|
| 804 |
+
if i == self.patch_mixture_end_layer:
|
| 805 |
+
x = x.permute(1, 0, 2)
|
| 806 |
+
new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
|
| 807 |
+
x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
|
| 808 |
+
x = x.permute(1, 0, 2)
|
| 809 |
+
if full_seq_attention_mask is not None:
|
| 810 |
+
attention_mask = full_seq_attention_mask.clone()
|
| 811 |
+
else:
|
| 812 |
+
attention_mask = None
|
| 813 |
+
|
| 814 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 815 |
+
x = x[:, 1:1+grid_height*grid_width] # remove cls embed
|
| 816 |
+
x = self.ln_post(x)
|
| 817 |
+
# N L D -> N D H W
|
| 818 |
+
x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
|
| 819 |
+
x = self.ffn(x.contiguous())
|
| 820 |
+
x = self.rearrange(x)
|
| 821 |
+
_, _, org_h, org_w = x.shape
|
| 822 |
+
x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
|
| 823 |
+
x = self.conv_out(x)
|
| 824 |
+
|
| 825 |
+
return x
|
| 826 |
+
|
| 827 |
+
# Keep the original TiTokDecoder as a legacy class that inherits from ResolutionDecoder
|
| 828 |
+
class TiTokDecoder(ResolutionDecoder):
|
| 829 |
+
"""Legacy TiTokDecoder - now inherits from ResolutionDecoder for backward compatibility"""
|
| 830 |
+
|
| 831 |
+
def __init__(self, config):
|
| 832 |
+
# Override config to disable patch mixture and other advanced features for legacy mode
|
| 833 |
+
config_copy = type(config)()
|
| 834 |
+
for attr in dir(config):
|
| 835 |
+
if not attr.startswith('__'):
|
| 836 |
+
try:
|
| 837 |
+
setattr(config_copy, attr, getattr(config, attr))
|
| 838 |
+
except:
|
| 839 |
+
pass
|
| 840 |
+
|
| 841 |
+
# Disable patch mixture for legacy mode
|
| 842 |
+
if hasattr(config_copy.model.vq_model, 'patch_mixture_start_layer'):
|
| 843 |
+
config_copy.model.vq_model.patch_mixture_start_layer = -1
|
| 844 |
+
if hasattr(config_copy.model.vq_model, 'patch_mixture_end_layer'):
|
| 845 |
+
config_copy.model.vq_model.patch_mixture_end_layer = -1
|
| 846 |
+
|
| 847 |
+
super().__init__(config_copy)
|
| 848 |
+
|
| 849 |
+
# Override grid_size for legacy compatibility
|
| 850 |
+
self.grid_size = self.image_size // self.patch_size
|
| 851 |
+
|
| 852 |
+
# Replace ResolutionDecoder's advanced final layers with legacy ones if needed
|
| 853 |
+
if self.is_legacy:
|
| 854 |
+
self.ffn = nn.Sequential(
|
| 855 |
+
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
|
| 856 |
+
nn.Tanh(),
|
| 857 |
+
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
|
| 858 |
+
)
|
| 859 |
+
self.conv_out = nn.Identity()
|
| 860 |
+
else:
|
| 861 |
+
# Use simpler final layers for backward compatibility
|
| 862 |
+
self.ffn = nn.Sequential(
|
| 863 |
+
nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
|
| 864 |
+
Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
|
| 865 |
+
p1 = self.patch_size, p2 = self.patch_size),)
|
| 866 |
+
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
|
| 867 |
+
|
| 868 |
+
def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
|
| 869 |
+
# Legacy compatibility: use fixed grid size if height/width not provided
|
| 870 |
+
if height is None:
|
| 871 |
+
height = self.image_size
|
| 872 |
+
if width is None:
|
| 873 |
+
width = self.image_size
|
| 874 |
+
|
| 875 |
+
# Force decode_patch_size to be the original patch_size for legacy compatibility
|
| 876 |
+
if decode_patch_size is None:
|
| 877 |
+
decode_patch_size = self.patch_size
|
| 878 |
+
|
| 879 |
+
# Use the parent's forward method but with legacy parameters
|
| 880 |
+
return super().forward(z_quantized, attention_mask, height, width, decode_patch_size, train)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
class TATiTokDecoder(ResolutionDecoder):
|
| 884 |
+
def __init__(self, config):
|
| 885 |
+
super().__init__(config)
|
| 886 |
+
scale = self.width ** -0.5
|
| 887 |
+
self.text_context_length = config.model.vq_model.get("text_context_length", 77)
|
| 888 |
+
self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
|
| 889 |
+
self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
|
| 890 |
+
self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
|
| 891 |
+
|
| 892 |
+
# Add grid_size for backward compatibility
|
| 893 |
+
self.grid_size = self.image_size // self.patch_size
|
| 894 |
+
|
| 895 |
+
def forward(self, z_quantized, text_guidance, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
|
| 896 |
+
N, C, H, W = z_quantized.shape
|
| 897 |
+
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
|
| 898 |
+
x = self.decoder_embed(x)
|
| 899 |
+
|
| 900 |
+
batchsize, seq_len, _ = x.shape
|
| 901 |
+
|
| 902 |
+
# Use fixed grid size for backward compatibility
|
| 903 |
+
if height is None:
|
| 904 |
+
height = self.image_size
|
| 905 |
+
if width is None:
|
| 906 |
+
width = self.image_size
|
| 907 |
+
if decode_patch_size is None:
|
| 908 |
+
decode_patch_size = self.patch_size
|
| 909 |
+
|
| 910 |
+
grid_height = height // decode_patch_size
|
| 911 |
+
grid_width = width // decode_patch_size
|
| 912 |
+
|
| 913 |
+
mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
|
| 914 |
+
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
|
| 915 |
+
mask_tokens], dim=1)
|
| 916 |
+
mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
|
| 917 |
+
x = x + self.latent_token_positional_embedding[:seq_len]
|
| 918 |
+
x = torch.cat([mask_tokens, x], dim=1)
|
| 919 |
+
|
| 920 |
+
text_guidance = self.text_guidance_proj(text_guidance)
|
| 921 |
+
text_guidance = text_guidance + self.text_guidance_positional_embedding
|
| 922 |
+
x = torch.cat([x, text_guidance], dim=1)
|
| 923 |
+
|
| 924 |
+
x = self.ln_pre(x)
|
| 925 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 926 |
+
for i in range(self.num_layers):
|
| 927 |
+
x = self.transformer[i](x)
|
| 928 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 929 |
+
x = x[:, 1:1+grid_height*grid_width] # remove cls embed
|
| 930 |
+
x = self.ln_post(x)
|
| 931 |
+
# N L D -> N D H W
|
| 932 |
+
x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
|
| 933 |
+
x = self.ffn(x.contiguous())
|
| 934 |
+
x = self.conv_out(x)
|
| 935 |
+
return x
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
class WeightTiedLMHead(nn.Module):
|
| 939 |
+
def __init__(self, embeddings, target_codebook_size):
|
| 940 |
+
super().__init__()
|
| 941 |
+
self.weight = embeddings.weight
|
| 942 |
+
self.target_codebook_size = target_codebook_size
|
| 943 |
+
|
| 944 |
+
def forward(self, x):
|
| 945 |
+
# x shape: [batch_size, seq_len, embed_dim]
|
| 946 |
+
# Get the weights for the target codebook size
|
| 947 |
+
weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
|
| 948 |
+
# Compute the logits by matrix multiplication
|
| 949 |
+
logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
|
| 950 |
+
return logits
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
class TimestepEmbedder(nn.Module):
|
| 954 |
+
"""
|
| 955 |
+
Embeds scalar timesteps into vector representations.
|
| 956 |
+
"""
|
| 957 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 958 |
+
super().__init__()
|
| 959 |
+
self.mlp = nn.Sequential(
|
| 960 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 961 |
+
nn.SiLU(),
|
| 962 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 963 |
+
)
|
| 964 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 965 |
+
|
| 966 |
+
@staticmethod
|
| 967 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 968 |
+
"""
|
| 969 |
+
Create sinusoidal timestep embeddings.
|
| 970 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 971 |
+
These may be fractional.
|
| 972 |
+
:param dim: the dimension of the output.
|
| 973 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 974 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 975 |
+
"""
|
| 976 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 977 |
+
half = dim // 2
|
| 978 |
+
freqs = torch.exp(
|
| 979 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 980 |
+
).to(device=t.device)
|
| 981 |
+
args = t[:, None].float() * freqs[None]
|
| 982 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 983 |
+
if dim % 2:
|
| 984 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 985 |
+
return embedding
|
| 986 |
+
|
| 987 |
+
def forward(self, t):
|
| 988 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 989 |
+
t_emb = self.mlp(t_freq)
|
| 990 |
+
return t_emb
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
class ResBlock(nn.Module):
|
| 994 |
+
"""
|
| 995 |
+
A residual block that can optionally change the number of channels.
|
| 996 |
+
:param channels: the number of input channels.
|
| 997 |
+
"""
|
| 998 |
+
|
| 999 |
+
def __init__(
|
| 1000 |
+
self,
|
| 1001 |
+
channels
|
| 1002 |
+
):
|
| 1003 |
+
super().__init__()
|
| 1004 |
+
self.channels = channels
|
| 1005 |
+
|
| 1006 |
+
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
| 1007 |
+
self.mlp = nn.Sequential(
|
| 1008 |
+
nn.Linear(channels, channels, bias=True),
|
| 1009 |
+
nn.SiLU(),
|
| 1010 |
+
nn.Linear(channels, channels, bias=True),
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
self.adaLN_modulation = nn.Sequential(
|
| 1014 |
+
nn.SiLU(),
|
| 1015 |
+
nn.Linear(channels, 3 * channels, bias=True)
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
def forward(self, x, y):
|
| 1019 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
| 1020 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
| 1021 |
+
h = self.mlp(h)
|
| 1022 |
+
return x + gate_mlp * h
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
class FinalLayer(nn.Module):
|
| 1026 |
+
"""
|
| 1027 |
+
The final layer adopted from DiT.
|
| 1028 |
+
"""
|
| 1029 |
+
def __init__(self, model_channels, out_channels):
|
| 1030 |
+
super().__init__()
|
| 1031 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
| 1032 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
| 1033 |
+
self.adaLN_modulation = nn.Sequential(
|
| 1034 |
+
nn.SiLU(),
|
| 1035 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True)
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
def forward(self, x, c):
|
| 1039 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 1040 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 1041 |
+
x = self.linear(x)
|
| 1042 |
+
return x
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class SimpleMLPAdaLN(nn.Module):
|
| 1046 |
+
"""
|
| 1047 |
+
The MLP for Diffusion Loss.
|
| 1048 |
+
:param in_channels: channels in the input Tensor.
|
| 1049 |
+
:param model_channels: base channel count for the model.
|
| 1050 |
+
:param out_channels: channels in the output Tensor.
|
| 1051 |
+
:param z_channels: channels in the condition.
|
| 1052 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
| 1053 |
+
"""
|
| 1054 |
+
|
| 1055 |
+
def __init__(
|
| 1056 |
+
self,
|
| 1057 |
+
in_channels,
|
| 1058 |
+
model_channels,
|
| 1059 |
+
out_channels,
|
| 1060 |
+
z_channels,
|
| 1061 |
+
num_res_blocks,
|
| 1062 |
+
grad_checkpointing=False,
|
| 1063 |
+
):
|
| 1064 |
+
super().__init__()
|
| 1065 |
+
|
| 1066 |
+
self.in_channels = in_channels
|
| 1067 |
+
self.model_channels = model_channels
|
| 1068 |
+
self.out_channels = out_channels
|
| 1069 |
+
self.num_res_blocks = num_res_blocks
|
| 1070 |
+
self.grad_checkpointing = grad_checkpointing
|
| 1071 |
+
|
| 1072 |
+
self.time_embed = TimestepEmbedder(model_channels)
|
| 1073 |
+
self.cond_embed = nn.Linear(z_channels, model_channels)
|
| 1074 |
+
|
| 1075 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
| 1076 |
+
|
| 1077 |
+
res_blocks = []
|
| 1078 |
+
for i in range(num_res_blocks):
|
| 1079 |
+
res_blocks.append(ResBlock(
|
| 1080 |
+
model_channels,
|
| 1081 |
+
))
|
| 1082 |
+
|
| 1083 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 1084 |
+
self.final_layer = FinalLayer(model_channels, out_channels)
|
| 1085 |
+
|
| 1086 |
+
self.initialize_weights()
|
| 1087 |
+
|
| 1088 |
+
def initialize_weights(self):
|
| 1089 |
+
def _basic_init(module):
|
| 1090 |
+
if isinstance(module, nn.Linear):
|
| 1091 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 1092 |
+
if module.bias is not None:
|
| 1093 |
+
nn.init.constant_(module.bias, 0)
|
| 1094 |
+
self.apply(_basic_init)
|
| 1095 |
+
|
| 1096 |
+
# Initialize timestep embedding MLP
|
| 1097 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 1098 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 1099 |
+
|
| 1100 |
+
# Zero-out adaLN modulation layers
|
| 1101 |
+
for block in self.res_blocks:
|
| 1102 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 1103 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 1104 |
+
|
| 1105 |
+
# Zero-out output layers
|
| 1106 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 1107 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 1108 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 1109 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 1110 |
+
|
| 1111 |
+
def forward(self, x, t, c):
|
| 1112 |
+
"""
|
| 1113 |
+
Apply the model to an input batch.
|
| 1114 |
+
:param x: an [N x C] Tensor of inputs.
|
| 1115 |
+
:param t: a 1-D batch of timesteps.
|
| 1116 |
+
:param c: conditioning from AR transformer.
|
| 1117 |
+
:return: an [N x C] Tensor of outputs.
|
| 1118 |
+
"""
|
| 1119 |
+
x = self.input_proj(x)
|
| 1120 |
+
t = self.time_embed(t)
|
| 1121 |
+
c = self.cond_embed(c)
|
| 1122 |
+
|
| 1123 |
+
y = t + c
|
| 1124 |
+
|
| 1125 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 1126 |
+
for block in self.res_blocks:
|
| 1127 |
+
x = checkpoint(block, x, y)
|
| 1128 |
+
else:
|
| 1129 |
+
for block in self.res_blocks:
|
| 1130 |
+
x = block(x, y)
|
| 1131 |
+
|
| 1132 |
+
return self.final_layer(x, y)
|
| 1133 |
+
|
| 1134 |
+
def forward_with_cfg(self, x, t, c, cfg_scale):
|
| 1135 |
+
half = x[: len(x) // 2]
|
| 1136 |
+
combined = torch.cat([half, half], dim=0)
|
| 1137 |
+
model_out = self.forward(combined, t, c)
|
| 1138 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 1139 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 1140 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 1141 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 1142 |
+
return torch.cat([eps, rest], dim=1)
|
modeling/modules/fuzzy_embedding.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import einops
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class FuzzyEmbedding(nn.Module):
|
| 8 |
+
def __init__(self, grid_size, scale, width, apply_fuzzy=False):
|
| 9 |
+
super(FuzzyEmbedding, self).__init__()
|
| 10 |
+
assert grid_size == 1024, "grid_size must be 1024 for now"
|
| 11 |
+
|
| 12 |
+
self.grid_size = grid_size
|
| 13 |
+
self.scale = scale
|
| 14 |
+
self.width = width
|
| 15 |
+
self.apply_fuzzy = apply_fuzzy
|
| 16 |
+
# grid_size is the minimum possible token size
|
| 17 |
+
# then we can use grid_sample to get the fuzzy embedding for any resolution
|
| 18 |
+
self.positional_embedding = nn.Parameter(
|
| 19 |
+
scale * torch.randn(grid_size, width))
|
| 20 |
+
|
| 21 |
+
self.class_positional_embedding = nn.Parameter(
|
| 22 |
+
scale * torch.randn(1, width))
|
| 23 |
+
|
| 24 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 25 |
+
def forward(self, grid_height, grid_width, train=True, dtype=torch.float32):
|
| 26 |
+
meshx, meshy = torch.meshgrid(
|
| 27 |
+
torch.tensor(list(range(grid_height)), device=self.positional_embedding.device),
|
| 28 |
+
torch.tensor(list(range(grid_width)), device=self.positional_embedding.device)
|
| 29 |
+
)
|
| 30 |
+
meshx = meshx.to(dtype)
|
| 31 |
+
meshy = meshy.to(dtype)
|
| 32 |
+
|
| 33 |
+
# Normalize coordinates to [-1, 1] range
|
| 34 |
+
meshx = 2 * (meshx / (grid_height - 1)) - 1
|
| 35 |
+
meshy = 2 * (meshy / (grid_width - 1)) - 1
|
| 36 |
+
|
| 37 |
+
if self.apply_fuzzy:
|
| 38 |
+
# Add uniform noise in range [-0.0004, 0.0004] to x and y coordinates
|
| 39 |
+
if train:
|
| 40 |
+
noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004
|
| 41 |
+
noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004
|
| 42 |
+
else:
|
| 43 |
+
noise_x = torch.zeros_like(meshx)
|
| 44 |
+
noise_y = torch.zeros_like(meshy)
|
| 45 |
+
|
| 46 |
+
# Apply noise to the mesh coordinates
|
| 47 |
+
meshx = meshx + noise_x
|
| 48 |
+
meshy = meshy + noise_y
|
| 49 |
+
|
| 50 |
+
grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device)
|
| 51 |
+
grid = grid.unsqueeze(0) # add batch dim
|
| 52 |
+
|
| 53 |
+
positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size)))
|
| 54 |
+
positional_embedding = positional_embedding.to(dtype)
|
| 55 |
+
positional_embedding = positional_embedding.unsqueeze(0) # add batch dim
|
| 56 |
+
|
| 57 |
+
fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False)
|
| 58 |
+
fuzzy_embedding = fuzzy_embedding.to(dtype)
|
| 59 |
+
fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0)
|
| 60 |
+
|
| 61 |
+
final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0)
|
| 62 |
+
return final_embedding
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024)
|
| 67 |
+
grid_height = 16
|
| 68 |
+
grid_width = 32
|
| 69 |
+
fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16)
|
| 70 |
+
print(fuzzy_embedding.shape)
|
modeling/modules/losses.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loss implementation.
|
| 2 |
+
|
| 3 |
+
Ref:
|
| 4 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py
|
| 5 |
+
"""
|
| 6 |
+
from typing import Mapping, Text, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from torch.cuda.amp import autocast
|
| 13 |
+
|
| 14 |
+
from modeling.modules.blocks import SimpleMLPAdaLN
|
| 15 |
+
from .perceptual_loss import PerceptualLoss
|
| 16 |
+
from .discriminator import NLayerDiscriminator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""Hinge loss for discrminator.
|
| 21 |
+
|
| 22 |
+
This function is borrowed from
|
| 23 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20
|
| 24 |
+
"""
|
| 25 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
| 26 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
| 27 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
| 28 |
+
return d_loss
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_lecam_loss(
|
| 32 |
+
logits_real_mean: torch.Tensor,
|
| 33 |
+
logits_fake_mean: torch.Tensor,
|
| 34 |
+
ema_logits_real_mean: torch.Tensor,
|
| 35 |
+
ema_logits_fake_mean: torch.Tensor
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""Computes the LeCam loss for the given average real and fake logits.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
logits_real_mean -> torch.Tensor: The average real logits.
|
| 41 |
+
logits_fake_mean -> torch.Tensor: The average fake logits.
|
| 42 |
+
ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
|
| 43 |
+
ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
lecam_loss -> torch.Tensor: The LeCam loss.
|
| 47 |
+
"""
|
| 48 |
+
lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2))
|
| 49 |
+
lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2))
|
| 50 |
+
return lecam_loss
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ReconstructionLoss_Stage1(torch.nn.Module):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
config
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
loss_config = config.losses
|
| 60 |
+
self.quantizer_weight = loss_config.quantizer_weight
|
| 61 |
+
self.target_codebook_size = 1024
|
| 62 |
+
|
| 63 |
+
def forward(self,
|
| 64 |
+
target_codes: torch.Tensor,
|
| 65 |
+
reconstructions: torch.Tensor,
|
| 66 |
+
quantizer_loss: torch.Tensor,
|
| 67 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 68 |
+
return self._forward_generator(target_codes, reconstructions, quantizer_loss)
|
| 69 |
+
|
| 70 |
+
def _forward_generator(self,
|
| 71 |
+
target_codes: torch.Tensor,
|
| 72 |
+
reconstructions: torch.Tensor,
|
| 73 |
+
quantizer_loss: Mapping[Text, torch.Tensor],
|
| 74 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 75 |
+
reconstructions = reconstructions.contiguous()
|
| 76 |
+
loss_fct = nn.CrossEntropyLoss(reduction="mean")
|
| 77 |
+
batch_size = reconstructions.shape[0]
|
| 78 |
+
reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1),
|
| 79 |
+
target_codes.view(batch_size, -1))
|
| 80 |
+
total_loss = reconstruction_loss + \
|
| 81 |
+
self.quantizer_weight * quantizer_loss["quantizer_loss"]
|
| 82 |
+
|
| 83 |
+
loss_dict = dict(
|
| 84 |
+
total_loss=total_loss.clone().detach(),
|
| 85 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
| 86 |
+
quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(),
|
| 87 |
+
commitment_loss=quantizer_loss["commitment_loss"].detach(),
|
| 88 |
+
codebook_loss=quantizer_loss["codebook_loss"].detach(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return total_loss, loss_dict
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ReconstructionLoss_Stage2(torch.nn.Module):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
config
|
| 98 |
+
):
|
| 99 |
+
"""Initializes the losses module.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
config: A dictionary, the configuration for the model and everything else.
|
| 103 |
+
"""
|
| 104 |
+
super().__init__()
|
| 105 |
+
loss_config = config.losses
|
| 106 |
+
self.discriminator = NLayerDiscriminator()
|
| 107 |
+
|
| 108 |
+
self.reconstruction_loss = loss_config.reconstruction_loss
|
| 109 |
+
self.reconstruction_weight = loss_config.reconstruction_weight
|
| 110 |
+
self.quantizer_weight = loss_config.quantizer_weight
|
| 111 |
+
self.perceptual_loss = PerceptualLoss(
|
| 112 |
+
loss_config.perceptual_loss).eval()
|
| 113 |
+
self.perceptual_weight = loss_config.perceptual_weight
|
| 114 |
+
self.discriminator_iter_start = loss_config.discriminator_start
|
| 115 |
+
|
| 116 |
+
self.discriminator_factor = loss_config.discriminator_factor
|
| 117 |
+
self.discriminator_weight = loss_config.discriminator_weight
|
| 118 |
+
self.lecam_regularization_weight = loss_config.lecam_regularization_weight
|
| 119 |
+
self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999)
|
| 120 |
+
if self.lecam_regularization_weight > 0.0:
|
| 121 |
+
self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
|
| 122 |
+
self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
|
| 123 |
+
|
| 124 |
+
self.config = config
|
| 125 |
+
|
| 126 |
+
@autocast(enabled=False)
|
| 127 |
+
def forward(self,
|
| 128 |
+
inputs: torch.Tensor,
|
| 129 |
+
reconstructions: torch.Tensor,
|
| 130 |
+
extra_result_dict: Mapping[Text, torch.Tensor],
|
| 131 |
+
global_step: int,
|
| 132 |
+
mode: str = "generator",
|
| 133 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 134 |
+
# Both inputs and reconstructions are in range [0, 1].
|
| 135 |
+
inputs = inputs.float()
|
| 136 |
+
reconstructions = reconstructions.float()
|
| 137 |
+
|
| 138 |
+
if mode == "generator":
|
| 139 |
+
return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step)
|
| 140 |
+
elif mode == "discriminator":
|
| 141 |
+
return self._forward_discriminator(inputs, reconstructions, global_step)
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"Unsupported mode {mode}")
|
| 144 |
+
|
| 145 |
+
def should_discriminator_be_trained(self, global_step : int):
|
| 146 |
+
return global_step >= self.discriminator_iter_start
|
| 147 |
+
|
| 148 |
+
def _forward_generator(self,
|
| 149 |
+
inputs: torch.Tensor,
|
| 150 |
+
reconstructions: torch.Tensor,
|
| 151 |
+
extra_result_dict: Mapping[Text, torch.Tensor],
|
| 152 |
+
global_step: int
|
| 153 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 154 |
+
"""Generator training step."""
|
| 155 |
+
inputs = inputs.contiguous()
|
| 156 |
+
reconstructions = reconstructions.contiguous()
|
| 157 |
+
if self.reconstruction_loss == "l1":
|
| 158 |
+
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
|
| 159 |
+
elif self.reconstruction_loss == "l2":
|
| 160 |
+
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
|
| 163 |
+
reconstruction_loss *= self.reconstruction_weight
|
| 164 |
+
|
| 165 |
+
# Compute perceptual loss.
|
| 166 |
+
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
|
| 167 |
+
|
| 168 |
+
# Compute discriminator loss.
|
| 169 |
+
generator_loss = torch.zeros((), device=inputs.device)
|
| 170 |
+
discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
|
| 171 |
+
d_weight = 1.0
|
| 172 |
+
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
|
| 173 |
+
# Disable discriminator gradients.
|
| 174 |
+
for param in self.discriminator.parameters():
|
| 175 |
+
param.requires_grad = False
|
| 176 |
+
logits_fake = self.discriminator(reconstructions)
|
| 177 |
+
generator_loss = -torch.mean(logits_fake)
|
| 178 |
+
|
| 179 |
+
d_weight *= self.discriminator_weight
|
| 180 |
+
|
| 181 |
+
# Compute quantizer loss.
|
| 182 |
+
quantizer_loss = extra_result_dict["quantizer_loss"]
|
| 183 |
+
total_loss = (
|
| 184 |
+
reconstruction_loss
|
| 185 |
+
+ self.perceptual_weight * perceptual_loss
|
| 186 |
+
+ self.quantizer_weight * quantizer_loss
|
| 187 |
+
+ d_weight * discriminator_factor * generator_loss
|
| 188 |
+
)
|
| 189 |
+
loss_dict = dict(
|
| 190 |
+
total_loss=total_loss.clone().detach(),
|
| 191 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
| 192 |
+
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
|
| 193 |
+
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
|
| 194 |
+
weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
|
| 195 |
+
discriminator_factor=torch.tensor(discriminator_factor),
|
| 196 |
+
commitment_loss=extra_result_dict["commitment_loss"].detach(),
|
| 197 |
+
codebook_loss=extra_result_dict["codebook_loss"].detach(),
|
| 198 |
+
d_weight=d_weight,
|
| 199 |
+
gan_loss=generator_loss.detach(),
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return total_loss, loss_dict
|
| 203 |
+
|
| 204 |
+
def _forward_discriminator(self,
|
| 205 |
+
inputs: torch.Tensor,
|
| 206 |
+
reconstructions: torch.Tensor,
|
| 207 |
+
global_step: int,
|
| 208 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 209 |
+
"""Discrminator training step."""
|
| 210 |
+
discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
|
| 211 |
+
loss_dict = {}
|
| 212 |
+
# Turn the gradients on.
|
| 213 |
+
for param in self.discriminator.parameters():
|
| 214 |
+
param.requires_grad = True
|
| 215 |
+
|
| 216 |
+
real_images = inputs.detach().requires_grad_(True)
|
| 217 |
+
logits_real = self.discriminator(real_images)
|
| 218 |
+
logits_fake = self.discriminator(reconstructions.detach())
|
| 219 |
+
|
| 220 |
+
discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake)
|
| 221 |
+
|
| 222 |
+
# optional lecam regularization
|
| 223 |
+
lecam_loss = torch.zeros((), device=inputs.device)
|
| 224 |
+
if self.lecam_regularization_weight > 0.0:
|
| 225 |
+
lecam_loss = compute_lecam_loss(
|
| 226 |
+
torch.mean(logits_real),
|
| 227 |
+
torch.mean(logits_fake),
|
| 228 |
+
self.ema_real_logits_mean,
|
| 229 |
+
self.ema_fake_logits_mean
|
| 230 |
+
) * self.lecam_regularization_weight
|
| 231 |
+
|
| 232 |
+
self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay)
|
| 233 |
+
self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay)
|
| 234 |
+
|
| 235 |
+
discriminator_loss += lecam_loss
|
| 236 |
+
|
| 237 |
+
loss_dict = dict(
|
| 238 |
+
discriminator_loss=discriminator_loss.detach(),
|
| 239 |
+
logits_real=logits_real.detach().mean(),
|
| 240 |
+
logits_fake=logits_fake.detach().mean(),
|
| 241 |
+
lecam_loss=lecam_loss.detach(),
|
| 242 |
+
)
|
| 243 |
+
return discriminator_loss, loss_dict
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2):
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
config
|
| 250 |
+
):
|
| 251 |
+
super().__init__(config)
|
| 252 |
+
loss_config = config.losses
|
| 253 |
+
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
|
| 254 |
+
|
| 255 |
+
if self.quantize_mode == "vae":
|
| 256 |
+
self.kl_weight = loss_config.get("kl_weight", 1e-6)
|
| 257 |
+
logvar_init = loss_config.get("logvar_init", 0.0)
|
| 258 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init, requires_grad=False)
|
| 259 |
+
|
| 260 |
+
def _forward_generator(self,
|
| 261 |
+
inputs: torch.Tensor,
|
| 262 |
+
reconstructions: torch.Tensor,
|
| 263 |
+
extra_result_dict: Mapping[Text, torch.Tensor],
|
| 264 |
+
global_step: int
|
| 265 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 266 |
+
"""Generator training step."""
|
| 267 |
+
inputs = inputs.contiguous()
|
| 268 |
+
reconstructions = reconstructions.contiguous()
|
| 269 |
+
if self.reconstruction_loss == "l1":
|
| 270 |
+
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
|
| 271 |
+
elif self.reconstruction_loss == "l2":
|
| 272 |
+
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
|
| 275 |
+
reconstruction_loss *= self.reconstruction_weight
|
| 276 |
+
|
| 277 |
+
# Compute perceptual loss.
|
| 278 |
+
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
|
| 279 |
+
|
| 280 |
+
# Compute discriminator loss.
|
| 281 |
+
generator_loss = torch.zeros((), device=inputs.device)
|
| 282 |
+
discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
|
| 283 |
+
d_weight = 1.0
|
| 284 |
+
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
|
| 285 |
+
# Disable discriminator gradients.
|
| 286 |
+
for param in self.discriminator.parameters():
|
| 287 |
+
param.requires_grad = False
|
| 288 |
+
logits_fake = self.discriminator(reconstructions)
|
| 289 |
+
generator_loss = -torch.mean(logits_fake)
|
| 290 |
+
|
| 291 |
+
d_weight *= self.discriminator_weight
|
| 292 |
+
|
| 293 |
+
if self.quantize_mode in ["vq", "mvq", "softvq"]:
|
| 294 |
+
# Compute quantizer loss.
|
| 295 |
+
quantizer_loss = extra_result_dict["quantizer_loss"]
|
| 296 |
+
total_loss = (
|
| 297 |
+
reconstruction_loss
|
| 298 |
+
+ self.perceptual_weight * perceptual_loss
|
| 299 |
+
+ self.quantizer_weight * quantizer_loss
|
| 300 |
+
+ d_weight * discriminator_factor * generator_loss
|
| 301 |
+
)
|
| 302 |
+
loss_dict = dict(
|
| 303 |
+
total_loss=total_loss.clone().detach(),
|
| 304 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
| 305 |
+
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
|
| 306 |
+
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
|
| 307 |
+
weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
|
| 308 |
+
discriminator_factor=torch.tensor(discriminator_factor),
|
| 309 |
+
commitment_loss=extra_result_dict["commitment_loss"].detach(),
|
| 310 |
+
codebook_loss=extra_result_dict["codebook_loss"].detach(),
|
| 311 |
+
d_weight=d_weight,
|
| 312 |
+
gan_loss=generator_loss.detach(),
|
| 313 |
+
)
|
| 314 |
+
elif self.quantize_mode == "vae":
|
| 315 |
+
# Compute kl loss.
|
| 316 |
+
reconstruction_loss = reconstruction_loss / torch.exp(self.logvar)
|
| 317 |
+
posteriors = extra_result_dict
|
| 318 |
+
kl_loss = posteriors.kl()
|
| 319 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
| 320 |
+
total_loss = (
|
| 321 |
+
reconstruction_loss
|
| 322 |
+
+ self.perceptual_weight * perceptual_loss
|
| 323 |
+
+ self.kl_weight * kl_loss
|
| 324 |
+
+ d_weight * discriminator_factor * generator_loss
|
| 325 |
+
)
|
| 326 |
+
loss_dict = dict(
|
| 327 |
+
total_loss=total_loss.clone().detach(),
|
| 328 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
| 329 |
+
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
|
| 330 |
+
kl_loss=(self.kl_weight * kl_loss).detach(),
|
| 331 |
+
weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
|
| 332 |
+
discriminator_factor=torch.tensor(discriminator_factor),
|
| 333 |
+
d_weight=d_weight,
|
| 334 |
+
gan_loss=generator_loss.detach(),
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
raise NotImplementedError
|
| 338 |
+
|
| 339 |
+
return total_loss, loss_dict
|
modeling/modules/lpips.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LPIPS perceptual loss.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/richzhang/PerceptualSimilarity/
|
| 5 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py
|
| 6 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/util.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import hashlib
|
| 11 |
+
import requests
|
| 12 |
+
from collections import namedtuple
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from torchvision import models
|
| 19 |
+
|
| 20 |
+
_LPIPS_MEAN = [-0.030, -0.088, -0.188]
|
| 21 |
+
_LPIPS_STD = [0.458, 0.448, 0.450]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
URL_MAP = {
|
| 25 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
CKPT_MAP = {
|
| 29 |
+
"vgg_lpips": "vgg.pth"
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
MD5_MAP = {
|
| 33 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def download(url, local_path, chunk_size=1024):
|
| 38 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
| 39 |
+
with requests.get(url, stream=True) as r:
|
| 40 |
+
total_size = int(r.headers.get("content-length", 0))
|
| 41 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
| 42 |
+
with open(local_path, "wb") as f:
|
| 43 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 44 |
+
if data:
|
| 45 |
+
f.write(data)
|
| 46 |
+
pbar.update(chunk_size)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def md5_hash(path):
|
| 50 |
+
with open(path, "rb") as f:
|
| 51 |
+
content = f.read()
|
| 52 |
+
return hashlib.md5(content).hexdigest()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_ckpt_path(name, root, check=False):
|
| 56 |
+
assert name in URL_MAP
|
| 57 |
+
path = os.path.join(root, CKPT_MAP[name])
|
| 58 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
| 59 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
| 60 |
+
download(URL_MAP[name], path)
|
| 61 |
+
md5 = md5_hash(path)
|
| 62 |
+
assert md5 == MD5_MAP[name], md5
|
| 63 |
+
return path
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LPIPS(nn.Module):
|
| 67 |
+
# Learned perceptual metric.
|
| 68 |
+
def __init__(self, use_dropout=True):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.scaling_layer = ScalingLayer()
|
| 71 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
| 72 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
| 73 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
| 74 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
| 75 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
| 76 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
| 77 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
| 78 |
+
self.load_pretrained()
|
| 79 |
+
for param in self.parameters():
|
| 80 |
+
param.requires_grad = False
|
| 81 |
+
|
| 82 |
+
def load_pretrained(self):
|
| 83 |
+
workspace = os.environ.get('WORKSPACE', '')
|
| 84 |
+
VGG_PATH = get_ckpt_path("vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True)
|
| 85 |
+
self.load_state_dict(torch.load(VGG_PATH, map_location=torch.device("cpu"), weights_only=True), strict=False)
|
| 86 |
+
|
| 87 |
+
def forward(self, input, target):
|
| 88 |
+
# Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1].
|
| 89 |
+
# However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed.
|
| 90 |
+
input = input * 2. - 1.
|
| 91 |
+
target = target * 2. - 1.
|
| 92 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
| 93 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
| 94 |
+
feats0, feats1, diffs = {}, {}, {}
|
| 95 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
| 96 |
+
for kk in range(len(self.chns)):
|
| 97 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
| 98 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
| 99 |
+
|
| 100 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
| 101 |
+
val = res[0]
|
| 102 |
+
for l in range(1, len(self.chns)):
|
| 103 |
+
val += res[l]
|
| 104 |
+
return val
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ScalingLayer(nn.Module):
|
| 108 |
+
def __init__(self):
|
| 109 |
+
super(ScalingLayer, self).__init__()
|
| 110 |
+
self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None])
|
| 111 |
+
self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None])
|
| 112 |
+
|
| 113 |
+
def forward(self, inp):
|
| 114 |
+
return (inp - self.shift) / self.scale
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NetLinLayer(nn.Module):
|
| 118 |
+
"""A single linear layer which does a 1x1 conv."""
|
| 119 |
+
|
| 120 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
| 121 |
+
super(NetLinLayer, self).__init__()
|
| 122 |
+
layers = (
|
| 123 |
+
[
|
| 124 |
+
nn.Dropout(),
|
| 125 |
+
]
|
| 126 |
+
if (use_dropout)
|
| 127 |
+
else []
|
| 128 |
+
)
|
| 129 |
+
layers += [
|
| 130 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
| 131 |
+
]
|
| 132 |
+
self.model = nn.Sequential(*layers)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class vgg16(torch.nn.Module):
|
| 136 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
| 137 |
+
super(vgg16, self).__init__()
|
| 138 |
+
vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
|
| 139 |
+
self.slice1 = torch.nn.Sequential()
|
| 140 |
+
self.slice2 = torch.nn.Sequential()
|
| 141 |
+
self.slice3 = torch.nn.Sequential()
|
| 142 |
+
self.slice4 = torch.nn.Sequential()
|
| 143 |
+
self.slice5 = torch.nn.Sequential()
|
| 144 |
+
self.N_slices = 5
|
| 145 |
+
for x in range(4):
|
| 146 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 147 |
+
for x in range(4, 9):
|
| 148 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 149 |
+
for x in range(9, 16):
|
| 150 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 151 |
+
for x in range(16, 23):
|
| 152 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 153 |
+
for x in range(23, 30):
|
| 154 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 155 |
+
if not requires_grad:
|
| 156 |
+
for param in self.parameters():
|
| 157 |
+
param.requires_grad = False
|
| 158 |
+
|
| 159 |
+
def forward(self, X):
|
| 160 |
+
h = self.slice1(X)
|
| 161 |
+
h_relu1_2 = h
|
| 162 |
+
h = self.slice2(h)
|
| 163 |
+
h_relu2_2 = h
|
| 164 |
+
h = self.slice3(h)
|
| 165 |
+
h_relu3_3 = h
|
| 166 |
+
h = self.slice4(h)
|
| 167 |
+
h_relu4_3 = h
|
| 168 |
+
h = self.slice5(h)
|
| 169 |
+
h_relu5_3 = h
|
| 170 |
+
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
|
| 171 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def normalize_tensor(x, eps=1e-10):
|
| 176 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
| 177 |
+
return x / (norm_factor + eps)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def spatial_average(x, keepdim=True):
|
| 181 |
+
return x.mean([2, 3], keepdim=keepdim)
|
modeling/modules/maskgit_vqgan.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MaskGIT-VQGAN tokenizer.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
r"""MaskGIT Tokenizer based on VQGAN.
|
| 8 |
+
|
| 9 |
+
This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841]
|
| 10 |
+
with several modifications. The non-local layers are removed from VQGAN for
|
| 11 |
+
faster speed.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Conv2D with same padding
|
| 22 |
+
class Conv2dSame(nn.Conv2d):
|
| 23 |
+
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
|
| 24 |
+
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
| 25 |
+
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
ih, iw = x.size()[-2:]
|
| 28 |
+
|
| 29 |
+
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
|
| 30 |
+
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
|
| 31 |
+
|
| 32 |
+
if pad_h > 0 or pad_w > 0:
|
| 33 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 34 |
+
return super().forward(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResnetBlock(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
in_channels: int,
|
| 41 |
+
out_channels: int = None,
|
| 42 |
+
dropout_prob: float = 0.0,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.in_channels = in_channels
|
| 47 |
+
self.out_channels = out_channels
|
| 48 |
+
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
|
| 49 |
+
|
| 50 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 51 |
+
self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False)
|
| 52 |
+
|
| 53 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
|
| 54 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 55 |
+
self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False)
|
| 56 |
+
|
| 57 |
+
if self.in_channels != self.out_channels_:
|
| 58 |
+
self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False)
|
| 59 |
+
|
| 60 |
+
def forward(self, hidden_states):
|
| 61 |
+
residual = hidden_states
|
| 62 |
+
hidden_states = self.norm1(hidden_states)
|
| 63 |
+
hidden_states = F.silu(hidden_states)
|
| 64 |
+
hidden_states = self.conv1(hidden_states)
|
| 65 |
+
|
| 66 |
+
hidden_states = self.norm2(hidden_states)
|
| 67 |
+
hidden_states = F.silu(hidden_states)
|
| 68 |
+
hidden_states = self.dropout(hidden_states)
|
| 69 |
+
hidden_states = self.conv2(hidden_states)
|
| 70 |
+
|
| 71 |
+
if self.in_channels != self.out_channels_:
|
| 72 |
+
residual = self.nin_shortcut(hidden_states)
|
| 73 |
+
|
| 74 |
+
return hidden_states + residual
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DownsamplingBlock(nn.Module):
|
| 78 |
+
def __init__(self, config, block_idx: int):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.config = config
|
| 82 |
+
self.block_idx = block_idx
|
| 83 |
+
|
| 84 |
+
in_channel_mult = (1,) + tuple(self.config.channel_mult)
|
| 85 |
+
block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
|
| 86 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
|
| 87 |
+
|
| 88 |
+
res_blocks = nn.ModuleList()
|
| 89 |
+
for _ in range(self.config.num_res_blocks):
|
| 90 |
+
res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
|
| 91 |
+
block_in = block_out
|
| 92 |
+
self.block = res_blocks
|
| 93 |
+
|
| 94 |
+
self.downsample = self.block_idx != self.config.num_resolutions - 1
|
| 95 |
+
|
| 96 |
+
def forward(self, hidden_states):
|
| 97 |
+
for res_block in self.block:
|
| 98 |
+
hidden_states = res_block(hidden_states)
|
| 99 |
+
|
| 100 |
+
if self.downsample:
|
| 101 |
+
hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
|
| 102 |
+
|
| 103 |
+
return hidden_states
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class UpsamplingBlock(nn.Module):
|
| 107 |
+
def __init__(self, config, block_idx: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self.config = config
|
| 111 |
+
self.block_idx = block_idx
|
| 112 |
+
|
| 113 |
+
if self.block_idx == self.config.num_resolutions - 1:
|
| 114 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[-1]
|
| 115 |
+
else:
|
| 116 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
|
| 117 |
+
|
| 118 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
|
| 119 |
+
|
| 120 |
+
res_blocks = []
|
| 121 |
+
for _ in range(self.config.num_res_blocks):
|
| 122 |
+
res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
|
| 123 |
+
block_in = block_out
|
| 124 |
+
self.block = nn.ModuleList(res_blocks)
|
| 125 |
+
|
| 126 |
+
self.add_upsample = self.block_idx != 0
|
| 127 |
+
if self.add_upsample:
|
| 128 |
+
self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)
|
| 129 |
+
|
| 130 |
+
def forward(self, hidden_states):
|
| 131 |
+
for res_block in self.block:
|
| 132 |
+
hidden_states = res_block(hidden_states)
|
| 133 |
+
|
| 134 |
+
if self.add_upsample:
|
| 135 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 136 |
+
hidden_states = self.upsample_conv(hidden_states)
|
| 137 |
+
|
| 138 |
+
return hidden_states
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Encoder(nn.Module):
|
| 142 |
+
def __init__(self, config):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.config = config
|
| 145 |
+
# downsampling
|
| 146 |
+
self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False)
|
| 147 |
+
|
| 148 |
+
downsample_blocks = []
|
| 149 |
+
for i_level in range(self.config.num_resolutions):
|
| 150 |
+
downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level))
|
| 151 |
+
self.down = nn.ModuleList(downsample_blocks)
|
| 152 |
+
|
| 153 |
+
# middle
|
| 154 |
+
mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
|
| 155 |
+
res_blocks = nn.ModuleList()
|
| 156 |
+
for _ in range(self.config.num_res_blocks):
|
| 157 |
+
res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout))
|
| 158 |
+
self.mid = res_blocks
|
| 159 |
+
|
| 160 |
+
# end
|
| 161 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
|
| 162 |
+
self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1)
|
| 163 |
+
|
| 164 |
+
def forward(self, pixel_values):
|
| 165 |
+
# downsampling
|
| 166 |
+
hidden_states = self.conv_in(pixel_values)
|
| 167 |
+
for block in self.down:
|
| 168 |
+
hidden_states = block(hidden_states)
|
| 169 |
+
|
| 170 |
+
# middle
|
| 171 |
+
for block in self.mid:
|
| 172 |
+
hidden_states = block(hidden_states)
|
| 173 |
+
|
| 174 |
+
# end
|
| 175 |
+
hidden_states = self.norm_out(hidden_states)
|
| 176 |
+
hidden_states = F.silu(hidden_states)
|
| 177 |
+
hidden_states = self.conv_out(hidden_states)
|
| 178 |
+
return hidden_states
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class Decoder(nn.Module):
|
| 182 |
+
def __init__(self, config):
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
self.config = config
|
| 186 |
+
|
| 187 |
+
# compute in_channel_mult, block_in and curr_res at lowest res
|
| 188 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
|
| 189 |
+
curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
|
| 190 |
+
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
|
| 191 |
+
|
| 192 |
+
# z to block_in
|
| 193 |
+
self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3)
|
| 194 |
+
|
| 195 |
+
# middle
|
| 196 |
+
res_blocks = nn.ModuleList()
|
| 197 |
+
for _ in range(self.config.num_res_blocks):
|
| 198 |
+
res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout))
|
| 199 |
+
self.mid = res_blocks
|
| 200 |
+
|
| 201 |
+
# upsampling
|
| 202 |
+
upsample_blocks = []
|
| 203 |
+
for i_level in reversed(range(self.config.num_resolutions)):
|
| 204 |
+
upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))
|
| 205 |
+
self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
|
| 206 |
+
|
| 207 |
+
# end
|
| 208 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[0]
|
| 209 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
|
| 210 |
+
self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3)
|
| 211 |
+
|
| 212 |
+
def forward(self, hidden_states):
|
| 213 |
+
# z to block_in
|
| 214 |
+
hidden_states = self.conv_in(hidden_states)
|
| 215 |
+
|
| 216 |
+
# middle
|
| 217 |
+
for block in self.mid:
|
| 218 |
+
hidden_states = block(hidden_states)
|
| 219 |
+
|
| 220 |
+
# upsampling
|
| 221 |
+
for block in reversed(self.up):
|
| 222 |
+
hidden_states = block(hidden_states)
|
| 223 |
+
|
| 224 |
+
# end
|
| 225 |
+
hidden_states = self.norm_out(hidden_states)
|
| 226 |
+
hidden_states = F.silu(hidden_states)
|
| 227 |
+
hidden_states = self.conv_out(hidden_states)
|
| 228 |
+
|
| 229 |
+
return hidden_states
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class VectorQuantizer(nn.Module):
|
| 233 |
+
"""
|
| 234 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
| 235 |
+
Discretization bottleneck part of the VQ-VAE.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
|
| 239 |
+
r"""
|
| 240 |
+
Args:
|
| 241 |
+
num_embeddings: number of vectors in the quantized space.
|
| 242 |
+
embedding_dim: dimensionality of the tensors in the quantized space.
|
| 243 |
+
Inputs to the modules must be in this format as well.
|
| 244 |
+
commitment_cost: scalar which controls the weighting of the loss terms
|
| 245 |
+
(see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
|
| 246 |
+
"""
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
self.num_embeddings = num_embeddings
|
| 250 |
+
self.embedding_dim = embedding_dim
|
| 251 |
+
self.commitment_cost = commitment_cost
|
| 252 |
+
|
| 253 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 254 |
+
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
|
| 255 |
+
|
| 256 |
+
def forward(self, hidden_states, return_loss=False):
|
| 257 |
+
"""
|
| 258 |
+
Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
|
| 259 |
+
closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
|
| 260 |
+
quantization pipeline:
|
| 261 |
+
1. get encoder input (B,C,H,W)
|
| 262 |
+
2. flatten input to (B*H*W,C)
|
| 263 |
+
"""
|
| 264 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 265 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
|
| 266 |
+
|
| 267 |
+
distances = self.compute_distances(hidden_states)
|
| 268 |
+
min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
|
| 269 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
|
| 270 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 271 |
+
|
| 272 |
+
# get quantized latent vectors
|
| 273 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
|
| 274 |
+
|
| 275 |
+
# reshape to (batch, num_tokens)
|
| 276 |
+
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
|
| 277 |
+
|
| 278 |
+
# compute loss for embedding
|
| 279 |
+
loss = None
|
| 280 |
+
if return_loss:
|
| 281 |
+
loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
|
| 282 |
+
(z_q - hidden_states.detach()) ** 2
|
| 283 |
+
)
|
| 284 |
+
# preserve gradients
|
| 285 |
+
z_q = hidden_states + (z_q - hidden_states).detach()
|
| 286 |
+
|
| 287 |
+
# reshape back to match original input shape
|
| 288 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 289 |
+
|
| 290 |
+
return z_q, min_encoding_indices, loss
|
| 291 |
+
|
| 292 |
+
def compute_distances(self, hidden_states):
|
| 293 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 294 |
+
hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
|
| 295 |
+
emb_weights = self.embedding.weight.t()
|
| 296 |
+
|
| 297 |
+
inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
|
| 298 |
+
codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
|
| 299 |
+
distances = torch.addmm(
|
| 300 |
+
inputs_norm_sq + codebook_t_norm_sq,
|
| 301 |
+
hidden_states_flattended,
|
| 302 |
+
emb_weights,
|
| 303 |
+
alpha=-2.0,
|
| 304 |
+
)
|
| 305 |
+
return distances
|
| 306 |
+
|
| 307 |
+
def get_codebook_entry(self, indices):
|
| 308 |
+
# indices are expected to be of shape (batch, num_tokens)
|
| 309 |
+
# get quantized latent vectors
|
| 310 |
+
if len(indices.shape) == 2:
|
| 311 |
+
batch, num_tokens = indices.shape
|
| 312 |
+
z_q = self.embedding(indices)
|
| 313 |
+
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
|
| 314 |
+
elif len(indices.shape) == 3:
|
| 315 |
+
batch, height, width = indices.shape
|
| 316 |
+
indices = indices.view(batch, -1)
|
| 317 |
+
z_q = self.embedding(indices)
|
| 318 |
+
z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2)
|
| 319 |
+
else:
|
| 320 |
+
print(indices.shape)
|
| 321 |
+
raise NotImplementedError
|
| 322 |
+
return z_q
|
| 323 |
+
|
| 324 |
+
# adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
|
| 325 |
+
def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
|
| 326 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
|
| 327 |
+
distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
|
| 328 |
+
|
| 329 |
+
soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
|
| 330 |
+
if stochastic:
|
| 331 |
+
code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
|
| 332 |
+
else:
|
| 333 |
+
code = distances.argmin(dim=-1) # (batch * height * width)
|
| 334 |
+
|
| 335 |
+
code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
|
| 336 |
+
batch, num_tokens = code.shape
|
| 337 |
+
soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
|
| 338 |
+
return soft_code, code
|
| 339 |
+
|
| 340 |
+
def get_code(self, hidden_states):
|
| 341 |
+
# reshape z -> (batch, height, width, channel)
|
| 342 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
|
| 343 |
+
distances = self.compute_distances(hidden_states)
|
| 344 |
+
indices = torch.argmin(distances, axis=1).unsqueeze(1)
|
| 345 |
+
indices = indices.reshape(hidden_states.shape[0], -1)
|
| 346 |
+
return indices
|
modeling/modules/perceptual_loss.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Perceptual loss module using LPIPS and ConvNeXt-S."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from torchvision import models
|
| 7 |
+
from .lpips import LPIPS
|
| 8 |
+
|
| 9 |
+
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 10 |
+
_IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PerceptualLoss(torch.nn.Module):
|
| 14 |
+
def __init__(self, model_name: str = "convnext_s"):
|
| 15 |
+
"""Initializes the PerceptualLoss class.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_name: A string, the name of the perceptual loss model to use.
|
| 19 |
+
|
| 20 |
+
Raise:
|
| 21 |
+
ValueError: If the model_name does not contain "lpips" or "convnext_s".
|
| 22 |
+
"""
|
| 23 |
+
super().__init__()
|
| 24 |
+
if ("lpips" not in model_name) and (
|
| 25 |
+
"convnext_s" not in model_name):
|
| 26 |
+
raise ValueError(f"Unsupported Perceptual Loss model name {model_name}")
|
| 27 |
+
self.lpips = None
|
| 28 |
+
self.convnext = None
|
| 29 |
+
self.loss_weight_lpips = None
|
| 30 |
+
self.loss_weight_convnext = None
|
| 31 |
+
|
| 32 |
+
# Parsing the model name. We support name formatted in
|
| 33 |
+
# "lpips-convnext_s-{float_number}-{float_number}", where the
|
| 34 |
+
# {float_number} refers to the loss weight for each component.
|
| 35 |
+
# E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss
|
| 36 |
+
# using both the convnext_s and lpips, and average the final loss with
|
| 37 |
+
# (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0).
|
| 38 |
+
if "lpips" in model_name:
|
| 39 |
+
self.lpips = LPIPS().eval()
|
| 40 |
+
|
| 41 |
+
if "convnext_s" in model_name:
|
| 42 |
+
self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval()
|
| 43 |
+
|
| 44 |
+
if "lpips" in model_name and "convnext_s" in model_name:
|
| 45 |
+
loss_config = model_name.split('-')[-2:]
|
| 46 |
+
self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1])
|
| 47 |
+
print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")
|
| 48 |
+
|
| 49 |
+
self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None])
|
| 50 |
+
self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None])
|
| 51 |
+
|
| 52 |
+
for param in self.parameters():
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
|
| 55 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
| 56 |
+
"""Computes the perceptual loss.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1].
|
| 60 |
+
target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1].
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
A scalar tensor, the perceptual loss.
|
| 64 |
+
"""
|
| 65 |
+
# Always in eval mode.
|
| 66 |
+
self.eval()
|
| 67 |
+
loss = 0.
|
| 68 |
+
num_losses = 0.
|
| 69 |
+
lpips_loss = 0.
|
| 70 |
+
convnext_loss = 0.
|
| 71 |
+
# Computes LPIPS loss, if available.
|
| 72 |
+
if self.lpips is not None:
|
| 73 |
+
lpips_loss = self.lpips(input, target)
|
| 74 |
+
if self.loss_weight_lpips is None:
|
| 75 |
+
loss += lpips_loss
|
| 76 |
+
num_losses += 1
|
| 77 |
+
else:
|
| 78 |
+
num_losses += self.loss_weight_lpips
|
| 79 |
+
loss += self.loss_weight_lpips * lpips_loss
|
| 80 |
+
|
| 81 |
+
if self.convnext is not None:
|
| 82 |
+
# Computes ConvNeXt-s loss, if available.
|
| 83 |
+
input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True)
|
| 84 |
+
target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True)
|
| 85 |
+
pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std)
|
| 86 |
+
pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std)
|
| 87 |
+
convnext_loss = torch.nn.functional.mse_loss(
|
| 88 |
+
pred_input,
|
| 89 |
+
pred_target,
|
| 90 |
+
reduction="mean")
|
| 91 |
+
|
| 92 |
+
if self.loss_weight_convnext is None:
|
| 93 |
+
num_losses += 1
|
| 94 |
+
loss += convnext_loss
|
| 95 |
+
else:
|
| 96 |
+
num_losses += self.loss_weight_convnext
|
| 97 |
+
loss += self.loss_weight_convnext * convnext_loss
|
| 98 |
+
|
| 99 |
+
# weighted avg.
|
| 100 |
+
loss = loss / num_losses
|
| 101 |
+
return loss
|
modeling/quantizer/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import VectorQuantizer, DiagonalGaussianDistribution
|
| 2 |
+
from .mvq import VectorQuantizerMVQ
|
| 3 |
+
from .softvq import SoftVectorQuantizer
|
modeling/quantizer/dist.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from typing import List
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import pytz
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as tdist
|
| 11 |
+
import torch.multiprocessing as mp
|
| 12 |
+
|
| 13 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 14 |
+
__rank_str_zfill = '0'
|
| 15 |
+
__initialized = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def initialized():
|
| 19 |
+
return __initialized
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
|
| 23 |
+
global __device
|
| 24 |
+
if not torch.cuda.is_available():
|
| 25 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
| 26 |
+
return
|
| 27 |
+
elif 'RANK' not in os.environ:
|
| 28 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
| 29 |
+
__device = torch.empty(1).cuda().device
|
| 30 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
| 31 |
+
return
|
| 32 |
+
# then 'RANK' must exist
|
| 33 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
| 34 |
+
local_rank = global_rank % num_gpus
|
| 35 |
+
torch.cuda.set_device(local_rank)
|
| 36 |
+
|
| 37 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
| 38 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 39 |
+
method = 'fork' if fork else 'spawn'
|
| 40 |
+
print(f'[dist initialize] mp method={method}')
|
| 41 |
+
mp.set_start_method(method)
|
| 42 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
|
| 43 |
+
|
| 44 |
+
global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
|
| 45 |
+
__local_rank = local_rank
|
| 46 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
| 47 |
+
__rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
|
| 48 |
+
__device = torch.empty(1).cuda().device
|
| 49 |
+
__initialized = True
|
| 50 |
+
|
| 51 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
| 52 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_rank():
|
| 56 |
+
return __rank
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_rank_str_zfill():
|
| 60 |
+
return __rank_str_zfill
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_local_rank():
|
| 64 |
+
return __local_rank
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_world_size():
|
| 68 |
+
return __world_size
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_device():
|
| 72 |
+
return __device
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def set_gpu_id(gpu_id: int):
|
| 76 |
+
if gpu_id is None: return
|
| 77 |
+
global __device
|
| 78 |
+
if isinstance(gpu_id, (str, int)):
|
| 79 |
+
torch.cuda.set_device(int(gpu_id))
|
| 80 |
+
__device = torch.empty(1).cuda().device
|
| 81 |
+
else:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def is_master():
|
| 86 |
+
return __rank == 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def is_local_master():
|
| 90 |
+
return __local_rank == 0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def new_group(ranks: List[int]):
|
| 94 |
+
if __initialized:
|
| 95 |
+
return tdist.new_group(ranks=ranks)
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def new_local_machine_group():
|
| 100 |
+
if __initialized:
|
| 101 |
+
cur_subgroup, subgroups = tdist.new_subgroups()
|
| 102 |
+
return cur_subgroup
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def barrier():
|
| 107 |
+
if __initialized:
|
| 108 |
+
tdist.barrier()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
| 112 |
+
if __initialized:
|
| 113 |
+
if not t.is_cuda:
|
| 114 |
+
cu = t.detach().cuda()
|
| 115 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
| 116 |
+
t.copy_(cu.cpu())
|
| 117 |
+
else:
|
| 118 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
| 119 |
+
return ret
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 124 |
+
if __initialized:
|
| 125 |
+
if not t.is_cuda:
|
| 126 |
+
t = t.cuda()
|
| 127 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
| 128 |
+
tdist.all_gather(ls, t)
|
| 129 |
+
else:
|
| 130 |
+
ls = [t]
|
| 131 |
+
if cat:
|
| 132 |
+
ls = torch.cat(ls, dim=0)
|
| 133 |
+
return ls
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 137 |
+
if __initialized:
|
| 138 |
+
if not t.is_cuda:
|
| 139 |
+
t = t.cuda()
|
| 140 |
+
|
| 141 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
| 142 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
| 143 |
+
tdist.all_gather(ls_size, t_size)
|
| 144 |
+
|
| 145 |
+
max_B = max(size[0].item() for size in ls_size)
|
| 146 |
+
pad = max_B - t_size[0].item()
|
| 147 |
+
if pad:
|
| 148 |
+
pad_size = (pad, *t.size()[1:])
|
| 149 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
| 150 |
+
|
| 151 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
| 152 |
+
tdist.all_gather(ls_padded, t)
|
| 153 |
+
ls = []
|
| 154 |
+
for t, size in zip(ls_padded, ls_size):
|
| 155 |
+
ls.append(t[:size[0].item()])
|
| 156 |
+
else:
|
| 157 |
+
ls = [t]
|
| 158 |
+
if cat:
|
| 159 |
+
ls = torch.cat(ls, dim=0)
|
| 160 |
+
return ls
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
| 164 |
+
if __initialized:
|
| 165 |
+
if not t.is_cuda:
|
| 166 |
+
cu = t.detach().cuda()
|
| 167 |
+
tdist.broadcast(cu, src=src_rank)
|
| 168 |
+
t.copy_(cu.cpu())
|
| 169 |
+
else:
|
| 170 |
+
tdist.broadcast(t, src=src_rank)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
| 174 |
+
if not initialized():
|
| 175 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
| 176 |
+
|
| 177 |
+
ts = torch.zeros(__world_size)
|
| 178 |
+
ts[__rank] = val
|
| 179 |
+
allreduce(ts)
|
| 180 |
+
if fmt is None:
|
| 181 |
+
return ts
|
| 182 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def master_only(func):
|
| 186 |
+
@functools.wraps(func)
|
| 187 |
+
def wrapper(*args, **kwargs):
|
| 188 |
+
force = kwargs.pop('force', False)
|
| 189 |
+
if force or is_master():
|
| 190 |
+
ret = func(*args, **kwargs)
|
| 191 |
+
else:
|
| 192 |
+
ret = None
|
| 193 |
+
barrier()
|
| 194 |
+
return ret
|
| 195 |
+
return wrapper
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def local_master_only(func):
|
| 199 |
+
@functools.wraps(func)
|
| 200 |
+
def wrapper(*args, **kwargs):
|
| 201 |
+
force = kwargs.pop('force', False)
|
| 202 |
+
if force or is_local_master():
|
| 203 |
+
ret = func(*args, **kwargs)
|
| 204 |
+
else:
|
| 205 |
+
ret = None
|
| 206 |
+
barrier()
|
| 207 |
+
return ret
|
| 208 |
+
return wrapper
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def for_visualize(func):
|
| 212 |
+
@functools.wraps(func)
|
| 213 |
+
def wrapper(*args, **kwargs):
|
| 214 |
+
if is_master():
|
| 215 |
+
# with torch.no_grad():
|
| 216 |
+
ret = func(*args, **kwargs)
|
| 217 |
+
else:
|
| 218 |
+
ret = None
|
| 219 |
+
return ret
|
| 220 |
+
return wrapper
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def finalize():
|
| 224 |
+
if __initialized:
|
| 225 |
+
tdist.destroy_process_group()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
|
| 229 |
+
try:
|
| 230 |
+
__initialize(fork=False, timeout_minutes=timeout_minutes)
|
| 231 |
+
barrier()
|
| 232 |
+
except RuntimeError as e:
|
| 233 |
+
print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
|
| 234 |
+
raise e
|
| 235 |
+
|
| 236 |
+
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
|
| 237 |
+
_change_builtin_print(is_local_master())
|
| 238 |
+
if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
|
| 239 |
+
sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _change_builtin_print(is_master):
|
| 243 |
+
import builtins as __builtin__
|
| 244 |
+
|
| 245 |
+
builtin_print = __builtin__.print
|
| 246 |
+
if type(builtin_print) != type(open):
|
| 247 |
+
return
|
| 248 |
+
|
| 249 |
+
def prt(*args, **kwargs):
|
| 250 |
+
force = kwargs.pop('force', False)
|
| 251 |
+
clean = kwargs.pop('clean', False)
|
| 252 |
+
deeper = kwargs.pop('deeper', False)
|
| 253 |
+
if is_master or force:
|
| 254 |
+
if not clean:
|
| 255 |
+
f_back = sys._getframe().f_back
|
| 256 |
+
if deeper and f_back.f_back is not None:
|
| 257 |
+
f_back = f_back.f_back
|
| 258 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
| 259 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
| 260 |
+
builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
|
| 261 |
+
else:
|
| 262 |
+
builtin_print(*args, **kwargs)
|
| 263 |
+
|
| 264 |
+
__builtin__.print = prt
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class BackupStreamToFile(object):
|
| 268 |
+
def __init__(self, local_output_dir, for_stdout=True):
|
| 269 |
+
self.for_stdout = for_stdout
|
| 270 |
+
self.terminal_stream = sys.stdout if for_stdout else sys.stderr
|
| 271 |
+
fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
|
| 272 |
+
existing = os.path.exists(fname)
|
| 273 |
+
self.file_stream = open(fname, 'a')
|
| 274 |
+
if existing:
|
| 275 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
| 276 |
+
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
|
| 277 |
+
self.file_stream.flush()
|
| 278 |
+
self.enabled = True
|
| 279 |
+
|
| 280 |
+
def write(self, message):
|
| 281 |
+
self.terminal_stream.write(message)
|
| 282 |
+
self.file_stream.write(message)
|
| 283 |
+
|
| 284 |
+
def flush(self):
|
| 285 |
+
self.terminal_stream.flush()
|
| 286 |
+
self.file_stream.flush()
|
| 287 |
+
|
| 288 |
+
def close(self):
|
| 289 |
+
if not self.enabled:
|
| 290 |
+
return
|
| 291 |
+
self.enabled = False
|
| 292 |
+
self.file_stream.flush()
|
| 293 |
+
self.file_stream.close()
|
| 294 |
+
if self.for_stdout:
|
| 295 |
+
sys.stdout = self.terminal_stream
|
| 296 |
+
sys.stdout.flush()
|
| 297 |
+
else:
|
| 298 |
+
sys.stderr = self.terminal_stream
|
| 299 |
+
sys.stderr.flush()
|
| 300 |
+
|
| 301 |
+
def __del__(self):
|
| 302 |
+
self.close()
|
modeling/quantizer/mvq.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch import distributed as tdist, nn as nn
|
| 5 |
+
|
| 6 |
+
from .quantizer import VectorQuantizer
|
| 7 |
+
|
| 8 |
+
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
|
| 9 |
+
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
|
| 10 |
+
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
|
| 11 |
+
logits = -E_dist.float().mul_(inv_entropy_tau)
|
| 12 |
+
# calc per_sample_entropy
|
| 13 |
+
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
|
| 14 |
+
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
|
| 15 |
+
# calc codebook_entropy
|
| 16 |
+
avg_prob = prob.mean(dim=0) # (vocab_size,)
|
| 17 |
+
log_avg_prob = torch.log(avg_prob + 1e-7)
|
| 18 |
+
codebook_entropy = (-avg_prob * log_avg_prob).sum()
|
| 19 |
+
# calc entropy_loss
|
| 20 |
+
entropy_loss = per_sample_entropy - codebook_entropy
|
| 21 |
+
return entropy_loss
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NormalizedEmbedding(nn.Embedding):
|
| 25 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
| 26 |
+
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 27 |
+
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
|
| 28 |
+
|
| 29 |
+
def forward(self, idx):
|
| 30 |
+
return F.embedding(
|
| 31 |
+
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
|
| 32 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def get_norm_weight(self):
|
| 36 |
+
return F.normalize(self.weight, dim=1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResConv(nn.Conv2d):
|
| 40 |
+
def __init__(self, embed_dim, quant_resi):
|
| 41 |
+
ks = 3 if quant_resi < 0 else 1
|
| 42 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
|
| 43 |
+
self.resi_ratio = abs(quant_resi)
|
| 44 |
+
|
| 45 |
+
def forward(self, h_BChw):
|
| 46 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
| 47 |
+
|
| 48 |
+
class VectorQuantizerMVQ(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
codebook_size,
|
| 52 |
+
token_size,
|
| 53 |
+
commitment_cost=0.25,
|
| 54 |
+
use_l2_norm=False,
|
| 55 |
+
# entropy_temp=0.01, # we do not use this
|
| 56 |
+
clustering_vq=False,
|
| 57 |
+
num_codebooks=16
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.num_codebooks = num_codebooks
|
| 61 |
+
self.codebooks = nn.ModuleList()
|
| 62 |
+
for _ in range(num_codebooks):
|
| 63 |
+
codebook = VectorQuantizer(
|
| 64 |
+
codebook_size=codebook_size // num_codebooks,
|
| 65 |
+
token_size=token_size // num_codebooks,
|
| 66 |
+
commitment_cost=commitment_cost,
|
| 67 |
+
use_l2_norm=use_l2_norm,
|
| 68 |
+
clustering_vq=clustering_vq,
|
| 69 |
+
)
|
| 70 |
+
self.codebooks.append(codebook)
|
| 71 |
+
|
| 72 |
+
def init_vocab(self, eini: float):
|
| 73 |
+
for codebook in self.codebooks:
|
| 74 |
+
codebook.init_vocab(eini)
|
| 75 |
+
|
| 76 |
+
def f_to_idx(self, features):
|
| 77 |
+
indices = []
|
| 78 |
+
chunk_size = features.shape[-1] // self.num_codebooks
|
| 79 |
+
splited_features = features.split(chunk_size, dim=-1)
|
| 80 |
+
for i, codebook in enumerate(self.codebooks):
|
| 81 |
+
indices.append(codebook.f_to_idx(splited_features[i]))
|
| 82 |
+
indices = torch.stack(indices, dim=1)
|
| 83 |
+
return indices
|
| 84 |
+
|
| 85 |
+
def idx_to_f(self, indices):
|
| 86 |
+
assert indices.shape[1] == self.num_codebooks
|
| 87 |
+
latent_features = []
|
| 88 |
+
for i, codebook in enumerate(self.codebooks):
|
| 89 |
+
sub_indices = indices[:, i].flatten(start_dim=1)
|
| 90 |
+
latent_feature = codebook.codebook(sub_indices)
|
| 91 |
+
latent_features.append(latent_feature)
|
| 92 |
+
latent_features = torch.cat(latent_features, dim=-1)
|
| 93 |
+
return latent_features
|
| 94 |
+
|
| 95 |
+
def get_codebook_entry(self, indices):
|
| 96 |
+
"""Get codebook entries for multi-codebook indices.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W)
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
z_quantized: Quantized features
|
| 103 |
+
"""
|
| 104 |
+
if len(indices.shape) == 2:
|
| 105 |
+
# indices shape: (N, num_codebooks)
|
| 106 |
+
latent_features = []
|
| 107 |
+
for i, codebook in enumerate(self.codebooks):
|
| 108 |
+
sub_indices = indices[:, i]
|
| 109 |
+
latent_feature = codebook.get_codebook_entry(sub_indices)
|
| 110 |
+
latent_features.append(latent_feature)
|
| 111 |
+
return torch.cat(latent_features, dim=-1)
|
| 112 |
+
elif len(indices.shape) == 4:
|
| 113 |
+
# indices shape: (B, num_codebooks, H, W)
|
| 114 |
+
batch_size, _, height, width = indices.shape
|
| 115 |
+
latent_features = []
|
| 116 |
+
for i, codebook in enumerate(self.codebooks):
|
| 117 |
+
sub_indices = indices[:, i] # (B, H, W)
|
| 118 |
+
latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
|
| 119 |
+
# Reshape to (B, H, W, token_size // num_codebooks)
|
| 120 |
+
latent_feature = latent_feature.view(batch_size, height, width, -1)
|
| 121 |
+
latent_features.append(latent_feature)
|
| 122 |
+
# Concatenate along the last dimension and rearrange to (B, C, H, W)
|
| 123 |
+
latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C)
|
| 124 |
+
return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
|
| 125 |
+
else:
|
| 126 |
+
raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
|
| 127 |
+
|
| 128 |
+
def forward(self, features):
|
| 129 |
+
latent_features = []
|
| 130 |
+
all_result_dicts = []
|
| 131 |
+
chunk_size = features.shape[1] // self.num_codebooks
|
| 132 |
+
splited_features = features.split(chunk_size, dim=1)
|
| 133 |
+
|
| 134 |
+
for i, codebook in enumerate(self.codebooks):
|
| 135 |
+
latent_feature, result_dict = codebook(splited_features[i].float())
|
| 136 |
+
latent_features.append(latent_feature.to(features.dtype))
|
| 137 |
+
all_result_dicts.append(result_dict)
|
| 138 |
+
|
| 139 |
+
# Concatenate latent features
|
| 140 |
+
z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension
|
| 141 |
+
|
| 142 |
+
# Calculate global losses
|
| 143 |
+
global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
|
| 144 |
+
global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
|
| 145 |
+
global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
|
| 146 |
+
|
| 147 |
+
# Collect all min_encoding_indices
|
| 148 |
+
# Each codebook returns indices of shape (B, H, W)
|
| 149 |
+
# Stack them to get shape (B, num_codebooks, H, W)
|
| 150 |
+
all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
|
| 151 |
+
|
| 152 |
+
result_dict = dict(
|
| 153 |
+
quantizer_loss=global_quantizer_loss,
|
| 154 |
+
commitment_loss=global_commitment_loss,
|
| 155 |
+
codebook_loss=global_codebook_loss,
|
| 156 |
+
min_encoding_indices=all_indices
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return z_quantized, result_dict
|
modeling/quantizer/quantizer.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector quantizer.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
|
| 5 |
+
https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
|
| 6 |
+
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
|
| 7 |
+
https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py
|
| 8 |
+
"""
|
| 9 |
+
from typing import Mapping, Text, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from accelerate.utils.operations import gather
|
| 14 |
+
from torch.cuda.amp import autocast
|
| 15 |
+
|
| 16 |
+
class VectorQuantizer(torch.nn.Module):
|
| 17 |
+
def __init__(self,
|
| 18 |
+
codebook_size: int = 1024,
|
| 19 |
+
token_size: int = 256,
|
| 20 |
+
commitment_cost: float = 0.25,
|
| 21 |
+
use_l2_norm: bool = False,
|
| 22 |
+
clustering_vq: bool = False
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.codebook_size = codebook_size
|
| 26 |
+
self.token_size = token_size
|
| 27 |
+
self.commitment_cost = commitment_cost
|
| 28 |
+
|
| 29 |
+
self.embedding = torch.nn.Embedding(codebook_size, token_size)
|
| 30 |
+
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 31 |
+
self.use_l2_norm = use_l2_norm
|
| 32 |
+
|
| 33 |
+
self.clustering_vq = clustering_vq
|
| 34 |
+
if clustering_vq:
|
| 35 |
+
self.decay = 0.99
|
| 36 |
+
self.register_buffer("embed_prob", torch.zeros(self.codebook_size))
|
| 37 |
+
|
| 38 |
+
# Ensure quantization is performed using f32
|
| 39 |
+
@autocast(enabled=False)
|
| 40 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 41 |
+
z = z.float()
|
| 42 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 43 |
+
z_flattened = rearrange(z, 'b h w c -> (b h w) c')
|
| 44 |
+
unnormed_z_flattened = z_flattened
|
| 45 |
+
|
| 46 |
+
if self.use_l2_norm:
|
| 47 |
+
z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
|
| 48 |
+
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
|
| 49 |
+
else:
|
| 50 |
+
embedding = self.embedding.weight
|
| 51 |
+
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
|
| 52 |
+
torch.sum(embedding**2, dim=1) - 2 * \
|
| 53 |
+
torch.einsum('bd,dn->bn', z_flattened, embedding.T)
|
| 54 |
+
|
| 55 |
+
min_encoding_indices = torch.argmin(d, dim=1) # num_ele
|
| 56 |
+
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
|
| 57 |
+
|
| 58 |
+
if self.use_l2_norm:
|
| 59 |
+
z = torch.nn.functional.normalize(z, dim=-1)
|
| 60 |
+
|
| 61 |
+
# compute loss for embedding
|
| 62 |
+
commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
|
| 63 |
+
codebook_loss = torch.mean((z_quantized - z.detach()) **2)
|
| 64 |
+
|
| 65 |
+
if self.clustering_vq and self.training:
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
# Gather distance matrix from all GPUs.
|
| 68 |
+
encoding_indices = gather(min_encoding_indices)
|
| 69 |
+
if len(min_encoding_indices.shape) != 1:
|
| 70 |
+
raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}")
|
| 71 |
+
# Compute and update the usage of each entry in the codebook.
|
| 72 |
+
encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device)
|
| 73 |
+
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
|
| 74 |
+
avg_probs = torch.mean(encodings, dim=0)
|
| 75 |
+
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay)
|
| 76 |
+
# Closest sampling to update the codebook.
|
| 77 |
+
all_d = gather(d)
|
| 78 |
+
all_unnormed_z_flattened = gather(unnormed_z_flattened).detach()
|
| 79 |
+
if all_d.shape[0] != all_unnormed_z_flattened.shape[0]:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
"all_d and all_unnormed_z_flattened have different length" +
|
| 82 |
+
f"{all_d.shape}, {all_unnormed_z_flattened.shape}")
|
| 83 |
+
indices = torch.argmin(all_d, dim=0)
|
| 84 |
+
random_feat = all_unnormed_z_flattened[indices]
|
| 85 |
+
# Decay parameter based on the average usage.
|
| 86 |
+
decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) /
|
| 87 |
+
(1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size)
|
| 88 |
+
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
|
| 89 |
+
|
| 90 |
+
loss = commitment_loss + codebook_loss
|
| 91 |
+
|
| 92 |
+
# preserve gradients
|
| 93 |
+
z_quantized = z + (z_quantized - z).detach()
|
| 94 |
+
|
| 95 |
+
# reshape back to match original input shape
|
| 96 |
+
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
|
| 97 |
+
|
| 98 |
+
result_dict = dict(
|
| 99 |
+
quantizer_loss=loss,
|
| 100 |
+
commitment_loss=commitment_loss,
|
| 101 |
+
codebook_loss=codebook_loss,
|
| 102 |
+
min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return z_quantized, result_dict
|
| 106 |
+
|
| 107 |
+
@autocast(enabled=False)
|
| 108 |
+
def get_codebook_entry(self, indices):
|
| 109 |
+
indices = indices.long()
|
| 110 |
+
if len(indices.shape) == 1:
|
| 111 |
+
z_quantized = self.embedding(indices)
|
| 112 |
+
elif len(indices.shape) == 2:
|
| 113 |
+
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError
|
| 116 |
+
if self.use_l2_norm:
|
| 117 |
+
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
|
| 118 |
+
return z_quantized
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class DiagonalGaussianDistribution(object):
|
| 122 |
+
@autocast(enabled=False)
|
| 123 |
+
def __init__(self, parameters, deterministic=False):
|
| 124 |
+
"""Initializes a Gaussian distribution instance given the parameters.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected
|
| 128 |
+
to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension.
|
| 129 |
+
First C channels are used for mean and last C are used for logvar in the Gaussian distribution.
|
| 130 |
+
deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results
|
| 131 |
+
is purely based on mean (i.e., std = 0).
|
| 132 |
+
"""
|
| 133 |
+
self.parameters = parameters
|
| 134 |
+
self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
|
| 135 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 136 |
+
self.deterministic = deterministic
|
| 137 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 138 |
+
self.var = torch.exp(self.logvar)
|
| 139 |
+
if self.deterministic:
|
| 140 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 141 |
+
|
| 142 |
+
@autocast(enabled=False)
|
| 143 |
+
def sample(self):
|
| 144 |
+
x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
@autocast(enabled=False)
|
| 148 |
+
def mode(self):
|
| 149 |
+
return self.mean
|
| 150 |
+
|
| 151 |
+
@autocast(enabled=False)
|
| 152 |
+
def kl(self):
|
| 153 |
+
if self.deterministic:
|
| 154 |
+
return torch.Tensor([0.])
|
| 155 |
+
else:
|
| 156 |
+
return 0.5 * torch.sum(torch.pow(self.mean.float(), 2)
|
| 157 |
+
+ self.var.float() - 1.0 - self.logvar.float(),
|
| 158 |
+
dim=[1, 2])
|
modeling/quantizer/softvq.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Mapping, Text, Tuple
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch.cuda.amp import autocast
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SoftVectorQuantizer(torch.nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
codebook_size: int = 1024,
|
| 12 |
+
token_size: int = 256,
|
| 13 |
+
commitment_cost: float = 0.25,
|
| 14 |
+
use_l2_norm: bool = False,
|
| 15 |
+
clustering_vq: bool = False,
|
| 16 |
+
entropy_loss_ratio: float = 0.01,
|
| 17 |
+
tau: float = 0.07,
|
| 18 |
+
num_codebooks: int = 1,
|
| 19 |
+
show_usage: bool = False
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# Map new parameter names to internal names for compatibility
|
| 23 |
+
self.codebook_size = codebook_size
|
| 24 |
+
self.token_size = token_size
|
| 25 |
+
self.commitment_cost = commitment_cost
|
| 26 |
+
self.use_l2_norm = use_l2_norm
|
| 27 |
+
self.clustering_vq = clustering_vq
|
| 28 |
+
|
| 29 |
+
# Keep soft quantization specific parameters
|
| 30 |
+
self.num_codebooks = num_codebooks
|
| 31 |
+
self.n_e = codebook_size
|
| 32 |
+
self.e_dim = token_size
|
| 33 |
+
self.entropy_loss_ratio = entropy_loss_ratio
|
| 34 |
+
self.l2_norm = use_l2_norm
|
| 35 |
+
self.show_usage = show_usage
|
| 36 |
+
self.tau = tau
|
| 37 |
+
|
| 38 |
+
# Single embedding layer for all codebooks
|
| 39 |
+
self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size))
|
| 40 |
+
self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 41 |
+
|
| 42 |
+
if self.l2_norm:
|
| 43 |
+
self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1)
|
| 44 |
+
|
| 45 |
+
if self.show_usage:
|
| 46 |
+
self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536))
|
| 47 |
+
|
| 48 |
+
# Ensure quantization is performed using f32
|
| 49 |
+
@autocast(enabled=False)
|
| 50 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
| 51 |
+
z = z.float()
|
| 52 |
+
original_shape = z.shape
|
| 53 |
+
|
| 54 |
+
# Handle input reshaping to match VectorQuantizer format
|
| 55 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 56 |
+
z = z.view(z.size(0), -1, z.size(-1))
|
| 57 |
+
|
| 58 |
+
batch_size, seq_length, _ = z.shape
|
| 59 |
+
|
| 60 |
+
# Ensure sequence length is divisible by number of codebooks
|
| 61 |
+
assert seq_length % self.num_codebooks == 0, \
|
| 62 |
+
f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})"
|
| 63 |
+
|
| 64 |
+
segment_length = seq_length // self.num_codebooks
|
| 65 |
+
z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim)
|
| 66 |
+
|
| 67 |
+
# Apply L2 norm if needed
|
| 68 |
+
embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding
|
| 69 |
+
if self.l2_norm:
|
| 70 |
+
z_segments = F.normalize(z_segments, p=2, dim=-1)
|
| 71 |
+
|
| 72 |
+
z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim)
|
| 73 |
+
|
| 74 |
+
logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach())
|
| 75 |
+
|
| 76 |
+
# Calculate probabilities (soft quantization)
|
| 77 |
+
probs = F.softmax(logits / self.tau, dim=-1)
|
| 78 |
+
|
| 79 |
+
# Soft quantize
|
| 80 |
+
z_q = torch.einsum('nbk, nke -> nbe', probs, embedding)
|
| 81 |
+
|
| 82 |
+
# Reshape back
|
| 83 |
+
z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous()
|
| 84 |
+
|
| 85 |
+
# Calculate cosine similarity
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
zq_z_cos = F.cosine_similarity(
|
| 88 |
+
z_segments.view(-1, self.e_dim),
|
| 89 |
+
z_q.view(-1, self.e_dim),
|
| 90 |
+
dim=-1
|
| 91 |
+
).mean()
|
| 92 |
+
|
| 93 |
+
# Get indices for usage tracking
|
| 94 |
+
indices = torch.argmax(probs, dim=-1) # (num_codebooks, batch_size * segment_length)
|
| 95 |
+
indices = indices.transpose(0, 1).contiguous() # (batch_size * segment_length, num_codebooks)
|
| 96 |
+
|
| 97 |
+
# Track codebook usage
|
| 98 |
+
if self.show_usage and self.training:
|
| 99 |
+
for k in range(self.num_codebooks):
|
| 100 |
+
cur_len = indices.size(0)
|
| 101 |
+
self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone())
|
| 102 |
+
self.codebook_used[k, -cur_len:].copy_(indices[:, k])
|
| 103 |
+
|
| 104 |
+
# Calculate losses if training
|
| 105 |
+
if self.training:
|
| 106 |
+
# Soft quantization doesn't have traditional commitment/codebook loss
|
| 107 |
+
# Map entropy loss to quantizer_loss for compatibility
|
| 108 |
+
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e))
|
| 109 |
+
quantizer_loss = entropy_loss
|
| 110 |
+
commitment_loss = torch.tensor(0.0, device=z.device)
|
| 111 |
+
codebook_loss = torch.tensor(0.0, device=z.device)
|
| 112 |
+
else:
|
| 113 |
+
quantizer_loss = torch.tensor(0.0, device=z.device)
|
| 114 |
+
commitment_loss = torch.tensor(0.0, device=z.device)
|
| 115 |
+
codebook_loss = torch.tensor(0.0, device=z.device)
|
| 116 |
+
|
| 117 |
+
# Calculate codebook usage
|
| 118 |
+
codebook_usage = torch.tensor([
|
| 119 |
+
len(torch.unique(self.codebook_used[k])) / self.n_e
|
| 120 |
+
for k in range(self.num_codebooks)
|
| 121 |
+
]).mean() if self.show_usage else 0
|
| 122 |
+
|
| 123 |
+
z_q = z_q.view(batch_size, -1, self.e_dim)
|
| 124 |
+
|
| 125 |
+
# Reshape back to original input shape to match VectorQuantizer
|
| 126 |
+
z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1])
|
| 127 |
+
z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
| 128 |
+
|
| 129 |
+
# Calculate average probabilities
|
| 130 |
+
avg_probs = torch.mean(torch.mean(probs, dim=-1))
|
| 131 |
+
max_probs = torch.mean(torch.max(probs, dim=-1)[0])
|
| 132 |
+
|
| 133 |
+
# Return format matching VectorQuantizer
|
| 134 |
+
result_dict = dict(
|
| 135 |
+
quantizer_loss=quantizer_loss,
|
| 136 |
+
commitment_loss=commitment_loss,
|
| 137 |
+
codebook_loss=codebook_loss,
|
| 138 |
+
min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return z_quantized, result_dict
|
| 142 |
+
|
| 143 |
+
def get_codebook_entry(self, indices):
|
| 144 |
+
"""Added for compatibility with VectorQuantizer API"""
|
| 145 |
+
if len(indices.shape) == 1:
|
| 146 |
+
# For single codebook case
|
| 147 |
+
z_quantized = self.embedding[0][indices]
|
| 148 |
+
elif len(indices.shape) == 2:
|
| 149 |
+
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0])
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError
|
| 152 |
+
if self.use_l2_norm:
|
| 153 |
+
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
|
| 154 |
+
return z_quantized
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
| 158 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
| 159 |
+
flat_affinity /= temperature
|
| 160 |
+
probs = F.softmax(flat_affinity, dim=-1)
|
| 161 |
+
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
| 162 |
+
if loss_type == "softmax":
|
| 163 |
+
target_probs = probs
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
| 166 |
+
avg_probs = torch.mean(target_probs, dim=0)
|
| 167 |
+
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6))
|
| 168 |
+
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
| 169 |
+
loss = sample_entropy - avg_entropy
|
| 170 |
+
return loss
|
modeling/vibetoken_model.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VibeToken model definition."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
from modeling.modules.base_model import BaseModel
|
| 8 |
+
from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder
|
| 9 |
+
from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer
|
| 10 |
+
from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder
|
| 11 |
+
from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder
|
| 12 |
+
from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
|
| 13 |
+
import json
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PretrainedTokenizer(nn.Module):
|
| 21 |
+
def __init__(self, pretrained_weight):
|
| 22 |
+
super().__init__()
|
| 23 |
+
conf = OmegaConf.create(
|
| 24 |
+
{"channel_mult": [1, 1, 2, 2, 4],
|
| 25 |
+
"num_resolutions": 5,
|
| 26 |
+
"dropout": 0.0,
|
| 27 |
+
"hidden_channels": 128,
|
| 28 |
+
"num_channels": 3,
|
| 29 |
+
"num_res_blocks": 2,
|
| 30 |
+
"resolution": 256,
|
| 31 |
+
"z_channels": 256})
|
| 32 |
+
self.encoder = Pixel_Eecoder(conf)
|
| 33 |
+
self.decoder = Pixel_Decoder(conf)
|
| 34 |
+
self.quantize = Pixel_Quantizer(
|
| 35 |
+
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
|
| 36 |
+
# Load pretrained weights
|
| 37 |
+
self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True)
|
| 38 |
+
|
| 39 |
+
self.eval()
|
| 40 |
+
for param in self.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def encode(self, x):
|
| 45 |
+
hidden_states = self.encoder(x)
|
| 46 |
+
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states)
|
| 47 |
+
return codebook_indices.detach()
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def decode(self, codes):
|
| 51 |
+
quantized_states = self.quantize.get_codebook_entry(codes)
|
| 52 |
+
rec_images = self.decoder(quantized_states)
|
| 53 |
+
rec_images = torch.clamp(rec_images, 0.0, 1.0)
|
| 54 |
+
return rec_images.detach()
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def decode_tokens(self, codes):
|
| 58 |
+
return self.decode(codes)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]):
|
| 62 |
+
def __init__(self, config):
|
| 63 |
+
|
| 64 |
+
if isinstance(config, dict):
|
| 65 |
+
config = OmegaConf.create(config)
|
| 66 |
+
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.config = config
|
| 69 |
+
# This should be False for stage1 and True for stage2.
|
| 70 |
+
self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True)
|
| 71 |
+
|
| 72 |
+
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
|
| 73 |
+
if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]:
|
| 74 |
+
raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.")
|
| 75 |
+
|
| 76 |
+
if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]:
|
| 77 |
+
raise ValueError("Only supprot finetune_decoder with vq quantization for now.")
|
| 78 |
+
|
| 79 |
+
self.encoder = ResolutionEncoder(config)
|
| 80 |
+
self.decoder = ResolutionDecoder(config)
|
| 81 |
+
|
| 82 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
| 83 |
+
scale = self.encoder.width ** -0.5
|
| 84 |
+
self.latent_tokens = nn.Parameter(
|
| 85 |
+
scale * torch.randn(self.num_latent_tokens, self.encoder.width))
|
| 86 |
+
|
| 87 |
+
self.apply(self._init_weights)
|
| 88 |
+
|
| 89 |
+
if self.quantize_mode == "vq":
|
| 90 |
+
self.quantize = VectorQuantizer(
|
| 91 |
+
codebook_size=config.model.vq_model.codebook_size,
|
| 92 |
+
token_size=config.model.vq_model.token_size,
|
| 93 |
+
commitment_cost=config.model.vq_model.commitment_cost,
|
| 94 |
+
use_l2_norm=config.model.vq_model.use_l2_norm,)
|
| 95 |
+
elif self.quantize_mode == "vae":
|
| 96 |
+
self.quantize = DiagonalGaussianDistribution
|
| 97 |
+
elif self.quantize_mode == "mvq":
|
| 98 |
+
self.quantize = VectorQuantizerMVQ(
|
| 99 |
+
codebook_size=config.model.vq_model.codebook_size,
|
| 100 |
+
token_size=config.model.vq_model.token_size,
|
| 101 |
+
commitment_cost=config.model.vq_model.commitment_cost,
|
| 102 |
+
use_l2_norm=config.model.vq_model.use_l2_norm,
|
| 103 |
+
num_codebooks=config.model.vq_model.num_codebooks,
|
| 104 |
+
)
|
| 105 |
+
elif self.quantize_mode == "softvq":
|
| 106 |
+
self.quantize = SoftVectorQuantizer(
|
| 107 |
+
codebook_size=config.model.vq_model.codebook_size,
|
| 108 |
+
token_size=config.model.vq_model.token_size,
|
| 109 |
+
commitment_cost=config.model.vq_model.commitment_cost,
|
| 110 |
+
use_l2_norm=config.model.vq_model.use_l2_norm,
|
| 111 |
+
num_codebooks=config.model.vq_model.num_codebooks,
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
+
if self.finetune_decoder:
|
| 117 |
+
# Freeze encoder/quantizer/latent tokens
|
| 118 |
+
self.latent_tokens.requires_grad_(False)
|
| 119 |
+
self.encoder.eval()
|
| 120 |
+
self.encoder.requires_grad_(False)
|
| 121 |
+
self.quantize.eval()
|
| 122 |
+
self.quantize.requires_grad_(False)
|
| 123 |
+
|
| 124 |
+
# Include MaskGiT-VQGAN's quantizer and decoder
|
| 125 |
+
self.pixel_quantize = Pixel_Quantizer(
|
| 126 |
+
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
|
| 127 |
+
self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
|
| 128 |
+
{"channel_mult": [1, 1, 2, 2, 4],
|
| 129 |
+
"num_resolutions": 5,
|
| 130 |
+
"dropout": 0.0,
|
| 131 |
+
"hidden_channels": 128,
|
| 132 |
+
"num_channels": 3,
|
| 133 |
+
"num_res_blocks": 2,
|
| 134 |
+
"resolution": 256,
|
| 135 |
+
"z_channels": 256}))
|
| 136 |
+
|
| 137 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 138 |
+
"""Save weights and config to a local directory."""
|
| 139 |
+
# Assume 'self.config' is your DictConfig object
|
| 140 |
+
# Convert to a regular dictionary
|
| 141 |
+
dict_config = OmegaConf.to_container(self.config)
|
| 142 |
+
# Save as JSON
|
| 143 |
+
file_path = Path(save_directory) / "config.json"
|
| 144 |
+
with open(file_path, 'w') as json_file:
|
| 145 |
+
json.dump(dict_config, json_file, indent=4)
|
| 146 |
+
super()._save_pretrained(save_directory)
|
| 147 |
+
|
| 148 |
+
def _init_weights(self, module):
|
| 149 |
+
""" Initialize the weights.
|
| 150 |
+
:param:
|
| 151 |
+
module -> torch.nn.Module: module to initialize
|
| 152 |
+
"""
|
| 153 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
|
| 154 |
+
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
|
| 155 |
+
if module.bias is not None:
|
| 156 |
+
module.bias.data.zero_()
|
| 157 |
+
elif isinstance(module, nn.Embedding):
|
| 158 |
+
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
|
| 159 |
+
elif isinstance(module, nn.LayerNorm):
|
| 160 |
+
module.bias.data.zero_()
|
| 161 |
+
module.weight.data.fill_(1.0)
|
| 162 |
+
|
| 163 |
+
def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None):
|
| 164 |
+
if self.finetune_decoder:
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
self.encoder.eval()
|
| 167 |
+
self.quantize.eval()
|
| 168 |
+
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
|
| 169 |
+
z_quantized, result_dict = self.quantize(z)
|
| 170 |
+
result_dict["quantizer_loss"] *= 0
|
| 171 |
+
result_dict["commitment_loss"] *= 0
|
| 172 |
+
result_dict["codebook_loss"] *= 0
|
| 173 |
+
else:
|
| 174 |
+
if length is not None:
|
| 175 |
+
attention_mask = None
|
| 176 |
+
latent_tokens = self.latent_tokens[:length+1]
|
| 177 |
+
else:
|
| 178 |
+
latent_tokens = self.latent_tokens
|
| 179 |
+
z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
|
| 180 |
+
if self.quantize_mode in ["vq", "mvq", "softvq"]:
|
| 181 |
+
z_quantized, result_dict = self.quantize(z)
|
| 182 |
+
elif self.quantize_mode == "vae":
|
| 183 |
+
posteriors = self.quantize(z)
|
| 184 |
+
z_quantized = posteriors.sample()
|
| 185 |
+
result_dict = posteriors
|
| 186 |
+
|
| 187 |
+
return z_quantized, result_dict
|
| 188 |
+
|
| 189 |
+
def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
|
| 190 |
+
decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
|
| 191 |
+
if self.finetune_decoder:
|
| 192 |
+
quantized_states = torch.einsum(
|
| 193 |
+
'nchw,cd->ndhw', decoded.softmax(1),
|
| 194 |
+
self.pixel_quantize.embedding.weight)
|
| 195 |
+
decoded = self.pixel_decoder(quantized_states)
|
| 196 |
+
return decoded
|
| 197 |
+
|
| 198 |
+
def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
|
| 199 |
+
if self.quantize_mode in ["vq", "softvq"]:
|
| 200 |
+
tokens = tokens.squeeze(1)
|
| 201 |
+
batch, seq_len = tokens.shape # B x N
|
| 202 |
+
z_quantized = self.quantize.get_codebook_entry(
|
| 203 |
+
tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
|
| 204 |
+
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
|
| 205 |
+
elif self.quantize_mode == "mvq":
|
| 206 |
+
z_quantized = self.quantize.get_codebook_entry(tokens)
|
| 207 |
+
elif self.quantize_mode == "vae":
|
| 208 |
+
z_quantized = tokens
|
| 209 |
+
z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
|
| 210 |
+
decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
|
| 211 |
+
return decoded
|
| 212 |
+
|
| 213 |
+
def forward(self, x, key_attention_mask=None, height=None, width=None, train=True):
|
| 214 |
+
if height is None:
|
| 215 |
+
batch_size, channels, height, width = x.shape
|
| 216 |
+
z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train)
|
| 217 |
+
z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
|
| 218 |
+
decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train)
|
| 219 |
+
return decoded, result_dict
|
reconstruct.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Simple reconstruction script for VibeToken.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Auto mode (recommended) - automatically determines optimal settings
|
| 6 |
+
python reconstruct.py --auto \
|
| 7 |
+
--config configs/vibetoken_ll.yaml \
|
| 8 |
+
--checkpoint /path/to/checkpoint.bin \
|
| 9 |
+
--image assets/example_1.jpg \
|
| 10 |
+
--output assets/reconstructed.png
|
| 11 |
+
|
| 12 |
+
# Manual mode - specify all parameters
|
| 13 |
+
python reconstruct.py \
|
| 14 |
+
--config configs/vibetoken_ll.yaml \
|
| 15 |
+
--checkpoint /path/to/checkpoint.bin \
|
| 16 |
+
--image assets/example_1.jpg \
|
| 17 |
+
--output assets/reconstructed.png \
|
| 18 |
+
--input_height 512 --input_width 512 \
|
| 19 |
+
--encoder_patch_size 16,32 \
|
| 20 |
+
--decoder_patch_size 16
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_patch_size(value):
|
| 29 |
+
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
|
| 30 |
+
if value is None:
|
| 31 |
+
return None
|
| 32 |
+
if ',' in value:
|
| 33 |
+
parts = value.split(',')
|
| 34 |
+
return (int(parts[0]), int(parts[1]))
|
| 35 |
+
return int(value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
parser = argparse.ArgumentParser(description="VibeToken image reconstruction")
|
| 40 |
+
parser.add_argument("--config", type=str, default="configs/vibetoken_ll.yaml",
|
| 41 |
+
help="Path to config YAML")
|
| 42 |
+
parser.add_argument("--checkpoint", type=str, required=True,
|
| 43 |
+
help="Path to model checkpoint")
|
| 44 |
+
parser.add_argument("--image", type=str, default="assets/example_1.jpg",
|
| 45 |
+
help="Path to input image")
|
| 46 |
+
parser.add_argument("--output", type=str, default="./assets/reconstructed.png",
|
| 47 |
+
help="Path to output image")
|
| 48 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 49 |
+
help="Device (cuda/cpu)")
|
| 50 |
+
|
| 51 |
+
# Auto mode
|
| 52 |
+
parser.add_argument("--auto", action="store_true",
|
| 53 |
+
help="Auto mode: automatically determine optimal input resolution and patch sizes")
|
| 54 |
+
|
| 55 |
+
# Input resolution (optional - resize input before encoding)
|
| 56 |
+
parser.add_argument("--input_height", type=int, default=None,
|
| 57 |
+
help="Resize input to this height before encoding (default: original)")
|
| 58 |
+
parser.add_argument("--input_width", type=int, default=None,
|
| 59 |
+
help="Resize input to this width before encoding (default: original)")
|
| 60 |
+
|
| 61 |
+
# Output resolution (optional - decode to this size)
|
| 62 |
+
parser.add_argument("--output_height", type=int, default=None,
|
| 63 |
+
help="Decode to this height (default: same as input)")
|
| 64 |
+
parser.add_argument("--output_width", type=int, default=None,
|
| 65 |
+
help="Decode to this width (default: same as input)")
|
| 66 |
+
|
| 67 |
+
# Patch sizes (optional) - supports single int or tuple like "16,32"
|
| 68 |
+
parser.add_argument("--encoder_patch_size", type=str, default=None,
|
| 69 |
+
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 70 |
+
parser.add_argument("--decoder_patch_size", type=str, default=None,
|
| 71 |
+
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
|
| 72 |
+
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
# Load tokenizer
|
| 76 |
+
print(f"Loading tokenizer from {args.config}")
|
| 77 |
+
tokenizer = VibeTokenTokenizer.from_config(
|
| 78 |
+
args.config,
|
| 79 |
+
args.checkpoint,
|
| 80 |
+
device=args.device,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Load image
|
| 84 |
+
print(f"Loading image from {args.image}")
|
| 85 |
+
image = Image.open(args.image).convert("RGB")
|
| 86 |
+
original_size = image.size # (W, H)
|
| 87 |
+
print(f"Original image size: {original_size[0]}x{original_size[1]}")
|
| 88 |
+
|
| 89 |
+
if args.auto:
|
| 90 |
+
# AUTO MODE - use centralized auto_preprocess_image
|
| 91 |
+
print("\n=== AUTO MODE ===")
|
| 92 |
+
image, patch_size, info = auto_preprocess_image(image, verbose=True)
|
| 93 |
+
input_width, input_height = info["cropped_size"]
|
| 94 |
+
output_width, output_height = input_width, input_height
|
| 95 |
+
encoder_patch_size = patch_size
|
| 96 |
+
decoder_patch_size = patch_size
|
| 97 |
+
print("=================\n")
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
# MANUAL MODE
|
| 101 |
+
# Parse patch sizes
|
| 102 |
+
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
|
| 103 |
+
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
|
| 104 |
+
|
| 105 |
+
# Resize input if specified
|
| 106 |
+
if args.input_width or args.input_height:
|
| 107 |
+
input_width = args.input_width or original_size[0]
|
| 108 |
+
input_height = args.input_height or original_size[1]
|
| 109 |
+
print(f"Resizing input to {input_width}x{input_height}")
|
| 110 |
+
image = image.resize((input_width, input_height), Image.LANCZOS)
|
| 111 |
+
|
| 112 |
+
# Always center crop to ensure dimensions divisible by 32
|
| 113 |
+
image = center_crop_to_multiple(image, multiple=32)
|
| 114 |
+
input_width, input_height = image.size
|
| 115 |
+
if (input_width, input_height) != original_size:
|
| 116 |
+
print(f"Center cropped to {input_width}x{input_height} (divisible by 32)")
|
| 117 |
+
|
| 118 |
+
# Determine output size
|
| 119 |
+
output_height = args.output_height or input_height
|
| 120 |
+
output_width = args.output_width or input_width
|
| 121 |
+
|
| 122 |
+
# Encode image to tokens
|
| 123 |
+
print("Encoding image to tokens...")
|
| 124 |
+
if encoder_patch_size:
|
| 125 |
+
print(f" Using encoder patch size: {encoder_patch_size}")
|
| 126 |
+
tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
|
| 127 |
+
print(f"Token shape: {tokens.shape}")
|
| 128 |
+
|
| 129 |
+
# Decode back to image
|
| 130 |
+
print(f"Decoding to {output_width}x{output_height}...")
|
| 131 |
+
if decoder_patch_size:
|
| 132 |
+
print(f" Using decoder patch size: {decoder_patch_size}")
|
| 133 |
+
reconstructed = tokenizer.decode(
|
| 134 |
+
tokens,
|
| 135 |
+
height=output_height,
|
| 136 |
+
width=output_width,
|
| 137 |
+
patch_size=decoder_patch_size
|
| 138 |
+
)
|
| 139 |
+
print(f"Reconstructed shape: {reconstructed.shape}")
|
| 140 |
+
|
| 141 |
+
# Convert tensor to PIL and save
|
| 142 |
+
output_images = tokenizer.to_pil(reconstructed)
|
| 143 |
+
output_images[0].save(args.output)
|
| 144 |
+
print(f"Saved reconstructed image to {args.output}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision
|
| 4 |
+
einops>=0.6.0
|
| 5 |
+
omegaconf>=2.3.0
|
| 6 |
+
pillow>=9.0.0
|
| 7 |
+
numpy>=1.20.0
|
| 8 |
+
huggingface_hub>=0.16.0
|
| 9 |
+
accelerate
|
| 10 |
+
wandb
|
| 11 |
+
webdataset
|
| 12 |
+
timm
|
| 13 |
+
open_clip_torch
|
| 14 |
+
transformers
|
| 15 |
+
scipy
|
| 16 |
+
torch-fidelity
|
| 17 |
+
torchinfo
|
| 18 |
+
termcolor
|
| 19 |
+
iopath
|
| 20 |
+
opencv-python
|
| 21 |
+
diffusers
|
| 22 |
+
gdown
|
| 23 |
+
tqdm
|
| 24 |
+
requests
|
| 25 |
+
datasets
|
| 26 |
+
gradio>=4.0.0
|
scripts/train_vibetoken.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training script for VibeToken.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://github.com/huggingface/open-muse
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
| 11 |
+
sys.path.append(parent_dir)
|
| 12 |
+
|
| 13 |
+
from accelerate.utils import set_seed
|
| 14 |
+
from accelerate import Accelerator
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import wandb
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
from utils.logger import setup_logger
|
| 20 |
+
|
| 21 |
+
from utils.train_utils import (
|
| 22 |
+
get_config, create_pretrained_tokenizer,
|
| 23 |
+
create_model_and_loss_module,
|
| 24 |
+
create_optimizer, create_lr_scheduler, create_dataloader,
|
| 25 |
+
create_evaluator, auto_resume, save_checkpoint,
|
| 26 |
+
train_one_epoch)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
workspace = os.environ.get('WORKSPACE', '')
|
| 31 |
+
if workspace:
|
| 32 |
+
torch.hub.set_dir(workspace + "/models/hub")
|
| 33 |
+
|
| 34 |
+
config = get_config()
|
| 35 |
+
# Enable TF32 on Ampere GPUs.
|
| 36 |
+
if config.training.enable_tf32:
|
| 37 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 39 |
+
torch.backends.cudnn.benchmark = True
|
| 40 |
+
torch.backends.cudnn.deterministic = False
|
| 41 |
+
|
| 42 |
+
output_dir = config.experiment.output_dir
|
| 43 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 44 |
+
config.experiment.logging_dir = os.path.join(output_dir, "logs")
|
| 45 |
+
|
| 46 |
+
# Whether logging to Wandb or Tensorboard.
|
| 47 |
+
tracker = "tensorboard"
|
| 48 |
+
if config.training.enable_wandb:
|
| 49 |
+
tracker = "wandb"
|
| 50 |
+
|
| 51 |
+
accelerator = Accelerator(
|
| 52 |
+
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
| 53 |
+
mixed_precision=config.training.mixed_precision,
|
| 54 |
+
log_with=tracker,
|
| 55 |
+
project_dir=config.experiment.logging_dir,
|
| 56 |
+
split_batches=False,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
logger = setup_logger(name="VibeToken", log_level="INFO",
|
| 60 |
+
output_file=f"{output_dir}/log{accelerator.process_index}.txt")
|
| 61 |
+
|
| 62 |
+
if accelerator.is_main_process:
|
| 63 |
+
if config.training.enable_wandb:
|
| 64 |
+
wandb_config = config.training.get("wandb", {})
|
| 65 |
+
wandb_project = wandb_config.get("project", config.experiment.project)
|
| 66 |
+
wandb_entity = wandb_config.get("entity", None)
|
| 67 |
+
wandb_name = wandb_config.get("name", config.experiment.name)
|
| 68 |
+
wandb_tags = list(wandb_config.get("tags", []))
|
| 69 |
+
wandb_notes = wandb_config.get("notes", None)
|
| 70 |
+
wandb_resume_id = wandb_config.get("resume_id", None)
|
| 71 |
+
|
| 72 |
+
wandb_init_kwargs = {
|
| 73 |
+
"wandb": {
|
| 74 |
+
"name": wandb_name,
|
| 75 |
+
"dir": output_dir,
|
| 76 |
+
"resume": "allow",
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
if wandb_entity:
|
| 80 |
+
wandb_init_kwargs["wandb"]["entity"] = wandb_entity
|
| 81 |
+
if wandb_tags:
|
| 82 |
+
wandb_init_kwargs["wandb"]["tags"] = wandb_tags
|
| 83 |
+
if wandb_notes:
|
| 84 |
+
wandb_init_kwargs["wandb"]["notes"] = wandb_notes
|
| 85 |
+
if wandb_resume_id:
|
| 86 |
+
wandb_init_kwargs["wandb"]["id"] = wandb_resume_id
|
| 87 |
+
|
| 88 |
+
accelerator.init_trackers(
|
| 89 |
+
project_name=wandb_project,
|
| 90 |
+
config=OmegaConf.to_container(config, resolve=True),
|
| 91 |
+
init_kwargs=wandb_init_kwargs,
|
| 92 |
+
)
|
| 93 |
+
logger.info(f"WandB initialized - Project: {wandb_project}, Name: {wandb_name}")
|
| 94 |
+
else:
|
| 95 |
+
accelerator.init_trackers(config.experiment.name)
|
| 96 |
+
|
| 97 |
+
config_path = Path(output_dir) / "config.yaml"
|
| 98 |
+
logger.info(f"Saving config to {config_path}")
|
| 99 |
+
OmegaConf.save(config, config_path)
|
| 100 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(config)}")
|
| 101 |
+
|
| 102 |
+
# If passed along, set the training seed now.
|
| 103 |
+
if config.training.seed is not None:
|
| 104 |
+
set_seed(config.training.seed, device_specific=True)
|
| 105 |
+
|
| 106 |
+
accelerator.wait_for_everyone()
|
| 107 |
+
|
| 108 |
+
# Create pretrained tokenizer in a synchronized manner
|
| 109 |
+
if config.model.vq_model.is_legacy:
|
| 110 |
+
if accelerator.is_main_process:
|
| 111 |
+
logger.info("Creating pretrained tokenizer on main process...")
|
| 112 |
+
accelerator.wait_for_everyone()
|
| 113 |
+
pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator)
|
| 114 |
+
accelerator.wait_for_everyone()
|
| 115 |
+
if accelerator.is_main_process:
|
| 116 |
+
logger.info("Pretrained tokenizer creation completed.")
|
| 117 |
+
else:
|
| 118 |
+
pretrained_tokenizer = None
|
| 119 |
+
|
| 120 |
+
if accelerator.is_main_process:
|
| 121 |
+
logger.info("Creating model and loss module...")
|
| 122 |
+
accelerator.wait_for_everyone()
|
| 123 |
+
|
| 124 |
+
model, ema_model, loss_module = create_model_and_loss_module(
|
| 125 |
+
config, logger, accelerator, model_type="vibetoken")
|
| 126 |
+
|
| 127 |
+
accelerator.wait_for_everyone()
|
| 128 |
+
if accelerator.is_main_process:
|
| 129 |
+
logger.info("Model creation completed.")
|
| 130 |
+
|
| 131 |
+
optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module, model_type="vibetoken")
|
| 132 |
+
|
| 133 |
+
lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler(
|
| 134 |
+
config, logger, accelerator, optimizer, discriminator_optimizer)
|
| 135 |
+
|
| 136 |
+
if accelerator.is_main_process:
|
| 137 |
+
logger.info("Creating dataloaders...")
|
| 138 |
+
train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator)
|
| 139 |
+
accelerator.wait_for_everyone()
|
| 140 |
+
|
| 141 |
+
# Set up evaluator.
|
| 142 |
+
if accelerator.is_main_process:
|
| 143 |
+
logger.info("Setting up evaluator...")
|
| 144 |
+
evaluator = create_evaluator(config, logger, accelerator)
|
| 145 |
+
|
| 146 |
+
# Prepare everything with accelerator.
|
| 147 |
+
logger.info("Preparing model, optimizer and dataloaders")
|
| 148 |
+
# The dataloader are already aware of distributed training, so we don't need to prepare them.
|
| 149 |
+
if config.model.vq_model.is_legacy:
|
| 150 |
+
if config.model.vq_model.finetune_decoder:
|
| 151 |
+
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
|
| 152 |
+
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
model, optimizer, lr_scheduler = accelerator.prepare(
|
| 156 |
+
model, optimizer, lr_scheduler
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
|
| 160 |
+
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if config.training.use_ema:
|
| 164 |
+
ema_model.to(accelerator.device)
|
| 165 |
+
|
| 166 |
+
total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes
|
| 167 |
+
num_batches = math.ceil(
|
| 168 |
+
config.experiment.max_train_examples / total_batch_size_without_accum)
|
| 169 |
+
num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps)
|
| 170 |
+
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch)
|
| 171 |
+
|
| 172 |
+
# Start training.
|
| 173 |
+
logger.info("***** Running training *****")
|
| 174 |
+
logger.info(f" Num training steps = {config.training.max_train_steps}")
|
| 175 |
+
logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}")
|
| 176 |
+
logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}")
|
| 177 |
+
logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {(
|
| 178 |
+
config.training.per_gpu_batch_size *
|
| 179 |
+
accelerator.num_processes *
|
| 180 |
+
config.training.gradient_accumulation_steps)}""")
|
| 181 |
+
global_step = 0
|
| 182 |
+
first_epoch = 0
|
| 183 |
+
|
| 184 |
+
global_step, first_epoch = auto_resume(
|
| 185 |
+
config, logger, accelerator, ema_model, num_update_steps_per_epoch,
|
| 186 |
+
strict=True)
|
| 187 |
+
|
| 188 |
+
for current_epoch in range(first_epoch, num_train_epochs):
|
| 189 |
+
accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.")
|
| 190 |
+
global_step = train_one_epoch(config, logger, accelerator,
|
| 191 |
+
model, ema_model, loss_module,
|
| 192 |
+
optimizer, discriminator_optimizer,
|
| 193 |
+
lr_scheduler, discriminator_lr_scheduler,
|
| 194 |
+
train_dataloader, eval_dataloader,
|
| 195 |
+
evaluator,
|
| 196 |
+
global_step,
|
| 197 |
+
pretrained_tokenizer=pretrained_tokenizer,
|
| 198 |
+
model_type="vibetoken")
|
| 199 |
+
# Stop training if max steps is reached.
|
| 200 |
+
if global_step >= config.training.max_train_steps:
|
| 201 |
+
accelerator.print(
|
| 202 |
+
f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}"
|
| 203 |
+
)
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
accelerator.wait_for_everyone()
|
| 207 |
+
# Save checkpoint at the end of training.
|
| 208 |
+
save_checkpoint(model, output_dir, accelerator, global_step, logger=logger)
|
| 209 |
+
# Save the final trained checkpoint
|
| 210 |
+
if accelerator.is_main_process:
|
| 211 |
+
model = accelerator.unwrap_model(model)
|
| 212 |
+
if config.training.use_ema:
|
| 213 |
+
ema_model.copy_to(model.parameters())
|
| 214 |
+
model.save_pretrained_weight(output_dir)
|
| 215 |
+
|
| 216 |
+
if accelerator.is_main_process and config.training.enable_wandb:
|
| 217 |
+
wandb.finish()
|
| 218 |
+
logger.info("WandB run finished")
|
| 219 |
+
accelerator.end_training()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
main()
|
setup.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Data preparation script for VibeToken training.
|
| 3 |
+
# Set DATA_DIR to control where datasets are stored (defaults to ./data).
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# export DATA_DIR=/mnt/fastssd/datasets # optional, defaults to ./data
|
| 7 |
+
# bash setup.sh
|
| 8 |
+
|
| 9 |
+
DATA_DIR="${DATA_DIR:-./data}"
|
| 10 |
+
|
| 11 |
+
echo "Using DATA_DIR=${DATA_DIR}"
|
| 12 |
+
|
| 13 |
+
# Download ImageNet-1k via HuggingFace
|
| 14 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 15 |
+
huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
|
| 16 |
+
|
| 17 |
+
# Convert to WebDataset format
|
| 18 |
+
python data/convert_imagenet_to_wds.py \
|
| 19 |
+
--input_dir "${DATA_DIR}/imagenet-1k" \
|
| 20 |
+
--output_dir "${DATA_DIR}/imagenet_wds"
|
train_tokenvibe.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run training with 8 GPUs across 2 nodes (4 GPUs per node)
|
| 2 |
+
NODE_RANK=${RANK:-1}
|
| 3 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 4 |
+
MASTER_PORT=${MASTER_PORT:-9871}
|
| 5 |
+
|
| 6 |
+
accelerate launch \
|
| 7 |
+
--num_machines=1 \
|
| 8 |
+
--num_processes=8 \
|
| 9 |
+
--machine_rank=$NODE_RANK \
|
| 10 |
+
--main_process_ip=$MASTER_ADDR \
|
| 11 |
+
--main_process_port=$MASTER_PORT \
|
| 12 |
+
--same_network \
|
| 13 |
+
scripts/train_tokenvibe.py \
|
| 14 |
+
config=configs/training/VibeToken_small.yaml
|
train_vibetoken.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run training with 8 GPUs across 2 nodes (4 GPUs per node)
|
| 2 |
+
NODE_RANK=${RANK:-1}
|
| 3 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 4 |
+
MASTER_PORT=${MASTER_PORT:-9871}
|
| 5 |
+
|
| 6 |
+
accelerate launch \
|
| 7 |
+
--num_machines=1 \
|
| 8 |
+
--num_processes=8 \
|
| 9 |
+
--machine_rank=$NODE_RANK \
|
| 10 |
+
--main_process_ip=$MASTER_ADDR \
|
| 11 |
+
--main_process_port=$MASTER_PORT \
|
| 12 |
+
--same_network \
|
| 13 |
+
scripts/train_vibetoken.py \
|
| 14 |
+
config=configs/training/VibeToken_small.yaml
|
utils/__init__.py
ADDED
|
File without changes
|
utils/logger.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Util functions supporting logging to terminal and files."""
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import sys
|
| 5 |
+
from accelerate.logging import MultiProcessAdapter
|
| 6 |
+
import logging
|
| 7 |
+
from termcolor import colored
|
| 8 |
+
|
| 9 |
+
from iopath.common.file_io import PathManager as PathManagerClass
|
| 10 |
+
|
| 11 |
+
__all__ = ["setup_logger", "PathManager"]
|
| 12 |
+
|
| 13 |
+
PathManager = PathManagerClass()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class _ColorfulFormatter(logging.Formatter):
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
self._root_name = kwargs.pop("root_name") + "."
|
| 19 |
+
self._abbrev_name = kwargs.pop("abbrev_name", self._root_name)
|
| 20 |
+
if len(self._abbrev_name):
|
| 21 |
+
self._abbrev_name = self._abbrev_name + "."
|
| 22 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
| 23 |
+
|
| 24 |
+
def formatMessage(self, record):
|
| 25 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
| 26 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
| 27 |
+
if record.levelno == logging.WARNING:
|
| 28 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
| 29 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
| 30 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
| 31 |
+
else:
|
| 32 |
+
return log
|
| 33 |
+
return prefix + " " + log
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@functools.lru_cache()
|
| 37 |
+
def setup_logger(name="TiTok", log_level: str = None, color=True, use_accelerate=True,
|
| 38 |
+
output_file=None):
|
| 39 |
+
logger = logging.getLogger(name)
|
| 40 |
+
if log_level is None:
|
| 41 |
+
logger.setLevel(logging.DEBUG)
|
| 42 |
+
else:
|
| 43 |
+
logger.setLevel(log_level.upper())
|
| 44 |
+
|
| 45 |
+
plain_formatter = logging.Formatter(
|
| 46 |
+
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
| 47 |
+
)
|
| 48 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
| 49 |
+
ch.setLevel(logging.DEBUG)
|
| 50 |
+
if color:
|
| 51 |
+
formatter = _ColorfulFormatter(
|
| 52 |
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
| 53 |
+
datefmt="%m/%d %H:%M:%S",
|
| 54 |
+
root_name=name,
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
formatter = plain_formatter
|
| 58 |
+
ch.setFormatter(formatter)
|
| 59 |
+
logger.addHandler(ch)
|
| 60 |
+
|
| 61 |
+
if output_file is not None:
|
| 62 |
+
fileHandler = logging.FileHandler(output_file)
|
| 63 |
+
fileHandler.setFormatter(formatter)
|
| 64 |
+
logger.addHandler(fileHandler)
|
| 65 |
+
|
| 66 |
+
if use_accelerate:
|
| 67 |
+
return MultiProcessAdapter(logger, {})
|
| 68 |
+
else:
|
| 69 |
+
return logger
|
utils/lr_schedulers.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Learning rate schedulers.
|
| 2 |
+
|
| 3 |
+
Reference:
|
| 4 |
+
https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SchedulerType(Enum):
|
| 14 |
+
COSINE = "cosine"
|
| 15 |
+
CONSTANT = "constant"
|
| 16 |
+
|
| 17 |
+
def get_cosine_schedule_with_warmup(
|
| 18 |
+
optimizer: torch.optim.Optimizer,
|
| 19 |
+
num_warmup_steps: int,
|
| 20 |
+
num_training_steps: int,
|
| 21 |
+
num_cycles: float = 0.5,
|
| 22 |
+
last_epoch: int = -1,
|
| 23 |
+
base_lr: float = 1e-4,
|
| 24 |
+
end_lr: float = 0.0,
|
| 25 |
+
):
|
| 26 |
+
"""Creates a cosine learning rate schedule with warm-up and ending learning rate.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
|
| 30 |
+
num_warmup_steps: An integer, the number of steps for the warmup phase.
|
| 31 |
+
num_training_steps: An integer, the total number of training steps.
|
| 32 |
+
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
|
| 33 |
+
just decrease from the max value to 0 following a half-cosine).
|
| 34 |
+
last_epoch: An integer, the index of the last epoch when resuming training.
|
| 35 |
+
base_lr: A float, the base learning rate.
|
| 36 |
+
end_lr: A float, the final learning rate.
|
| 37 |
+
|
| 38 |
+
Return:
|
| 39 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def lr_lambda(current_step):
|
| 43 |
+
if current_step < num_warmup_steps:
|
| 44 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 45 |
+
progress = float(current_step - num_warmup_steps) / \
|
| 46 |
+
float(max(1, num_training_steps - num_warmup_steps))
|
| 47 |
+
ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 48 |
+
return (end_lr + (base_lr - end_lr) * ratio) / base_lr
|
| 49 |
+
|
| 50 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_constant_schedule_with_warmup(
|
| 54 |
+
optimizer: torch.optim.Optimizer,
|
| 55 |
+
num_warmup_steps: int,
|
| 56 |
+
num_training_steps: int,
|
| 57 |
+
base_lr: float = 1e-4,
|
| 58 |
+
end_lr: float = 0.0,
|
| 59 |
+
):
|
| 60 |
+
"""UViT: Creates a constant learning rate schedule with warm-up.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
|
| 64 |
+
num_warmup_steps: An integer, the number of steps for the warmup phase.
|
| 65 |
+
num_training_steps: An integer, the total number of training steps.
|
| 66 |
+
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
|
| 67 |
+
just decrease from the max value to 0 following a half-cosine).
|
| 68 |
+
last_epoch: An integer, the index of the last epoch when resuming training.
|
| 69 |
+
base_lr: A float, the base learning rate.
|
| 70 |
+
end_lr: A float, the final learning rate.
|
| 71 |
+
|
| 72 |
+
Return:
|
| 73 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def lr_lambda(current_step):
|
| 77 |
+
if current_step < num_warmup_steps:
|
| 78 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 79 |
+
else:
|
| 80 |
+
return 1.0
|
| 81 |
+
|
| 82 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
| 86 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
| 87 |
+
SchedulerType.CONSTANT: get_constant_schedule_with_warmup,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def get_scheduler(
|
| 91 |
+
name: Union[str, SchedulerType],
|
| 92 |
+
optimizer: torch.optim.Optimizer,
|
| 93 |
+
num_warmup_steps: Optional[int] = None,
|
| 94 |
+
num_training_steps: Optional[int] = None,
|
| 95 |
+
base_lr: float = 1e-4,
|
| 96 |
+
end_lr: float = 0.0,
|
| 97 |
+
):
|
| 98 |
+
"""Retrieves a learning rate scheduler from the given name and optimizer.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
name: A string or SchedulerType, the name of the scheduler to retrieve.
|
| 102 |
+
optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler.
|
| 103 |
+
num_warmup_steps: An integer, the number of warmup steps.
|
| 104 |
+
num_training_steps: An integer, the total number of training steps.
|
| 105 |
+
base_lr: A float, the base learning rate.
|
| 106 |
+
end_lr: A float, the final learning rate.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
A instance of torch.optim.lr_scheduler.LambdaLR
|
| 110 |
+
|
| 111 |
+
Raises:
|
| 112 |
+
ValueError: If num_warmup_steps or num_training_steps is not provided.
|
| 113 |
+
"""
|
| 114 |
+
name = SchedulerType(name)
|
| 115 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 116 |
+
|
| 117 |
+
if num_warmup_steps is None:
|
| 118 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 119 |
+
|
| 120 |
+
if num_training_steps is None:
|
| 121 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 122 |
+
|
| 123 |
+
return schedule_func(
|
| 124 |
+
optimizer,
|
| 125 |
+
num_warmup_steps=num_warmup_steps,
|
| 126 |
+
num_training_steps=num_training_steps,
|
| 127 |
+
base_lr=base_lr,
|
| 128 |
+
end_lr=end_lr,
|
| 129 |
+
)
|
utils/misc.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file is borrowed from https://github.com/LTH14/mar/blob/main/util/misc.py
|
| 2 |
+
"""
|
| 3 |
+
import builtins
|
| 4 |
+
import datetime
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict, deque
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
| 13 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
| 14 |
+
|
| 15 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
| 16 |
+
from torch._six import inf
|
| 17 |
+
else:
|
| 18 |
+
from torch import inf
|
| 19 |
+
import copy
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SmoothedValue(object):
|
| 23 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 24 |
+
window or the global series average.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, window_size=20, fmt=None):
|
| 28 |
+
if fmt is None:
|
| 29 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 30 |
+
self.deque = deque(maxlen=window_size)
|
| 31 |
+
self.total = 0.0
|
| 32 |
+
self.count = 0
|
| 33 |
+
self.fmt = fmt
|
| 34 |
+
|
| 35 |
+
def update(self, value, n=1):
|
| 36 |
+
self.deque.append(value)
|
| 37 |
+
self.count += n
|
| 38 |
+
self.total += value * n
|
| 39 |
+
|
| 40 |
+
def synchronize_between_processes(self):
|
| 41 |
+
"""
|
| 42 |
+
Warning: does not synchronize the deque!
|
| 43 |
+
"""
|
| 44 |
+
if not is_dist_avail_and_initialized():
|
| 45 |
+
return
|
| 46 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 47 |
+
dist.barrier()
|
| 48 |
+
dist.all_reduce(t)
|
| 49 |
+
t = t.tolist()
|
| 50 |
+
self.count = int(t[0])
|
| 51 |
+
self.total = t[1]
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def median(self):
|
| 55 |
+
d = torch.tensor(list(self.deque))
|
| 56 |
+
return d.median().item()
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def avg(self):
|
| 60 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 61 |
+
return d.mean().item()
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def global_avg(self):
|
| 65 |
+
return self.total / self.count
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def max(self):
|
| 69 |
+
return max(self.deque)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def value(self):
|
| 73 |
+
return self.deque[-1]
|
| 74 |
+
|
| 75 |
+
def __str__(self):
|
| 76 |
+
return self.fmt.format(
|
| 77 |
+
median=self.median,
|
| 78 |
+
avg=self.avg,
|
| 79 |
+
global_avg=self.global_avg,
|
| 80 |
+
max=self.max,
|
| 81 |
+
value=self.value)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MetricLogger(object):
|
| 85 |
+
def __init__(self, delimiter="\t"):
|
| 86 |
+
self.meters = defaultdict(SmoothedValue)
|
| 87 |
+
self.delimiter = delimiter
|
| 88 |
+
|
| 89 |
+
def update(self, **kwargs):
|
| 90 |
+
for k, v in kwargs.items():
|
| 91 |
+
if v is None:
|
| 92 |
+
continue
|
| 93 |
+
if isinstance(v, torch.Tensor):
|
| 94 |
+
v = v.item()
|
| 95 |
+
assert isinstance(v, (float, int))
|
| 96 |
+
self.meters[k].update(v)
|
| 97 |
+
|
| 98 |
+
def __getattr__(self, attr):
|
| 99 |
+
if attr in self.meters:
|
| 100 |
+
return self.meters[attr]
|
| 101 |
+
if attr in self.__dict__:
|
| 102 |
+
return self.__dict__[attr]
|
| 103 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 104 |
+
type(self).__name__, attr))
|
| 105 |
+
|
| 106 |
+
def __str__(self):
|
| 107 |
+
loss_str = []
|
| 108 |
+
for name, meter in self.meters.items():
|
| 109 |
+
loss_str.append(
|
| 110 |
+
"{}: {}".format(name, str(meter))
|
| 111 |
+
)
|
| 112 |
+
return self.delimiter.join(loss_str)
|
| 113 |
+
|
| 114 |
+
def synchronize_between_processes(self):
|
| 115 |
+
for meter in self.meters.values():
|
| 116 |
+
meter.synchronize_between_processes()
|
| 117 |
+
|
| 118 |
+
def add_meter(self, name, meter):
|
| 119 |
+
self.meters[name] = meter
|
| 120 |
+
|
| 121 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 122 |
+
i = 0
|
| 123 |
+
if not header:
|
| 124 |
+
header = ''
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
end = time.time()
|
| 127 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 128 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 129 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 130 |
+
log_msg = [
|
| 131 |
+
header,
|
| 132 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 133 |
+
'eta: {eta}',
|
| 134 |
+
'{meters}',
|
| 135 |
+
'time: {time}',
|
| 136 |
+
'data: {data}'
|
| 137 |
+
]
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 140 |
+
log_msg = self.delimiter.join(log_msg)
|
| 141 |
+
MB = 1024.0 * 1024.0
|
| 142 |
+
for obj in iterable:
|
| 143 |
+
data_time.update(time.time() - end)
|
| 144 |
+
yield obj
|
| 145 |
+
iter_time.update(time.time() - end)
|
| 146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
print(log_msg.format(
|
| 151 |
+
i, len(iterable), eta=eta_string,
|
| 152 |
+
meters=str(self),
|
| 153 |
+
time=str(iter_time), data=str(data_time),
|
| 154 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 155 |
+
else:
|
| 156 |
+
print(log_msg.format(
|
| 157 |
+
i, len(iterable), eta=eta_string,
|
| 158 |
+
meters=str(self),
|
| 159 |
+
time=str(iter_time), data=str(data_time)))
|
| 160 |
+
i += 1
|
| 161 |
+
end = time.time()
|
| 162 |
+
total_time = time.time() - start_time
|
| 163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 164 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 165 |
+
header, total_time_str, total_time / len(iterable)))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def setup_for_distributed(is_master):
|
| 169 |
+
"""
|
| 170 |
+
This function disables printing when not in master process
|
| 171 |
+
"""
|
| 172 |
+
builtin_print = builtins.print
|
| 173 |
+
|
| 174 |
+
def print(*args, **kwargs):
|
| 175 |
+
force = kwargs.pop('force', False)
|
| 176 |
+
force = force or (get_world_size() > 8)
|
| 177 |
+
if is_master or force:
|
| 178 |
+
now = datetime.datetime.now().time()
|
| 179 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
| 180 |
+
builtin_print(*args, **kwargs)
|
| 181 |
+
|
| 182 |
+
builtins.print = print
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def is_dist_avail_and_initialized():
|
| 186 |
+
if not dist.is_available():
|
| 187 |
+
return False
|
| 188 |
+
if not dist.is_initialized():
|
| 189 |
+
return False
|
| 190 |
+
return True
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_world_size():
|
| 194 |
+
if not is_dist_avail_and_initialized():
|
| 195 |
+
return 1
|
| 196 |
+
return dist.get_world_size()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_rank():
|
| 200 |
+
if not is_dist_avail_and_initialized():
|
| 201 |
+
return 0
|
| 202 |
+
return dist.get_rank()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def is_main_process():
|
| 206 |
+
return get_rank() == 0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def save_on_master(*args, **kwargs):
|
| 210 |
+
if is_main_process():
|
| 211 |
+
torch.save(*args, **kwargs)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def init_distributed_mode(args):
|
| 215 |
+
if args.dist_on_itp:
|
| 216 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 217 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 218 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 219 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 220 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 221 |
+
os.environ['RANK'] = str(args.rank)
|
| 222 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 223 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 224 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 225 |
+
args.rank = int(os.environ["RANK"])
|
| 226 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 227 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 228 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 229 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 230 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 231 |
+
else:
|
| 232 |
+
print('Not using distributed mode')
|
| 233 |
+
setup_for_distributed(is_master=True) # hack
|
| 234 |
+
args.distributed = False
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
args.distributed = True
|
| 238 |
+
|
| 239 |
+
torch.cuda.set_device(args.gpu)
|
| 240 |
+
args.dist_backend = 'nccl'
|
| 241 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 242 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 243 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 244 |
+
world_size=args.world_size, rank=args.rank)
|
| 245 |
+
torch.distributed.barrier()
|
| 246 |
+
setup_for_distributed(args.rank == 0)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class NativeScalerWithGradNormCount:
|
| 250 |
+
state_dict_key = "amp_scaler"
|
| 251 |
+
|
| 252 |
+
def __init__(self):
|
| 253 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 254 |
+
|
| 255 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
| 256 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 257 |
+
if update_grad:
|
| 258 |
+
if clip_grad is not None:
|
| 259 |
+
assert parameters is not None
|
| 260 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 261 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 262 |
+
else:
|
| 263 |
+
self._scaler.unscale_(optimizer)
|
| 264 |
+
norm = get_grad_norm_(parameters)
|
| 265 |
+
self._scaler.step(optimizer)
|
| 266 |
+
self._scaler.update()
|
| 267 |
+
else:
|
| 268 |
+
norm = None
|
| 269 |
+
return norm
|
| 270 |
+
|
| 271 |
+
def state_dict(self):
|
| 272 |
+
return self._scaler.state_dict()
|
| 273 |
+
|
| 274 |
+
def load_state_dict(self, state_dict):
|
| 275 |
+
self._scaler.load_state_dict(state_dict)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 279 |
+
if isinstance(parameters, torch.Tensor):
|
| 280 |
+
parameters = [parameters]
|
| 281 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 282 |
+
norm_type = float(norm_type)
|
| 283 |
+
if len(parameters) == 0:
|
| 284 |
+
return torch.tensor(0.)
|
| 285 |
+
device = parameters[0].grad.device
|
| 286 |
+
if norm_type == inf:
|
| 287 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 288 |
+
else:
|
| 289 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 290 |
+
return total_norm
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
| 294 |
+
decay = []
|
| 295 |
+
no_decay = []
|
| 296 |
+
for name, param in model.named_parameters():
|
| 297 |
+
if not param.requires_grad:
|
| 298 |
+
continue # frozen weights
|
| 299 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
|
| 300 |
+
no_decay.append(param) # no weight decay on bias, norm and diffloss
|
| 301 |
+
else:
|
| 302 |
+
decay.append(param)
|
| 303 |
+
return [
|
| 304 |
+
{'params': no_decay, 'weight_decay': 0.},
|
| 305 |
+
{'params': decay, 'weight_decay': weight_decay}]
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
|
| 309 |
+
if epoch_name is None:
|
| 310 |
+
epoch_name = str(epoch)
|
| 311 |
+
output_dir = Path(args.output_dir)
|
| 312 |
+
checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
|
| 313 |
+
|
| 314 |
+
# ema
|
| 315 |
+
if ema_params is not None:
|
| 316 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
| 317 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
| 318 |
+
assert name in ema_state_dict
|
| 319 |
+
ema_state_dict[name] = ema_params[i]
|
| 320 |
+
else:
|
| 321 |
+
ema_state_dict = None
|
| 322 |
+
|
| 323 |
+
to_save = {
|
| 324 |
+
'model': model_without_ddp.state_dict(),
|
| 325 |
+
'model_ema': ema_state_dict,
|
| 326 |
+
'optimizer': optimizer.state_dict(),
|
| 327 |
+
'epoch': epoch,
|
| 328 |
+
'scaler': loss_scaler.state_dict(),
|
| 329 |
+
'args': args,
|
| 330 |
+
}
|
| 331 |
+
save_on_master(to_save, checkpoint_path)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def all_reduce_mean(x):
|
| 335 |
+
world_size = get_world_size()
|
| 336 |
+
if world_size > 1:
|
| 337 |
+
x_reduce = torch.tensor(x).cuda()
|
| 338 |
+
dist.all_reduce(x_reduce)
|
| 339 |
+
x_reduce /= world_size
|
| 340 |
+
return x_reduce.item()
|
| 341 |
+
else:
|
| 342 |
+
return x
|