Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +0 -0
- var/D3HR/DiT-XL-2-256x256.pt +3 -0
- var/D3HR/DiT-XL/.gitattributes +34 -0
- var/D3HR/DiT-XL/README.md +10 -0
- var/D3HR/DiT-XL/model_index.json +1018 -0
- var/D3HR/DiT-XL/scheduler/scheduler_config.json +13 -0
- var/D3HR/DiT-XL/transformer/config.json +23 -0
- var/D3HR/DiT-XL/transformer/diffusion_pytorch_model.bin +3 -0
- var/D3HR/DiT-XL/vae/config.json +30 -0
- var/D3HR/DiT-XL/vae/diffusion_pytorch_model.bin +3 -0
- var/D3HR/README.md +98 -0
- var/D3HR/ds_inf/imagenet1k_train.txt +3 -0
- var/D3HR/ds_inf/imagenet_1k_mapping.json +0 -0
- var/D3HR/ds_inf/tiny-imagenet-mapping.txt +200 -0
- var/D3HR/generation/__init__.py +0 -0
- var/D3HR/generation/__pycache__/__init__.cpython-310.pyc +0 -0
- var/D3HR/generation/__pycache__/dit_inversion_save_statistic.cpython-310.pyc +0 -0
- var/D3HR/generation/dit_inversion_save_statistic.py +437 -0
- var/D3HR/generation/dit_inversion_save_statistic.sh +24 -0
- var/D3HR/generation/group_sampling.py +368 -0
- var/D3HR/generation/group_sampling.sh +20 -0
- var/D3HR/imgs/framework.jpg +3 -0
- var/D3HR/imgs/framework.pdf +3 -0
- var/D3HR/requirements.txt +9 -0
- var/D3HR/validation/__pycache__/argument.cpython-310.pyc +0 -0
- var/D3HR/validation/argument.py +310 -0
- var/D3HR/validation/get_train_list.py +26 -0
- var/D3HR/validation/models/__init__.py +190 -0
- var/D3HR/validation/models/__pycache__/__init__.cpython-310.pyc +0 -0
- var/D3HR/validation/models/__pycache__/__init__.cpython-37.pyc +0 -0
- var/D3HR/validation/models/__pycache__/convnet.cpython-310.pyc +0 -0
- var/D3HR/validation/models/__pycache__/convnet.cpython-37.pyc +0 -0
- var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-310.pyc +0 -0
- var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-37.pyc +0 -0
- var/D3HR/validation/models/__pycache__/resnet.cpython-310.pyc +0 -0
- var/D3HR/validation/models/__pycache__/resnet.cpython-37.pyc +0 -0
- var/D3HR/validation/models/convnet.py +147 -0
- var/D3HR/validation/models/dit_models.py +438 -0
- var/D3HR/validation/models/mobilenet_v2.py +151 -0
- var/D3HR/validation/models/pipeline_stable_unclip_img2img.py +854 -0
- var/D3HR/validation/models/resnet.py +80 -0
- var/D3HR/validation/models/scheduling_ddim.py +522 -0
- var/D3HR/validation/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
- var/D3HR/validation/utils/__pycache__/data_utils.cpython-37.pyc +0 -0
- var/D3HR/validation/utils/__pycache__/validate_utils.cpython-310.pyc +0 -0
- var/D3HR/validation/utils/__pycache__/validate_utils.cpython-37.pyc +0 -0
- var/D3HR/validation/utils/data_utils.py +431 -0
- var/D3HR/validation/utils/download.py +50 -0
- var/D3HR/validation/utils/syn_utils_dit.py +172 -0
- var/D3HR/validation/utils/syn_utils_img2img.py +134 -0
.gitattributes
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
var/D3HR/DiT-XL-2-256x256.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ec1876e4c03471bca126663a30e2d1b20610b6d2f87850a39a36f25cc685521
|
| 3 |
+
size 2700611775
|
var/D3HR/DiT-XL/.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
var/D3HR/DiT-XL/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Scalable Diffusion Models with Transformers (DiT)
|
| 6 |
+
|
| 7 |
+
## Abstract
|
| 8 |
+
|
| 9 |
+
We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.
|
| 10 |
+
|
var/D3HR/DiT-XL/model_index.json
ADDED
|
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DiTPipeline",
|
| 3 |
+
"_diffusers_version": "0.12.0.dev0",
|
| 4 |
+
"scheduler": [
|
| 5 |
+
"diffusers",
|
| 6 |
+
"DDIMScheduler"
|
| 7 |
+
],
|
| 8 |
+
"transformer": [
|
| 9 |
+
"diffusers",
|
| 10 |
+
"Transformer2DModel"
|
| 11 |
+
],
|
| 12 |
+
"vae": [
|
| 13 |
+
"diffusers",
|
| 14 |
+
"AutoencoderKL"
|
| 15 |
+
],
|
| 16 |
+
"id2label": {
|
| 17 |
+
"0": "tench, Tinca tinca",
|
| 18 |
+
"1": "goldfish, Carassius auratus",
|
| 19 |
+
"2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
| 20 |
+
"3": "tiger shark, Galeocerdo cuvieri",
|
| 21 |
+
"4": "hammerhead, hammerhead shark",
|
| 22 |
+
"5": "electric ray, crampfish, numbfish, torpedo",
|
| 23 |
+
"6": "stingray",
|
| 24 |
+
"7": "cock",
|
| 25 |
+
"8": "hen",
|
| 26 |
+
"9": "ostrich, Struthio camelus",
|
| 27 |
+
"10": "brambling, Fringilla montifringilla",
|
| 28 |
+
"11": "goldfinch, Carduelis carduelis",
|
| 29 |
+
"12": "house finch, linnet, Carpodacus mexicanus",
|
| 30 |
+
"13": "junco, snowbird",
|
| 31 |
+
"14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
|
| 32 |
+
"15": "robin, American robin, Turdus migratorius",
|
| 33 |
+
"16": "bulbul",
|
| 34 |
+
"17": "jay",
|
| 35 |
+
"18": "magpie",
|
| 36 |
+
"19": "chickadee",
|
| 37 |
+
"20": "water ouzel, dipper",
|
| 38 |
+
"21": "kite",
|
| 39 |
+
"22": "bald eagle, American eagle, Haliaeetus leucocephalus",
|
| 40 |
+
"23": "vulture",
|
| 41 |
+
"24": "great grey owl, great gray owl, Strix nebulosa",
|
| 42 |
+
"25": "European fire salamander, Salamandra salamandra",
|
| 43 |
+
"26": "common newt, Triturus vulgaris",
|
| 44 |
+
"27": "eft",
|
| 45 |
+
"28": "spotted salamander, Ambystoma maculatum",
|
| 46 |
+
"29": "axolotl, mud puppy, Ambystoma mexicanum",
|
| 47 |
+
"30": "bullfrog, Rana catesbeiana",
|
| 48 |
+
"31": "tree frog, tree-frog",
|
| 49 |
+
"32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
|
| 50 |
+
"33": "loggerhead, loggerhead turtle, Caretta caretta",
|
| 51 |
+
"34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
|
| 52 |
+
"35": "mud turtle",
|
| 53 |
+
"36": "terrapin",
|
| 54 |
+
"37": "box turtle, box tortoise",
|
| 55 |
+
"38": "banded gecko",
|
| 56 |
+
"39": "common iguana, iguana, Iguana iguana",
|
| 57 |
+
"40": "American chameleon, anole, Anolis carolinensis",
|
| 58 |
+
"41": "whiptail, whiptail lizard",
|
| 59 |
+
"42": "agama",
|
| 60 |
+
"43": "frilled lizard, Chlamydosaurus kingi",
|
| 61 |
+
"44": "alligator lizard",
|
| 62 |
+
"45": "Gila monster, Heloderma suspectum",
|
| 63 |
+
"46": "green lizard, Lacerta viridis",
|
| 64 |
+
"47": "African chameleon, Chamaeleo chamaeleon",
|
| 65 |
+
"48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
|
| 66 |
+
"49": "African crocodile, Nile crocodile, Crocodylus niloticus",
|
| 67 |
+
"50": "American alligator, Alligator mississipiensis",
|
| 68 |
+
"51": "triceratops",
|
| 69 |
+
"52": "thunder snake, worm snake, Carphophis amoenus",
|
| 70 |
+
"53": "ringneck snake, ring-necked snake, ring snake",
|
| 71 |
+
"54": "hognose snake, puff adder, sand viper",
|
| 72 |
+
"55": "green snake, grass snake",
|
| 73 |
+
"56": "king snake, kingsnake",
|
| 74 |
+
"57": "garter snake, grass snake",
|
| 75 |
+
"58": "water snake",
|
| 76 |
+
"59": "vine snake",
|
| 77 |
+
"60": "night snake, Hypsiglena torquata",
|
| 78 |
+
"61": "boa constrictor, Constrictor constrictor",
|
| 79 |
+
"62": "rock python, rock snake, Python sebae",
|
| 80 |
+
"63": "Indian cobra, Naja naja",
|
| 81 |
+
"64": "green mamba",
|
| 82 |
+
"65": "sea snake",
|
| 83 |
+
"66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
|
| 84 |
+
"67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
|
| 85 |
+
"68": "sidewinder, horned rattlesnake, Crotalus cerastes",
|
| 86 |
+
"69": "trilobite",
|
| 87 |
+
"70": "harvestman, daddy longlegs, Phalangium opilio",
|
| 88 |
+
"71": "scorpion",
|
| 89 |
+
"72": "black and gold garden spider, Argiope aurantia",
|
| 90 |
+
"73": "barn spider, Araneus cavaticus",
|
| 91 |
+
"74": "garden spider, Aranea diademata",
|
| 92 |
+
"75": "black widow, Latrodectus mactans",
|
| 93 |
+
"76": "tarantula",
|
| 94 |
+
"77": "wolf spider, hunting spider",
|
| 95 |
+
"78": "tick",
|
| 96 |
+
"79": "centipede",
|
| 97 |
+
"80": "black grouse",
|
| 98 |
+
"81": "ptarmigan",
|
| 99 |
+
"82": "ruffed grouse, partridge, Bonasa umbellus",
|
| 100 |
+
"83": "prairie chicken, prairie grouse, prairie fowl",
|
| 101 |
+
"84": "peacock",
|
| 102 |
+
"85": "quail",
|
| 103 |
+
"86": "partridge",
|
| 104 |
+
"87": "African grey, African gray, Psittacus erithacus",
|
| 105 |
+
"88": "macaw",
|
| 106 |
+
"89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
| 107 |
+
"90": "lorikeet",
|
| 108 |
+
"91": "coucal",
|
| 109 |
+
"92": "bee eater",
|
| 110 |
+
"93": "hornbill",
|
| 111 |
+
"94": "hummingbird",
|
| 112 |
+
"95": "jacamar",
|
| 113 |
+
"96": "toucan",
|
| 114 |
+
"97": "drake",
|
| 115 |
+
"98": "red-breasted merganser, Mergus serrator",
|
| 116 |
+
"99": "goose",
|
| 117 |
+
"100": "black swan, Cygnus atratus",
|
| 118 |
+
"101": "tusker",
|
| 119 |
+
"102": "echidna, spiny anteater, anteater",
|
| 120 |
+
"103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
|
| 121 |
+
"104": "wallaby, brush kangaroo",
|
| 122 |
+
"105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
|
| 123 |
+
"106": "wombat",
|
| 124 |
+
"107": "jellyfish",
|
| 125 |
+
"108": "sea anemone, anemone",
|
| 126 |
+
"109": "brain coral",
|
| 127 |
+
"110": "flatworm, platyhelminth",
|
| 128 |
+
"111": "nematode, nematode worm, roundworm",
|
| 129 |
+
"112": "conch",
|
| 130 |
+
"113": "snail",
|
| 131 |
+
"114": "slug",
|
| 132 |
+
"115": "sea slug, nudibranch",
|
| 133 |
+
"116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
|
| 134 |
+
"117": "chambered nautilus, pearly nautilus, nautilus",
|
| 135 |
+
"118": "Dungeness crab, Cancer magister",
|
| 136 |
+
"119": "rock crab, Cancer irroratus",
|
| 137 |
+
"120": "fiddler crab",
|
| 138 |
+
"121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
|
| 139 |
+
"122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
|
| 140 |
+
"123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
|
| 141 |
+
"124": "crayfish, crawfish, crawdad, crawdaddy",
|
| 142 |
+
"125": "hermit crab",
|
| 143 |
+
"126": "isopod",
|
| 144 |
+
"127": "white stork, Ciconia ciconia",
|
| 145 |
+
"128": "black stork, Ciconia nigra",
|
| 146 |
+
"129": "spoonbill",
|
| 147 |
+
"130": "flamingo",
|
| 148 |
+
"131": "little blue heron, Egretta caerulea",
|
| 149 |
+
"132": "American egret, great white heron, Egretta albus",
|
| 150 |
+
"133": "bittern",
|
| 151 |
+
"134": "crane",
|
| 152 |
+
"135": "limpkin, Aramus pictus",
|
| 153 |
+
"136": "European gallinule, Porphyrio porphyrio",
|
| 154 |
+
"137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
|
| 155 |
+
"138": "bustard",
|
| 156 |
+
"139": "ruddy turnstone, Arenaria interpres",
|
| 157 |
+
"140": "red-backed sandpiper, dunlin, Erolia alpina",
|
| 158 |
+
"141": "redshank, Tringa totanus",
|
| 159 |
+
"142": "dowitcher",
|
| 160 |
+
"143": "oystercatcher, oyster catcher",
|
| 161 |
+
"144": "pelican",
|
| 162 |
+
"145": "king penguin, Aptenodytes patagonica",
|
| 163 |
+
"146": "albatross, mollymawk",
|
| 164 |
+
"147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
|
| 165 |
+
"148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
|
| 166 |
+
"149": "dugong, Dugong dugon",
|
| 167 |
+
"150": "sea lion",
|
| 168 |
+
"151": "Chihuahua",
|
| 169 |
+
"152": "Japanese spaniel",
|
| 170 |
+
"153": "Maltese dog, Maltese terrier, Maltese",
|
| 171 |
+
"154": "Pekinese, Pekingese, Peke",
|
| 172 |
+
"155": "Shih-Tzu",
|
| 173 |
+
"156": "Blenheim spaniel",
|
| 174 |
+
"157": "papillon",
|
| 175 |
+
"158": "toy terrier",
|
| 176 |
+
"159": "Rhodesian ridgeback",
|
| 177 |
+
"160": "Afghan hound, Afghan",
|
| 178 |
+
"161": "basset, basset hound",
|
| 179 |
+
"162": "beagle",
|
| 180 |
+
"163": "bloodhound, sleuthhound",
|
| 181 |
+
"164": "bluetick",
|
| 182 |
+
"165": "black-and-tan coonhound",
|
| 183 |
+
"166": "Walker hound, Walker foxhound",
|
| 184 |
+
"167": "English foxhound",
|
| 185 |
+
"168": "redbone",
|
| 186 |
+
"169": "borzoi, Russian wolfhound",
|
| 187 |
+
"170": "Irish wolfhound",
|
| 188 |
+
"171": "Italian greyhound",
|
| 189 |
+
"172": "whippet",
|
| 190 |
+
"173": "Ibizan hound, Ibizan Podenco",
|
| 191 |
+
"174": "Norwegian elkhound, elkhound",
|
| 192 |
+
"175": "otterhound, otter hound",
|
| 193 |
+
"176": "Saluki, gazelle hound",
|
| 194 |
+
"177": "Scottish deerhound, deerhound",
|
| 195 |
+
"178": "Weimaraner",
|
| 196 |
+
"179": "Staffordshire bullterrier, Staffordshire bull terrier",
|
| 197 |
+
"180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
|
| 198 |
+
"181": "Bedlington terrier",
|
| 199 |
+
"182": "Border terrier",
|
| 200 |
+
"183": "Kerry blue terrier",
|
| 201 |
+
"184": "Irish terrier",
|
| 202 |
+
"185": "Norfolk terrier",
|
| 203 |
+
"186": "Norwich terrier",
|
| 204 |
+
"187": "Yorkshire terrier",
|
| 205 |
+
"188": "wire-haired fox terrier",
|
| 206 |
+
"189": "Lakeland terrier",
|
| 207 |
+
"190": "Sealyham terrier, Sealyham",
|
| 208 |
+
"191": "Airedale, Airedale terrier",
|
| 209 |
+
"192": "cairn, cairn terrier",
|
| 210 |
+
"193": "Australian terrier",
|
| 211 |
+
"194": "Dandie Dinmont, Dandie Dinmont terrier",
|
| 212 |
+
"195": "Boston bull, Boston terrier",
|
| 213 |
+
"196": "miniature schnauzer",
|
| 214 |
+
"197": "giant schnauzer",
|
| 215 |
+
"198": "standard schnauzer",
|
| 216 |
+
"199": "Scotch terrier, Scottish terrier, Scottie",
|
| 217 |
+
"200": "Tibetan terrier, chrysanthemum dog",
|
| 218 |
+
"201": "silky terrier, Sydney silky",
|
| 219 |
+
"202": "soft-coated wheaten terrier",
|
| 220 |
+
"203": "West Highland white terrier",
|
| 221 |
+
"204": "Lhasa, Lhasa apso",
|
| 222 |
+
"205": "flat-coated retriever",
|
| 223 |
+
"206": "curly-coated retriever",
|
| 224 |
+
"207": "golden retriever",
|
| 225 |
+
"208": "Labrador retriever",
|
| 226 |
+
"209": "Chesapeake Bay retriever",
|
| 227 |
+
"210": "German short-haired pointer",
|
| 228 |
+
"211": "vizsla, Hungarian pointer",
|
| 229 |
+
"212": "English setter",
|
| 230 |
+
"213": "Irish setter, red setter",
|
| 231 |
+
"214": "Gordon setter",
|
| 232 |
+
"215": "Brittany spaniel",
|
| 233 |
+
"216": "clumber, clumber spaniel",
|
| 234 |
+
"217": "English springer, English springer spaniel",
|
| 235 |
+
"218": "Welsh springer spaniel",
|
| 236 |
+
"219": "cocker spaniel, English cocker spaniel, cocker",
|
| 237 |
+
"220": "Sussex spaniel",
|
| 238 |
+
"221": "Irish water spaniel",
|
| 239 |
+
"222": "kuvasz",
|
| 240 |
+
"223": "schipperke",
|
| 241 |
+
"224": "groenendael",
|
| 242 |
+
"225": "malinois",
|
| 243 |
+
"226": "briard",
|
| 244 |
+
"227": "kelpie",
|
| 245 |
+
"228": "komondor",
|
| 246 |
+
"229": "Old English sheepdog, bobtail",
|
| 247 |
+
"230": "Shetland sheepdog, Shetland sheep dog, Shetland",
|
| 248 |
+
"231": "collie",
|
| 249 |
+
"232": "Border collie",
|
| 250 |
+
"233": "Bouvier des Flandres, Bouviers des Flandres",
|
| 251 |
+
"234": "Rottweiler",
|
| 252 |
+
"235": "German shepherd, German shepherd dog, German police dog, alsatian",
|
| 253 |
+
"236": "Doberman, Doberman pinscher",
|
| 254 |
+
"237": "miniature pinscher",
|
| 255 |
+
"238": "Greater Swiss Mountain dog",
|
| 256 |
+
"239": "Bernese mountain dog",
|
| 257 |
+
"240": "Appenzeller",
|
| 258 |
+
"241": "EntleBucher",
|
| 259 |
+
"242": "boxer",
|
| 260 |
+
"243": "bull mastiff",
|
| 261 |
+
"244": "Tibetan mastiff",
|
| 262 |
+
"245": "French bulldog",
|
| 263 |
+
"246": "Great Dane",
|
| 264 |
+
"247": "Saint Bernard, St Bernard",
|
| 265 |
+
"248": "Eskimo dog, husky",
|
| 266 |
+
"249": "malamute, malemute, Alaskan malamute",
|
| 267 |
+
"250": "Siberian husky",
|
| 268 |
+
"251": "dalmatian, coach dog, carriage dog",
|
| 269 |
+
"252": "affenpinscher, monkey pinscher, monkey dog",
|
| 270 |
+
"253": "basenji",
|
| 271 |
+
"254": "pug, pug-dog",
|
| 272 |
+
"255": "Leonberg",
|
| 273 |
+
"256": "Newfoundland, Newfoundland dog",
|
| 274 |
+
"257": "Great Pyrenees",
|
| 275 |
+
"258": "Samoyed, Samoyede",
|
| 276 |
+
"259": "Pomeranian",
|
| 277 |
+
"260": "chow, chow chow",
|
| 278 |
+
"261": "keeshond",
|
| 279 |
+
"262": "Brabancon griffon",
|
| 280 |
+
"263": "Pembroke, Pembroke Welsh corgi",
|
| 281 |
+
"264": "Cardigan, Cardigan Welsh corgi",
|
| 282 |
+
"265": "toy poodle",
|
| 283 |
+
"266": "miniature poodle",
|
| 284 |
+
"267": "standard poodle",
|
| 285 |
+
"268": "Mexican hairless",
|
| 286 |
+
"269": "timber wolf, grey wolf, gray wolf, Canis lupus",
|
| 287 |
+
"270": "white wolf, Arctic wolf, Canis lupus tundrarum",
|
| 288 |
+
"271": "red wolf, maned wolf, Canis rufus, Canis niger",
|
| 289 |
+
"272": "coyote, prairie wolf, brush wolf, Canis latrans",
|
| 290 |
+
"273": "dingo, warrigal, warragal, Canis dingo",
|
| 291 |
+
"274": "dhole, Cuon alpinus",
|
| 292 |
+
"275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
|
| 293 |
+
"276": "hyena, hyaena",
|
| 294 |
+
"277": "red fox, Vulpes vulpes",
|
| 295 |
+
"278": "kit fox, Vulpes macrotis",
|
| 296 |
+
"279": "Arctic fox, white fox, Alopex lagopus",
|
| 297 |
+
"280": "grey fox, gray fox, Urocyon cinereoargenteus",
|
| 298 |
+
"281": "tabby, tabby cat",
|
| 299 |
+
"282": "tiger cat",
|
| 300 |
+
"283": "Persian cat",
|
| 301 |
+
"284": "Siamese cat, Siamese",
|
| 302 |
+
"285": "Egyptian cat",
|
| 303 |
+
"286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
|
| 304 |
+
"287": "lynx, catamount",
|
| 305 |
+
"288": "leopard, Panthera pardus",
|
| 306 |
+
"289": "snow leopard, ounce, Panthera uncia",
|
| 307 |
+
"290": "jaguar, panther, Panthera onca, Felis onca",
|
| 308 |
+
"291": "lion, king of beasts, Panthera leo",
|
| 309 |
+
"292": "tiger, Panthera tigris",
|
| 310 |
+
"293": "cheetah, chetah, Acinonyx jubatus",
|
| 311 |
+
"294": "brown bear, bruin, Ursus arctos",
|
| 312 |
+
"295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
|
| 313 |
+
"296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
|
| 314 |
+
"297": "sloth bear, Melursus ursinus, Ursus ursinus",
|
| 315 |
+
"298": "mongoose",
|
| 316 |
+
"299": "meerkat, mierkat",
|
| 317 |
+
"300": "tiger beetle",
|
| 318 |
+
"301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
|
| 319 |
+
"302": "ground beetle, carabid beetle",
|
| 320 |
+
"303": "long-horned beetle, longicorn, longicorn beetle",
|
| 321 |
+
"304": "leaf beetle, chrysomelid",
|
| 322 |
+
"305": "dung beetle",
|
| 323 |
+
"306": "rhinoceros beetle",
|
| 324 |
+
"307": "weevil",
|
| 325 |
+
"308": "fly",
|
| 326 |
+
"309": "bee",
|
| 327 |
+
"310": "ant, emmet, pismire",
|
| 328 |
+
"311": "grasshopper, hopper",
|
| 329 |
+
"312": "cricket",
|
| 330 |
+
"313": "walking stick, walkingstick, stick insect",
|
| 331 |
+
"314": "cockroach, roach",
|
| 332 |
+
"315": "mantis, mantid",
|
| 333 |
+
"316": "cicada, cicala",
|
| 334 |
+
"317": "leafhopper",
|
| 335 |
+
"318": "lacewing, lacewing fly",
|
| 336 |
+
"319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
| 337 |
+
"320": "damselfly",
|
| 338 |
+
"321": "admiral",
|
| 339 |
+
"322": "ringlet, ringlet butterfly",
|
| 340 |
+
"323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
|
| 341 |
+
"324": "cabbage butterfly",
|
| 342 |
+
"325": "sulphur butterfly, sulfur butterfly",
|
| 343 |
+
"326": "lycaenid, lycaenid butterfly",
|
| 344 |
+
"327": "starfish, sea star",
|
| 345 |
+
"328": "sea urchin",
|
| 346 |
+
"329": "sea cucumber, holothurian",
|
| 347 |
+
"330": "wood rabbit, cottontail, cottontail rabbit",
|
| 348 |
+
"331": "hare",
|
| 349 |
+
"332": "Angora, Angora rabbit",
|
| 350 |
+
"333": "hamster",
|
| 351 |
+
"334": "porcupine, hedgehog",
|
| 352 |
+
"335": "fox squirrel, eastern fox squirrel, Sciurus niger",
|
| 353 |
+
"336": "marmot",
|
| 354 |
+
"337": "beaver",
|
| 355 |
+
"338": "guinea pig, Cavia cobaya",
|
| 356 |
+
"339": "sorrel",
|
| 357 |
+
"340": "zebra",
|
| 358 |
+
"341": "hog, pig, grunter, squealer, Sus scrofa",
|
| 359 |
+
"342": "wild boar, boar, Sus scrofa",
|
| 360 |
+
"343": "warthog",
|
| 361 |
+
"344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
|
| 362 |
+
"345": "ox",
|
| 363 |
+
"346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
|
| 364 |
+
"347": "bison",
|
| 365 |
+
"348": "ram, tup",
|
| 366 |
+
"349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
|
| 367 |
+
"350": "ibex, Capra ibex",
|
| 368 |
+
"351": "hartebeest",
|
| 369 |
+
"352": "impala, Aepyceros melampus",
|
| 370 |
+
"353": "gazelle",
|
| 371 |
+
"354": "Arabian camel, dromedary, Camelus dromedarius",
|
| 372 |
+
"355": "llama",
|
| 373 |
+
"356": "weasel",
|
| 374 |
+
"357": "mink",
|
| 375 |
+
"358": "polecat, fitch, foulmart, foumart, Mustela putorius",
|
| 376 |
+
"359": "black-footed ferret, ferret, Mustela nigripes",
|
| 377 |
+
"360": "otter",
|
| 378 |
+
"361": "skunk, polecat, wood pussy",
|
| 379 |
+
"362": "badger",
|
| 380 |
+
"363": "armadillo",
|
| 381 |
+
"364": "three-toed sloth, ai, Bradypus tridactylus",
|
| 382 |
+
"365": "orangutan, orang, orangutang, Pongo pygmaeus",
|
| 383 |
+
"366": "gorilla, Gorilla gorilla",
|
| 384 |
+
"367": "chimpanzee, chimp, Pan troglodytes",
|
| 385 |
+
"368": "gibbon, Hylobates lar",
|
| 386 |
+
"369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
|
| 387 |
+
"370": "guenon, guenon monkey",
|
| 388 |
+
"371": "patas, hussar monkey, Erythrocebus patas",
|
| 389 |
+
"372": "baboon",
|
| 390 |
+
"373": "macaque",
|
| 391 |
+
"374": "langur",
|
| 392 |
+
"375": "colobus, colobus monkey",
|
| 393 |
+
"376": "proboscis monkey, Nasalis larvatus",
|
| 394 |
+
"377": "marmoset",
|
| 395 |
+
"378": "capuchin, ringtail, Cebus capucinus",
|
| 396 |
+
"379": "howler monkey, howler",
|
| 397 |
+
"380": "titi, titi monkey",
|
| 398 |
+
"381": "spider monkey, Ateles geoffroyi",
|
| 399 |
+
"382": "squirrel monkey, Saimiri sciureus",
|
| 400 |
+
"383": "Madagascar cat, ring-tailed lemur, Lemur catta",
|
| 401 |
+
"384": "indri, indris, Indri indri, Indri brevicaudatus",
|
| 402 |
+
"385": "Indian elephant, Elephas maximus",
|
| 403 |
+
"386": "African elephant, Loxodonta africana",
|
| 404 |
+
"387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
|
| 405 |
+
"388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
|
| 406 |
+
"389": "barracouta, snoek",
|
| 407 |
+
"390": "eel",
|
| 408 |
+
"391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
|
| 409 |
+
"392": "rock beauty, Holocanthus tricolor",
|
| 410 |
+
"393": "anemone fish",
|
| 411 |
+
"394": "sturgeon",
|
| 412 |
+
"395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
|
| 413 |
+
"396": "lionfish",
|
| 414 |
+
"397": "puffer, pufferfish, blowfish, globefish",
|
| 415 |
+
"398": "abacus",
|
| 416 |
+
"399": "abaya",
|
| 417 |
+
"400": "academic gown, academic robe, judge's robe",
|
| 418 |
+
"401": "accordion, piano accordion, squeeze box",
|
| 419 |
+
"402": "acoustic guitar",
|
| 420 |
+
"403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
| 421 |
+
"404": "airliner",
|
| 422 |
+
"405": "airship, dirigible",
|
| 423 |
+
"406": "altar",
|
| 424 |
+
"407": "ambulance",
|
| 425 |
+
"408": "amphibian, amphibious vehicle",
|
| 426 |
+
"409": "analog clock",
|
| 427 |
+
"410": "apiary, bee house",
|
| 428 |
+
"411": "apron",
|
| 429 |
+
"412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
| 430 |
+
"413": "assault rifle, assault gun",
|
| 431 |
+
"414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
| 432 |
+
"415": "bakery, bakeshop, bakehouse",
|
| 433 |
+
"416": "balance beam, beam",
|
| 434 |
+
"417": "balloon",
|
| 435 |
+
"418": "ballpoint, ballpoint pen, ballpen, Biro",
|
| 436 |
+
"419": "Band Aid",
|
| 437 |
+
"420": "banjo",
|
| 438 |
+
"421": "bannister, banister, balustrade, balusters, handrail",
|
| 439 |
+
"422": "barbell",
|
| 440 |
+
"423": "barber chair",
|
| 441 |
+
"424": "barbershop",
|
| 442 |
+
"425": "barn",
|
| 443 |
+
"426": "barometer",
|
| 444 |
+
"427": "barrel, cask",
|
| 445 |
+
"428": "barrow, garden cart, lawn cart, wheelbarrow",
|
| 446 |
+
"429": "baseball",
|
| 447 |
+
"430": "basketball",
|
| 448 |
+
"431": "bassinet",
|
| 449 |
+
"432": "bassoon",
|
| 450 |
+
"433": "bathing cap, swimming cap",
|
| 451 |
+
"434": "bath towel",
|
| 452 |
+
"435": "bathtub, bathing tub, bath, tub",
|
| 453 |
+
"436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
|
| 454 |
+
"437": "beacon, lighthouse, beacon light, pharos",
|
| 455 |
+
"438": "beaker",
|
| 456 |
+
"439": "bearskin, busby, shako",
|
| 457 |
+
"440": "beer bottle",
|
| 458 |
+
"441": "beer glass",
|
| 459 |
+
"442": "bell cote, bell cot",
|
| 460 |
+
"443": "bib",
|
| 461 |
+
"444": "bicycle-built-for-two, tandem bicycle, tandem",
|
| 462 |
+
"445": "bikini, two-piece",
|
| 463 |
+
"446": "binder, ring-binder",
|
| 464 |
+
"447": "binoculars, field glasses, opera glasses",
|
| 465 |
+
"448": "birdhouse",
|
| 466 |
+
"449": "boathouse",
|
| 467 |
+
"450": "bobsled, bobsleigh, bob",
|
| 468 |
+
"451": "bolo tie, bolo, bola tie, bola",
|
| 469 |
+
"452": "bonnet, poke bonnet",
|
| 470 |
+
"453": "bookcase",
|
| 471 |
+
"454": "bookshop, bookstore, bookstall",
|
| 472 |
+
"455": "bottlecap",
|
| 473 |
+
"456": "bow",
|
| 474 |
+
"457": "bow tie, bow-tie, bowtie",
|
| 475 |
+
"458": "brass, memorial tablet, plaque",
|
| 476 |
+
"459": "brassiere, bra, bandeau",
|
| 477 |
+
"460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
|
| 478 |
+
"461": "breastplate, aegis, egis",
|
| 479 |
+
"462": "broom",
|
| 480 |
+
"463": "bucket, pail",
|
| 481 |
+
"464": "buckle",
|
| 482 |
+
"465": "bulletproof vest",
|
| 483 |
+
"466": "bullet train, bullet",
|
| 484 |
+
"467": "butcher shop, meat market",
|
| 485 |
+
"468": "cab, hack, taxi, taxicab",
|
| 486 |
+
"469": "caldron, cauldron",
|
| 487 |
+
"470": "candle, taper, wax light",
|
| 488 |
+
"471": "cannon",
|
| 489 |
+
"472": "canoe",
|
| 490 |
+
"473": "can opener, tin opener",
|
| 491 |
+
"474": "cardigan",
|
| 492 |
+
"475": "car mirror",
|
| 493 |
+
"476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
|
| 494 |
+
"477": "carpenter's kit, tool kit",
|
| 495 |
+
"478": "carton",
|
| 496 |
+
"479": "car wheel",
|
| 497 |
+
"480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
|
| 498 |
+
"481": "cassette",
|
| 499 |
+
"482": "cassette player",
|
| 500 |
+
"483": "castle",
|
| 501 |
+
"484": "catamaran",
|
| 502 |
+
"485": "CD player",
|
| 503 |
+
"486": "cello, violoncello",
|
| 504 |
+
"487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
| 505 |
+
"488": "chain",
|
| 506 |
+
"489": "chainlink fence",
|
| 507 |
+
"490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
|
| 508 |
+
"491": "chain saw, chainsaw",
|
| 509 |
+
"492": "chest",
|
| 510 |
+
"493": "chiffonier, commode",
|
| 511 |
+
"494": "chime, bell, gong",
|
| 512 |
+
"495": "china cabinet, china closet",
|
| 513 |
+
"496": "Christmas stocking",
|
| 514 |
+
"497": "church, church building",
|
| 515 |
+
"498": "cinema, movie theater, movie theatre, movie house, picture palace",
|
| 516 |
+
"499": "cleaver, meat cleaver, chopper",
|
| 517 |
+
"500": "cliff dwelling",
|
| 518 |
+
"501": "cloak",
|
| 519 |
+
"502": "clog, geta, patten, sabot",
|
| 520 |
+
"503": "cocktail shaker",
|
| 521 |
+
"504": "coffee mug",
|
| 522 |
+
"505": "coffeepot",
|
| 523 |
+
"506": "coil, spiral, volute, whorl, helix",
|
| 524 |
+
"507": "combination lock",
|
| 525 |
+
"508": "computer keyboard, keypad",
|
| 526 |
+
"509": "confectionery, confectionary, candy store",
|
| 527 |
+
"510": "container ship, containership, container vessel",
|
| 528 |
+
"511": "convertible",
|
| 529 |
+
"512": "corkscrew, bottle screw",
|
| 530 |
+
"513": "cornet, horn, trumpet, trump",
|
| 531 |
+
"514": "cowboy boot",
|
| 532 |
+
"515": "cowboy hat, ten-gallon hat",
|
| 533 |
+
"516": "cradle",
|
| 534 |
+
"517": "crane",
|
| 535 |
+
"518": "crash helmet",
|
| 536 |
+
"519": "crate",
|
| 537 |
+
"520": "crib, cot",
|
| 538 |
+
"521": "Crock Pot",
|
| 539 |
+
"522": "croquet ball",
|
| 540 |
+
"523": "crutch",
|
| 541 |
+
"524": "cuirass",
|
| 542 |
+
"525": "dam, dike, dyke",
|
| 543 |
+
"526": "desk",
|
| 544 |
+
"527": "desktop computer",
|
| 545 |
+
"528": "dial telephone, dial phone",
|
| 546 |
+
"529": "diaper, nappy, napkin",
|
| 547 |
+
"530": "digital clock",
|
| 548 |
+
"531": "digital watch",
|
| 549 |
+
"532": "dining table, board",
|
| 550 |
+
"533": "dishrag, dishcloth",
|
| 551 |
+
"534": "dishwasher, dish washer, dishwashing machine",
|
| 552 |
+
"535": "disk brake, disc brake",
|
| 553 |
+
"536": "dock, dockage, docking facility",
|
| 554 |
+
"537": "dogsled, dog sled, dog sleigh",
|
| 555 |
+
"538": "dome",
|
| 556 |
+
"539": "doormat, welcome mat",
|
| 557 |
+
"540": "drilling platform, offshore rig",
|
| 558 |
+
"541": "drum, membranophone, tympan",
|
| 559 |
+
"542": "drumstick",
|
| 560 |
+
"543": "dumbbell",
|
| 561 |
+
"544": "Dutch oven",
|
| 562 |
+
"545": "electric fan, blower",
|
| 563 |
+
"546": "electric guitar",
|
| 564 |
+
"547": "electric locomotive",
|
| 565 |
+
"548": "entertainment center",
|
| 566 |
+
"549": "envelope",
|
| 567 |
+
"550": "espresso maker",
|
| 568 |
+
"551": "face powder",
|
| 569 |
+
"552": "feather boa, boa",
|
| 570 |
+
"553": "file, file cabinet, filing cabinet",
|
| 571 |
+
"554": "fireboat",
|
| 572 |
+
"555": "fire engine, fire truck",
|
| 573 |
+
"556": "fire screen, fireguard",
|
| 574 |
+
"557": "flagpole, flagstaff",
|
| 575 |
+
"558": "flute, transverse flute",
|
| 576 |
+
"559": "folding chair",
|
| 577 |
+
"560": "football helmet",
|
| 578 |
+
"561": "forklift",
|
| 579 |
+
"562": "fountain",
|
| 580 |
+
"563": "fountain pen",
|
| 581 |
+
"564": "four-poster",
|
| 582 |
+
"565": "freight car",
|
| 583 |
+
"566": "French horn, horn",
|
| 584 |
+
"567": "frying pan, frypan, skillet",
|
| 585 |
+
"568": "fur coat",
|
| 586 |
+
"569": "garbage truck, dustcart",
|
| 587 |
+
"570": "gasmask, respirator, gas helmet",
|
| 588 |
+
"571": "gas pump, gasoline pump, petrol pump, island dispenser",
|
| 589 |
+
"572": "goblet",
|
| 590 |
+
"573": "go-kart",
|
| 591 |
+
"574": "golf ball",
|
| 592 |
+
"575": "golfcart, golf cart",
|
| 593 |
+
"576": "gondola",
|
| 594 |
+
"577": "gong, tam-tam",
|
| 595 |
+
"578": "gown",
|
| 596 |
+
"579": "grand piano, grand",
|
| 597 |
+
"580": "greenhouse, nursery, glasshouse",
|
| 598 |
+
"581": "grille, radiator grille",
|
| 599 |
+
"582": "grocery store, grocery, food market, market",
|
| 600 |
+
"583": "guillotine",
|
| 601 |
+
"584": "hair slide",
|
| 602 |
+
"585": "hair spray",
|
| 603 |
+
"586": "half track",
|
| 604 |
+
"587": "hammer",
|
| 605 |
+
"588": "hamper",
|
| 606 |
+
"589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
|
| 607 |
+
"590": "hand-held computer, hand-held microcomputer",
|
| 608 |
+
"591": "handkerchief, hankie, hanky, hankey",
|
| 609 |
+
"592": "hard disc, hard disk, fixed disk",
|
| 610 |
+
"593": "harmonica, mouth organ, harp, mouth harp",
|
| 611 |
+
"594": "harp",
|
| 612 |
+
"595": "harvester, reaper",
|
| 613 |
+
"596": "hatchet",
|
| 614 |
+
"597": "holster",
|
| 615 |
+
"598": "home theater, home theatre",
|
| 616 |
+
"599": "honeycomb",
|
| 617 |
+
"600": "hook, claw",
|
| 618 |
+
"601": "hoopskirt, crinoline",
|
| 619 |
+
"602": "horizontal bar, high bar",
|
| 620 |
+
"603": "horse cart, horse-cart",
|
| 621 |
+
"604": "hourglass",
|
| 622 |
+
"605": "iPod",
|
| 623 |
+
"606": "iron, smoothing iron",
|
| 624 |
+
"607": "jack-o'-lantern",
|
| 625 |
+
"608": "jean, blue jean, denim",
|
| 626 |
+
"609": "jeep, landrover",
|
| 627 |
+
"610": "jersey, T-shirt, tee shirt",
|
| 628 |
+
"611": "jigsaw puzzle",
|
| 629 |
+
"612": "jinrikisha, ricksha, rickshaw",
|
| 630 |
+
"613": "joystick",
|
| 631 |
+
"614": "kimono",
|
| 632 |
+
"615": "knee pad",
|
| 633 |
+
"616": "knot",
|
| 634 |
+
"617": "lab coat, laboratory coat",
|
| 635 |
+
"618": "ladle",
|
| 636 |
+
"619": "lampshade, lamp shade",
|
| 637 |
+
"620": "laptop, laptop computer",
|
| 638 |
+
"621": "lawn mower, mower",
|
| 639 |
+
"622": "lens cap, lens cover",
|
| 640 |
+
"623": "letter opener, paper knife, paperknife",
|
| 641 |
+
"624": "library",
|
| 642 |
+
"625": "lifeboat",
|
| 643 |
+
"626": "lighter, light, igniter, ignitor",
|
| 644 |
+
"627": "limousine, limo",
|
| 645 |
+
"628": "liner, ocean liner",
|
| 646 |
+
"629": "lipstick, lip rouge",
|
| 647 |
+
"630": "Loafer",
|
| 648 |
+
"631": "lotion",
|
| 649 |
+
"632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
| 650 |
+
"633": "loupe, jeweler's loupe",
|
| 651 |
+
"634": "lumbermill, sawmill",
|
| 652 |
+
"635": "magnetic compass",
|
| 653 |
+
"636": "mailbag, postbag",
|
| 654 |
+
"637": "mailbox, letter box",
|
| 655 |
+
"638": "maillot",
|
| 656 |
+
"639": "maillot, tank suit",
|
| 657 |
+
"640": "manhole cover",
|
| 658 |
+
"641": "maraca",
|
| 659 |
+
"642": "marimba, xylophone",
|
| 660 |
+
"643": "mask",
|
| 661 |
+
"644": "matchstick",
|
| 662 |
+
"645": "maypole",
|
| 663 |
+
"646": "maze, labyrinth",
|
| 664 |
+
"647": "measuring cup",
|
| 665 |
+
"648": "medicine chest, medicine cabinet",
|
| 666 |
+
"649": "megalith, megalithic structure",
|
| 667 |
+
"650": "microphone, mike",
|
| 668 |
+
"651": "microwave, microwave oven",
|
| 669 |
+
"652": "military uniform",
|
| 670 |
+
"653": "milk can",
|
| 671 |
+
"654": "minibus",
|
| 672 |
+
"655": "miniskirt, mini",
|
| 673 |
+
"656": "minivan",
|
| 674 |
+
"657": "missile",
|
| 675 |
+
"658": "mitten",
|
| 676 |
+
"659": "mixing bowl",
|
| 677 |
+
"660": "mobile home, manufactured home",
|
| 678 |
+
"661": "Model T",
|
| 679 |
+
"662": "modem",
|
| 680 |
+
"663": "monastery",
|
| 681 |
+
"664": "monitor",
|
| 682 |
+
"665": "moped",
|
| 683 |
+
"666": "mortar",
|
| 684 |
+
"667": "mortarboard",
|
| 685 |
+
"668": "mosque",
|
| 686 |
+
"669": "mosquito net",
|
| 687 |
+
"670": "motor scooter, scooter",
|
| 688 |
+
"671": "mountain bike, all-terrain bike, off-roader",
|
| 689 |
+
"672": "mountain tent",
|
| 690 |
+
"673": "mouse, computer mouse",
|
| 691 |
+
"674": "mousetrap",
|
| 692 |
+
"675": "moving van",
|
| 693 |
+
"676": "muzzle",
|
| 694 |
+
"677": "nail",
|
| 695 |
+
"678": "neck brace",
|
| 696 |
+
"679": "necklace",
|
| 697 |
+
"680": "nipple",
|
| 698 |
+
"681": "notebook, notebook computer",
|
| 699 |
+
"682": "obelisk",
|
| 700 |
+
"683": "oboe, hautboy, hautbois",
|
| 701 |
+
"684": "ocarina, sweet potato",
|
| 702 |
+
"685": "odometer, hodometer, mileometer, milometer",
|
| 703 |
+
"686": "oil filter",
|
| 704 |
+
"687": "organ, pipe organ",
|
| 705 |
+
"688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
|
| 706 |
+
"689": "overskirt",
|
| 707 |
+
"690": "oxcart",
|
| 708 |
+
"691": "oxygen mask",
|
| 709 |
+
"692": "packet",
|
| 710 |
+
"693": "paddle, boat paddle",
|
| 711 |
+
"694": "paddlewheel, paddle wheel",
|
| 712 |
+
"695": "padlock",
|
| 713 |
+
"696": "paintbrush",
|
| 714 |
+
"697": "pajama, pyjama, pj's, jammies",
|
| 715 |
+
"698": "palace",
|
| 716 |
+
"699": "panpipe, pandean pipe, syrinx",
|
| 717 |
+
"700": "paper towel",
|
| 718 |
+
"701": "parachute, chute",
|
| 719 |
+
"702": "parallel bars, bars",
|
| 720 |
+
"703": "park bench",
|
| 721 |
+
"704": "parking meter",
|
| 722 |
+
"705": "passenger car, coach, carriage",
|
| 723 |
+
"706": "patio, terrace",
|
| 724 |
+
"707": "pay-phone, pay-station",
|
| 725 |
+
"708": "pedestal, plinth, footstall",
|
| 726 |
+
"709": "pencil box, pencil case",
|
| 727 |
+
"710": "pencil sharpener",
|
| 728 |
+
"711": "perfume, essence",
|
| 729 |
+
"712": "Petri dish",
|
| 730 |
+
"713": "photocopier",
|
| 731 |
+
"714": "pick, plectrum, plectron",
|
| 732 |
+
"715": "pickelhaube",
|
| 733 |
+
"716": "picket fence, paling",
|
| 734 |
+
"717": "pickup, pickup truck",
|
| 735 |
+
"718": "pier",
|
| 736 |
+
"719": "piggy bank, penny bank",
|
| 737 |
+
"720": "pill bottle",
|
| 738 |
+
"721": "pillow",
|
| 739 |
+
"722": "ping-pong ball",
|
| 740 |
+
"723": "pinwheel",
|
| 741 |
+
"724": "pirate, pirate ship",
|
| 742 |
+
"725": "pitcher, ewer",
|
| 743 |
+
"726": "plane, carpenter's plane, woodworking plane",
|
| 744 |
+
"727": "planetarium",
|
| 745 |
+
"728": "plastic bag",
|
| 746 |
+
"729": "plate rack",
|
| 747 |
+
"730": "plow, plough",
|
| 748 |
+
"731": "plunger, plumber's helper",
|
| 749 |
+
"732": "Polaroid camera, Polaroid Land camera",
|
| 750 |
+
"733": "pole",
|
| 751 |
+
"734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
|
| 752 |
+
"735": "poncho",
|
| 753 |
+
"736": "pool table, billiard table, snooker table",
|
| 754 |
+
"737": "pop bottle, soda bottle",
|
| 755 |
+
"738": "pot, flowerpot",
|
| 756 |
+
"739": "potter's wheel",
|
| 757 |
+
"740": "power drill",
|
| 758 |
+
"741": "prayer rug, prayer mat",
|
| 759 |
+
"742": "printer",
|
| 760 |
+
"743": "prison, prison house",
|
| 761 |
+
"744": "projectile, missile",
|
| 762 |
+
"745": "projector",
|
| 763 |
+
"746": "puck, hockey puck",
|
| 764 |
+
"747": "punching bag, punch bag, punching ball, punchball",
|
| 765 |
+
"748": "purse",
|
| 766 |
+
"749": "quill, quill pen",
|
| 767 |
+
"750": "quilt, comforter, comfort, puff",
|
| 768 |
+
"751": "racer, race car, racing car",
|
| 769 |
+
"752": "racket, racquet",
|
| 770 |
+
"753": "radiator",
|
| 771 |
+
"754": "radio, wireless",
|
| 772 |
+
"755": "radio telescope, radio reflector",
|
| 773 |
+
"756": "rain barrel",
|
| 774 |
+
"757": "recreational vehicle, RV, R.V.",
|
| 775 |
+
"758": "reel",
|
| 776 |
+
"759": "reflex camera",
|
| 777 |
+
"760": "refrigerator, icebox",
|
| 778 |
+
"761": "remote control, remote",
|
| 779 |
+
"762": "restaurant, eating house, eating place, eatery",
|
| 780 |
+
"763": "revolver, six-gun, six-shooter",
|
| 781 |
+
"764": "rifle",
|
| 782 |
+
"765": "rocking chair, rocker",
|
| 783 |
+
"766": "rotisserie",
|
| 784 |
+
"767": "rubber eraser, rubber, pencil eraser",
|
| 785 |
+
"768": "rugby ball",
|
| 786 |
+
"769": "rule, ruler",
|
| 787 |
+
"770": "running shoe",
|
| 788 |
+
"771": "safe",
|
| 789 |
+
"772": "safety pin",
|
| 790 |
+
"773": "saltshaker, salt shaker",
|
| 791 |
+
"774": "sandal",
|
| 792 |
+
"775": "sarong",
|
| 793 |
+
"776": "sax, saxophone",
|
| 794 |
+
"777": "scabbard",
|
| 795 |
+
"778": "scale, weighing machine",
|
| 796 |
+
"779": "school bus",
|
| 797 |
+
"780": "schooner",
|
| 798 |
+
"781": "scoreboard",
|
| 799 |
+
"782": "screen, CRT screen",
|
| 800 |
+
"783": "screw",
|
| 801 |
+
"784": "screwdriver",
|
| 802 |
+
"785": "seat belt, seatbelt",
|
| 803 |
+
"786": "sewing machine",
|
| 804 |
+
"787": "shield, buckler",
|
| 805 |
+
"788": "shoe shop, shoe-shop, shoe store",
|
| 806 |
+
"789": "shoji",
|
| 807 |
+
"790": "shopping basket",
|
| 808 |
+
"791": "shopping cart",
|
| 809 |
+
"792": "shovel",
|
| 810 |
+
"793": "shower cap",
|
| 811 |
+
"794": "shower curtain",
|
| 812 |
+
"795": "ski",
|
| 813 |
+
"796": "ski mask",
|
| 814 |
+
"797": "sleeping bag",
|
| 815 |
+
"798": "slide rule, slipstick",
|
| 816 |
+
"799": "sliding door",
|
| 817 |
+
"800": "slot, one-armed bandit",
|
| 818 |
+
"801": "snorkel",
|
| 819 |
+
"802": "snowmobile",
|
| 820 |
+
"803": "snowplow, snowplough",
|
| 821 |
+
"804": "soap dispenser",
|
| 822 |
+
"805": "soccer ball",
|
| 823 |
+
"806": "sock",
|
| 824 |
+
"807": "solar dish, solar collector, solar furnace",
|
| 825 |
+
"808": "sombrero",
|
| 826 |
+
"809": "soup bowl",
|
| 827 |
+
"810": "space bar",
|
| 828 |
+
"811": "space heater",
|
| 829 |
+
"812": "space shuttle",
|
| 830 |
+
"813": "spatula",
|
| 831 |
+
"814": "speedboat",
|
| 832 |
+
"815": "spider web, spider's web",
|
| 833 |
+
"816": "spindle",
|
| 834 |
+
"817": "sports car, sport car",
|
| 835 |
+
"818": "spotlight, spot",
|
| 836 |
+
"819": "stage",
|
| 837 |
+
"820": "steam locomotive",
|
| 838 |
+
"821": "steel arch bridge",
|
| 839 |
+
"822": "steel drum",
|
| 840 |
+
"823": "stethoscope",
|
| 841 |
+
"824": "stole",
|
| 842 |
+
"825": "stone wall",
|
| 843 |
+
"826": "stopwatch, stop watch",
|
| 844 |
+
"827": "stove",
|
| 845 |
+
"828": "strainer",
|
| 846 |
+
"829": "streetcar, tram, tramcar, trolley, trolley car",
|
| 847 |
+
"830": "stretcher",
|
| 848 |
+
"831": "studio couch, day bed",
|
| 849 |
+
"832": "stupa, tope",
|
| 850 |
+
"833": "submarine, pigboat, sub, U-boat",
|
| 851 |
+
"834": "suit, suit of clothes",
|
| 852 |
+
"835": "sundial",
|
| 853 |
+
"836": "sunglass",
|
| 854 |
+
"837": "sunglasses, dark glasses, shades",
|
| 855 |
+
"838": "sunscreen, sunblock, sun blocker",
|
| 856 |
+
"839": "suspension bridge",
|
| 857 |
+
"840": "swab, swob, mop",
|
| 858 |
+
"841": "sweatshirt",
|
| 859 |
+
"842": "swimming trunks, bathing trunks",
|
| 860 |
+
"843": "swing",
|
| 861 |
+
"844": "switch, electric switch, electrical switch",
|
| 862 |
+
"845": "syringe",
|
| 863 |
+
"846": "table lamp",
|
| 864 |
+
"847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
|
| 865 |
+
"848": "tape player",
|
| 866 |
+
"849": "teapot",
|
| 867 |
+
"850": "teddy, teddy bear",
|
| 868 |
+
"851": "television, television system",
|
| 869 |
+
"852": "tennis ball",
|
| 870 |
+
"853": "thatch, thatched roof",
|
| 871 |
+
"854": "theater curtain, theatre curtain",
|
| 872 |
+
"855": "thimble",
|
| 873 |
+
"856": "thresher, thrasher, threshing machine",
|
| 874 |
+
"857": "throne",
|
| 875 |
+
"858": "tile roof",
|
| 876 |
+
"859": "toaster",
|
| 877 |
+
"860": "tobacco shop, tobacconist shop, tobacconist",
|
| 878 |
+
"861": "toilet seat",
|
| 879 |
+
"862": "torch",
|
| 880 |
+
"863": "totem pole",
|
| 881 |
+
"864": "tow truck, tow car, wrecker",
|
| 882 |
+
"865": "toyshop",
|
| 883 |
+
"866": "tractor",
|
| 884 |
+
"867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
|
| 885 |
+
"868": "tray",
|
| 886 |
+
"869": "trench coat",
|
| 887 |
+
"870": "tricycle, trike, velocipede",
|
| 888 |
+
"871": "trimaran",
|
| 889 |
+
"872": "tripod",
|
| 890 |
+
"873": "triumphal arch",
|
| 891 |
+
"874": "trolleybus, trolley coach, trackless trolley",
|
| 892 |
+
"875": "trombone",
|
| 893 |
+
"876": "tub, vat",
|
| 894 |
+
"877": "turnstile",
|
| 895 |
+
"878": "typewriter keyboard",
|
| 896 |
+
"879": "umbrella",
|
| 897 |
+
"880": "unicycle, monocycle",
|
| 898 |
+
"881": "upright, upright piano",
|
| 899 |
+
"882": "vacuum, vacuum cleaner",
|
| 900 |
+
"883": "vase",
|
| 901 |
+
"884": "vault",
|
| 902 |
+
"885": "velvet",
|
| 903 |
+
"886": "vending machine",
|
| 904 |
+
"887": "vestment",
|
| 905 |
+
"888": "viaduct",
|
| 906 |
+
"889": "violin, fiddle",
|
| 907 |
+
"890": "volleyball",
|
| 908 |
+
"891": "waffle iron",
|
| 909 |
+
"892": "wall clock",
|
| 910 |
+
"893": "wallet, billfold, notecase, pocketbook",
|
| 911 |
+
"894": "wardrobe, closet, press",
|
| 912 |
+
"895": "warplane, military plane",
|
| 913 |
+
"896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
|
| 914 |
+
"897": "washer, automatic washer, washing machine",
|
| 915 |
+
"898": "water bottle",
|
| 916 |
+
"899": "water jug",
|
| 917 |
+
"900": "water tower",
|
| 918 |
+
"901": "whiskey jug",
|
| 919 |
+
"902": "whistle",
|
| 920 |
+
"903": "wig",
|
| 921 |
+
"904": "window screen",
|
| 922 |
+
"905": "window shade",
|
| 923 |
+
"906": "Windsor tie",
|
| 924 |
+
"907": "wine bottle",
|
| 925 |
+
"908": "wing",
|
| 926 |
+
"909": "wok",
|
| 927 |
+
"910": "wooden spoon",
|
| 928 |
+
"911": "wool, woolen, woollen",
|
| 929 |
+
"912": "worm fence, snake fence, snake-rail fence, Virginia fence",
|
| 930 |
+
"913": "wreck",
|
| 931 |
+
"914": "yawl",
|
| 932 |
+
"915": "yurt",
|
| 933 |
+
"916": "web site, website, internet site, site",
|
| 934 |
+
"917": "comic book",
|
| 935 |
+
"918": "crossword puzzle, crossword",
|
| 936 |
+
"919": "street sign",
|
| 937 |
+
"920": "traffic light, traffic signal, stoplight",
|
| 938 |
+
"921": "book jacket, dust cover, dust jacket, dust wrapper",
|
| 939 |
+
"922": "menu",
|
| 940 |
+
"923": "plate",
|
| 941 |
+
"924": "guacamole",
|
| 942 |
+
"925": "consomme",
|
| 943 |
+
"926": "hot pot, hotpot",
|
| 944 |
+
"927": "trifle",
|
| 945 |
+
"928": "ice cream, icecream",
|
| 946 |
+
"929": "ice lolly, lolly, lollipop, popsicle",
|
| 947 |
+
"930": "French loaf",
|
| 948 |
+
"931": "bagel, beigel",
|
| 949 |
+
"932": "pretzel",
|
| 950 |
+
"933": "cheeseburger",
|
| 951 |
+
"934": "hotdog, hot dog, red hot",
|
| 952 |
+
"935": "mashed potato",
|
| 953 |
+
"936": "head cabbage",
|
| 954 |
+
"937": "broccoli",
|
| 955 |
+
"938": "cauliflower",
|
| 956 |
+
"939": "zucchini, courgette",
|
| 957 |
+
"940": "spaghetti squash",
|
| 958 |
+
"941": "acorn squash",
|
| 959 |
+
"942": "butternut squash",
|
| 960 |
+
"943": "cucumber, cuke",
|
| 961 |
+
"944": "artichoke, globe artichoke",
|
| 962 |
+
"945": "bell pepper",
|
| 963 |
+
"946": "cardoon",
|
| 964 |
+
"947": "mushroom",
|
| 965 |
+
"948": "Granny Smith",
|
| 966 |
+
"949": "strawberry",
|
| 967 |
+
"950": "orange",
|
| 968 |
+
"951": "lemon",
|
| 969 |
+
"952": "fig",
|
| 970 |
+
"953": "pineapple, ananas",
|
| 971 |
+
"954": "banana",
|
| 972 |
+
"955": "jackfruit, jak, jack",
|
| 973 |
+
"956": "custard apple",
|
| 974 |
+
"957": "pomegranate",
|
| 975 |
+
"958": "hay",
|
| 976 |
+
"959": "carbonara",
|
| 977 |
+
"960": "chocolate sauce, chocolate syrup",
|
| 978 |
+
"961": "dough",
|
| 979 |
+
"962": "meat loaf, meatloaf",
|
| 980 |
+
"963": "pizza, pizza pie",
|
| 981 |
+
"964": "potpie",
|
| 982 |
+
"965": "burrito",
|
| 983 |
+
"966": "red wine",
|
| 984 |
+
"967": "espresso",
|
| 985 |
+
"968": "cup",
|
| 986 |
+
"969": "eggnog",
|
| 987 |
+
"970": "alp",
|
| 988 |
+
"971": "bubble",
|
| 989 |
+
"972": "cliff, drop, drop-off",
|
| 990 |
+
"973": "coral reef",
|
| 991 |
+
"974": "geyser",
|
| 992 |
+
"975": "lakeside, lakeshore",
|
| 993 |
+
"976": "promontory, headland, head, foreland",
|
| 994 |
+
"977": "sandbar, sand bar",
|
| 995 |
+
"978": "seashore, coast, seacoast, sea-coast",
|
| 996 |
+
"979": "valley, vale",
|
| 997 |
+
"980": "volcano",
|
| 998 |
+
"981": "ballplayer, baseball player",
|
| 999 |
+
"982": "groom, bridegroom",
|
| 1000 |
+
"983": "scuba diver",
|
| 1001 |
+
"984": "rapeseed",
|
| 1002 |
+
"985": "daisy",
|
| 1003 |
+
"986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
| 1004 |
+
"987": "corn",
|
| 1005 |
+
"988": "acorn",
|
| 1006 |
+
"989": "hip, rose hip, rosehip",
|
| 1007 |
+
"990": "buckeye, horse chestnut, conker",
|
| 1008 |
+
"991": "coral fungus",
|
| 1009 |
+
"992": "agaric",
|
| 1010 |
+
"993": "gyromitra",
|
| 1011 |
+
"994": "stinkhorn, carrion fungus",
|
| 1012 |
+
"995": "earthstar",
|
| 1013 |
+
"996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
|
| 1014 |
+
"997": "bolete",
|
| 1015 |
+
"998": "ear, spike, capitulum",
|
| 1016 |
+
"999": "toilet tissue, toilet paper, bathroom tissue"
|
| 1017 |
+
}
|
| 1018 |
+
}
|
var/D3HR/DiT-XL/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DDIMScheduler",
|
| 3 |
+
"_diffusers_version": "0.12.0.dev0",
|
| 4 |
+
"beta_end": 0.02,
|
| 5 |
+
"beta_schedule": "linear",
|
| 6 |
+
"beta_start": 0.0001,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"num_train_timesteps": 1000,
|
| 9 |
+
"prediction_type": "epsilon",
|
| 10 |
+
"set_alpha_to_one": true,
|
| 11 |
+
"steps_offset": 0,
|
| 12 |
+
"trained_betas": null
|
| 13 |
+
}
|
var/D3HR/DiT-XL/transformer/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "Transformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.12.0.dev0",
|
| 4 |
+
"activation_fn": "gelu-approximate",
|
| 5 |
+
"attention_bias": true,
|
| 6 |
+
"attention_head_dim": 72,
|
| 7 |
+
"cross_attention_dim": null,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"in_channels": 4,
|
| 10 |
+
"norm_elementwise_affine": false,
|
| 11 |
+
"norm_num_groups": 32,
|
| 12 |
+
"norm_type": "ada_norm_zero",
|
| 13 |
+
"num_attention_heads": 16,
|
| 14 |
+
"num_embeds_ada_norm": 1000,
|
| 15 |
+
"num_layers": 28,
|
| 16 |
+
"num_vector_embeds": null,
|
| 17 |
+
"only_cross_attention": false,
|
| 18 |
+
"out_channels": 8,
|
| 19 |
+
"patch_size": 2,
|
| 20 |
+
"sample_size": 32,
|
| 21 |
+
"upcast_attention": false,
|
| 22 |
+
"use_linear_projection": false
|
| 23 |
+
}
|
var/D3HR/DiT-XL/transformer/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e592d64df5a579691e65d2b245641a00bb070b652e2c5ca775cce20a729ce9d9
|
| 3 |
+
size 2999533581
|
var/D3HR/DiT-XL/vae/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.12.0.dev0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-ema",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 4,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"norm_num_groups": 32,
|
| 22 |
+
"out_channels": 3,
|
| 23 |
+
"sample_size": 256,
|
| 24 |
+
"up_block_types": [
|
| 25 |
+
"UpDecoderBlock2D",
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
"UpDecoderBlock2D",
|
| 28 |
+
"UpDecoderBlock2D"
|
| 29 |
+
]
|
| 30 |
+
}
|
var/D3HR/DiT-XL/vae/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ec99ed96663e2418dba665762930f9eae8884e6b0a223fd53507931e8446eba
|
| 3 |
+
size 334711857
|
var/D3HR/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Taming Diffusion for Dataset Distillation with High Representativeness (ICML 2025)
|
| 2 |
+
|
| 3 |
+
This repository is the official implementation of the paper:
|
| 4 |
+
|
| 5 |
+
[**Taming Diffusion for Dataset Distillation with High Representativeness**](https://www.arxiv.org/pdf/2505.18399)
|
| 6 |
+
[*Lin Zhao*](https://lin-zhao-resolve.github.io/),
|
| 7 |
+
[*Yushu Wu*](https://wuyushuwys.github.io/),
|
| 8 |
+
[*Xinru Jiang*](https://oshikaka.github.io/),
|
| 9 |
+
[*Jianyang Gu*](https://vimar-gu.github.io/),
|
| 10 |
+
[*Yanzhi Wang*](https://coe.northeastern.edu/people/wang-yanzhi/),
|
| 11 |
+
[*Xiaolin Xu*](https://www.xiaolinxu.com/),
|
| 12 |
+
[*Pu Zhao*](https://puzhao.info/),
|
| 13 |
+
[*Xue Lin*](https://coe.northeastern.edu/people/lin-xue/),
|
| 14 |
+
ICML, 2025.
|
| 15 |
+
|
| 16 |
+
<div align=center>
|
| 17 |
+
<img width=85% src="./imgs/framework.jpg"/>
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
|
| 22 |
+
1. [Distilled Datasets](#distilled-datasets)
|
| 23 |
+
2. [Setup](#setup)
|
| 24 |
+
3. [Step1: DDIM inversion and distribution matching](#step1-ddim-inversion-and-distribution-matching)
|
| 25 |
+
4. [Step2: Group sampling](#step2-group-sampling)
|
| 26 |
+
6. [Evaluation](#evaluation)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Distilled Datasets
|
| 31 |
+
We provide distilled datasets with different IPCs generated by our method on Huggingface🤗! [*Imagenet-1K*](https://www.image-net.org/), [*Tiny-Imagenet*](https://www.kaggle.com/c/tiny-imagenet), [*CIFAR10*](https://www.cs.toronto.edu/~kriz/cifar.html), [*CIFAR100*](https://www.cs.toronto.edu/~kriz/cifar.html) datasets for users to use directly.
|
| 32 |
+
|
| 33 |
+
🔥Distilled datasets for Imagenet-1K: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/imagenet1k_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/imagenet1k_50ipc)
|
| 34 |
+
|
| 35 |
+
🔥Distilled datasets for Tiny-Imagenet: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/tinyimagenet_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/tinyimagenet_50ipc)
|
| 36 |
+
|
| 37 |
+
🔥Distilled datasets for CIFAR10: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar10_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar10_50ipc)
|
| 38 |
+
|
| 39 |
+
🔥Distilled datasets for CIFAR100: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar100_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar100_50ipc)
|
| 40 |
+
|
| 41 |
+
Besides, if you want to use the D3HR to generate distilled datasets by yourself, run the following steps:
|
| 42 |
+
|
| 43 |
+
## Setup
|
| 44 |
+
|
| 45 |
+
To install the required dependencies, use the following commands:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
conda create -n D3HR python=3.10
|
| 49 |
+
conda activate D3HR
|
| 50 |
+
cd D3HR
|
| 51 |
+
pip install -e .
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Step1: DDIM inversion and distribution matching
|
| 55 |
+
|
| 56 |
+
### (1) load pretrained model
|
| 57 |
+
For Imagenet-1K dataset, you can just use the pretrained [DiT](https://github.com/facebookresearch/DiT) model in huggingface:
|
| 58 |
+
```bash
|
| 59 |
+
huggingface-cli download facebook/DiT-XL-2-256 --local-dir <your_local_path>
|
| 60 |
+
```
|
| 61 |
+
For other datasets, you must first fine-tune the pretrained DiT model on the dataset ([github repo](https://github.com/facebookresearch/DiT)), then continue.
|
| 62 |
+
|
| 63 |
+
### (2) perform DDIM inversion and distribution matching to obtain the statistic information
|
| 64 |
+
```bash
|
| 65 |
+
sh generation/dit_inversion_save_statistic.sh
|
| 66 |
+
```
|
| 67 |
+
Note: By default, we store the results at 15 timesteps (23 < t < 39) to support the experiments in Section 6.2.
|
| 68 |
+
|
| 69 |
+
## Step2: Group sampling
|
| 70 |
+
```bash
|
| 71 |
+
sh generation/group_sampling.sh
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
## Evaluation
|
| 76 |
+
```bash
|
| 77 |
+
sh validation/validate.sh
|
| 78 |
+
```
|
| 79 |
+
Note: The .sh script includes several configuration options—select the one that best fits your needs.
|
| 80 |
+
|
| 81 |
+
## Acknowledgement
|
| 82 |
+
This project is mainly developed based on:
|
| 83 |
+
[DiT](https://github.com/facebookresearch/DiT)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
## Contact
|
| 87 |
+
If you have any questions, please contact zhao.lin1@northeastern.edu.
|
| 88 |
+
|
| 89 |
+
## Citation
|
| 90 |
+
If you find our work useful, please cite:
|
| 91 |
+
|
| 92 |
+
```BiBTeX
|
| 93 |
+
@inproceedings{zhaotaming,
|
| 94 |
+
title={Taming Diffusion for Dataset Distillation with High Representativeness},
|
| 95 |
+
author={Zhao, Lin and Wu, Yushu and Jiang, Xinru and Gu, Jianyang and Wang, Yanzhi and Xu, Xiaolin and Zhao, Pu and Lin, Xue},
|
| 96 |
+
booktitle={Forty-second International Conference on Machine Learning}
|
| 97 |
+
}
|
| 98 |
+
```
|
var/D3HR/ds_inf/imagenet1k_train.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fe7183ea2495a3d15bf50036db1047ea387bf397c8c6f1a0bcc30f42df957ce
|
| 3 |
+
size 64470500
|
var/D3HR/ds_inf/imagenet_1k_mapping.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
var/D3HR/ds_inf/tiny-imagenet-mapping.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
n02124075 285
|
| 2 |
+
n04067472 758
|
| 3 |
+
n04540053 890
|
| 4 |
+
n04099969 765
|
| 5 |
+
n07749582 951
|
| 6 |
+
n01641577 30
|
| 7 |
+
n02802426 430
|
| 8 |
+
n09246464 972
|
| 9 |
+
n07920052 967
|
| 10 |
+
n03970156 731
|
| 11 |
+
n03891332 704
|
| 12 |
+
n02106662 235
|
| 13 |
+
n03201208 532
|
| 14 |
+
n02279972 323
|
| 15 |
+
n02132136 294
|
| 16 |
+
n04146614 779
|
| 17 |
+
n07873807 963
|
| 18 |
+
n02364673 338
|
| 19 |
+
n04507155 879
|
| 20 |
+
n03854065 687
|
| 21 |
+
n03838899 683
|
| 22 |
+
n03733131 645
|
| 23 |
+
n01443537 1
|
| 24 |
+
n07875152 964
|
| 25 |
+
n03544143 604
|
| 26 |
+
n09428293 978
|
| 27 |
+
n03085013 508
|
| 28 |
+
n02437312 354
|
| 29 |
+
n07614500 928
|
| 30 |
+
n03804744 677
|
| 31 |
+
n04265275 811
|
| 32 |
+
n02963159 474
|
| 33 |
+
n02486410 372
|
| 34 |
+
n01944390 113
|
| 35 |
+
n09256479 973
|
| 36 |
+
n02058221 146
|
| 37 |
+
n04275548 815
|
| 38 |
+
n02321529 329
|
| 39 |
+
n02769748 414
|
| 40 |
+
n02099712 208
|
| 41 |
+
n07695742 932
|
| 42 |
+
n02056570 145
|
| 43 |
+
n02281406 325
|
| 44 |
+
n01774750 76
|
| 45 |
+
n02509815 387
|
| 46 |
+
n03983396 737
|
| 47 |
+
n07753592 954
|
| 48 |
+
n04254777 806
|
| 49 |
+
n02233338 314
|
| 50 |
+
n04008634 744
|
| 51 |
+
n02823428 440
|
| 52 |
+
n02236044 315
|
| 53 |
+
n03393912 565
|
| 54 |
+
n07583066 924
|
| 55 |
+
n04074963 761
|
| 56 |
+
n01629819 25
|
| 57 |
+
n09332890 975
|
| 58 |
+
n02481823 367
|
| 59 |
+
n03902125 707
|
| 60 |
+
n03404251 568
|
| 61 |
+
n09193705 970
|
| 62 |
+
n03637318 619
|
| 63 |
+
n04456115 862
|
| 64 |
+
n02666196 398
|
| 65 |
+
n03796401 675
|
| 66 |
+
n02795169 427
|
| 67 |
+
n02123045 281
|
| 68 |
+
n01855672 99
|
| 69 |
+
n01882714 105
|
| 70 |
+
n02917067 466
|
| 71 |
+
n02988304 485
|
| 72 |
+
n04398044 849
|
| 73 |
+
n02843684 448
|
| 74 |
+
n02423022 353
|
| 75 |
+
n02669723 400
|
| 76 |
+
n04465501 866
|
| 77 |
+
n02165456 301
|
| 78 |
+
n03770439 655
|
| 79 |
+
n02099601 207
|
| 80 |
+
n04486054 873
|
| 81 |
+
n02950826 471
|
| 82 |
+
n03814639 678
|
| 83 |
+
n04259630 808
|
| 84 |
+
n03424325 570
|
| 85 |
+
n02948072 470
|
| 86 |
+
n03179701 526
|
| 87 |
+
n03400231 567
|
| 88 |
+
n02206856 309
|
| 89 |
+
n03160309 525
|
| 90 |
+
n01984695 123
|
| 91 |
+
n03977966 734
|
| 92 |
+
n03584254 605
|
| 93 |
+
n04023962 747
|
| 94 |
+
n02814860 437
|
| 95 |
+
n01910747 107
|
| 96 |
+
n04596742 909
|
| 97 |
+
n03992509 739
|
| 98 |
+
n04133789 774
|
| 99 |
+
n03937543 720
|
| 100 |
+
n02927161 467
|
| 101 |
+
n01945685 114
|
| 102 |
+
n02395406 341
|
| 103 |
+
n02125311 286
|
| 104 |
+
n03126707 517
|
| 105 |
+
n04532106 887
|
| 106 |
+
n02268443 319
|
| 107 |
+
n02977058 480
|
| 108 |
+
n07734744 947
|
| 109 |
+
n03599486 612
|
| 110 |
+
n04562935 900
|
| 111 |
+
n03014705 492
|
| 112 |
+
n04251144 801
|
| 113 |
+
n04356056 837
|
| 114 |
+
n02190166 308
|
| 115 |
+
n03670208 627
|
| 116 |
+
n02002724 128
|
| 117 |
+
n02074367 149
|
| 118 |
+
n04285008 817
|
| 119 |
+
n04560804 899
|
| 120 |
+
n04366367 839
|
| 121 |
+
n02403003 345
|
| 122 |
+
n07615774 929
|
| 123 |
+
n04501370 877
|
| 124 |
+
n03026506 496
|
| 125 |
+
n02906734 462
|
| 126 |
+
n01770393 71
|
| 127 |
+
n04597913 910
|
| 128 |
+
n03930313 716
|
| 129 |
+
n04118538 768
|
| 130 |
+
n04179913 786
|
| 131 |
+
n04311004 821
|
| 132 |
+
n02123394 283
|
| 133 |
+
n04070727 760
|
| 134 |
+
n02793495 425
|
| 135 |
+
n02730930 411
|
| 136 |
+
n02094433 187
|
| 137 |
+
n04371430 842
|
| 138 |
+
n04328186 826
|
| 139 |
+
n03649909 621
|
| 140 |
+
n04417672 853
|
| 141 |
+
n03388043 562
|
| 142 |
+
n01774384 75
|
| 143 |
+
n02837789 445
|
| 144 |
+
n07579787 923
|
| 145 |
+
n04399382 850
|
| 146 |
+
n02791270 424
|
| 147 |
+
n03089624 509
|
| 148 |
+
n02814533 436
|
| 149 |
+
n04149813 781
|
| 150 |
+
n07747607 950
|
| 151 |
+
n03355925 557
|
| 152 |
+
n01983481 122
|
| 153 |
+
n04487081 874
|
| 154 |
+
n03250847 542
|
| 155 |
+
n03255030 543
|
| 156 |
+
n02892201 458
|
| 157 |
+
n02883205 457
|
| 158 |
+
n03100240 511
|
| 159 |
+
n02415577 349
|
| 160 |
+
n02480495 365
|
| 161 |
+
n01698640 50
|
| 162 |
+
n01784675 79
|
| 163 |
+
n04376876 845
|
| 164 |
+
n03444034 573
|
| 165 |
+
n01917289 109
|
| 166 |
+
n01950731 115
|
| 167 |
+
n03042490 500
|
| 168 |
+
n07711569 935
|
| 169 |
+
n04532670 888
|
| 170 |
+
n03763968 652
|
| 171 |
+
n07768694 957
|
| 172 |
+
n02999410 488
|
| 173 |
+
n03617480 614
|
| 174 |
+
n06596364 917
|
| 175 |
+
n01768244 69
|
| 176 |
+
n02410509 347
|
| 177 |
+
n03976657 733
|
| 178 |
+
n01742172 61
|
| 179 |
+
n03980874 735
|
| 180 |
+
n02808440 435
|
| 181 |
+
n02226429 311
|
| 182 |
+
n02231487 313
|
| 183 |
+
n02085620 151
|
| 184 |
+
n01644900 32
|
| 185 |
+
n02129165 291
|
| 186 |
+
n02699494 406
|
| 187 |
+
n03837869 682
|
| 188 |
+
n02815834 438
|
| 189 |
+
n07720875 945
|
| 190 |
+
n02788148 421
|
| 191 |
+
n02909870 463
|
| 192 |
+
n03706229 635
|
| 193 |
+
n07871810 962
|
| 194 |
+
n03447447 576
|
| 195 |
+
n02113799 267
|
| 196 |
+
n12267677 988
|
| 197 |
+
n03662601 625
|
| 198 |
+
n02841315 447
|
| 199 |
+
n07715103 938
|
| 200 |
+
n02504458 386
|
var/D3HR/generation/__init__.py
ADDED
|
File without changes
|
var/D3HR/generation/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
var/D3HR/generation/__pycache__/dit_inversion_save_statistic.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
var/D3HR/generation/dit_inversion_save_statistic.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from matplotlib import pyplot as plt
|
| 7 |
+
from torchvision import transforms as tfms
|
| 8 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler, DiTPipeline
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
from scipy import io
|
| 12 |
+
from diffusers import DiTPipeline
|
| 13 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 15 |
+
import ipdb
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from itertools import islice
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
# sample
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def sample(
|
| 24 |
+
pipe,
|
| 25 |
+
class_labels,
|
| 26 |
+
start_step=0,
|
| 27 |
+
start_latents=None,
|
| 28 |
+
guidance_scale=4.0,
|
| 29 |
+
num_inference_steps=30,
|
| 30 |
+
do_classifier_free_guidance=True,
|
| 31 |
+
device=None,
|
| 32 |
+
):
|
| 33 |
+
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
latent_size = pipe.transformer.config.sample_size
|
| 36 |
+
latent_channels = pipe.transformer.config.in_channels
|
| 37 |
+
if start_latents == None:
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
shape=(batch_size, latent_channels, latent_size, latent_size),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=pipe._execution_device,
|
| 42 |
+
dtype=pipe.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
latents = start_latents.clone()
|
| 46 |
+
|
| 47 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 48 |
+
|
| 49 |
+
class_labels = torch.tensor(class_labels, device=device).reshape(-1)
|
| 50 |
+
class_null = torch.tensor([1000] * batch_size, device=device)
|
| 51 |
+
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
|
| 52 |
+
class_labels_input = class_labels_input.to(device)
|
| 53 |
+
|
| 54 |
+
# set step values
|
| 55 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
for i in tqdm(range(start_step, num_inference_steps)):
|
| 59 |
+
|
| 60 |
+
t = pipe.scheduler.timesteps[i]
|
| 61 |
+
|
| 62 |
+
if do_classifier_free_guidance:
|
| 63 |
+
half = latent_model_input[: len(latent_model_input) // 2]
|
| 64 |
+
latent_model_input = torch.cat([half, half], dim=0)
|
| 65 |
+
|
| 66 |
+
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
|
| 67 |
+
|
| 68 |
+
timesteps = t
|
| 69 |
+
if not torch.is_tensor(timesteps):
|
| 70 |
+
is_mps = latent_model_input.device.type == "mps"
|
| 71 |
+
if isinstance(timesteps, float):
|
| 72 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 73 |
+
else:
|
| 74 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 75 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
|
| 76 |
+
elif len(timesteps.shape) == 0:
|
| 77 |
+
timesteps = timesteps[None].to(latent_model_input.device)
|
| 78 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 79 |
+
timesteps = timesteps.expand(latent_model_input.shape[0])
|
| 80 |
+
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
|
| 81 |
+
# predict noise model_output
|
| 82 |
+
noise_pred = pipe.transformer(
|
| 83 |
+
latent_model_input, timestep=timesteps, class_labels=class_labels_input
|
| 84 |
+
).sample
|
| 85 |
+
|
| 86 |
+
# Perform guidance
|
| 87 |
+
if do_classifier_free_guidance:
|
| 88 |
+
# perform guidance
|
| 89 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 90 |
+
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 91 |
+
|
| 92 |
+
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
|
| 93 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 94 |
+
|
| 95 |
+
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
| 96 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 97 |
+
|
| 98 |
+
noise_pred = torch.cat([eps, rest], dim=1)
|
| 99 |
+
|
| 100 |
+
# learned sigma
|
| 101 |
+
if pipe.transformer.config.out_channels // 2 == latent_channels:
|
| 102 |
+
|
| 103 |
+
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
|
| 104 |
+
else:
|
| 105 |
+
model_output = noise_pred
|
| 106 |
+
|
| 107 |
+
# calculate ddim:
|
| 108 |
+
prev_t = max(1, t.item() - (1000 // num_inference_steps)) # t-1
|
| 109 |
+
alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
|
| 110 |
+
alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
|
| 111 |
+
predicted_x0 = (latent_model_input - (1 - alpha_t).sqrt() * model_output) / alpha_t.sqrt()
|
| 112 |
+
direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * model_output
|
| 113 |
+
latent_model_input = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt
|
| 114 |
+
# latent_model_input = pipe.scheduler.step(model_output, t, latent_model_input).prev_sample
|
| 115 |
+
|
| 116 |
+
if guidance_scale > 1:
|
| 117 |
+
latents, _ = latent_model_input.chunk(2, dim=0)
|
| 118 |
+
else:
|
| 119 |
+
latents = latent_model_input
|
| 120 |
+
|
| 121 |
+
latents = 1 / pipe.vae.config.scaling_factor * latents
|
| 122 |
+
samples = pipe.vae.decode(latents).sample
|
| 123 |
+
|
| 124 |
+
samples = (samples / 2 + 0.5).clamp(0, 1)
|
| 125 |
+
|
| 126 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 127 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 128 |
+
|
| 129 |
+
return samples
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Inversion
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def invert(
|
| 139 |
+
pipe,
|
| 140 |
+
start_latents,
|
| 141 |
+
class_labels,
|
| 142 |
+
guidance_scale=4.0,
|
| 143 |
+
num_inference_steps=80,
|
| 144 |
+
do_classifier_free_guidance=True,
|
| 145 |
+
device=None,
|
| 146 |
+
):
|
| 147 |
+
|
| 148 |
+
batch_size = len(class_labels)
|
| 149 |
+
latent_size = pipe.transformer.config.sample_size
|
| 150 |
+
latent_channels = pipe.transformer.config.in_channels
|
| 151 |
+
|
| 152 |
+
latents = start_latents.clone()
|
| 153 |
+
|
| 154 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 155 |
+
|
| 156 |
+
class_labels = torch.tensor(class_labels, device=device).reshape(-1)
|
| 157 |
+
class_null = torch.tensor([1000] * batch_size, device=device)
|
| 158 |
+
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
|
| 159 |
+
class_labels_input = class_labels_input.to(device)
|
| 160 |
+
|
| 161 |
+
# set step values
|
| 162 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
| 163 |
+
intermediate_latents = []
|
| 164 |
+
|
| 165 |
+
# Reversed timesteps <<<<<<<<<<<<<<<<<<<<
|
| 166 |
+
timesteps_all = reversed(pipe.scheduler.timesteps)
|
| 167 |
+
|
| 168 |
+
for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):
|
| 169 |
+
|
| 170 |
+
# We'll skip the final iteration
|
| 171 |
+
if i >= num_inference_steps - 1 -10:
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
t = timesteps_all[i]
|
| 175 |
+
|
| 176 |
+
if do_classifier_free_guidance:
|
| 177 |
+
half = latent_model_input[: len(latent_model_input) // 2]
|
| 178 |
+
latent_model_input = torch.cat([half, half], dim=0)
|
| 179 |
+
|
| 180 |
+
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
|
| 181 |
+
|
| 182 |
+
timesteps = t
|
| 183 |
+
if not torch.is_tensor(timesteps):
|
| 184 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 185 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 186 |
+
is_mps = latent_model_input.device.type == "mps"
|
| 187 |
+
if isinstance(timesteps, float):
|
| 188 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 189 |
+
else:
|
| 190 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 191 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
|
| 192 |
+
elif len(timesteps.shape) == 0:
|
| 193 |
+
timesteps = timesteps[None].to(latent_model_input.device)
|
| 194 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 195 |
+
timesteps = timesteps.expand(latent_model_input.shape[0])
|
| 196 |
+
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
|
| 197 |
+
# predict noise model_output
|
| 198 |
+
noise_pred = pipe.transformer(
|
| 199 |
+
latent_model_input, timestep=timesteps, class_labels=class_labels_input
|
| 200 |
+
).sample
|
| 201 |
+
|
| 202 |
+
# Perform guidance
|
| 203 |
+
if do_classifier_free_guidance:
|
| 204 |
+
# perform guidance
|
| 205 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 206 |
+
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 207 |
+
|
| 208 |
+
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
|
| 209 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 210 |
+
|
| 211 |
+
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
| 212 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 213 |
+
|
| 214 |
+
noise_pred = torch.cat([eps, rest], dim=1)
|
| 215 |
+
|
| 216 |
+
# learned sigma
|
| 217 |
+
if pipe.transformer.config.out_channels // 2 == latent_channels:
|
| 218 |
+
|
| 219 |
+
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
|
| 220 |
+
else:
|
| 221 |
+
model_output = noise_pred
|
| 222 |
+
|
| 223 |
+
current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
|
| 224 |
+
next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
|
| 225 |
+
alpha_t = pipe.scheduler.alphas_cumprod[current_t]
|
| 226 |
+
alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
|
| 227 |
+
|
| 228 |
+
# Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
|
| 229 |
+
latent_model_input = (latent_model_input - (1 - alpha_t).sqrt() * model_output) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
|
| 230 |
+
1 - alpha_t_next
|
| 231 |
+
).sqrt() * model_output
|
| 232 |
+
|
| 233 |
+
if guidance_scale > 1:
|
| 234 |
+
latents_out, _ = latent_model_input.chunk(2, dim=0)
|
| 235 |
+
else:
|
| 236 |
+
latents_out = latent_model_input
|
| 237 |
+
|
| 238 |
+
# Store i=[3, 8, 13, 18, 23, 28, 33, 38, 43, 48]
|
| 239 |
+
# if (i+2)%5 == 0:
|
| 240 |
+
if i>23 and i<39:
|
| 241 |
+
intermediate_latents.append(latents_out)
|
| 242 |
+
return torch.stack(intermediate_latents, dim=0)
|
| 243 |
+
|
| 244 |
+
# return torch.cat(intermediate_latents)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def parse_args():
|
| 248 |
+
parser = argparse.ArgumentParser(
|
| 249 |
+
description="Script to train Stable Diffusion XL for InstructPix2Pix."
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--save_dir",
|
| 253 |
+
type=str,
|
| 254 |
+
default="/scratch/zhao.lin1/ddim_inversion_statistic",
|
| 255 |
+
help="statistic save path",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--mapping_file",
|
| 259 |
+
type=str,
|
| 260 |
+
default="ds_inf/imagenet_1k_mapping.json",
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument("--txt_file", default='ds_inf/imagenet1k_train.txt', type=str)
|
| 263 |
+
parser.add_argument("--pretrained_path", default='/scratch/zhao.lin1/DiT-XL-2-256', type=str)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--batch_size",
|
| 266 |
+
type=int,
|
| 267 |
+
default=200,
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--num_workers",
|
| 271 |
+
type=int,
|
| 272 |
+
default=24,
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--start",
|
| 276 |
+
type=int,
|
| 277 |
+
default=0,
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--end",
|
| 281 |
+
type=int,
|
| 282 |
+
default=25,
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--gpu",
|
| 286 |
+
type=int,
|
| 287 |
+
default=1,
|
| 288 |
+
)
|
| 289 |
+
args = parser.parse_args()
|
| 290 |
+
|
| 291 |
+
return args
|
| 292 |
+
|
| 293 |
+
def view_latents(pipe = None, inverted_latents = None):
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
|
| 296 |
+
pipe.numpy_to_pil(im)[0]
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def collate_fn(batch):
|
| 300 |
+
batch = [item for item in batch if item is not None]
|
| 301 |
+
|
| 302 |
+
if len(batch) == 0:
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
images = torch.stack([item['images'] for item in batch])
|
| 306 |
+
labels = torch.tensor([item['labels'] for item in batch])
|
| 307 |
+
idx = torch.tensor([item['idx'] for item in batch])
|
| 308 |
+
paths = [item['paths'] for item in batch]
|
| 309 |
+
|
| 310 |
+
return {
|
| 311 |
+
'images': images,
|
| 312 |
+
'labels': labels,
|
| 313 |
+
'idx':idx,
|
| 314 |
+
'paths': paths
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def save_latent(latent, save_path):
|
| 319 |
+
torch.save(latent, save_path)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ImageNetDataset(Dataset):
|
| 324 |
+
def __init__(self, txt_file='', mapping_file=None, class_dir=None):
|
| 325 |
+
self.images = []
|
| 326 |
+
self.img_labels = []
|
| 327 |
+
self.class_dir = class_dir
|
| 328 |
+
self.transform = self.get_transforms()
|
| 329 |
+
|
| 330 |
+
# Load class mapping and json file
|
| 331 |
+
self.wnid_to_index = load_mapping(mapping_file)
|
| 332 |
+
self._load_from_txt(txt_file)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _load_from_txt(self, txt_file):
|
| 336 |
+
with open(txt_file, "r") as file:
|
| 337 |
+
image_paths = file.readlines()
|
| 338 |
+
image_paths = [path.strip() for path in image_paths if path.split('/')[-2]==self.class_dir]
|
| 339 |
+
for path in image_paths:
|
| 340 |
+
self.images.append(path)
|
| 341 |
+
class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 342 |
+
self.img_labels.append(class_index)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_transforms(self):
|
| 346 |
+
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 347 |
+
# std=[0.229, 0.224, 0.225])
|
| 348 |
+
|
| 349 |
+
return transforms.Compose([
|
| 350 |
+
transforms.Resize(256),
|
| 351 |
+
transforms.CenterCrop(256),
|
| 352 |
+
transforms.ToTensor(),
|
| 353 |
+
# normalize
|
| 354 |
+
])
|
| 355 |
+
|
| 356 |
+
def __len__(self):
|
| 357 |
+
return len(self.images)
|
| 358 |
+
|
| 359 |
+
def __getitem__(self, idx):
|
| 360 |
+
img_path = self.images[idx]
|
| 361 |
+
try:
|
| 362 |
+
image = Image.open(img_path).convert('RGB')
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"Error loading image {img_path}: {e}")
|
| 365 |
+
# Return a black image in case of error
|
| 366 |
+
image = Image.new('RGB', (256, 256))
|
| 367 |
+
|
| 368 |
+
img_label = self.img_labels[idx]
|
| 369 |
+
|
| 370 |
+
if self.transform:
|
| 371 |
+
image = self.transform(image)
|
| 372 |
+
|
| 373 |
+
sample = {
|
| 374 |
+
'images': image,
|
| 375 |
+
'paths': img_path,
|
| 376 |
+
'labels': img_label,
|
| 377 |
+
'idx': idx
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
return sample
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def load_mapping(mapping_file):
|
| 384 |
+
new_mapping = {}
|
| 385 |
+
with open(mapping_file, 'r') as file:
|
| 386 |
+
data = json.load(file)
|
| 387 |
+
if "tiny" in mapping_file:
|
| 388 |
+
for index, line in enumerate(file):
|
| 389 |
+
# Extract wnid (eg. n01443537) for each line and -1
|
| 390 |
+
key = line.split()[0]
|
| 391 |
+
new_mapping[key] = index
|
| 392 |
+
else:
|
| 393 |
+
new_mapping = {item["wnid"]: item["index"] for item in data.values()}
|
| 394 |
+
return new_mapping
|
| 395 |
+
|
| 396 |
+
def main():
|
| 397 |
+
args = parse_args()
|
| 398 |
+
torch.cuda.set_device(args.gpu)
|
| 399 |
+
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| 400 |
+
|
| 401 |
+
torch.cuda.set_device(args.gpu)
|
| 402 |
+
|
| 403 |
+
wnid_to_index = load_mapping(args.mapping_file)
|
| 404 |
+
class_dirs = sorted(list(wnid_to_index.keys()))[args.start:args.end]
|
| 405 |
+
|
| 406 |
+
pipe = DiTPipeline.from_pretrained(args.pretrained_path, torch_dtype=torch.float16)
|
| 407 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
| 408 |
+
|
| 409 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
pipe = pipe.to(device)
|
| 413 |
+
|
| 414 |
+
for class_dir in tqdm(class_dirs):
|
| 415 |
+
imgnet1k_dataset = ImageNetDataset(args.txt_file, args.mapping_file, class_dir)
|
| 416 |
+
trainloader = torch.utils.data.DataLoader(imgnet1k_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, drop_last=False, collate_fn=collate_fn)
|
| 417 |
+
latents = []
|
| 418 |
+
for sample in tqdm(trainloader):
|
| 419 |
+
with torch.no_grad():
|
| 420 |
+
images = sample['images'].to(device)
|
| 421 |
+
latent = pipe.vae.encode(images.to(device, dtype=torch.float16) * 2 - 1)
|
| 422 |
+
ls = 0.18215 * latent.latent_dist.sample()
|
| 423 |
+
|
| 424 |
+
inverted_latents = invert(pipe,start_latents = ls, class_labels=sample['labels'], num_inference_steps = 50, device=device).cpu()
|
| 425 |
+
latents.append(torch.flatten(inverted_latents.permute(1,0,2,3,4), start_dim=2))
|
| 426 |
+
|
| 427 |
+
latents = torch.cat(latents, dim=0).cpu()
|
| 428 |
+
mean = latents.mean(dim=0)
|
| 429 |
+
variance = latents.var(dim=0)
|
| 430 |
+
torch.save({"mean": mean, "variance": variance}, os.path.join(args.save_dir,class_dir+'.pt'))
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
main()
|
var/D3HR/generation/dit_inversion_save_statistic.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
save_dir='/home/v-qichen3/debug/var/D3HR/ddim_inversion_statistic'
|
| 4 |
+
pretrained_path='/home/v-qichen3/debug/var/D3HR/DiT-XL'
|
| 5 |
+
|
| 6 |
+
# the range of class ids
|
| 7 |
+
# To improve efficiency, we distribute the generation of different classes across separate GPUs. You can change it to your own setting.
|
| 8 |
+
n=0
|
| 9 |
+
declare -a gpus=(0 1)
|
| 10 |
+
declare -a starts=($n $(($n+100)))
|
| 11 |
+
declare -a ends=($(($n+100)) $(($n+200)))
|
| 12 |
+
|
| 13 |
+
for i in ${!gpus[@]}; do
|
| 14 |
+
gpu=${gpus[$i]}
|
| 15 |
+
start=${starts[$i]}
|
| 16 |
+
end=${ends[$i]}
|
| 17 |
+
|
| 18 |
+
echo "Running on GPU $gpu with start=$start and end=$end"
|
| 19 |
+
python generation/dit_inversion_save_statistic.py --start $start --end $end --gpu $gpu --save_dir $save_dir --pretrained_path $pretrained_path &
|
| 20 |
+
done
|
| 21 |
+
|
| 22 |
+
# waiting for all tasks
|
| 23 |
+
wait
|
| 24 |
+
echo "All tasks completed."
|
var/D3HR/generation/group_sampling.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append('/home/zhao.lin1/D3HR')
|
| 5 |
+
from generation.dit_inversion_save_statistic import sample
|
| 6 |
+
from diffusers import DiTPipeline, DDIMScheduler
|
| 7 |
+
import json
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import ipdb
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
def save_average_latents(input_dir, save_dir):
|
| 15 |
+
latents_by_class = {}
|
| 16 |
+
|
| 17 |
+
# Iterate over each class folder in input_dir
|
| 18 |
+
for class_name in sorted(os.listdir(input_dir)):
|
| 19 |
+
class_dir = os.path.join(input_dir, class_name)
|
| 20 |
+
|
| 21 |
+
if os.path.isdir(class_dir):
|
| 22 |
+
latents = []
|
| 23 |
+
for file_name in sorted(os.listdir(class_dir)):
|
| 24 |
+
if file_name.endswith('.pt'):
|
| 25 |
+
file_path = os.path.join(class_dir, file_name)
|
| 26 |
+
latent = torch.load(file_path) # Load the latent .pth file
|
| 27 |
+
latents.append(latent)
|
| 28 |
+
|
| 29 |
+
# Store the latent vectors for this class
|
| 30 |
+
latents_tensor = torch.stack(latents)
|
| 31 |
+
average_latent = torch.mean(latents_tensor, dim=0)
|
| 32 |
+
|
| 33 |
+
# latents_by_class[class_name] = latents
|
| 34 |
+
|
| 35 |
+
save_file_path = os.path.join(save_dir, f'{class_name}_average_latent.pth')
|
| 36 |
+
torch.save(average_latent, save_file_path)
|
| 37 |
+
print(f"Saved average latent for class '{class_name}' to {save_file_path}")
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
def load_mapping(mapping_file):
|
| 41 |
+
new_mapping = {}
|
| 42 |
+
with open(mapping_file, 'r') as file:
|
| 43 |
+
data = json.load(file)
|
| 44 |
+
if "tiny" in mapping_file:
|
| 45 |
+
for index, line in enumerate(file):
|
| 46 |
+
#Extract the wnid starting with 'n' from each line and subtract 1 from the line number.
|
| 47 |
+
key = line.split()[0]
|
| 48 |
+
new_mapping[key] = index
|
| 49 |
+
else:
|
| 50 |
+
new_mapping = {item["wnid"]: item["index"] for item in data.values()}
|
| 51 |
+
return new_mapping
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def process_class(class_dir, folder_path):
|
| 55 |
+
data_list = []
|
| 56 |
+
class_path = os.path.join(folder_path, class_dir)
|
| 57 |
+
for file_name in os.listdir(class_path):
|
| 58 |
+
if file_name.endswith(".pt"):
|
| 59 |
+
file_path = os.path.join(class_path, file_name)
|
| 60 |
+
tensor = torch.load(file_path, map_location=torch.device('cpu'))
|
| 61 |
+
data_list.append(torch.flatten(tensor, start_dim=1))
|
| 62 |
+
|
| 63 |
+
if data_list:
|
| 64 |
+
data = torch.stack(data_list, dim=0)
|
| 65 |
+
mean = data.mean(dim=0)
|
| 66 |
+
variance = data.var(dim=0)
|
| 67 |
+
return class_dir, {"mean": mean, "variance": variance}
|
| 68 |
+
return class_dir, None
|
| 69 |
+
|
| 70 |
+
def process_p_sample(class_dir, folder_path):
|
| 71 |
+
data_list = []
|
| 72 |
+
class_path = os.path.join(folder_path, class_dir)
|
| 73 |
+
for file_name in os.listdir(class_path):
|
| 74 |
+
if file_name.endswith(".pt"):
|
| 75 |
+
file_path = os.path.join(class_path, file_name)
|
| 76 |
+
tensor = torch.load(file_path, map_location=torch.device('cpu'))
|
| 77 |
+
data_list.append(tensor.flatten())
|
| 78 |
+
|
| 79 |
+
data = torch.vstack(data_list)
|
| 80 |
+
return data
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def kl_divergence(selected_points, mean, std, device):
|
| 84 |
+
"""
|
| 85 |
+
Compute the KL divergence between the candidate distribution and the target Gaussian distribution.
|
| 86 |
+
KL(P || Q) = 0.5 * [tr(Sigma_Q^-1 Sigma_P) + (mu_Q - mu_P)^T Sigma_Q^-1 (mu_Q - mu_P) - k + log(det(Sigma_Q) / det(Sigma_P))]
|
| 87 |
+
Here, the target distribution Q is N(mean,cov),
|
| 88 |
+
and the sampling distribution P is estimated from the selected_points.
|
| 89 |
+
"""
|
| 90 |
+
# k = mean.size(0) # Feature dimension: 4090
|
| 91 |
+
selected_mean = selected_points.mean(dim=0)
|
| 92 |
+
selected_var = selected_points.var(dim=0)
|
| 93 |
+
selected_std = torch.sqrt(selected_var)
|
| 94 |
+
|
| 95 |
+
# Compute KL divergnece
|
| 96 |
+
diff = mean - selected_mean
|
| 97 |
+
log_sigma_ratio = torch.log(selected_std / std)
|
| 98 |
+
variance_ratio = (std**2 + diff**2) / (2 * selected_std**2)
|
| 99 |
+
kl = torch.sum(log_sigma_ratio + variance_ratio - 0.5)
|
| 100 |
+
|
| 101 |
+
return kl.item() # Return a scalar value
|
| 102 |
+
|
| 103 |
+
def kl_divergence_independent_batch(mean, std, samples, device):
|
| 104 |
+
mean = mean.to(device)
|
| 105 |
+
std = std.to(device)
|
| 106 |
+
samples = samples.to(device)
|
| 107 |
+
|
| 108 |
+
# Compute KL divergence
|
| 109 |
+
diff = samples - mean
|
| 110 |
+
term1 = torch.sum(diff**2, dim=1) / (2 * std**2)
|
| 111 |
+
kl_divs = term1 + torch.log(std) - 0.5 # log(std/std)=0
|
| 112 |
+
return kl_divs
|
| 113 |
+
|
| 114 |
+
def sinkhorn(A, B, epsilon=0.1, max_iter=1000, tol=1e-9):
|
| 115 |
+
"""
|
| 116 |
+
Estimate the Wasserstein distance using the Sinkhorn algorithm, which supports distributions with different numbers of samples.
|
| 117 |
+
A, B: The two input distributions
|
| 118 |
+
epsilon: Sinkhorn regularization parameter
|
| 119 |
+
max_iter: Maximum number of iterations
|
| 120 |
+
tol: Convergence tolerance
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# The amount of samples
|
| 124 |
+
n_a, n_b = A.size(0), B.size(0)
|
| 125 |
+
|
| 126 |
+
# Define weights and ensure normalization.
|
| 127 |
+
weight_a = torch.ones(n_a, device=A.device) / n_a
|
| 128 |
+
weight_b = torch.ones(n_b, device=B.device) / n_b
|
| 129 |
+
|
| 130 |
+
# Compute the distance matrix
|
| 131 |
+
C = torch.cdist(A, B, p=2) ** 2 # Squared Euclidean distance
|
| 132 |
+
|
| 133 |
+
# Initialize dual variables
|
| 134 |
+
u = torch.zeros(n_a, device=A.device)
|
| 135 |
+
v = torch.zeros(n_b, device=B.device)
|
| 136 |
+
|
| 137 |
+
K = torch.exp(-C / epsilon) # Regularized distance matrix
|
| 138 |
+
|
| 139 |
+
for _ in range(max_iter):
|
| 140 |
+
# Update u and c, consider weights simultaneously
|
| 141 |
+
u_new = epsilon * torch.log(weight_a) - epsilon * torch.logsumexp(-K / epsilon + v.view(1, -1), dim=1)
|
| 142 |
+
v_new = epsilon * torch.log(weight_b) - epsilon * torch.logsumexp(-K / epsilon + u_new.view(-1, 1), dim=0)
|
| 143 |
+
|
| 144 |
+
# Check convergence
|
| 145 |
+
if torch.max(torch.abs(u_new - u)) < tol and torch.max(torch.abs(v_new - v)) < tol:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
u, v = u_new, v_new
|
| 149 |
+
|
| 150 |
+
transport_cost = torch.sum(K * C)
|
| 151 |
+
wasserstein_distance = transport_cost + epsilon * (torch.sum(u * weight_a) + torch.sum(v * weight_b))
|
| 152 |
+
|
| 153 |
+
return wasserstein_distance
|
| 154 |
+
|
| 155 |
+
def skewness(tensor):
|
| 156 |
+
mean = torch.mean(tensor, dim=0)
|
| 157 |
+
std = torch.std(tensor, dim=0)
|
| 158 |
+
n = tensor.size(0)
|
| 159 |
+
skew = torch.sum(((tensor - mean) / std) ** 3, dim=0) * (n / ((n - 1) * (n - 2)))
|
| 160 |
+
return skew
|
| 161 |
+
|
| 162 |
+
def kurtosis(tensor):
|
| 163 |
+
mean = torch.mean(tensor, dim=0)
|
| 164 |
+
std = torch.std(tensor, dim=0)
|
| 165 |
+
n = tensor.size(0)
|
| 166 |
+
kurt = torch.sum(((tensor - mean) / std) ** 4, dim=0) * (n * (n + 1)) / ((n - 1) * (n - 2) * (n - 3)) - (3 * (n - 1) ** 2) / ((n - 2) * (n - 3))
|
| 167 |
+
return kurt
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def skewness_batch(tensor):
|
| 171 |
+
mean = torch.mean(tensor, dim=1, keepdim=True) # shape: [20000, 1, 4096]
|
| 172 |
+
std = torch.std(tensor, dim=1, keepdim=True) # shape: [20000, 1, 4096]
|
| 173 |
+
|
| 174 |
+
n = tensor.size(2) # feature dimension: 4096
|
| 175 |
+
skew = torch.sum(((tensor - mean) / std) ** 3, dim=2) * (n / ((n - 1) * (n - 2))) # shape: [20000, 1]
|
| 176 |
+
|
| 177 |
+
return skew
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def evaluate_distribution(samples, mean, std):
|
| 181 |
+
sample_mean = torch.mean(samples, dim=0)
|
| 182 |
+
mean_diff = torch.norm(sample_mean - mean)
|
| 183 |
+
|
| 184 |
+
sample_std = torch.std(samples, dim=0)
|
| 185 |
+
std_diff = torch.norm(sample_std - std)
|
| 186 |
+
|
| 187 |
+
sample_skew = skewness_batch(samples)
|
| 188 |
+
|
| 189 |
+
skew_diff = torch.norm(torch.tensor(sample_skew) - 0) # Sample Skewness close to 0
|
| 190 |
+
# kurt_diff = torch.norm(torch.tensor(sample_kurt) - 3) # Sample Kurtosis close to 3
|
| 191 |
+
|
| 192 |
+
# Comprehensive evaluation: each component can be weighted as needed
|
| 193 |
+
score = mean_diff + std_diff + 10*skew_diff
|
| 194 |
+
return score
|
| 195 |
+
|
| 196 |
+
def select_algorithm(n_trials, n_samples, mean, std, device):
|
| 197 |
+
best_score = float('inf')
|
| 198 |
+
best_sample = None
|
| 199 |
+
for _ in range(n_trials):
|
| 200 |
+
samples = torch.normal(mean.expand(n_samples, -1), std.expand(n_samples, -1)).to(device)
|
| 201 |
+
|
| 202 |
+
score = evaluate_distribution(samples, mean, std)
|
| 203 |
+
print(score)
|
| 204 |
+
|
| 205 |
+
# Choose sample with best score
|
| 206 |
+
if score < best_score:
|
| 207 |
+
best_score = score
|
| 208 |
+
best_sample = samples
|
| 209 |
+
|
| 210 |
+
return best_sample
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def evaluate_distribution_batch(samples, mean, std):
|
| 214 |
+
sample_mean = torch.mean(samples, dim=1)
|
| 215 |
+
mean_diff = torch.norm(sample_mean - mean, dim=1)
|
| 216 |
+
|
| 217 |
+
sample_std = torch.std(samples, dim=1)
|
| 218 |
+
std_diff = torch.norm(sample_std - std, dim=1)
|
| 219 |
+
|
| 220 |
+
# Compute samples Skewness
|
| 221 |
+
sample_skew = skewness_batch(samples)
|
| 222 |
+
# Batch computation of skewness differences, default is 0 -> sample_skew-0
|
| 223 |
+
skew_diff = torch.norm(sample_skew, dim=1)
|
| 224 |
+
|
| 225 |
+
# Comprehensive evaluation: each component can be weighted as needed
|
| 226 |
+
score = mean_diff + std_diff + 0.1 * skew_diff
|
| 227 |
+
return score
|
| 228 |
+
|
| 229 |
+
def select_algorithm_batch(n_trials, n_samples, mean, std, device, seed):
|
| 230 |
+
if seed is not None:
|
| 231 |
+
torch.manual_seed(seed)
|
| 232 |
+
# Batch computation, where n_trials indicates the batch size.
|
| 233 |
+
samples = torch.normal(mean.expand(n_trials, n_samples, -1), std.expand(n_trials, n_samples, -1)).to(device) # Batch sampling
|
| 234 |
+
scores = evaluate_distribution_batch(samples, mean, std) # Batch evaluating
|
| 235 |
+
|
| 236 |
+
best_score, best_idx = torch.min(scores, dim=0) # Find smaples with best (small) scores
|
| 237 |
+
worst_score, worst_idx = torch.max(scores, dim=0)
|
| 238 |
+
best_sample = samples[best_idx] # Get best samples
|
| 239 |
+
worst_sample = samples[worst_idx]
|
| 240 |
+
|
| 241 |
+
return best_sample, worst_sample, best_score, worst_score
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def parse_args():
|
| 246 |
+
parser = argparse.ArgumentParser(
|
| 247 |
+
description="Script to train Stable Diffusion XL for InstructPix2Pix."
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--mapping_file",
|
| 251 |
+
type=str,
|
| 252 |
+
default="ds_inf/imagenet_1k_mapping.json",
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument("--pretrained_path", default='/scratch/zhao.lin1/DiT-XL-2-256', type=str)
|
| 255 |
+
parser.add_argument("--save_dir", default='/scratch/zhao.lin1/distilled_images/', type=str)
|
| 256 |
+
parser.add_argument("--statistic_path", default='/scratch/zhao.lin1/ddim_inversion_statistic', type=str)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--start",
|
| 259 |
+
type=int,
|
| 260 |
+
default=0,
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--end",
|
| 264 |
+
type=int,
|
| 265 |
+
default=1000,
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--gpu",
|
| 269 |
+
type=int,
|
| 270 |
+
default=0,
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--ipc",
|
| 274 |
+
type=int,
|
| 275 |
+
default=20,
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--start_step",
|
| 279 |
+
type=int,
|
| 280 |
+
default=18,
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--i_step",
|
| 284 |
+
type=int,
|
| 285 |
+
default=6,
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--m",
|
| 289 |
+
type=int,
|
| 290 |
+
default=100000,
|
| 291 |
+
)
|
| 292 |
+
args = parser.parse_args()
|
| 293 |
+
|
| 294 |
+
return args
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def main():
|
| 298 |
+
args = parse_args()
|
| 299 |
+
|
| 300 |
+
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| 301 |
+
|
| 302 |
+
torch.cuda.set_device(device)
|
| 303 |
+
|
| 304 |
+
# x5_step i=[3, 8, 13, 18, 23, 28, 33, 38, 43, 48] start_step = [45, 40, 35, 30, 25, 20, 15, 10, 5, 0]
|
| 305 |
+
# 10-20_step i=[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39] start_step = [24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9]
|
| 306 |
+
|
| 307 |
+
wnid_to_index = load_mapping(args.mapping_file)
|
| 308 |
+
class_dirs = sorted(list(wnid_to_index.keys()))[args.start:args.end]
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
pipe = DiTPipeline.from_pretrained(args.pretrained_path, torch_dtype=torch.float16)
|
| 313 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
| 314 |
+
pipe = pipe.to(device)
|
| 315 |
+
|
| 316 |
+
for class_dir in tqdm(class_dirs):
|
| 317 |
+
# Compute the best and worst samples
|
| 318 |
+
statics = torch.load(os.path.join(args.statistic_path, class_dir+'.pt'))
|
| 319 |
+
mean = statics['mean'][args.i_step].to(device)
|
| 320 |
+
variance = statics['variance'][args.i_step].to(device)
|
| 321 |
+
std = torch.sqrt(variance)
|
| 322 |
+
|
| 323 |
+
latents_best = None
|
| 324 |
+
latents_worst = None
|
| 325 |
+
best_overall_score = float('inf') # initialize as inf
|
| 326 |
+
worst_overall_score = float('-inf') # initialize as -ing
|
| 327 |
+
|
| 328 |
+
# group sampling
|
| 329 |
+
for i in range(args.m//10):
|
| 330 |
+
seed = i * 12345
|
| 331 |
+
best_sample, worst_sample, best_score, worst_score = select_algorithm_batch(10000, args.ipc, mean, std, device, seed)
|
| 332 |
+
|
| 333 |
+
# Update best and worst samples
|
| 334 |
+
if best_score < best_overall_score:
|
| 335 |
+
best_overall_score = best_score
|
| 336 |
+
latents_best = best_sample
|
| 337 |
+
|
| 338 |
+
if worst_score > worst_overall_score:
|
| 339 |
+
worst_overall_score = worst_score
|
| 340 |
+
latents_worst = worst_sample
|
| 341 |
+
|
| 342 |
+
# Output results
|
| 343 |
+
print("Best overall score:", best_overall_score)
|
| 344 |
+
print("Worst overall score:", worst_overall_score)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
latents_best = latents_best.view(-1,4,32,32)
|
| 348 |
+
# latents_worst = latents_worst.view(-1,4,32,32)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Generate images
|
| 352 |
+
for k, latent in enumerate(latents_best):
|
| 353 |
+
image = sample(
|
| 354 |
+
pipe,
|
| 355 |
+
class_labels=torch.tensor(wnid_to_index[class_dir]).unsqueeze(0),
|
| 356 |
+
start_latents=latent.unsqueeze(0).to(torch.float16),
|
| 357 |
+
start_step=args.start_step,
|
| 358 |
+
num_inference_steps=50,
|
| 359 |
+
device=device
|
| 360 |
+
)
|
| 361 |
+
os.makedirs(os.path.join(args.save_dir,class_dir.split('/')[-1]), exist_ok=True)
|
| 362 |
+
pipe.numpy_to_pil(image)[0].resize((224, 224), Image.LANCZOS).save(os.path.join(args.save_dir,class_dir.split('/')[-1], str(k)+'.png'))
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
main()
|
| 368 |
+
|
var/D3HR/generation/group_sampling.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Define the range of classes
|
| 4 |
+
# To improve efficiency, we distribute the generation of different classes across separate GPUs. You can change it to your own setting.
|
| 5 |
+
n=0
|
| 6 |
+
declare -a gpus=(0 1)
|
| 7 |
+
declare -a starts=($n $(($n+100)))
|
| 8 |
+
declare -a ends=($(($n+100)) $(($n+200)))
|
| 9 |
+
|
| 10 |
+
for i in ${!gpus[@]}; do
|
| 11 |
+
gpu=${gpus[$i]}
|
| 12 |
+
start=${starts[$i]}
|
| 13 |
+
end=${ends[$i]}
|
| 14 |
+
|
| 15 |
+
echo "Running on GPU $gpu with start=$start and end=$end"
|
| 16 |
+
python generation/group_sampling.py --start $start --end $end --gpu $gpu &
|
| 17 |
+
done
|
| 18 |
+
|
| 19 |
+
wait
|
| 20 |
+
echo "All tasks completed."
|
var/D3HR/imgs/framework.jpg
ADDED
|
Git LFS Details
|
var/D3HR/imgs/framework.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adefcf20eb3eb575bc93846882e857b3cc38bb57e415c0317b0a9f2c9114d27d
|
| 3 |
+
size 301019
|
var/D3HR/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
torchvision==0.21.0
|
| 3 |
+
transformers==4.46.2
|
| 4 |
+
diffusers==0.28.0
|
| 5 |
+
matplotlib==3.10.1
|
| 6 |
+
ipdb==0.13.13
|
| 7 |
+
scipy==1.15.2
|
| 8 |
+
huggingface-hub==0.30.2
|
| 9 |
+
accelerate==1.3.0
|
var/D3HR/validation/__pycache__/argument.cpython-310.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|
var/D3HR/validation/argument.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def str2bool(v):
|
| 6 |
+
"""Cast string to boolean
|
| 7 |
+
"""
|
| 8 |
+
if isinstance(v, bool):
|
| 9 |
+
return v
|
| 10 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 11 |
+
return True
|
| 12 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 13 |
+
return False
|
| 14 |
+
else:
|
| 15 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
parser = argparse.ArgumentParser("EEF")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--arch-name",
|
| 21 |
+
type=str,
|
| 22 |
+
default="resnet18",
|
| 23 |
+
help="arch name from pretrained torchvision models",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--subset",
|
| 27 |
+
type=str,
|
| 28 |
+
default="imagenet-1k",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--spec",
|
| 32 |
+
type=str,
|
| 33 |
+
default="none",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--data-dir",
|
| 40 |
+
nargs='+',
|
| 41 |
+
default=["../data/imagenet"],
|
| 42 |
+
help="path to imagenet dataset",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--nclass",
|
| 46 |
+
type=int,
|
| 47 |
+
default=10,
|
| 48 |
+
help="number of classes for synthesis or validation",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--ipc",
|
| 52 |
+
type=int,
|
| 53 |
+
default=10,
|
| 54 |
+
help="number of images per class for synthesis or validation",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--target-ipc",
|
| 58 |
+
type=int,
|
| 59 |
+
default=50,
|
| 60 |
+
help="number of images per class for synthesis or validation",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--phase",
|
| 64 |
+
type=int,
|
| 65 |
+
default=0,
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--input-size",
|
| 69 |
+
default=224,
|
| 70 |
+
type=int,
|
| 71 |
+
metavar="S",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--save-size",
|
| 75 |
+
default=224,
|
| 76 |
+
type=int,
|
| 77 |
+
metavar="S",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--repeat",
|
| 81 |
+
default=1,
|
| 82 |
+
type=int,
|
| 83 |
+
help="Repeat times for the validation"
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--factor",
|
| 87 |
+
default=2,
|
| 88 |
+
type=int,
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--batch-size", default=0, type=int, metavar="N"
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--accum-steps",
|
| 95 |
+
type=int,
|
| 96 |
+
default=1,
|
| 97 |
+
help="gradient accumulation steps for small gpu memory",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--mix-type",
|
| 101 |
+
default="cutmix",
|
| 102 |
+
type=str,
|
| 103 |
+
choices=["mixup", "cutmix", None],
|
| 104 |
+
help="mixup or cutmix or None",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--stud-name",
|
| 108 |
+
type=str,
|
| 109 |
+
default="resnet18",
|
| 110 |
+
help="arch name from torchvision models",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--workers",
|
| 114 |
+
default=24,
|
| 115 |
+
type=int,
|
| 116 |
+
metavar="N",
|
| 117 |
+
help="number of data loading workers (default: 4)",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--temperature",
|
| 121 |
+
type=float,
|
| 122 |
+
help="temperature for distillation loss",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--min-scale-crops", type=float, default=0.08, help="argument in RandomResizedCrop"
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop"
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument("--epochs", default=300, type=int)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--results-dir",
|
| 133 |
+
type=str,
|
| 134 |
+
default="results",
|
| 135 |
+
help="where to store synthetic data",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--seed", default=42, type=int, help="seed for initializing training. "
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--mixup",
|
| 142 |
+
type=float,
|
| 143 |
+
default=0.8,
|
| 144 |
+
help="mixup alpha, mixup enabled if > 0. (default: 0.8)",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--cutmix",
|
| 148 |
+
type=float,
|
| 149 |
+
default=1.0,
|
| 150 |
+
help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument("--cos", default=True, help="cosine lr scheduler")
|
| 153 |
+
parser.add_argument("--verbose", type=str2bool, default=False)
|
| 154 |
+
parser.add_argument("--mapping_file", default="ds_inf/imagenet_1k_mapping.json", type=str)
|
| 155 |
+
parser.add_argument("--txt_file", default='/home/zhao.lin1/DD-DDIM-inversion/ds_inf/imagenet-1k/biggest_20%_ipc_for_all_1k.txt', type=str)
|
| 156 |
+
parser.add_argument("--val_txt_file", default='/home/zhao.lin1/CONCORD/val.txt', type=str)
|
| 157 |
+
# diffusion
|
| 158 |
+
parser.add_argument("--dit-model", default='DiT-XL/2')
|
| 159 |
+
parser.add_argument("--ckpt", type=str, default='pretrained_models/DiT-XL-2-256x256.pt',
|
| 160 |
+
help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
|
| 161 |
+
parser.add_argument("--dit-image-size", default=256, type=int)
|
| 162 |
+
parser.add_argument("--num-dit-classes", default=1000, type=int)
|
| 163 |
+
parser.add_argument("--diffusion-steps", default=1000, type=int)
|
| 164 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 165 |
+
|
| 166 |
+
parser.add_argument("--vae-path", default='stabilityai/sd-vae-ft-ema')
|
| 167 |
+
|
| 168 |
+
# distillation
|
| 169 |
+
parser.add_argument("--save-path", default='./results/test')
|
| 170 |
+
parser.add_argument("--description-path", default='./misc/class_description.json')
|
| 171 |
+
parser.add_argument("--clip-alpha", type=float, default=10.0)
|
| 172 |
+
parser.add_argument("--cls-alpha", type=float, default=10.0)
|
| 173 |
+
parser.add_argument("--num-neg-samples", type=int, default=5)
|
| 174 |
+
parser.add_argument("--neg-policy", type=str, default="weighted")
|
| 175 |
+
|
| 176 |
+
# sgd
|
| 177 |
+
parser.add_argument("--sgd", default=False, action="store_true", help="sgd optimizer")
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"-lr",
|
| 180 |
+
"--learning-rate",
|
| 181 |
+
type=float,
|
| 182 |
+
default=0.1,
|
| 183 |
+
help="sgd init learning rate",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument("--momentum", type=float, default=0.9, help="sgd momentum")
|
| 186 |
+
parser.add_argument("--weight-decay", type=float, default=1e-4, help="sgd weight decay")
|
| 187 |
+
|
| 188 |
+
# adamw
|
| 189 |
+
parser.add_argument("--adamw-lr", type=float, default=0, help="adamw learning rate")
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay"
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--exp-name",
|
| 195 |
+
type=str,
|
| 196 |
+
help="name of the experiment, subfolder under syn_data_path",
|
| 197 |
+
)
|
| 198 |
+
args = parser.parse_args()
|
| 199 |
+
|
| 200 |
+
# temperature
|
| 201 |
+
if args.mix_type == "mixup":
|
| 202 |
+
args.temperature = 4
|
| 203 |
+
elif args.mix_type == "cutmix":
|
| 204 |
+
args.temperature = 20
|
| 205 |
+
|
| 206 |
+
if args.subset == "imagenet_1k":
|
| 207 |
+
args.nclass = 1000
|
| 208 |
+
args.classes = range(args.nclass)
|
| 209 |
+
args.val_ipc = 50
|
| 210 |
+
args.input_size = 224
|
| 211 |
+
|
| 212 |
+
elif args.subset == "imagewoof":
|
| 213 |
+
args.nclass = 10
|
| 214 |
+
args.classes = range(args.nclass)
|
| 215 |
+
args.val_ipc = 50
|
| 216 |
+
args.input_size = 224
|
| 217 |
+
if args.ipc == 10:
|
| 218 |
+
args.epochs = 2000
|
| 219 |
+
elif args.ipc == 50:
|
| 220 |
+
args.epochs = 1500
|
| 221 |
+
else:
|
| 222 |
+
args.epochs = 1000
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
elif args.subset == "cifar10":
|
| 226 |
+
args.nclass = 10
|
| 227 |
+
args.classes = range(args.nclass)
|
| 228 |
+
args.val_ipc = 1000
|
| 229 |
+
args.input_size = 32
|
| 230 |
+
args.epochs = 1000
|
| 231 |
+
|
| 232 |
+
elif args.subset == "cifar100":
|
| 233 |
+
args.nclass = 100
|
| 234 |
+
args.classes = range(args.nclass)
|
| 235 |
+
args.val_ipc = 100
|
| 236 |
+
args.input_size = 32
|
| 237 |
+
args.epochs = 400
|
| 238 |
+
|
| 239 |
+
elif args.subset == "tinyimagenet":
|
| 240 |
+
args.nclass = 200
|
| 241 |
+
args.classes = range(args.nclass)
|
| 242 |
+
args.val_ipc = 50
|
| 243 |
+
args.input_size = 64
|
| 244 |
+
args.epochs = 300
|
| 245 |
+
|
| 246 |
+
# set up batch size
|
| 247 |
+
if args.batch_size == 0:
|
| 248 |
+
if args.ipc >= 50:
|
| 249 |
+
args.batch_size = 100
|
| 250 |
+
elif args.ipc >= 10:
|
| 251 |
+
args.batch_size = 50
|
| 252 |
+
elif args.ipc > 0:
|
| 253 |
+
args.batch_size = 15
|
| 254 |
+
elif args.ipc == -1:
|
| 255 |
+
args.batch_size = 100
|
| 256 |
+
|
| 257 |
+
if args.nclass == 10:
|
| 258 |
+
args.batch_size *= 1
|
| 259 |
+
if args.nclass == 100:
|
| 260 |
+
args.batch_size *= 2
|
| 261 |
+
# if args.nclass == 1000:
|
| 262 |
+
# args.batch_size *= 2
|
| 263 |
+
|
| 264 |
+
# reset batch size below ipc * nclass
|
| 265 |
+
if args.ipc != -1 and args.batch_size > args.ipc * args.nclass:
|
| 266 |
+
args.batch_size = int(args.ipc * args.nclass)
|
| 267 |
+
|
| 268 |
+
# reset batch size with accum_steps
|
| 269 |
+
if args.accum_steps != 1:
|
| 270 |
+
args.batch_size = int(args.batch_size / args.accum_steps)
|
| 271 |
+
|
| 272 |
+
# result dir for saving
|
| 273 |
+
args.exp_name = f"{args.spec}_{args.arch_name}_f{args.factor}_ipc{args.ipc}"
|
| 274 |
+
if not os.path.exists(f"./exp/{args.exp_name}"):
|
| 275 |
+
os.makedirs(f"./exp/{args.exp_name}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# adamw learning rate
|
| 279 |
+
if args.stud_name == "vgg11":
|
| 280 |
+
args.adamw_lr = 0.0005
|
| 281 |
+
elif args.stud_name == "conv3":
|
| 282 |
+
args.adamw_lr = 0.001
|
| 283 |
+
elif args.stud_name == "conv4":
|
| 284 |
+
args.adamw_lr = 0.001
|
| 285 |
+
elif args.stud_name == "conv5":
|
| 286 |
+
args.adamw_lr = 0.001
|
| 287 |
+
elif args.stud_name == "conv6":
|
| 288 |
+
args.adamw_lr = 0.001
|
| 289 |
+
elif args.stud_name == "resnet18":
|
| 290 |
+
args.adamw_lr = 0.001
|
| 291 |
+
elif args.stud_name == "resnet18_modified":
|
| 292 |
+
args.adamw_lr = 0.001
|
| 293 |
+
elif args.stud_name == "efficientnet_b0":
|
| 294 |
+
args.adamw_lr = 0.002
|
| 295 |
+
elif args.stud_name == "mobilenet_v2":
|
| 296 |
+
args.adamw_lr = 0.0025
|
| 297 |
+
elif args.stud_name == "alexnet":
|
| 298 |
+
args.adamw_lr = 0.0001
|
| 299 |
+
elif args.stud_name == "resnet50":
|
| 300 |
+
args.adamw_lr = 0.001
|
| 301 |
+
elif args.stud_name == "resnet50_modified":
|
| 302 |
+
args.adamw_lr = 0.001
|
| 303 |
+
elif args.stud_name == "resnet101":
|
| 304 |
+
args.adamw_lr = 0.001
|
| 305 |
+
elif args.stud_name == "resnet101_modified":
|
| 306 |
+
args.adamw_lr = 0.001
|
| 307 |
+
elif args.stud_name == "vit_b_16":
|
| 308 |
+
args.adamw_lr = 0.0001
|
| 309 |
+
elif args.stud_name == "swin_v2_t":
|
| 310 |
+
args.adamw_lr = 0.0001
|
var/D3HR/validation/get_train_list.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_mapping(mapping_file):
|
| 5 |
+
new_mapping = {}
|
| 6 |
+
with open(mapping_file, 'r') as file:
|
| 7 |
+
data = json.load(file)
|
| 8 |
+
if "tiny" in mapping_file:
|
| 9 |
+
for index, line in enumerate(file):
|
| 10 |
+
key = line.split()[0]
|
| 11 |
+
new_mapping[key] = index
|
| 12 |
+
else:
|
| 13 |
+
new_mapping = {item["wnid"]: item["index"] for item in data.values()}
|
| 14 |
+
return new_mapping
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
wnid_to_index = load_mapping("ds_inf/imagenet_1k_mapping.json")
|
| 18 |
+
class_dirs = sorted(list(wnid_to_index.keys()))
|
| 19 |
+
path_list = []
|
| 20 |
+
for class_dir in class_dirs:
|
| 21 |
+
for i in range(20):
|
| 22 |
+
path_list.append(os.path.join('/scratch/zhao.lin1/imagenet1k_256_4.0classfree_start_step_18_ddim_inversion_20_min_images_2/', class_dir, str(i)+'.png'))
|
| 23 |
+
output_file = "/scratch/zhao.lin1/imagenet1k_256_4.0classfree_start_step_18_ddim_inversion_20_min_images_2/train.txt"
|
| 24 |
+
with open(output_file, "w") as file:
|
| 25 |
+
for path in path_list:
|
| 26 |
+
file.write(path + "\n")
|
var/D3HR/validation/models/__init__.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as thmodels
|
| 4 |
+
|
| 5 |
+
from .convnet import ConvNet
|
| 6 |
+
from .resnet import resnet18, resnet50, resnet101, resnet152
|
| 7 |
+
from .mobilenet_v2 import mobilenetv2
|
| 8 |
+
# import timm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_model(model_name="resnet18", dataset="cifar10", spec='full', pretrained=True, input_size=224, classes=[]):
|
| 13 |
+
def get_model(model_name="resnet18"):
|
| 14 |
+
if "conv" in model_name:
|
| 15 |
+
size = input_size
|
| 16 |
+
nclass = 1000
|
| 17 |
+
|
| 18 |
+
model = ConvNet(
|
| 19 |
+
num_classes=nclass,
|
| 20 |
+
net_norm="batch",
|
| 21 |
+
net_act="relu",
|
| 22 |
+
net_pooling="avgpooling",
|
| 23 |
+
net_depth=int(model_name[-1]),
|
| 24 |
+
net_width=128,
|
| 25 |
+
channel=3,
|
| 26 |
+
im_size=(size, size),
|
| 27 |
+
)
|
| 28 |
+
elif model_name == 'resnet18':
|
| 29 |
+
model = resnet18(weights=None)
|
| 30 |
+
elif model_name == 'resnet50':
|
| 31 |
+
model = resnet50(weights=None)
|
| 32 |
+
elif model_name == 'resnet101':
|
| 33 |
+
model = resnet101(weights=None)
|
| 34 |
+
elif model_name == 'resnet152':
|
| 35 |
+
model = resnet152(weights=None)
|
| 36 |
+
elif model_name == 'mobilenet_v2':
|
| 37 |
+
model = mobilenetv2()
|
| 38 |
+
elif model_name == 'efficientnet_b0':
|
| 39 |
+
model = timm.create_model('efficientnet_b0.ra_in1k', pretrained=False)
|
| 40 |
+
elif model_name == "resnet18_modified":
|
| 41 |
+
model = thmodels.__dict__["resnet18"](pretrained=False)
|
| 42 |
+
model.conv1 = nn.Conv2d(
|
| 43 |
+
3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
|
| 44 |
+
)
|
| 45 |
+
model.maxpool = nn.Identity()
|
| 46 |
+
elif model_name == "resnet50_modified":
|
| 47 |
+
model = thmodels.__dict__["resnet50"](pretrained=False)
|
| 48 |
+
model.conv1 = nn.Conv2d(
|
| 49 |
+
3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
|
| 50 |
+
)
|
| 51 |
+
model.maxpool = nn.Identity()
|
| 52 |
+
elif model_name == "resnet101_modified":
|
| 53 |
+
model = thmodels.__dict__["resnet101"](pretrained=False)
|
| 54 |
+
model.conv1 = nn.Conv2d(
|
| 55 |
+
3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
|
| 56 |
+
)
|
| 57 |
+
model.maxpool = nn.Identity()
|
| 58 |
+
else:
|
| 59 |
+
model = thmodels.__dict__[model_name](weights=None)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
def pruning_classifier(model=None, classes=[]):
|
| 64 |
+
try:
|
| 65 |
+
model_named_parameters = [name for name, x in model.named_parameters()]
|
| 66 |
+
for name, x in model.named_parameters():
|
| 67 |
+
if (
|
| 68 |
+
name == model_named_parameters[-1]
|
| 69 |
+
or name == model_named_parameters[-2]
|
| 70 |
+
):
|
| 71 |
+
x.data = x[classes]
|
| 72 |
+
except:
|
| 73 |
+
print("ERROR in changing the number of classes.")
|
| 74 |
+
|
| 75 |
+
return model
|
| 76 |
+
|
| 77 |
+
model = get_model(model_name)
|
| 78 |
+
model = pruning_classifier(model, classes)
|
| 79 |
+
|
| 80 |
+
if pretrained:
|
| 81 |
+
if dataset == 'imagenet_1k':
|
| 82 |
+
if model_name == "efficientnet_b0":
|
| 83 |
+
checkpoint = timm.create_model('efficientnet_b0.ra_in1k', pretrained=True).state_dict()
|
| 84 |
+
model.load_state_dict(checkpoint)
|
| 85 |
+
elif model_name == 'conv4':
|
| 86 |
+
state_dict = torch.load('/home/linz/CONCORD/pretrained_models/imagenet-1k_conv4.pth')
|
| 87 |
+
model.load_state_dict(state_dict['model'])
|
| 88 |
+
elif model_name == 'resnet18':
|
| 89 |
+
model = resnet18(weights='DEFAULT')
|
| 90 |
+
elif model_name == 'mobilenet_v2':
|
| 91 |
+
model.load_state_dict(torch.load('/home/zhao.lin1/CONCORD/pretrained_models/mobilenetv2_1.0-0c6065bc.pth'))
|
| 92 |
+
else:
|
| 93 |
+
raise AttributeError(f'{model_name} is not supported in the pre-trained pool')
|
| 94 |
+
else:
|
| 95 |
+
checkpoint = torch.load(
|
| 96 |
+
f"pretrain_models/{dataset}_{model_name}.pth", map_location="cpu"
|
| 97 |
+
)
|
| 98 |
+
model.load_state_dict(checkpoint["model"])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
return model
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]):
|
| 105 |
+
# def get_model(model_name="resnet18"):
|
| 106 |
+
# if "conv" in model_name:
|
| 107 |
+
# if dataset in ["cifar10", "cifar100"]:
|
| 108 |
+
# size = 32
|
| 109 |
+
# elif dataset == "tinyimagenet":
|
| 110 |
+
# size = 64
|
| 111 |
+
# elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]:
|
| 112 |
+
# size = 128
|
| 113 |
+
# else:
|
| 114 |
+
# size = 224
|
| 115 |
+
|
| 116 |
+
# nclass = len(classes)
|
| 117 |
+
|
| 118 |
+
# model = ConvNet(
|
| 119 |
+
# num_classes=nclass,
|
| 120 |
+
# net_norm="batch",
|
| 121 |
+
# net_act="relu",
|
| 122 |
+
# net_pooling="avgpooling",
|
| 123 |
+
# net_depth=int(model_name[-1]),
|
| 124 |
+
# net_width=128,
|
| 125 |
+
# channel=3,
|
| 126 |
+
# im_size=(size, size),
|
| 127 |
+
# )
|
| 128 |
+
# elif model_name == "resnet18_modified":
|
| 129 |
+
# model = thmodels.__dict__["resnet18"](pretrained=False)
|
| 130 |
+
# model.conv1 = nn.Conv2d(
|
| 131 |
+
# 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
|
| 132 |
+
# )
|
| 133 |
+
# model.maxpool = nn.Identity()
|
| 134 |
+
# elif model_name == "resnet101_modified":
|
| 135 |
+
# model = thmodels.__dict__["resnet101"](pretrained=False)
|
| 136 |
+
# model.conv1 = nn.Conv2d(
|
| 137 |
+
# 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
|
| 138 |
+
# )
|
| 139 |
+
# model.maxpool = nn.Identity()
|
| 140 |
+
# else:
|
| 141 |
+
# model = thmodels.__dict__[model_name](pretrained=False)
|
| 142 |
+
|
| 143 |
+
# return model
|
| 144 |
+
|
| 145 |
+
# def pruning_classifier(model=None, classes=[]):
|
| 146 |
+
# try:
|
| 147 |
+
# model_named_parameters = [name for name, x in model.named_parameters()]
|
| 148 |
+
# for name, x in model.named_parameters():
|
| 149 |
+
# if (
|
| 150 |
+
# name == model_named_parameters[-1]
|
| 151 |
+
# or name == model_named_parameters[-2]
|
| 152 |
+
# ):
|
| 153 |
+
# x.data = x[classes]
|
| 154 |
+
# except:
|
| 155 |
+
# print("ERROR in changing the number of classes.")
|
| 156 |
+
|
| 157 |
+
# return model
|
| 158 |
+
|
| 159 |
+
# # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof"
|
| 160 |
+
# model = get_model(model_name)
|
| 161 |
+
# model = pruning_classifier(model, classes)
|
| 162 |
+
# if pretrained:
|
| 163 |
+
# if dataset in [
|
| 164 |
+
# "imagenet-100",
|
| 165 |
+
# "imagenet-10",
|
| 166 |
+
# "imagenet-nette",
|
| 167 |
+
# "imagenet-woof",
|
| 168 |
+
# "tinyimagenet",
|
| 169 |
+
# "cifar10",
|
| 170 |
+
# "cifar100",
|
| 171 |
+
# ]:
|
| 172 |
+
# checkpoint = torch.load(
|
| 173 |
+
# f"./data/pretrain_models/{dataset}_{model_name}.pth", map_location="cpu"
|
| 174 |
+
# )
|
| 175 |
+
# model.load_state_dict(checkpoint["model"])
|
| 176 |
+
# elif dataset in ["imagenet-1k"]:
|
| 177 |
+
# if model_name == "efficientNet-b0":
|
| 178 |
+
# # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
|
| 179 |
+
# from torchvision.models._api import WeightsEnum
|
| 180 |
+
# from torch.hub import load_state_dict_from_url
|
| 181 |
+
|
| 182 |
+
# def get_state_dict(self, *args, **kwargs):
|
| 183 |
+
# kwargs.pop("check_hash")
|
| 184 |
+
# return load_state_dict_from_url(self.url, *args, **kwargs)
|
| 185 |
+
|
| 186 |
+
# WeightsEnum.get_state_dict = get_state_dict
|
| 187 |
+
|
| 188 |
+
# model = thmodels.__dict__[model_name](pretrained=True)
|
| 189 |
+
|
| 190 |
+
# return model
|
var/D3HR/validation/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.97 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/convnet.cpython-310.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/convnet.cpython-37.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-310.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-37.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/resnet.cpython-310.pyc
ADDED
|
Binary file (2 kB). View file
|
|
|
var/D3HR/validation/models/__pycache__/resnet.cpython-37.pyc
ADDED
|
Binary file (2.33 kB). View file
|
|
|
var/D3HR/validation/models/convnet.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Conv-3 model
|
| 6 |
+
class ConvNet(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
num_classes,
|
| 10 |
+
net_norm="batch",
|
| 11 |
+
net_depth=3,
|
| 12 |
+
net_width=128,
|
| 13 |
+
channel=3,
|
| 14 |
+
net_act="relu",
|
| 15 |
+
net_pooling="avgpooling",
|
| 16 |
+
im_size=(32, 32),
|
| 17 |
+
):
|
| 18 |
+
# print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})")
|
| 19 |
+
super(ConvNet, self).__init__()
|
| 20 |
+
if net_act == "sigmoid":
|
| 21 |
+
self.net_act = nn.Sigmoid()
|
| 22 |
+
elif net_act == "relu":
|
| 23 |
+
self.net_act = nn.ReLU()
|
| 24 |
+
elif net_act == "leakyrelu":
|
| 25 |
+
self.net_act = nn.LeakyReLU(negative_slope=0.01)
|
| 26 |
+
else:
|
| 27 |
+
exit("unknown activation function: %s" % net_act)
|
| 28 |
+
|
| 29 |
+
if net_pooling == "maxpooling":
|
| 30 |
+
self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 31 |
+
elif net_pooling == "avgpooling":
|
| 32 |
+
self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 33 |
+
elif net_pooling == "none":
|
| 34 |
+
self.net_pooling = None
|
| 35 |
+
else:
|
| 36 |
+
exit("unknown net_pooling: %s" % net_pooling)
|
| 37 |
+
|
| 38 |
+
self.depth = net_depth
|
| 39 |
+
self.net_norm = net_norm
|
| 40 |
+
|
| 41 |
+
self.layers, shape_feat = self._make_layers(
|
| 42 |
+
channel, net_width, net_depth, net_norm, net_pooling, im_size
|
| 43 |
+
)
|
| 44 |
+
num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
|
| 45 |
+
self.classifier = nn.Linear(num_feat, num_classes)
|
| 46 |
+
|
| 47 |
+
def forward(self, x, return_features=False):
|
| 48 |
+
for d in range(self.depth):
|
| 49 |
+
x = self.layers["conv"][d](x)
|
| 50 |
+
if len(self.layers["norm"]) > 0:
|
| 51 |
+
x = self.layers["norm"][d](x)
|
| 52 |
+
x = self.layers["act"][d](x)
|
| 53 |
+
if len(self.layers["pool"]) > 0:
|
| 54 |
+
x = self.layers["pool"][d](x)
|
| 55 |
+
|
| 56 |
+
# x = nn.functional.avg_pool2d(x, x.shape[-1])
|
| 57 |
+
out = x.view(x.shape[0], -1)
|
| 58 |
+
logit = self.classifier(out)
|
| 59 |
+
|
| 60 |
+
if return_features:
|
| 61 |
+
return logit, out
|
| 62 |
+
else:
|
| 63 |
+
return logit
|
| 64 |
+
|
| 65 |
+
def get_feature(
|
| 66 |
+
self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False
|
| 67 |
+
):
|
| 68 |
+
if idx_to == -1:
|
| 69 |
+
idx_to = idx_from
|
| 70 |
+
features = []
|
| 71 |
+
|
| 72 |
+
for d in range(self.depth):
|
| 73 |
+
x = self.layers["conv"][d](x)
|
| 74 |
+
if self.net_norm:
|
| 75 |
+
x = self.layers["norm"][d](x)
|
| 76 |
+
x = self.layers["act"][d](x)
|
| 77 |
+
if self.net_pooling:
|
| 78 |
+
x = self.layers["pool"][d](x)
|
| 79 |
+
features.append(x)
|
| 80 |
+
if idx_to < len(features):
|
| 81 |
+
return features[idx_from : idx_to + 1]
|
| 82 |
+
|
| 83 |
+
if return_prob:
|
| 84 |
+
out = x.view(x.size(0), -1)
|
| 85 |
+
logit = self.classifier(out)
|
| 86 |
+
prob = torch.softmax(logit, dim=-1)
|
| 87 |
+
return features, prob
|
| 88 |
+
elif return_logit:
|
| 89 |
+
out = x.view(x.size(0), -1)
|
| 90 |
+
logit = self.classifier(out)
|
| 91 |
+
return features, logit
|
| 92 |
+
else:
|
| 93 |
+
return features[idx_from : idx_to + 1]
|
| 94 |
+
|
| 95 |
+
def _get_normlayer(self, net_norm, shape_feat):
|
| 96 |
+
# shape_feat = (c * h * w)
|
| 97 |
+
if net_norm == "batch":
|
| 98 |
+
norm = nn.BatchNorm2d(shape_feat[0], affine=True)
|
| 99 |
+
elif net_norm == "layer":
|
| 100 |
+
norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
|
| 101 |
+
elif net_norm == "instance":
|
| 102 |
+
norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
|
| 103 |
+
elif net_norm == "group":
|
| 104 |
+
norm = nn.GroupNorm(4, shape_feat[0], affine=True)
|
| 105 |
+
elif net_norm == "none":
|
| 106 |
+
norm = None
|
| 107 |
+
else:
|
| 108 |
+
norm = None
|
| 109 |
+
exit("unknown net_norm: %s" % net_norm)
|
| 110 |
+
return norm
|
| 111 |
+
|
| 112 |
+
def _make_layers(
|
| 113 |
+
self, channel, net_width, net_depth, net_norm, net_pooling, im_size
|
| 114 |
+
):
|
| 115 |
+
layers = {"conv": [], "norm": [], "act": [], "pool": []}
|
| 116 |
+
|
| 117 |
+
in_channels = channel
|
| 118 |
+
if im_size[0] == 28:
|
| 119 |
+
im_size = (32, 32)
|
| 120 |
+
shape_feat = [in_channels, im_size[0], im_size[1]]
|
| 121 |
+
|
| 122 |
+
for d in range(net_depth):
|
| 123 |
+
layers["conv"] += [
|
| 124 |
+
nn.Conv2d(
|
| 125 |
+
in_channels,
|
| 126 |
+
net_width,
|
| 127 |
+
kernel_size=3,
|
| 128 |
+
padding=3 if channel == 1 and d == 0 else 1,
|
| 129 |
+
)
|
| 130 |
+
]
|
| 131 |
+
shape_feat[0] = net_width
|
| 132 |
+
if net_norm != "none":
|
| 133 |
+
layers["norm"] += [self._get_normlayer(net_norm, shape_feat)]
|
| 134 |
+
layers["act"] += [self.net_act]
|
| 135 |
+
in_channels = net_width
|
| 136 |
+
if net_pooling != "none":
|
| 137 |
+
layers["pool"] += [self.net_pooling]
|
| 138 |
+
shape_feat[1] //= 2
|
| 139 |
+
shape_feat[2] //= 2
|
| 140 |
+
|
| 141 |
+
layers["conv"] = nn.ModuleList(layers["conv"])
|
| 142 |
+
layers["norm"] = nn.ModuleList(layers["norm"])
|
| 143 |
+
layers["act"] = nn.ModuleList(layers["act"])
|
| 144 |
+
layers["pool"] = nn.ModuleList(layers["pool"])
|
| 145 |
+
layers = nn.ModuleDict(layers)
|
| 146 |
+
|
| 147 |
+
return layers, shape_feat
|
var/D3HR/validation/models/dit_models.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.jit import Final
|
| 16 |
+
import numpy as np
|
| 17 |
+
import math
|
| 18 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
| 19 |
+
from timm.layers import use_fused_attn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def modulate(x, shift, scale):
|
| 23 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Attention(nn.Module):
|
| 27 |
+
fused_attn: Final[bool]
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
dim,
|
| 32 |
+
num_heads=8,
|
| 33 |
+
qkv_bias=False,
|
| 34 |
+
qk_norm=False,
|
| 35 |
+
attn_drop=0.,
|
| 36 |
+
proj_drop=0.,
|
| 37 |
+
norm_layer=nn.LayerNorm,
|
| 38 |
+
use_gamma=False
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
self.head_dim = dim // num_heads
|
| 44 |
+
self.scale = self.head_dim ** -0.5
|
| 45 |
+
self.fused_attn = use_fused_attn()
|
| 46 |
+
|
| 47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 48 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 49 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 50 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 51 |
+
self.proj = nn.Linear(dim, dim)
|
| 52 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 53 |
+
if use_gamma:
|
| 54 |
+
self.gamma1 = nn.Parameter(torch.ones(dim * 3))
|
| 55 |
+
self.gamma2 = nn.Parameter(torch.ones(dim))
|
| 56 |
+
else:
|
| 57 |
+
self.gamma1 = 1
|
| 58 |
+
self.gamma2 = 1
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
B, N, C = x.shape
|
| 62 |
+
qkv = (self.gamma1 * self.qkv(x)).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 63 |
+
q, k, v = qkv.unbind(0)
|
| 64 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 65 |
+
|
| 66 |
+
if self.fused_attn:
|
| 67 |
+
x = F.scaled_dot_product_attention(
|
| 68 |
+
q, k, v,
|
| 69 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
q = q * self.scale
|
| 73 |
+
attn = q @ k.transpose(-2, -1)
|
| 74 |
+
attn = attn.softmax(dim=-1)
|
| 75 |
+
attn = self.attn_drop(attn)
|
| 76 |
+
x = attn @ v
|
| 77 |
+
|
| 78 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 79 |
+
x = self.gamma2 * self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
#################################################################################
|
| 85 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 86 |
+
#################################################################################
|
| 87 |
+
|
| 88 |
+
class TimestepEmbedder(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Embeds scalar timesteps into vector representations.
|
| 91 |
+
"""
|
| 92 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.mlp = nn.Sequential(
|
| 95 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 96 |
+
nn.SiLU(),
|
| 97 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 98 |
+
)
|
| 99 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 103 |
+
"""
|
| 104 |
+
Create sinusoidal timestep embeddings.
|
| 105 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 106 |
+
These may be fractional.
|
| 107 |
+
:param dim: the dimension of the output.
|
| 108 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 109 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 110 |
+
"""
|
| 111 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 112 |
+
half = dim // 2
|
| 113 |
+
freqs = torch.exp(
|
| 114 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 115 |
+
).to(device=t.device)
|
| 116 |
+
args = t[:, None].float() * freqs[None]
|
| 117 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 118 |
+
if dim % 2:
|
| 119 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 120 |
+
return embedding
|
| 121 |
+
|
| 122 |
+
def forward(self, t):
|
| 123 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 124 |
+
t_emb = self.mlp(t_freq)
|
| 125 |
+
return t_emb
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class LabelEmbedder(nn.Module):
|
| 129 |
+
"""
|
| 130 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 131 |
+
"""
|
| 132 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 133 |
+
super().__init__()
|
| 134 |
+
use_cfg_embedding = dropout_prob > 0
|
| 135 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 136 |
+
self.num_classes = num_classes
|
| 137 |
+
self.dropout_prob = dropout_prob
|
| 138 |
+
|
| 139 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 140 |
+
"""
|
| 141 |
+
Drops labels to enable classifier-free guidance.
|
| 142 |
+
"""
|
| 143 |
+
if force_drop_ids is None:
|
| 144 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 145 |
+
else:
|
| 146 |
+
drop_ids = force_drop_ids == 1
|
| 147 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 148 |
+
return labels
|
| 149 |
+
|
| 150 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 151 |
+
use_dropout = self.dropout_prob > 0
|
| 152 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 153 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 154 |
+
embeddings = self.embedding_table(labels)
|
| 155 |
+
return embeddings
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
#################################################################################
|
| 159 |
+
# Core DiT Model #
|
| 160 |
+
#################################################################################
|
| 161 |
+
|
| 162 |
+
class DiTBlock(nn.Module):
|
| 163 |
+
"""
|
| 164 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 165 |
+
"""
|
| 166 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, use_gamma=False, **block_kwargs):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 169 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, use_gamma=use_gamma, **block_kwargs)
|
| 170 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 171 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 172 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 173 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 174 |
+
self.adaLN_modulation = nn.Sequential(
|
| 175 |
+
nn.SiLU(),
|
| 176 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 177 |
+
)
|
| 178 |
+
if use_gamma:
|
| 179 |
+
self.gamma1 = nn.Parameter(torch.ones(hidden_size))
|
| 180 |
+
self.gamma2 = nn.Parameter(torch.ones(hidden_size))
|
| 181 |
+
else:
|
| 182 |
+
self.gamma1 = 1
|
| 183 |
+
self.gamma2 = 1
|
| 184 |
+
|
| 185 |
+
def forward(self, x, c):
|
| 186 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 187 |
+
x = x + self.gamma1 * gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 188 |
+
x = x + self.gamma2 * gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FinalLayer(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
The final layer of DiT.
|
| 195 |
+
"""
|
| 196 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 199 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 200 |
+
self.adaLN_modulation = nn.Sequential(
|
| 201 |
+
nn.SiLU(),
|
| 202 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def forward(self, x, c):
|
| 206 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 207 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 208 |
+
x = self.linear(x)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class DiT(nn.Module):
|
| 213 |
+
"""
|
| 214 |
+
Diffusion model with a Transformer backbone.
|
| 215 |
+
"""
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
input_size=32,
|
| 219 |
+
patch_size=2,
|
| 220 |
+
in_channels=4,
|
| 221 |
+
hidden_size=1152,
|
| 222 |
+
depth=28,
|
| 223 |
+
num_heads=16,
|
| 224 |
+
mlp_ratio=4.0,
|
| 225 |
+
class_dropout_prob=0.1,
|
| 226 |
+
num_classes=1000,
|
| 227 |
+
learn_sigma=True,
|
| 228 |
+
):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.learn_sigma = learn_sigma
|
| 231 |
+
self.in_channels = in_channels
|
| 232 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 233 |
+
self.patch_size = patch_size
|
| 234 |
+
self.num_heads = num_heads
|
| 235 |
+
|
| 236 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 237 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 238 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 239 |
+
num_patches = self.x_embedder.num_patches
|
| 240 |
+
# Will use fixed sin-cos embedding:
|
| 241 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 242 |
+
|
| 243 |
+
use_gamma = [True] * 14 + [False] * 14
|
| 244 |
+
self.blocks = nn.ModuleList([
|
| 245 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, use_gamma=use_gamma[depth_index]) for depth_index in range(depth)
|
| 246 |
+
])
|
| 247 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 248 |
+
self.initialize_weights()
|
| 249 |
+
|
| 250 |
+
def initialize_weights(self):
|
| 251 |
+
# Initialize transformer layers:
|
| 252 |
+
def _basic_init(module):
|
| 253 |
+
if isinstance(module, nn.Linear):
|
| 254 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 255 |
+
if module.bias is not None:
|
| 256 |
+
nn.init.constant_(module.bias, 0)
|
| 257 |
+
self.apply(_basic_init)
|
| 258 |
+
|
| 259 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 260 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 261 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 262 |
+
|
| 263 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 264 |
+
w = self.x_embedder.proj.weight.data
|
| 265 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 266 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 267 |
+
|
| 268 |
+
# Initialize label embedding table:
|
| 269 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 270 |
+
|
| 271 |
+
# Initialize timestep embedding MLP:
|
| 272 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 273 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 274 |
+
|
| 275 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 276 |
+
for block in self.blocks:
|
| 277 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 278 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 279 |
+
|
| 280 |
+
# Zero-out output layers:
|
| 281 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 282 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 283 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 284 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 285 |
+
|
| 286 |
+
def unpatchify(self, x):
|
| 287 |
+
"""
|
| 288 |
+
x: (N, T, patch_size**2 * C)
|
| 289 |
+
imgs: (N, H, W, C)
|
| 290 |
+
"""
|
| 291 |
+
c = self.out_channels
|
| 292 |
+
p = self.x_embedder.patch_size[0]
|
| 293 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 294 |
+
assert h * w == x.shape[1]
|
| 295 |
+
|
| 296 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 297 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 298 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 299 |
+
return imgs
|
| 300 |
+
|
| 301 |
+
def forward(self, x, t, y):
|
| 302 |
+
"""
|
| 303 |
+
Forward pass of DiT.
|
| 304 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 305 |
+
t: (N,) tensor of diffusion timesteps
|
| 306 |
+
y: (N,) tensor of class labels
|
| 307 |
+
"""
|
| 308 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 309 |
+
t = self.t_embedder(t) # (N, D)
|
| 310 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 311 |
+
c = t + y # (N, D)
|
| 312 |
+
for block in self.blocks:
|
| 313 |
+
x = block(x, c) # (N, T, D)
|
| 314 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 315 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
def forward_with_cfg(self, x, t, y, cfg_scale, **kwargs):
|
| 319 |
+
"""
|
| 320 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 321 |
+
"""
|
| 322 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 323 |
+
half = x[: len(x) // 2]
|
| 324 |
+
combined = torch.cat([half, half], dim=0)
|
| 325 |
+
model_out = self.forward(combined, t, y)
|
| 326 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 327 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 328 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 329 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 330 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 331 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 332 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 333 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 334 |
+
return torch.cat([eps, rest], dim=1)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
#################################################################################
|
| 338 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 339 |
+
#################################################################################
|
| 340 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 341 |
+
|
| 342 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 343 |
+
"""
|
| 344 |
+
grid_size: int of the grid height and width
|
| 345 |
+
return:
|
| 346 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 347 |
+
"""
|
| 348 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 349 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 350 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 351 |
+
grid = np.stack(grid, axis=0)
|
| 352 |
+
|
| 353 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 354 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 355 |
+
if cls_token and extra_tokens > 0:
|
| 356 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 357 |
+
return pos_embed
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 361 |
+
assert embed_dim % 2 == 0
|
| 362 |
+
|
| 363 |
+
# use half of dimensions to encode grid_h
|
| 364 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 365 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 366 |
+
|
| 367 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 368 |
+
return emb
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 372 |
+
"""
|
| 373 |
+
embed_dim: output dimension for each position
|
| 374 |
+
pos: a list of positions to be encoded: size (M,)
|
| 375 |
+
out: (M, D)
|
| 376 |
+
"""
|
| 377 |
+
assert embed_dim % 2 == 0
|
| 378 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 379 |
+
omega /= embed_dim / 2.
|
| 380 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 381 |
+
|
| 382 |
+
pos = pos.reshape(-1) # (M,)
|
| 383 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 384 |
+
|
| 385 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 386 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 387 |
+
|
| 388 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 389 |
+
return emb
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
#################################################################################
|
| 393 |
+
# DiT Configs #
|
| 394 |
+
#################################################################################
|
| 395 |
+
|
| 396 |
+
def DiT_XL_2(**kwargs):
|
| 397 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 398 |
+
|
| 399 |
+
def DiT_XL_4(**kwargs):
|
| 400 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 401 |
+
|
| 402 |
+
def DiT_XL_8(**kwargs):
|
| 403 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 404 |
+
|
| 405 |
+
def DiT_L_2(**kwargs):
|
| 406 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 407 |
+
|
| 408 |
+
def DiT_L_4(**kwargs):
|
| 409 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 410 |
+
|
| 411 |
+
def DiT_L_8(**kwargs):
|
| 412 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 413 |
+
|
| 414 |
+
def DiT_B_2(**kwargs):
|
| 415 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 416 |
+
|
| 417 |
+
def DiT_B_4(**kwargs):
|
| 418 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 419 |
+
|
| 420 |
+
def DiT_B_8(**kwargs):
|
| 421 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 422 |
+
|
| 423 |
+
def DiT_S_2(**kwargs):
|
| 424 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 425 |
+
|
| 426 |
+
def DiT_S_4(**kwargs):
|
| 427 |
+
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 428 |
+
|
| 429 |
+
def DiT_S_8(**kwargs):
|
| 430 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
DiT_models = {
|
| 434 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
| 435 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
| 436 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
| 437 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
| 438 |
+
}
|
var/D3HR/validation/models/mobilenet_v2.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Creates a MobileNetV2 Model as defined in:
|
| 3 |
+
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. (2018).
|
| 4 |
+
MobileNetV2: Inverted Residuals and Linear Bottlenecks
|
| 5 |
+
arXiv preprint arXiv:1801.04381.
|
| 6 |
+
import from https://github.com/tonylins/pytorch-mobilenet-v2
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
__all__ = ['mobilenetv2']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 16 |
+
"""
|
| 17 |
+
This function is taken from the original tf repo.
|
| 18 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 19 |
+
It can be seen here:
|
| 20 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 21 |
+
:param v:
|
| 22 |
+
:param divisor:
|
| 23 |
+
:param min_value:
|
| 24 |
+
:return:
|
| 25 |
+
"""
|
| 26 |
+
if min_value is None:
|
| 27 |
+
min_value = divisor
|
| 28 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 29 |
+
# Make sure that round down does not go down by more than 10%.
|
| 30 |
+
if new_v < 0.9 * v:
|
| 31 |
+
new_v += divisor
|
| 32 |
+
return new_v
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def conv_3x3_bn(inp, oup, stride):
|
| 36 |
+
return nn.Sequential(
|
| 37 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 38 |
+
nn.BatchNorm2d(oup),
|
| 39 |
+
nn.ReLU6(inplace=True)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def conv_1x1_bn(inp, oup):
|
| 44 |
+
return nn.Sequential(
|
| 45 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 46 |
+
nn.BatchNorm2d(oup),
|
| 47 |
+
nn.ReLU6(inplace=True)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class InvertedResidual(nn.Module):
|
| 52 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
| 53 |
+
super(InvertedResidual, self).__init__()
|
| 54 |
+
assert stride in [1, 2]
|
| 55 |
+
|
| 56 |
+
hidden_dim = round(inp * expand_ratio)
|
| 57 |
+
self.identity = stride == 1 and inp == oup
|
| 58 |
+
|
| 59 |
+
if expand_ratio == 1:
|
| 60 |
+
self.conv = nn.Sequential(
|
| 61 |
+
# dw
|
| 62 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 63 |
+
nn.BatchNorm2d(hidden_dim),
|
| 64 |
+
nn.ReLU6(inplace=True),
|
| 65 |
+
# pw-linear
|
| 66 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 67 |
+
nn.BatchNorm2d(oup),
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
self.conv = nn.Sequential(
|
| 71 |
+
# pw
|
| 72 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
| 73 |
+
nn.BatchNorm2d(hidden_dim),
|
| 74 |
+
nn.ReLU6(inplace=True),
|
| 75 |
+
# dw
|
| 76 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 77 |
+
nn.BatchNorm2d(hidden_dim),
|
| 78 |
+
nn.ReLU6(inplace=True),
|
| 79 |
+
# pw-linear
|
| 80 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 81 |
+
nn.BatchNorm2d(oup),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
if self.identity:
|
| 86 |
+
return x + self.conv(x)
|
| 87 |
+
else:
|
| 88 |
+
return self.conv(x)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class MobileNetV2(nn.Module):
|
| 92 |
+
def __init__(self, num_classes=1000, width_mult=1.):
|
| 93 |
+
super(MobileNetV2, self).__init__()
|
| 94 |
+
# setting of inverted residual blocks
|
| 95 |
+
self.cfgs = [
|
| 96 |
+
# t, c, n, s
|
| 97 |
+
[1, 16, 1, 1],
|
| 98 |
+
[6, 24, 2, 2],
|
| 99 |
+
[6, 32, 3, 2],
|
| 100 |
+
[6, 64, 4, 2],
|
| 101 |
+
[6, 96, 3, 1],
|
| 102 |
+
[6, 160, 3, 2],
|
| 103 |
+
[6, 320, 1, 1],
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# building first layer
|
| 107 |
+
input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8)
|
| 108 |
+
layers = [conv_3x3_bn(3, input_channel, 2)]
|
| 109 |
+
# building inverted residual blocks
|
| 110 |
+
block = InvertedResidual
|
| 111 |
+
for t, c, n, s in self.cfgs:
|
| 112 |
+
output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8)
|
| 113 |
+
for i in range(n):
|
| 114 |
+
layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
|
| 115 |
+
input_channel = output_channel
|
| 116 |
+
self.features = nn.Sequential(*layers)
|
| 117 |
+
# building last several layers
|
| 118 |
+
output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280
|
| 119 |
+
self.conv = conv_1x1_bn(input_channel, output_channel)
|
| 120 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 121 |
+
self.classifier = nn.Linear(output_channel, num_classes)
|
| 122 |
+
|
| 123 |
+
self._initialize_weights()
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.features(x)
|
| 127 |
+
x = self.conv(x)
|
| 128 |
+
x = self.avgpool(x)
|
| 129 |
+
x = x.view(x.size(0), -1)
|
| 130 |
+
x = self.classifier(x)
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
def _initialize_weights(self):
|
| 134 |
+
for m in self.modules():
|
| 135 |
+
if isinstance(m, nn.Conv2d):
|
| 136 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 137 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 138 |
+
if m.bias is not None:
|
| 139 |
+
m.bias.data.zero_()
|
| 140 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 141 |
+
m.weight.data.fill_(1)
|
| 142 |
+
m.bias.data.zero_()
|
| 143 |
+
elif isinstance(m, nn.Linear):
|
| 144 |
+
m.weight.data.normal_(0, 0.01)
|
| 145 |
+
m.bias.data.zero_()
|
| 146 |
+
|
| 147 |
+
def mobilenetv2(**kwargs):
|
| 148 |
+
"""
|
| 149 |
+
Constructs a MobileNet V2 model
|
| 150 |
+
"""
|
| 151 |
+
return MobileNetV2(**kwargs)
|
var/D3HR/validation/models/pipeline_stable_unclip_img2img.py
ADDED
|
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import PIL.Image
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 21 |
+
|
| 22 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 23 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
| 24 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 25 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
| 26 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 27 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 28 |
+
from diffusers.utils import (
|
| 29 |
+
USE_PEFT_BACKEND,
|
| 30 |
+
deprecate,
|
| 31 |
+
logging,
|
| 32 |
+
replace_example_docstring,
|
| 33 |
+
scale_lora_layers,
|
| 34 |
+
unscale_lora_layers,
|
| 35 |
+
)
|
| 36 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 37 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 38 |
+
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
EXAMPLE_DOC_STRING = """
|
| 44 |
+
Examples:
|
| 45 |
+
```py
|
| 46 |
+
>>> import requests
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from PIL import Image
|
| 49 |
+
>>> from io import BytesIO
|
| 50 |
+
|
| 51 |
+
>>> from diffusers import StableUnCLIPImg2ImgPipeline
|
| 52 |
+
|
| 53 |
+
>>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 54 |
+
... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
|
| 55 |
+
... ) # TODO update model path
|
| 56 |
+
>>> pipe = pipe.to("cuda")
|
| 57 |
+
|
| 58 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
| 59 |
+
|
| 60 |
+
>>> response = requests.get(url)
|
| 61 |
+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 62 |
+
>>> init_image = init_image.resize((768, 512))
|
| 63 |
+
|
| 64 |
+
>>> prompt = "A fantasy landscape, trending on artstation"
|
| 65 |
+
|
| 66 |
+
>>> images = pipe(prompt, init_image).images
|
| 67 |
+
>>> images[0].save("fantasy_landscape.png")
|
| 68 |
+
```
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
| 73 |
+
"""
|
| 74 |
+
Pipeline for text-guided image-to-image generation using stable unCLIP.
|
| 75 |
+
|
| 76 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 77 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 78 |
+
|
| 79 |
+
The pipeline also inherits the following loading methods:
|
| 80 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 81 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 82 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
| 86 |
+
Feature extractor for image pre-processing before being encoded.
|
| 87 |
+
image_encoder ([`CLIPVisionModelWithProjection`]):
|
| 88 |
+
CLIP vision model for encoding images.
|
| 89 |
+
image_normalizer ([`StableUnCLIPImageNormalizer`]):
|
| 90 |
+
Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
|
| 91 |
+
embeddings after the noise has been applied.
|
| 92 |
+
image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
|
| 93 |
+
Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
|
| 94 |
+
by the `noise_level`.
|
| 95 |
+
tokenizer (`~transformers.CLIPTokenizer`):
|
| 96 |
+
A [`~transformers.CLIPTokenizer`)].
|
| 97 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 98 |
+
Frozen [`~transformers.CLIPTextModel`] text-encoder.
|
| 99 |
+
unet ([`UNet2DConditionModel`]):
|
| 100 |
+
A [`UNet2DConditionModel`] to denoise the encoded image latents.
|
| 101 |
+
scheduler ([`KarrasDiffusionSchedulers`]):
|
| 102 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 103 |
+
vae ([`AutoencoderKL`]):
|
| 104 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 108 |
+
_exclude_from_cpu_offload = ["image_normalizer"]
|
| 109 |
+
|
| 110 |
+
# image encoding components
|
| 111 |
+
feature_extractor: CLIPImageProcessor
|
| 112 |
+
image_encoder: CLIPVisionModelWithProjection
|
| 113 |
+
|
| 114 |
+
# image noising components
|
| 115 |
+
image_normalizer: StableUnCLIPImageNormalizer
|
| 116 |
+
image_noising_scheduler: KarrasDiffusionSchedulers
|
| 117 |
+
|
| 118 |
+
# regular denoising components
|
| 119 |
+
tokenizer: CLIPTokenizer
|
| 120 |
+
text_encoder: CLIPTextModel
|
| 121 |
+
unet: UNet2DConditionModel
|
| 122 |
+
scheduler: KarrasDiffusionSchedulers
|
| 123 |
+
|
| 124 |
+
vae: AutoencoderKL
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
# image encoding components
|
| 129 |
+
feature_extractor: CLIPImageProcessor,
|
| 130 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 131 |
+
# image noising components
|
| 132 |
+
image_normalizer: StableUnCLIPImageNormalizer,
|
| 133 |
+
image_noising_scheduler: KarrasDiffusionSchedulers,
|
| 134 |
+
# regular denoising components
|
| 135 |
+
tokenizer: CLIPTokenizer,
|
| 136 |
+
text_encoder: CLIPTextModel,
|
| 137 |
+
unet: UNet2DConditionModel,
|
| 138 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 139 |
+
# vae
|
| 140 |
+
vae: AutoencoderKL,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
|
| 144 |
+
self.register_modules(
|
| 145 |
+
feature_extractor=feature_extractor,
|
| 146 |
+
image_encoder=image_encoder,
|
| 147 |
+
image_normalizer=image_normalizer,
|
| 148 |
+
image_noising_scheduler=image_noising_scheduler,
|
| 149 |
+
tokenizer=tokenizer,
|
| 150 |
+
text_encoder=text_encoder,
|
| 151 |
+
unet=unet,
|
| 152 |
+
scheduler=scheduler,
|
| 153 |
+
vae=vae,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 157 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 158 |
+
|
| 159 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
| 160 |
+
def enable_vae_slicing(self):
|
| 161 |
+
r"""
|
| 162 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 163 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 164 |
+
"""
|
| 165 |
+
self.vae.enable_slicing()
|
| 166 |
+
|
| 167 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
| 168 |
+
def disable_vae_slicing(self):
|
| 169 |
+
r"""
|
| 170 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 171 |
+
computing decoding in one step.
|
| 172 |
+
"""
|
| 173 |
+
self.vae.disable_slicing()
|
| 174 |
+
|
| 175 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
| 176 |
+
def _encode_prompt(
|
| 177 |
+
self,
|
| 178 |
+
prompt,
|
| 179 |
+
device,
|
| 180 |
+
num_images_per_prompt,
|
| 181 |
+
do_classifier_free_guidance,
|
| 182 |
+
negative_prompt=None,
|
| 183 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 184 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 185 |
+
lora_scale: Optional[float] = None,
|
| 186 |
+
**kwargs,
|
| 187 |
+
):
|
| 188 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 189 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 190 |
+
|
| 191 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 192 |
+
prompt=prompt,
|
| 193 |
+
device=device,
|
| 194 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 195 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 196 |
+
negative_prompt=negative_prompt,
|
| 197 |
+
prompt_embeds=prompt_embeds,
|
| 198 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 199 |
+
lora_scale=lora_scale,
|
| 200 |
+
**kwargs,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# concatenate for backwards comp
|
| 204 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 205 |
+
|
| 206 |
+
return prompt_embeds
|
| 207 |
+
|
| 208 |
+
def _encode_image(
|
| 209 |
+
self,
|
| 210 |
+
image,
|
| 211 |
+
device,
|
| 212 |
+
batch_size,
|
| 213 |
+
num_images_per_prompt,
|
| 214 |
+
do_classifier_free_guidance,
|
| 215 |
+
noise_level,
|
| 216 |
+
generator,
|
| 217 |
+
image_embeds,
|
| 218 |
+
):
|
| 219 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 220 |
+
|
| 221 |
+
if isinstance(image, PIL.Image.Image):
|
| 222 |
+
# the image embedding should repeated so it matches the total batch size of the prompt
|
| 223 |
+
repeat_by = batch_size
|
| 224 |
+
else:
|
| 225 |
+
# assume the image input is already properly batched and just needs to be repeated so
|
| 226 |
+
# it matches the num_images_per_prompt.
|
| 227 |
+
#
|
| 228 |
+
# NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
|
| 229 |
+
# `image_embeds`. If those happen to be common use cases, let's think harder about
|
| 230 |
+
# what the expected dimensions of inputs should be and how we handle the encoding.
|
| 231 |
+
repeat_by = num_images_per_prompt
|
| 232 |
+
|
| 233 |
+
if image_embeds is None:
|
| 234 |
+
if not isinstance(image, torch.Tensor):
|
| 235 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 236 |
+
|
| 237 |
+
image = image.to(device=device, dtype=dtype)
|
| 238 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 239 |
+
|
| 240 |
+
image_embeds = self.noise_image_embeddings(
|
| 241 |
+
image_embeds=image_embeds,
|
| 242 |
+
noise_level=noise_level,
|
| 243 |
+
generator=generator,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 247 |
+
image_embeds = image_embeds.unsqueeze(1)
|
| 248 |
+
bs_embed, seq_len, _ = image_embeds.shape
|
| 249 |
+
image_embeds = image_embeds.repeat(1, repeat_by, 1)
|
| 250 |
+
image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
|
| 251 |
+
image_embeds = image_embeds.squeeze(1)
|
| 252 |
+
|
| 253 |
+
if do_classifier_free_guidance:
|
| 254 |
+
negative_prompt_embeds = torch.zeros_like(image_embeds)
|
| 255 |
+
|
| 256 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 257 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 258 |
+
# to avoid doing two forward passes
|
| 259 |
+
image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
|
| 260 |
+
|
| 261 |
+
return image_embeds
|
| 262 |
+
|
| 263 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
| 264 |
+
def encode_prompt(
|
| 265 |
+
self,
|
| 266 |
+
prompt,
|
| 267 |
+
device,
|
| 268 |
+
num_images_per_prompt,
|
| 269 |
+
do_classifier_free_guidance,
|
| 270 |
+
negative_prompt=None,
|
| 271 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 272 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 273 |
+
lora_scale: Optional[float] = None,
|
| 274 |
+
clip_skip: Optional[int] = None,
|
| 275 |
+
):
|
| 276 |
+
r"""
|
| 277 |
+
Encodes the prompt into text encoder hidden states.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 281 |
+
prompt to be encoded
|
| 282 |
+
device: (`torch.device`):
|
| 283 |
+
torch device
|
| 284 |
+
num_images_per_prompt (`int`):
|
| 285 |
+
number of images that should be generated per prompt
|
| 286 |
+
do_classifier_free_guidance (`bool`):
|
| 287 |
+
whether to use classifier free guidance or not
|
| 288 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 289 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 290 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 291 |
+
less than `1`).
|
| 292 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 293 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 294 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 295 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 296 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 297 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 298 |
+
argument.
|
| 299 |
+
lora_scale (`float`, *optional*):
|
| 300 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 301 |
+
clip_skip (`int`, *optional*):
|
| 302 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 303 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 304 |
+
"""
|
| 305 |
+
# set lora scale so that monkey patched LoRA
|
| 306 |
+
# function of text encoder can correctly access it
|
| 307 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 308 |
+
self._lora_scale = lora_scale
|
| 309 |
+
|
| 310 |
+
# dynamically adjust the LoRA scale
|
| 311 |
+
if not USE_PEFT_BACKEND:
|
| 312 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 313 |
+
else:
|
| 314 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 315 |
+
|
| 316 |
+
if prompt is not None and isinstance(prompt, str):
|
| 317 |
+
batch_size = 1
|
| 318 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 319 |
+
batch_size = len(prompt)
|
| 320 |
+
else:
|
| 321 |
+
batch_size = prompt_embeds.shape[0]
|
| 322 |
+
|
| 323 |
+
if prompt_embeds is None:
|
| 324 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
| 325 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 326 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 327 |
+
|
| 328 |
+
text_inputs = self.tokenizer(
|
| 329 |
+
prompt,
|
| 330 |
+
padding="max_length",
|
| 331 |
+
max_length=self.tokenizer.model_max_length,
|
| 332 |
+
truncation=True,
|
| 333 |
+
return_tensors="pt",
|
| 334 |
+
)
|
| 335 |
+
text_input_ids = text_inputs.input_ids
|
| 336 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 337 |
+
|
| 338 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 339 |
+
text_input_ids, untruncated_ids
|
| 340 |
+
):
|
| 341 |
+
removed_text = self.tokenizer.batch_decode(
|
| 342 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 343 |
+
)
|
| 344 |
+
logger.warning(
|
| 345 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 346 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 350 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 351 |
+
else:
|
| 352 |
+
attention_mask = None
|
| 353 |
+
|
| 354 |
+
if clip_skip is None:
|
| 355 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 356 |
+
prompt_embeds = prompt_embeds[0]
|
| 357 |
+
else:
|
| 358 |
+
prompt_embeds = self.text_encoder(
|
| 359 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 360 |
+
)
|
| 361 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 362 |
+
# all the hidden states from the encoder layers. Then index into
|
| 363 |
+
# the tuple to access the hidden states from the desired layer.
|
| 364 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 365 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 366 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 367 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 368 |
+
# layer.
|
| 369 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 370 |
+
|
| 371 |
+
if self.text_encoder is not None:
|
| 372 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 373 |
+
elif self.unet is not None:
|
| 374 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 375 |
+
else:
|
| 376 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 377 |
+
|
| 378 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 379 |
+
|
| 380 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 381 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 382 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 383 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 384 |
+
|
| 385 |
+
# get unconditional embeddings for classifier free guidance
|
| 386 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 387 |
+
uncond_tokens: List[str]
|
| 388 |
+
if negative_prompt is None:
|
| 389 |
+
uncond_tokens = [""] * batch_size
|
| 390 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 391 |
+
raise TypeError(
|
| 392 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 393 |
+
f" {type(prompt)}."
|
| 394 |
+
)
|
| 395 |
+
elif isinstance(negative_prompt, str):
|
| 396 |
+
uncond_tokens = [negative_prompt]
|
| 397 |
+
elif batch_size != len(negative_prompt):
|
| 398 |
+
raise ValueError(
|
| 399 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 400 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 401 |
+
" the batch size of `prompt`."
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
uncond_tokens = negative_prompt
|
| 405 |
+
|
| 406 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
| 407 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 408 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 409 |
+
|
| 410 |
+
max_length = prompt_embeds.shape[1]
|
| 411 |
+
uncond_input = self.tokenizer(
|
| 412 |
+
uncond_tokens,
|
| 413 |
+
padding="max_length",
|
| 414 |
+
max_length=max_length,
|
| 415 |
+
truncation=True,
|
| 416 |
+
return_tensors="pt",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 420 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 421 |
+
else:
|
| 422 |
+
attention_mask = None
|
| 423 |
+
|
| 424 |
+
negative_prompt_embeds = self.text_encoder(
|
| 425 |
+
uncond_input.input_ids.to(device),
|
| 426 |
+
attention_mask=attention_mask,
|
| 427 |
+
)
|
| 428 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 429 |
+
|
| 430 |
+
if do_classifier_free_guidance:
|
| 431 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 432 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 433 |
+
|
| 434 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 435 |
+
|
| 436 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 437 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 438 |
+
|
| 439 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 440 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 441 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 442 |
+
|
| 443 |
+
return prompt_embeds, negative_prompt_embeds
|
| 444 |
+
|
| 445 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 446 |
+
def decode_latents(self, latents):
|
| 447 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 448 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 449 |
+
|
| 450 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 451 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 452 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 453 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 454 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 455 |
+
return image
|
| 456 |
+
|
| 457 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 458 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 459 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 460 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 461 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 462 |
+
# and should be between [0, 1]
|
| 463 |
+
|
| 464 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 465 |
+
extra_step_kwargs = {}
|
| 466 |
+
if accepts_eta:
|
| 467 |
+
extra_step_kwargs["eta"] = eta
|
| 468 |
+
|
| 469 |
+
# check if the scheduler accepts generator
|
| 470 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 471 |
+
if accepts_generator:
|
| 472 |
+
extra_step_kwargs["generator"] = generator
|
| 473 |
+
return extra_step_kwargs
|
| 474 |
+
|
| 475 |
+
def check_inputs(
|
| 476 |
+
self,
|
| 477 |
+
prompt,
|
| 478 |
+
image,
|
| 479 |
+
height,
|
| 480 |
+
width,
|
| 481 |
+
callback_steps,
|
| 482 |
+
noise_level,
|
| 483 |
+
negative_prompt=None,
|
| 484 |
+
prompt_embeds=None,
|
| 485 |
+
negative_prompt_embeds=None,
|
| 486 |
+
image_embeds=None,
|
| 487 |
+
):
|
| 488 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 489 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 490 |
+
|
| 491 |
+
if (callback_steps is None) or (
|
| 492 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 493 |
+
):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 496 |
+
f" {type(callback_steps)}."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if prompt is not None and prompt_embeds is not None:
|
| 500 |
+
raise ValueError(
|
| 501 |
+
"Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if prompt is None and prompt_embeds is None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 510 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 511 |
+
|
| 512 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if prompt is not None and negative_prompt is not None:
|
| 518 |
+
if type(prompt) is not type(negative_prompt):
|
| 519 |
+
raise TypeError(
|
| 520 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 521 |
+
f" {type(prompt)}."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 525 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 528 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 529 |
+
f" {negative_prompt_embeds.shape}."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
| 533 |
+
raise ValueError(
|
| 534 |
+
f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
if image is not None and image_embeds is not None:
|
| 538 |
+
raise ValueError(
|
| 539 |
+
"Provide either `image` or `image_embeds`. Please make sure to define only one of the two."
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if image is None and image_embeds is None:
|
| 543 |
+
raise ValueError(
|
| 544 |
+
"Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if image is not None:
|
| 548 |
+
if (
|
| 549 |
+
not isinstance(image, torch.Tensor)
|
| 550 |
+
and not isinstance(image, PIL.Image.Image)
|
| 551 |
+
and not isinstance(image, list)
|
| 552 |
+
):
|
| 553 |
+
raise ValueError(
|
| 554 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 555 |
+
f" {type(image)}"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 559 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 560 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 561 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 562 |
+
raise ValueError(
|
| 563 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 564 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if latents is None:
|
| 568 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 569 |
+
else:
|
| 570 |
+
latents = latents.to(device)
|
| 571 |
+
|
| 572 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 573 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 574 |
+
return latents
|
| 575 |
+
|
| 576 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
|
| 577 |
+
def noise_image_embeddings(
|
| 578 |
+
self,
|
| 579 |
+
image_embeds: torch.Tensor,
|
| 580 |
+
noise_level: int,
|
| 581 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 582 |
+
generator: Optional[torch.Generator] = None,
|
| 583 |
+
):
|
| 584 |
+
"""
|
| 585 |
+
Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
|
| 586 |
+
`noise_level` increases the variance in the final un-noised images.
|
| 587 |
+
|
| 588 |
+
The noise is applied in two ways:
|
| 589 |
+
1. A noise schedule is applied directly to the embeddings.
|
| 590 |
+
2. A vector of sinusoidal time embeddings are appended to the output.
|
| 591 |
+
|
| 592 |
+
In both cases, the amount of noise is controlled by the same `noise_level`.
|
| 593 |
+
|
| 594 |
+
The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
|
| 595 |
+
"""
|
| 596 |
+
if noise is None:
|
| 597 |
+
noise = randn_tensor(
|
| 598 |
+
image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
|
| 602 |
+
|
| 603 |
+
self.image_normalizer.to(image_embeds.device)
|
| 604 |
+
image_embeds = self.image_normalizer.scale(image_embeds)
|
| 605 |
+
|
| 606 |
+
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
|
| 607 |
+
|
| 608 |
+
image_embeds = self.image_normalizer.unscale(image_embeds)
|
| 609 |
+
|
| 610 |
+
noise_level = get_timestep_embedding(
|
| 611 |
+
timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
|
| 615 |
+
# but we might actually be running in fp16. so we need to cast here.
|
| 616 |
+
# there might be better ways to encapsulate this.
|
| 617 |
+
noise_level = noise_level.to(image_embeds.dtype)
|
| 618 |
+
|
| 619 |
+
image_embeds = torch.cat((image_embeds, noise_level), 1)
|
| 620 |
+
|
| 621 |
+
return image_embeds
|
| 622 |
+
|
| 623 |
+
@torch.no_grad()
|
| 624 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 625 |
+
def __call__(
|
| 626 |
+
self,
|
| 627 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 628 |
+
prompt: Union[str, List[str]] = None,
|
| 629 |
+
height: Optional[int] = None,
|
| 630 |
+
width: Optional[int] = None,
|
| 631 |
+
num_inference_steps: int = 20,
|
| 632 |
+
guidance_scale: float = 10,
|
| 633 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 634 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 635 |
+
eta: float = 0.0,
|
| 636 |
+
generator: Optional[torch.Generator] = None,
|
| 637 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 638 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 639 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 640 |
+
output_type: Optional[str] = "pil",
|
| 641 |
+
return_dict: bool = True,
|
| 642 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 643 |
+
callback_steps: int = 1,
|
| 644 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 645 |
+
noise_level: int = 0,
|
| 646 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
| 647 |
+
clip_skip: Optional[int] = None,
|
| 648 |
+
cond_fn = None,
|
| 649 |
+
):
|
| 650 |
+
r"""
|
| 651 |
+
The call function to the pipeline for generation.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 655 |
+
The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
|
| 656 |
+
used or prompt is initialized to `""`.
|
| 657 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 658 |
+
`Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the
|
| 659 |
+
`unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the
|
| 660 |
+
denoising process like it is in the standard Stable Diffusion text-guided image variation process.
|
| 661 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 662 |
+
The height in pixels of the generated image.
|
| 663 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 664 |
+
The width in pixels of the generated image.
|
| 665 |
+
num_inference_steps (`int`, *optional*, defaults to 20):
|
| 666 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 667 |
+
expense of slower inference.
|
| 668 |
+
guidance_scale (`float`, *optional*, defaults to 10.0):
|
| 669 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 670 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 671 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 672 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 673 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 674 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 675 |
+
The number of images to generate per prompt.
|
| 676 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 677 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 678 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 679 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 680 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 681 |
+
generation deterministic.
|
| 682 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 683 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 684 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 685 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 686 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 687 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 688 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 689 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 690 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 691 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 692 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 693 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 694 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 695 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 696 |
+
callback (`Callable`, *optional*):
|
| 697 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 698 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 699 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 700 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 701 |
+
every step.
|
| 702 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 703 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 704 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 705 |
+
noise_level (`int`, *optional*, defaults to `0`):
|
| 706 |
+
The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
|
| 707 |
+
the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details.
|
| 708 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
| 709 |
+
Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising
|
| 710 |
+
process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`.
|
| 711 |
+
clip_skip (`int`, *optional*):
|
| 712 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 713 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 714 |
+
|
| 715 |
+
Examples:
|
| 716 |
+
|
| 717 |
+
Returns:
|
| 718 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 719 |
+
[`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning
|
| 720 |
+
a tuple, the first element is a list with the generated images.
|
| 721 |
+
"""
|
| 722 |
+
# 0. Default height and width to unet
|
| 723 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 724 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 725 |
+
|
| 726 |
+
if prompt is None and prompt_embeds is None:
|
| 727 |
+
prompt = len(image) * [""] if isinstance(image, list) else ""
|
| 728 |
+
|
| 729 |
+
# 1. Check inputs. Raise error if not correct
|
| 730 |
+
self.check_inputs(
|
| 731 |
+
prompt=prompt,
|
| 732 |
+
image=image,
|
| 733 |
+
height=height,
|
| 734 |
+
width=width,
|
| 735 |
+
callback_steps=callback_steps,
|
| 736 |
+
noise_level=noise_level,
|
| 737 |
+
negative_prompt=negative_prompt,
|
| 738 |
+
prompt_embeds=prompt_embeds,
|
| 739 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 740 |
+
image_embeds=image_embeds,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# 2. Define call parameters
|
| 744 |
+
if prompt is not None and isinstance(prompt, str):
|
| 745 |
+
batch_size = 1
|
| 746 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 747 |
+
batch_size = len(prompt)
|
| 748 |
+
else:
|
| 749 |
+
batch_size = prompt_embeds.shape[0]
|
| 750 |
+
|
| 751 |
+
batch_size = batch_size * num_images_per_prompt
|
| 752 |
+
|
| 753 |
+
device = self._execution_device
|
| 754 |
+
|
| 755 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 756 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 757 |
+
# corresponds to doing no classifier free guidance.
|
| 758 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 759 |
+
|
| 760 |
+
# 3. Encode input prompt
|
| 761 |
+
text_encoder_lora_scale = (
|
| 762 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 763 |
+
)
|
| 764 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 765 |
+
prompt=prompt,
|
| 766 |
+
device=device,
|
| 767 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 768 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 769 |
+
negative_prompt=negative_prompt,
|
| 770 |
+
prompt_embeds=prompt_embeds,
|
| 771 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 772 |
+
lora_scale=text_encoder_lora_scale,
|
| 773 |
+
)
|
| 774 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 775 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 776 |
+
# to avoid doing two forward passes
|
| 777 |
+
if do_classifier_free_guidance:
|
| 778 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 779 |
+
|
| 780 |
+
# 4. Encoder input image
|
| 781 |
+
noise_level = torch.tensor([noise_level], device=device)
|
| 782 |
+
image_embeds = self._encode_image(
|
| 783 |
+
image=image,
|
| 784 |
+
device=device,
|
| 785 |
+
batch_size=batch_size,
|
| 786 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 787 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 788 |
+
noise_level=noise_level,
|
| 789 |
+
generator=generator,
|
| 790 |
+
image_embeds=image_embeds,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# 5. Prepare timesteps
|
| 794 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 795 |
+
timesteps = self.scheduler.timesteps
|
| 796 |
+
|
| 797 |
+
# 6. Prepare latent variables
|
| 798 |
+
num_channels_latents = self.unet.config.in_channels
|
| 799 |
+
latents = self.prepare_latents(
|
| 800 |
+
batch_size=batch_size,
|
| 801 |
+
num_channels_latents=num_channels_latents,
|
| 802 |
+
height=height,
|
| 803 |
+
width=width,
|
| 804 |
+
dtype=prompt_embeds.dtype,
|
| 805 |
+
device=device,
|
| 806 |
+
generator=generator,
|
| 807 |
+
latents=latents,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 811 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 812 |
+
|
| 813 |
+
# 8. Denoising loop
|
| 814 |
+
for i, t in enumerate(timesteps):
|
| 815 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 816 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 817 |
+
|
| 818 |
+
# predict the noise residual
|
| 819 |
+
noise_pred = self.unet(
|
| 820 |
+
latent_model_input,
|
| 821 |
+
t,
|
| 822 |
+
encoder_hidden_states=prompt_embeds,
|
| 823 |
+
class_labels=image_embeds,
|
| 824 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 825 |
+
return_dict=False,
|
| 826 |
+
)[0]
|
| 827 |
+
|
| 828 |
+
# perform guidance
|
| 829 |
+
if do_classifier_free_guidance:
|
| 830 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 831 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 832 |
+
|
| 833 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 834 |
+
latents = self.scheduler.step(noise_pred, t, latents, cond_fn, **extra_step_kwargs, return_dict=False)[0]
|
| 835 |
+
|
| 836 |
+
if callback is not None and i % callback_steps == 0:
|
| 837 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 838 |
+
callback(step_idx, t, latents)
|
| 839 |
+
|
| 840 |
+
# 9. Post-processing
|
| 841 |
+
if not output_type == "latent":
|
| 842 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 843 |
+
else:
|
| 844 |
+
image = latents
|
| 845 |
+
|
| 846 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 847 |
+
|
| 848 |
+
# Offload all models
|
| 849 |
+
self.maybe_free_model_hooks()
|
| 850 |
+
|
| 851 |
+
if not return_dict:
|
| 852 |
+
return (image,)
|
| 853 |
+
|
| 854 |
+
return ImagePipelineOutput(images=image)
|
var/D3HR/validation/models/resnet.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.models.resnet import (
|
| 4 |
+
ResNet, ResNet18_Weights, ResNet50_Weights, ResNet101_Weights,
|
| 5 |
+
BasicBlock, Bottleneck,
|
| 6 |
+
_ovewrite_named_param
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FeatResNet(ResNet):
|
| 11 |
+
def __init__(self, block, layers, **kwargs):
|
| 12 |
+
super(FeatResNet, self).__init__(block, layers, **kwargs)
|
| 13 |
+
|
| 14 |
+
def get_features(self, x):
|
| 15 |
+
x = self.conv1(x)
|
| 16 |
+
x = self.bn1(x)
|
| 17 |
+
x = self.relu(x)
|
| 18 |
+
x = self.maxpool(x)
|
| 19 |
+
|
| 20 |
+
x = self.layer1(x)
|
| 21 |
+
x = self.layer2(x)
|
| 22 |
+
x = self.layer3(x)
|
| 23 |
+
x = self.layer4(x)
|
| 24 |
+
|
| 25 |
+
x = self.avgpool(x)
|
| 26 |
+
x = torch.flatten(x, 1)
|
| 27 |
+
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def resnet18(*, weights=None, progress=True, **kwargs):
|
| 32 |
+
weights = ResNet18_Weights.verify(weights)
|
| 33 |
+
if weights is not None:
|
| 34 |
+
_ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
|
| 35 |
+
|
| 36 |
+
model = FeatResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| 37 |
+
|
| 38 |
+
if weights is not None:
|
| 39 |
+
model.load_state_dict(weights.get_state_dict(progress=progress))
|
| 40 |
+
|
| 41 |
+
return model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def resnet50(*, weights=None, progress=True, **kwargs):
|
| 45 |
+
weights = ResNet50_Weights.verify(weights)
|
| 46 |
+
if weights is not None:
|
| 47 |
+
_ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
|
| 48 |
+
|
| 49 |
+
model = FeatResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 50 |
+
|
| 51 |
+
if weights is not None:
|
| 52 |
+
model.load_state_dict(weights.get_state_dict(progress=progress))
|
| 53 |
+
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def resnet101(*, weights=None, progress=True, **kwargs):
|
| 58 |
+
weights = ResNet101_Weights.verify(weights)
|
| 59 |
+
if weights is not None:
|
| 60 |
+
_ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
|
| 61 |
+
|
| 62 |
+
model = FeatResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| 63 |
+
|
| 64 |
+
if weights is not None:
|
| 65 |
+
model.load_state_dict(weights.get_state_dict(progress=progress))
|
| 66 |
+
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def resnet152(*, weights=None, progress=True, **kwargs):
|
| 71 |
+
weights = ResNet101_Weights.verify(weights)
|
| 72 |
+
if weights is not None:
|
| 73 |
+
_ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
|
| 74 |
+
|
| 75 |
+
model = FeatResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
| 76 |
+
|
| 77 |
+
if weights is not None:
|
| 78 |
+
model.load_state_dict(weights.get_state_dict(progress=progress))
|
| 79 |
+
|
| 80 |
+
return model
|
var/D3HR/validation/models/scheduling_ddim.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
| 16 |
+
# and https://github.com/hojonathanho/diffusion
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.utils import BaseOutput
|
| 27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 28 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
| 33 |
+
class DDIMSchedulerOutput(BaseOutput):
|
| 34 |
+
"""
|
| 35 |
+
Output class for the scheduler's `step` function output.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 39 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 40 |
+
denoising loop.
|
| 41 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 42 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 43 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
prev_sample: torch.Tensor
|
| 47 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 51 |
+
def betas_for_alpha_bar(
|
| 52 |
+
num_diffusion_timesteps,
|
| 53 |
+
max_beta=0.999,
|
| 54 |
+
alpha_transform_type="cosine",
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 58 |
+
(1-beta) over time from t = [0,1].
|
| 59 |
+
|
| 60 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 61 |
+
to that part of the diffusion process.
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 66 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 67 |
+
prevent singularities.
|
| 68 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 69 |
+
Choose from `cosine` or `exp`
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 73 |
+
"""
|
| 74 |
+
if alpha_transform_type == "cosine":
|
| 75 |
+
|
| 76 |
+
def alpha_bar_fn(t):
|
| 77 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 78 |
+
|
| 79 |
+
elif alpha_transform_type == "exp":
|
| 80 |
+
|
| 81 |
+
def alpha_bar_fn(t):
|
| 82 |
+
return math.exp(t * -12.0)
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 86 |
+
|
| 87 |
+
betas = []
|
| 88 |
+
for i in range(num_diffusion_timesteps):
|
| 89 |
+
t1 = i / num_diffusion_timesteps
|
| 90 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 91 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 92 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def rescale_zero_terminal_snr(betas):
|
| 96 |
+
"""
|
| 97 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
betas (`torch.Tensor`):
|
| 102 |
+
the betas that the scheduler is being initialized with.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 106 |
+
"""
|
| 107 |
+
# Convert betas to alphas_bar_sqrt
|
| 108 |
+
alphas = 1.0 - betas
|
| 109 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 110 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 111 |
+
|
| 112 |
+
# Store old values.
|
| 113 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 114 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 115 |
+
|
| 116 |
+
# Shift so the last timestep is zero.
|
| 117 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 118 |
+
|
| 119 |
+
# Scale so the first timestep is back to the old value.
|
| 120 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 121 |
+
|
| 122 |
+
# Convert alphas_bar_sqrt to betas
|
| 123 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 124 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 125 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 126 |
+
betas = 1 - alphas
|
| 127 |
+
|
| 128 |
+
return betas
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
| 132 |
+
"""
|
| 133 |
+
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
| 134 |
+
non-Markovian guidance.
|
| 135 |
+
|
| 136 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 137 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 141 |
+
The number of diffusion steps to train the model.
|
| 142 |
+
beta_start (`float`, defaults to 0.0001):
|
| 143 |
+
The starting `beta` value of inference.
|
| 144 |
+
beta_end (`float`, defaults to 0.02):
|
| 145 |
+
The final `beta` value.
|
| 146 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 147 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 148 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 149 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 150 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 151 |
+
clip_sample (`bool`, defaults to `True`):
|
| 152 |
+
Clip the predicted sample for numerical stability.
|
| 153 |
+
clip_sample_range (`float`, defaults to 1.0):
|
| 154 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
| 155 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
| 156 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
| 157 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
| 158 |
+
otherwise it uses the alpha value at step 0.
|
| 159 |
+
steps_offset (`int`, defaults to 0):
|
| 160 |
+
An offset added to the inference steps, as required by some model families.
|
| 161 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 162 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 163 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 164 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 165 |
+
thresholding (`bool`, defaults to `False`):
|
| 166 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 167 |
+
as Stable Diffusion.
|
| 168 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 169 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 170 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 171 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
| 172 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
| 173 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 174 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 175 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 176 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 177 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 178 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 182 |
+
order = 1
|
| 183 |
+
|
| 184 |
+
@register_to_config
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
num_train_timesteps: int = 1000,
|
| 188 |
+
beta_start: float = 0.0001,
|
| 189 |
+
beta_end: float = 0.02,
|
| 190 |
+
beta_schedule: str = "linear",
|
| 191 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 192 |
+
clip_sample: bool = True,
|
| 193 |
+
set_alpha_to_one: bool = True,
|
| 194 |
+
steps_offset: int = 0,
|
| 195 |
+
prediction_type: str = "epsilon",
|
| 196 |
+
thresholding: bool = False,
|
| 197 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 198 |
+
clip_sample_range: float = 1.0,
|
| 199 |
+
sample_max_value: float = 1.0,
|
| 200 |
+
timestep_spacing: str = "leading",
|
| 201 |
+
rescale_betas_zero_snr: bool = False,
|
| 202 |
+
):
|
| 203 |
+
if trained_betas is not None:
|
| 204 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 205 |
+
elif beta_schedule == "linear":
|
| 206 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 207 |
+
elif beta_schedule == "scaled_linear":
|
| 208 |
+
# this schedule is very specific to the latent diffusion model.
|
| 209 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 210 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 211 |
+
# Glide cosine schedule
|
| 212 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 213 |
+
else:
|
| 214 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 215 |
+
|
| 216 |
+
# Rescale for zero SNR
|
| 217 |
+
if rescale_betas_zero_snr:
|
| 218 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 219 |
+
|
| 220 |
+
self.alphas = 1.0 - self.betas
|
| 221 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 222 |
+
|
| 223 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
| 224 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
| 225 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
| 226 |
+
# whether we use the final alpha of the "non-previous" one.
|
| 227 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
| 228 |
+
|
| 229 |
+
# standard deviation of the initial noise distribution
|
| 230 |
+
self.init_noise_sigma = 1.0
|
| 231 |
+
|
| 232 |
+
# setable values
|
| 233 |
+
self.num_inference_steps = None
|
| 234 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
| 235 |
+
|
| 236 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
| 237 |
+
"""
|
| 238 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 239 |
+
current timestep.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
sample (`torch.Tensor`):
|
| 243 |
+
The input sample.
|
| 244 |
+
timestep (`int`, *optional*):
|
| 245 |
+
The current timestep in the diffusion chain.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
`torch.Tensor`:
|
| 249 |
+
A scaled input sample.
|
| 250 |
+
"""
|
| 251 |
+
return sample
|
| 252 |
+
|
| 253 |
+
def _get_variance(self, timestep, prev_timestep):
|
| 254 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 255 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 256 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 257 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 258 |
+
|
| 259 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 260 |
+
|
| 261 |
+
return variance
|
| 262 |
+
|
| 263 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 264 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 265 |
+
"""
|
| 266 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 267 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 268 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 269 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 270 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 271 |
+
|
| 272 |
+
https://arxiv.org/abs/2205.11487
|
| 273 |
+
"""
|
| 274 |
+
dtype = sample.dtype
|
| 275 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 276 |
+
|
| 277 |
+
if dtype not in (torch.float32, torch.float64):
|
| 278 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 279 |
+
|
| 280 |
+
# Flatten sample for doing quantile calculation along each image
|
| 281 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 282 |
+
|
| 283 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 284 |
+
|
| 285 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 286 |
+
s = torch.clamp(
|
| 287 |
+
s, min=1, max=self.config.sample_max_value
|
| 288 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 289 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 290 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 291 |
+
|
| 292 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 293 |
+
sample = sample.to(dtype)
|
| 294 |
+
|
| 295 |
+
return sample
|
| 296 |
+
|
| 297 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
| 298 |
+
"""
|
| 299 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
num_inference_steps (`int`):
|
| 303 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
| 307 |
+
raise ValueError(
|
| 308 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
| 309 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
| 310 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.num_inference_steps = num_inference_steps
|
| 314 |
+
|
| 315 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
| 316 |
+
if self.config.timestep_spacing == "linspace":
|
| 317 |
+
timesteps = (
|
| 318 |
+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
| 319 |
+
.round()[::-1]
|
| 320 |
+
.copy()
|
| 321 |
+
.astype(np.int64)
|
| 322 |
+
)
|
| 323 |
+
elif self.config.timestep_spacing == "leading":
|
| 324 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
| 325 |
+
# creates integer timesteps by multiplying by ratio
|
| 326 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 327 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
| 328 |
+
timesteps += self.config.steps_offset
|
| 329 |
+
elif self.config.timestep_spacing == "trailing":
|
| 330 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
| 331 |
+
# creates integer timesteps by multiplying by ratio
|
| 332 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 333 |
+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
| 334 |
+
timesteps -= 1
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
| 341 |
+
|
| 342 |
+
def step(
|
| 343 |
+
self,
|
| 344 |
+
model_output: torch.Tensor,
|
| 345 |
+
timestep: int,
|
| 346 |
+
sample: torch.Tensor,
|
| 347 |
+
cond_fn = None,
|
| 348 |
+
eta: float = 0.0,
|
| 349 |
+
use_clipped_model_output: bool = False,
|
| 350 |
+
generator=None,
|
| 351 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 352 |
+
return_dict: bool = True,
|
| 353 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
| 354 |
+
"""
|
| 355 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 356 |
+
process from the learned model outputs (most often the predicted noise).
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
model_output (`torch.Tensor`):
|
| 360 |
+
The direct output from learned diffusion model.
|
| 361 |
+
timestep (`float`):
|
| 362 |
+
The current discrete timestep in the diffusion chain.
|
| 363 |
+
sample (`torch.Tensor`):
|
| 364 |
+
A current instance of a sample created by the diffusion process.
|
| 365 |
+
eta (`float`):
|
| 366 |
+
The weight of noise for added noise in diffusion step.
|
| 367 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
| 368 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
| 369 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
| 370 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
| 371 |
+
`use_clipped_model_output` has no effect.
|
| 372 |
+
generator (`torch.Generator`, *optional*):
|
| 373 |
+
A random number generator.
|
| 374 |
+
variance_noise (`torch.Tensor`):
|
| 375 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 376 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
| 377 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 378 |
+
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
|
| 382 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
| 383 |
+
tuple is returned where the first element is the sample tensor.
|
| 384 |
+
|
| 385 |
+
"""
|
| 386 |
+
if self.num_inference_steps is None:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
| 392 |
+
# Ideally, read DDIM paper in-detail understanding
|
| 393 |
+
|
| 394 |
+
# Notation (<variable name> -> <name in paper>
|
| 395 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
| 396 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
| 397 |
+
# - std_dev_t -> sigma_t
|
| 398 |
+
# - eta -> η
|
| 399 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
| 400 |
+
# - pred_prev_sample -> "x_t-1"
|
| 401 |
+
|
| 402 |
+
# 1. get previous step value (=t-1)
|
| 403 |
+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
| 404 |
+
|
| 405 |
+
# 2. compute alphas, betas
|
| 406 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 407 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 408 |
+
|
| 409 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 410 |
+
|
| 411 |
+
if cond_fn is not None:
|
| 412 |
+
model_output = model_output - (1 - alpha_prod_t) ** (0.5) * cond_fn(sample)
|
| 413 |
+
|
| 414 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 415 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 416 |
+
if self.config.prediction_type == "epsilon":
|
| 417 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 418 |
+
pred_epsilon = model_output
|
| 419 |
+
elif self.config.prediction_type == "sample":
|
| 420 |
+
pred_original_sample = model_output
|
| 421 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
| 422 |
+
elif self.config.prediction_type == "v_prediction":
|
| 423 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
| 424 |
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 425 |
+
else:
|
| 426 |
+
raise ValueError(
|
| 427 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 428 |
+
" `v_prediction`"
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# 4. Clip or threshold "predicted x_0"
|
| 432 |
+
if self.config.thresholding:
|
| 433 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
| 434 |
+
elif self.config.clip_sample:
|
| 435 |
+
pred_original_sample = pred_original_sample.clamp(
|
| 436 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
| 440 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
| 441 |
+
variance = self._get_variance(timestep, prev_timestep)
|
| 442 |
+
std_dev_t = eta * variance ** (0.5)
|
| 443 |
+
|
| 444 |
+
if use_clipped_model_output:
|
| 445 |
+
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
| 446 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
| 447 |
+
|
| 448 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 449 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
| 450 |
+
|
| 451 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 452 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 453 |
+
|
| 454 |
+
if eta > 0:
|
| 455 |
+
if variance_noise is not None and generator is not None:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
| 458 |
+
" `variance_noise` stays `None`."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if variance_noise is None:
|
| 462 |
+
variance_noise = randn_tensor(
|
| 463 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
| 464 |
+
)
|
| 465 |
+
variance = std_dev_t * variance_noise
|
| 466 |
+
|
| 467 |
+
prev_sample = prev_sample + variance
|
| 468 |
+
|
| 469 |
+
if not return_dict:
|
| 470 |
+
return (prev_sample,)
|
| 471 |
+
|
| 472 |
+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 473 |
+
|
| 474 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
| 475 |
+
def add_noise(
|
| 476 |
+
self,
|
| 477 |
+
original_samples: torch.Tensor,
|
| 478 |
+
noise: torch.Tensor,
|
| 479 |
+
timesteps: torch.IntTensor,
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
| 482 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
| 483 |
+
# for the subsequent add_noise calls
|
| 484 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
| 485 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
| 486 |
+
timesteps = timesteps.to(original_samples.device)
|
| 487 |
+
|
| 488 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 489 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 490 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
| 491 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 492 |
+
|
| 493 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 494 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 495 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
| 496 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 497 |
+
|
| 498 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 499 |
+
return noisy_samples
|
| 500 |
+
|
| 501 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
| 502 |
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
| 503 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
| 504 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
| 505 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
| 506 |
+
timesteps = timesteps.to(sample.device)
|
| 507 |
+
|
| 508 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 509 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 510 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
| 511 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 512 |
+
|
| 513 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 514 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 515 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
| 516 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 517 |
+
|
| 518 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 519 |
+
return velocity
|
| 520 |
+
|
| 521 |
+
def __len__(self):
|
| 522 |
+
return self.config.num_train_timesteps
|
var/D3HR/validation/utils/__pycache__/data_utils.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
var/D3HR/validation/utils/__pycache__/data_utils.cpython-37.pyc
ADDED
|
Binary file (7.34 kB). View file
|
|
|
var/D3HR/validation/utils/__pycache__/validate_utils.cpython-310.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
var/D3HR/validation/utils/__pycache__/validate_utils.cpython-37.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
var/D3HR/validation/utils/data_utils.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
import json
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def find_subclasses(spec, nclass, phase=0):
|
| 13 |
+
classes = []
|
| 14 |
+
cls_from = nclass * phase
|
| 15 |
+
cls_to = nclass * (phase + 1)
|
| 16 |
+
if spec == 'woof':
|
| 17 |
+
file_list = './misc/class_woof.txt'
|
| 18 |
+
elif spec == 'im100':
|
| 19 |
+
file_list = './misc/class_100.txt'
|
| 20 |
+
else:
|
| 21 |
+
file_list = './misc/class_indices.txt'
|
| 22 |
+
with open(file_list, 'r') as f:
|
| 23 |
+
class_name = f.readlines()
|
| 24 |
+
for c in class_name:
|
| 25 |
+
c = c.split('\n')[0]
|
| 26 |
+
classes.append(c)
|
| 27 |
+
classes = classes[cls_from:cls_to]
|
| 28 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
| 29 |
+
|
| 30 |
+
return classes, class_to_idx
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def find_original_classes(spec, classes):
|
| 34 |
+
file_list = './misc/class_indices.txt'
|
| 35 |
+
with open(file_list, 'r') as f:
|
| 36 |
+
all_classes = f.readlines()
|
| 37 |
+
all_classes = [class_name.split('\n')[0] for class_name in all_classes]
|
| 38 |
+
original_classes = []
|
| 39 |
+
for class_name in classes:
|
| 40 |
+
original_classes.append(all_classes.index(class_name))
|
| 41 |
+
return original_classes
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_mapping_imgwoof(mapping_file, names):
|
| 45 |
+
new_mapping = {}
|
| 46 |
+
with open(mapping_file, 'r') as file:
|
| 47 |
+
data = json.load(file)
|
| 48 |
+
if "tiny" in mapping_file:
|
| 49 |
+
for index, line in enumerate(file):
|
| 50 |
+
# 提取每一行的编号(n开头部分)并将行号-1
|
| 51 |
+
key = line.split()[0]
|
| 52 |
+
new_mapping[key] = index
|
| 53 |
+
else:
|
| 54 |
+
new_mapping = {item["wnid"]: names.index(item["name"]) for item in data.values() if item['name'] in names}
|
| 55 |
+
return new_mapping
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_mapping(mapping_file):
|
| 59 |
+
new_mapping = {}
|
| 60 |
+
with open(mapping_file, 'r') as file:
|
| 61 |
+
data = json.load(file)
|
| 62 |
+
if "tiny" in mapping_file:
|
| 63 |
+
for index, line in enumerate(file):
|
| 64 |
+
# 提取每一行的编号(n开头部分)并将行号-1
|
| 65 |
+
key = line.split()[0]
|
| 66 |
+
new_mapping[key] = index
|
| 67 |
+
else:
|
| 68 |
+
new_mapping = {item["wnid"]: item["index"] for item in data.values()}
|
| 69 |
+
return new_mapping
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_mapping_txt(mapping_file):
|
| 74 |
+
wnid_to_index = {}
|
| 75 |
+
with open(mapping_file, 'r') as f:
|
| 76 |
+
for line in f:
|
| 77 |
+
wnid, index = line.strip().split('\t')
|
| 78 |
+
wnid_to_index[wnid] = int(index)
|
| 79 |
+
return wnid_to_index
|
| 80 |
+
|
| 81 |
+
def find_classes(class_file):
|
| 82 |
+
with open(class_file) as r:
|
| 83 |
+
classes = list(map(lambda s: s.strip(), r.readlines()))
|
| 84 |
+
|
| 85 |
+
classes.sort()
|
| 86 |
+
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
| 87 |
+
|
| 88 |
+
return class_to_idx
|
| 89 |
+
|
| 90 |
+
class ImageFolder(Dataset):
|
| 91 |
+
def __init__(self, split=None, txt_file=None, subset=None, mapping_file=None, transform=None):
|
| 92 |
+
super(ImageFolder, self).__init__()
|
| 93 |
+
self.split = split
|
| 94 |
+
self.image_paths = []
|
| 95 |
+
self.targets = []
|
| 96 |
+
self.samples = []
|
| 97 |
+
self.subset = subset
|
| 98 |
+
if self.subset == 'imagenet_1k':
|
| 99 |
+
self.wnid_to_index = load_mapping(mapping_file)
|
| 100 |
+
elif self.subset == 'tinyimagenet':
|
| 101 |
+
self.wnid_to_index = find_classes(mapping_file)
|
| 102 |
+
if split == 'train':
|
| 103 |
+
self._load_from_txt(txt_file)
|
| 104 |
+
else:
|
| 105 |
+
self._load_from_txt(txt_file)
|
| 106 |
+
self.transform = transform
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _load_from_txt(self, txt_file):
|
| 110 |
+
with open(txt_file, "r") as file:
|
| 111 |
+
image_paths = file.readlines()
|
| 112 |
+
|
| 113 |
+
# 去掉每行的换行符
|
| 114 |
+
self.image_paths = [path.strip() for path in image_paths]
|
| 115 |
+
for path in self.image_paths:
|
| 116 |
+
self.samples.append(path)
|
| 117 |
+
if self.subset == 'cifar10' or self.subset == 'cifar100':
|
| 118 |
+
class_index = int(path.split('/')[-2][-3:])
|
| 119 |
+
else:
|
| 120 |
+
# if self.split == 'test':
|
| 121 |
+
# class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 122 |
+
# elif self.split == 'train':
|
| 123 |
+
class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 124 |
+
self.targets.append(class_index)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
#combine ten txt
|
| 128 |
+
def _load_from_txt_1(self, txt_file):
|
| 129 |
+
|
| 130 |
+
image_paths_10 = []
|
| 131 |
+
for kk in range(10):
|
| 132 |
+
txt_file=f'/scratch/zhao.lin1/tinyimagenet_finetune_start_step_18_ddim_inversion_10_min_images_{kk}/train.txt'
|
| 133 |
+
with open(txt_file, "r") as file:
|
| 134 |
+
image_paths = file.readlines()
|
| 135 |
+
|
| 136 |
+
image_paths_10.append([path.strip() for path in image_paths])
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
for kk in range(len(image_paths)):
|
| 140 |
+
number = random.randint(0, 9)
|
| 141 |
+
self.image_paths.append(image_paths_10[number][kk])
|
| 142 |
+
if self.subset == 'cifar10' or self.subset == 'cifar100':
|
| 143 |
+
class_index = int(path.split('/')[-2][-3:])
|
| 144 |
+
else:
|
| 145 |
+
# if self.split == 'test':
|
| 146 |
+
# class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 147 |
+
# elif self.split == 'train':
|
| 148 |
+
class_index = self.wnid_to_index[image_paths_10[number][kk].split('/')[-2]]
|
| 149 |
+
self.targets.append(class_index)
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, index):
|
| 152 |
+
img_path = self.image_paths[index]
|
| 153 |
+
try:
|
| 154 |
+
sample = Image.open(img_path).convert('RGB')
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Error loading image {img_path}: {e}")
|
| 157 |
+
# Return a black image in case of error
|
| 158 |
+
sample = Image.new('RGB', (256, 256))
|
| 159 |
+
sample = self.transform(sample)
|
| 160 |
+
# class_dir = img_path.split('/')[-2]
|
| 161 |
+
return sample, self.targets[index]
|
| 162 |
+
|
| 163 |
+
def __len__(self):
|
| 164 |
+
return len(self.targets)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Imagewoof(Dataset):
|
| 168 |
+
def __init__(self, split=None, txt_file=None, subset=None, mapping_file=None, transform=None):
|
| 169 |
+
super(Imagewoof, self).__init__()
|
| 170 |
+
self.split = split
|
| 171 |
+
self.image_paths = []
|
| 172 |
+
self.targets = []
|
| 173 |
+
self.samples = []
|
| 174 |
+
self.subset = subset
|
| 175 |
+
self.names = ["Australian_terrier", "Border_terrier", "Samoyed", "beagle", "Shih-Tzu", "English_foxhound", "Rhodesian_ridgeback", "dingo", "golden_retriever", "Old_English_sheepdog"]
|
| 176 |
+
self.wnid_to_index = load_mapping_imgwoof(mapping_file, self.names)
|
| 177 |
+
self._load_from_txt(txt_file)
|
| 178 |
+
self.transform = transform
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _load_from_txt(self, txt_file):
|
| 182 |
+
with open(txt_file, "r") as file:
|
| 183 |
+
image_paths = file.readlines()
|
| 184 |
+
|
| 185 |
+
# 去掉每行的换行符
|
| 186 |
+
image_paths = [path.strip() for path in image_paths]
|
| 187 |
+
for path in image_paths:
|
| 188 |
+
self.samples.append(path)
|
| 189 |
+
if self.subset == 'cifar10' or self.subset == 'cifar100':
|
| 190 |
+
class_index = int(path.split('/')[-2][-3:])
|
| 191 |
+
else:
|
| 192 |
+
# if self.split == 'test':
|
| 193 |
+
# class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 194 |
+
# elif self.split == 'train':
|
| 195 |
+
if path.split('/')[-2] in list(self.wnid_to_index.keys()):
|
| 196 |
+
class_index = self.wnid_to_index[path.split('/')[-2]]
|
| 197 |
+
self.image_paths.append(path)
|
| 198 |
+
self.targets.append(class_index)
|
| 199 |
+
|
| 200 |
+
def __getitem__(self, index):
|
| 201 |
+
img_path = self.image_paths[index]
|
| 202 |
+
try:
|
| 203 |
+
sample = Image.open(img_path).convert('RGB')
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"Error loading image {img_path}: {e}")
|
| 206 |
+
# Return a black image in case of error
|
| 207 |
+
sample = Image.new('RGB', (256, 256))
|
| 208 |
+
sample = self.transform(sample)
|
| 209 |
+
# class_dir = img_path.split('/')[-2]
|
| 210 |
+
return sample, self.targets[index]
|
| 211 |
+
|
| 212 |
+
def __len__(self):
|
| 213 |
+
return len(self.targets)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# class ImageFolder(torchvision.datasets.ImageFolder):
|
| 218 |
+
# def __init__(self, nclass, ipc, mem=False, spec='none', phase=0, **kwargs):
|
| 219 |
+
# super(ImageFolder, self).__init__(**kwargs)
|
| 220 |
+
# self.mem = mem
|
| 221 |
+
# self.spec = spec
|
| 222 |
+
# self.classes, self.class_to_idx = find_subclasses(
|
| 223 |
+
# spec=spec, nclass=nclass, phase=phase
|
| 224 |
+
# )
|
| 225 |
+
# self.original_classes = find_original_classes(spec=self.spec, classes=self.classes)
|
| 226 |
+
# self.samples, self.targets = self.load_subset(ipc=ipc)
|
| 227 |
+
# if self.mem:
|
| 228 |
+
# self.samples = [self.loader(path) for path in self.samples]
|
| 229 |
+
|
| 230 |
+
# def load_subset(self, ipc=-1):
|
| 231 |
+
# all_samples = torchvision.datasets.folder.make_dataset(
|
| 232 |
+
# self.root, self.class_to_idx, self.extensions
|
| 233 |
+
# )
|
| 234 |
+
# samples = np.array([item[0] for item in all_samples])
|
| 235 |
+
# targets = np.array([item[1] for item in all_samples])
|
| 236 |
+
|
| 237 |
+
# if ipc == -1:
|
| 238 |
+
# return samples, targets
|
| 239 |
+
# else:
|
| 240 |
+
# sub_samples = []
|
| 241 |
+
# sub_targets = []
|
| 242 |
+
# for c in range(len(self.classes)):
|
| 243 |
+
# c_indices = np.where(targets == c)[0]
|
| 244 |
+
# #random.shuffle(c_indices)
|
| 245 |
+
# sub_samples.extend(samples[c_indices[:ipc]])
|
| 246 |
+
# sub_targets.extend(targets[c_indices[:ipc]])
|
| 247 |
+
# return sub_samples, sub_targets
|
| 248 |
+
|
| 249 |
+
# def __getitem__(self, index):
|
| 250 |
+
# if self.mem:
|
| 251 |
+
# sample = self.samples[index]
|
| 252 |
+
# else:
|
| 253 |
+
# sample = self.loader(self.samples[index])
|
| 254 |
+
# sample = self.transform(sample)
|
| 255 |
+
# return sample, self.targets[index]
|
| 256 |
+
|
| 257 |
+
# def __len__(self):
|
| 258 |
+
# return len(self.targets)
|
| 259 |
+
def random_stitch_crop_4(image):
|
| 260 |
+
"""随机从 stitch 的四个子区域中裁剪一个"""
|
| 261 |
+
w, h = image.size # 获取图像的宽和高
|
| 262 |
+
w_half, h_half = w // 2, h // 2
|
| 263 |
+
|
| 264 |
+
# 定义四个区域的坐标
|
| 265 |
+
regions = [
|
| 266 |
+
(0, 0, w_half, h_half), # 左上
|
| 267 |
+
(w_half, 0, w, h_half), # 右上
|
| 268 |
+
(0, h_half, w_half, h), # 左下
|
| 269 |
+
(w_half, h_half, w, h), # 右下
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
# 随机选择一个区域
|
| 273 |
+
x1, y1, x2, y2 = random.choice(regions)
|
| 274 |
+
return image.crop((x1, y1, x2, y2)) # 裁��并返回
|
| 275 |
+
|
| 276 |
+
def transform_imagenet(args):
|
| 277 |
+
resize_test = [transforms.Resize(args.input_size // 7 * 8), transforms.CenterCrop(args.input_size)]
|
| 278 |
+
# resize_test = [transforms.Resize(args.input_size), transforms.CenterCrop(args.input_size)]
|
| 279 |
+
|
| 280 |
+
cast = [transforms.ToTensor()]
|
| 281 |
+
|
| 282 |
+
aug = [
|
| 283 |
+
# transforms.Resize(224),
|
| 284 |
+
# transforms.Lambda(random_stitch_crop_4),
|
| 285 |
+
# ShufflePatches(args.factor),
|
| 286 |
+
transforms.RandomResizedCrop(
|
| 287 |
+
size=args.input_size,
|
| 288 |
+
# scale=(0.5, 1.0),
|
| 289 |
+
# scale=(1 / args.factor, args.max_scale_crops),
|
| 290 |
+
scale=(0.08, args.max_scale_crops),
|
| 291 |
+
antialias=True,
|
| 292 |
+
),
|
| 293 |
+
transforms.RandomHorizontalFlip()
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
normalize = [transforms.Normalize(
|
| 297 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 298 |
+
)]
|
| 299 |
+
|
| 300 |
+
train_transform = transforms.Compose(aug +cast+ normalize)
|
| 301 |
+
test_transform = transforms.Compose(resize_test + cast + normalize)
|
| 302 |
+
|
| 303 |
+
return train_transform, test_transform
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
sharing_strategy = "file_system"
|
| 307 |
+
torch.multiprocessing.set_sharing_strategy(sharing_strategy)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def set_worker_sharing_strategy(worker_id: int) -> None:
|
| 311 |
+
torch.multiprocessing.set_sharing_strategy(sharing_strategy)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def load_data(args, coreset=False, resize_only=False, mem_flag=True, trainset_only=False):
|
| 315 |
+
train_transform, test_transform = transform_imagenet(args)
|
| 316 |
+
# if len(args.data_dir) == 1:
|
| 317 |
+
# train_dir = os.path.join(args.data_dir[0], 'train')
|
| 318 |
+
# val_dir = os.path.join(args.data_dir[0], 'val')
|
| 319 |
+
# else:
|
| 320 |
+
# train_dir = args.data_dir[0]
|
| 321 |
+
# val_dir = os.path.join(args.data_dir[1], 'val')
|
| 322 |
+
|
| 323 |
+
if resize_only:
|
| 324 |
+
train_transform = transforms.Compose([
|
| 325 |
+
transforms.Resize((512, 512)),
|
| 326 |
+
])
|
| 327 |
+
elif coreset:
|
| 328 |
+
train_transform = test_transform
|
| 329 |
+
|
| 330 |
+
# train_dataset = ImageFolder(
|
| 331 |
+
# nclass=args.nclass,
|
| 332 |
+
# ipc=args.ipc,
|
| 333 |
+
# mem=mem_flag,
|
| 334 |
+
# spec=args.spec,
|
| 335 |
+
# phase=args.phase,
|
| 336 |
+
# root=train_dir,
|
| 337 |
+
# transform=train_transform,
|
| 338 |
+
# )
|
| 339 |
+
|
| 340 |
+
if args.subset == 'imagewoof':
|
| 341 |
+
# Imagewoor
|
| 342 |
+
train_dataset = Imagewoof(
|
| 343 |
+
split = 'train',
|
| 344 |
+
txt_file=args.txt_file,
|
| 345 |
+
mapping_file=args.mapping_file,
|
| 346 |
+
subset = args.subset,
|
| 347 |
+
transform=train_transform,
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
train_dataset = ImageFolder(
|
| 351 |
+
split = 'train',
|
| 352 |
+
txt_file=args.txt_file,
|
| 353 |
+
mapping_file=args.mapping_file,
|
| 354 |
+
subset = args.subset,
|
| 355 |
+
transform=train_transform,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if trainset_only:
|
| 361 |
+
return train_dataset
|
| 362 |
+
|
| 363 |
+
train_loader = torch.utils.data.DataLoader(
|
| 364 |
+
train_dataset,
|
| 365 |
+
batch_size=args.batch_size,
|
| 366 |
+
shuffle=True,
|
| 367 |
+
num_workers=24,
|
| 368 |
+
pin_memory=True,
|
| 369 |
+
worker_init_fn=set_worker_sharing_strategy,
|
| 370 |
+
)
|
| 371 |
+
if args.subset == 'cifar10':
|
| 372 |
+
val_dataset = torchvision.datasets.CIFAR10(root='/scratch/zhao.lin1/', train=False, download=True, transform=test_transform)
|
| 373 |
+
elif args.subset == 'cifar100':
|
| 374 |
+
val_dataset = torchvision.datasets.CIFAR100(root='/scratch/zhao.lin1/', train=False, download=True, transform=test_transform)
|
| 375 |
+
elif args.subset == 'imagewoof':
|
| 376 |
+
val_dataset = Imagewoof(
|
| 377 |
+
split = 'test',
|
| 378 |
+
txt_file=args.val_txt_file,
|
| 379 |
+
mapping_file=args.mapping_file,
|
| 380 |
+
subset = args.subset,
|
| 381 |
+
transform=test_transform,
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
val_dataset = ImageFolder(
|
| 385 |
+
split = 'test',
|
| 386 |
+
txt_file=args.val_txt_file,
|
| 387 |
+
mapping_file=args.mapping_file,
|
| 388 |
+
subset = args.subset,
|
| 389 |
+
transform=test_transform,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
val_loader = torch.utils.data.DataLoader(
|
| 395 |
+
val_dataset,
|
| 396 |
+
batch_size=256,
|
| 397 |
+
shuffle=False,
|
| 398 |
+
num_workers=24,
|
| 399 |
+
pin_memory=True,
|
| 400 |
+
worker_init_fn=set_worker_sharing_strategy,
|
| 401 |
+
)
|
| 402 |
+
print("load data successfully")
|
| 403 |
+
|
| 404 |
+
return train_dataset, train_loader, val_loader
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class ShufflePatches(torch.nn.Module):
|
| 408 |
+
def __init__(self, factor):
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.factor = factor
|
| 411 |
+
|
| 412 |
+
def shuffle_weight(self, img, factor):
|
| 413 |
+
h, w = img.shape[1:]
|
| 414 |
+
tw = w // factor
|
| 415 |
+
patches = []
|
| 416 |
+
for i in range(factor):
|
| 417 |
+
i = i * tw
|
| 418 |
+
if i != factor - 1:
|
| 419 |
+
patches.append(img[..., i : i + tw])
|
| 420 |
+
else:
|
| 421 |
+
patches.append(img[..., i:])
|
| 422 |
+
random.shuffle(patches)
|
| 423 |
+
img = torch.cat(patches, -1)
|
| 424 |
+
return img
|
| 425 |
+
|
| 426 |
+
def forward(self, img):
|
| 427 |
+
img = self.shuffle_weight(img, self.factor)
|
| 428 |
+
img = img.permute(0, 2, 1)
|
| 429 |
+
img = self.shuffle_weight(img, self.factor)
|
| 430 |
+
img = img.permute(0, 2, 1)
|
| 431 |
+
return img
|
var/D3HR/validation/utils/download.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Functions for downloading pre-trained DiT models
|
| 9 |
+
"""
|
| 10 |
+
from torchvision.datasets.utils import download_url
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_model(model_name):
|
| 19 |
+
"""
|
| 20 |
+
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
| 21 |
+
"""
|
| 22 |
+
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
|
| 23 |
+
return download_model(model_name)
|
| 24 |
+
else: # Load a custom DiT checkpoint:
|
| 25 |
+
assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
|
| 26 |
+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
|
| 27 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 28 |
+
checkpoint = checkpoint["ema"]
|
| 29 |
+
return checkpoint
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def download_model(model_name):
|
| 33 |
+
"""
|
| 34 |
+
Downloads a pre-trained DiT model from the web.
|
| 35 |
+
"""
|
| 36 |
+
assert model_name in pretrained_models
|
| 37 |
+
local_path = f'pretrained_models/{model_name}'
|
| 38 |
+
if not os.path.isfile(local_path):
|
| 39 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 40 |
+
web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
|
| 41 |
+
download_url(web_path, 'pretrained_models')
|
| 42 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
# Download all DiT checkpoints
|
| 48 |
+
for model in pretrained_models:
|
| 49 |
+
download_model(model)
|
| 50 |
+
print('Done.')
|
var/D3HR/validation/utils/syn_utils_dit.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision import transforms as T
|
| 8 |
+
from torchvision.models import resnet18
|
| 9 |
+
from transformers import CLIPModel, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
from .download import find_model
|
| 12 |
+
from diffusion import create_diffusion
|
| 13 |
+
from models.dit_models import DiT_models
|
| 14 |
+
from diffusers.models import AutoencoderKL
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SupConLoss(nn.Module):
|
| 18 |
+
def __init__(self, temperature=0.05, base_temperatue=0.05):
|
| 19 |
+
super(SupConLoss, self).__init__()
|
| 20 |
+
self.temperature = temperature
|
| 21 |
+
self.base_temperature = base_temperatue
|
| 22 |
+
|
| 23 |
+
def forward(self, image_features, text_features, text_labels):
|
| 24 |
+
logits = (image_features @ text_features.T) / self.temperature
|
| 25 |
+
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
|
| 26 |
+
logits = logits - logits_max.detach()
|
| 27 |
+
|
| 28 |
+
exp_logits = torch.exp(logits) * text_labels
|
| 29 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
| 30 |
+
mean_log_prob_pos = ((1 - text_labels) * log_prob).sum(1) / (1 - text_labels).sum(1)
|
| 31 |
+
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
| 32 |
+
loss = loss.mean()
|
| 33 |
+
|
| 34 |
+
return loss
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ImageSynthesizer(object):
|
| 38 |
+
def __init__(self, args):
|
| 39 |
+
self.vae = AutoencoderKL.from_pretrained(args.vae_path).to('cuda')
|
| 40 |
+
self.clip_model = CLIPModel.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K').to('cuda')
|
| 41 |
+
self.clip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K')
|
| 42 |
+
|
| 43 |
+
# DiT model
|
| 44 |
+
assert args.dit_image_size % 8 == 0, 'Image size must be divisible by 8'
|
| 45 |
+
latent_size = args.dit_image_size // 8
|
| 46 |
+
self.latent_size = latent_size
|
| 47 |
+
self.dit = DiT_models[args.dit_model](
|
| 48 |
+
input_size=latent_size,
|
| 49 |
+
num_classes=args.num_dit_classes
|
| 50 |
+
).to('cuda')
|
| 51 |
+
ckpt_path = args.ckpt
|
| 52 |
+
state_dict = find_model(ckpt_path)
|
| 53 |
+
self.dit.load_state_dict(state_dict, strict=False)
|
| 54 |
+
|
| 55 |
+
# Diffusion
|
| 56 |
+
self.diffusion = create_diffusion(str(args.diffusion_steps))
|
| 57 |
+
|
| 58 |
+
# Class description
|
| 59 |
+
self.description_file = args.description_path
|
| 60 |
+
self.load_class_description()
|
| 61 |
+
|
| 62 |
+
self.cfg_scale = args.cfg_scale
|
| 63 |
+
self.clip_alpha = args.clip_alpha
|
| 64 |
+
self.cls_alpha = args.cls_alpha
|
| 65 |
+
self.num_pos_samples = 5
|
| 66 |
+
self.num_neg_samples = args.num_neg_samples
|
| 67 |
+
self.clip_normalize = T.Normalize(
|
| 68 |
+
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
|
| 69 |
+
)
|
| 70 |
+
self.contrast_criterion = SupConLoss()
|
| 71 |
+
self.neg_policy = args.neg_policy
|
| 72 |
+
|
| 73 |
+
def load_class_description(self):
|
| 74 |
+
with open(self.description_file, 'r') as fp:
|
| 75 |
+
descriptions = json.load(fp)
|
| 76 |
+
self.class_names = {}
|
| 77 |
+
self.descriptions = {}
|
| 78 |
+
|
| 79 |
+
for class_index, (class_name, description) in descriptions.items():
|
| 80 |
+
self.class_names[class_index] = class_name
|
| 81 |
+
self.descriptions[class_index] = description
|
| 82 |
+
|
| 83 |
+
self.class_indices = list(self.class_names.keys())
|
| 84 |
+
self.class_name_list = list(self.class_names.values())
|
| 85 |
+
self.neighbors = {}
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
class_name_feat = self.extract_text_feature(self.class_name_list)
|
| 88 |
+
name_sims = (class_name_feat @ class_name_feat.T).cpu()
|
| 89 |
+
name_sims -= torch.eye(len(name_sims))
|
| 90 |
+
name_sims = name_sims.numpy()
|
| 91 |
+
for class_index, sim_indices in zip(self.class_indices, name_sims):
|
| 92 |
+
self.neighbors[class_index] = list(sim_indices)
|
| 93 |
+
|
| 94 |
+
def extract_text_feature(self, descriptions):
|
| 95 |
+
input_text = self.clip_tokenizer(descriptions, padding=True, return_tensors='pt').to('cuda')
|
| 96 |
+
text_feature = self.clip_model.get_text_features(**input_text)
|
| 97 |
+
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
| 98 |
+
return text_feature
|
| 99 |
+
|
| 100 |
+
def cond_fn(self, x, t, y=None, text_features=None, contrastive=True, **kwargs):
|
| 101 |
+
with torch.enable_grad():
|
| 102 |
+
x = nn.Parameter(x).requires_grad_()
|
| 103 |
+
real_x, _ = x.chunk(2, dim=0)
|
| 104 |
+
pseudo_image = self.vae.decode(real_x / 0.18215, return_dict=False)[0]
|
| 105 |
+
pseudo_image = T.Resize((224, 224))(pseudo_image) * 0.5 + 0.5
|
| 106 |
+
pseudo_image = self.clip_normalize(pseudo_image)
|
| 107 |
+
|
| 108 |
+
# Extract image embedding
|
| 109 |
+
clip_feat_image = self.clip_model.get_image_features(pseudo_image)
|
| 110 |
+
clip_feat_image = clip_feat_image / clip_feat_image.norm(dim=-1, keepdim=True)
|
| 111 |
+
|
| 112 |
+
# Extract text embedding
|
| 113 |
+
clip_feat_text_pos, clip_feat_text_neg = torch.split(
|
| 114 |
+
text_features, [self.num_pos_samples, self.num_neg_samples]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if contrastive:
|
| 118 |
+
clip_loss = self.contrast_criterion(
|
| 119 |
+
clip_feat_image, torch.cat((clip_feat_text_pos, clip_feat_text_neg), dim=0),
|
| 120 |
+
torch.cat((torch.zeros(self.num_pos_samples), torch.ones(self.num_neg_samples))).unsqueeze(0).cuda()
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
clip_loss = 1. - (clip_feat_image @ clip_feat_text_pos.T).mean()
|
| 124 |
+
|
| 125 |
+
loss = self.clip_alpha * clip_loss
|
| 126 |
+
|
| 127 |
+
return -torch.autograd.grad(loss, x, allow_unused=True)[0]
|
| 128 |
+
|
| 129 |
+
def sample(self, original_label, class_index, batch_size=1, device=None):
|
| 130 |
+
z = torch.randn(batch_size, 4, self.latent_size, self.latent_size, device=device)
|
| 131 |
+
y = torch.tensor([original_label] * batch_size, device=device)
|
| 132 |
+
|
| 133 |
+
# classifier-free guidance
|
| 134 |
+
z = torch.cat([z, z], 0)
|
| 135 |
+
y_null = torch.tensor([1000] * batch_size, device=device)
|
| 136 |
+
y = torch.cat([y, y_null], 0)
|
| 137 |
+
|
| 138 |
+
pos_descriptions = self.descriptions[class_index]
|
| 139 |
+
pos_descriptions = [self.class_names[class_index]+' with '+description for description in pos_descriptions]
|
| 140 |
+
neg_descriptions = []
|
| 141 |
+
if self.neg_policy == 'random':
|
| 142 |
+
neg_classes = random.choices(self.class_indices, k=self.num_neg_samples)
|
| 143 |
+
elif self.neg_policy == 'similar':
|
| 144 |
+
max_indices = np.argsort(self.neighbors[class_index])[-self.num_neg_samples:]
|
| 145 |
+
neg_classes = [self.class_indices[max_index] for max_index in max_indices]
|
| 146 |
+
else:
|
| 147 |
+
neg_classes = random.choices(self.class_indices, self.neighbors[class_index], k=self.num_neg_samples)
|
| 148 |
+
for rand_index in neg_classes:
|
| 149 |
+
neg_descriptions.append(self.class_names[rand_index] + ' with ' + self.descriptions[rand_index][np.random.randint(0, 4)])
|
| 150 |
+
all_descriptions = pos_descriptions + neg_descriptions
|
| 151 |
+
text_features = self.extract_text_feature(all_descriptions)
|
| 152 |
+
|
| 153 |
+
model_kwargs = dict(
|
| 154 |
+
y=y, cfg_scale=self.cfg_scale,
|
| 155 |
+
text_features=text_features, contrastive=True
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def get_samples(z):
|
| 159 |
+
samples = self.diffusion.ddim_sample_loop(
|
| 160 |
+
self.dit.forward_with_cfg, z.shape, z, clip_denoised=False,
|
| 161 |
+
model_kwargs=model_kwargs, progress=False, device=device,
|
| 162 |
+
cond_fn=self.cond_fn
|
| 163 |
+
)
|
| 164 |
+
samples, _ = samples.chunk(2, dim=0)
|
| 165 |
+
samples = self.vae.decode(samples / 0.18215).sample
|
| 166 |
+
samples = T.Resize((224, 224))(samples)
|
| 167 |
+
|
| 168 |
+
return samples
|
| 169 |
+
|
| 170 |
+
samples = get_samples(z)
|
| 171 |
+
|
| 172 |
+
return samples
|
var/D3HR/validation/utils/syn_utils_img2img.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torchvision import transforms as T
|
| 6 |
+
from transformers import CLIPModel, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from misc import prompts
|
| 9 |
+
from models.scheduling_ddim import DDIMScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SupConLoss(torch.nn.Module):
|
| 13 |
+
def __init__(self, temperature=0.05, base_temperatue=0.05):
|
| 14 |
+
super(SupConLoss, self).__init__()
|
| 15 |
+
self.temperature = temperature
|
| 16 |
+
self.base_temperature = base_temperatue
|
| 17 |
+
|
| 18 |
+
def forward(self, image_features, text_features, text_labels):
|
| 19 |
+
logits = (image_features @ text_features.T) / self.temperature
|
| 20 |
+
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
|
| 21 |
+
logits = logits - logits_max.detach()
|
| 22 |
+
|
| 23 |
+
exp_logits = torch.exp(logits) * text_labels
|
| 24 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
| 25 |
+
mean_log_prob_pos = ((1 - text_labels) * log_prob).sum(1) / (1 - text_labels).sum(1)
|
| 26 |
+
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
| 27 |
+
loss = loss.mean()
|
| 28 |
+
|
| 29 |
+
return loss
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ImageSynthesizer(object):
|
| 33 |
+
def __init__(self, args):
|
| 34 |
+
self.init_clip()
|
| 35 |
+
self.description_file = args.description_path
|
| 36 |
+
self.load_class_description()
|
| 37 |
+
self.contrast_criterion = SupConLoss()
|
| 38 |
+
|
| 39 |
+
self.prompts = prompts.prompt_templates
|
| 40 |
+
self.diffusion_steps = args.diffusion_steps
|
| 41 |
+
self.clip_alpha = args.clip_alpha
|
| 42 |
+
self.num_neg_samples = args.num_neg_samples
|
| 43 |
+
self.neg_policy = args.neg_policy
|
| 44 |
+
|
| 45 |
+
def load_class_description(self):
|
| 46 |
+
with open(self.description_file, 'r') as fp:
|
| 47 |
+
descriptions = json.load(fp)
|
| 48 |
+
self.class_names = {}
|
| 49 |
+
self.descriptions = {}
|
| 50 |
+
|
| 51 |
+
for class_index, (class_name, description) in descriptions.items():
|
| 52 |
+
self.class_names[class_index] = class_name
|
| 53 |
+
self.descriptions[class_index] = description
|
| 54 |
+
|
| 55 |
+
self.class_indices = list(self.class_names.keys())
|
| 56 |
+
self.class_name_list = list(self.class_names.values())
|
| 57 |
+
self.neighbors = {}
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
class_name_feat = self.extract_clip_text_embed(self.class_name_list)
|
| 60 |
+
name_sims = (class_name_feat @ class_name_feat.T).cpu()
|
| 61 |
+
name_sims -= torch.eye(len(name_sims))
|
| 62 |
+
name_sims = name_sims.numpy()
|
| 63 |
+
for class_index, sim_indices in zip(self.class_indices, name_sims):
|
| 64 |
+
self.neighbors[class_index] = list(sim_indices)
|
| 65 |
+
|
| 66 |
+
def init_clip(self):
|
| 67 |
+
self.clip_model = CLIPModel.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K').to('cuda')
|
| 68 |
+
self.clip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K')
|
| 69 |
+
self.clip_normalize = T.Normalize(
|
| 70 |
+
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def extract_clip_image_embed(self, image):
|
| 74 |
+
image = self.clip_transform(image).unsqueeze(0)
|
| 75 |
+
clip_feat = self.clip_model.encode_image(image)
|
| 76 |
+
return clip_feat
|
| 77 |
+
|
| 78 |
+
def extract_clip_text_embed(self, descriptions):
|
| 79 |
+
input_text = self.clip_tokenizer(descriptions, padding=True, return_tensors='pt').to('cuda')
|
| 80 |
+
text_features = self.clip_model.get_text_features(**input_text)
|
| 81 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 82 |
+
return text_features
|
| 83 |
+
|
| 84 |
+
def cond_fn(self, sample, **kwargs):
|
| 85 |
+
with torch.enable_grad():
|
| 86 |
+
sample = torch.nn.Parameter(sample).requires_grad_()
|
| 87 |
+
pseudo_image = self.pipe.vae.decode(sample / 0.18215, return_dict=False)[0]
|
| 88 |
+
pseudo_image = T.Resize((224, 224))(pseudo_image) * 0.5 + 0.5
|
| 89 |
+
pseudo_image = self.clip_normalize(pseudo_image)
|
| 90 |
+
|
| 91 |
+
# Extract image embedding
|
| 92 |
+
clip_feat_image = self.clip_model.get_image_features(pseudo_image)
|
| 93 |
+
clip_feat_image = clip_feat_image / clip_feat_image.norm(dim=-1, keepdim=True)
|
| 94 |
+
|
| 95 |
+
clip_loss = self.contrast_criterion(
|
| 96 |
+
clip_feat_image, self.current_desc_embeddings,
|
| 97 |
+
torch.cat((torch.zeros(5), torch.ones(len(self.current_desc_embeddings) - 5))).unsqueeze(0).cuda()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
loss = self.clip_alpha * clip_loss
|
| 101 |
+
|
| 102 |
+
return -torch.autograd.grad(loss, sample, allow_unused=True)[0]
|
| 103 |
+
|
| 104 |
+
def init_img2img(self):
|
| 105 |
+
from models.pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
| 106 |
+
self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 107 |
+
'radames/stable-diffusion-2-1-unclip-img2img'
|
| 108 |
+
)
|
| 109 |
+
self.pipe.scheduler = DDIMScheduler.from_pretrained('radames/stable-diffusion-2-1-unclip-img2img', subfolder='scheduler')
|
| 110 |
+
self.pipe = self.pipe.to('cuda')
|
| 111 |
+
|
| 112 |
+
def sample_img2img(self, image, class_index):
|
| 113 |
+
class_name = self.class_names[class_index]
|
| 114 |
+
class_name = class_name.split(',')[0]
|
| 115 |
+
pos_descriptions = self.descriptions[class_index]
|
| 116 |
+
prompt = random.choice(self.prompts).format(class_name, '')
|
| 117 |
+
|
| 118 |
+
pos_descriptions = [self.class_names[class_index]+' with '+description for description in pos_descriptions]
|
| 119 |
+
neg_descriptions = []
|
| 120 |
+
if self.neg_policy == 'random':
|
| 121 |
+
neg_classes = random.choices(self.class_indices, k=self.num_neg_samples)
|
| 122 |
+
elif self.neg_policy == 'similar':
|
| 123 |
+
max_indices = np.argsort(self.neighbors[class_index])[-self.num_neg_samples:]
|
| 124 |
+
neg_classes = [self.class_indices[max_index] for max_index in max_indices]
|
| 125 |
+
else:
|
| 126 |
+
neg_classes = random.choices(self.class_indices, self.neighbors[class_index], k=self.num_neg_samples)
|
| 127 |
+
for rand_index in neg_classes:
|
| 128 |
+
neg_descriptions.append(self.class_names[rand_index] + 'with ' + self.descriptions[rand_index][np.random.randint(0, 4)])
|
| 129 |
+
self.current_desc_embeddings = self.extract_clip_text_embed(pos_descriptions + neg_descriptions)
|
| 130 |
+
new_image = self.pipe(image=image, prompt=prompt, cond_fn=self.cond_fn, num_inference_steps=self.diffusion_steps).images[0]
|
| 131 |
+
|
| 132 |
+
new_image = new_image.resize((224, 224))
|
| 133 |
+
|
| 134 |
+
return new_image
|