qic999 commited on
Commit
9999797
·
verified ·
1 Parent(s): 284fa92

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. VAR/ILSVRC2012_img_train.torrent +0 -0
  3. VAR/ILSVRC2012_img_val.torrent +0 -0
  4. VAR/Imagenette/imagenette2-160.tgz +3 -0
  5. VAR/Imagenette/imagenette2-320.tgz +3 -0
  6. VAR/Imagenette/imagenette2.tgz +3 -0
  7. VAR/cifar-10-batches-py/batches.meta +0 -0
  8. VAR/cifar-10-batches-py/data_batch_1 +3 -0
  9. VAR/cifar-10-batches-py/data_batch_2 +3 -0
  10. VAR/cifar-10-batches-py/data_batch_3 +3 -0
  11. VAR/cifar-10-batches-py/data_batch_4 +3 -0
  12. VAR/cifar-10-batches-py/data_batch_5 +3 -0
  13. VAR/cifar-10-batches-py/readme.html +1 -0
  14. VAR/cifar-10-batches-py/test_batch +3 -0
  15. VAR/cifar-10-python.tar.gz +3 -0
  16. VAR/cifar-100-python.tar.gz +3 -0
  17. VAR/cifar-100-python/file.txt~ +0 -0
  18. VAR/cifar-100-python/meta +10 -0
  19. VAR/cifar-100-python/test +3 -0
  20. VAR/cifar-100-python/train +3 -0
  21. VAR/code/VAR/LICENSE +21 -0
  22. VAR/code/VAR/README.md +232 -0
  23. VAR/code/VAR/__pycache__/dist.cpython-310.pyc +0 -0
  24. VAR/code/VAR/__pycache__/dist.cpython-311.pyc +0 -0
  25. VAR/code/VAR/__pycache__/trainer.cpython-310.pyc +0 -0
  26. VAR/code/VAR/config.sh +30 -0
  27. VAR/code/VAR/demo_sample.ipynb +127 -0
  28. VAR/code/VAR/demo_zero_shot_edit.ipynb +0 -0
  29. VAR/code/VAR/dist.py +211 -0
  30. VAR/code/VAR/models/__init__.py +39 -0
  31. VAR/code/VAR/models/__pycache__/__init__.cpython-310.pyc +0 -0
  32. VAR/code/VAR/models/__pycache__/basic_vae.cpython-310.pyc +0 -0
  33. VAR/code/VAR/models/__pycache__/basic_var.cpython-310.pyc +0 -0
  34. VAR/code/VAR/models/__pycache__/helpers.cpython-310.pyc +0 -0
  35. VAR/code/VAR/models/__pycache__/quant.cpython-310.pyc +0 -0
  36. VAR/code/VAR/models/__pycache__/var.cpython-310.pyc +0 -0
  37. VAR/code/VAR/models/__pycache__/vqvae.cpython-310.pyc +0 -0
  38. VAR/code/VAR/models/basic_vae.py +226 -0
  39. VAR/code/VAR/models/basic_var.py +174 -0
  40. VAR/code/VAR/models/helpers.py +59 -0
  41. VAR/code/VAR/models/quant.py +243 -0
  42. VAR/code/VAR/models/var.py +419 -0
  43. VAR/code/VAR/models/vqvae.py +95 -0
  44. VAR/code/VAR/requirements.txt +8 -0
  45. VAR/code/VAR/train.py +357 -0
  46. VAR/code/VAR/trainer.py +201 -0
  47. VAR/code/VAR/utils/__pycache__/amp_sc.cpython-310.pyc +0 -0
  48. VAR/code/VAR/utils/__pycache__/arg_util.cpython-310.pyc +0 -0
  49. VAR/code/VAR/utils/__pycache__/arg_util.cpython-311.pyc +0 -0
  50. VAR/code/VAR/utils/__pycache__/data.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -7967,3 +7967,12 @@ Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_
7967
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_7.png filter=lfs diff=lfs merge=lfs -text
7968
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_8.png filter=lfs diff=lfs merge=lfs -text
7969
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_9.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
7967
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_7.png filter=lfs diff=lfs merge=lfs -text
7968
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_8.png filter=lfs diff=lfs merge=lfs -text
7969
  Abnormal-CT-Generation-Healthy/logs/full_ct_2d_with_body_mask/inference/valid_5_a_2_sample_9.png filter=lfs diff=lfs merge=lfs -text
7970
+ VAR/cifar-10-batches-py/data_batch_1 filter=lfs diff=lfs merge=lfs -text
7971
+ VAR/cifar-10-batches-py/data_batch_2 filter=lfs diff=lfs merge=lfs -text
7972
+ VAR/cifar-10-batches-py/data_batch_3 filter=lfs diff=lfs merge=lfs -text
7973
+ VAR/cifar-10-batches-py/data_batch_4 filter=lfs diff=lfs merge=lfs -text
7974
+ VAR/cifar-10-batches-py/data_batch_5 filter=lfs diff=lfs merge=lfs -text
7975
+ VAR/cifar-10-batches-py/test_batch filter=lfs diff=lfs merge=lfs -text
7976
+ VAR/cifar-100-python/test filter=lfs diff=lfs merge=lfs -text
7977
+ VAR/cifar-100-python/train filter=lfs diff=lfs merge=lfs -text
7978
+ VAR/imagenet/LOC_train_solution.csv filter=lfs diff=lfs merge=lfs -text
VAR/ILSVRC2012_img_train.torrent ADDED
File without changes
VAR/ILSVRC2012_img_val.torrent ADDED
File without changes
VAR/Imagenette/imagenette2-160.tgz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64d0c4859f35a461889e0147755a999a48b49bf38a7e0f9bd27003f10db02fe5
3
+ size 99003388
VAR/Imagenette/imagenette2-320.tgz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:569b4497c98db6dd29f335d1f109cf315fe127053cedf69010d047f0188e158c
3
+ size 341663724
VAR/Imagenette/imagenette2.tgz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cbfac238434d89fe99e651496f0812ebc7a10fa62bd42d6874042bf01de4efd
3
+ size 1557161267
VAR/cifar-10-batches-py/batches.meta ADDED
Binary file (158 Bytes). View file
 
VAR/cifar-10-batches-py/data_batch_1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54636561a3ce25bd3e19253c6b0d8538147b0ae398331ac4a2d86c6d987368cd
3
+ size 31035704
VAR/cifar-10-batches-py/data_batch_2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:766b2cef9fbc745cf056b3152224f7cf77163b330ea9a15f9392beb8b89bc5a8
3
+ size 31035320
VAR/cifar-10-batches-py/data_batch_3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f00d98ebfb30b3ec0ad19f9756dc2630b89003e10525f5e148445e82aa6a1f9
3
+ size 31035999
VAR/cifar-10-batches-py/data_batch_4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f7bb240661948b8f4d53e36ec720d8306f5668bd0071dcb4e6c947f78e9682b
3
+ size 31035696
VAR/cifar-10-batches-py/data_batch_5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d91802434d8376bbaeeadf58a737e3a1b12ac839077e931237e0dcd43adcb154
3
+ size 31035623
VAR/cifar-10-batches-py/readme.html ADDED
@@ -0,0 +1 @@
 
 
1
+ <meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
VAR/cifar-10-batches-py/test_batch ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f53d8d457504f7cff4ea9e021afcf0e0ad8e24a91f3fc42091b8adef61157831
3
+ size 31035526
VAR/cifar-10-python.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce
3
+ size 170498071
VAR/cifar-100-python.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7
3
+ size 169001437
VAR/cifar-100-python/file.txt~ ADDED
File without changes
VAR/cifar-100-python/meta ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ �}q(Ufine_label_namesq]q(UappleqU
2
+ Ubeetleq Ubicycleq Ubottleq
3
+ chimpanzeeqUclockqUcloudqU cockroachqUcouchqUcrabqU crocodileqUcupq Udinosaurq!Udolphinq"Uelephantq#Uflatfishq$Uforestq%Ufoxq&Ugirlq'Uhamsterq(Uhouseq)Ukangarooq*Ukeyboardq+Ulampq,U
4
+ lawn_mowerq-Uleopardq.Ulionq/Ulizardq0Ulobsterq1Umanq2U
5
+ maple_treeq3U
6
+ motorcycleq4Umountainq5Umouseq6Umushroomq7Uoak_treeq8Uorangeq9Uorchidq:Uotterq;U palm_treeq<Upearq=U pickup_truckq>U pine_treeq?Uplainq@UplateqAUpoppyqBU porcupineqCUpossumqDUrabbitqEUraccoonqFUrayqGUroadqHUrocketqIUroseqJUseaqKUsealqLUsharkqMUshrewqNUskunkqOU
7
+ skyscraperqPUsnailqQUsnakeqRUspiderqSUsquirrelqTU streetcarqUU sunflowerqVU sweet_pepperqWUtableqXUtankqYU telephoneqZU
8
+ televisionq[Utigerq\Utractorq]Utrainq^Utroutq_Utulipq`UturtleqaUwardrobeqbUwhaleqcU willow_treeqdUwolfqeUwomanqfUwormqgeUcoarse_label_namesqh]qi(Uaquatic_mammalsqjUfishqkUflowersqlUfood_containersqmUfruit_and_vegetablesqnUhousehold_electrical_devicesqoUhousehold_furnitureqpUinsectsqqUlarge_carnivoresqrUlarge_man-made_outdoor_thingsqsUlarge_natural_outdoor_scenesqtUlarge_omnivores_and_herbivoresquUmedium_mammalsqvUnon-insect_invertebratesqwUpeopleqxUreptilesqyU
9
+ vehicles_1q|U
10
+ vehicles_2q}eu.
VAR/cifar-100-python/test ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b67687d9933c4db8f0831104447f15b93774f4f464bd0516f0f0f2ac83b7864
3
+ size 31049707
VAR/cifar-100-python/train ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735e79b04f092ca3d2e6d07f368c0a7d70d48c48d28865950cc24454cf45129b
3
+ size 155249918
VAR/code/VAR/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 FoundationVision
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
VAR/code/VAR/README.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
2
+
3
+ <div align="center">
4
+
5
+ [![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://opensource.bytedance.com/gmpt/t2i/invite)&nbsp;
6
+ [![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)&nbsp;
7
+ [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)&nbsp;
8
+ [![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
9
+
10
+
11
+ </div>
12
+ <p align="center" style="font-size: larger;">
13
+ <a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
14
+ </p>
15
+
16
+ <div>
17
+ <p align="center" style="font-size: larger;">
18
+ <strong>NeurIPS 2024 Best Paper</strong>
19
+ </p>
20
+ </div>
21
+
22
+ <p align="center">
23
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
24
+ <p>
25
+
26
+ <br>
27
+
28
+ ## News
29
+
30
+ * **2024-12:** 🏆 VAR received **NeurIPS 2024 Best Paper Award**.
31
+ * **2024-12:** 🔥 We Release our Text-to-Image research based on VAR, please check [Infinity](https://github.com/FoundationVision/Infinity).
32
+ * **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.
33
+ * **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.
34
+
35
+ ## 🕹️ Try and Play with VAR!
36
+
37
+ ~~We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!~~
38
+
39
+ We provide a [demo website](https://opensource.bytedance.com/gmpt/t2i/invite) for you to play with VAR Text-to-Image and generate images interactively. Enjoy the fun of visual autoregressive modeling!
40
+
41
+ We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
42
+
43
+ [//]: # (<p align="center">)
44
+ [//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
45
+ [//]: # (<p>)
46
+
47
+
48
+ ## What's New?
49
+
50
+ ### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
51
+
52
+ Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
53
+
54
+ <p align="center">
55
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
56
+ <p>
57
+
58
+ ### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
59
+ <p align="center">
60
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
61
+ <p>
62
+
63
+
64
+ ### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
65
+
66
+
67
+ <p align="center">
68
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
69
+ <p>
70
+ <p align="center">
71
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
72
+ <p>
73
+
74
+
75
+ ### 🔥 Zero-shot generalizability🛠️:
76
+
77
+ <p align="center">
78
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
79
+ <p>
80
+
81
+ #### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
82
+
83
+
84
+ ## VAR zoo
85
+ We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
86
+
87
+ | model | reso. | FID | rel. cost | #params | HF weights🤗 |
88
+ |:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
89
+ | VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
90
+ | VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
91
+ | VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
92
+ | VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
93
+ | VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
94
+ | VAR-d36 | 512 | **2.63** | - | 2.3B | [var_d36.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d36.pth) |
95
+
96
+ You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
97
+
98
+
99
+ ## Installation
100
+
101
+ 1. Install `torch>=2.0.0`.
102
+ 2. Install other pip packages via `pip3 install -r requirements.txt`.
103
+ 3. Prepare the [ImageNet](http://image-net.org/) dataset
104
+ <details>
105
+ <summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
106
+
107
+ ```
108
+ /path/to/imagenet/:
109
+ train/:
110
+ n01440764:
111
+ many_images.JPEG ...
112
+ n01443537:
113
+ many_images.JPEG ...
114
+ val/:
115
+ n01440764:
116
+ ILSVRC2012_val_00000293.JPEG ...
117
+ n01443537:
118
+ ILSVRC2012_val_00000236.JPEG ...
119
+ ```
120
+ **NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
121
+ </details>
122
+
123
+ 5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
124
+
125
+
126
+ ## Training Scripts
127
+
128
+ To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
129
+ ```shell
130
+ # d16, 256x256
131
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
132
+ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
133
+ # d20, 256x256
134
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
135
+ --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
136
+ # d24, 256x256
137
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
138
+ --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
139
+ # d30, 256x256
140
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
141
+ --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
142
+ # d36-s, 512x512 (-s means saln=1, shared AdaLN)
143
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
144
+ --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
145
+ ```
146
+ A folder named `local_output` will be created to save the checkpoints and logs.
147
+ You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
148
+
149
+ If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
150
+
151
+ ## Sampling & Zero-shot Inference
152
+
153
+ For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
154
+ Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
155
+
156
+ Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
157
+ We'll provide the sampling script later.
158
+
159
+
160
+ ## Third-party Usage and Research
161
+
162
+ ***In this pargraph, we cross link third-party repositories or research which use VAR and report results. You can let us know by raising an issue***
163
+
164
+ (`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
165
+
166
+ | **Time** | **Research** | **Link** |
167
+ |--------------|-------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------|
168
+ | [5/12/2025] | [ICML 2025]Continuous Visual Autoregressive Generation via Score Maximization | https://github.com/shaochenze/EAR |
169
+ | [5/8/2025] | Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction | https://github.com/icon-lab/FedGAT |
170
+ | [4/7/2025] | FastVAR: Linear Visual Autoregressive Modeling via Cached Token Pruning | https://github.com/csguoh/FastVAR |
171
+ | [4/3/2025] | VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning | https://github.com/VARGPT-family/VARGPT-v1.1 |
172
+ | [3/31/2025] | Training-Free Text-Guided Image Editing with Visual Autoregressive Model | https://github.com/wyf0912/AREdit |
173
+ | [3/17/2025] | Next-Scale Autoregressive Models are Zero-Shot Single-Image Object View Synthesizers | https://github.com/Shiran-Yuan/ArchonView |
174
+ | [3/14/2025] | Safe-VAR: Safe Visual Autoregressive Model for Text-to-Image Generative Watermarking | https://arxiv.org/abs/2503.11324 |
175
+ | [3/3/2025] | [ICML 2025]Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator | https://research.nvidia.com/labs/dir/ddo/ |
176
+ | [2/28/2025] | Autoregressive Medical Image Segmentation via Next-Scale Mask Prediction | https://arxiv.org/abs/2502.20784 |
177
+ | [2/27/2025] | FlexVAR: Flexible Visual Autoregressive Modeling without Residual Prediction | https://github.com/jiaosiyu1999/FlexVAR |
178
+ | [2/17/2025] | MARS: Mesh AutoRegressive Model for 3D Shape Detailization | https://arxiv.org/abs/2502.11390 |
179
+ | [1/31/2025] | [ICML 2025]Visual Autoregressive Modeling for Image Super-Resolution | https://github.com/quyp2000/VARSR |
180
+ | [1/21/2025] | VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model | https://github.com/VARGPT-family/VARGPT |
181
+ | [1/26/2025] | [ICML 2025]Visual Generation Without Guidance | https://github.com/thu-ml/GFT |
182
+ | [12/30/2024] | Next Token Prediction Towards Multimodal Intelligence | https://github.com/LMM101/Awesome-Multimodal-Next-Token-Prediction |
183
+ | [12/30/2024] | Varformer: Adapting VAR’s Generative Prior for Image Restoration | https://arxiv.org/abs/2412.21063 |
184
+ | [12/22/2024] | [ICLR 2025]Distilled Decoding 1: One-step Sampling of Image Auto-regressive Models with Flow Matching | https://github.com/imagination-research/distilled-decoding |
185
+ | [12/19/2024] | FlowAR: Scale-wise Autoregressive Image Generation Meets Flow Matching | https://github.com/OliverRensu/FlowAR |
186
+ | [12/13/2024] | 3D representation in 512-Byte: Variational tokenizer is the key for autoregressive 3D generation | https://github.com/sparse-mvs-2/VAT |
187
+ | [12/9/2024] | CARP: Visuomotor Policy Learning via Coarse-to-Fine Autoregressive Prediction | https://carp-robot.github.io/ |
188
+ | [12/5/2024] | [CVPR 2025]Infinity ∞: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis | https://github.com/FoundationVision/Infinity |
189
+ | [12/5/2024] | [CVPR 2025]Switti: Designing Scale-Wise Transformers for Text-to-Image Synthesis | https://github.com/yandex-research/switti |
190
+ | [12/4/2024] | [CVPR 2025]TokenFlow🚀: Unified Image Tokenizer for Multimodal Understanding and Generation | https://github.com/ByteFlow-AI/TokenFlow |
191
+ | [12/3/2024] | XQ-GAN🚀: An Open-source Image Tokenization Framework for Autoregressive Generation | https://github.com/lxa9867/ImageFolder |
192
+ | [11/28/2024] | [CVPR 2025]CoDe: Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient | https://github.com/czg1225/CoDe |
193
+ | [11/28/2024] | [CVPR 2025]Scalable Autoregressive Monocular Depth Estimation | https://arxiv.org/abs/2411.11361 |
194
+ | [11/27/2024] | [CVPR 2025]SAR3D: Autoregressive 3D Object Generation and Understanding via Multi-scale 3D VQVAE | https://github.com/cyw-3d/SAR3D |
195
+ | [11/26/2024] | LiteVAR: Compressing Visual Autoregressive Modelling with Efficient Attention and Quantization | https://arxiv.org/abs/2411.17178 |
196
+ | [11/15/2024] | M-VAR: Decoupled Scale-wise Autoregressive Modeling for High-Quality Image Generation | https://github.com/OliverRensu/MVAR |
197
+ | [10/14/2024] | [ICLR 2025]HART: Efficient Visual Generation with Hybrid Autoregressive Transformer | https://github.com/mit-han-lab/hart |
198
+ | [10/12/2024] | [ICLR 2025 Oral]Toward Guidance-Free AR Visual Generation via Condition Contrastive Alignment | https://github.com/thu-ml/CCA |
199
+ | [10/3/2024] | [ICLR 2025]ImageFolder🚀: Autoregressive Image Generation with Folded Tokens | https://github.com/lxa9867/ImageFolder |
200
+ | [07/25/2024] | ControlVAR: Exploring Controllable Visual Autoregressive Modeling | https://github.com/lxa9867/ControlVAR |
201
+ | [07/3/2024] | VAR-CLIP: Text-to-Image Generator with Visual Auto-Regressive Modeling | https://github.com/daixiangzi/VAR-CLIP |
202
+ | [06/16/2024] | STAR: Scale-wise Text-to-image generation via Auto-Regressive representations | https://arxiv.org/abs/2406.10797 |
203
+
204
+
205
+ ## License
206
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
207
+
208
+
209
+ ## Citation
210
+ If our work assists your research, feel free to give us a star ⭐ or cite us using:
211
+ ```
212
+ @Article{VAR,
213
+ title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
214
+ author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
215
+ year={2024},
216
+ eprint={2404.02905},
217
+ archivePrefix={arXiv},
218
+ primaryClass={cs.CV}
219
+ }
220
+ ```
221
+
222
+ ```
223
+ @misc{Infinity,
224
+ title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis},
225
+ author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu},
226
+ year={2024},
227
+ eprint={2412.04431},
228
+ archivePrefix={arXiv},
229
+ primaryClass={cs.CV},
230
+ url={https://arxiv.org/abs/2412.04431},
231
+ }
232
+ ```
VAR/code/VAR/__pycache__/dist.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
VAR/code/VAR/__pycache__/dist.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
VAR/code/VAR/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (6.91 kB). View file
 
VAR/code/VAR/config.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # d16, 256x256
5
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
6
+ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
7
+ # d20, 256x256
8
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
9
+ --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
10
+ # d24, 256x256
11
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
12
+ --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
13
+ # d30, 256x256
14
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
15
+ --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
16
+ # d36-s, 512x512 (-s means saln=1, shared AdaLN)
17
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
18
+ --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
19
+
20
+
21
+
22
+ torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
23
+ --master_addr=127.0.0.1 --master_port=29500 \
24
+ train.py --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
25
+
26
+ torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
27
+ --master_addr=127.0.0.1 --master_port=29500 \
28
+ train.py --depth=16 --bs=96 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
29
+
30
+ python train.py --depth=16 --bs=96 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --data_path=/sd/qichen/VAR/imagenet_var
VAR/code/VAR/demo_sample.ipynb ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟"
7
+ ],
8
+ "metadata": {
9
+ "collapsed": false
10
+ }
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "outputs": [],
16
+ "source": [
17
+ "################## 1. Download checkpoints and build models\n",
18
+ "import os\n",
19
+ "import os.path as osp\n",
20
+ "import torch, torchvision\n",
21
+ "import random\n",
22
+ "import numpy as np\n",
23
+ "import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
24
+ "setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
25
+ "setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
26
+ "from models import VQVAE, build_vae_var\n",
27
+ "\n",
28
+ "MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====\n",
29
+ "assert MODEL_DEPTH in {16, 20, 24, 30}\n",
30
+ "\n",
31
+ "\n",
32
+ "# download checkpoint\n",
33
+ "hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
34
+ "vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'\n",
35
+ "if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')\n",
36
+ "if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')\n",
37
+ "\n",
38
+ "# build vae, var\n",
39
+ "patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)\n",
40
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
41
+ "if 'vae' not in globals() or 'var' not in globals():\n",
42
+ " vae, var = build_vae_var(\n",
43
+ " V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters\n",
44
+ " device=device, patch_nums=patch_nums,\n",
45
+ " num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
46
+ " )\n",
47
+ "\n",
48
+ "# load checkpoints\n",
49
+ "vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)\n",
50
+ "var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)\n",
51
+ "vae.eval(), var.eval()\n",
52
+ "for p in vae.parameters(): p.requires_grad_(False)\n",
53
+ "for p in var.parameters(): p.requires_grad_(False)\n",
54
+ "print(f'prepare finished.')"
55
+ ],
56
+ "metadata": {
57
+ "collapsed": false,
58
+ "is_executing": true
59
+ }
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "outputs": [],
65
+ "source": [
66
+ "############################# 2. Sample with classifier-free guidance\n",
67
+ "\n",
68
+ "# set args\n",
69
+ "seed = 0 #@param {type:\"number\"}\n",
70
+ "torch.manual_seed(seed)\n",
71
+ "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
72
+ "cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
73
+ "class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:\"raw\"}\n",
74
+ "more_smooth = False # True for more smooth output\n",
75
+ "\n",
76
+ "# seed\n",
77
+ "torch.manual_seed(seed)\n",
78
+ "random.seed(seed)\n",
79
+ "np.random.seed(seed)\n",
80
+ "torch.backends.cudnn.deterministic = True\n",
81
+ "torch.backends.cudnn.benchmark = False\n",
82
+ "\n",
83
+ "# run faster\n",
84
+ "tf32 = True\n",
85
+ "torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
86
+ "torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
87
+ "torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
88
+ "\n",
89
+ "# sample\n",
90
+ "B = len(class_labels)\n",
91
+ "label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
92
+ "with torch.inference_mode():\n",
93
+ " with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # using bfloat16 can be faster\n",
94
+ " recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
95
+ "\n",
96
+ "chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
97
+ "chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
98
+ "chw = PImage.fromarray(chw.astype(np.uint8))\n",
99
+ "chw.show()\n"
100
+ ],
101
+ "metadata": {
102
+ "collapsed": false
103
+ }
104
+ }
105
+ ],
106
+ "metadata": {
107
+ "kernelspec": {
108
+ "display_name": "Python 3",
109
+ "language": "python",
110
+ "name": "python3"
111
+ },
112
+ "language_info": {
113
+ "codemirror_mode": {
114
+ "name": "ipython",
115
+ "version": 2
116
+ },
117
+ "file_extension": ".py",
118
+ "mimetype": "text/x-python",
119
+ "name": "python",
120
+ "nbconvert_exporter": "python",
121
+ "pygments_lexer": "ipython2",
122
+ "version": "2.7.6"
123
+ }
124
+ },
125
+ "nbformat": 4,
126
+ "nbformat_minor": 0
127
+ }
VAR/code/VAR/demo_zero_shot_edit.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
VAR/code/VAR/dist.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.distributed as tdist
10
+ import torch.multiprocessing as mp
11
+
12
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ __initialized = False
14
+
15
+
16
+ def initialized():
17
+ return __initialized
18
+
19
+
20
+ def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
21
+ global __device
22
+ if not torch.cuda.is_available():
23
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
24
+ return
25
+ elif 'RANK' not in os.environ:
26
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
27
+ __device = torch.empty(1).cuda().device
28
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
29
+ return
30
+ # then 'RANK' must exist
31
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
32
+ local_rank = global_rank % num_gpus
33
+ torch.cuda.set_device(local_rank)
34
+
35
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
36
+ if mp.get_start_method(allow_none=True) is None:
37
+ method = 'fork' if fork else 'spawn'
38
+ print(f'[dist initialize] mp method={method}')
39
+ mp.set_start_method(method)
40
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
41
+
42
+ global __rank, __local_rank, __world_size, __initialized
43
+ __local_rank = local_rank
44
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
45
+ __device = torch.empty(1).cuda().device
46
+ __initialized = True
47
+
48
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
49
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
50
+
51
+
52
+ def get_rank():
53
+ return __rank
54
+
55
+
56
+ def get_local_rank():
57
+ return __local_rank
58
+
59
+
60
+ def get_world_size():
61
+ return __world_size
62
+
63
+
64
+ def get_device():
65
+ return __device
66
+
67
+
68
+ def set_gpu_id(gpu_id: int):
69
+ if gpu_id is None: return
70
+ global __device
71
+ if isinstance(gpu_id, (str, int)):
72
+ torch.cuda.set_device(int(gpu_id))
73
+ __device = torch.empty(1).cuda().device
74
+ else:
75
+ raise NotImplementedError
76
+
77
+
78
+ def is_master():
79
+ return __rank == 0
80
+
81
+
82
+ def is_local_master():
83
+ return __local_rank == 0
84
+
85
+
86
+ def new_group(ranks: List[int]):
87
+ if __initialized:
88
+ return tdist.new_group(ranks=ranks)
89
+ return None
90
+
91
+
92
+ def barrier():
93
+ if __initialized:
94
+ tdist.barrier()
95
+
96
+
97
+ def allreduce(t: torch.Tensor, async_op=False):
98
+ if __initialized:
99
+ if not t.is_cuda:
100
+ cu = t.detach().cuda()
101
+ ret = tdist.all_reduce(cu, async_op=async_op)
102
+ t.copy_(cu.cpu())
103
+ else:
104
+ ret = tdist.all_reduce(t, async_op=async_op)
105
+ return ret
106
+ return None
107
+
108
+
109
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
110
+ if __initialized:
111
+ if not t.is_cuda:
112
+ t = t.cuda()
113
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
114
+ tdist.all_gather(ls, t)
115
+ else:
116
+ ls = [t]
117
+ if cat:
118
+ ls = torch.cat(ls, dim=0)
119
+ return ls
120
+
121
+
122
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
123
+ if __initialized:
124
+ if not t.is_cuda:
125
+ t = t.cuda()
126
+
127
+ t_size = torch.tensor(t.size(), device=t.device)
128
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
129
+ tdist.all_gather(ls_size, t_size)
130
+
131
+ max_B = max(size[0].item() for size in ls_size)
132
+ pad = max_B - t_size[0].item()
133
+ if pad:
134
+ pad_size = (pad, *t.size()[1:])
135
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
136
+
137
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
138
+ tdist.all_gather(ls_padded, t)
139
+ ls = []
140
+ for t, size in zip(ls_padded, ls_size):
141
+ ls.append(t[:size[0].item()])
142
+ else:
143
+ ls = [t]
144
+ if cat:
145
+ ls = torch.cat(ls, dim=0)
146
+ return ls
147
+
148
+
149
+ def broadcast(t: torch.Tensor, src_rank) -> None:
150
+ if __initialized:
151
+ if not t.is_cuda:
152
+ cu = t.detach().cuda()
153
+ tdist.broadcast(cu, src=src_rank)
154
+ t.copy_(cu.cpu())
155
+ else:
156
+ tdist.broadcast(t, src=src_rank)
157
+
158
+
159
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
160
+ if not initialized():
161
+ return torch.tensor([val]) if fmt is None else [fmt % val]
162
+
163
+ ts = torch.zeros(__world_size)
164
+ ts[__rank] = val
165
+ allreduce(ts)
166
+ if fmt is None:
167
+ return ts
168
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
169
+
170
+
171
+ def master_only(func):
172
+ @functools.wraps(func)
173
+ def wrapper(*args, **kwargs):
174
+ force = kwargs.pop('force', False)
175
+ if force or is_master():
176
+ ret = func(*args, **kwargs)
177
+ else:
178
+ ret = None
179
+ barrier()
180
+ return ret
181
+ return wrapper
182
+
183
+
184
+ def local_master_only(func):
185
+ @functools.wraps(func)
186
+ def wrapper(*args, **kwargs):
187
+ force = kwargs.pop('force', False)
188
+ if force or is_local_master():
189
+ ret = func(*args, **kwargs)
190
+ else:
191
+ ret = None
192
+ barrier()
193
+ return ret
194
+ return wrapper
195
+
196
+
197
+ def for_visualize(func):
198
+ @functools.wraps(func)
199
+ def wrapper(*args, **kwargs):
200
+ if is_master():
201
+ # with torch.no_grad():
202
+ ret = func(*args, **kwargs)
203
+ else:
204
+ ret = None
205
+ return ret
206
+ return wrapper
207
+
208
+
209
+ def finalize():
210
+ if __initialized:
211
+ tdist.destroy_process_group()
VAR/code/VAR/models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch.nn as nn
3
+
4
+ from .quant import VectorQuantizer2
5
+ from .var import VAR
6
+ from .vqvae import VQVAE
7
+
8
+
9
+ def build_vae_var(
10
+ # Shared args
11
+ device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
12
+ # VQVAE args
13
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
14
+ # VAR args
15
+ num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
16
+ flash_if_available=True, fused_if_available=True,
17
+ init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
18
+ ) -> Tuple[VQVAE, VAR]:
19
+ heads = depth
20
+ width = depth * 64
21
+ dpr = 0.1 * depth/24
22
+
23
+ # disable built-in initialization for speed
24
+ for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
25
+ setattr(clz, 'reset_parameters', lambda self: None)
26
+
27
+ # build models
28
+ vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
29
+ var_wo_ddp = VAR(
30
+ vae_local=vae_local,
31
+ num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
32
+ norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
33
+ attn_l2_norm=attn_l2_norm,
34
+ patch_nums=patch_nums,
35
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
36
+ ).to(device)
37
+ var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
38
+
39
+ return vae_local, var_wo_ddp
VAR/code/VAR/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
VAR/code/VAR/models/__pycache__/basic_vae.cpython-310.pyc ADDED
Binary file (6.83 kB). View file
 
VAR/code/VAR/models/__pycache__/basic_var.cpython-310.pyc ADDED
Binary file (7.44 kB). View file
 
VAR/code/VAR/models/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
VAR/code/VAR/models/__pycache__/quant.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
VAR/code/VAR/models/__pycache__/var.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
VAR/code/VAR/models/__pycache__/vqvae.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
VAR/code/VAR/models/basic_vae.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # this file only provides the 2 modules used in VQVAE
7
+ __all__ = ['Encoder', 'Decoder',]
8
+
9
+
10
+ """
11
+ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
12
+ """
13
+ # swish
14
+ def nonlinearity(x):
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels, num_groups=32):
19
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
20
+
21
+
22
+ class Upsample2x(nn.Module):
23
+ def __init__(self, in_channels):
24
+ super().__init__()
25
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
26
+
27
+ def forward(self, x):
28
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
29
+
30
+
31
+ class Downsample2x(nn.Module):
32
+ def __init__(self, in_channels):
33
+ super().__init__()
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x):
37
+ return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
38
+
39
+
40
+ class ResnetBlock(nn.Module):
41
+ def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ out_channels = in_channels if out_channels is None else out_channels
45
+ self.out_channels = out_channels
46
+
47
+ self.norm1 = Normalize(in_channels)
48
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
49
+ self.norm2 = Normalize(out_channels)
50
+ self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
51
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
52
+ if self.in_channels != self.out_channels:
53
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
54
+ else:
55
+ self.nin_shortcut = nn.Identity()
56
+
57
+ def forward(self, x):
58
+ h = self.conv1(F.silu(self.norm1(x), inplace=True))
59
+ h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
60
+ return self.nin_shortcut(x) + h
61
+
62
+
63
+ class AttnBlock(nn.Module):
64
+ def __init__(self, in_channels):
65
+ super().__init__()
66
+ self.C = in_channels
67
+
68
+ self.norm = Normalize(in_channels)
69
+ self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
70
+ self.w_ratio = int(in_channels) ** (-0.5)
71
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
72
+
73
+ def forward(self, x):
74
+ qkv = self.qkv(self.norm(x))
75
+ B, _, H, W = qkv.shape # should be B,3C,H,W
76
+ C = self.C
77
+ q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
78
+
79
+ # compute attention
80
+ q = q.view(B, C, H * W).contiguous()
81
+ q = q.permute(0, 2, 1).contiguous() # B,HW,C
82
+ k = k.view(B, C, H * W).contiguous() # B,C,HW
83
+ w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
84
+ w = F.softmax(w, dim=2)
85
+
86
+ # attend to values
87
+ v = v.view(B, C, H * W).contiguous()
88
+ w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
89
+ h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
90
+ h = h.view(B, C, H, W).contiguous()
91
+
92
+ return x + self.proj_out(h)
93
+
94
+
95
+ def make_attn(in_channels, using_sa=True):
96
+ return AttnBlock(in_channels) if using_sa else nn.Identity()
97
+
98
+
99
+ class Encoder(nn.Module):
100
+ def __init__(
101
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
102
+ dropout=0.0, in_channels=3,
103
+ z_channels, double_z=False, using_sa=True, using_mid_sa=True,
104
+ ):
105
+ super().__init__()
106
+ self.ch = ch
107
+ self.num_resolutions = len(ch_mult)
108
+ self.downsample_ratio = 2 ** (self.num_resolutions - 1)
109
+ self.num_res_blocks = num_res_blocks
110
+ self.in_channels = in_channels
111
+
112
+ # downsampling
113
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
114
+
115
+ in_ch_mult = (1,) + tuple(ch_mult)
116
+ self.down = nn.ModuleList()
117
+ for i_level in range(self.num_resolutions):
118
+ block = nn.ModuleList()
119
+ attn = nn.ModuleList()
120
+ block_in = ch * in_ch_mult[i_level]
121
+ block_out = ch * ch_mult[i_level]
122
+ for i_block in range(self.num_res_blocks):
123
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
124
+ block_in = block_out
125
+ if i_level == self.num_resolutions - 1 and using_sa:
126
+ attn.append(make_attn(block_in, using_sa=True))
127
+ down = nn.Module()
128
+ down.block = block
129
+ down.attn = attn
130
+ if i_level != self.num_resolutions - 1:
131
+ down.downsample = Downsample2x(block_in)
132
+ self.down.append(down)
133
+
134
+ # middle
135
+ self.mid = nn.Module()
136
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
137
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
138
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
139
+
140
+ # end
141
+ self.norm_out = Normalize(block_in)
142
+ self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
143
+
144
+ def forward(self, x):
145
+ # downsampling
146
+ h = self.conv_in(x)
147
+ for i_level in range(self.num_resolutions):
148
+ for i_block in range(self.num_res_blocks):
149
+ h = self.down[i_level].block[i_block](h)
150
+ if len(self.down[i_level].attn) > 0:
151
+ h = self.down[i_level].attn[i_block](h)
152
+ if i_level != self.num_resolutions - 1:
153
+ h = self.down[i_level].downsample(h)
154
+
155
+ # middle
156
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
157
+
158
+ # end
159
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
160
+ return h
161
+
162
+
163
+ class Decoder(nn.Module):
164
+ def __init__(
165
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
166
+ dropout=0.0, in_channels=3, # in_channels: raw img channels
167
+ z_channels, using_sa=True, using_mid_sa=True,
168
+ ):
169
+ super().__init__()
170
+ self.ch = ch
171
+ self.num_resolutions = len(ch_mult)
172
+ self.num_res_blocks = num_res_blocks
173
+ self.in_channels = in_channels
174
+
175
+ # compute in_ch_mult, block_in and curr_res at lowest res
176
+ in_ch_mult = (1,) + tuple(ch_mult)
177
+ block_in = ch * ch_mult[self.num_resolutions - 1]
178
+
179
+ # z to block_in
180
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
181
+
182
+ # middle
183
+ self.mid = nn.Module()
184
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
185
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
186
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
187
+
188
+ # upsampling
189
+ self.up = nn.ModuleList()
190
+ for i_level in reversed(range(self.num_resolutions)):
191
+ block = nn.ModuleList()
192
+ attn = nn.ModuleList()
193
+ block_out = ch * ch_mult[i_level]
194
+ for i_block in range(self.num_res_blocks + 1):
195
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
196
+ block_in = block_out
197
+ if i_level == self.num_resolutions-1 and using_sa:
198
+ attn.append(make_attn(block_in, using_sa=True))
199
+ up = nn.Module()
200
+ up.block = block
201
+ up.attn = attn
202
+ if i_level != 0:
203
+ up.upsample = Upsample2x(block_in)
204
+ self.up.insert(0, up) # prepend to get consistent order
205
+
206
+ # end
207
+ self.norm_out = Normalize(block_in)
208
+ self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
209
+
210
+ def forward(self, z):
211
+ # z to block_in
212
+ # middle
213
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
214
+
215
+ # upsampling
216
+ for i_level in reversed(range(self.num_resolutions)):
217
+ for i_block in range(self.num_res_blocks + 1):
218
+ h = self.up[i_level].block[i_block](h)
219
+ if len(self.up[i_level].attn) > 0:
220
+ h = self.up[i_level].attn[i_block](h)
221
+ if i_level != 0:
222
+ h = self.up[i_level].upsample(h)
223
+
224
+ # end
225
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
226
+ return h
VAR/code/VAR/models/basic_var.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.helpers import DropPath, drop_path
8
+
9
+
10
+ # this file only provides the 3 blocks used in VAR transformer
11
+ __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
12
+
13
+
14
+ # automatically import fused operators
15
+ dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
16
+ try:
17
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
18
+ from flash_attn.ops.fused_dense import fused_mlp_func
19
+ except ImportError: pass
20
+ # automatically import faster attention implementations
21
+ try: from xformers.ops import memory_efficient_attention
22
+ except ImportError: pass
23
+ try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
24
+ except ImportError: pass
25
+ try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
26
+ except ImportError:
27
+ def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
28
+ attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
29
+ if attn_mask is not None: attn.add_(attn_mask)
30
+ return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
31
+
32
+
33
+ class FFN(nn.Module):
34
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
35
+ super().__init__()
36
+ self.fused_mlp_func = fused_mlp_func if fused_if_available else None
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+ self.fc1 = nn.Linear(in_features, hidden_features)
40
+ self.act = nn.GELU(approximate='tanh')
41
+ self.fc2 = nn.Linear(hidden_features, out_features)
42
+ self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
43
+
44
+ def forward(self, x):
45
+ if self.fused_mlp_func is not None:
46
+ return self.drop(self.fused_mlp_func(
47
+ x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
48
+ activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
49
+ heuristic=0, process_group=None,
50
+ ))
51
+ else:
52
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
53
+
54
+ def extra_repr(self) -> str:
55
+ return f'fused_mlp_func={self.fused_mlp_func is not None}'
56
+
57
+
58
+ class SelfAttention(nn.Module):
59
+ def __init__(
60
+ self, block_idx, embed_dim=768, num_heads=12,
61
+ attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
62
+ ):
63
+ super().__init__()
64
+ assert embed_dim % num_heads == 0
65
+ self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
66
+ self.attn_l2_norm = attn_l2_norm
67
+ if self.attn_l2_norm:
68
+ self.scale = 1
69
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
70
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
71
+ else:
72
+ self.scale = 0.25 / math.sqrt(self.head_dim)
73
+
74
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
75
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
76
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
77
+
78
+ self.proj = nn.Linear(embed_dim, embed_dim)
79
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
80
+ self.attn_drop: float = attn_drop
81
+ self.using_flash = flash_if_available and flash_attn_func is not None
82
+ self.using_xform = flash_if_available and memory_efficient_attention is not None
83
+
84
+ # only used during inference
85
+ self.caching, self.cached_k, self.cached_v = False, None, None
86
+
87
+ def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
88
+
89
+ # NOTE: attn_bias is None during inference because kv cache is enabled
90
+ def forward(self, x, attn_bias):
91
+ B, L, C = x.shape
92
+
93
+ qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
94
+ main_type = qkv.dtype
95
+ # qkv: BL3Hc
96
+
97
+ using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
98
+ if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
99
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
100
+
101
+ if self.attn_l2_norm:
102
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
103
+ if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
104
+ q = F.normalize(q, dim=-1).mul(scale_mul)
105
+ k = F.normalize(k, dim=-1)
106
+
107
+ if self.caching:
108
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
109
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
110
+
111
+ dropout_p = self.attn_drop if self.training else 0.0
112
+ if using_flash:
113
+ oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
114
+ elif self.using_xform:
115
+ oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
116
+ else:
117
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
118
+
119
+ return self.proj_drop(self.proj(oup))
120
+ # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
121
+ # attn = self.attn_drop(attn.softmax(dim=-1))
122
+ # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
123
+
124
+ def extra_repr(self) -> str:
125
+ return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
126
+
127
+
128
+ class AdaLNSelfAttn(nn.Module):
129
+ def __init__(
130
+ self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
131
+ num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
132
+ flash_if_available=False, fused_if_available=True,
133
+ ):
134
+ super(AdaLNSelfAttn, self).__init__()
135
+ self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
136
+ self.C, self.D = embed_dim, cond_dim
137
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
138
+ self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
139
+ self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
140
+
141
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
142
+ self.shared_aln = shared_aln
143
+ if self.shared_aln:
144
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
145
+ else:
146
+ lin = nn.Linear(cond_dim, 6*embed_dim)
147
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
148
+
149
+ self.fused_add_norm_fn = None
150
+
151
+ # NOTE: attn_bias is None during inference because kv cache is enabled
152
+ def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
153
+ if self.shared_aln:
154
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
155
+ else:
156
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
157
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
158
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
159
+ return x
160
+
161
+ def extra_repr(self) -> str:
162
+ return f'shared_aln={self.shared_aln}'
163
+
164
+
165
+ class AdaLNBeforeHead(nn.Module):
166
+ def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
167
+ super().__init__()
168
+ self.C, self.D = C, D
169
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
170
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
171
+
172
+ def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
173
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
174
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
VAR/code/VAR/models/helpers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
7
+ B, l, V = logits_BlV.shape
8
+ if top_k > 0:
9
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
10
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
11
+ if top_p > 0:
12
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
13
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
14
+ sorted_idx_to_remove[..., -1:] = False
15
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
16
+ # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
17
+ replacement = num_samples >= 0
18
+ num_samples = abs(num_samples)
19
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
20
+
21
+
22
+ def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
23
+ if rng is None:
24
+ return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
25
+
26
+ gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
27
+ gumbels = (logits + gumbels) / tau
28
+ y_soft = gumbels.softmax(dim)
29
+
30
+ if hard:
31
+ index = y_soft.max(dim, keepdim=True)[1]
32
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
33
+ ret = y_hard - y_soft.detach() + y_soft
34
+ else:
35
+ ret = y_soft
36
+ return ret
37
+
38
+
39
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
40
+ if drop_prob == 0. or not training: return x
41
+ keep_prob = 1 - drop_prob
42
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
44
+ if keep_prob > 0.0 and scale_by_keep:
45
+ random_tensor.div_(keep_prob)
46
+ return x * random_tensor
47
+
48
+
49
+ class DropPath(nn.Module): # taken from timm
50
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
51
+ super(DropPath, self).__init__()
52
+ self.drop_prob = drop_prob
53
+ self.scale_by_keep = scale_by_keep
54
+
55
+ def forward(self, x):
56
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
57
+
58
+ def extra_repr(self):
59
+ return f'(drop_prob=...)'
VAR/code/VAR/models/quant.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Sequence, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import distributed as tdist, nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ import dist
9
+
10
+
11
+ # this file only provides the VectorQuantizer2 used in VQVAE
12
+ __all__ = ['VectorQuantizer2',]
13
+
14
+
15
+ class VectorQuantizer2(nn.Module):
16
+ # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
17
+ def __init__(
18
+ self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
19
+ default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
20
+ ):
21
+ super().__init__()
22
+ self.vocab_size: int = vocab_size
23
+ self.Cvae: int = Cvae
24
+ self.using_znorm: bool = using_znorm
25
+ self.v_patch_nums: Tuple[int] = v_patch_nums
26
+
27
+ self.quant_resi_ratio = quant_resi
28
+ if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
29
+ self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])
30
+ elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
31
+ self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
32
+ else: # partially shared: \phi_{1 to share_quant_resi} for K scales
33
+ self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))
34
+
35
+ self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
36
+ self.record_hit = 0
37
+
38
+ self.beta: float = beta
39
+ self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
40
+
41
+ # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
42
+ self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
43
+
44
+ def eini(self, eini):
45
+ if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
46
+ elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
47
+
48
+ def extra_repr(self) -> str:
49
+ return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
50
+
51
+ # ===================== `forward` is only used in VAE training =====================
52
+ def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
53
+ dtype = f_BChw.dtype
54
+ if dtype != torch.float32: f_BChw = f_BChw.float()
55
+ B, C, H, W = f_BChw.shape
56
+ f_no_grad = f_BChw.detach()
57
+
58
+ f_rest = f_no_grad.clone()
59
+ f_hat = torch.zeros_like(f_rest)
60
+
61
+ with torch.cuda.amp.autocast(enabled=False):
62
+ mean_vq_loss: torch.Tensor = 0.0
63
+ vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
64
+ SN = len(self.v_patch_nums)
65
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
66
+ # find the nearest embedding
67
+ if self.using_znorm:
68
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
69
+ rest_NC = F.normalize(rest_NC, dim=-1)
70
+ idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
71
+ else:
72
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
73
+ d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
74
+ d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
75
+ idx_N = torch.argmin(d_no_grad, dim=1)
76
+
77
+ hit_V = idx_N.bincount(minlength=self.vocab_size).float()
78
+ if self.training:
79
+ if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
80
+
81
+ # calc loss
82
+ idx_Bhw = idx_N.view(B, pn, pn)
83
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
84
+ h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
85
+ f_hat = f_hat + h_BChw
86
+ f_rest -= h_BChw
87
+
88
+ if self.training and dist.initialized():
89
+ handler.wait()
90
+ if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)
91
+ elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
92
+ else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
93
+ self.record_hit += 1
94
+ vocab_hit_V.add_(hit_V)
95
+ mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
96
+
97
+ mean_vq_loss *= 1. / SN
98
+ f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
99
+
100
+ margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
101
+ # margin = pn*pn / 100
102
+ if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]
103
+ else: usages = None
104
+ return f_hat, usages, mean_vq_loss
105
+ # ===================== `forward` is only used in VAE training =====================
106
+
107
+ def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
108
+ ls_f_hat_BChw = []
109
+ B = ms_h_BChw[0].shape[0]
110
+ H = W = self.v_patch_nums[-1]
111
+ SN = len(self.v_patch_nums)
112
+ if all_to_max_scale:
113
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
114
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
115
+ h_BChw = ms_h_BChw[si]
116
+ if si < len(self.v_patch_nums) - 1:
117
+ h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')
118
+ h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
119
+ f_hat.add_(h_BChw)
120
+ if last_one: ls_f_hat_BChw = f_hat
121
+ else: ls_f_hat_BChw.append(f_hat.clone())
122
+ else:
123
+ # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
124
+ # WARNING: this should only be used for experimental purpose
125
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)
126
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
127
+ f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')
128
+ h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])
129
+ f_hat.add_(h_BChw)
130
+ if last_one: ls_f_hat_BChw = f_hat
131
+ else: ls_f_hat_BChw.append(f_hat)
132
+
133
+ return ls_f_hat_BChw
134
+
135
+ def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
136
+ B, C, H, W = f_BChw.shape
137
+ f_no_grad = f_BChw.detach()
138
+ f_rest = f_no_grad.clone()
139
+ f_hat = torch.zeros_like(f_rest)
140
+
141
+ f_hat_or_idx_Bl: List[torch.Tensor] = []
142
+
143
+ patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)] # from small to large
144
+ assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
145
+
146
+ SN = len(patch_hws)
147
+ for si, (ph, pw) in enumerate(patch_hws): # from small to large
148
+ if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
149
+ # find the nearest embedding
150
+ z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
151
+ if self.using_znorm:
152
+ z_NC = F.normalize(z_NC, dim=-1)
153
+ idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
154
+ else:
155
+ d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
156
+ d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
157
+ idx_N = torch.argmin(d_no_grad, dim=1)
158
+
159
+ idx_Bhw = idx_N.view(B, ph, pw)
160
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
161
+ h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
162
+ f_hat.add_(h_BChw)
163
+ f_rest.sub_(h_BChw)
164
+ f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))
165
+
166
+ return f_hat_or_idx_Bl
167
+
168
+ # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
169
+ def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
170
+ next_scales = []
171
+ B = gt_ms_idx_Bl[0].shape[0]
172
+ C = self.Cvae
173
+ H = W = self.v_patch_nums[-1]
174
+ SN = len(self.v_patch_nums)
175
+
176
+ f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
177
+ pn_next: int = self.v_patch_nums[0]
178
+ for si in range(SN-1):
179
+ if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
180
+ h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
181
+ f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
182
+ pn_next = self.v_patch_nums[si+1]
183
+ next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
184
+ return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
185
+
186
+ # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
187
+ def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
188
+ HW = self.v_patch_nums[-1]
189
+ if si != SN-1:
190
+ h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic')) # conv after upsample
191
+ f_hat.add_(h)
192
+ return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')
193
+ else:
194
+ h = self.quant_resi[si/(SN-1)](h_BChw)
195
+ f_hat.add_(h)
196
+ return f_hat, f_hat
197
+
198
+
199
+ class Phi(nn.Conv2d):
200
+ def __init__(self, embed_dim, quant_resi):
201
+ ks = 3
202
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)
203
+ self.resi_ratio = abs(quant_resi)
204
+
205
+ def forward(self, h_BChw):
206
+ return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
207
+
208
+
209
+ class PhiShared(nn.Module):
210
+ def __init__(self, qresi: Phi):
211
+ super().__init__()
212
+ self.qresi: Phi = qresi
213
+
214
+ def __getitem__(self, _) -> Phi:
215
+ return self.qresi
216
+
217
+
218
+ class PhiPartiallyShared(nn.Module):
219
+ def __init__(self, qresi_ls: nn.ModuleList):
220
+ super().__init__()
221
+ self.qresi_ls = qresi_ls
222
+ K = len(qresi_ls)
223
+ self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
224
+
225
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
226
+ return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
227
+
228
+ def extra_repr(self) -> str:
229
+ return f'ticks={self.ticks}'
230
+
231
+
232
+ class PhiNonShared(nn.ModuleList):
233
+ def __init__(self, qresi: List):
234
+ super().__init__(qresi)
235
+ # self.qresi = qresi
236
+ K = len(qresi)
237
+ self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
238
+
239
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
240
+ return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
241
+
242
+ def extra_repr(self) -> str:
243
+ return f'ticks={self.ticks}'
VAR/code/VAR/models/var.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ import dist
10
+ from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
11
+ from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
12
+ from models.vqvae import VQVAE, VectorQuantizer2
13
+
14
+
15
+ class SharedAdaLin(nn.Linear):
16
+ def forward(self, cond_BD):
17
+ C = self.weight.shape[0] // 6
18
+ return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
19
+
20
+
21
+ class VAR(nn.Module):
22
+ def __init__(
23
+ self, vae_local: VQVAE,
24
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
25
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
26
+ attn_l2_norm=False,
27
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
28
+ flash_if_available=True, fused_if_available=True,
29
+ ):
30
+ super().__init__()
31
+ # 0. hyperparameters
32
+ assert embed_dim % num_heads == 0
33
+ self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
34
+ self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
35
+
36
+ self.cond_drop_rate = cond_drop_rate
37
+ self.prog_si = -1 # progressive training
38
+
39
+ self.patch_nums: Tuple[int] = patch_nums
40
+ self.L = sum(pn ** 2 for pn in self.patch_nums)
41
+ self.first_l = self.patch_nums[0] ** 2
42
+ self.begin_ends = []
43
+ cur = 0
44
+ for i, pn in enumerate(self.patch_nums):
45
+ self.begin_ends.append((cur, cur+pn ** 2))
46
+ cur += pn ** 2
47
+
48
+ self.num_stages_minus_1 = len(self.patch_nums) - 1
49
+ self.rng = torch.Generator(device=dist.get_device())
50
+
51
+ # 1. input (word) embedding
52
+ quant: VectorQuantizer2 = vae_local.quantize
53
+ self.vae_proxy: Tuple[VQVAE] = (vae_local,)
54
+ self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
55
+ self.word_embed = nn.Linear(self.Cvae, self.C)
56
+
57
+ # 2. class embedding
58
+ init_std = math.sqrt(1 / self.C / 3)
59
+ self.num_classes = num_classes
60
+ self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
61
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
62
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
63
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
64
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
65
+
66
+ # 3. absolute position embedding
67
+ pos_1LC = []
68
+ for i, pn in enumerate(self.patch_nums):
69
+ pe = torch.empty(1, pn*pn, self.C)
70
+ nn.init.trunc_normal_(pe, mean=0, std=init_std)
71
+ pos_1LC.append(pe)
72
+ pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
73
+ assert tuple(pos_1LC.shape) == (1, self.L, self.C)
74
+ self.pos_1LC = nn.Parameter(pos_1LC)
75
+ # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
76
+ self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
77
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
78
+
79
+ # 4. backbone blocks
80
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
81
+
82
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
83
+ self.drop_path_rate = drop_path_rate
84
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
85
+ self.blocks = nn.ModuleList([
86
+ AdaLNSelfAttn(
87
+ cond_dim=self.D, shared_aln=shared_aln,
88
+ block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
89
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
90
+ attn_l2_norm=attn_l2_norm,
91
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
92
+ )
93
+ for block_idx in range(depth)
94
+ ])
95
+
96
+ fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
97
+ self.using_fused_add_norm_fn = any(fused_add_norm_fns)
98
+ print(
99
+ f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
100
+ f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
101
+ f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
102
+ end='\n\n', flush=True
103
+ )
104
+
105
+ # 5. attention mask used in training (for masking out the future)
106
+ # it won't be used in inference, since kv cache is enabled
107
+ d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
108
+ dT = d.transpose(1, 2) # dT: 11L
109
+ lvl_1L = dT[:, 0].contiguous()
110
+ self.register_buffer('lvl_1L', lvl_1L)
111
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
112
+ self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
113
+
114
+ # 6. classifier head
115
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
116
+ self.head = nn.Linear(self.C, self.V)
117
+
118
+
119
+ # 7. register tokens
120
+ self.num_register_tokens = getattr(self, "num_register_tokens", 4) # 可外部传参
121
+ self.register_tokens = nn.Parameter(torch.randn(self.num_register_tokens, self.C) * 0.02)
122
+
123
+ def _expand_attn_bias(self, attn_bias: torch.Tensor, add_len: int) -> torch.Tensor:
124
+ """
125
+ attn_bias: (B_or_1, H, L, L) 或 (1, 1, L, L)
126
+ add_len : 要在序列长度上增加的 register 数量
127
+ 返回: (B_or_1, H, L+add_len, L+add_len)
128
+ """
129
+ if add_len <= 0:
130
+ return attn_bias
131
+
132
+ *prefix, L, _ = attn_bias.shape # (..., L, L)
133
+ newL = L + add_len
134
+ # 用 0 pad(表示不额外抑制注意力)。如果你有因果/遮挡逻辑,仍保持老区域不变
135
+ expanded = attn_bias.new_zeros(*prefix, newL, newL)
136
+ expanded[..., :L, :L] = attn_bias
137
+ return expanded
138
+
139
+
140
+
141
+
142
+ def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
143
+ if not isinstance(h_or_h_and_residual, torch.Tensor):
144
+ h, resi = h_or_h_and_residual # fused_add_norm must be used
145
+ h = resi + self.blocks[-1].drop_path(h)
146
+ else: # fused_add_norm is not used
147
+ h = h_or_h_and_residual
148
+ return self.head(self.head_nm(h.float(), cond_BD).float()).float()
149
+
150
+ @torch.no_grad()
151
+ def autoregressive_infer_cfg(
152
+ self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
153
+ g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
154
+ more_smooth=False,
155
+ ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
156
+ """
157
+ only used for inference, on autoregressive mode
158
+ :param B: batch size
159
+ :param label_B: imagenet label; if None, randomly sampled
160
+ :param g_seed: random seed
161
+ :param cfg: classifier-free guidance ratio
162
+ :param top_k: top-k sampling
163
+ :param top_p: top-p sampling
164
+ :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
165
+ :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
166
+ """
167
+ if g_seed is None: rng = None
168
+ else: self.rng.manual_seed(g_seed); rng = self.rng
169
+
170
+ if label_B is None:
171
+ label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
172
+ elif isinstance(label_B, int):
173
+ label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
174
+
175
+ sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))
176
+
177
+ lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
178
+ next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
179
+
180
+ cur_L = 0
181
+ f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
182
+
183
+ for b in self.blocks: b.attn.kv_caching(True)
184
+ for si, pn in enumerate(self.patch_nums): # si: i-th segment
185
+ ratio = si / self.num_stages_minus_1
186
+ # last_L = cur_L
187
+ cur_L += pn*pn
188
+ # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
189
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
190
+ x = next_token_map
191
+ AdaLNSelfAttn.forward
192
+ for b in self.blocks:
193
+ x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
194
+ logits_BlV = self.get_logits(x, cond_BD)
195
+
196
+ t = cfg * ratio
197
+ logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
198
+
199
+ idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
200
+ if not more_smooth: # this is the default case
201
+ h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
202
+ else: # not used when evaluating FID/IS/Precision/Recall
203
+ gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
204
+ h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
205
+
206
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
207
+ f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
208
+ if si != self.num_stages_minus_1: # prepare for next stage
209
+ next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
210
+ next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
211
+ next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
212
+
213
+ for b in self.blocks: b.attn.kv_caching(False)
214
+ return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
215
+
216
+ # def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
217
+ # """
218
+ # :param label_B: label_B
219
+ # :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
220
+ # :return: logits BLV, V is vocab_size
221
+ # """
222
+ # bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
223
+ # B = x_BLCv_wo_first_l.shape[0]
224
+ # with torch.cuda.amp.autocast(enabled=False):
225
+ # label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
226
+ # sos = cond_BD = self.class_emb(label_B)
227
+ # sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
228
+
229
+ # if self.prog_si == 0: x_BLC = sos
230
+ # else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
231
+ # x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
232
+
233
+ # attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
234
+ # cond_BD_or_gss = self.shared_ada_lin(cond_BD)
235
+
236
+ # # hack: get the dtype if mixed precision is used
237
+ # temp = x_BLC.new_ones(8, 8)
238
+ # main_type = torch.matmul(temp, temp).dtype
239
+
240
+ # x_BLC = x_BLC.to(dtype=main_type)
241
+ # cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
242
+ # attn_bias = attn_bias.to(dtype=main_type)
243
+
244
+ # AdaLNSelfAttn.forward
245
+ # for i, b in enumerate(self.blocks):
246
+ # x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
247
+ # x_BLC = self.get_logits(x_BLC.float(), cond_BD)
248
+
249
+ # if self.prog_si == 0:
250
+ # if isinstance(self.word_embed, nn.Linear):
251
+ # x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
252
+ # else:
253
+ # s = 0
254
+ # for p in self.word_embed.parameters():
255
+ # if p.requires_grad:
256
+ # s += p.view(-1)[0] * 0
257
+ # x_BLC[0, 0, 0] += s
258
+ # return x_BLC # logits BLV, V is vocab_size
259
+
260
+ def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor) -> torch.Tensor: # returns logits_BLV
261
+ """
262
+ :param label_B: label_B
263
+ :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
264
+ :return: logits BLV, V is vocab_size
265
+ """
266
+ bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
267
+ B = x_BLCv_wo_first_l.shape[0]
268
+
269
+ with torch.cuda.amp.autocast(enabled=False):
270
+ label_B = torch.where(
271
+ torch.rand(B, device=label_B.device) < self.cond_drop_rate,
272
+ self.num_classes,
273
+ label_B
274
+ )
275
+ sos = cond_BD = self.class_emb(label_B) # (B, D)
276
+ sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
277
+
278
+ if self.prog_si == 0:
279
+ x_BLC = sos # (B, first_l, C)
280
+ else:
281
+ x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1) # (B, ed, C)
282
+
283
+ # 仅对原 tokens 加 lvl/pos(register 不加位置信息)
284
+ x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
285
+
286
+ # 原始 attn bias(针对 ed 长度)
287
+ attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
288
+
289
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
290
+
291
+ # hack: 获取主 dtype(兼容混精)
292
+ temp = x_BLC.new_ones(8, 8)
293
+ main_type = torch.matmul(temp, temp).dtype
294
+
295
+ x_BLC = x_BLC.to(dtype=main_type)
296
+ cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
297
+ attn_bias = attn_bias.to(dtype=main_type)
298
+
299
+ # ====== 这里开始插入 register tokens ======
300
+ K = self.num_register_tokens
301
+ if K > 0:
302
+ # [K, C] -> [B, K, C]
303
+ r_BKC = self.register_tokens.to(device=x_BLC.device, dtype=x_BLC.dtype).unsqueeze(0).expand(B, -1, -1)
304
+
305
+ # breakpoint()
306
+ # 拼接顺序: [REG ... | TOKENS ...] —— 仅作为中间交互,输出前会丢弃
307
+ x_with_reg = torch.cat([r_BKC, x_BLC], dim=1) # (B, K+ed, C)
308
+
309
+ # 扩展 attention bias 到 K+ed
310
+ attn_bias_expanded = self._expand_attn_bias(attn_bias, add_len=K)
311
+ else:
312
+ x_with_reg = x_BLC
313
+ attn_bias_expanded = attn_bias
314
+ # ===========================================
315
+
316
+ # 经过 blocks(让 register 参与注意力交互)
317
+ for i, b in enumerate(self.blocks):
318
+ x_with_reg = b(x=x_with_reg, cond_BD=cond_BD_or_gss, attn_bias=attn_bias_expanded)
319
+
320
+ # 丢弃 register,只保留原 tokens 部分以对齐下游 get_logits 的长度/语义
321
+ if K > 0:
322
+ x_BLC = x_with_reg[:, K:, :] # (B, ed, C)
323
+ else:
324
+ x_BLC = x_with_reg
325
+
326
+ # 计算 logits(保持你原有逻辑不变)
327
+ x_BLC = self.get_logits(x_BLC.float(), cond_BD)
328
+
329
+ # 保持对 word_embed 的“零引用”以避免编译/导出时被优化掉
330
+ if self.prog_si == 0:
331
+ if isinstance(self.word_embed, nn.Linear):
332
+ x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
333
+ else:
334
+ s = 0
335
+ for p in self.word_embed.parameters():
336
+ if p.requires_grad:
337
+ s += p.view(-1)[0] * 0
338
+ x_BLC[0, 0, 0] += s
339
+
340
+ return x_BLC # logits BLV, V is vocab_size
341
+
342
+
343
+ def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
344
+ if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
345
+
346
+ print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
347
+ for m in self.modules():
348
+ with_weight = hasattr(m, 'weight') and m.weight is not None
349
+ with_bias = hasattr(m, 'bias') and m.bias is not None
350
+ if isinstance(m, nn.Linear):
351
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
352
+ if with_bias: m.bias.data.zero_()
353
+ elif isinstance(m, nn.Embedding):
354
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
355
+ if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
356
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
357
+ if with_weight: m.weight.data.fill_(1.)
358
+ if with_bias: m.bias.data.zero_()
359
+ # conv: VAR has no conv, only VQVAE has conv
360
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
361
+ if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
362
+ else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
363
+ if with_bias: m.bias.data.zero_()
364
+
365
+ if init_head >= 0:
366
+ if isinstance(self.head, nn.Linear):
367
+ self.head.weight.data.mul_(init_head)
368
+ self.head.bias.data.zero_()
369
+ elif isinstance(self.head, nn.Sequential):
370
+ self.head[-1].weight.data.mul_(init_head)
371
+ self.head[-1].bias.data.zero_()
372
+
373
+ if isinstance(self.head_nm, AdaLNBeforeHead):
374
+ self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
375
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
376
+ self.head_nm.ada_lin[-1].bias.data.zero_()
377
+
378
+ depth = len(self.blocks)
379
+ for block_idx, sab in enumerate(self.blocks):
380
+ sab: AdaLNSelfAttn
381
+ sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
382
+ sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
383
+ if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
384
+ nn.init.ones_(sab.ffn.fcg.bias)
385
+ nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
386
+ if hasattr(sab, 'ada_lin'):
387
+ sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
388
+ sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
389
+ if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
390
+ sab.ada_lin[-1].bias.data.zero_()
391
+ elif hasattr(sab, 'ada_gss'):
392
+ sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
393
+ sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
394
+
395
+ def extra_repr(self):
396
+ return f'drop_path_rate={self.drop_path_rate:g}'
397
+
398
+
399
+ class VARHF(VAR, PyTorchModelHubMixin):
400
+ # repo_url="https://github.com/FoundationVision/VAR",
401
+ # tags=["image-generation"]):
402
+ def __init__(
403
+ self,
404
+ vae_kwargs,
405
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
406
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
407
+ attn_l2_norm=False,
408
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
409
+ flash_if_available=True, fused_if_available=True,
410
+ ):
411
+ vae_local = VQVAE(**vae_kwargs)
412
+ super().__init__(
413
+ vae_local=vae_local,
414
+ num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
415
+ norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
416
+ attn_l2_norm=attn_l2_norm,
417
+ patch_nums=patch_nums,
418
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
419
+ )
VAR/code/VAR/models/vqvae.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
4
+ - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
5
+ - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
6
+ """
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .basic_vae import Decoder, Encoder
13
+ from .quant import VectorQuantizer2
14
+
15
+
16
+ class VQVAE(nn.Module):
17
+ def __init__(
18
+ self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
19
+ beta=0.25, # commitment loss weight
20
+ using_znorm=False, # whether to normalize when computing the nearest neighbors
21
+ quant_conv_ks=3, # quant conv kernel size
22
+ quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
23
+ share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
24
+ default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
25
+ v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
26
+ test_mode=True,
27
+ ):
28
+ super().__init__()
29
+ self.test_mode = test_mode
30
+ self.V, self.Cvae = vocab_size, z_channels
31
+ # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
32
+ ddconfig = dict(
33
+ dropout=dropout, ch=ch, z_channels=z_channels,
34
+ in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
35
+ using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
36
+ # resamp_with_conv=True, # always True, removed.
37
+ )
38
+ ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
39
+ self.encoder = Encoder(double_z=False, **ddconfig)
40
+ self.decoder = Decoder(**ddconfig)
41
+
42
+ self.vocab_size = vocab_size
43
+ self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
44
+ self.quantize: VectorQuantizer2 = VectorQuantizer2(
45
+ vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
46
+ default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
47
+ )
48
+ self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
49
+ self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
50
+
51
+ if self.test_mode:
52
+ self.eval()
53
+ [p.requires_grad_(False) for p in self.parameters()]
54
+
55
+ # ===================== `forward` is only used in VAE training =====================
56
+ def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
57
+ VectorQuantizer2.forward
58
+ f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
59
+ return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
60
+ # ===================== `forward` is only used in VAE training =====================
61
+
62
+ def fhat_to_img(self, f_hat: torch.Tensor):
63
+ return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
64
+
65
+ def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
66
+ f = self.quant_conv(self.encoder(inp_img_no_grad))
67
+ return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
68
+
69
+ def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
70
+ B = ms_idx_Bl[0].shape[0]
71
+ ms_h_BChw = []
72
+ for idx_Bl in ms_idx_Bl:
73
+ l = idx_Bl.shape[1]
74
+ pn = round(l ** 0.5)
75
+ ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
76
+ return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
77
+
78
+ def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
79
+ if last_one:
80
+ return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
81
+ else:
82
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
83
+
84
+ def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
85
+ f = self.quant_conv(self.encoder(x))
86
+ ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
87
+ if last_one:
88
+ return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
89
+ else:
90
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
91
+
92
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
93
+ if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
94
+ state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
95
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
VAR/code/VAR/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ Pillow
3
+ huggingface_hub
4
+ numpy
5
+ pytz
6
+ transformers
7
+ typed-argument-parser
8
+ tensorboard
VAR/code/VAR/train.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import time
6
+ import warnings
7
+ from functools import partial
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ import dist
13
+ from utils import arg_util, misc
14
+ from utils.data import build_dataset
15
+ from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
16
+ from utils.misc import auto_resume
17
+
18
+
19
+ def build_everything(args: arg_util.Args):
20
+ # resume
21
+ auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
22
+ # create tensorboard logger
23
+ tb_lg: misc.TensorboardLogger
24
+ with_tb_lg = dist.is_master()
25
+ if with_tb_lg:
26
+ os.makedirs(args.tb_log_dir_path, exist_ok=True)
27
+ # noinspection PyTypeChecker
28
+ tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
29
+ tb_lg.flush()
30
+ else:
31
+ # noinspection PyTypeChecker
32
+ tb_lg = misc.DistLogger(None, verbose=False)
33
+ dist.barrier()
34
+
35
+ # log args
36
+ print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
37
+ print(f'initial args:\n{str(args)}')
38
+
39
+ # build data
40
+ if not args.local_debug:
41
+ print(f'[build PT data] ...\n')
42
+ num_classes, dataset_train, dataset_val = build_dataset(
43
+ args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
44
+ )
45
+ types = str((type(dataset_train).__name__, type(dataset_val).__name__))
46
+
47
+ ld_val = DataLoader(
48
+ dataset_val, num_workers=0, pin_memory=True,
49
+ batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
50
+ shuffle=False, drop_last=False,
51
+ )
52
+ del dataset_val
53
+
54
+ ld_train = DataLoader(
55
+ dataset=dataset_train, num_workers=args.workers, pin_memory=True,
56
+ generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
57
+ batch_sampler=DistInfiniteBatchSampler(
58
+ dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
59
+ shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
60
+ ),
61
+ )
62
+ del dataset_train
63
+
64
+ [print(line) for line in auto_resume_info]
65
+ print(f'[dataloader multi processing] ...', end='', flush=True)
66
+ stt = time.time()
67
+ iters_train = len(ld_train)
68
+ ld_train = iter(ld_train)
69
+ # noinspection PyArgumentList
70
+ print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
71
+ print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
72
+
73
+ else:
74
+ num_classes = 1000
75
+ ld_val = ld_train = None
76
+ iters_train = 10
77
+
78
+ # build models
79
+ from torch.nn.parallel import DistributedDataParallel as DDP
80
+ from models import VAR, VQVAE, build_vae_var
81
+ from trainer import VARTrainer
82
+ from utils.amp_sc import AmpOptimizer
83
+ from utils.lr_control import filter_params
84
+
85
+ vae_local, var_wo_ddp = build_vae_var(
86
+ V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
87
+ device=dist.get_device(), patch_nums=args.patch_nums,
88
+ num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
89
+ flash_if_available=args.fuse, fused_if_available=args.fuse,
90
+ init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
91
+ )
92
+
93
+ vae_ckpt = 'vae_ch160v4096z32.pth'
94
+ if dist.is_local_master():
95
+ if not os.path.exists(vae_ckpt):
96
+ os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
97
+ dist.barrier()
98
+ vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
99
+
100
+ vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
101
+ var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
102
+ var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
103
+
104
+ print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
105
+ count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
106
+ print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
107
+ print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
108
+
109
+ # build optimizer
110
+ names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
111
+ 'cls_token', 'start_token', 'task_token', 'cfg_uncond',
112
+ 'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
113
+ 'gamma', 'beta',
114
+ 'ada_gss', 'moe_bias',
115
+ 'scale_mul',
116
+ })
117
+ opt_clz = {
118
+ 'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
119
+ 'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
120
+ }[args.opt.lower().strip()]
121
+ opt_kw = dict(lr=args.tlr, weight_decay=0)
122
+ print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
123
+
124
+ var_optim = AmpOptimizer(
125
+ mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
126
+ grad_clip=args.tclip, n_gradient_accumulation=args.ac
127
+ )
128
+ del names, paras, para_groups
129
+
130
+ # build trainer
131
+ trainer = VARTrainer(
132
+ device=args.device, patch_nums=args.patch_nums, resos=args.resos,
133
+ vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
134
+ var_opt=var_optim, label_smooth=args.ls,
135
+ )
136
+ if trainer_state is not None and len(trainer_state):
137
+ trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
138
+ del vae_local, var_wo_ddp, var, var_optim
139
+
140
+ if args.local_debug:
141
+ rng = torch.Generator('cpu')
142
+ rng.manual_seed(0)
143
+ B = 4
144
+ inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
145
+ label = torch.ones(B, dtype=torch.long)
146
+
147
+ me = misc.MetricLogger(delimiter=' ')
148
+ trainer.train_step(
149
+ it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
150
+ inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
151
+ )
152
+ trainer.load_state_dict(trainer.state_dict())
153
+ trainer.train_step(
154
+ it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
155
+ inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
156
+ )
157
+ print({k: meter.global_avg for k, meter in me.meters.items()})
158
+
159
+ args.dump_log(); tb_lg.flush(); tb_lg.close()
160
+ if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
161
+ sys.stdout.close(), sys.stderr.close()
162
+ exit(0)
163
+
164
+ dist.barrier()
165
+ return (
166
+ tb_lg, trainer, start_ep, start_it,
167
+ iters_train, ld_train, ld_val
168
+ )
169
+
170
+
171
+ def main_training():
172
+
173
+ # # 在能 import Args 的环境里执行
174
+ # import copy, types
175
+ # from utils.arg_util import Args
176
+
177
+ # bad = []
178
+ # # 候选字段:注解 + 非私有类属性
179
+ # names = set(getattr(Args, '__annotations__', {}).keys()) | {
180
+ # k for k in vars(Args).keys() if not k.startswith('_')
181
+ # }
182
+ # for name in sorted(names):
183
+ # if hasattr(Args, name):
184
+ # val = getattr(Args, name)
185
+ # try:
186
+ # copy.deepcopy(val)
187
+ # except Exception as e:
188
+ # bad.append((name, type(val).__name__, str(e)))
189
+
190
+ # print("Non-deepcopy-able defaults:")
191
+ # for n,t,e in bad:
192
+ # print(f" {n}: {t} -> {e}")
193
+
194
+ args: arg_util.Args = arg_util.init_dist_and_get_args()
195
+ if args.local_debug:
196
+ torch.autograd.set_detect_anomaly(True)
197
+
198
+ (
199
+ tb_lg, trainer,
200
+ start_ep, start_it,
201
+ iters_train, ld_train, ld_val
202
+ ) = build_everything(args)
203
+
204
+ # train
205
+ start_time = time.time()
206
+ best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
207
+ best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
208
+
209
+ L_mean, L_tail = -1, -1
210
+ for ep in range(start_ep, args.ep):
211
+ if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
212
+ ld_train.sampler.set_epoch(ep)
213
+ if ep < 3:
214
+ # noinspection PyArgumentList
215
+ print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
216
+ tb_lg.set_step(ep * iters_train)
217
+
218
+ stats, (sec, remain_time, finish_time) = train_one_ep(
219
+ ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
220
+ )
221
+
222
+ L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
223
+ best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
224
+ if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
225
+ args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
226
+ args.cur_ep = f'{ep+1}/{args.ep}'
227
+ args.remain_time, args.finish_time = remain_time, finish_time
228
+
229
+ AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
230
+ is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
231
+ if is_val_and_also_saving:
232
+ val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
233
+ best_updated = best_val_loss_tail > val_loss_tail
234
+ best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
235
+ best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
236
+ AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
237
+ args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
238
+ print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
239
+
240
+ if dist.is_local_master():
241
+ local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
242
+ local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
243
+ print(f'[saving ckpt] ...', end='', flush=True)
244
+ torch.save({
245
+ 'epoch': ep+1,
246
+ 'iter': 0,
247
+ 'trainer': trainer.state_dict(),
248
+ 'args': args.state_dict(),
249
+ }, local_out_ckpt)
250
+ if best_updated:
251
+ shutil.copy(local_out_ckpt, local_out_ckpt_best)
252
+ print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
253
+ dist.barrier()
254
+
255
+ print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
256
+ tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
257
+ tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
258
+ args.dump_log(); tb_lg.flush()
259
+
260
+ total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
261
+ print('\n\n')
262
+ print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
263
+ print('\n\n')
264
+
265
+ del stats
266
+ del iters_train, ld_train
267
+ time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
268
+
269
+ args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
270
+ print(f'final args:\n\n{str(args)}')
271
+ args.dump_log(); tb_lg.flush(); tb_lg.close()
272
+ dist.barrier()
273
+
274
+
275
+ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
276
+ # import heavy packages after Dataloader object creation
277
+ from trainer import VARTrainer
278
+ from utils.lr_control import lr_wd_annealing
279
+ trainer: VARTrainer
280
+
281
+ step_cnt = 0
282
+ me = misc.MetricLogger(delimiter=' ')
283
+ me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
284
+ me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
285
+ [me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
286
+ [me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
287
+ header = f'[Ep]: [{ep:4d}/{args.ep}]'
288
+
289
+ if is_first_ep:
290
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
291
+ warnings.filterwarnings('ignore', category=UserWarning)
292
+ g_it, max_it = ep * iters_train, args.ep * iters_train
293
+
294
+ for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
295
+ g_it = ep * iters_train + it
296
+ if it < start_it: continue
297
+ if is_first_ep and it == start_it: warnings.resetwarnings()
298
+
299
+ inp = inp.to(args.device, non_blocking=True)
300
+ label = label.to(args.device, non_blocking=True)
301
+
302
+ args.cur_it = f'{it+1}/{iters_train}'
303
+
304
+ wp_it = args.wp * iters_train
305
+ min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
306
+ args.cur_lr, args.cur_wd = max_tlr, max_twd
307
+
308
+ if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
309
+ if g_it <= wp_it: prog_si = args.pg0
310
+ elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
311
+ else:
312
+ delta = len(args.patch_nums) - 1 - args.pg0
313
+ progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
314
+ prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
315
+ else:
316
+ prog_si = -1
317
+
318
+ stepping = (g_it + 1) % args.ac == 0
319
+ step_cnt += int(stepping)
320
+
321
+ grad_norm, scale_log2 = trainer.train_step(
322
+ it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
323
+ inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
324
+ )
325
+
326
+ me.update(tlr=max_tlr)
327
+ tb_lg.set_step(step=g_it)
328
+ tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
329
+ tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
330
+ tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
331
+ tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
332
+ tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
333
+
334
+ if args.tclip > 0:
335
+ tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
336
+ tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
337
+
338
+ me.synchronize_between_processes()
339
+ return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
340
+
341
+
342
+ class NullDDP(torch.nn.Module):
343
+ def __init__(self, module, *args, **kwargs):
344
+ super(NullDDP, self).__init__()
345
+ self.module = module
346
+ self.require_backward_grad_sync = False
347
+
348
+ def forward(self, *args, **kwargs):
349
+ return self.module(*args, **kwargs)
350
+
351
+
352
+ if __name__ == '__main__':
353
+ try: main_training()
354
+ finally:
355
+ dist.finalize()
356
+ if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
357
+ sys.stdout.close(), sys.stderr.close()
VAR/code/VAR/trainer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.utils.data import DataLoader
8
+
9
+ import dist
10
+ from models import VAR, VQVAE, VectorQuantizer2
11
+ from utils.amp_sc import AmpOptimizer
12
+ from utils.misc import MetricLogger, TensorboardLogger
13
+
14
+ Ten = torch.Tensor
15
+ FTen = torch.Tensor
16
+ ITen = torch.LongTensor
17
+ BTen = torch.BoolTensor
18
+
19
+
20
+ class VARTrainer(object):
21
+ def __init__(
22
+ self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
23
+ vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
24
+ var_opt: AmpOptimizer, label_smooth: float,
25
+ ):
26
+ super(VARTrainer, self).__init__()
27
+
28
+ self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
29
+ self.quantize_local: VectorQuantizer2
30
+ self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
31
+ self.var_opt = var_opt
32
+
33
+ del self.var_wo_ddp.rng
34
+ self.var_wo_ddp.rng = torch.Generator(device=device)
35
+
36
+ self.label_smooth = label_smooth
37
+ self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
38
+ self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
39
+ self.L = sum(pn * pn for pn in patch_nums)
40
+ self.last_l = patch_nums[-1] * patch_nums[-1]
41
+ self.loss_weight = torch.ones(1, self.L, device=device) / self.L
42
+
43
+ self.patch_nums, self.resos = patch_nums, resos
44
+ self.begin_ends = []
45
+ cur = 0
46
+ for i, pn in enumerate(patch_nums):
47
+ self.begin_ends.append((cur, cur + pn * pn))
48
+ cur += pn*pn
49
+
50
+ self.prog_it = 0
51
+ self.last_prog_si = -1
52
+ self.first_prog = True
53
+
54
+ @torch.no_grad()
55
+ def eval_ep(self, ld_val: DataLoader):
56
+ tot = 0
57
+ L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
58
+ stt = time.time()
59
+ training = self.var_wo_ddp.training
60
+ self.var_wo_ddp.eval()
61
+ for inp_B3HW, label_B in ld_val:
62
+ B, V = label_B.shape[0], self.vae_local.vocab_size
63
+ inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
64
+ label_B = label_B.to(dist.get_device(), non_blocking=True)
65
+
66
+ gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
67
+ gt_BL = torch.cat(gt_idx_Bl, dim=1)
68
+ x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
69
+
70
+ self.var_wo_ddp.forward
71
+ logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
72
+ L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
73
+ L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
74
+ acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
75
+ acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
76
+ tot += B
77
+ self.var_wo_ddp.train(training)
78
+
79
+ stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
80
+ dist.allreduce(stats)
81
+ tot = round(stats[-1].item())
82
+ stats /= tot
83
+ L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
84
+ return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
85
+
86
+ def train_step(
87
+ self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
88
+ inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
89
+ ) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
90
+ # if progressive training
91
+ self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
92
+ if self.last_prog_si != prog_si:
93
+ if self.last_prog_si != -1: self.first_prog = False
94
+ self.last_prog_si = prog_si
95
+ self.prog_it = 0
96
+ self.prog_it += 1
97
+ prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
98
+ if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
99
+ if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
100
+
101
+ # forward
102
+ B, V = label_B.shape[0], self.vae_local.vocab_size
103
+ self.var.require_backward_grad_sync = stepping
104
+
105
+ gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
106
+ gt_BL = torch.cat(gt_idx_Bl, dim=1)
107
+ x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
108
+
109
+ with self.var_opt.amp_ctx:
110
+ self.var_wo_ddp.forward
111
+ logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
112
+ loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
113
+ if prog_si >= 0: # in progressive training
114
+ bg, ed = self.begin_ends[prog_si]
115
+ assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
116
+ lw = self.loss_weight[:, :ed].clone()
117
+ lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
118
+ else: # not in progressive training
119
+ lw = self.loss_weight
120
+ loss = loss.mul(lw).sum(dim=-1).mean()
121
+
122
+ # backward
123
+ grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
124
+
125
+ # log
126
+ pred_BL = logits_BLV.data.argmax(dim=-1)
127
+ if it == 0 or it in metric_lg.log_iters:
128
+ Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
129
+ acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
130
+ if prog_si >= 0: # in progressive training
131
+ Ltail = acc_tail = -1
132
+ else: # not in progressive training
133
+ Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
134
+ acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
135
+ grad_norm = grad_norm.item()
136
+ metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
137
+
138
+ # log to tensorboard
139
+ if g_it == 0 or (g_it + 1) % 500 == 0:
140
+ prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
141
+ dist.allreduce(prob_per_class_is_chosen)
142
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
143
+ cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
144
+ if dist.is_master():
145
+ if g_it == 0:
146
+ tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
147
+ tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
148
+ kw = dict(z_voc_usage=cluster_usage)
149
+ for si, (bg, ed) in enumerate(self.begin_ends):
150
+ if 0 <= prog_si < si: break
151
+ pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
152
+ acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
153
+ ce = self.val_loss(pred, tar).item()
154
+ kw[f'acc_{self.resos[si]}'] = acc
155
+ kw[f'L_{self.resos[si]}'] = ce
156
+ tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
157
+ tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
158
+
159
+ self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
160
+ return grad_norm, scale_log2
161
+
162
+ def get_config(self):
163
+ return {
164
+ 'patch_nums': self.patch_nums, 'resos': self.resos,
165
+ 'label_smooth': self.label_smooth,
166
+ 'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
167
+ }
168
+
169
+ def state_dict(self):
170
+ state = {'config': self.get_config()}
171
+ for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
172
+ m = getattr(self, k)
173
+ if m is not None:
174
+ if hasattr(m, '_orig_mod'):
175
+ m = m._orig_mod
176
+ state[k] = m.state_dict()
177
+ return state
178
+
179
+ def load_state_dict(self, state, strict=True, skip_vae=False):
180
+ for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
181
+ if skip_vae and 'vae' in k: continue
182
+ m = getattr(self, k)
183
+ if m is not None:
184
+ if hasattr(m, '_orig_mod'):
185
+ m = m._orig_mod
186
+ ret = m.load_state_dict(state[k], strict=strict)
187
+ if ret is not None:
188
+ missing, unexpected = ret
189
+ print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
190
+ print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
191
+
192
+ config: dict = state.pop('config', None)
193
+ self.prog_it = config.get('prog_it', 0)
194
+ self.last_prog_si = config.get('last_prog_si', -1)
195
+ self.first_prog = config.get('first_prog', True)
196
+ if config is not None:
197
+ for k, v in self.get_config().items():
198
+ if config.get(k, None) != v:
199
+ err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
200
+ if strict: raise AttributeError(err)
201
+ else: print(err)
VAR/code/VAR/utils/__pycache__/amp_sc.cpython-310.pyc ADDED
Binary file (3.14 kB). View file
 
VAR/code/VAR/utils/__pycache__/arg_util.cpython-310.pyc ADDED
Binary file (9.26 kB). View file
 
VAR/code/VAR/utils/__pycache__/arg_util.cpython-311.pyc ADDED
Binary file (17.7 kB). View file
 
VAR/code/VAR/utils/__pycache__/data.cpython-310.pyc ADDED
Binary file (1.84 kB). View file