Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +9 -0
- VAR/ILSVRC2012_img_train.torrent +0 -0
- VAR/ILSVRC2012_img_val.torrent +0 -0
- VAR/Imagenette/imagenette2-160.tgz +3 -0
- VAR/Imagenette/imagenette2-320.tgz +3 -0
- VAR/Imagenette/imagenette2.tgz +3 -0
- VAR/cifar-10-batches-py/batches.meta +0 -0
- VAR/cifar-10-batches-py/data_batch_1 +3 -0
- VAR/cifar-10-batches-py/data_batch_2 +3 -0
- VAR/cifar-10-batches-py/data_batch_3 +3 -0
- VAR/cifar-10-batches-py/data_batch_4 +3 -0
- VAR/cifar-10-batches-py/data_batch_5 +3 -0
- VAR/cifar-10-batches-py/readme.html +1 -0
- VAR/cifar-10-batches-py/test_batch +3 -0
- VAR/cifar-10-python.tar.gz +3 -0
- VAR/cifar-100-python.tar.gz +3 -0
- VAR/cifar-100-python/file.txt~ +0 -0
- VAR/cifar-100-python/meta +10 -0
- VAR/cifar-100-python/test +3 -0
- VAR/cifar-100-python/train +3 -0
- VAR/code/VAR/LICENSE +21 -0
- VAR/code/VAR/README.md +232 -0
- VAR/code/VAR/__pycache__/dist.cpython-310.pyc +0 -0
- VAR/code/VAR/__pycache__/dist.cpython-311.pyc +0 -0
- VAR/code/VAR/__pycache__/trainer.cpython-310.pyc +0 -0
- VAR/code/VAR/config.sh +30 -0
- VAR/code/VAR/demo_sample.ipynb +127 -0
- VAR/code/VAR/demo_zero_shot_edit.ipynb +0 -0
- VAR/code/VAR/dist.py +211 -0
- VAR/code/VAR/models/__init__.py +39 -0
- VAR/code/VAR/models/__pycache__/__init__.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/basic_vae.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/basic_var.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/helpers.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/quant.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/var.cpython-310.pyc +0 -0
- VAR/code/VAR/models/__pycache__/vqvae.cpython-310.pyc +0 -0
- VAR/code/VAR/models/basic_vae.py +226 -0
- VAR/code/VAR/models/basic_var.py +174 -0
- VAR/code/VAR/models/helpers.py +59 -0
- VAR/code/VAR/models/quant.py +243 -0
- VAR/code/VAR/models/var.py +419 -0
- VAR/code/VAR/models/vqvae.py +95 -0
- VAR/code/VAR/requirements.txt +8 -0
- VAR/code/VAR/train.py +357 -0
- VAR/code/VAR/trainer.py +201 -0
- VAR/code/VAR/utils/__pycache__/amp_sc.cpython-310.pyc +0 -0
- VAR/code/VAR/utils/__pycache__/arg_util.cpython-310.pyc +0 -0
- VAR/code/VAR/utils/__pycache__/arg_util.cpython-311.pyc +0 -0
- VAR/code/VAR/utils/__pycache__/data.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -7967,3 +7967,12 @@ Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_
|
|
| 7967 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_7.png filter=lfs diff=lfs merge=lfs -text
|
| 7968 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_8.png filter=lfs diff=lfs merge=lfs -text
|
| 7969 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_9.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7967 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_7.png filter=lfs diff=lfs merge=lfs -text
|
| 7968 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_8.png filter=lfs diff=lfs merge=lfs -text
|
| 7969 |
Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_9.png filter=lfs diff=lfs merge=lfs -text
|
| 7970 |
+
VAR/cifar-10-batches-py/data_batch_1 filter=lfs diff=lfs merge=lfs -text
|
| 7971 |
+
VAR/cifar-10-batches-py/data_batch_2 filter=lfs diff=lfs merge=lfs -text
|
| 7972 |
+
VAR/cifar-10-batches-py/data_batch_3 filter=lfs diff=lfs merge=lfs -text
|
| 7973 |
+
VAR/cifar-10-batches-py/data_batch_4 filter=lfs diff=lfs merge=lfs -text
|
| 7974 |
+
VAR/cifar-10-batches-py/data_batch_5 filter=lfs diff=lfs merge=lfs -text
|
| 7975 |
+
VAR/cifar-10-batches-py/test_batch filter=lfs diff=lfs merge=lfs -text
|
| 7976 |
+
VAR/cifar-100-python/test filter=lfs diff=lfs merge=lfs -text
|
| 7977 |
+
VAR/cifar-100-python/train filter=lfs diff=lfs merge=lfs -text
|
| 7978 |
+
VAR/imagenet/LOC_train_solution.csv filter=lfs diff=lfs merge=lfs -text
|
VAR/ILSVRC2012_img_train.torrent
ADDED
|
File without changes
|
VAR/ILSVRC2012_img_val.torrent
ADDED
|
File without changes
|
VAR/Imagenette/imagenette2-160.tgz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64d0c4859f35a461889e0147755a999a48b49bf38a7e0f9bd27003f10db02fe5
|
| 3 |
+
size 99003388
|
VAR/Imagenette/imagenette2-320.tgz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:569b4497c98db6dd29f335d1f109cf315fe127053cedf69010d047f0188e158c
|
| 3 |
+
size 341663724
|
VAR/Imagenette/imagenette2.tgz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6cbfac238434d89fe99e651496f0812ebc7a10fa62bd42d6874042bf01de4efd
|
| 3 |
+
size 1557161267
|
VAR/cifar-10-batches-py/batches.meta
ADDED
|
Binary file (158 Bytes). View file
|
|
|
VAR/cifar-10-batches-py/data_batch_1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54636561a3ce25bd3e19253c6b0d8538147b0ae398331ac4a2d86c6d987368cd
|
| 3 |
+
size 31035704
|
VAR/cifar-10-batches-py/data_batch_2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:766b2cef9fbc745cf056b3152224f7cf77163b330ea9a15f9392beb8b89bc5a8
|
| 3 |
+
size 31035320
|
VAR/cifar-10-batches-py/data_batch_3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f00d98ebfb30b3ec0ad19f9756dc2630b89003e10525f5e148445e82aa6a1f9
|
| 3 |
+
size 31035999
|
VAR/cifar-10-batches-py/data_batch_4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f7bb240661948b8f4d53e36ec720d8306f5668bd0071dcb4e6c947f78e9682b
|
| 3 |
+
size 31035696
|
VAR/cifar-10-batches-py/data_batch_5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d91802434d8376bbaeeadf58a737e3a1b12ac839077e931237e0dcd43adcb154
|
| 3 |
+
size 31035623
|
VAR/cifar-10-batches-py/readme.html
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
<meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
|
VAR/cifar-10-batches-py/test_batch
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f53d8d457504f7cff4ea9e021afcf0e0ad8e24a91f3fc42091b8adef61157831
|
| 3 |
+
size 31035526
|
VAR/cifar-10-python.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce
|
| 3 |
+
size 170498071
|
VAR/cifar-100-python.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7
|
| 3 |
+
size 169001437
|
VAR/cifar-100-python/file.txt~
ADDED
|
File without changes
|
VAR/cifar-100-python/meta
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
�}q(Ufine_label_namesq]q(UappleqU
|
| 2 |
+
UbeetleqUbicycleqUbottleq
|
| 3 |
+
chimpanzeeqUclockqUcloudqU cockroachqUcouchqUcrabqU crocodileqUcupq Udinosaurq!Udolphinq"Uelephantq#Uflatfishq$Uforestq%Ufoxq&Ugirlq'Uhamsterq(Uhouseq)Ukangarooq*Ukeyboardq+Ulampq,U
|
| 4 |
+
lawn_mowerq-Uleopardq.Ulionq/Ulizardq0Ulobsterq1Umanq2U
|
| 5 |
+
maple_treeq3U
|
| 6 |
+
motorcycleq4Umountainq5Umouseq6Umushroomq7Uoak_treeq8Uorangeq9Uorchidq:Uotterq;U palm_treeq<Upearq=Upickup_truckq>U pine_treeq?Uplainq@UplateqAUpoppyqBU porcupineqCUpossumqDUrabbitqEUraccoonqFUrayqGUroadqHUrocketqIUroseqJUseaqKUsealqLUsharkqMUshrewqNUskunkqOU
|
| 7 |
+
skyscraperqPUsnailqQUsnakeqRUspiderqSUsquirrelqTU streetcarqUU sunflowerqVUsweet_pepperqWUtableqXUtankqYU telephoneqZU
|
| 8 |
+
televisionq[Utigerq\Utractorq]Utrainq^Utroutq_Utulipq`UturtleqaUwardrobeqbUwhaleqcUwillow_treeqdUwolfqeUwomanqfUwormqgeUcoarse_label_namesqh]qi(Uaquatic_mammalsqjUfishqkUflowersqlUfood_containersqmUfruit_and_vegetablesqnUhousehold_electrical_devicesqoUhousehold_furnitureqpUinsectsqqUlarge_carnivoresqrUlarge_man-made_outdoor_thingsqsUlarge_natural_outdoor_scenesqtUlarge_omnivores_and_herbivoresquUmedium_mammalsqvUnon-insect_invertebratesqwUpeopleqxUreptilesqyU
|
| 9 |
+
vehicles_1q|U
|
| 10 |
+
vehicles_2q}eu.
|
VAR/cifar-100-python/test
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b67687d9933c4db8f0831104447f15b93774f4f464bd0516f0f0f2ac83b7864
|
| 3 |
+
size 31049707
|
VAR/cifar-100-python/train
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:735e79b04f092ca3d2e6d07f368c0a7d70d48c48d28865950cc24454cf45129b
|
| 3 |
+
size 155249918
|
VAR/code/VAR/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 FoundationVision
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
VAR/code/VAR/README.md
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
|
| 5 |
+
[](https://opensource.bytedance.com/gmpt/t2i/invite)
|
| 6 |
+
[](https://arxiv.org/abs/2404.02905)
|
| 7 |
+
[](https://huggingface.co/FoundationVision/var)
|
| 8 |
+
[](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
</div>
|
| 12 |
+
<p align="center" style="font-size: larger;">
|
| 13 |
+
<a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
<div>
|
| 17 |
+
<p align="center" style="font-size: larger;">
|
| 18 |
+
<strong>NeurIPS 2024 Best Paper</strong>
|
| 19 |
+
</p>
|
| 20 |
+
</div>
|
| 21 |
+
|
| 22 |
+
<p align="center">
|
| 23 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
|
| 24 |
+
<p>
|
| 25 |
+
|
| 26 |
+
<br>
|
| 27 |
+
|
| 28 |
+
## News
|
| 29 |
+
|
| 30 |
+
* **2024-12:** 🏆 VAR received **NeurIPS 2024 Best Paper Award**.
|
| 31 |
+
* **2024-12:** 🔥 We Release our Text-to-Image research based on VAR, please check [Infinity](https://github.com/FoundationVision/Infinity).
|
| 32 |
+
* **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.
|
| 33 |
+
* **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.
|
| 34 |
+
|
| 35 |
+
## 🕹️ Try and Play with VAR!
|
| 36 |
+
|
| 37 |
+
~~We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!~~
|
| 38 |
+
|
| 39 |
+
We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with VAR Text-to-Image and generate images interactively. Enjoy the fun of visual autoregressive modeling!
|
| 40 |
+
|
| 41 |
+
We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
|
| 42 |
+
|
| 43 |
+
[//]: # (<p align="center">)
|
| 44 |
+
[//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
|
| 45 |
+
[//]: # (<p>)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## What's New?
|
| 49 |
+
|
| 50 |
+
### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
|
| 51 |
+
|
| 52 |
+
Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
|
| 53 |
+
|
| 54 |
+
<p align="center">
|
| 55 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
|
| 56 |
+
<p>
|
| 57 |
+
|
| 58 |
+
### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
|
| 59 |
+
<p align="center">
|
| 60 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
|
| 61 |
+
<p>
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
<p align="center">
|
| 68 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
|
| 69 |
+
<p>
|
| 70 |
+
<p align="center">
|
| 71 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
|
| 72 |
+
<p>
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
### 🔥 Zero-shot generalizability🛠️:
|
| 76 |
+
|
| 77 |
+
<p align="center">
|
| 78 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
|
| 79 |
+
<p>
|
| 80 |
+
|
| 81 |
+
#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## VAR zoo
|
| 85 |
+
We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
|
| 86 |
+
|
| 87 |
+
| model | reso. | FID | rel. cost | #params | HF weights🤗 |
|
| 88 |
+
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
|
| 89 |
+
| VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
|
| 90 |
+
| VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
|
| 91 |
+
| VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
|
| 92 |
+
| VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
|
| 93 |
+
| VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
|
| 94 |
+
| VAR-d36 | 512 | **2.63** | - | 2.3B | [var_d36.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d36.pth) |
|
| 95 |
+
|
| 96 |
+
You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
## Installation
|
| 100 |
+
|
| 101 |
+
1. Install `torch>=2.0.0`.
|
| 102 |
+
2. Install other pip packages via `pip3 install -r requirements.txt`.
|
| 103 |
+
3. Prepare the [ImageNet](http://image-net.org/) dataset
|
| 104 |
+
<details>
|
| 105 |
+
<summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
/path/to/imagenet/:
|
| 109 |
+
train/:
|
| 110 |
+
n01440764:
|
| 111 |
+
many_images.JPEG ...
|
| 112 |
+
n01443537:
|
| 113 |
+
many_images.JPEG ...
|
| 114 |
+
val/:
|
| 115 |
+
n01440764:
|
| 116 |
+
ILSVRC2012_val_00000293.JPEG ...
|
| 117 |
+
n01443537:
|
| 118 |
+
ILSVRC2012_val_00000236.JPEG ...
|
| 119 |
+
```
|
| 120 |
+
**NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
|
| 121 |
+
</details>
|
| 122 |
+
|
| 123 |
+
5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## Training Scripts
|
| 127 |
+
|
| 128 |
+
To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
|
| 129 |
+
```shell
|
| 130 |
+
# d16, 256x256
|
| 131 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 132 |
+
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
|
| 133 |
+
# d20, 256x256
|
| 134 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 135 |
+
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
|
| 136 |
+
# d24, 256x256
|
| 137 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 138 |
+
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
|
| 139 |
+
# d30, 256x256
|
| 140 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 141 |
+
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
|
| 142 |
+
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
|
| 143 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 144 |
+
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
|
| 145 |
+
```
|
| 146 |
+
A folder named `local_output` will be created to save the checkpoints and logs.
|
| 147 |
+
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
|
| 148 |
+
|
| 149 |
+
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
|
| 150 |
+
|
| 151 |
+
## Sampling & Zero-shot Inference
|
| 152 |
+
|
| 153 |
+
For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
|
| 154 |
+
Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
|
| 155 |
+
|
| 156 |
+
Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
|
| 157 |
+
We'll provide the sampling script later.
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
## Third-party Usage and Research
|
| 161 |
+
|
| 162 |
+
***In this pargraph, we cross link third-party repositories or research which use VAR and report results. You can let us know by raising an issue***
|
| 163 |
+
|
| 164 |
+
(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
|
| 165 |
+
|
| 166 |
+
| **Time** | **Research** | **Link** |
|
| 167 |
+
|--------------|-------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------|
|
| 168 |
+
| [5/12/2025] | [ICML 2025]Continuous Visual Autoregressive Generation via Score Maximization | https://github.com/shaochenze/EAR |
|
| 169 |
+
| [5/8/2025] | Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction | https://github.com/icon-lab/FedGAT |
|
| 170 |
+
| [4/7/2025] | FastVAR: Linear Visual Autoregressive Modeling via Cached Token Pruning | https://github.com/csguoh/FastVAR |
|
| 171 |
+
| [4/3/2025] | VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning | https://github.com/VARGPT-family/VARGPT-v1.1 |
|
| 172 |
+
| [3/31/2025] | Training-Free Text-Guided Image Editing with Visual Autoregressive Model | https://github.com/wyf0912/AREdit |
|
| 173 |
+
| [3/17/2025] | Next-Scale Autoregressive Models are Zero-Shot Single-Image Object View Synthesizers | https://github.com/Shiran-Yuan/ArchonView |
|
| 174 |
+
| [3/14/2025] | Safe-VAR: Safe Visual Autoregressive Model for Text-to-Image Generative Watermarking | https://arxiv.org/abs/2503.11324 |
|
| 175 |
+
| [3/3/2025] | [ICML 2025]Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator | https://research.nvidia.com/labs/dir/ddo/ |
|
| 176 |
+
| [2/28/2025] | Autoregressive Medical Image Segmentation via Next-Scale Mask Prediction | https://arxiv.org/abs/2502.20784 |
|
| 177 |
+
| [2/27/2025] | FlexVAR: Flexible Visual Autoregressive Modeling without Residual Prediction | https://github.com/jiaosiyu1999/FlexVAR |
|
| 178 |
+
| [2/17/2025] | MARS: Mesh AutoRegressive Model for 3D Shape Detailization | https://arxiv.org/abs/2502.11390 |
|
| 179 |
+
| [1/31/2025] | [ICML 2025]Visual Autoregressive Modeling for Image Super-Resolution | https://github.com/quyp2000/VARSR |
|
| 180 |
+
| [1/21/2025] | VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model | https://github.com/VARGPT-family/VARGPT |
|
| 181 |
+
| [1/26/2025] | [ICML 2025]Visual Generation Without Guidance | https://github.com/thu-ml/GFT |
|
| 182 |
+
| [12/30/2024] | Next Token Prediction Towards Multimodal Intelligence | https://github.com/LMM101/Awesome-Multimodal-Next-Token-Prediction |
|
| 183 |
+
| [12/30/2024] | Varformer: Adapting VAR’s Generative Prior for Image Restoration | https://arxiv.org/abs/2412.21063 |
|
| 184 |
+
| [12/22/2024] | [ICLR 2025]Distilled Decoding 1: One-step Sampling of Image Auto-regressive Models with Flow Matching | https://github.com/imagination-research/distilled-decoding |
|
| 185 |
+
| [12/19/2024] | FlowAR: Scale-wise Autoregressive Image Generation Meets Flow Matching | https://github.com/OliverRensu/FlowAR |
|
| 186 |
+
| [12/13/2024] | 3D representation in 512-Byte: Variational tokenizer is the key for autoregressive 3D generation | https://github.com/sparse-mvs-2/VAT |
|
| 187 |
+
| [12/9/2024] | CARP: Visuomotor Policy Learning via Coarse-to-Fine Autoregressive Prediction | https://carp-robot.github.io/ |
|
| 188 |
+
| [12/5/2024] | [CVPR 2025]Infinity ∞: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis | https://github.com/FoundationVision/Infinity |
|
| 189 |
+
| [12/5/2024] | [CVPR 2025]Switti: Designing Scale-Wise Transformers for Text-to-Image Synthesis | https://github.com/yandex-research/switti |
|
| 190 |
+
| [12/4/2024] | [CVPR 2025]TokenFlow🚀: Unified Image Tokenizer for Multimodal Understanding and Generation | https://github.com/ByteFlow-AI/TokenFlow |
|
| 191 |
+
| [12/3/2024] | XQ-GAN🚀: An Open-source Image Tokenization Framework for Autoregressive Generation | https://github.com/lxa9867/ImageFolder |
|
| 192 |
+
| [11/28/2024] | [CVPR 2025]CoDe: Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient | https://github.com/czg1225/CoDe |
|
| 193 |
+
| [11/28/2024] | [CVPR 2025]Scalable Autoregressive Monocular Depth Estimation | https://arxiv.org/abs/2411.11361 |
|
| 194 |
+
| [11/27/2024] | [CVPR 2025]SAR3D: Autoregressive 3D Object Generation and Understanding via Multi-scale 3D VQVAE | https://github.com/cyw-3d/SAR3D |
|
| 195 |
+
| [11/26/2024] | LiteVAR: Compressing Visual Autoregressive Modelling with Efficient Attention and Quantization | https://arxiv.org/abs/2411.17178 |
|
| 196 |
+
| [11/15/2024] | M-VAR: Decoupled Scale-wise Autoregressive Modeling for High-Quality Image Generation | https://github.com/OliverRensu/MVAR |
|
| 197 |
+
| [10/14/2024] | [ICLR 2025]HART: Efficient Visual Generation with Hybrid Autoregressive Transformer | https://github.com/mit-han-lab/hart |
|
| 198 |
+
| [10/12/2024] | [ICLR 2025 Oral]Toward Guidance-Free AR Visual Generation via Condition Contrastive Alignment | https://github.com/thu-ml/CCA |
|
| 199 |
+
| [10/3/2024] | [ICLR 2025]ImageFolder🚀: Autoregressive Image Generation with Folded Tokens | https://github.com/lxa9867/ImageFolder |
|
| 200 |
+
| [07/25/2024] | ControlVAR: Exploring Controllable Visual Autoregressive Modeling | https://github.com/lxa9867/ControlVAR |
|
| 201 |
+
| [07/3/2024] | VAR-CLIP: Text-to-Image Generator with Visual Auto-Regressive Modeling | https://github.com/daixiangzi/VAR-CLIP |
|
| 202 |
+
| [06/16/2024] | STAR: Scale-wise Text-to-image generation via Auto-Regressive representations | https://arxiv.org/abs/2406.10797 |
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
## License
|
| 206 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
## Citation
|
| 210 |
+
If our work assists your research, feel free to give us a star ⭐ or cite us using:
|
| 211 |
+
```
|
| 212 |
+
@Article{VAR,
|
| 213 |
+
title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
|
| 214 |
+
author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
|
| 215 |
+
year={2024},
|
| 216 |
+
eprint={2404.02905},
|
| 217 |
+
archivePrefix={arXiv},
|
| 218 |
+
primaryClass={cs.CV}
|
| 219 |
+
}
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
```
|
| 223 |
+
@misc{Infinity,
|
| 224 |
+
title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis},
|
| 225 |
+
author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu},
|
| 226 |
+
year={2024},
|
| 227 |
+
eprint={2412.04431},
|
| 228 |
+
archivePrefix={arXiv},
|
| 229 |
+
primaryClass={cs.CV},
|
| 230 |
+
url={https://arxiv.org/abs/2412.04431},
|
| 231 |
+
}
|
| 232 |
+
```
|
VAR/code/VAR/__pycache__/dist.cpython-310.pyc
ADDED
|
Binary file (6.43 kB). View file
|
|
|
VAR/code/VAR/__pycache__/dist.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
VAR/code/VAR/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (6.91 kB). View file
|
|
|
VAR/code/VAR/config.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# d16, 256x256
|
| 5 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 6 |
+
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
|
| 7 |
+
# d20, 256x256
|
| 8 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 9 |
+
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
|
| 10 |
+
# d24, 256x256
|
| 11 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 12 |
+
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
|
| 13 |
+
# d30, 256x256
|
| 14 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 15 |
+
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
|
| 16 |
+
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
|
| 17 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
| 18 |
+
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
|
| 23 |
+
--master_addr=127.0.0.1 --master_port=29500 \
|
| 24 |
+
train.py --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
|
| 25 |
+
|
| 26 |
+
torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
|
| 27 |
+
--master_addr=127.0.0.1 --master_port=29500 \
|
| 28 |
+
train.py --depth=16 --bs=96 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
|
| 29 |
+
|
| 30 |
+
python train.py --depth=16 --bs=96 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
|
VAR/code/VAR/demo_sample.ipynb
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"source": [
|
| 6 |
+
"### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟"
|
| 7 |
+
],
|
| 8 |
+
"metadata": {
|
| 9 |
+
"collapsed": false
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"outputs": [],
|
| 16 |
+
"source": [
|
| 17 |
+
"################## 1. Download checkpoints and build models\n",
|
| 18 |
+
"import os\n",
|
| 19 |
+
"import os.path as osp\n",
|
| 20 |
+
"import torch, torchvision\n",
|
| 21 |
+
"import random\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
|
| 24 |
+
"setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
|
| 25 |
+
"setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
|
| 26 |
+
"from models import VQVAE, build_vae_var\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====\n",
|
| 29 |
+
"assert MODEL_DEPTH in {16, 20, 24, 30}\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"# download checkpoint\n",
|
| 33 |
+
"hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
|
| 34 |
+
"vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'\n",
|
| 35 |
+
"if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')\n",
|
| 36 |
+
"if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# build vae, var\n",
|
| 39 |
+
"patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)\n",
|
| 40 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 41 |
+
"if 'vae' not in globals() or 'var' not in globals():\n",
|
| 42 |
+
" vae, var = build_vae_var(\n",
|
| 43 |
+
" V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters\n",
|
| 44 |
+
" device=device, patch_nums=patch_nums,\n",
|
| 45 |
+
" num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
|
| 46 |
+
" )\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"# load checkpoints\n",
|
| 49 |
+
"vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)\n",
|
| 50 |
+
"var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)\n",
|
| 51 |
+
"vae.eval(), var.eval()\n",
|
| 52 |
+
"for p in vae.parameters(): p.requires_grad_(False)\n",
|
| 53 |
+
"for p in var.parameters(): p.requires_grad_(False)\n",
|
| 54 |
+
"print(f'prepare finished.')"
|
| 55 |
+
],
|
| 56 |
+
"metadata": {
|
| 57 |
+
"collapsed": false,
|
| 58 |
+
"is_executing": true
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": null,
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"############################# 2. Sample with classifier-free guidance\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# set args\n",
|
| 69 |
+
"seed = 0 #@param {type:\"number\"}\n",
|
| 70 |
+
"torch.manual_seed(seed)\n",
|
| 71 |
+
"num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
|
| 72 |
+
"cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
|
| 73 |
+
"class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:\"raw\"}\n",
|
| 74 |
+
"more_smooth = False # True for more smooth output\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# seed\n",
|
| 77 |
+
"torch.manual_seed(seed)\n",
|
| 78 |
+
"random.seed(seed)\n",
|
| 79 |
+
"np.random.seed(seed)\n",
|
| 80 |
+
"torch.backends.cudnn.deterministic = True\n",
|
| 81 |
+
"torch.backends.cudnn.benchmark = False\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"# run faster\n",
|
| 84 |
+
"tf32 = True\n",
|
| 85 |
+
"torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
|
| 86 |
+
"torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
|
| 87 |
+
"torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# sample\n",
|
| 90 |
+
"B = len(class_labels)\n",
|
| 91 |
+
"label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
|
| 92 |
+
"with torch.inference_mode():\n",
|
| 93 |
+
" with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # using bfloat16 can be faster\n",
|
| 94 |
+
" recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
|
| 97 |
+
"chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
|
| 98 |
+
"chw = PImage.fromarray(chw.astype(np.uint8))\n",
|
| 99 |
+
"chw.show()\n"
|
| 100 |
+
],
|
| 101 |
+
"metadata": {
|
| 102 |
+
"collapsed": false
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
],
|
| 106 |
+
"metadata": {
|
| 107 |
+
"kernelspec": {
|
| 108 |
+
"display_name": "Python 3",
|
| 109 |
+
"language": "python",
|
| 110 |
+
"name": "python3"
|
| 111 |
+
},
|
| 112 |
+
"language_info": {
|
| 113 |
+
"codemirror_mode": {
|
| 114 |
+
"name": "ipython",
|
| 115 |
+
"version": 2
|
| 116 |
+
},
|
| 117 |
+
"file_extension": ".py",
|
| 118 |
+
"mimetype": "text/x-python",
|
| 119 |
+
"name": "python",
|
| 120 |
+
"nbconvert_exporter": "python",
|
| 121 |
+
"pygments_lexer": "ipython2",
|
| 122 |
+
"version": "2.7.6"
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
"nbformat": 4,
|
| 126 |
+
"nbformat_minor": 0
|
| 127 |
+
}
|
VAR/code/VAR/demo_zero_shot_edit.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
VAR/code/VAR/dist.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from typing import List
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as tdist
|
| 10 |
+
import torch.multiprocessing as mp
|
| 11 |
+
|
| 12 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 13 |
+
__initialized = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def initialized():
|
| 17 |
+
return __initialized
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
|
| 21 |
+
global __device
|
| 22 |
+
if not torch.cuda.is_available():
|
| 23 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
| 24 |
+
return
|
| 25 |
+
elif 'RANK' not in os.environ:
|
| 26 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
| 27 |
+
__device = torch.empty(1).cuda().device
|
| 28 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
| 29 |
+
return
|
| 30 |
+
# then 'RANK' must exist
|
| 31 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
| 32 |
+
local_rank = global_rank % num_gpus
|
| 33 |
+
torch.cuda.set_device(local_rank)
|
| 34 |
+
|
| 35 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
| 36 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 37 |
+
method = 'fork' if fork else 'spawn'
|
| 38 |
+
print(f'[dist initialize] mp method={method}')
|
| 39 |
+
mp.set_start_method(method)
|
| 40 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
|
| 41 |
+
|
| 42 |
+
global __rank, __local_rank, __world_size, __initialized
|
| 43 |
+
__local_rank = local_rank
|
| 44 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
| 45 |
+
__device = torch.empty(1).cuda().device
|
| 46 |
+
__initialized = True
|
| 47 |
+
|
| 48 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
| 49 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_rank():
|
| 53 |
+
return __rank
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_local_rank():
|
| 57 |
+
return __local_rank
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_world_size():
|
| 61 |
+
return __world_size
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_device():
|
| 65 |
+
return __device
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def set_gpu_id(gpu_id: int):
|
| 69 |
+
if gpu_id is None: return
|
| 70 |
+
global __device
|
| 71 |
+
if isinstance(gpu_id, (str, int)):
|
| 72 |
+
torch.cuda.set_device(int(gpu_id))
|
| 73 |
+
__device = torch.empty(1).cuda().device
|
| 74 |
+
else:
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def is_master():
|
| 79 |
+
return __rank == 0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def is_local_master():
|
| 83 |
+
return __local_rank == 0
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def new_group(ranks: List[int]):
|
| 87 |
+
if __initialized:
|
| 88 |
+
return tdist.new_group(ranks=ranks)
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def barrier():
|
| 93 |
+
if __initialized:
|
| 94 |
+
tdist.barrier()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
| 98 |
+
if __initialized:
|
| 99 |
+
if not t.is_cuda:
|
| 100 |
+
cu = t.detach().cuda()
|
| 101 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
| 102 |
+
t.copy_(cu.cpu())
|
| 103 |
+
else:
|
| 104 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
| 105 |
+
return ret
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 110 |
+
if __initialized:
|
| 111 |
+
if not t.is_cuda:
|
| 112 |
+
t = t.cuda()
|
| 113 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
| 114 |
+
tdist.all_gather(ls, t)
|
| 115 |
+
else:
|
| 116 |
+
ls = [t]
|
| 117 |
+
if cat:
|
| 118 |
+
ls = torch.cat(ls, dim=0)
|
| 119 |
+
return ls
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 123 |
+
if __initialized:
|
| 124 |
+
if not t.is_cuda:
|
| 125 |
+
t = t.cuda()
|
| 126 |
+
|
| 127 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
| 128 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
| 129 |
+
tdist.all_gather(ls_size, t_size)
|
| 130 |
+
|
| 131 |
+
max_B = max(size[0].item() for size in ls_size)
|
| 132 |
+
pad = max_B - t_size[0].item()
|
| 133 |
+
if pad:
|
| 134 |
+
pad_size = (pad, *t.size()[1:])
|
| 135 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
| 136 |
+
|
| 137 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
| 138 |
+
tdist.all_gather(ls_padded, t)
|
| 139 |
+
ls = []
|
| 140 |
+
for t, size in zip(ls_padded, ls_size):
|
| 141 |
+
ls.append(t[:size[0].item()])
|
| 142 |
+
else:
|
| 143 |
+
ls = [t]
|
| 144 |
+
if cat:
|
| 145 |
+
ls = torch.cat(ls, dim=0)
|
| 146 |
+
return ls
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
| 150 |
+
if __initialized:
|
| 151 |
+
if not t.is_cuda:
|
| 152 |
+
cu = t.detach().cuda()
|
| 153 |
+
tdist.broadcast(cu, src=src_rank)
|
| 154 |
+
t.copy_(cu.cpu())
|
| 155 |
+
else:
|
| 156 |
+
tdist.broadcast(t, src=src_rank)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
| 160 |
+
if not initialized():
|
| 161 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
| 162 |
+
|
| 163 |
+
ts = torch.zeros(__world_size)
|
| 164 |
+
ts[__rank] = val
|
| 165 |
+
allreduce(ts)
|
| 166 |
+
if fmt is None:
|
| 167 |
+
return ts
|
| 168 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def master_only(func):
|
| 172 |
+
@functools.wraps(func)
|
| 173 |
+
def wrapper(*args, **kwargs):
|
| 174 |
+
force = kwargs.pop('force', False)
|
| 175 |
+
if force or is_master():
|
| 176 |
+
ret = func(*args, **kwargs)
|
| 177 |
+
else:
|
| 178 |
+
ret = None
|
| 179 |
+
barrier()
|
| 180 |
+
return ret
|
| 181 |
+
return wrapper
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def local_master_only(func):
|
| 185 |
+
@functools.wraps(func)
|
| 186 |
+
def wrapper(*args, **kwargs):
|
| 187 |
+
force = kwargs.pop('force', False)
|
| 188 |
+
if force or is_local_master():
|
| 189 |
+
ret = func(*args, **kwargs)
|
| 190 |
+
else:
|
| 191 |
+
ret = None
|
| 192 |
+
barrier()
|
| 193 |
+
return ret
|
| 194 |
+
return wrapper
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def for_visualize(func):
|
| 198 |
+
@functools.wraps(func)
|
| 199 |
+
def wrapper(*args, **kwargs):
|
| 200 |
+
if is_master():
|
| 201 |
+
# with torch.no_grad():
|
| 202 |
+
ret = func(*args, **kwargs)
|
| 203 |
+
else:
|
| 204 |
+
ret = None
|
| 205 |
+
return ret
|
| 206 |
+
return wrapper
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def finalize():
|
| 210 |
+
if __initialized:
|
| 211 |
+
tdist.destroy_process_group()
|
VAR/code/VAR/models/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .quant import VectorQuantizer2
|
| 5 |
+
from .var import VAR
|
| 6 |
+
from .vqvae import VQVAE
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_vae_var(
|
| 10 |
+
# Shared args
|
| 11 |
+
device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
| 12 |
+
# VQVAE args
|
| 13 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
| 14 |
+
# VAR args
|
| 15 |
+
num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
|
| 16 |
+
flash_if_available=True, fused_if_available=True,
|
| 17 |
+
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
|
| 18 |
+
) -> Tuple[VQVAE, VAR]:
|
| 19 |
+
heads = depth
|
| 20 |
+
width = depth * 64
|
| 21 |
+
dpr = 0.1 * depth/24
|
| 22 |
+
|
| 23 |
+
# disable built-in initialization for speed
|
| 24 |
+
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
|
| 25 |
+
setattr(clz, 'reset_parameters', lambda self: None)
|
| 26 |
+
|
| 27 |
+
# build models
|
| 28 |
+
vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
|
| 29 |
+
var_wo_ddp = VAR(
|
| 30 |
+
vae_local=vae_local,
|
| 31 |
+
num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
|
| 32 |
+
norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
|
| 33 |
+
attn_l2_norm=attn_l2_norm,
|
| 34 |
+
patch_nums=patch_nums,
|
| 35 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
| 36 |
+
).to(device)
|
| 37 |
+
var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
|
| 38 |
+
|
| 39 |
+
return vae_local, var_wo_ddp
|
VAR/code/VAR/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/basic_vae.cpython-310.pyc
ADDED
|
Binary file (6.83 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/basic_var.cpython-310.pyc
ADDED
|
Binary file (7.44 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/helpers.cpython-310.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/quant.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/var.cpython-310.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
VAR/code/VAR/models/__pycache__/vqvae.cpython-310.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
VAR/code/VAR/models/basic_vae.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# this file only provides the 2 modules used in VQVAE
|
| 7 |
+
__all__ = ['Encoder', 'Decoder',]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
|
| 12 |
+
"""
|
| 13 |
+
# swish
|
| 14 |
+
def nonlinearity(x):
|
| 15 |
+
return x * torch.sigmoid(x)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def Normalize(in_channels, num_groups=32):
|
| 19 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Upsample2x(nn.Module):
|
| 23 |
+
def __init__(self, in_channels):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Downsample2x(nn.Module):
|
| 32 |
+
def __init__(self, in_channels):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ResnetBlock(nn.Module):
|
| 41 |
+
def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.in_channels = in_channels
|
| 44 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 45 |
+
self.out_channels = out_channels
|
| 46 |
+
|
| 47 |
+
self.norm1 = Normalize(in_channels)
|
| 48 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 49 |
+
self.norm2 = Normalize(out_channels)
|
| 50 |
+
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
|
| 51 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 52 |
+
if self.in_channels != self.out_channels:
|
| 53 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 54 |
+
else:
|
| 55 |
+
self.nin_shortcut = nn.Identity()
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
h = self.conv1(F.silu(self.norm1(x), inplace=True))
|
| 59 |
+
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
|
| 60 |
+
return self.nin_shortcut(x) + h
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AttnBlock(nn.Module):
|
| 64 |
+
def __init__(self, in_channels):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.C = in_channels
|
| 67 |
+
|
| 68 |
+
self.norm = Normalize(in_channels)
|
| 69 |
+
self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
|
| 70 |
+
self.w_ratio = int(in_channels) ** (-0.5)
|
| 71 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
qkv = self.qkv(self.norm(x))
|
| 75 |
+
B, _, H, W = qkv.shape # should be B,3C,H,W
|
| 76 |
+
C = self.C
|
| 77 |
+
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
|
| 78 |
+
|
| 79 |
+
# compute attention
|
| 80 |
+
q = q.view(B, C, H * W).contiguous()
|
| 81 |
+
q = q.permute(0, 2, 1).contiguous() # B,HW,C
|
| 82 |
+
k = k.view(B, C, H * W).contiguous() # B,C,HW
|
| 83 |
+
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
|
| 84 |
+
w = F.softmax(w, dim=2)
|
| 85 |
+
|
| 86 |
+
# attend to values
|
| 87 |
+
v = v.view(B, C, H * W).contiguous()
|
| 88 |
+
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
|
| 89 |
+
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
|
| 90 |
+
h = h.view(B, C, H, W).contiguous()
|
| 91 |
+
|
| 92 |
+
return x + self.proj_out(h)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def make_attn(in_channels, using_sa=True):
|
| 96 |
+
return AttnBlock(in_channels) if using_sa else nn.Identity()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Encoder(nn.Module):
|
| 100 |
+
def __init__(
|
| 101 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
| 102 |
+
dropout=0.0, in_channels=3,
|
| 103 |
+
z_channels, double_z=False, using_sa=True, using_mid_sa=True,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.ch = ch
|
| 107 |
+
self.num_resolutions = len(ch_mult)
|
| 108 |
+
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
|
| 109 |
+
self.num_res_blocks = num_res_blocks
|
| 110 |
+
self.in_channels = in_channels
|
| 111 |
+
|
| 112 |
+
# downsampling
|
| 113 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 114 |
+
|
| 115 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 116 |
+
self.down = nn.ModuleList()
|
| 117 |
+
for i_level in range(self.num_resolutions):
|
| 118 |
+
block = nn.ModuleList()
|
| 119 |
+
attn = nn.ModuleList()
|
| 120 |
+
block_in = ch * in_ch_mult[i_level]
|
| 121 |
+
block_out = ch * ch_mult[i_level]
|
| 122 |
+
for i_block in range(self.num_res_blocks):
|
| 123 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
| 124 |
+
block_in = block_out
|
| 125 |
+
if i_level == self.num_resolutions - 1 and using_sa:
|
| 126 |
+
attn.append(make_attn(block_in, using_sa=True))
|
| 127 |
+
down = nn.Module()
|
| 128 |
+
down.block = block
|
| 129 |
+
down.attn = attn
|
| 130 |
+
if i_level != self.num_resolutions - 1:
|
| 131 |
+
down.downsample = Downsample2x(block_in)
|
| 132 |
+
self.down.append(down)
|
| 133 |
+
|
| 134 |
+
# middle
|
| 135 |
+
self.mid = nn.Module()
|
| 136 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
| 137 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
| 138 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
| 139 |
+
|
| 140 |
+
# end
|
| 141 |
+
self.norm_out = Normalize(block_in)
|
| 142 |
+
self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
# downsampling
|
| 146 |
+
h = self.conv_in(x)
|
| 147 |
+
for i_level in range(self.num_resolutions):
|
| 148 |
+
for i_block in range(self.num_res_blocks):
|
| 149 |
+
h = self.down[i_level].block[i_block](h)
|
| 150 |
+
if len(self.down[i_level].attn) > 0:
|
| 151 |
+
h = self.down[i_level].attn[i_block](h)
|
| 152 |
+
if i_level != self.num_resolutions - 1:
|
| 153 |
+
h = self.down[i_level].downsample(h)
|
| 154 |
+
|
| 155 |
+
# middle
|
| 156 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
|
| 157 |
+
|
| 158 |
+
# end
|
| 159 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
| 160 |
+
return h
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Decoder(nn.Module):
|
| 164 |
+
def __init__(
|
| 165 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
| 166 |
+
dropout=0.0, in_channels=3, # in_channels: raw img channels
|
| 167 |
+
z_channels, using_sa=True, using_mid_sa=True,
|
| 168 |
+
):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.ch = ch
|
| 171 |
+
self.num_resolutions = len(ch_mult)
|
| 172 |
+
self.num_res_blocks = num_res_blocks
|
| 173 |
+
self.in_channels = in_channels
|
| 174 |
+
|
| 175 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 176 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 177 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 178 |
+
|
| 179 |
+
# z to block_in
|
| 180 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 181 |
+
|
| 182 |
+
# middle
|
| 183 |
+
self.mid = nn.Module()
|
| 184 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
| 185 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
| 186 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
| 187 |
+
|
| 188 |
+
# upsampling
|
| 189 |
+
self.up = nn.ModuleList()
|
| 190 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 191 |
+
block = nn.ModuleList()
|
| 192 |
+
attn = nn.ModuleList()
|
| 193 |
+
block_out = ch * ch_mult[i_level]
|
| 194 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 195 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
| 196 |
+
block_in = block_out
|
| 197 |
+
if i_level == self.num_resolutions-1 and using_sa:
|
| 198 |
+
attn.append(make_attn(block_in, using_sa=True))
|
| 199 |
+
up = nn.Module()
|
| 200 |
+
up.block = block
|
| 201 |
+
up.attn = attn
|
| 202 |
+
if i_level != 0:
|
| 203 |
+
up.upsample = Upsample2x(block_in)
|
| 204 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 205 |
+
|
| 206 |
+
# end
|
| 207 |
+
self.norm_out = Normalize(block_in)
|
| 208 |
+
self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
|
| 209 |
+
|
| 210 |
+
def forward(self, z):
|
| 211 |
+
# z to block_in
|
| 212 |
+
# middle
|
| 213 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
|
| 214 |
+
|
| 215 |
+
# upsampling
|
| 216 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 217 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 218 |
+
h = self.up[i_level].block[i_block](h)
|
| 219 |
+
if len(self.up[i_level].attn) > 0:
|
| 220 |
+
h = self.up[i_level].attn[i_block](h)
|
| 221 |
+
if i_level != 0:
|
| 222 |
+
h = self.up[i_level].upsample(h)
|
| 223 |
+
|
| 224 |
+
# end
|
| 225 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
| 226 |
+
return h
|
VAR/code/VAR/models/basic_var.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from models.helpers import DropPath, drop_path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# this file only provides the 3 blocks used in VAR transformer
|
| 11 |
+
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# automatically import fused operators
|
| 15 |
+
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
|
| 16 |
+
try:
|
| 17 |
+
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
| 18 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
| 19 |
+
except ImportError: pass
|
| 20 |
+
# automatically import faster attention implementations
|
| 21 |
+
try: from xformers.ops import memory_efficient_attention
|
| 22 |
+
except ImportError: pass
|
| 23 |
+
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
|
| 24 |
+
except ImportError: pass
|
| 25 |
+
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
|
| 26 |
+
except ImportError:
|
| 27 |
+
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
|
| 28 |
+
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
|
| 29 |
+
if attn_mask is not None: attn.add_(attn_mask)
|
| 30 |
+
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FFN(nn.Module):
|
| 34 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
|
| 37 |
+
out_features = out_features or in_features
|
| 38 |
+
hidden_features = hidden_features or in_features
|
| 39 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 40 |
+
self.act = nn.GELU(approximate='tanh')
|
| 41 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 42 |
+
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
if self.fused_mlp_func is not None:
|
| 46 |
+
return self.drop(self.fused_mlp_func(
|
| 47 |
+
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
|
| 48 |
+
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
|
| 49 |
+
heuristic=0, process_group=None,
|
| 50 |
+
))
|
| 51 |
+
else:
|
| 52 |
+
return self.drop(self.fc2( self.act(self.fc1(x)) ))
|
| 53 |
+
|
| 54 |
+
def extra_repr(self) -> str:
|
| 55 |
+
return f'fused_mlp_func={self.fused_mlp_func is not None}'
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SelfAttention(nn.Module):
|
| 59 |
+
def __init__(
|
| 60 |
+
self, block_idx, embed_dim=768, num_heads=12,
|
| 61 |
+
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
assert embed_dim % num_heads == 0
|
| 65 |
+
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
|
| 66 |
+
self.attn_l2_norm = attn_l2_norm
|
| 67 |
+
if self.attn_l2_norm:
|
| 68 |
+
self.scale = 1
|
| 69 |
+
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
|
| 70 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
| 71 |
+
else:
|
| 72 |
+
self.scale = 0.25 / math.sqrt(self.head_dim)
|
| 73 |
+
|
| 74 |
+
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
| 75 |
+
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
|
| 76 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
| 77 |
+
|
| 78 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 79 |
+
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
|
| 80 |
+
self.attn_drop: float = attn_drop
|
| 81 |
+
self.using_flash = flash_if_available and flash_attn_func is not None
|
| 82 |
+
self.using_xform = flash_if_available and memory_efficient_attention is not None
|
| 83 |
+
|
| 84 |
+
# only used during inference
|
| 85 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
| 86 |
+
|
| 87 |
+
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
|
| 88 |
+
|
| 89 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
| 90 |
+
def forward(self, x, attn_bias):
|
| 91 |
+
B, L, C = x.shape
|
| 92 |
+
|
| 93 |
+
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
|
| 94 |
+
main_type = qkv.dtype
|
| 95 |
+
# qkv: BL3Hc
|
| 96 |
+
|
| 97 |
+
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
|
| 98 |
+
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
|
| 99 |
+
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
|
| 100 |
+
|
| 101 |
+
if self.attn_l2_norm:
|
| 102 |
+
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
|
| 103 |
+
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
|
| 104 |
+
q = F.normalize(q, dim=-1).mul(scale_mul)
|
| 105 |
+
k = F.normalize(k, dim=-1)
|
| 106 |
+
|
| 107 |
+
if self.caching:
|
| 108 |
+
if self.cached_k is None: self.cached_k = k; self.cached_v = v
|
| 109 |
+
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
|
| 110 |
+
|
| 111 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 112 |
+
if using_flash:
|
| 113 |
+
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
|
| 114 |
+
elif self.using_xform:
|
| 115 |
+
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
|
| 116 |
+
else:
|
| 117 |
+
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
|
| 118 |
+
|
| 119 |
+
return self.proj_drop(self.proj(oup))
|
| 120 |
+
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
|
| 121 |
+
# attn = self.attn_drop(attn.softmax(dim=-1))
|
| 122 |
+
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
|
| 123 |
+
|
| 124 |
+
def extra_repr(self) -> str:
|
| 125 |
+
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class AdaLNSelfAttn(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
|
| 131 |
+
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
|
| 132 |
+
flash_if_available=False, fused_if_available=True,
|
| 133 |
+
):
|
| 134 |
+
super(AdaLNSelfAttn, self).__init__()
|
| 135 |
+
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
|
| 136 |
+
self.C, self.D = embed_dim, cond_dim
|
| 137 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 138 |
+
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
|
| 139 |
+
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
|
| 140 |
+
|
| 141 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
| 142 |
+
self.shared_aln = shared_aln
|
| 143 |
+
if self.shared_aln:
|
| 144 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
| 145 |
+
else:
|
| 146 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
| 147 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
|
| 148 |
+
|
| 149 |
+
self.fused_add_norm_fn = None
|
| 150 |
+
|
| 151 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
| 152 |
+
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
|
| 153 |
+
if self.shared_aln:
|
| 154 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
| 155 |
+
else:
|
| 156 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
| 157 |
+
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
|
| 158 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
def extra_repr(self) -> str:
|
| 162 |
+
return f'shared_aln={self.shared_aln}'
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AdaLNBeforeHead(nn.Module):
|
| 166 |
+
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.C, self.D = C, D
|
| 169 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
| 170 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
|
| 171 |
+
|
| 172 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
|
| 173 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
| 174 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
VAR/code/VAR/models/helpers.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
| 7 |
+
B, l, V = logits_BlV.shape
|
| 8 |
+
if top_k > 0:
|
| 9 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
| 10 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
| 11 |
+
if top_p > 0:
|
| 12 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
| 13 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
| 14 |
+
sorted_idx_to_remove[..., -1:] = False
|
| 15 |
+
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
|
| 16 |
+
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
|
| 17 |
+
replacement = num_samples >= 0
|
| 18 |
+
num_samples = abs(num_samples)
|
| 19 |
+
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
|
| 23 |
+
if rng is None:
|
| 24 |
+
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
|
| 25 |
+
|
| 26 |
+
gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
|
| 27 |
+
gumbels = (logits + gumbels) / tau
|
| 28 |
+
y_soft = gumbels.softmax(dim)
|
| 29 |
+
|
| 30 |
+
if hard:
|
| 31 |
+
index = y_soft.max(dim, keepdim=True)[1]
|
| 32 |
+
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
| 33 |
+
ret = y_hard - y_soft.detach() + y_soft
|
| 34 |
+
else:
|
| 35 |
+
ret = y_soft
|
| 36 |
+
return ret
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
|
| 40 |
+
if drop_prob == 0. or not training: return x
|
| 41 |
+
keep_prob = 1 - drop_prob
|
| 42 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 43 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 44 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 45 |
+
random_tensor.div_(keep_prob)
|
| 46 |
+
return x * random_tensor
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DropPath(nn.Module): # taken from timm
|
| 50 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 51 |
+
super(DropPath, self).__init__()
|
| 52 |
+
self.drop_prob = drop_prob
|
| 53 |
+
self.scale_by_keep = scale_by_keep
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 57 |
+
|
| 58 |
+
def extra_repr(self):
|
| 59 |
+
return f'(drop_prob=...)'
|
VAR/code/VAR/models/quant.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import distributed as tdist, nn as nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
import dist
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# this file only provides the VectorQuantizer2 used in VQVAE
|
| 12 |
+
__all__ = ['VectorQuantizer2',]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VectorQuantizer2(nn.Module):
|
| 16 |
+
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
|
| 17 |
+
def __init__(
|
| 18 |
+
self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
|
| 19 |
+
default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.vocab_size: int = vocab_size
|
| 23 |
+
self.Cvae: int = Cvae
|
| 24 |
+
self.using_znorm: bool = using_znorm
|
| 25 |
+
self.v_patch_nums: Tuple[int] = v_patch_nums
|
| 26 |
+
|
| 27 |
+
self.quant_resi_ratio = quant_resi
|
| 28 |
+
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
|
| 29 |
+
self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])
|
| 30 |
+
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
|
| 31 |
+
self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
|
| 32 |
+
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
|
| 33 |
+
self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))
|
| 34 |
+
|
| 35 |
+
self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
|
| 36 |
+
self.record_hit = 0
|
| 37 |
+
|
| 38 |
+
self.beta: float = beta
|
| 39 |
+
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
|
| 40 |
+
|
| 41 |
+
# only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
|
| 42 |
+
self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
|
| 43 |
+
|
| 44 |
+
def eini(self, eini):
|
| 45 |
+
if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
| 46 |
+
elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
|
| 47 |
+
|
| 48 |
+
def extra_repr(self) -> str:
|
| 49 |
+
return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
|
| 50 |
+
|
| 51 |
+
# ===================== `forward` is only used in VAE training =====================
|
| 52 |
+
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
|
| 53 |
+
dtype = f_BChw.dtype
|
| 54 |
+
if dtype != torch.float32: f_BChw = f_BChw.float()
|
| 55 |
+
B, C, H, W = f_BChw.shape
|
| 56 |
+
f_no_grad = f_BChw.detach()
|
| 57 |
+
|
| 58 |
+
f_rest = f_no_grad.clone()
|
| 59 |
+
f_hat = torch.zeros_like(f_rest)
|
| 60 |
+
|
| 61 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 62 |
+
mean_vq_loss: torch.Tensor = 0.0
|
| 63 |
+
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
|
| 64 |
+
SN = len(self.v_patch_nums)
|
| 65 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
| 66 |
+
# find the nearest embedding
|
| 67 |
+
if self.using_znorm:
|
| 68 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
| 69 |
+
rest_NC = F.normalize(rest_NC, dim=-1)
|
| 70 |
+
idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
| 71 |
+
else:
|
| 72 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
| 73 |
+
d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
|
| 74 |
+
d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
| 75 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
| 76 |
+
|
| 77 |
+
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
|
| 78 |
+
if self.training:
|
| 79 |
+
if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
|
| 80 |
+
|
| 81 |
+
# calc loss
|
| 82 |
+
idx_Bhw = idx_N.view(B, pn, pn)
|
| 83 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
| 84 |
+
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
|
| 85 |
+
f_hat = f_hat + h_BChw
|
| 86 |
+
f_rest -= h_BChw
|
| 87 |
+
|
| 88 |
+
if self.training and dist.initialized():
|
| 89 |
+
handler.wait()
|
| 90 |
+
if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)
|
| 91 |
+
elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
|
| 92 |
+
else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
|
| 93 |
+
self.record_hit += 1
|
| 94 |
+
vocab_hit_V.add_(hit_V)
|
| 95 |
+
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
|
| 96 |
+
|
| 97 |
+
mean_vq_loss *= 1. / SN
|
| 98 |
+
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
|
| 99 |
+
|
| 100 |
+
margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
|
| 101 |
+
# margin = pn*pn / 100
|
| 102 |
+
if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]
|
| 103 |
+
else: usages = None
|
| 104 |
+
return f_hat, usages, mean_vq_loss
|
| 105 |
+
# ===================== `forward` is only used in VAE training =====================
|
| 106 |
+
|
| 107 |
+
def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 108 |
+
ls_f_hat_BChw = []
|
| 109 |
+
B = ms_h_BChw[0].shape[0]
|
| 110 |
+
H = W = self.v_patch_nums[-1]
|
| 111 |
+
SN = len(self.v_patch_nums)
|
| 112 |
+
if all_to_max_scale:
|
| 113 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
|
| 114 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
| 115 |
+
h_BChw = ms_h_BChw[si]
|
| 116 |
+
if si < len(self.v_patch_nums) - 1:
|
| 117 |
+
h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')
|
| 118 |
+
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
|
| 119 |
+
f_hat.add_(h_BChw)
|
| 120 |
+
if last_one: ls_f_hat_BChw = f_hat
|
| 121 |
+
else: ls_f_hat_BChw.append(f_hat.clone())
|
| 122 |
+
else:
|
| 123 |
+
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
|
| 124 |
+
# WARNING: this should only be used for experimental purpose
|
| 125 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)
|
| 126 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
| 127 |
+
f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')
|
| 128 |
+
h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])
|
| 129 |
+
f_hat.add_(h_BChw)
|
| 130 |
+
if last_one: ls_f_hat_BChw = f_hat
|
| 131 |
+
else: ls_f_hat_BChw.append(f_hat)
|
| 132 |
+
|
| 133 |
+
return ls_f_hat_BChw
|
| 134 |
+
|
| 135 |
+
def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
|
| 136 |
+
B, C, H, W = f_BChw.shape
|
| 137 |
+
f_no_grad = f_BChw.detach()
|
| 138 |
+
f_rest = f_no_grad.clone()
|
| 139 |
+
f_hat = torch.zeros_like(f_rest)
|
| 140 |
+
|
| 141 |
+
f_hat_or_idx_Bl: List[torch.Tensor] = []
|
| 142 |
+
|
| 143 |
+
patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)] # from small to large
|
| 144 |
+
assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
|
| 145 |
+
|
| 146 |
+
SN = len(patch_hws)
|
| 147 |
+
for si, (ph, pw) in enumerate(patch_hws): # from small to large
|
| 148 |
+
if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
|
| 149 |
+
# find the nearest embedding
|
| 150 |
+
z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
| 151 |
+
if self.using_znorm:
|
| 152 |
+
z_NC = F.normalize(z_NC, dim=-1)
|
| 153 |
+
idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
| 154 |
+
else:
|
| 155 |
+
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
|
| 156 |
+
d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
| 157 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
| 158 |
+
|
| 159 |
+
idx_Bhw = idx_N.view(B, ph, pw)
|
| 160 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
| 161 |
+
h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
|
| 162 |
+
f_hat.add_(h_BChw)
|
| 163 |
+
f_rest.sub_(h_BChw)
|
| 164 |
+
f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))
|
| 165 |
+
|
| 166 |
+
return f_hat_or_idx_Bl
|
| 167 |
+
|
| 168 |
+
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
|
| 169 |
+
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
|
| 170 |
+
next_scales = []
|
| 171 |
+
B = gt_ms_idx_Bl[0].shape[0]
|
| 172 |
+
C = self.Cvae
|
| 173 |
+
H = W = self.v_patch_nums[-1]
|
| 174 |
+
SN = len(self.v_patch_nums)
|
| 175 |
+
|
| 176 |
+
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
|
| 177 |
+
pn_next: int = self.v_patch_nums[0]
|
| 178 |
+
for si in range(SN-1):
|
| 179 |
+
if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
|
| 180 |
+
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
|
| 181 |
+
f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
|
| 182 |
+
pn_next = self.v_patch_nums[si+1]
|
| 183 |
+
next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
|
| 184 |
+
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
|
| 185 |
+
|
| 186 |
+
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
|
| 187 |
+
def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
|
| 188 |
+
HW = self.v_patch_nums[-1]
|
| 189 |
+
if si != SN-1:
|
| 190 |
+
h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic')) # conv after upsample
|
| 191 |
+
f_hat.add_(h)
|
| 192 |
+
return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')
|
| 193 |
+
else:
|
| 194 |
+
h = self.quant_resi[si/(SN-1)](h_BChw)
|
| 195 |
+
f_hat.add_(h)
|
| 196 |
+
return f_hat, f_hat
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Phi(nn.Conv2d):
|
| 200 |
+
def __init__(self, embed_dim, quant_resi):
|
| 201 |
+
ks = 3
|
| 202 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)
|
| 203 |
+
self.resi_ratio = abs(quant_resi)
|
| 204 |
+
|
| 205 |
+
def forward(self, h_BChw):
|
| 206 |
+
return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class PhiShared(nn.Module):
|
| 210 |
+
def __init__(self, qresi: Phi):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.qresi: Phi = qresi
|
| 213 |
+
|
| 214 |
+
def __getitem__(self, _) -> Phi:
|
| 215 |
+
return self.qresi
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PhiPartiallyShared(nn.Module):
|
| 219 |
+
def __init__(self, qresi_ls: nn.ModuleList):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.qresi_ls = qresi_ls
|
| 222 |
+
K = len(qresi_ls)
|
| 223 |
+
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
|
| 224 |
+
|
| 225 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
| 226 |
+
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
|
| 227 |
+
|
| 228 |
+
def extra_repr(self) -> str:
|
| 229 |
+
return f'ticks={self.ticks}'
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class PhiNonShared(nn.ModuleList):
|
| 233 |
+
def __init__(self, qresi: List):
|
| 234 |
+
super().__init__(qresi)
|
| 235 |
+
# self.qresi = qresi
|
| 236 |
+
K = len(qresi)
|
| 237 |
+
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
|
| 238 |
+
|
| 239 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
| 240 |
+
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
|
| 241 |
+
|
| 242 |
+
def extra_repr(self) -> str:
|
| 243 |
+
return f'ticks={self.ticks}'
|
VAR/code/VAR/models/var.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 8 |
+
|
| 9 |
+
import dist
|
| 10 |
+
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
|
| 11 |
+
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
|
| 12 |
+
from models.vqvae import VQVAE, VectorQuantizer2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SharedAdaLin(nn.Linear):
|
| 16 |
+
def forward(self, cond_BD):
|
| 17 |
+
C = self.weight.shape[0] // 6
|
| 18 |
+
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VAR(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self, vae_local: VQVAE,
|
| 24 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
| 25 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
| 26 |
+
attn_l2_norm=False,
|
| 27 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
| 28 |
+
flash_if_available=True, fused_if_available=True,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
# 0. hyperparameters
|
| 32 |
+
assert embed_dim % num_heads == 0
|
| 33 |
+
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
|
| 34 |
+
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
|
| 35 |
+
|
| 36 |
+
self.cond_drop_rate = cond_drop_rate
|
| 37 |
+
self.prog_si = -1 # progressive training
|
| 38 |
+
|
| 39 |
+
self.patch_nums: Tuple[int] = patch_nums
|
| 40 |
+
self.L = sum(pn ** 2 for pn in self.patch_nums)
|
| 41 |
+
self.first_l = self.patch_nums[0] ** 2
|
| 42 |
+
self.begin_ends = []
|
| 43 |
+
cur = 0
|
| 44 |
+
for i, pn in enumerate(self.patch_nums):
|
| 45 |
+
self.begin_ends.append((cur, cur+pn ** 2))
|
| 46 |
+
cur += pn ** 2
|
| 47 |
+
|
| 48 |
+
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
| 49 |
+
self.rng = torch.Generator(device=dist.get_device())
|
| 50 |
+
|
| 51 |
+
# 1. input (word) embedding
|
| 52 |
+
quant: VectorQuantizer2 = vae_local.quantize
|
| 53 |
+
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
|
| 54 |
+
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
|
| 55 |
+
self.word_embed = nn.Linear(self.Cvae, self.C)
|
| 56 |
+
|
| 57 |
+
# 2. class embedding
|
| 58 |
+
init_std = math.sqrt(1 / self.C / 3)
|
| 59 |
+
self.num_classes = num_classes
|
| 60 |
+
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
|
| 61 |
+
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
|
| 62 |
+
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
|
| 63 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
| 64 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
| 65 |
+
|
| 66 |
+
# 3. absolute position embedding
|
| 67 |
+
pos_1LC = []
|
| 68 |
+
for i, pn in enumerate(self.patch_nums):
|
| 69 |
+
pe = torch.empty(1, pn*pn, self.C)
|
| 70 |
+
nn.init.trunc_normal_(pe, mean=0, std=init_std)
|
| 71 |
+
pos_1LC.append(pe)
|
| 72 |
+
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
|
| 73 |
+
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
|
| 74 |
+
self.pos_1LC = nn.Parameter(pos_1LC)
|
| 75 |
+
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
|
| 76 |
+
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
|
| 77 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
| 78 |
+
|
| 79 |
+
# 4. backbone blocks
|
| 80 |
+
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
|
| 81 |
+
|
| 82 |
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
| 83 |
+
self.drop_path_rate = drop_path_rate
|
| 84 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
|
| 85 |
+
self.blocks = nn.ModuleList([
|
| 86 |
+
AdaLNSelfAttn(
|
| 87 |
+
cond_dim=self.D, shared_aln=shared_aln,
|
| 88 |
+
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
| 89 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
|
| 90 |
+
attn_l2_norm=attn_l2_norm,
|
| 91 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
| 92 |
+
)
|
| 93 |
+
for block_idx in range(depth)
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
|
| 97 |
+
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
|
| 98 |
+
print(
|
| 99 |
+
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
|
| 100 |
+
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
|
| 101 |
+
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
|
| 102 |
+
end='\n\n', flush=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# 5. attention mask used in training (for masking out the future)
|
| 106 |
+
# it won't be used in inference, since kv cache is enabled
|
| 107 |
+
d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
|
| 108 |
+
dT = d.transpose(1, 2) # dT: 11L
|
| 109 |
+
lvl_1L = dT[:, 0].contiguous()
|
| 110 |
+
self.register_buffer('lvl_1L', lvl_1L)
|
| 111 |
+
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
|
| 112 |
+
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
|
| 113 |
+
|
| 114 |
+
# 6. classifier head
|
| 115 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
|
| 116 |
+
self.head = nn.Linear(self.C, self.V)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# 7. register tokens
|
| 120 |
+
self.num_register_tokens = getattr(self, "num_register_tokens", 4) # 可外部传参
|
| 121 |
+
self.register_tokens = nn.Parameter(torch.randn(self.num_register_tokens, self.C) * 0.02)
|
| 122 |
+
|
| 123 |
+
def _expand_attn_bias(self, attn_bias: torch.Tensor, add_len: int) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
attn_bias: (B_or_1, H, L, L) 或 (1, 1, L, L)
|
| 126 |
+
add_len : 要在序列长度上增加的 register 数量
|
| 127 |
+
返回: (B_or_1, H, L+add_len, L+add_len)
|
| 128 |
+
"""
|
| 129 |
+
if add_len <= 0:
|
| 130 |
+
return attn_bias
|
| 131 |
+
|
| 132 |
+
*prefix, L, _ = attn_bias.shape # (..., L, L)
|
| 133 |
+
newL = L + add_len
|
| 134 |
+
# 用 0 pad(表示不额外抑制注意力)。如果你有因果/遮挡逻辑,仍保持老区域不变
|
| 135 |
+
expanded = attn_bias.new_zeros(*prefix, newL, newL)
|
| 136 |
+
expanded[..., :L, :L] = attn_bias
|
| 137 |
+
return expanded
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
|
| 143 |
+
if not isinstance(h_or_h_and_residual, torch.Tensor):
|
| 144 |
+
h, resi = h_or_h_and_residual # fused_add_norm must be used
|
| 145 |
+
h = resi + self.blocks[-1].drop_path(h)
|
| 146 |
+
else: # fused_add_norm is not used
|
| 147 |
+
h = h_or_h_and_residual
|
| 148 |
+
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
|
| 149 |
+
|
| 150 |
+
@torch.no_grad()
|
| 151 |
+
def autoregressive_infer_cfg(
|
| 152 |
+
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
|
| 153 |
+
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
|
| 154 |
+
more_smooth=False,
|
| 155 |
+
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
|
| 156 |
+
"""
|
| 157 |
+
only used for inference, on autoregressive mode
|
| 158 |
+
:param B: batch size
|
| 159 |
+
:param label_B: imagenet label; if None, randomly sampled
|
| 160 |
+
:param g_seed: random seed
|
| 161 |
+
:param cfg: classifier-free guidance ratio
|
| 162 |
+
:param top_k: top-k sampling
|
| 163 |
+
:param top_p: top-p sampling
|
| 164 |
+
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
| 165 |
+
:return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
|
| 166 |
+
"""
|
| 167 |
+
if g_seed is None: rng = None
|
| 168 |
+
else: self.rng.manual_seed(g_seed); rng = self.rng
|
| 169 |
+
|
| 170 |
+
if label_B is None:
|
| 171 |
+
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
|
| 172 |
+
elif isinstance(label_B, int):
|
| 173 |
+
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
|
| 174 |
+
|
| 175 |
+
sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))
|
| 176 |
+
|
| 177 |
+
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
|
| 178 |
+
next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
|
| 179 |
+
|
| 180 |
+
cur_L = 0
|
| 181 |
+
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
|
| 182 |
+
|
| 183 |
+
for b in self.blocks: b.attn.kv_caching(True)
|
| 184 |
+
for si, pn in enumerate(self.patch_nums): # si: i-th segment
|
| 185 |
+
ratio = si / self.num_stages_minus_1
|
| 186 |
+
# last_L = cur_L
|
| 187 |
+
cur_L += pn*pn
|
| 188 |
+
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
|
| 189 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
| 190 |
+
x = next_token_map
|
| 191 |
+
AdaLNSelfAttn.forward
|
| 192 |
+
for b in self.blocks:
|
| 193 |
+
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
|
| 194 |
+
logits_BlV = self.get_logits(x, cond_BD)
|
| 195 |
+
|
| 196 |
+
t = cfg * ratio
|
| 197 |
+
logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
|
| 198 |
+
|
| 199 |
+
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
|
| 200 |
+
if not more_smooth: # this is the default case
|
| 201 |
+
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
|
| 202 |
+
else: # not used when evaluating FID/IS/Precision/Recall
|
| 203 |
+
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
| 204 |
+
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
|
| 205 |
+
|
| 206 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
|
| 207 |
+
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
|
| 208 |
+
if si != self.num_stages_minus_1: # prepare for next stage
|
| 209 |
+
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
|
| 210 |
+
next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
|
| 211 |
+
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
|
| 212 |
+
|
| 213 |
+
for b in self.blocks: b.attn.kv_caching(False)
|
| 214 |
+
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
|
| 215 |
+
|
| 216 |
+
# def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
|
| 217 |
+
# """
|
| 218 |
+
# :param label_B: label_B
|
| 219 |
+
# :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
|
| 220 |
+
# :return: logits BLV, V is vocab_size
|
| 221 |
+
# """
|
| 222 |
+
# bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
|
| 223 |
+
# B = x_BLCv_wo_first_l.shape[0]
|
| 224 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
| 225 |
+
# label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
|
| 226 |
+
# sos = cond_BD = self.class_emb(label_B)
|
| 227 |
+
# sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
|
| 228 |
+
|
| 229 |
+
# if self.prog_si == 0: x_BLC = sos
|
| 230 |
+
# else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
|
| 231 |
+
# x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
|
| 232 |
+
|
| 233 |
+
# attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
|
| 234 |
+
# cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
| 235 |
+
|
| 236 |
+
# # hack: get the dtype if mixed precision is used
|
| 237 |
+
# temp = x_BLC.new_ones(8, 8)
|
| 238 |
+
# main_type = torch.matmul(temp, temp).dtype
|
| 239 |
+
|
| 240 |
+
# x_BLC = x_BLC.to(dtype=main_type)
|
| 241 |
+
# cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
|
| 242 |
+
# attn_bias = attn_bias.to(dtype=main_type)
|
| 243 |
+
|
| 244 |
+
# AdaLNSelfAttn.forward
|
| 245 |
+
# for i, b in enumerate(self.blocks):
|
| 246 |
+
# x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
|
| 247 |
+
# x_BLC = self.get_logits(x_BLC.float(), cond_BD)
|
| 248 |
+
|
| 249 |
+
# if self.prog_si == 0:
|
| 250 |
+
# if isinstance(self.word_embed, nn.Linear):
|
| 251 |
+
# x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
|
| 252 |
+
# else:
|
| 253 |
+
# s = 0
|
| 254 |
+
# for p in self.word_embed.parameters():
|
| 255 |
+
# if p.requires_grad:
|
| 256 |
+
# s += p.view(-1)[0] * 0
|
| 257 |
+
# x_BLC[0, 0, 0] += s
|
| 258 |
+
# return x_BLC # logits BLV, V is vocab_size
|
| 259 |
+
|
| 260 |
+
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
|
| 261 |
+
"""
|
| 262 |
+
:param label_B: label_B
|
| 263 |
+
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
|
| 264 |
+
:return: logits BLV, V is vocab_size
|
| 265 |
+
"""
|
| 266 |
+
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
|
| 267 |
+
B = x_BLCv_wo_first_l.shape[0]
|
| 268 |
+
|
| 269 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 270 |
+
label_B = torch.where(
|
| 271 |
+
torch.rand(B, device=label_B.device) < self.cond_drop_rate,
|
| 272 |
+
self.num_classes,
|
| 273 |
+
label_B
|
| 274 |
+
)
|
| 275 |
+
sos = cond_BD = self.class_emb(label_B) # (B, D)
|
| 276 |
+
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
|
| 277 |
+
|
| 278 |
+
if self.prog_si == 0:
|
| 279 |
+
x_BLC = sos # (B, first_l, C)
|
| 280 |
+
else:
|
| 281 |
+
x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1) # (B, ed, C)
|
| 282 |
+
|
| 283 |
+
# 仅对原 tokens 加 lvl/pos(register 不加位置信息)
|
| 284 |
+
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
|
| 285 |
+
|
| 286 |
+
# 原始 attn bias(针对 ed 长度)
|
| 287 |
+
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
|
| 288 |
+
|
| 289 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
| 290 |
+
|
| 291 |
+
# hack: 获取主 dtype(兼容混精)
|
| 292 |
+
temp = x_BLC.new_ones(8, 8)
|
| 293 |
+
main_type = torch.matmul(temp, temp).dtype
|
| 294 |
+
|
| 295 |
+
x_BLC = x_BLC.to(dtype=main_type)
|
| 296 |
+
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
|
| 297 |
+
attn_bias = attn_bias.to(dtype=main_type)
|
| 298 |
+
|
| 299 |
+
# ====== 这里开始插入 register tokens ======
|
| 300 |
+
K = self.num_register_tokens
|
| 301 |
+
if K > 0:
|
| 302 |
+
# [K, C] -> [B, K, C]
|
| 303 |
+
r_BKC = self.register_tokens.to(device=x_BLC.device, dtype=x_BLC.dtype).unsqueeze(0).expand(B, -1, -1)
|
| 304 |
+
|
| 305 |
+
# breakpoint()
|
| 306 |
+
# 拼接顺序: [REG ... | TOKENS ...] —— 仅作为中间交互,输出前会丢弃
|
| 307 |
+
x_with_reg = torch.cat([r_BKC, x_BLC], dim=1) # (B, K+ed, C)
|
| 308 |
+
|
| 309 |
+
# 扩展 attention bias 到 K+ed
|
| 310 |
+
attn_bias_expanded = self._expand_attn_bias(attn_bias, add_len=K)
|
| 311 |
+
else:
|
| 312 |
+
x_with_reg = x_BLC
|
| 313 |
+
attn_bias_expanded = attn_bias
|
| 314 |
+
# ===========================================
|
| 315 |
+
|
| 316 |
+
# 经过 blocks(让 register 参与注意力交互)
|
| 317 |
+
for i, b in enumerate(self.blocks):
|
| 318 |
+
x_with_reg = b(x=x_with_reg, cond_BD=cond_BD_or_gss, attn_bias=attn_bias_expanded)
|
| 319 |
+
|
| 320 |
+
# 丢弃 register,只保留原 tokens 部分以对齐下游 get_logits 的长度/语义
|
| 321 |
+
if K > 0:
|
| 322 |
+
x_BLC = x_with_reg[:, K:, :] # (B, ed, C)
|
| 323 |
+
else:
|
| 324 |
+
x_BLC = x_with_reg
|
| 325 |
+
|
| 326 |
+
# 计算 logits(保持你原有逻辑不变)
|
| 327 |
+
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
|
| 328 |
+
|
| 329 |
+
# 保持对 word_embed 的“零引用”以避免编译/导出时被优化掉
|
| 330 |
+
if self.prog_si == 0:
|
| 331 |
+
if isinstance(self.word_embed, nn.Linear):
|
| 332 |
+
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
|
| 333 |
+
else:
|
| 334 |
+
s = 0
|
| 335 |
+
for p in self.word_embed.parameters():
|
| 336 |
+
if p.requires_grad:
|
| 337 |
+
s += p.view(-1)[0] * 0
|
| 338 |
+
x_BLC[0, 0, 0] += s
|
| 339 |
+
|
| 340 |
+
return x_BLC # logits BLV, V is vocab_size
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
|
| 344 |
+
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
|
| 345 |
+
|
| 346 |
+
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
|
| 347 |
+
for m in self.modules():
|
| 348 |
+
with_weight = hasattr(m, 'weight') and m.weight is not None
|
| 349 |
+
with_bias = hasattr(m, 'bias') and m.bias is not None
|
| 350 |
+
if isinstance(m, nn.Linear):
|
| 351 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
| 352 |
+
if with_bias: m.bias.data.zero_()
|
| 353 |
+
elif isinstance(m, nn.Embedding):
|
| 354 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
| 355 |
+
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
|
| 356 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
| 357 |
+
if with_weight: m.weight.data.fill_(1.)
|
| 358 |
+
if with_bias: m.bias.data.zero_()
|
| 359 |
+
# conv: VAR has no conv, only VQVAE has conv
|
| 360 |
+
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
| 361 |
+
if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
|
| 362 |
+
else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
|
| 363 |
+
if with_bias: m.bias.data.zero_()
|
| 364 |
+
|
| 365 |
+
if init_head >= 0:
|
| 366 |
+
if isinstance(self.head, nn.Linear):
|
| 367 |
+
self.head.weight.data.mul_(init_head)
|
| 368 |
+
self.head.bias.data.zero_()
|
| 369 |
+
elif isinstance(self.head, nn.Sequential):
|
| 370 |
+
self.head[-1].weight.data.mul_(init_head)
|
| 371 |
+
self.head[-1].bias.data.zero_()
|
| 372 |
+
|
| 373 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
| 374 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
|
| 375 |
+
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
|
| 376 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
| 377 |
+
|
| 378 |
+
depth = len(self.blocks)
|
| 379 |
+
for block_idx, sab in enumerate(self.blocks):
|
| 380 |
+
sab: AdaLNSelfAttn
|
| 381 |
+
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
|
| 382 |
+
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
|
| 383 |
+
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
|
| 384 |
+
nn.init.ones_(sab.ffn.fcg.bias)
|
| 385 |
+
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
|
| 386 |
+
if hasattr(sab, 'ada_lin'):
|
| 387 |
+
sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
|
| 388 |
+
sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
|
| 389 |
+
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
|
| 390 |
+
sab.ada_lin[-1].bias.data.zero_()
|
| 391 |
+
elif hasattr(sab, 'ada_gss'):
|
| 392 |
+
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
|
| 393 |
+
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
|
| 394 |
+
|
| 395 |
+
def extra_repr(self):
|
| 396 |
+
return f'drop_path_rate={self.drop_path_rate:g}'
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class VARHF(VAR, PyTorchModelHubMixin):
|
| 400 |
+
# repo_url="https://github.com/FoundationVision/VAR",
|
| 401 |
+
# tags=["image-generation"]):
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
vae_kwargs,
|
| 405 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
| 406 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
| 407 |
+
attn_l2_norm=False,
|
| 408 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
| 409 |
+
flash_if_available=True, fused_if_available=True,
|
| 410 |
+
):
|
| 411 |
+
vae_local = VQVAE(**vae_kwargs)
|
| 412 |
+
super().__init__(
|
| 413 |
+
vae_local=vae_local,
|
| 414 |
+
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
| 415 |
+
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
|
| 416 |
+
attn_l2_norm=attn_l2_norm,
|
| 417 |
+
patch_nums=patch_nums,
|
| 418 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
| 419 |
+
)
|
VAR/code/VAR/models/vqvae.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
|
| 4 |
+
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
|
| 5 |
+
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
|
| 6 |
+
"""
|
| 7 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from .basic_vae import Decoder, Encoder
|
| 13 |
+
from .quant import VectorQuantizer2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VQVAE(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
|
| 19 |
+
beta=0.25, # commitment loss weight
|
| 20 |
+
using_znorm=False, # whether to normalize when computing the nearest neighbors
|
| 21 |
+
quant_conv_ks=3, # quant conv kernel size
|
| 22 |
+
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
|
| 23 |
+
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
|
| 24 |
+
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
|
| 25 |
+
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
|
| 26 |
+
test_mode=True,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.test_mode = test_mode
|
| 30 |
+
self.V, self.Cvae = vocab_size, z_channels
|
| 31 |
+
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
|
| 32 |
+
ddconfig = dict(
|
| 33 |
+
dropout=dropout, ch=ch, z_channels=z_channels,
|
| 34 |
+
in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
|
| 35 |
+
using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
|
| 36 |
+
# resamp_with_conv=True, # always True, removed.
|
| 37 |
+
)
|
| 38 |
+
ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
|
| 39 |
+
self.encoder = Encoder(double_z=False, **ddconfig)
|
| 40 |
+
self.decoder = Decoder(**ddconfig)
|
| 41 |
+
|
| 42 |
+
self.vocab_size = vocab_size
|
| 43 |
+
self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
|
| 44 |
+
self.quantize: VectorQuantizer2 = VectorQuantizer2(
|
| 45 |
+
vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
|
| 46 |
+
default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
|
| 47 |
+
)
|
| 48 |
+
self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
| 49 |
+
self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
| 50 |
+
|
| 51 |
+
if self.test_mode:
|
| 52 |
+
self.eval()
|
| 53 |
+
[p.requires_grad_(False) for p in self.parameters()]
|
| 54 |
+
|
| 55 |
+
# ===================== `forward` is only used in VAE training =====================
|
| 56 |
+
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
|
| 57 |
+
VectorQuantizer2.forward
|
| 58 |
+
f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
|
| 59 |
+
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
|
| 60 |
+
# ===================== `forward` is only used in VAE training =====================
|
| 61 |
+
|
| 62 |
+
def fhat_to_img(self, f_hat: torch.Tensor):
|
| 63 |
+
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
| 64 |
+
|
| 65 |
+
def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
|
| 66 |
+
f = self.quant_conv(self.encoder(inp_img_no_grad))
|
| 67 |
+
return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
|
| 68 |
+
|
| 69 |
+
def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 70 |
+
B = ms_idx_Bl[0].shape[0]
|
| 71 |
+
ms_h_BChw = []
|
| 72 |
+
for idx_Bl in ms_idx_Bl:
|
| 73 |
+
l = idx_Bl.shape[1]
|
| 74 |
+
pn = round(l ** 0.5)
|
| 75 |
+
ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
|
| 76 |
+
return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
|
| 77 |
+
|
| 78 |
+
def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 79 |
+
if last_one:
|
| 80 |
+
return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
|
| 81 |
+
else:
|
| 82 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
|
| 83 |
+
|
| 84 |
+
def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
|
| 85 |
+
f = self.quant_conv(self.encoder(x))
|
| 86 |
+
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
|
| 87 |
+
if last_one:
|
| 88 |
+
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
|
| 89 |
+
else:
|
| 90 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
|
| 91 |
+
|
| 92 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
|
| 93 |
+
if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
|
| 94 |
+
state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
|
| 95 |
+
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
|
VAR/code/VAR/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Pillow
|
| 3 |
+
huggingface_hub
|
| 4 |
+
numpy
|
| 5 |
+
pytz
|
| 6 |
+
transformers
|
| 7 |
+
typed-argument-parser
|
| 8 |
+
tensorboard
|
VAR/code/VAR/train.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import warnings
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
import dist
|
| 13 |
+
from utils import arg_util, misc
|
| 14 |
+
from utils.data import build_dataset
|
| 15 |
+
from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
|
| 16 |
+
from utils.misc import auto_resume
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_everything(args: arg_util.Args):
|
| 20 |
+
# resume
|
| 21 |
+
auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
|
| 22 |
+
# create tensorboard logger
|
| 23 |
+
tb_lg: misc.TensorboardLogger
|
| 24 |
+
with_tb_lg = dist.is_master()
|
| 25 |
+
if with_tb_lg:
|
| 26 |
+
os.makedirs(args.tb_log_dir_path, exist_ok=True)
|
| 27 |
+
# noinspection PyTypeChecker
|
| 28 |
+
tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
|
| 29 |
+
tb_lg.flush()
|
| 30 |
+
else:
|
| 31 |
+
# noinspection PyTypeChecker
|
| 32 |
+
tb_lg = misc.DistLogger(None, verbose=False)
|
| 33 |
+
dist.barrier()
|
| 34 |
+
|
| 35 |
+
# log args
|
| 36 |
+
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
|
| 37 |
+
print(f'initial args:\n{str(args)}')
|
| 38 |
+
|
| 39 |
+
# build data
|
| 40 |
+
if not args.local_debug:
|
| 41 |
+
print(f'[build PT data] ...\n')
|
| 42 |
+
num_classes, dataset_train, dataset_val = build_dataset(
|
| 43 |
+
args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
|
| 44 |
+
)
|
| 45 |
+
types = str((type(dataset_train).__name__, type(dataset_val).__name__))
|
| 46 |
+
|
| 47 |
+
ld_val = DataLoader(
|
| 48 |
+
dataset_val, num_workers=0, pin_memory=True,
|
| 49 |
+
batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
|
| 50 |
+
shuffle=False, drop_last=False,
|
| 51 |
+
)
|
| 52 |
+
del dataset_val
|
| 53 |
+
|
| 54 |
+
ld_train = DataLoader(
|
| 55 |
+
dataset=dataset_train, num_workers=args.workers, pin_memory=True,
|
| 56 |
+
generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
|
| 57 |
+
batch_sampler=DistInfiniteBatchSampler(
|
| 58 |
+
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
|
| 59 |
+
shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
|
| 60 |
+
),
|
| 61 |
+
)
|
| 62 |
+
del dataset_train
|
| 63 |
+
|
| 64 |
+
[print(line) for line in auto_resume_info]
|
| 65 |
+
print(f'[dataloader multi processing] ...', end='', flush=True)
|
| 66 |
+
stt = time.time()
|
| 67 |
+
iters_train = len(ld_train)
|
| 68 |
+
ld_train = iter(ld_train)
|
| 69 |
+
# noinspection PyArgumentList
|
| 70 |
+
print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
|
| 71 |
+
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
num_classes = 1000
|
| 75 |
+
ld_val = ld_train = None
|
| 76 |
+
iters_train = 10
|
| 77 |
+
|
| 78 |
+
# build models
|
| 79 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 80 |
+
from models import VAR, VQVAE, build_vae_var
|
| 81 |
+
from trainer import VARTrainer
|
| 82 |
+
from utils.amp_sc import AmpOptimizer
|
| 83 |
+
from utils.lr_control import filter_params
|
| 84 |
+
|
| 85 |
+
vae_local, var_wo_ddp = build_vae_var(
|
| 86 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
|
| 87 |
+
device=dist.get_device(), patch_nums=args.patch_nums,
|
| 88 |
+
num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
|
| 89 |
+
flash_if_available=args.fuse, fused_if_available=args.fuse,
|
| 90 |
+
init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
vae_ckpt = 'vae_ch160v4096z32.pth'
|
| 94 |
+
if dist.is_local_master():
|
| 95 |
+
if not os.path.exists(vae_ckpt):
|
| 96 |
+
os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
|
| 97 |
+
dist.barrier()
|
| 98 |
+
vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
|
| 99 |
+
|
| 100 |
+
vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
|
| 101 |
+
var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
|
| 102 |
+
var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
| 103 |
+
|
| 104 |
+
print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
|
| 105 |
+
count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
|
| 106 |
+
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
|
| 107 |
+
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
|
| 108 |
+
|
| 109 |
+
# build optimizer
|
| 110 |
+
names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
|
| 111 |
+
'cls_token', 'start_token', 'task_token', 'cfg_uncond',
|
| 112 |
+
'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
|
| 113 |
+
'gamma', 'beta',
|
| 114 |
+
'ada_gss', 'moe_bias',
|
| 115 |
+
'scale_mul',
|
| 116 |
+
})
|
| 117 |
+
opt_clz = {
|
| 118 |
+
'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
|
| 119 |
+
'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
|
| 120 |
+
}[args.opt.lower().strip()]
|
| 121 |
+
opt_kw = dict(lr=args.tlr, weight_decay=0)
|
| 122 |
+
print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
|
| 123 |
+
|
| 124 |
+
var_optim = AmpOptimizer(
|
| 125 |
+
mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
|
| 126 |
+
grad_clip=args.tclip, n_gradient_accumulation=args.ac
|
| 127 |
+
)
|
| 128 |
+
del names, paras, para_groups
|
| 129 |
+
|
| 130 |
+
# build trainer
|
| 131 |
+
trainer = VARTrainer(
|
| 132 |
+
device=args.device, patch_nums=args.patch_nums, resos=args.resos,
|
| 133 |
+
vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
|
| 134 |
+
var_opt=var_optim, label_smooth=args.ls,
|
| 135 |
+
)
|
| 136 |
+
if trainer_state is not None and len(trainer_state):
|
| 137 |
+
trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
|
| 138 |
+
del vae_local, var_wo_ddp, var, var_optim
|
| 139 |
+
|
| 140 |
+
if args.local_debug:
|
| 141 |
+
rng = torch.Generator('cpu')
|
| 142 |
+
rng.manual_seed(0)
|
| 143 |
+
B = 4
|
| 144 |
+
inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
|
| 145 |
+
label = torch.ones(B, dtype=torch.long)
|
| 146 |
+
|
| 147 |
+
me = misc.MetricLogger(delimiter=' ')
|
| 148 |
+
trainer.train_step(
|
| 149 |
+
it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
|
| 150 |
+
inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
|
| 151 |
+
)
|
| 152 |
+
trainer.load_state_dict(trainer.state_dict())
|
| 153 |
+
trainer.train_step(
|
| 154 |
+
it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
|
| 155 |
+
inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
|
| 156 |
+
)
|
| 157 |
+
print({k: meter.global_avg for k, meter in me.meters.items()})
|
| 158 |
+
|
| 159 |
+
args.dump_log(); tb_lg.flush(); tb_lg.close()
|
| 160 |
+
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
|
| 161 |
+
sys.stdout.close(), sys.stderr.close()
|
| 162 |
+
exit(0)
|
| 163 |
+
|
| 164 |
+
dist.barrier()
|
| 165 |
+
return (
|
| 166 |
+
tb_lg, trainer, start_ep, start_it,
|
| 167 |
+
iters_train, ld_train, ld_val
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def main_training():
|
| 172 |
+
|
| 173 |
+
# # 在能 import Args 的环境里执行
|
| 174 |
+
# import copy, types
|
| 175 |
+
# from utils.arg_util import Args
|
| 176 |
+
|
| 177 |
+
# bad = []
|
| 178 |
+
# # 候选字段:注解 + 非私有类属性
|
| 179 |
+
# names = set(getattr(Args, '__annotations__', {}).keys()) | {
|
| 180 |
+
# k for k in vars(Args).keys() if not k.startswith('_')
|
| 181 |
+
# }
|
| 182 |
+
# for name in sorted(names):
|
| 183 |
+
# if hasattr(Args, name):
|
| 184 |
+
# val = getattr(Args, name)
|
| 185 |
+
# try:
|
| 186 |
+
# copy.deepcopy(val)
|
| 187 |
+
# except Exception as e:
|
| 188 |
+
# bad.append((name, type(val).__name__, str(e)))
|
| 189 |
+
|
| 190 |
+
# print("Non-deepcopy-able defaults:")
|
| 191 |
+
# for n,t,e in bad:
|
| 192 |
+
# print(f" {n}: {t} -> {e}")
|
| 193 |
+
|
| 194 |
+
args: arg_util.Args = arg_util.init_dist_and_get_args()
|
| 195 |
+
if args.local_debug:
|
| 196 |
+
torch.autograd.set_detect_anomaly(True)
|
| 197 |
+
|
| 198 |
+
(
|
| 199 |
+
tb_lg, trainer,
|
| 200 |
+
start_ep, start_it,
|
| 201 |
+
iters_train, ld_train, ld_val
|
| 202 |
+
) = build_everything(args)
|
| 203 |
+
|
| 204 |
+
# train
|
| 205 |
+
start_time = time.time()
|
| 206 |
+
best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
|
| 207 |
+
best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
|
| 208 |
+
|
| 209 |
+
L_mean, L_tail = -1, -1
|
| 210 |
+
for ep in range(start_ep, args.ep):
|
| 211 |
+
if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
|
| 212 |
+
ld_train.sampler.set_epoch(ep)
|
| 213 |
+
if ep < 3:
|
| 214 |
+
# noinspection PyArgumentList
|
| 215 |
+
print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
|
| 216 |
+
tb_lg.set_step(ep * iters_train)
|
| 217 |
+
|
| 218 |
+
stats, (sec, remain_time, finish_time) = train_one_ep(
|
| 219 |
+
ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
|
| 223 |
+
best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
|
| 224 |
+
if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
|
| 225 |
+
args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
|
| 226 |
+
args.cur_ep = f'{ep+1}/{args.ep}'
|
| 227 |
+
args.remain_time, args.finish_time = remain_time, finish_time
|
| 228 |
+
|
| 229 |
+
AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
|
| 230 |
+
is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
|
| 231 |
+
if is_val_and_also_saving:
|
| 232 |
+
val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
|
| 233 |
+
best_updated = best_val_loss_tail > val_loss_tail
|
| 234 |
+
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
|
| 235 |
+
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
|
| 236 |
+
AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
|
| 237 |
+
args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
|
| 238 |
+
print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
|
| 239 |
+
|
| 240 |
+
if dist.is_local_master():
|
| 241 |
+
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
|
| 242 |
+
local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
|
| 243 |
+
print(f'[saving ckpt] ...', end='', flush=True)
|
| 244 |
+
torch.save({
|
| 245 |
+
'epoch': ep+1,
|
| 246 |
+
'iter': 0,
|
| 247 |
+
'trainer': trainer.state_dict(),
|
| 248 |
+
'args': args.state_dict(),
|
| 249 |
+
}, local_out_ckpt)
|
| 250 |
+
if best_updated:
|
| 251 |
+
shutil.copy(local_out_ckpt, local_out_ckpt_best)
|
| 252 |
+
print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
|
| 253 |
+
dist.barrier()
|
| 254 |
+
|
| 255 |
+
print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
|
| 256 |
+
tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
|
| 257 |
+
tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
|
| 258 |
+
args.dump_log(); tb_lg.flush()
|
| 259 |
+
|
| 260 |
+
total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
|
| 261 |
+
print('\n\n')
|
| 262 |
+
print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
|
| 263 |
+
print('\n\n')
|
| 264 |
+
|
| 265 |
+
del stats
|
| 266 |
+
del iters_train, ld_train
|
| 267 |
+
time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
|
| 268 |
+
|
| 269 |
+
args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
|
| 270 |
+
print(f'final args:\n\n{str(args)}')
|
| 271 |
+
args.dump_log(); tb_lg.flush(); tb_lg.close()
|
| 272 |
+
dist.barrier()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
|
| 276 |
+
# import heavy packages after Dataloader object creation
|
| 277 |
+
from trainer import VARTrainer
|
| 278 |
+
from utils.lr_control import lr_wd_annealing
|
| 279 |
+
trainer: VARTrainer
|
| 280 |
+
|
| 281 |
+
step_cnt = 0
|
| 282 |
+
me = misc.MetricLogger(delimiter=' ')
|
| 283 |
+
me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
|
| 284 |
+
me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
|
| 285 |
+
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
|
| 286 |
+
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
|
| 287 |
+
header = f'[Ep]: [{ep:4d}/{args.ep}]'
|
| 288 |
+
|
| 289 |
+
if is_first_ep:
|
| 290 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
| 291 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 292 |
+
g_it, max_it = ep * iters_train, args.ep * iters_train
|
| 293 |
+
|
| 294 |
+
for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
|
| 295 |
+
g_it = ep * iters_train + it
|
| 296 |
+
if it < start_it: continue
|
| 297 |
+
if is_first_ep and it == start_it: warnings.resetwarnings()
|
| 298 |
+
|
| 299 |
+
inp = inp.to(args.device, non_blocking=True)
|
| 300 |
+
label = label.to(args.device, non_blocking=True)
|
| 301 |
+
|
| 302 |
+
args.cur_it = f'{it+1}/{iters_train}'
|
| 303 |
+
|
| 304 |
+
wp_it = args.wp * iters_train
|
| 305 |
+
min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
|
| 306 |
+
args.cur_lr, args.cur_wd = max_tlr, max_twd
|
| 307 |
+
|
| 308 |
+
if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
|
| 309 |
+
if g_it <= wp_it: prog_si = args.pg0
|
| 310 |
+
elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
|
| 311 |
+
else:
|
| 312 |
+
delta = len(args.patch_nums) - 1 - args.pg0
|
| 313 |
+
progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
|
| 314 |
+
prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
|
| 315 |
+
else:
|
| 316 |
+
prog_si = -1
|
| 317 |
+
|
| 318 |
+
stepping = (g_it + 1) % args.ac == 0
|
| 319 |
+
step_cnt += int(stepping)
|
| 320 |
+
|
| 321 |
+
grad_norm, scale_log2 = trainer.train_step(
|
| 322 |
+
it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
|
| 323 |
+
inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
me.update(tlr=max_tlr)
|
| 327 |
+
tb_lg.set_step(step=g_it)
|
| 328 |
+
tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
|
| 329 |
+
tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
|
| 330 |
+
tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
|
| 331 |
+
tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
|
| 332 |
+
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
|
| 333 |
+
|
| 334 |
+
if args.tclip > 0:
|
| 335 |
+
tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
|
| 336 |
+
tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
|
| 337 |
+
|
| 338 |
+
me.synchronize_between_processes()
|
| 339 |
+
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class NullDDP(torch.nn.Module):
|
| 343 |
+
def __init__(self, module, *args, **kwargs):
|
| 344 |
+
super(NullDDP, self).__init__()
|
| 345 |
+
self.module = module
|
| 346 |
+
self.require_backward_grad_sync = False
|
| 347 |
+
|
| 348 |
+
def forward(self, *args, **kwargs):
|
| 349 |
+
return self.module(*args, **kwargs)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == '__main__':
|
| 353 |
+
try: main_training()
|
| 354 |
+
finally:
|
| 355 |
+
dist.finalize()
|
| 356 |
+
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
|
| 357 |
+
sys.stdout.close(), sys.stderr.close()
|
VAR/code/VAR/trainer.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
import dist
|
| 10 |
+
from models import VAR, VQVAE, VectorQuantizer2
|
| 11 |
+
from utils.amp_sc import AmpOptimizer
|
| 12 |
+
from utils.misc import MetricLogger, TensorboardLogger
|
| 13 |
+
|
| 14 |
+
Ten = torch.Tensor
|
| 15 |
+
FTen = torch.Tensor
|
| 16 |
+
ITen = torch.LongTensor
|
| 17 |
+
BTen = torch.BoolTensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VARTrainer(object):
|
| 21 |
+
def __init__(
|
| 22 |
+
self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
|
| 23 |
+
vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
|
| 24 |
+
var_opt: AmpOptimizer, label_smooth: float,
|
| 25 |
+
):
|
| 26 |
+
super(VARTrainer, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
|
| 29 |
+
self.quantize_local: VectorQuantizer2
|
| 30 |
+
self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
|
| 31 |
+
self.var_opt = var_opt
|
| 32 |
+
|
| 33 |
+
del self.var_wo_ddp.rng
|
| 34 |
+
self.var_wo_ddp.rng = torch.Generator(device=device)
|
| 35 |
+
|
| 36 |
+
self.label_smooth = label_smooth
|
| 37 |
+
self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
|
| 38 |
+
self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
|
| 39 |
+
self.L = sum(pn * pn for pn in patch_nums)
|
| 40 |
+
self.last_l = patch_nums[-1] * patch_nums[-1]
|
| 41 |
+
self.loss_weight = torch.ones(1, self.L, device=device) / self.L
|
| 42 |
+
|
| 43 |
+
self.patch_nums, self.resos = patch_nums, resos
|
| 44 |
+
self.begin_ends = []
|
| 45 |
+
cur = 0
|
| 46 |
+
for i, pn in enumerate(patch_nums):
|
| 47 |
+
self.begin_ends.append((cur, cur + pn * pn))
|
| 48 |
+
cur += pn*pn
|
| 49 |
+
|
| 50 |
+
self.prog_it = 0
|
| 51 |
+
self.last_prog_si = -1
|
| 52 |
+
self.first_prog = True
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def eval_ep(self, ld_val: DataLoader):
|
| 56 |
+
tot = 0
|
| 57 |
+
L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
|
| 58 |
+
stt = time.time()
|
| 59 |
+
training = self.var_wo_ddp.training
|
| 60 |
+
self.var_wo_ddp.eval()
|
| 61 |
+
for inp_B3HW, label_B in ld_val:
|
| 62 |
+
B, V = label_B.shape[0], self.vae_local.vocab_size
|
| 63 |
+
inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
|
| 64 |
+
label_B = label_B.to(dist.get_device(), non_blocking=True)
|
| 65 |
+
|
| 66 |
+
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
|
| 67 |
+
gt_BL = torch.cat(gt_idx_Bl, dim=1)
|
| 68 |
+
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
|
| 69 |
+
|
| 70 |
+
self.var_wo_ddp.forward
|
| 71 |
+
logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
|
| 72 |
+
L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
|
| 73 |
+
L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
|
| 74 |
+
acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
|
| 75 |
+
acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
|
| 76 |
+
tot += B
|
| 77 |
+
self.var_wo_ddp.train(training)
|
| 78 |
+
|
| 79 |
+
stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
|
| 80 |
+
dist.allreduce(stats)
|
| 81 |
+
tot = round(stats[-1].item())
|
| 82 |
+
stats /= tot
|
| 83 |
+
L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
|
| 84 |
+
return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
|
| 85 |
+
|
| 86 |
+
def train_step(
|
| 87 |
+
self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
|
| 88 |
+
inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
|
| 89 |
+
) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
|
| 90 |
+
# if progressive training
|
| 91 |
+
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
|
| 92 |
+
if self.last_prog_si != prog_si:
|
| 93 |
+
if self.last_prog_si != -1: self.first_prog = False
|
| 94 |
+
self.last_prog_si = prog_si
|
| 95 |
+
self.prog_it = 0
|
| 96 |
+
self.prog_it += 1
|
| 97 |
+
prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
|
| 98 |
+
if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
|
| 99 |
+
if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
|
| 100 |
+
|
| 101 |
+
# forward
|
| 102 |
+
B, V = label_B.shape[0], self.vae_local.vocab_size
|
| 103 |
+
self.var.require_backward_grad_sync = stepping
|
| 104 |
+
|
| 105 |
+
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
|
| 106 |
+
gt_BL = torch.cat(gt_idx_Bl, dim=1)
|
| 107 |
+
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
|
| 108 |
+
|
| 109 |
+
with self.var_opt.amp_ctx:
|
| 110 |
+
self.var_wo_ddp.forward
|
| 111 |
+
logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
|
| 112 |
+
loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
|
| 113 |
+
if prog_si >= 0: # in progressive training
|
| 114 |
+
bg, ed = self.begin_ends[prog_si]
|
| 115 |
+
assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
|
| 116 |
+
lw = self.loss_weight[:, :ed].clone()
|
| 117 |
+
lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
|
| 118 |
+
else: # not in progressive training
|
| 119 |
+
lw = self.loss_weight
|
| 120 |
+
loss = loss.mul(lw).sum(dim=-1).mean()
|
| 121 |
+
|
| 122 |
+
# backward
|
| 123 |
+
grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
|
| 124 |
+
|
| 125 |
+
# log
|
| 126 |
+
pred_BL = logits_BLV.data.argmax(dim=-1)
|
| 127 |
+
if it == 0 or it in metric_lg.log_iters:
|
| 128 |
+
Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
|
| 129 |
+
acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
|
| 130 |
+
if prog_si >= 0: # in progressive training
|
| 131 |
+
Ltail = acc_tail = -1
|
| 132 |
+
else: # not in progressive training
|
| 133 |
+
Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
|
| 134 |
+
acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
|
| 135 |
+
grad_norm = grad_norm.item()
|
| 136 |
+
metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
|
| 137 |
+
|
| 138 |
+
# log to tensorboard
|
| 139 |
+
if g_it == 0 or (g_it + 1) % 500 == 0:
|
| 140 |
+
prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
|
| 141 |
+
dist.allreduce(prob_per_class_is_chosen)
|
| 142 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
| 143 |
+
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
|
| 144 |
+
if dist.is_master():
|
| 145 |
+
if g_it == 0:
|
| 146 |
+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
|
| 147 |
+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
|
| 148 |
+
kw = dict(z_voc_usage=cluster_usage)
|
| 149 |
+
for si, (bg, ed) in enumerate(self.begin_ends):
|
| 150 |
+
if 0 <= prog_si < si: break
|
| 151 |
+
pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
|
| 152 |
+
acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
|
| 153 |
+
ce = self.val_loss(pred, tar).item()
|
| 154 |
+
kw[f'acc_{self.resos[si]}'] = acc
|
| 155 |
+
kw[f'L_{self.resos[si]}'] = ce
|
| 156 |
+
tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
|
| 157 |
+
tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
|
| 158 |
+
|
| 159 |
+
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
|
| 160 |
+
return grad_norm, scale_log2
|
| 161 |
+
|
| 162 |
+
def get_config(self):
|
| 163 |
+
return {
|
| 164 |
+
'patch_nums': self.patch_nums, 'resos': self.resos,
|
| 165 |
+
'label_smooth': self.label_smooth,
|
| 166 |
+
'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
def state_dict(self):
|
| 170 |
+
state = {'config': self.get_config()}
|
| 171 |
+
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
|
| 172 |
+
m = getattr(self, k)
|
| 173 |
+
if m is not None:
|
| 174 |
+
if hasattr(m, '_orig_mod'):
|
| 175 |
+
m = m._orig_mod
|
| 176 |
+
state[k] = m.state_dict()
|
| 177 |
+
return state
|
| 178 |
+
|
| 179 |
+
def load_state_dict(self, state, strict=True, skip_vae=False):
|
| 180 |
+
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
|
| 181 |
+
if skip_vae and 'vae' in k: continue
|
| 182 |
+
m = getattr(self, k)
|
| 183 |
+
if m is not None:
|
| 184 |
+
if hasattr(m, '_orig_mod'):
|
| 185 |
+
m = m._orig_mod
|
| 186 |
+
ret = m.load_state_dict(state[k], strict=strict)
|
| 187 |
+
if ret is not None:
|
| 188 |
+
missing, unexpected = ret
|
| 189 |
+
print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
|
| 190 |
+
print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
|
| 191 |
+
|
| 192 |
+
config: dict = state.pop('config', None)
|
| 193 |
+
self.prog_it = config.get('prog_it', 0)
|
| 194 |
+
self.last_prog_si = config.get('last_prog_si', -1)
|
| 195 |
+
self.first_prog = config.get('first_prog', True)
|
| 196 |
+
if config is not None:
|
| 197 |
+
for k, v in self.get_config().items():
|
| 198 |
+
if config.get(k, None) != v:
|
| 199 |
+
err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
|
| 200 |
+
if strict: raise AttributeError(err)
|
| 201 |
+
else: print(err)
|
VAR/code/VAR/utils/__pycache__/amp_sc.cpython-310.pyc
ADDED
|
Binary file (3.14 kB). View file
|
|
|
VAR/code/VAR/utils/__pycache__/arg_util.cpython-310.pyc
ADDED
|
Binary file (9.26 kB). View file
|
|
|
VAR/code/VAR/utils/__pycache__/arg_util.cpython-311.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
VAR/code/VAR/utils/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|