Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- New/REG/README.md +156 -0
- New/REG/eval.sh +52 -0
- New/REG/generate.py +253 -0
- New/REG/loss.py +102 -0
- New/REG/requirements.txt +97 -0
- back/evaluations/README.md +72 -0
- back/evaluations/evaluator.py +679 -0
- back/evaluations/requirements.txt +4 -0
- back/models/__pycache__/mocov3_vit.cpython-310.pyc +0 -0
- back/models/__pycache__/mocov3_vit.cpython-312.pyc +0 -0
- back/models/__pycache__/sit.cpython-310.pyc +0 -0
- back/models/__pycache__/sit.cpython-312.pyc +0 -0
- back/models/clip_vit.py +426 -0
- back/models/jepa.py +547 -0
- back/models/mae_vit.py +71 -0
- back/models/mocov3_vit.py +207 -0
- back/models/sit.py +420 -0
- back/preprocessing/README.md +25 -0
- back/preprocessing/dataset_image_encoder.py +353 -0
- back/preprocessing/dataset_prepare_convert.sh +11 -0
- back/preprocessing/dataset_prepare_encode.sh +9 -0
- back/preprocessing/dataset_tools.py +422 -0
- back/preprocessing/dnnlib/__init__.py +8 -0
- back/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc +0 -0
- back/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc +0 -0
- back/preprocessing/dnnlib/util.py +485 -0
- back/preprocessing/encoders.py +103 -0
- back/preprocessing/torch_utils/__init__.py +8 -0
- back/preprocessing/torch_utils/distributed.py +140 -0
- back/preprocessing/torch_utils/misc.py +277 -0
- back/preprocessing/torch_utils/persistence.py +257 -0
- back/preprocessing/torch_utils/training_stats.py +283 -0
- back/wandb/debug-internal.log +19 -0
- back/wandb/debug.log +20 -0
- back/wandb/run-20260322_141726-2yw08kz9/files/config.yaml +203 -0
- back/wandb/run-20260322_141726-2yw08kz9/files/output.log +27 -0
- back/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt +168 -0
- back/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json +101 -0
- back/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json +1 -0
- back/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log +7 -0
- back/wandb/run-20260322_141726-2yw08kz9/logs/debug.log +22 -0
- back/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb +0 -0
- back/wandb/run-20260322_141833-vm0y8t9t/files/output.log +0 -0
- back/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt +168 -0
- back/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json +101 -0
- back/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log +6 -0
- back/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log +20 -0
- back/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml +202 -0
- back/wandb/run-20260322_150022-yhxc5cgu/files/output.log +19 -0
- back/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt +168 -0
New/REG/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
|
| 3 |
+
</h1>
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup> 
|
| 6 |
+
<a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup> 
|
| 7 |
+
<a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup> 
|
| 8 |
+
<a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup> 
|
| 9 |
+
<a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup> 
|
| 10 |
+
<a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup> 
|
| 11 |
+
<a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup> 
|
| 12 |
+
<a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup> 
|
| 13 |
+
<a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup> 
|
| 14 |
+
<a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup> 
|
| 15 |
+
<a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup> 
|
| 16 |
+
<a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup> 
|
| 17 |
+
<p align="center">
|
| 18 |
+
$^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
|
| 19 |
+
$^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
|
| 20 |
+
<p align='center'>
|
| 21 |
+
<div align="center">
|
| 22 |
+
<a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
|
| 23 |
+
<a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
|
| 24 |
+
<a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
|
| 25 |
+
</div>
|
| 26 |
+
<p align='center'>
|
| 27 |
+
</p>
|
| 28 |
+
</p>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## 🚩 Overview
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
|
| 37 |
+
We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
|
| 38 |
+
|
| 39 |
+
In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
|
| 40 |
+
REG acquires the capability to produce coherent image-class pairs directly from pure noise,
|
| 41 |
+
substantially improving both generation quality and training efficiency.
|
| 42 |
+
This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
|
| 43 |
+
The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
|
| 44 |
+
|
| 45 |
+
On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
|
| 46 |
+
More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## 📰 News
|
| 51 |
+
|
| 52 |
+
- **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## 📝 Results
|
| 56 |
+
|
| 57 |
+
- Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
|
| 58 |
+
- $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
|
| 59 |
+
|
| 60 |
+
<div align="center">
|
| 61 |
+
<img src="fig/img.png" alt="Results">
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 📋 Plan
|
| 66 |
+
- More training steps on ImageNet 256&512 and T2I.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 👊 Usage
|
| 70 |
+
|
| 71 |
+
### 1. Environment setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
conda create -n reg python=3.10.16 -y
|
| 75 |
+
conda activate reg
|
| 76 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
|
| 77 |
+
pip install -r requirements.txt
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 2. Dataset
|
| 81 |
+
|
| 82 |
+
#### Dataset download
|
| 83 |
+
|
| 84 |
+
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
|
| 85 |
+
|
| 86 |
+
#### Preprocessing data
|
| 87 |
+
Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
|
| 88 |
+
|
| 89 |
+
### 3. Training
|
| 90 |
+
Run train.sh
|
| 91 |
+
```bash
|
| 92 |
+
bash train.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
train.sh contains the following content.
|
| 96 |
+
```bash
|
| 97 |
+
accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
|
| 98 |
+
--report-to="wandb" \
|
| 99 |
+
--allow-tf32 \
|
| 100 |
+
--mixed-precision="fp16" \
|
| 101 |
+
--seed=0 \
|
| 102 |
+
--path-type="linear" \
|
| 103 |
+
--prediction="v" \
|
| 104 |
+
--weighting="uniform" \
|
| 105 |
+
--model="SiT-B/2" \
|
| 106 |
+
--enc-type="dinov2-vit-b" \
|
| 107 |
+
--proj-coeff=0.5 \
|
| 108 |
+
--encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
|
| 109 |
+
--output-dir="your_path" \
|
| 110 |
+
--exp-name="linear-dinov2-b-enc4" \
|
| 111 |
+
--batch-size=256 \
|
| 112 |
+
--data-dir="data_path/imagenet_vae" \
|
| 113 |
+
--cls=0.03
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
|
| 117 |
+
|
| 118 |
+
- `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
|
| 119 |
+
- `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
|
| 120 |
+
- `--proj-coeff`: Any values larger than 0
|
| 121 |
+
- `--encoder-depth`: Any values between 1 to the depth of the model
|
| 122 |
+
- `--output-dir`: Any directory that you want to save checkpoints and logs
|
| 123 |
+
- `--exp-name`: Any string name (the folder will be created under `output-dir`)
|
| 124 |
+
- `--cls`: Weight coefficients of REG loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
### 4. Generate images and evaluation
|
| 128 |
+
You can generate images and get the final results through the following script.
|
| 129 |
+
The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
bash eval.sh
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Citation
|
| 137 |
+
If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
|
| 138 |
+
```
|
| 139 |
+
@article{wu2025representation,
|
| 140 |
+
title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
|
| 141 |
+
author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
|
| 142 |
+
journal={arXiv preprint arXiv:2507.01467},
|
| 143 |
+
year={2025}
|
| 144 |
+
}
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Contact
|
| 148 |
+
If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Acknowledgements
|
| 152 |
+
|
| 153 |
+
Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
New/REG/eval.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
random_number=$((RANDOM % 100 + 1200))
|
| 3 |
+
NUM_GPUS=8
|
| 4 |
+
STEP="4000000"
|
| 5 |
+
SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
|
| 6 |
+
VAE_PATH="your_vae_path/"
|
| 7 |
+
NUM_STEP=250
|
| 8 |
+
MODEL_SIZE='XL'
|
| 9 |
+
CFG_SCALE=2.3
|
| 10 |
+
CLS_CFG_SCALE=2.3
|
| 11 |
+
GH=0.85
|
| 12 |
+
|
| 13 |
+
export NCCL_P2P_DISABLE=1
|
| 14 |
+
|
| 15 |
+
python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
|
| 16 |
+
--model SiT-XL/2 \
|
| 17 |
+
--num-fid-samples 50000 \
|
| 18 |
+
--ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
|
| 19 |
+
--path-type=linear \
|
| 20 |
+
--encoder-depth=8 \
|
| 21 |
+
--projector-embed-dims=768 \
|
| 22 |
+
--per-proc-batch-size=64 \
|
| 23 |
+
--mode=sde \
|
| 24 |
+
--num-steps=${NUM_STEP} \
|
| 25 |
+
--cfg-scale=${CFG_SCALE} \
|
| 26 |
+
--cls-cfg-scale=${CLS_CFG_SCALE} \
|
| 27 |
+
--guidance-high=${GH} \
|
| 28 |
+
--sample-dir ${SAVE_PATH}/checkpoints \
|
| 29 |
+
--cls=768
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
python ./evaluations/evaluator.py \
|
| 33 |
+
--ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
|
| 34 |
+
--sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
|
| 35 |
+
--save_path ${SAVE_PATH}/checkpoints \
|
| 36 |
+
--cfg_cond 1 \
|
| 37 |
+
--step ${STEP} \
|
| 38 |
+
--num_steps ${NUM_STEP} \
|
| 39 |
+
--cfg ${CFG_SCALE} \
|
| 40 |
+
--cls_cfg ${CLS_CFG_SCALE} \
|
| 41 |
+
--gh ${GH}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
New/REG/generate.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 9 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 10 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 11 |
+
|
| 12 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from models.sit import SiT_models
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
from samplers import euler_maruyama_sampler, euler_sampler
|
| 25 |
+
from utils import load_legacy_checkpoints, download_model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 29 |
+
"""
|
| 30 |
+
Builds a single .npz file from a folder of .png samples.
|
| 31 |
+
"""
|
| 32 |
+
samples = []
|
| 33 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 34 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 35 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 36 |
+
samples.append(sample_np)
|
| 37 |
+
samples = np.stack(samples)
|
| 38 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 39 |
+
npz_path = f"{sample_dir}.npz"
|
| 40 |
+
np.savez(npz_path, arr_0=samples)
|
| 41 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 42 |
+
return npz_path
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(args):
|
| 46 |
+
"""
|
| 47 |
+
Run sampling.
|
| 48 |
+
"""
|
| 49 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 50 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 51 |
+
torch.set_grad_enabled(False)
|
| 52 |
+
|
| 53 |
+
# Setup DDP:cd
|
| 54 |
+
dist.init_process_group("nccl")
|
| 55 |
+
rank = dist.get_rank()
|
| 56 |
+
device = rank % torch.cuda.device_count()
|
| 57 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 58 |
+
torch.manual_seed(seed)
|
| 59 |
+
torch.cuda.set_device(device)
|
| 60 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 61 |
+
|
| 62 |
+
# Load model:
|
| 63 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 64 |
+
latent_size = args.resolution // 8
|
| 65 |
+
model = SiT_models[args.model](
|
| 66 |
+
input_size=latent_size,
|
| 67 |
+
num_classes=args.num_classes,
|
| 68 |
+
use_cfg = True,
|
| 69 |
+
z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
|
| 70 |
+
encoder_depth=args.encoder_depth,
|
| 71 |
+
**block_kwargs,
|
| 72 |
+
).to(device)
|
| 73 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 74 |
+
ckpt_path = args.ckpt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 78 |
+
if ckpt_path is None:
|
| 79 |
+
args.ckpt = 'SiT-XL-2-256x256.pt'
|
| 80 |
+
assert args.model == 'SiT-XL/2'
|
| 81 |
+
assert len(args.projector_embed_dims.split(',')) == 1
|
| 82 |
+
assert int(args.projector_embed_dims.split(',')[0]) == 768
|
| 83 |
+
state_dict = download_model('last.pt')
|
| 84 |
+
else:
|
| 85 |
+
state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
|
| 86 |
+
|
| 87 |
+
if args.legacy:
|
| 88 |
+
state_dict = load_legacy_checkpoints(
|
| 89 |
+
state_dict=state_dict, encoder_depth=args.encoder_depth
|
| 90 |
+
)
|
| 91 |
+
model.load_state_dict(state_dict)
|
| 92 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
model.eval() # important!
|
| 96 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 97 |
+
#vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Create folder to save samples:
|
| 101 |
+
model_string_name = args.model.replace("/", "-")
|
| 102 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 103 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
|
| 104 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
|
| 105 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 106 |
+
if rank == 0:
|
| 107 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 108 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 109 |
+
dist.barrier()
|
| 110 |
+
|
| 111 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 112 |
+
n = args.per_proc_batch_size
|
| 113 |
+
global_batch_size = n * dist.get_world_size()
|
| 114 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 115 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 116 |
+
if rank == 0:
|
| 117 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 118 |
+
print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 119 |
+
print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
|
| 120 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 121 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 122 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 123 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 124 |
+
pbar = range(iterations)
|
| 125 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 126 |
+
fixed_noise = None
|
| 127 |
+
if args.fixed_noise_file:
|
| 128 |
+
try:
|
| 129 |
+
fixed_noise = torch.load(args.fixed_noise_file, map_location="cpu", weights_only=True)
|
| 130 |
+
except TypeError:
|
| 131 |
+
fixed_noise = torch.load(args.fixed_noise_file, map_location="cpu")
|
| 132 |
+
for k in ("z", "y", "cls_z"):
|
| 133 |
+
if k not in fixed_noise:
|
| 134 |
+
raise KeyError(f"fixed noise file missing key: {k}")
|
| 135 |
+
if int(fixed_noise["z"].shape[0]) < total_samples:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"fixed noise size={fixed_noise['z'].shape[0]} < required total_samples={total_samples}"
|
| 138 |
+
)
|
| 139 |
+
if rank == 0:
|
| 140 |
+
print(f"Using fixed noise file: {args.fixed_noise_file}")
|
| 141 |
+
|
| 142 |
+
total = 0
|
| 143 |
+
for _ in pbar:
|
| 144 |
+
if fixed_noise is not None:
|
| 145 |
+
idx = torch.arange(total + rank, total + global_batch_size, dist.get_world_size(), dtype=torch.long)
|
| 146 |
+
z = fixed_noise["z"][idx].to(device=device)
|
| 147 |
+
y = fixed_noise["y"][idx].to(device=device)
|
| 148 |
+
cls_z = fixed_noise["cls_z"][idx].to(device=device)
|
| 149 |
+
else:
|
| 150 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 151 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 152 |
+
cls_z = torch.randn(n, args.cls, device=device)
|
| 153 |
+
|
| 154 |
+
# Sample images:
|
| 155 |
+
sampling_kwargs = dict(
|
| 156 |
+
model=model,
|
| 157 |
+
latents=z,
|
| 158 |
+
y=y,
|
| 159 |
+
num_steps=args.num_steps,
|
| 160 |
+
heun=args.heun,
|
| 161 |
+
cfg_scale=args.cfg_scale,
|
| 162 |
+
guidance_low=args.guidance_low,
|
| 163 |
+
guidance_high=args.guidance_high,
|
| 164 |
+
path_type=args.path_type,
|
| 165 |
+
cls_latents=cls_z,
|
| 166 |
+
args=args
|
| 167 |
+
)
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
if args.mode == "sde":
|
| 170 |
+
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
|
| 171 |
+
elif args.mode == "ode":
|
| 172 |
+
samples = euler_sampler(**sampling_kwargs).to(torch.float32)
|
| 173 |
+
else:
|
| 174 |
+
raise NotImplementedError()
|
| 175 |
+
|
| 176 |
+
latents_scale = torch.tensor(
|
| 177 |
+
[0.18215, 0.18215, 0.18215, 0.18215, ]
|
| 178 |
+
).view(1, 4, 1, 1).to(device)
|
| 179 |
+
latents_bias = -torch.tensor(
|
| 180 |
+
[0., 0., 0., 0.,]
|
| 181 |
+
).view(1, 4, 1, 1).to(device)
|
| 182 |
+
samples = vae.decode((samples - latents_bias) / latents_scale).sample
|
| 183 |
+
samples = (samples + 1) / 2.
|
| 184 |
+
samples = torch.clamp(
|
| 185 |
+
255. * samples, 0, 255
|
| 186 |
+
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 187 |
+
|
| 188 |
+
# Save samples to disk as individual .png files
|
| 189 |
+
for i, sample in enumerate(samples):
|
| 190 |
+
index = i * dist.get_world_size() + rank + total
|
| 191 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 192 |
+
total += global_batch_size
|
| 193 |
+
|
| 194 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 195 |
+
dist.barrier()
|
| 196 |
+
if rank == 0:
|
| 197 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 198 |
+
print("Done.")
|
| 199 |
+
dist.barrier()
|
| 200 |
+
dist.destroy_process_group()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
parser = argparse.ArgumentParser()
|
| 205 |
+
# seed
|
| 206 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 207 |
+
|
| 208 |
+
# precision
|
| 209 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 210 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 211 |
+
|
| 212 |
+
# logging/saving:
|
| 213 |
+
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
|
| 214 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 215 |
+
|
| 216 |
+
# model
|
| 217 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 218 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 219 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 220 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 221 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
|
| 222 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 223 |
+
# vae
|
| 224 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 225 |
+
|
| 226 |
+
# number of samples
|
| 227 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
| 228 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 229 |
+
|
| 230 |
+
# sampling related hyperparameters
|
| 231 |
+
parser.add_argument("--mode", type=str, default="ode")
|
| 232 |
+
parser.add_argument("--cfg-scale", type=float, default=1.5)
|
| 233 |
+
parser.add_argument("--cls-cfg-scale", type=float, default=1.5)
|
| 234 |
+
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
|
| 235 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 236 |
+
parser.add_argument("--num-steps", type=int, default=50)
|
| 237 |
+
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 238 |
+
parser.add_argument("--guidance-low", type=float, default=0.)
|
| 239 |
+
parser.add_argument("--guidance-high", type=float, default=1.)
|
| 240 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 241 |
+
parser.add_argument('--cls', default=768, type=int)
|
| 242 |
+
# will be deprecated
|
| 243 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--fixed-noise-file",
|
| 246 |
+
type=str,
|
| 247 |
+
default=None,
|
| 248 |
+
help="Optional .pt with keys z/y/cls_z to force identical initial states across runs.",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
main(args)
|
New/REG/loss.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def mean_flat(x):
|
| 6 |
+
"""
|
| 7 |
+
Take the mean over all non-batch dimensions.
|
| 8 |
+
"""
|
| 9 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
| 10 |
+
|
| 11 |
+
def sum_flat(x):
|
| 12 |
+
"""
|
| 13 |
+
Take the mean over all non-batch dimensions.
|
| 14 |
+
"""
|
| 15 |
+
return torch.sum(x, dim=list(range(1, len(x.size()))))
|
| 16 |
+
|
| 17 |
+
class SILoss:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
prediction='v',
|
| 21 |
+
path_type="linear",
|
| 22 |
+
weighting="uniform",
|
| 23 |
+
encoders=[],
|
| 24 |
+
accelerator=None,
|
| 25 |
+
latents_scale=None,
|
| 26 |
+
latents_bias=None,
|
| 27 |
+
):
|
| 28 |
+
self.prediction = prediction
|
| 29 |
+
self.weighting = weighting
|
| 30 |
+
self.path_type = path_type
|
| 31 |
+
self.encoders = encoders
|
| 32 |
+
self.accelerator = accelerator
|
| 33 |
+
self.latents_scale = latents_scale
|
| 34 |
+
self.latents_bias = latents_bias
|
| 35 |
+
|
| 36 |
+
def interpolant(self, t):
|
| 37 |
+
if self.path_type == "linear":
|
| 38 |
+
alpha_t = 1 - t
|
| 39 |
+
sigma_t = t
|
| 40 |
+
d_alpha_t = -1
|
| 41 |
+
d_sigma_t = 1
|
| 42 |
+
elif self.path_type == "cosine":
|
| 43 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 44 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 45 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 46 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError()
|
| 49 |
+
|
| 50 |
+
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
|
| 51 |
+
|
| 52 |
+
def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
|
| 53 |
+
time_input=None, noises=None,):
|
| 54 |
+
if model_kwargs == None:
|
| 55 |
+
model_kwargs = {}
|
| 56 |
+
# sample timesteps
|
| 57 |
+
if time_input is None:
|
| 58 |
+
if self.weighting == "uniform":
|
| 59 |
+
time_input = torch.rand((images.shape[0], 1, 1, 1))
|
| 60 |
+
elif self.weighting == "lognormal":
|
| 61 |
+
# sample timestep according to log-normal distribution of sigmas following EDM
|
| 62 |
+
rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
|
| 63 |
+
sigma = rnd_normal.exp()
|
| 64 |
+
if self.path_type == "linear":
|
| 65 |
+
time_input = sigma / (1 + sigma)
|
| 66 |
+
elif self.path_type == "cosine":
|
| 67 |
+
time_input = 2 / np.pi * torch.atan(sigma)
|
| 68 |
+
|
| 69 |
+
time_input = time_input.to(device=images.device, dtype=images.dtype)
|
| 70 |
+
|
| 71 |
+
if noises is None:
|
| 72 |
+
noises = torch.randn_like(images)
|
| 73 |
+
noises_cls = torch.randn_like(cls_token)
|
| 74 |
+
|
| 75 |
+
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
|
| 76 |
+
|
| 77 |
+
model_input = alpha_t * images + sigma_t * noises
|
| 78 |
+
cls_input = alpha_t.squeeze(-1).squeeze(-1) * cls_token + sigma_t.squeeze(-1).squeeze(-1) * noises_cls
|
| 79 |
+
if self.prediction == 'v':
|
| 80 |
+
model_target = d_alpha_t * images + d_sigma_t * noises
|
| 81 |
+
cls_target = d_alpha_t * cls_token + d_sigma_t * noises_cls
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
model_output, zs_tilde, cls_output = model(model_input, time_input.flatten(), **model_kwargs,
|
| 86 |
+
cls_token=cls_input)
|
| 87 |
+
|
| 88 |
+
#denoising_loss
|
| 89 |
+
denoising_loss = mean_flat((model_output - model_target) ** 2)
|
| 90 |
+
denoising_loss_cls = mean_flat((cls_output - cls_target) ** 2)
|
| 91 |
+
|
| 92 |
+
# projection loss
|
| 93 |
+
proj_loss = 0.
|
| 94 |
+
bsz = zs[0].shape[0]
|
| 95 |
+
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
|
| 96 |
+
for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
|
| 97 |
+
z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
|
| 98 |
+
z_j = torch.nn.functional.normalize(z_j, dim=-1)
|
| 99 |
+
proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
|
| 100 |
+
proj_loss /= (len(zs) * bsz)
|
| 101 |
+
|
| 102 |
+
return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
|
New/REG/requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- pip:
|
| 2 |
+
absl-py==2.2.2
|
| 3 |
+
accelerate==1.2.1
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.11.16
|
| 6 |
+
aiosignal==1.3.2
|
| 7 |
+
astunparse==1.6.3
|
| 8 |
+
async-timeout==5.0.1
|
| 9 |
+
attrs==25.3.0
|
| 10 |
+
certifi==2022.12.7
|
| 11 |
+
charset-normalizer==2.1.1
|
| 12 |
+
click==8.1.8
|
| 13 |
+
datasets==2.20.0
|
| 14 |
+
diffusers==0.32.1
|
| 15 |
+
dill==0.3.8
|
| 16 |
+
docker-pycreds==0.4.0
|
| 17 |
+
einops==0.8.1
|
| 18 |
+
filelock==3.13.1
|
| 19 |
+
flatbuffers==25.2.10
|
| 20 |
+
frozenlist==1.5.0
|
| 21 |
+
fsspec==2024.5.0
|
| 22 |
+
ftfy==6.3.1
|
| 23 |
+
gast==0.6.0
|
| 24 |
+
gitdb==4.0.12
|
| 25 |
+
gitpython==3.1.44
|
| 26 |
+
google-pasta==0.2.0
|
| 27 |
+
grpcio==1.71.0
|
| 28 |
+
h5py==3.13.0
|
| 29 |
+
huggingface-hub==0.27.1
|
| 30 |
+
idna==3.4
|
| 31 |
+
importlib-metadata==8.6.1
|
| 32 |
+
jinja2==3.1.4
|
| 33 |
+
joblib==1.4.2
|
| 34 |
+
keras==3.9.2
|
| 35 |
+
libclang==18.1.1
|
| 36 |
+
markdown==3.8
|
| 37 |
+
markdown-it-py==3.0.0
|
| 38 |
+
markupsafe==2.1.5
|
| 39 |
+
mdurl==0.1.2
|
| 40 |
+
ml-dtypes==0.3.2
|
| 41 |
+
mpmath==1.3.0
|
| 42 |
+
multidict==6.4.3
|
| 43 |
+
multiprocess==0.70.16
|
| 44 |
+
namex==0.0.8
|
| 45 |
+
networkx==3.3
|
| 46 |
+
numpy==1.26.4
|
| 47 |
+
opt-einsum==3.4.0
|
| 48 |
+
optree==0.15.0
|
| 49 |
+
packaging==24.2
|
| 50 |
+
pandas==2.2.3
|
| 51 |
+
pillow==11.0.0
|
| 52 |
+
platformdirs==4.3.7
|
| 53 |
+
propcache==0.3.1
|
| 54 |
+
protobuf==4.25.6
|
| 55 |
+
psutil==7.0.0
|
| 56 |
+
pyarrow==19.0.1
|
| 57 |
+
pyarrow-hotfix==0.6
|
| 58 |
+
pygments==2.19.1
|
| 59 |
+
python-dateutil==2.9.0.post0
|
| 60 |
+
pytz==2025.2
|
| 61 |
+
pyyaml==6.0.2
|
| 62 |
+
regex==2024.11.6
|
| 63 |
+
requests==2.32.3
|
| 64 |
+
rich==14.0.0
|
| 65 |
+
safetensors==0.5.3
|
| 66 |
+
scikit-learn==1.5.1
|
| 67 |
+
scipy==1.15.2
|
| 68 |
+
sentry-sdk==2.26.1
|
| 69 |
+
setproctitle==1.3.5
|
| 70 |
+
six==1.17.0
|
| 71 |
+
smmap==5.0.2
|
| 72 |
+
sympy==1.13.1
|
| 73 |
+
tensorboard==2.16.1
|
| 74 |
+
tensorboard-data-server==0.7.2
|
| 75 |
+
tensorflow==2.16.1
|
| 76 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
| 77 |
+
termcolor==3.0.1
|
| 78 |
+
tf-keras==2.16.0
|
| 79 |
+
threadpoolctl==3.6.0
|
| 80 |
+
timm==1.0.12
|
| 81 |
+
tokenizers==0.21.0
|
| 82 |
+
tqdm==4.67.1
|
| 83 |
+
transformers==4.47.0
|
| 84 |
+
triton==2.1.0
|
| 85 |
+
typing-extensions==4.12.2
|
| 86 |
+
tzdata==2025.2
|
| 87 |
+
urllib3==1.26.13
|
| 88 |
+
wandb==0.17.6
|
| 89 |
+
wcwidth==0.2.13
|
| 90 |
+
werkzeug==3.1.3
|
| 91 |
+
wrapt==1.17.2
|
| 92 |
+
xformer==1.0.1
|
| 93 |
+
xformers==0.0.23
|
| 94 |
+
xxhash==3.5.0
|
| 95 |
+
yarl==1.20.0
|
| 96 |
+
zipp==3.21.0
|
| 97 |
+
|
back/evaluations/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluations
|
| 2 |
+
|
| 3 |
+
To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
|
| 4 |
+
|
| 5 |
+
# Download batches
|
| 6 |
+
|
| 7 |
+
We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
|
| 8 |
+
|
| 9 |
+
Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
|
| 10 |
+
|
| 11 |
+
Here are links to download all of the sample and reference batches:
|
| 12 |
+
|
| 13 |
+
* LSUN
|
| 14 |
+
* LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
|
| 15 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
|
| 16 |
+
* [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
|
| 17 |
+
* [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
|
| 18 |
+
* [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
|
| 19 |
+
* LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
|
| 20 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
|
| 21 |
+
* [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
|
| 22 |
+
* LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
|
| 23 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
|
| 24 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
|
| 25 |
+
|
| 26 |
+
* ImageNet
|
| 27 |
+
* ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
|
| 28 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
|
| 29 |
+
* [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
|
| 30 |
+
* [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
|
| 31 |
+
* ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
|
| 32 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
|
| 33 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
|
| 34 |
+
* [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
|
| 35 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
|
| 36 |
+
* ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
|
| 37 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
|
| 38 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
|
| 39 |
+
* [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
|
| 40 |
+
* [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
|
| 41 |
+
* [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
|
| 42 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
|
| 43 |
+
* ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
|
| 44 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
|
| 45 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
|
| 46 |
+
* [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
|
| 47 |
+
* [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
|
| 48 |
+
* [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
|
| 49 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
|
| 50 |
+
|
| 51 |
+
# Run evaluations
|
| 52 |
+
|
| 53 |
+
First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
|
| 54 |
+
|
| 55 |
+
Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
|
| 56 |
+
|
| 57 |
+
The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
$ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
|
| 61 |
+
...
|
| 62 |
+
computing reference batch activations...
|
| 63 |
+
computing/reading reference batch statistics...
|
| 64 |
+
computing sample batch activations...
|
| 65 |
+
computing/reading sample batch statistics...
|
| 66 |
+
Computing evaluations...
|
| 67 |
+
Inception Score: 215.8370361328125
|
| 68 |
+
FID: 3.9425574129223264
|
| 69 |
+
sFID: 6.140433703346162
|
| 70 |
+
Precision: 0.8265
|
| 71 |
+
Recall: 0.5309
|
| 72 |
+
```
|
back/evaluations/evaluator.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("--ref_batch", help="path to reference batch npz file")
|
| 30 |
+
parser.add_argument("--sample_batch", help="path to sample batch npz file")
|
| 31 |
+
parser.add_argument("--save_path", help="path to sample batch npz file")
|
| 32 |
+
parser.add_argument("--cfg_cond", default=1, type=int)
|
| 33 |
+
parser.add_argument("--step", default=1, type=int)
|
| 34 |
+
parser.add_argument("--cfg", default=1.0, type=float)
|
| 35 |
+
parser.add_argument("--cls_cfg", default=1.0, type=float)
|
| 36 |
+
parser.add_argument("--gh", default=1.0, type=float)
|
| 37 |
+
parser.add_argument("--num_steps", default=250, type=int)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
if not os.path.exists(args.save_path):
|
| 41 |
+
os.mkdir(args.save_path)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
config = tf.ConfigProto(
|
| 45 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 46 |
+
)
|
| 47 |
+
config.gpu_options.allow_growth = True
|
| 48 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 49 |
+
|
| 50 |
+
print("warming up TensorFlow...")
|
| 51 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 52 |
+
# than after the next print(), to help prevent confusion.
|
| 53 |
+
evaluator.warmup()
|
| 54 |
+
|
| 55 |
+
print("computing reference batch activations...")
|
| 56 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 57 |
+
print("computing/reading reference batch statistics...")
|
| 58 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 59 |
+
|
| 60 |
+
print("computing sample batch activations...")
|
| 61 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 62 |
+
print("computing/reading sample batch statistics...")
|
| 63 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 64 |
+
|
| 65 |
+
print("Computing evaluations...")
|
| 66 |
+
Inception_Score = evaluator.compute_inception_score(sample_acts[0])
|
| 67 |
+
FID = sample_stats.frechet_distance(ref_stats)
|
| 68 |
+
sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
| 69 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 70 |
+
|
| 71 |
+
print("Inception Score:", Inception_Score)
|
| 72 |
+
print("FID:", FID)
|
| 73 |
+
print("sFID:", sFID)
|
| 74 |
+
print("Precision:", prec)
|
| 75 |
+
print("Recall:", recall)
|
| 76 |
+
|
| 77 |
+
if args.cfg_cond:
|
| 78 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
|
| 79 |
+
else:
|
| 80 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
|
| 81 |
+
with open(file_path, "w") as file:
|
| 82 |
+
file.write("Inception Score: {}\n".format(Inception_Score))
|
| 83 |
+
file.write("FID: {}\n".format(FID))
|
| 84 |
+
file.write("sFID: {}\n".format(sFID))
|
| 85 |
+
file.write("Precision: {}\n".format(prec))
|
| 86 |
+
file.write("Recall: {}\n".format(recall))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class InvalidFIDException(Exception):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FIDStatistics:
|
| 94 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 95 |
+
self.mu = mu
|
| 96 |
+
self.sigma = sigma
|
| 97 |
+
|
| 98 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 99 |
+
"""
|
| 100 |
+
Compute the Frechet distance between two sets of statistics.
|
| 101 |
+
"""
|
| 102 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 103 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 104 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 105 |
+
|
| 106 |
+
mu1 = np.atleast_1d(mu1)
|
| 107 |
+
mu2 = np.atleast_1d(mu2)
|
| 108 |
+
|
| 109 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 110 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 111 |
+
|
| 112 |
+
assert (
|
| 113 |
+
mu1.shape == mu2.shape
|
| 114 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 115 |
+
assert (
|
| 116 |
+
sigma1.shape == sigma2.shape
|
| 117 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 118 |
+
|
| 119 |
+
diff = mu1 - mu2
|
| 120 |
+
|
| 121 |
+
# product might be almost singular
|
| 122 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 123 |
+
if not np.isfinite(covmean).all():
|
| 124 |
+
msg = (
|
| 125 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 126 |
+
% eps
|
| 127 |
+
)
|
| 128 |
+
warnings.warn(msg)
|
| 129 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 130 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 131 |
+
|
| 132 |
+
# numerical error might give slight imaginary component
|
| 133 |
+
if np.iscomplexobj(covmean):
|
| 134 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 135 |
+
m = np.max(np.abs(covmean.imag))
|
| 136 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 137 |
+
covmean = covmean.real
|
| 138 |
+
|
| 139 |
+
tr_covmean = np.trace(covmean)
|
| 140 |
+
|
| 141 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Evaluator:
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
session,
|
| 148 |
+
batch_size=64,
|
| 149 |
+
softmax_batch_size=512,
|
| 150 |
+
):
|
| 151 |
+
self.sess = session
|
| 152 |
+
self.batch_size = batch_size
|
| 153 |
+
self.softmax_batch_size = softmax_batch_size
|
| 154 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 155 |
+
with self.sess.graph.as_default():
|
| 156 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 157 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 158 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 159 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 160 |
+
|
| 161 |
+
def warmup(self):
|
| 162 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 163 |
+
|
| 164 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 165 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 166 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 167 |
+
|
| 168 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
Compute image features for downstream evals.
|
| 171 |
+
|
| 172 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 173 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 174 |
+
dimension. The tuple is (pool_3, spatial).
|
| 175 |
+
"""
|
| 176 |
+
preds = []
|
| 177 |
+
spatial_preds = []
|
| 178 |
+
for batch in tqdm(batches):
|
| 179 |
+
batch = batch.astype(np.float32)
|
| 180 |
+
pred, spatial_pred = self.sess.run(
|
| 181 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 182 |
+
)
|
| 183 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 184 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 185 |
+
return (
|
| 186 |
+
np.concatenate(preds, axis=0),
|
| 187 |
+
np.concatenate(spatial_preds, axis=0),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def read_statistics(
|
| 191 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 192 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 193 |
+
obj = np.load(npz_path)
|
| 194 |
+
if "mu" in list(obj.keys()):
|
| 195 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 196 |
+
obj["mu_s"], obj["sigma_s"]
|
| 197 |
+
)
|
| 198 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 199 |
+
|
| 200 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 201 |
+
mu = np.mean(activations, axis=0)
|
| 202 |
+
sigma = np.cov(activations, rowvar=False)
|
| 203 |
+
return FIDStatistics(mu, sigma)
|
| 204 |
+
|
| 205 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 206 |
+
softmax_out = []
|
| 207 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 208 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 209 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 210 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 211 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 212 |
+
scores = []
|
| 213 |
+
for i in range(0, len(preds), split_size):
|
| 214 |
+
part = preds[i : i + split_size]
|
| 215 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 216 |
+
kl = np.mean(np.sum(kl, 1))
|
| 217 |
+
scores.append(np.exp(kl))
|
| 218 |
+
return float(np.mean(scores))
|
| 219 |
+
|
| 220 |
+
def compute_prec_recall(
|
| 221 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 222 |
+
) -> Tuple[float, float]:
|
| 223 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 224 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 225 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 226 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 227 |
+
)
|
| 228 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ManifoldEstimator:
|
| 232 |
+
"""
|
| 233 |
+
A helper for comparing manifolds of feature vectors.
|
| 234 |
+
|
| 235 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
session,
|
| 241 |
+
row_batch_size=10000,
|
| 242 |
+
col_batch_size=10000,
|
| 243 |
+
nhood_sizes=(3,),
|
| 244 |
+
clamp_to_percentile=None,
|
| 245 |
+
eps=1e-5,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Estimate the manifold of given feature vectors.
|
| 249 |
+
|
| 250 |
+
:param session: the TensorFlow session.
|
| 251 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 252 |
+
(parameter to trade-off between memory usage and performance).
|
| 253 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 254 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 255 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 256 |
+
the given percentile.
|
| 257 |
+
:param eps: small number for numerical stability.
|
| 258 |
+
"""
|
| 259 |
+
self.distance_block = DistanceBlock(session)
|
| 260 |
+
self.row_batch_size = row_batch_size
|
| 261 |
+
self.col_batch_size = col_batch_size
|
| 262 |
+
self.nhood_sizes = nhood_sizes
|
| 263 |
+
self.num_nhoods = len(nhood_sizes)
|
| 264 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 265 |
+
self.eps = eps
|
| 266 |
+
|
| 267 |
+
def warmup(self):
|
| 268 |
+
feats, radii = (
|
| 269 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 270 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 271 |
+
)
|
| 272 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 273 |
+
|
| 274 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 275 |
+
num_images = len(features)
|
| 276 |
+
|
| 277 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 278 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 279 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 280 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 281 |
+
|
| 282 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 283 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 284 |
+
row_batch = features[begin1:end1]
|
| 285 |
+
|
| 286 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 287 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 288 |
+
col_batch = features[begin2:end2]
|
| 289 |
+
|
| 290 |
+
# Compute distances between batches.
|
| 291 |
+
distance_batch[
|
| 292 |
+
0 : end1 - begin1, begin2:end2
|
| 293 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 294 |
+
|
| 295 |
+
# Find the k-nearest neighbor from the current batch.
|
| 296 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 297 |
+
[
|
| 298 |
+
x[:, self.nhood_sizes]
|
| 299 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 300 |
+
],
|
| 301 |
+
axis=0,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if self.clamp_to_percentile is not None:
|
| 305 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 306 |
+
radii[radii > max_distances] = 0
|
| 307 |
+
return radii
|
| 308 |
+
|
| 309 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 310 |
+
"""
|
| 311 |
+
Evaluate if new feature vectors are at the manifold.
|
| 312 |
+
"""
|
| 313 |
+
num_eval_images = eval_features.shape[0]
|
| 314 |
+
num_ref_images = radii.shape[0]
|
| 315 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 316 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 317 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 318 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 319 |
+
|
| 320 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 321 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 322 |
+
feature_batch = eval_features[begin1:end1]
|
| 323 |
+
|
| 324 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 325 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 326 |
+
ref_batch = features[begin2:end2]
|
| 327 |
+
|
| 328 |
+
distance_batch[
|
| 329 |
+
0 : end1 - begin1, begin2:end2
|
| 330 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 331 |
+
|
| 332 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 333 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 334 |
+
# the new sample lies at the estimated manifold.
|
| 335 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 336 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 337 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 338 |
+
|
| 339 |
+
max_realism_score[begin1:end1] = np.max(
|
| 340 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 341 |
+
)
|
| 342 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 346 |
+
"batch_predictions": batch_predictions,
|
| 347 |
+
"max_realisim_score": max_realism_score,
|
| 348 |
+
"nearest_indices": nearest_indices,
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
def evaluate_pr(
|
| 352 |
+
self,
|
| 353 |
+
features_1: np.ndarray,
|
| 354 |
+
radii_1: np.ndarray,
|
| 355 |
+
features_2: np.ndarray,
|
| 356 |
+
radii_2: np.ndarray,
|
| 357 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 358 |
+
"""
|
| 359 |
+
Evaluate precision and recall efficiently.
|
| 360 |
+
|
| 361 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 362 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 363 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 364 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 365 |
+
:return: a tuple of arrays for (precision, recall):
|
| 366 |
+
- precision: an np.ndarray of length K1
|
| 367 |
+
- recall: an np.ndarray of length K2
|
| 368 |
+
"""
|
| 369 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
|
| 370 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
|
| 371 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 372 |
+
end_1 = begin_1 + self.row_batch_size
|
| 373 |
+
batch_1 = features_1[begin_1:end_1]
|
| 374 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 375 |
+
end_2 = begin_2 + self.col_batch_size
|
| 376 |
+
batch_2 = features_2[begin_2:end_2]
|
| 377 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 378 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 379 |
+
)
|
| 380 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 381 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 382 |
+
return (
|
| 383 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 384 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class DistanceBlock:
|
| 389 |
+
"""
|
| 390 |
+
Calculate pairwise distances between vectors.
|
| 391 |
+
|
| 392 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(self, session):
|
| 396 |
+
self.session = session
|
| 397 |
+
|
| 398 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 399 |
+
with session.graph.as_default():
|
| 400 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 401 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 402 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 403 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 404 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 405 |
+
)
|
| 406 |
+
self.distance_block = tf.cond(
|
| 407 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 408 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 409 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Extra logic for less thans.
|
| 413 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 414 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 415 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 416 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 417 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 418 |
+
|
| 419 |
+
def pairwise_distances(self, U, V):
|
| 420 |
+
"""
|
| 421 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 422 |
+
"""
|
| 423 |
+
return self.session.run(
|
| 424 |
+
self.distance_block,
|
| 425 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 429 |
+
return self.session.run(
|
| 430 |
+
[self._batch_1_in, self._batch_2_in],
|
| 431 |
+
feed_dict={
|
| 432 |
+
self._features_batch1: batch_1,
|
| 433 |
+
self._features_batch2: batch_2,
|
| 434 |
+
self._radii1: radii_1,
|
| 435 |
+
self._radii2: radii_2,
|
| 436 |
+
},
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _batch_pairwise_distances(U, V):
|
| 441 |
+
"""
|
| 442 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 443 |
+
"""
|
| 444 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 445 |
+
# Squared norms of each row in U and V.
|
| 446 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 447 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 448 |
+
|
| 449 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 450 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 451 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 452 |
+
|
| 453 |
+
# Pairwise squared Euclidean distances.
|
| 454 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 455 |
+
|
| 456 |
+
return D
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class NpzArrayReader(ABC):
|
| 460 |
+
@abstractmethod
|
| 461 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 462 |
+
pass
|
| 463 |
+
|
| 464 |
+
@abstractmethod
|
| 465 |
+
def remaining(self) -> int:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 469 |
+
def gen_fn():
|
| 470 |
+
while True:
|
| 471 |
+
batch = self.read_batch(batch_size)
|
| 472 |
+
if batch is None:
|
| 473 |
+
break
|
| 474 |
+
yield batch
|
| 475 |
+
|
| 476 |
+
rem = self.remaining()
|
| 477 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 478 |
+
return BatchIterator(gen_fn, num_batches)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class BatchIterator:
|
| 482 |
+
def __init__(self, gen_fn, length):
|
| 483 |
+
self.gen_fn = gen_fn
|
| 484 |
+
self.length = length
|
| 485 |
+
|
| 486 |
+
def __len__(self):
|
| 487 |
+
return self.length
|
| 488 |
+
|
| 489 |
+
def __iter__(self):
|
| 490 |
+
return self.gen_fn()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 494 |
+
def __init__(self, arr_f, shape, dtype):
|
| 495 |
+
self.arr_f = arr_f
|
| 496 |
+
self.shape = shape
|
| 497 |
+
self.dtype = dtype
|
| 498 |
+
self.idx = 0
|
| 499 |
+
|
| 500 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 501 |
+
if self.idx >= self.shape[0]:
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 505 |
+
self.idx += bs
|
| 506 |
+
|
| 507 |
+
if self.dtype.itemsize == 0:
|
| 508 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 509 |
+
|
| 510 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 511 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 512 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 513 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 514 |
+
|
| 515 |
+
def remaining(self) -> int:
|
| 516 |
+
return max(0, self.shape[0] - self.idx)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 520 |
+
def __init__(self, arr):
|
| 521 |
+
self.arr = arr
|
| 522 |
+
self.idx = 0
|
| 523 |
+
|
| 524 |
+
@classmethod
|
| 525 |
+
def load(cls, path: str, arr_name: str):
|
| 526 |
+
with open(path, "rb") as f:
|
| 527 |
+
arr = np.load(f)[arr_name]
|
| 528 |
+
return cls(arr)
|
| 529 |
+
|
| 530 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 531 |
+
if self.idx >= self.arr.shape[0]:
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 535 |
+
self.idx += batch_size
|
| 536 |
+
return res
|
| 537 |
+
|
| 538 |
+
def remaining(self) -> int:
|
| 539 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
@contextmanager
|
| 543 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 544 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 545 |
+
version = np.lib.format.read_magic(arr_f)
|
| 546 |
+
if version == (1, 0):
|
| 547 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 548 |
+
elif version == (2, 0):
|
| 549 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 550 |
+
else:
|
| 551 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 552 |
+
return
|
| 553 |
+
shape, fortran, dtype = header
|
| 554 |
+
if fortran or dtype.hasobject:
|
| 555 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 556 |
+
else:
|
| 557 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 561 |
+
"""
|
| 562 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 563 |
+
|
| 564 |
+
Read from file-like object until size bytes are read.
|
| 565 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 566 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 567 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 568 |
+
requested.
|
| 569 |
+
"""
|
| 570 |
+
data = bytes()
|
| 571 |
+
while True:
|
| 572 |
+
# io files (default in python3) return None or raise on
|
| 573 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 574 |
+
# done about that. note that regular files can't be non-blocking
|
| 575 |
+
try:
|
| 576 |
+
r = fp.read(size - len(data))
|
| 577 |
+
data += r
|
| 578 |
+
if len(r) == 0 or len(data) == size:
|
| 579 |
+
break
|
| 580 |
+
except io.BlockingIOError:
|
| 581 |
+
pass
|
| 582 |
+
if len(data) != size:
|
| 583 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 584 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 585 |
+
else:
|
| 586 |
+
return data
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@contextmanager
|
| 590 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 591 |
+
with open(path, "rb") as f:
|
| 592 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 593 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 594 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 595 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 596 |
+
yield arr_f
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def _download_inception_model():
|
| 600 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 601 |
+
return
|
| 602 |
+
print("downloading InceptionV3 model...")
|
| 603 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 604 |
+
r.raise_for_status()
|
| 605 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 606 |
+
with open(tmp_path, "wb") as f:
|
| 607 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 608 |
+
f.write(chunk)
|
| 609 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def _create_feature_graph(input_batch):
|
| 613 |
+
_download_inception_model()
|
| 614 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 615 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 616 |
+
graph_def = tf.GraphDef()
|
| 617 |
+
graph_def.ParseFromString(f.read())
|
| 618 |
+
pool3, spatial = tf.import_graph_def(
|
| 619 |
+
graph_def,
|
| 620 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 621 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 622 |
+
name=prefix,
|
| 623 |
+
)
|
| 624 |
+
_update_shapes(pool3)
|
| 625 |
+
spatial = spatial[..., :7]
|
| 626 |
+
return pool3, spatial
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def _create_softmax_graph(input_batch):
|
| 630 |
+
_download_inception_model()
|
| 631 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 632 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 633 |
+
graph_def = tf.GraphDef()
|
| 634 |
+
graph_def.ParseFromString(f.read())
|
| 635 |
+
(matmul,) = tf.import_graph_def(
|
| 636 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 637 |
+
)
|
| 638 |
+
w = matmul.inputs[1]
|
| 639 |
+
logits = tf.matmul(input_batch, w)
|
| 640 |
+
return tf.nn.softmax(logits)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def _update_shapes(pool3):
|
| 644 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 645 |
+
ops = pool3.graph.get_operations()
|
| 646 |
+
for op in ops:
|
| 647 |
+
for o in op.outputs:
|
| 648 |
+
shape = o.get_shape()
|
| 649 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 650 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 651 |
+
shape = [s for s in shape] # TF 2.x
|
| 652 |
+
new_shape = []
|
| 653 |
+
for j, s in enumerate(shape):
|
| 654 |
+
if s == 1 and j == 0:
|
| 655 |
+
new_shape.append(None)
|
| 656 |
+
else:
|
| 657 |
+
new_shape.append(s)
|
| 658 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 659 |
+
return pool3
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 663 |
+
num_workers = min(cpu_count(), len(arr))
|
| 664 |
+
chunk_size = len(arr) // num_workers
|
| 665 |
+
extra = len(arr) % num_workers
|
| 666 |
+
|
| 667 |
+
start_idx = 0
|
| 668 |
+
batches = []
|
| 669 |
+
for i in range(num_workers):
|
| 670 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 671 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 672 |
+
start_idx += size
|
| 673 |
+
|
| 674 |
+
with ThreadPool(num_workers) as pool:
|
| 675 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
if __name__ == "__main__":
|
| 679 |
+
main()
|
back/evaluations/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow-gpu>=2.0
|
| 2 |
+
scipy
|
| 3 |
+
requests
|
| 4 |
+
tqdm
|
back/models/__pycache__/mocov3_vit.cpython-310.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
back/models/__pycache__/mocov3_vit.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
back/models/__pycache__/sit.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
back/models/__pycache__/sit.cpython-312.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
back/models/clip_vit.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
import clip
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Bottleneck(nn.Module):
|
| 13 |
+
expansion = 4
|
| 14 |
+
|
| 15 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 19 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 20 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 21 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 22 |
+
|
| 23 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 24 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 25 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 26 |
+
|
| 27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 28 |
+
|
| 29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 31 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 32 |
+
|
| 33 |
+
self.downsample = None
|
| 34 |
+
self.stride = stride
|
| 35 |
+
|
| 36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 38 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 39 |
+
("-1", nn.AvgPool2d(stride)),
|
| 40 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 41 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 42 |
+
]))
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor):
|
| 45 |
+
identity = x
|
| 46 |
+
|
| 47 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 48 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 49 |
+
out = self.avgpool(out)
|
| 50 |
+
out = self.bn3(self.conv3(out))
|
| 51 |
+
|
| 52 |
+
if self.downsample is not None:
|
| 53 |
+
identity = self.downsample(x)
|
| 54 |
+
|
| 55 |
+
out += identity
|
| 56 |
+
out = self.relu3(out)
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AttentionPool2d(nn.Module):
|
| 61 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 64 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 66 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 67 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 72 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 73 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 74 |
+
x, _ = F.multi_head_attention_forward(
|
| 75 |
+
query=x[:1], key=x, value=x,
|
| 76 |
+
embed_dim_to_check=x.shape[-1],
|
| 77 |
+
num_heads=self.num_heads,
|
| 78 |
+
q_proj_weight=self.q_proj.weight,
|
| 79 |
+
k_proj_weight=self.k_proj.weight,
|
| 80 |
+
v_proj_weight=self.v_proj.weight,
|
| 81 |
+
in_proj_weight=None,
|
| 82 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 83 |
+
bias_k=None,
|
| 84 |
+
bias_v=None,
|
| 85 |
+
add_zero_attn=False,
|
| 86 |
+
dropout_p=0,
|
| 87 |
+
out_proj_weight=self.c_proj.weight,
|
| 88 |
+
out_proj_bias=self.c_proj.bias,
|
| 89 |
+
use_separate_proj_weight=True,
|
| 90 |
+
training=self.training,
|
| 91 |
+
need_weights=False
|
| 92 |
+
)
|
| 93 |
+
return x.squeeze(0)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ModifiedResNet(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 99 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 100 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 101 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.output_dim = output_dim
|
| 107 |
+
self.input_resolution = input_resolution
|
| 108 |
+
|
| 109 |
+
# the 3-layer stem
|
| 110 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 111 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 112 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 113 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 114 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 115 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 116 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 117 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 118 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 119 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 120 |
+
|
| 121 |
+
# residual layers
|
| 122 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 123 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 124 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 125 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 126 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 127 |
+
|
| 128 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 129 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 130 |
+
|
| 131 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 132 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 133 |
+
|
| 134 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 135 |
+
for _ in range(1, blocks):
|
| 136 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 137 |
+
|
| 138 |
+
return nn.Sequential(*layers)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
def stem(x):
|
| 142 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 143 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 144 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 145 |
+
x = self.avgpool(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
x = x.type(self.conv1.weight.dtype)
|
| 149 |
+
x = stem(x)
|
| 150 |
+
x = self.layer1(x)
|
| 151 |
+
x = self.layer2(x)
|
| 152 |
+
x = self.layer3(x)
|
| 153 |
+
x = self.layer4(x)
|
| 154 |
+
x = self.attnpool(x)
|
| 155 |
+
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class LayerNorm(nn.LayerNorm):
|
| 160 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 161 |
+
|
| 162 |
+
def forward(self, x: torch.Tensor):
|
| 163 |
+
orig_type = x.dtype
|
| 164 |
+
ret = super().forward(x.type(torch.float32))
|
| 165 |
+
return ret.type(orig_type)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class QuickGELU(nn.Module):
|
| 169 |
+
def forward(self, x: torch.Tensor):
|
| 170 |
+
return x * torch.sigmoid(1.702 * x)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ResidualAttentionBlock(nn.Module):
|
| 174 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 175 |
+
super().__init__()
|
| 176 |
+
|
| 177 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 178 |
+
self.ln_1 = LayerNorm(d_model)
|
| 179 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 180 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 181 |
+
("gelu", QuickGELU()),
|
| 182 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 183 |
+
]))
|
| 184 |
+
self.ln_2 = LayerNorm(d_model)
|
| 185 |
+
self.attn_mask = attn_mask
|
| 186 |
+
|
| 187 |
+
def attention(self, x: torch.Tensor):
|
| 188 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 189 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor):
|
| 192 |
+
x = x + self.attention(self.ln_1(x))
|
| 193 |
+
x = x + self.mlp(self.ln_2(x))
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Transformer(nn.Module):
|
| 198 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.width = width
|
| 201 |
+
self.layers = layers
|
| 202 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 203 |
+
|
| 204 |
+
def forward(self, x: torch.Tensor):
|
| 205 |
+
return self.resblocks(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class UpdatedVisionTransformer(nn.Module):
|
| 209 |
+
def __init__(self, model):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.model = model
|
| 212 |
+
|
| 213 |
+
def forward(self, x: torch.Tensor):
|
| 214 |
+
x = self.model.conv1(x) # shape = [*, width, grid, grid]
|
| 215 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 216 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 217 |
+
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 218 |
+
x = x + self.model.positional_embedding.to(x.dtype)
|
| 219 |
+
x = self.model.ln_pre(x)
|
| 220 |
+
|
| 221 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 222 |
+
x = self.model.transformer(x)
|
| 223 |
+
x = x.permute(1, 0, 2)[:, 1:] # LND -> NLD
|
| 224 |
+
|
| 225 |
+
# x = self.ln_post(x[:, 0, :])
|
| 226 |
+
|
| 227 |
+
# if self.proj is not None:
|
| 228 |
+
# x = x @ self.proj
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class CLIP(nn.Module):
|
| 234 |
+
def __init__(self,
|
| 235 |
+
embed_dim: int,
|
| 236 |
+
# vision
|
| 237 |
+
image_resolution: int,
|
| 238 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 239 |
+
vision_width: int,
|
| 240 |
+
vision_patch_size: int,
|
| 241 |
+
# text
|
| 242 |
+
context_length: int,
|
| 243 |
+
vocab_size: int,
|
| 244 |
+
transformer_width: int,
|
| 245 |
+
transformer_heads: int,
|
| 246 |
+
transformer_layers: int
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
|
| 250 |
+
self.context_length = context_length
|
| 251 |
+
|
| 252 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 253 |
+
vision_heads = vision_width * 32 // 64
|
| 254 |
+
self.visual = ModifiedResNet(
|
| 255 |
+
layers=vision_layers,
|
| 256 |
+
output_dim=embed_dim,
|
| 257 |
+
heads=vision_heads,
|
| 258 |
+
input_resolution=image_resolution,
|
| 259 |
+
width=vision_width
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
vision_heads = vision_width // 64
|
| 263 |
+
self.visual = UpdatedVisionTransformer(
|
| 264 |
+
input_resolution=image_resolution,
|
| 265 |
+
patch_size=vision_patch_size,
|
| 266 |
+
width=vision_width,
|
| 267 |
+
layers=vision_layers,
|
| 268 |
+
heads=vision_heads,
|
| 269 |
+
output_dim=embed_dim
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.transformer = Transformer(
|
| 273 |
+
width=transformer_width,
|
| 274 |
+
layers=transformer_layers,
|
| 275 |
+
heads=transformer_heads,
|
| 276 |
+
attn_mask=self.build_attention_mask()
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.vocab_size = vocab_size
|
| 280 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 281 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 282 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 283 |
+
|
| 284 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 285 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 286 |
+
|
| 287 |
+
self.initialize_parameters()
|
| 288 |
+
|
| 289 |
+
def initialize_parameters(self):
|
| 290 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 291 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 292 |
+
|
| 293 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 294 |
+
if self.visual.attnpool is not None:
|
| 295 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 296 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 297 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 298 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 299 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 300 |
+
|
| 301 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 302 |
+
for name, param in resnet_block.named_parameters():
|
| 303 |
+
if name.endswith("bn3.weight"):
|
| 304 |
+
nn.init.zeros_(param)
|
| 305 |
+
|
| 306 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 307 |
+
attn_std = self.transformer.width ** -0.5
|
| 308 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 309 |
+
for block in self.transformer.resblocks:
|
| 310 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 311 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 312 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 313 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 314 |
+
|
| 315 |
+
if self.text_projection is not None:
|
| 316 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 317 |
+
|
| 318 |
+
def build_attention_mask(self):
|
| 319 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 320 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 321 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 322 |
+
mask.fill_(float("-inf"))
|
| 323 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 324 |
+
return mask
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def dtype(self):
|
| 328 |
+
return self.visual.conv1.weight.dtype
|
| 329 |
+
|
| 330 |
+
def encode_image(self, image):
|
| 331 |
+
return self.visual(image.type(self.dtype))
|
| 332 |
+
|
| 333 |
+
def encode_text(self, text):
|
| 334 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 335 |
+
|
| 336 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 337 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 338 |
+
x = self.transformer(x)
|
| 339 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 340 |
+
x = self.ln_final(x).type(self.dtype)
|
| 341 |
+
|
| 342 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 343 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 344 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 345 |
+
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
def forward(self, image, text):
|
| 349 |
+
image_features = self.encode_image(image)
|
| 350 |
+
text_features = self.encode_text(text)
|
| 351 |
+
|
| 352 |
+
# normalized features
|
| 353 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 354 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 355 |
+
|
| 356 |
+
# cosine similarity as logits
|
| 357 |
+
logit_scale = self.logit_scale.exp()
|
| 358 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 359 |
+
logits_per_text = logits_per_image.t()
|
| 360 |
+
|
| 361 |
+
# shape = [global_batch_size, global_batch_size]
|
| 362 |
+
return logits_per_image, logits_per_text
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def convert_weights(model: nn.Module):
|
| 366 |
+
"""Convert applicable model parameters to fp16"""
|
| 367 |
+
|
| 368 |
+
def _convert_weights_to_fp16(l):
|
| 369 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 370 |
+
l.weight.data = l.weight.data.half()
|
| 371 |
+
if l.bias is not None:
|
| 372 |
+
l.bias.data = l.bias.data.half()
|
| 373 |
+
|
| 374 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 375 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 376 |
+
tensor = getattr(l, attr)
|
| 377 |
+
if tensor is not None:
|
| 378 |
+
tensor.data = tensor.data.half()
|
| 379 |
+
|
| 380 |
+
for name in ["text_projection", "proj"]:
|
| 381 |
+
if hasattr(l, name):
|
| 382 |
+
attr = getattr(l, name)
|
| 383 |
+
if attr is not None:
|
| 384 |
+
attr.data = attr.data.half()
|
| 385 |
+
|
| 386 |
+
model.apply(_convert_weights_to_fp16)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_model(state_dict: dict):
|
| 390 |
+
vit = "visual.proj" in state_dict
|
| 391 |
+
|
| 392 |
+
if vit:
|
| 393 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 394 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 395 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 396 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 397 |
+
image_resolution = vision_patch_size * grid_size
|
| 398 |
+
else:
|
| 399 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 400 |
+
vision_layers = tuple(counts)
|
| 401 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 402 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 403 |
+
vision_patch_size = None
|
| 404 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 405 |
+
image_resolution = output_width * 32
|
| 406 |
+
|
| 407 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 408 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 409 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 410 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 411 |
+
transformer_heads = transformer_width // 64
|
| 412 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 413 |
+
|
| 414 |
+
model = CLIP(
|
| 415 |
+
embed_dim,
|
| 416 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 417 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 421 |
+
if key in state_dict:
|
| 422 |
+
del state_dict[key]
|
| 423 |
+
|
| 424 |
+
convert_weights(model)
|
| 425 |
+
model.load_state_dict(state_dict)
|
| 426 |
+
return model.eval()
|
back/models/jepa.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from functools import partial
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 16 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 17 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 18 |
+
def norm_cdf(x):
|
| 19 |
+
# Computes standard normal cumulative distribution function
|
| 20 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
# Values are generated by using a truncated uniform distribution and
|
| 24 |
+
# then using the inverse CDF for the normal distribution.
|
| 25 |
+
# Get upper and lower cdf values
|
| 26 |
+
l = norm_cdf((a - mean) / std)
|
| 27 |
+
u = norm_cdf((b - mean) / std)
|
| 28 |
+
|
| 29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 30 |
+
# [2l-1, 2u-1].
|
| 31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 32 |
+
|
| 33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 34 |
+
# standard normal
|
| 35 |
+
tensor.erfinv_()
|
| 36 |
+
|
| 37 |
+
# Transform to proper mean, std
|
| 38 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 39 |
+
tensor.add_(mean)
|
| 40 |
+
|
| 41 |
+
# Clamp to ensure it's in the proper range
|
| 42 |
+
tensor.clamp_(min=a, max=b)
|
| 43 |
+
return tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 47 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def repeat_interleave_batch(x, B, repeat):
|
| 51 |
+
N = len(x) // B
|
| 52 |
+
x = torch.cat([
|
| 53 |
+
torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
|
| 54 |
+
for i in range(N)
|
| 55 |
+
], dim=0)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
def apply_masks(x, masks):
|
| 59 |
+
"""
|
| 60 |
+
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
| 61 |
+
:param masks: list of tensors containing indices of patches in [N] to keep
|
| 62 |
+
"""
|
| 63 |
+
all_x = []
|
| 64 |
+
for m in masks:
|
| 65 |
+
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
| 66 |
+
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
| 67 |
+
return torch.cat(all_x, dim=0)
|
| 68 |
+
|
| 69 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 70 |
+
"""
|
| 71 |
+
grid_size: int of the grid height and width
|
| 72 |
+
return:
|
| 73 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 74 |
+
"""
|
| 75 |
+
grid_h = np.arange(grid_size, dtype=float)
|
| 76 |
+
grid_w = np.arange(grid_size, dtype=float)
|
| 77 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 78 |
+
grid = np.stack(grid, axis=0)
|
| 79 |
+
|
| 80 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 81 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 82 |
+
if cls_token:
|
| 83 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 84 |
+
return pos_embed
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 88 |
+
assert embed_dim % 2 == 0
|
| 89 |
+
|
| 90 |
+
# use half of dimensions to encode grid_h
|
| 91 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 92 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 93 |
+
|
| 94 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 95 |
+
return emb
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 99 |
+
"""
|
| 100 |
+
grid_size: int of the grid length
|
| 101 |
+
return:
|
| 102 |
+
pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
|
| 103 |
+
"""
|
| 104 |
+
grid = np.arange(grid_size, dtype=float)
|
| 105 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 106 |
+
if cls_token:
|
| 107 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 108 |
+
return pos_embed
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 112 |
+
"""
|
| 113 |
+
embed_dim: output dimension for each position
|
| 114 |
+
pos: a list of positions to be encoded: size (M,)
|
| 115 |
+
out: (M, D)
|
| 116 |
+
"""
|
| 117 |
+
assert embed_dim % 2 == 0
|
| 118 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 119 |
+
omega /= embed_dim / 2.
|
| 120 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 121 |
+
|
| 122 |
+
pos = pos.reshape(-1) # (M,)
|
| 123 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 124 |
+
|
| 125 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 126 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 127 |
+
|
| 128 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 129 |
+
return emb
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 133 |
+
if drop_prob == 0. or not training:
|
| 134 |
+
return x
|
| 135 |
+
keep_prob = 1 - drop_prob
|
| 136 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 137 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 138 |
+
random_tensor.floor_() # binarize
|
| 139 |
+
output = x.div(keep_prob) * random_tensor
|
| 140 |
+
return output
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DropPath(nn.Module):
|
| 144 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 145 |
+
"""
|
| 146 |
+
def __init__(self, drop_prob=None):
|
| 147 |
+
super(DropPath, self).__init__()
|
| 148 |
+
self.drop_prob = drop_prob
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MLP(nn.Module):
|
| 155 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 156 |
+
super().__init__()
|
| 157 |
+
out_features = out_features or in_features
|
| 158 |
+
hidden_features = hidden_features or in_features
|
| 159 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 160 |
+
self.act = act_layer()
|
| 161 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 162 |
+
self.drop = nn.Dropout(drop)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
x = self.fc1(x)
|
| 166 |
+
x = self.act(x)
|
| 167 |
+
x = self.drop(x)
|
| 168 |
+
x = self.fc2(x)
|
| 169 |
+
x = self.drop(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Attention(nn.Module):
|
| 174 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.num_heads = num_heads
|
| 177 |
+
head_dim = dim // num_heads
|
| 178 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 179 |
+
|
| 180 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 181 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 182 |
+
self.proj = nn.Linear(dim, dim)
|
| 183 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
B, N, C = x.shape
|
| 187 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 188 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 189 |
+
|
| 190 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 191 |
+
attn = attn.softmax(dim=-1)
|
| 192 |
+
attn = self.attn_drop(attn)
|
| 193 |
+
|
| 194 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 195 |
+
x = self.proj(x)
|
| 196 |
+
x = self.proj_drop(x)
|
| 197 |
+
return x, attn
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Block(nn.Module):
|
| 201 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 202 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.norm1 = norm_layer(dim)
|
| 205 |
+
self.attn = Attention(
|
| 206 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 207 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 208 |
+
self.norm2 = norm_layer(dim)
|
| 209 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 210 |
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 211 |
+
|
| 212 |
+
def forward(self, x, return_attention=False):
|
| 213 |
+
y, attn = self.attn(self.norm1(x))
|
| 214 |
+
if return_attention:
|
| 215 |
+
return attn
|
| 216 |
+
x = x + self.drop_path(y)
|
| 217 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class PatchEmbed(nn.Module):
|
| 222 |
+
""" Image to Patch Embedding
|
| 223 |
+
"""
|
| 224 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 225 |
+
super().__init__()
|
| 226 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
| 227 |
+
self.img_size = img_size
|
| 228 |
+
self.patch_size = patch_size
|
| 229 |
+
self.num_patches = num_patches
|
| 230 |
+
|
| 231 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
B, C, H, W = x.shape
|
| 235 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ConvEmbed(nn.Module):
|
| 240 |
+
"""
|
| 241 |
+
3x3 Convolution stems for ViT following ViTC models
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
|
| 245 |
+
super().__init__()
|
| 246 |
+
# Build the stems
|
| 247 |
+
stem = []
|
| 248 |
+
channels = [in_chans] + channels
|
| 249 |
+
for i in range(len(channels) - 2):
|
| 250 |
+
stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
|
| 251 |
+
stride=strides[i], padding=1, bias=(not batch_norm))]
|
| 252 |
+
if batch_norm:
|
| 253 |
+
stem += [nn.BatchNorm2d(channels[i+1])]
|
| 254 |
+
stem += [nn.ReLU(inplace=True)]
|
| 255 |
+
stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
|
| 256 |
+
self.stem = nn.Sequential(*stem)
|
| 257 |
+
|
| 258 |
+
# Comptute the number of patches
|
| 259 |
+
stride_prod = int(np.prod(strides))
|
| 260 |
+
self.num_patches = (img_size[0] // stride_prod)**2
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
p = self.stem(x)
|
| 264 |
+
return p.flatten(2).transpose(1, 2)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class VisionTransformerPredictor(nn.Module):
|
| 268 |
+
""" Vision Transformer """
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
num_patches,
|
| 272 |
+
embed_dim=768,
|
| 273 |
+
predictor_embed_dim=384,
|
| 274 |
+
depth=6,
|
| 275 |
+
num_heads=12,
|
| 276 |
+
mlp_ratio=4.0,
|
| 277 |
+
qkv_bias=True,
|
| 278 |
+
qk_scale=None,
|
| 279 |
+
drop_rate=0.0,
|
| 280 |
+
attn_drop_rate=0.0,
|
| 281 |
+
drop_path_rate=0.0,
|
| 282 |
+
norm_layer=nn.LayerNorm,
|
| 283 |
+
init_std=0.02,
|
| 284 |
+
**kwargs
|
| 285 |
+
):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
| 288 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
|
| 289 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 290 |
+
# --
|
| 291 |
+
self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
|
| 292 |
+
requires_grad=False)
|
| 293 |
+
predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
|
| 294 |
+
int(num_patches**.5),
|
| 295 |
+
cls_token=False)
|
| 296 |
+
self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
|
| 297 |
+
# --
|
| 298 |
+
self.predictor_blocks = nn.ModuleList([
|
| 299 |
+
Block(
|
| 300 |
+
dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 301 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 302 |
+
for i in range(depth)])
|
| 303 |
+
self.predictor_norm = norm_layer(predictor_embed_dim)
|
| 304 |
+
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
| 305 |
+
# ------
|
| 306 |
+
self.init_std = init_std
|
| 307 |
+
trunc_normal_(self.mask_token, std=self.init_std)
|
| 308 |
+
self.apply(self._init_weights)
|
| 309 |
+
self.fix_init_weight()
|
| 310 |
+
|
| 311 |
+
def fix_init_weight(self):
|
| 312 |
+
def rescale(param, layer_id):
|
| 313 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 314 |
+
|
| 315 |
+
for layer_id, layer in enumerate(self.predictor_blocks):
|
| 316 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 317 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 318 |
+
|
| 319 |
+
def _init_weights(self, m):
|
| 320 |
+
if isinstance(m, nn.Linear):
|
| 321 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 322 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 323 |
+
nn.init.constant_(m.bias, 0)
|
| 324 |
+
elif isinstance(m, nn.LayerNorm):
|
| 325 |
+
nn.init.constant_(m.bias, 0)
|
| 326 |
+
nn.init.constant_(m.weight, 1.0)
|
| 327 |
+
elif isinstance(m, nn.Conv2d):
|
| 328 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 329 |
+
if m.bias is not None:
|
| 330 |
+
nn.init.constant_(m.bias, 0)
|
| 331 |
+
|
| 332 |
+
def forward(self, x, masks_x, masks):
|
| 333 |
+
assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'
|
| 334 |
+
|
| 335 |
+
if not isinstance(masks_x, list):
|
| 336 |
+
masks_x = [masks_x]
|
| 337 |
+
|
| 338 |
+
if not isinstance(masks, list):
|
| 339 |
+
masks = [masks]
|
| 340 |
+
|
| 341 |
+
# -- Batch Size
|
| 342 |
+
B = len(x) // len(masks_x)
|
| 343 |
+
|
| 344 |
+
# -- map from encoder-dim to pedictor-dim
|
| 345 |
+
x = self.predictor_embed(x)
|
| 346 |
+
|
| 347 |
+
# -- add positional embedding to x tokens
|
| 348 |
+
x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
|
| 349 |
+
x += apply_masks(x_pos_embed, masks_x)
|
| 350 |
+
|
| 351 |
+
_, N_ctxt, D = x.shape
|
| 352 |
+
|
| 353 |
+
# -- concat mask tokens to x
|
| 354 |
+
pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
|
| 355 |
+
pos_embs = apply_masks(pos_embs, masks)
|
| 356 |
+
pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
|
| 357 |
+
# --
|
| 358 |
+
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
|
| 359 |
+
# --
|
| 360 |
+
pred_tokens += pos_embs
|
| 361 |
+
x = x.repeat(len(masks), 1, 1)
|
| 362 |
+
x = torch.cat([x, pred_tokens], dim=1)
|
| 363 |
+
|
| 364 |
+
# -- fwd prop
|
| 365 |
+
for blk in self.predictor_blocks:
|
| 366 |
+
x = blk(x)
|
| 367 |
+
x = self.predictor_norm(x)
|
| 368 |
+
|
| 369 |
+
# -- return preds for mask tokens
|
| 370 |
+
x = x[:, N_ctxt:]
|
| 371 |
+
x = self.predictor_proj(x)
|
| 372 |
+
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class VisionTransformer(nn.Module):
|
| 377 |
+
""" Vision Transformer """
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
img_size=[224],
|
| 381 |
+
patch_size=16,
|
| 382 |
+
in_chans=3,
|
| 383 |
+
embed_dim=768,
|
| 384 |
+
predictor_embed_dim=384,
|
| 385 |
+
depth=12,
|
| 386 |
+
predictor_depth=12,
|
| 387 |
+
num_heads=12,
|
| 388 |
+
mlp_ratio=4.0,
|
| 389 |
+
qkv_bias=True,
|
| 390 |
+
qk_scale=None,
|
| 391 |
+
drop_rate=0.0,
|
| 392 |
+
attn_drop_rate=0.0,
|
| 393 |
+
drop_path_rate=0.0,
|
| 394 |
+
norm_layer=nn.LayerNorm,
|
| 395 |
+
init_std=0.02,
|
| 396 |
+
**kwargs
|
| 397 |
+
):
|
| 398 |
+
super().__init__()
|
| 399 |
+
self.num_features = self.embed_dim = embed_dim
|
| 400 |
+
self.num_heads = num_heads
|
| 401 |
+
# --
|
| 402 |
+
self.patch_embed = PatchEmbed(
|
| 403 |
+
img_size=img_size[0],
|
| 404 |
+
patch_size=patch_size,
|
| 405 |
+
in_chans=in_chans,
|
| 406 |
+
embed_dim=embed_dim)
|
| 407 |
+
num_patches = self.patch_embed.num_patches
|
| 408 |
+
# --
|
| 409 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
|
| 410 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
|
| 411 |
+
int(self.patch_embed.num_patches**.5),
|
| 412 |
+
cls_token=False)
|
| 413 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 414 |
+
# --
|
| 415 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 416 |
+
self.blocks = nn.ModuleList([
|
| 417 |
+
Block(
|
| 418 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 419 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 420 |
+
for i in range(depth)])
|
| 421 |
+
self.norm = norm_layer(embed_dim)
|
| 422 |
+
# ------
|
| 423 |
+
self.init_std = init_std
|
| 424 |
+
self.apply(self._init_weights)
|
| 425 |
+
self.fix_init_weight()
|
| 426 |
+
|
| 427 |
+
def fix_init_weight(self):
|
| 428 |
+
def rescale(param, layer_id):
|
| 429 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 430 |
+
|
| 431 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 432 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 433 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 434 |
+
|
| 435 |
+
def _init_weights(self, m):
|
| 436 |
+
if isinstance(m, nn.Linear):
|
| 437 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 438 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 439 |
+
nn.init.constant_(m.bias, 0)
|
| 440 |
+
elif isinstance(m, nn.LayerNorm):
|
| 441 |
+
nn.init.constant_(m.bias, 0)
|
| 442 |
+
nn.init.constant_(m.weight, 1.0)
|
| 443 |
+
elif isinstance(m, nn.Conv2d):
|
| 444 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 445 |
+
if m.bias is not None:
|
| 446 |
+
nn.init.constant_(m.bias, 0)
|
| 447 |
+
|
| 448 |
+
def forward(self, x, masks=None):
|
| 449 |
+
if masks is not None:
|
| 450 |
+
if not isinstance(masks, list):
|
| 451 |
+
masks = [masks]
|
| 452 |
+
|
| 453 |
+
# -- patchify x
|
| 454 |
+
x = self.patch_embed(x)
|
| 455 |
+
B, N, D = x.shape
|
| 456 |
+
|
| 457 |
+
# -- add positional embedding to x
|
| 458 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 459 |
+
x = x + pos_embed
|
| 460 |
+
|
| 461 |
+
# -- mask x
|
| 462 |
+
if masks is not None:
|
| 463 |
+
x = apply_masks(x, masks)
|
| 464 |
+
|
| 465 |
+
# -- fwd prop
|
| 466 |
+
for i, blk in enumerate(self.blocks):
|
| 467 |
+
x = blk(x)
|
| 468 |
+
|
| 469 |
+
if self.norm is not None:
|
| 470 |
+
x = self.norm(x)
|
| 471 |
+
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
def interpolate_pos_encoding(self, x, pos_embed):
|
| 475 |
+
npatch = x.shape[1] - 1
|
| 476 |
+
N = pos_embed.shape[1] - 1
|
| 477 |
+
if npatch == N:
|
| 478 |
+
return pos_embed
|
| 479 |
+
class_emb = pos_embed[:, 0]
|
| 480 |
+
pos_embed = pos_embed[:, 1:]
|
| 481 |
+
dim = x.shape[-1]
|
| 482 |
+
pos_embed = nn.functional.interpolate(
|
| 483 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 484 |
+
scale_factor=math.sqrt(npatch / N),
|
| 485 |
+
mode='bicubic',
|
| 486 |
+
)
|
| 487 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 488 |
+
return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def vit_predictor(**kwargs):
|
| 492 |
+
model = VisionTransformerPredictor(
|
| 493 |
+
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 494 |
+
**kwargs)
|
| 495 |
+
return model
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def vit_tiny(patch_size=16, **kwargs):
|
| 499 |
+
model = VisionTransformer(
|
| 500 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
| 501 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 502 |
+
return model
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def vit_small(patch_size=16, **kwargs):
|
| 506 |
+
model = VisionTransformer(
|
| 507 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
| 508 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 509 |
+
return model
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def vit_base(patch_size=16, **kwargs):
|
| 513 |
+
model = VisionTransformer(
|
| 514 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 515 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 516 |
+
return model
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def vit_large(patch_size=16, **kwargs):
|
| 520 |
+
model = VisionTransformer(
|
| 521 |
+
patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
|
| 522 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 523 |
+
return model
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def vit_huge(patch_size=16, **kwargs):
|
| 527 |
+
model = VisionTransformer(
|
| 528 |
+
patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
| 529 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 530 |
+
return model
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def vit_giant(patch_size=16, **kwargs):
|
| 534 |
+
model = VisionTransformer(
|
| 535 |
+
patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
|
| 536 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 537 |
+
return model
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
VIT_EMBED_DIMS = {
|
| 541 |
+
'vit_tiny': 192,
|
| 542 |
+
'vit_small': 384,
|
| 543 |
+
'vit_base': 768,
|
| 544 |
+
'vit_large': 1024,
|
| 545 |
+
'vit_huge': 1280,
|
| 546 |
+
'vit_giant': 1408,
|
| 547 |
+
}
|
back/models/mae_vit.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
import timm.models.vision_transformer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
| 21 |
+
""" Vision Transformer with support for global average pooling
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, global_pool=False, **kwargs):
|
| 24 |
+
super(VisionTransformer, self).__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
self.global_pool = global_pool
|
| 27 |
+
if self.global_pool:
|
| 28 |
+
norm_layer = kwargs['norm_layer']
|
| 29 |
+
embed_dim = kwargs['embed_dim']
|
| 30 |
+
self.fc_norm = norm_layer(embed_dim)
|
| 31 |
+
|
| 32 |
+
del self.norm # remove the original norm
|
| 33 |
+
|
| 34 |
+
def forward_features(self, x):
|
| 35 |
+
B = x.shape[0]
|
| 36 |
+
x = self.patch_embed(x)
|
| 37 |
+
|
| 38 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 39 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 40 |
+
x = x + self.pos_embed
|
| 41 |
+
x = self.pos_drop(x)
|
| 42 |
+
|
| 43 |
+
for blk in self.blocks:
|
| 44 |
+
x = blk(x)
|
| 45 |
+
|
| 46 |
+
x = x[:, 1:, :] #.mean(dim=1) # global pool without cls token
|
| 47 |
+
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def vit_base_patch16(**kwargs):
|
| 52 |
+
model = VisionTransformer(
|
| 53 |
+
num_classes=0,
|
| 54 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 55 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def vit_large_patch16(**kwargs):
|
| 60 |
+
model = VisionTransformer(
|
| 61 |
+
num_classes=0,
|
| 62 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 63 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def vit_huge_patch14(**kwargs):
|
| 68 |
+
model = VisionTransformer(
|
| 69 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 70 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 71 |
+
return model
|
back/models/mocov3_vit.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from functools import partial, reduce
|
| 11 |
+
from operator import mul
|
| 12 |
+
|
| 13 |
+
from timm.layers.helpers import to_2tuple
|
| 14 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
| 15 |
+
from timm.models.vision_transformer import PatchEmbed
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'vit_small',
|
| 19 |
+
'vit_base',
|
| 20 |
+
'vit_large',
|
| 21 |
+
'vit_conv_small',
|
| 22 |
+
'vit_conv_base',
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def patchify_avg(input_tensor, patch_size):
|
| 27 |
+
# Ensure input tensor is 4D: (batch_size, channels, height, width)
|
| 28 |
+
if input_tensor.dim() != 4:
|
| 29 |
+
raise ValueError("Input tensor must be 4D (batch_size, channels, height, width)")
|
| 30 |
+
|
| 31 |
+
# Get input tensor dimensions
|
| 32 |
+
batch_size, channels, height, width = input_tensor.shape
|
| 33 |
+
|
| 34 |
+
# Ensure patch_size is valid
|
| 35 |
+
patch_height, patch_width = patch_size, patch_size
|
| 36 |
+
if height % patch_height != 0 or width % patch_width != 0:
|
| 37 |
+
raise ValueError("Input tensor dimensions must be divisible by patch_size")
|
| 38 |
+
|
| 39 |
+
# Use unfold to create patches
|
| 40 |
+
patches = input_tensor.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
|
| 41 |
+
|
| 42 |
+
# Reshape patches to desired format: (batch_size, num_patches, channels)
|
| 43 |
+
patches = patches.contiguous().view(
|
| 44 |
+
batch_size, channels, -1, patch_height, patch_width
|
| 45 |
+
).mean(dim=-1).mean(dim=-1)
|
| 46 |
+
patches = patches.permute(0, 2, 1).contiguous()
|
| 47 |
+
|
| 48 |
+
return patches
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class VisionTransformerMoCo(VisionTransformer):
|
| 53 |
+
def __init__(self, stop_grad_conv1=False, **kwargs):
|
| 54 |
+
super().__init__(**kwargs)
|
| 55 |
+
# Use fixed 2D sin-cos position embedding
|
| 56 |
+
self.build_2d_sincos_position_embedding()
|
| 57 |
+
|
| 58 |
+
# weight initialization
|
| 59 |
+
for name, m in self.named_modules():
|
| 60 |
+
if isinstance(m, nn.Linear):
|
| 61 |
+
if 'qkv' in name:
|
| 62 |
+
# treat the weights of Q, K, V separately
|
| 63 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
| 64 |
+
nn.init.uniform_(m.weight, -val, val)
|
| 65 |
+
else:
|
| 66 |
+
nn.init.xavier_uniform_(m.weight)
|
| 67 |
+
nn.init.zeros_(m.bias)
|
| 68 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 69 |
+
|
| 70 |
+
if isinstance(self.patch_embed, PatchEmbed):
|
| 71 |
+
# xavier_uniform initialization
|
| 72 |
+
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
|
| 73 |
+
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
|
| 74 |
+
nn.init.zeros_(self.patch_embed.proj.bias)
|
| 75 |
+
|
| 76 |
+
if stop_grad_conv1:
|
| 77 |
+
self.patch_embed.proj.weight.requires_grad = False
|
| 78 |
+
self.patch_embed.proj.bias.requires_grad = False
|
| 79 |
+
|
| 80 |
+
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
| 81 |
+
h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]
|
| 82 |
+
w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]
|
| 83 |
+
grid_w = torch.arange(w, dtype=torch.float32)
|
| 84 |
+
grid_h = torch.arange(h, dtype=torch.float32)
|
| 85 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
| 86 |
+
assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
| 87 |
+
pos_dim = self.embed_dim // 4
|
| 88 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
| 89 |
+
omega = 1. / (temperature**omega)
|
| 90 |
+
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
| 91 |
+
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
| 92 |
+
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
| 93 |
+
|
| 94 |
+
# assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
|
| 95 |
+
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
| 96 |
+
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
| 97 |
+
self.pos_embed.requires_grad = False
|
| 98 |
+
|
| 99 |
+
def forward_diffusion_output(self, x):
|
| 100 |
+
x = x.reshape(*x.shape[0:2], -1).permute(0, 2, 1)
|
| 101 |
+
x = self._pos_embed(x)
|
| 102 |
+
x = self.patch_drop(x)
|
| 103 |
+
x = self.norm_pre(x)
|
| 104 |
+
x = self.blocks(x)
|
| 105 |
+
x = self.norm(x)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
class ConvStem(nn.Module):
|
| 109 |
+
"""
|
| 110 |
+
ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
assert patch_size == 16, 'ConvStem only supports patch size of 16'
|
| 116 |
+
assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
|
| 117 |
+
|
| 118 |
+
img_size = to_2tuple(img_size)
|
| 119 |
+
patch_size = to_2tuple(patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 123 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 124 |
+
self.flatten = flatten
|
| 125 |
+
|
| 126 |
+
# build stem, similar to the design in https://arxiv.org/abs/2106.14881
|
| 127 |
+
stem = []
|
| 128 |
+
input_dim, output_dim = 3, embed_dim // 8
|
| 129 |
+
for l in range(4):
|
| 130 |
+
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
| 131 |
+
stem.append(nn.BatchNorm2d(output_dim))
|
| 132 |
+
stem.append(nn.ReLU(inplace=True))
|
| 133 |
+
input_dim = output_dim
|
| 134 |
+
output_dim *= 2
|
| 135 |
+
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
| 136 |
+
self.proj = nn.Sequential(*stem)
|
| 137 |
+
|
| 138 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
B, C, H, W = x.shape
|
| 142 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 143 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 144 |
+
x = self.proj(x)
|
| 145 |
+
if self.flatten:
|
| 146 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 147 |
+
x = self.norm(x)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def vit_small(**kwargs):
|
| 152 |
+
model = VisionTransformerMoCo(
|
| 153 |
+
img_size=256,
|
| 154 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 155 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 156 |
+
model.default_cfg = _cfg()
|
| 157 |
+
return model
|
| 158 |
+
|
| 159 |
+
def vit_base(**kwargs):
|
| 160 |
+
model = VisionTransformerMoCo(
|
| 161 |
+
img_size=256,
|
| 162 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 163 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 164 |
+
model.default_cfg = _cfg()
|
| 165 |
+
return model
|
| 166 |
+
|
| 167 |
+
def vit_large(**kwargs):
|
| 168 |
+
model = VisionTransformerMoCo(
|
| 169 |
+
img_size=256,
|
| 170 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 171 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 172 |
+
model.default_cfg = _cfg()
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
def vit_conv_small(**kwargs):
|
| 176 |
+
# minus one ViT block
|
| 177 |
+
model = VisionTransformerMoCo(
|
| 178 |
+
patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 179 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
| 180 |
+
model.default_cfg = _cfg()
|
| 181 |
+
return model
|
| 182 |
+
|
| 183 |
+
def vit_conv_base(**kwargs):
|
| 184 |
+
# minus one ViT block
|
| 185 |
+
model = VisionTransformerMoCo(
|
| 186 |
+
patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 187 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
| 188 |
+
model.default_cfg = _cfg()
|
| 189 |
+
return model
|
| 190 |
+
|
| 191 |
+
def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
|
| 192 |
+
mlp = []
|
| 193 |
+
for l in range(num_layers):
|
| 194 |
+
dim1 = input_dim if l == 0 else mlp_dim
|
| 195 |
+
dim2 = output_dim if l == num_layers - 1 else mlp_dim
|
| 196 |
+
|
| 197 |
+
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
| 198 |
+
|
| 199 |
+
if l < num_layers - 1:
|
| 200 |
+
mlp.append(nn.BatchNorm1d(dim2))
|
| 201 |
+
mlp.append(nn.ReLU(inplace=True))
|
| 202 |
+
elif last_bn:
|
| 203 |
+
# follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
|
| 204 |
+
# for simplicity, we further removed gamma in BN
|
| 205 |
+
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
| 206 |
+
|
| 207 |
+
return nn.Sequential(*mlp)
|
back/models/sit.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# References:
|
| 5 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 6 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
| 17 |
+
return nn.Sequential(
|
| 18 |
+
nn.Linear(hidden_size, projector_dim),
|
| 19 |
+
nn.SiLU(),
|
| 20 |
+
nn.Linear(projector_dim, projector_dim),
|
| 21 |
+
nn.SiLU(),
|
| 22 |
+
nn.Linear(projector_dim, z_dim),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def modulate(x, shift, scale):
|
| 26 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 27 |
+
|
| 28 |
+
#################################################################################
|
| 29 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 30 |
+
#################################################################################
|
| 31 |
+
class TimestepEmbedder(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Embeds scalar timesteps into vector representations.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.mlp = nn.Sequential(
|
| 38 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 41 |
+
)
|
| 42 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def positional_embedding(t, dim, max_period=10000):
|
| 46 |
+
"""
|
| 47 |
+
Create sinusoidal timestep embeddings.
|
| 48 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 49 |
+
These may be fractional.
|
| 50 |
+
:param dim: the dimension of the output.
|
| 51 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 52 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 53 |
+
"""
|
| 54 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 55 |
+
half = dim // 2
|
| 56 |
+
freqs = torch.exp(
|
| 57 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 58 |
+
).to(device=t.device)
|
| 59 |
+
args = t[:, None].float() * freqs[None]
|
| 60 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 61 |
+
if dim % 2:
|
| 62 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 63 |
+
return embedding
|
| 64 |
+
|
| 65 |
+
def forward(self, t):
|
| 66 |
+
self.timestep_embedding = self.positional_embedding
|
| 67 |
+
t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
|
| 68 |
+
t_emb = self.mlp(t_freq)
|
| 69 |
+
return t_emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LabelEmbedder(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 77 |
+
super().__init__()
|
| 78 |
+
use_cfg_embedding = dropout_prob > 0
|
| 79 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 80 |
+
self.num_classes = num_classes
|
| 81 |
+
self.dropout_prob = dropout_prob
|
| 82 |
+
|
| 83 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 84 |
+
"""
|
| 85 |
+
Drops labels to enable classifier-free guidance.
|
| 86 |
+
"""
|
| 87 |
+
if force_drop_ids is None:
|
| 88 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 89 |
+
else:
|
| 90 |
+
drop_ids = force_drop_ids == 1
|
| 91 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 92 |
+
return labels
|
| 93 |
+
|
| 94 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 95 |
+
use_dropout = self.dropout_prob > 0
|
| 96 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 97 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 98 |
+
embeddings = self.embedding_table(labels)
|
| 99 |
+
return embeddings
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
#################################################################################
|
| 103 |
+
# Core SiT Model #
|
| 104 |
+
#################################################################################
|
| 105 |
+
|
| 106 |
+
class SiTBlock(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 113 |
+
self.attn = Attention(
|
| 114 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
|
| 115 |
+
)
|
| 116 |
+
if "fused_attn" in block_kwargs.keys():
|
| 117 |
+
self.attn.fused_attn = block_kwargs["fused_attn"]
|
| 118 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 119 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 120 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 121 |
+
self.mlp = Mlp(
|
| 122 |
+
in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
|
| 123 |
+
)
|
| 124 |
+
self.adaLN_modulation = nn.Sequential(
|
| 125 |
+
nn.SiLU(),
|
| 126 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, c):
|
| 130 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 131 |
+
self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 132 |
+
)
|
| 133 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 134 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 135 |
+
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class FinalLayer(nn.Module):
|
| 140 |
+
"""
|
| 141 |
+
The final layer of SiT.
|
| 142 |
+
"""
|
| 143 |
+
def __init__(self, hidden_size, patch_size, out_channels, cls_token_dim):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 146 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 147 |
+
self.linear_cls = nn.Linear(hidden_size, cls_token_dim, bias=True)
|
| 148 |
+
self.adaLN_modulation = nn.Sequential(
|
| 149 |
+
nn.SiLU(),
|
| 150 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def forward(self, x, c, cls=None):
|
| 154 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 155 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 156 |
+
|
| 157 |
+
if cls is None:
|
| 158 |
+
x = self.linear(x)
|
| 159 |
+
return x, None
|
| 160 |
+
else:
|
| 161 |
+
cls_token = self.linear_cls(x[:, 0]).unsqueeze(1)
|
| 162 |
+
x = self.linear(x[:, 1:])
|
| 163 |
+
return x, cls_token.squeeze(1)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SiT(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Diffusion model with a Transformer backbone.
|
| 169 |
+
"""
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
path_type='edm',
|
| 173 |
+
input_size=32,
|
| 174 |
+
patch_size=2,
|
| 175 |
+
in_channels=4,
|
| 176 |
+
hidden_size=1152,
|
| 177 |
+
decoder_hidden_size=768,
|
| 178 |
+
encoder_depth=8,
|
| 179 |
+
depth=28,
|
| 180 |
+
num_heads=16,
|
| 181 |
+
mlp_ratio=4.0,
|
| 182 |
+
class_dropout_prob=0.1,
|
| 183 |
+
num_classes=1000,
|
| 184 |
+
use_cfg=False,
|
| 185 |
+
z_dims=[768],
|
| 186 |
+
projector_dim=2048,
|
| 187 |
+
cls_token_dim=768,
|
| 188 |
+
**block_kwargs # fused_attn
|
| 189 |
+
):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.path_type = path_type
|
| 192 |
+
self.in_channels = in_channels
|
| 193 |
+
self.out_channels = in_channels
|
| 194 |
+
self.patch_size = patch_size
|
| 195 |
+
self.num_heads = num_heads
|
| 196 |
+
self.use_cfg = use_cfg
|
| 197 |
+
self.num_classes = num_classes
|
| 198 |
+
self.z_dims = z_dims
|
| 199 |
+
self.encoder_depth = encoder_depth
|
| 200 |
+
|
| 201 |
+
self.x_embedder = PatchEmbed(
|
| 202 |
+
input_size, patch_size, in_channels, hidden_size, bias=True
|
| 203 |
+
)
|
| 204 |
+
self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
|
| 205 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 206 |
+
num_patches = self.x_embedder.num_patches
|
| 207 |
+
# Will use fixed sin-cos embedding:
|
| 208 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, hidden_size), requires_grad=False)
|
| 209 |
+
|
| 210 |
+
self.blocks = nn.ModuleList([
|
| 211 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
|
| 212 |
+
])
|
| 213 |
+
self.projectors = nn.ModuleList([
|
| 214 |
+
build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims
|
| 215 |
+
])
|
| 216 |
+
|
| 217 |
+
z_dim = self.z_dims[0]
|
| 218 |
+
cls_token_dim = z_dim
|
| 219 |
+
self.final_layer = FinalLayer(decoder_hidden_size, patch_size, self.out_channels, cls_token_dim)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
self.cls_projectors2 = nn.Linear(in_features=cls_token_dim, out_features=hidden_size, bias=True)
|
| 223 |
+
self.wg_norm = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 224 |
+
|
| 225 |
+
self.initialize_weights()
|
| 226 |
+
|
| 227 |
+
def initialize_weights(self):
|
| 228 |
+
# Initialize transformer layers:
|
| 229 |
+
def _basic_init(module):
|
| 230 |
+
if isinstance(module, nn.Linear):
|
| 231 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 232 |
+
if module.bias is not None:
|
| 233 |
+
nn.init.constant_(module.bias, 0)
|
| 234 |
+
self.apply(_basic_init)
|
| 235 |
+
|
| 236 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 237 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 238 |
+
self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), cls_token=1, extra_tokens=1
|
| 239 |
+
)
|
| 240 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 241 |
+
|
| 242 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 243 |
+
w = self.x_embedder.proj.weight.data
|
| 244 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 245 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 246 |
+
|
| 247 |
+
# Initialize label embedding table:
|
| 248 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 249 |
+
|
| 250 |
+
# Initialize timestep embedding MLP:
|
| 251 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 252 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 253 |
+
|
| 254 |
+
# Zero-out adaLN modulation layers in SiT blocks:
|
| 255 |
+
for block in self.blocks:
|
| 256 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 257 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 258 |
+
|
| 259 |
+
# Zero-out output layers:
|
| 260 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 261 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 262 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 263 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 264 |
+
nn.init.constant_(self.final_layer.linear_cls.weight, 0)
|
| 265 |
+
nn.init.constant_(self.final_layer.linear_cls.bias, 0)
|
| 266 |
+
|
| 267 |
+
def unpatchify(self, x, patch_size=None):
|
| 268 |
+
"""
|
| 269 |
+
x: (N, T, patch_size**2 * C)
|
| 270 |
+
imgs: (N, C, H, W)
|
| 271 |
+
"""
|
| 272 |
+
c = self.out_channels
|
| 273 |
+
p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
|
| 274 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 275 |
+
assert h * w == x.shape[1]
|
| 276 |
+
|
| 277 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 278 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 279 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 280 |
+
return imgs
|
| 281 |
+
|
| 282 |
+
def forward(self, x, t, y, return_logvar=False, cls_token=None):
|
| 283 |
+
"""
|
| 284 |
+
Forward pass of SiT.
|
| 285 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 286 |
+
t: (N,) tensor of diffusion timesteps
|
| 287 |
+
y: (N,) tensor of class labels
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
#cat with cls_token
|
| 291 |
+
x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
|
| 292 |
+
if cls_token is not None:
|
| 293 |
+
cls_token = self.cls_projectors2(cls_token)
|
| 294 |
+
cls_token = self.wg_norm(cls_token)
|
| 295 |
+
cls_token = cls_token.unsqueeze(1) # [b, length, d]
|
| 296 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 297 |
+
x = x + self.pos_embed
|
| 298 |
+
else:
|
| 299 |
+
exit()
|
| 300 |
+
N, T, D = x.shape
|
| 301 |
+
|
| 302 |
+
# timestep and class embedding
|
| 303 |
+
t_embed = self.t_embedder(t) # (N, D)
|
| 304 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 305 |
+
c = t_embed + y
|
| 306 |
+
|
| 307 |
+
for i, block in enumerate(self.blocks):
|
| 308 |
+
x = block(x, c)
|
| 309 |
+
if (i + 1) == self.encoder_depth:
|
| 310 |
+
zs = [projector(x.reshape(-1, D)).reshape(N, T, -1) for projector in self.projectors]
|
| 311 |
+
|
| 312 |
+
x, cls_token = self.final_layer(x, c, cls=cls_token)
|
| 313 |
+
x = self.unpatchify(x)
|
| 314 |
+
|
| 315 |
+
return x, zs, cls_token
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
#################################################################################
|
| 319 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 320 |
+
#################################################################################
|
| 321 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 322 |
+
|
| 323 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 324 |
+
"""
|
| 325 |
+
grid_size: int of the grid height and width
|
| 326 |
+
return:
|
| 327 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 328 |
+
"""
|
| 329 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 330 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 331 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 332 |
+
grid = np.stack(grid, axis=0)
|
| 333 |
+
|
| 334 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 335 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 336 |
+
if cls_token and extra_tokens > 0:
|
| 337 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 338 |
+
return pos_embed
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 342 |
+
assert embed_dim % 2 == 0
|
| 343 |
+
|
| 344 |
+
# use half of dimensions to encode grid_h
|
| 345 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 346 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 347 |
+
|
| 348 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 349 |
+
return emb
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 353 |
+
"""
|
| 354 |
+
embed_dim: output dimension for each position
|
| 355 |
+
pos: a list of positions to be encoded: size (M,)
|
| 356 |
+
out: (M, D)
|
| 357 |
+
"""
|
| 358 |
+
assert embed_dim % 2 == 0
|
| 359 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 360 |
+
omega /= embed_dim / 2.
|
| 361 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 362 |
+
|
| 363 |
+
pos = pos.reshape(-1) # (M,)
|
| 364 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 365 |
+
|
| 366 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 367 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 368 |
+
|
| 369 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 370 |
+
return emb
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
#################################################################################
|
| 374 |
+
# SiT Configs #
|
| 375 |
+
#################################################################################
|
| 376 |
+
|
| 377 |
+
def SiT_XL_2(**kwargs):
|
| 378 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 379 |
+
|
| 380 |
+
def SiT_XL_4(**kwargs):
|
| 381 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 382 |
+
|
| 383 |
+
def SiT_XL_8(**kwargs):
|
| 384 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 385 |
+
|
| 386 |
+
def SiT_L_2(**kwargs):
|
| 387 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 388 |
+
|
| 389 |
+
def SiT_L_4(**kwargs):
|
| 390 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 391 |
+
|
| 392 |
+
def SiT_L_8(**kwargs):
|
| 393 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 394 |
+
|
| 395 |
+
def SiT_B_2(**kwargs):
|
| 396 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 397 |
+
|
| 398 |
+
def SiT_B_4(**kwargs):
|
| 399 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 400 |
+
|
| 401 |
+
def SiT_B_8(**kwargs):
|
| 402 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 403 |
+
|
| 404 |
+
def SiT_S_2(**kwargs):
|
| 405 |
+
return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 406 |
+
|
| 407 |
+
def SiT_S_4(**kwargs):
|
| 408 |
+
return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 409 |
+
|
| 410 |
+
def SiT_S_8(**kwargs):
|
| 411 |
+
return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
SiT_models = {
|
| 415 |
+
'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
|
| 416 |
+
'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
|
| 417 |
+
'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
|
| 418 |
+
'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
|
| 419 |
+
}
|
| 420 |
+
|
back/preprocessing/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center"> Preprocessing Guide
|
| 2 |
+
</h1>
|
| 3 |
+
|
| 4 |
+
#### Dataset download
|
| 5 |
+
|
| 6 |
+
We follow the preprocessing code used in [edm2](https://github.com/NVlabs/edm2). In this code we made a several edits: (1) we removed unncessary parts except preprocessing because this code is only used for preprocessing, (2) we use [-1, 1] range for an input to the stable diffusion VAE (similar to DiT or SiT) unlike edm2 that uses [0, 1] range, and (3) we consider preprocessing to 256x256 resolution (or 512x512 resolution).
|
| 7 |
+
|
| 8 |
+
After downloading ImageNet, please run the following scripts (please update 256x256 to 512x512 if you want to do experiments on 512x512 resolution);
|
| 9 |
+
|
| 10 |
+
Convert raw ImageNet data to a ZIP archive at 256x256 resolution
|
| 11 |
+
```bash
|
| 12 |
+
bash dataset_prepare_encode.sh
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Convert the pixel data to VAE latents
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash dataset_prepare_convert.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Here,`YOUR_DOWNLOAD_PATH` is the directory that you downloaded the dataset, and `TARGET_PATH` is the directory that you will save the preprocessed images and corresponding compressed latent vectors. This directory will be used for your experiment scripts.
|
| 22 |
+
|
| 23 |
+
## Acknowledgement
|
| 24 |
+
|
| 25 |
+
This code is mainly built upon [edm2](https://github.com/NVlabs/edm2) repository.
|
back/preprocessing/dataset_image_encoder.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Tool for creating ZIP/PNG based datasets."""
|
| 9 |
+
|
| 10 |
+
from collections.abc import Iterator
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import functools
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import zipfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Callable, Optional, Tuple, Union
|
| 20 |
+
import click
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from encoders import StabilityVAEEncoder
|
| 27 |
+
from utils import load_encoders
|
| 28 |
+
from torchvision.transforms import Normalize
|
| 29 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 30 |
+
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 31 |
+
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 32 |
+
|
| 33 |
+
def preprocess_raw_image(x, enc_type):
|
| 34 |
+
resolution = x.shape[-1]
|
| 35 |
+
if 'clip' in enc_type:
|
| 36 |
+
x = x / 255.
|
| 37 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 38 |
+
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| 39 |
+
elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| 40 |
+
x = x / 255.
|
| 41 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 42 |
+
elif 'dinov2' in enc_type:
|
| 43 |
+
x = x / 255.
|
| 44 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 45 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 46 |
+
elif 'dinov1' in enc_type:
|
| 47 |
+
x = x / 255.
|
| 48 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 49 |
+
elif 'jepa' in enc_type:
|
| 50 |
+
x = x / 255.
|
| 51 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 52 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
#----------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class ImageEntry:
|
| 61 |
+
img: np.ndarray
|
| 62 |
+
label: Optional[int]
|
| 63 |
+
|
| 64 |
+
#----------------------------------------------------------------------------
|
| 65 |
+
# Parse a 'M,N' or 'MxN' integer tuple.
|
| 66 |
+
# Example: '4x2' returns (4,2)
|
| 67 |
+
|
| 68 |
+
def parse_tuple(s: str) -> Tuple[int, int]:
|
| 69 |
+
m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| 70 |
+
if m:
|
| 71 |
+
return int(m.group(1)), int(m.group(2))
|
| 72 |
+
raise click.ClickException(f'cannot parse tuple {s}')
|
| 73 |
+
|
| 74 |
+
#----------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
| 77 |
+
if b is not None:
|
| 78 |
+
return min(a, b)
|
| 79 |
+
return a
|
| 80 |
+
|
| 81 |
+
#----------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def file_ext(name: Union[str, Path]) -> str:
|
| 84 |
+
return str(name).split('.')[-1]
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
| 89 |
+
ext = file_ext(fname).lower()
|
| 90 |
+
return f'.{ext}' in PIL.Image.EXTENSION
|
| 91 |
+
|
| 92 |
+
#----------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 95 |
+
input_images = []
|
| 96 |
+
def _recurse_dirs(root: str): # workaround Path().rglob() slowness
|
| 97 |
+
with os.scandir(root) as it:
|
| 98 |
+
for e in it:
|
| 99 |
+
if e.is_file():
|
| 100 |
+
input_images.append(os.path.join(root, e.name))
|
| 101 |
+
elif e.is_dir():
|
| 102 |
+
_recurse_dirs(os.path.join(root, e.name))
|
| 103 |
+
_recurse_dirs(source_dir)
|
| 104 |
+
input_images = sorted([f for f in input_images if is_image_ext(f)])
|
| 105 |
+
|
| 106 |
+
arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| 107 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 108 |
+
|
| 109 |
+
# Load labels.
|
| 110 |
+
labels = dict()
|
| 111 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
| 112 |
+
if os.path.isfile(meta_fname):
|
| 113 |
+
with open(meta_fname, 'r') as file:
|
| 114 |
+
data = json.load(file)['labels']
|
| 115 |
+
if data is not None:
|
| 116 |
+
labels = {x[0]: x[1] for x in data}
|
| 117 |
+
|
| 118 |
+
# No labels available => determine from top-level directory names.
|
| 119 |
+
if len(labels) == 0:
|
| 120 |
+
toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| 121 |
+
toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| 122 |
+
if len(toplevel_indices) > 1:
|
| 123 |
+
labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
| 124 |
+
|
| 125 |
+
def iterate_images():
|
| 126 |
+
for idx, fname in enumerate(input_images):
|
| 127 |
+
img = np.array(PIL.Image.open(fname).convert('RGB'))#.transpose(2, 0, 1)
|
| 128 |
+
yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
|
| 129 |
+
if idx >= max_idx - 1:
|
| 130 |
+
break
|
| 131 |
+
return max_idx, iterate_images()
|
| 132 |
+
|
| 133 |
+
#----------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 136 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 137 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| 138 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 139 |
+
|
| 140 |
+
# Load labels.
|
| 141 |
+
labels = dict()
|
| 142 |
+
if 'dataset.json' in z.namelist():
|
| 143 |
+
with z.open('dataset.json', 'r') as file:
|
| 144 |
+
data = json.load(file)['labels']
|
| 145 |
+
if data is not None:
|
| 146 |
+
labels = {x[0]: x[1] for x in data}
|
| 147 |
+
|
| 148 |
+
def iterate_images():
|
| 149 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 150 |
+
for idx, fname in enumerate(input_images):
|
| 151 |
+
with z.open(fname, 'r') as file:
|
| 152 |
+
img = np.array(PIL.Image.open(file).convert('RGB'))
|
| 153 |
+
yield ImageEntry(img=img, label=labels.get(fname))
|
| 154 |
+
if idx >= max_idx - 1:
|
| 155 |
+
break
|
| 156 |
+
return max_idx, iterate_images()
|
| 157 |
+
|
| 158 |
+
#----------------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
def make_transform(
|
| 161 |
+
transform: Optional[str],
|
| 162 |
+
output_width: Optional[int],
|
| 163 |
+
output_height: Optional[int]
|
| 164 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| 165 |
+
def scale(width, height, img):
|
| 166 |
+
w = img.shape[1]
|
| 167 |
+
h = img.shape[0]
|
| 168 |
+
if width == w and height == h:
|
| 169 |
+
return img
|
| 170 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 171 |
+
ww = width if width is not None else w
|
| 172 |
+
hh = height if height is not None else h
|
| 173 |
+
img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| 174 |
+
return np.array(img)
|
| 175 |
+
|
| 176 |
+
def center_crop(width, height, img):
|
| 177 |
+
crop = np.min(img.shape[:2])
|
| 178 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| 179 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 180 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 181 |
+
return np.array(img)
|
| 182 |
+
|
| 183 |
+
def center_crop_wide(width, height, img):
|
| 184 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| 185 |
+
if img.shape[1] < width or ch < height:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| 189 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 190 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 191 |
+
img = np.array(img)
|
| 192 |
+
|
| 193 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| 194 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| 195 |
+
return canvas
|
| 196 |
+
|
| 197 |
+
def center_crop_imagenet(image_size: int, arr: np.ndarray):
|
| 198 |
+
"""
|
| 199 |
+
Center cropping implementation from ADM.
|
| 200 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 201 |
+
"""
|
| 202 |
+
pil_image = PIL.Image.fromarray(arr)
|
| 203 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 204 |
+
new_size = tuple(x // 2 for x in pil_image.size)
|
| 205 |
+
assert len(new_size) == 2
|
| 206 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
|
| 207 |
+
|
| 208 |
+
scale = image_size / min(*pil_image.size)
|
| 209 |
+
new_size = tuple(round(x * scale) for x in pil_image.size)
|
| 210 |
+
assert len(new_size) == 2
|
| 211 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
|
| 212 |
+
|
| 213 |
+
arr = np.array(pil_image)
|
| 214 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 215 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 216 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
| 217 |
+
|
| 218 |
+
if transform is None:
|
| 219 |
+
return functools.partial(scale, output_width, output_height)
|
| 220 |
+
if transform == 'center-crop':
|
| 221 |
+
if output_width is None or output_height is None:
|
| 222 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| 223 |
+
return functools.partial(center_crop, output_width, output_height)
|
| 224 |
+
if transform == 'center-crop-wide':
|
| 225 |
+
if output_width is None or output_height is None:
|
| 226 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 227 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
| 228 |
+
if transform == 'center-crop-dhariwal':
|
| 229 |
+
if output_width is None or output_height is None:
|
| 230 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 231 |
+
if output_width != output_height:
|
| 232 |
+
raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
|
| 233 |
+
return functools.partial(center_crop_imagenet, output_width)
|
| 234 |
+
assert False, 'unknown transform'
|
| 235 |
+
|
| 236 |
+
#----------------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
| 239 |
+
if os.path.isdir(source):
|
| 240 |
+
return open_image_folder(source, max_images=max_images)
|
| 241 |
+
elif os.path.isfile(source):
|
| 242 |
+
if file_ext(source) == 'zip':
|
| 243 |
+
return open_image_zip(source, max_images=max_images)
|
| 244 |
+
else:
|
| 245 |
+
raise click.ClickException(f'Only zip archives are supported: {source}')
|
| 246 |
+
else:
|
| 247 |
+
raise click.ClickException(f'Missing input file or directory: {source}')
|
| 248 |
+
|
| 249 |
+
#----------------------------------------------------------------------------
|
| 250 |
+
|
| 251 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| 252 |
+
dest_ext = file_ext(dest)
|
| 253 |
+
|
| 254 |
+
if dest_ext == 'zip':
|
| 255 |
+
if os.path.dirname(dest) != '':
|
| 256 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| 257 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| 258 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| 259 |
+
zf.writestr(fname, data)
|
| 260 |
+
return '', zip_write_bytes, zf.close
|
| 261 |
+
else:
|
| 262 |
+
# If the output folder already exists, check that is is
|
| 263 |
+
# empty.
|
| 264 |
+
#
|
| 265 |
+
# Note: creating the output directory is not strictly
|
| 266 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
| 267 |
+
# to give an error message earlier in case the dest folder
|
| 268 |
+
# somehow cannot be created.
|
| 269 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| 270 |
+
raise click.ClickException('--dest folder must be empty')
|
| 271 |
+
os.makedirs(dest, exist_ok=True)
|
| 272 |
+
|
| 273 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| 274 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| 275 |
+
with open(fname, 'wb') as fout:
|
| 276 |
+
if isinstance(data, str):
|
| 277 |
+
data = data.encode('utf8')
|
| 278 |
+
fout.write(data)
|
| 279 |
+
return dest, folder_write_bytes, lambda: None
|
| 280 |
+
|
| 281 |
+
#----------------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
@click.group()
|
| 284 |
+
def cmdline():
|
| 285 |
+
'''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
|
| 286 |
+
if os.environ.get('WORLD_SIZE', '1') != '1':
|
| 287 |
+
raise click.ClickException('Distributed execution is not supported.')
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
#----------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@cmdline.command()
|
| 295 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 296 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 297 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 298 |
+
@click.option('--enc-type', help='Maximum number of images to output', metavar='PATH', type=str, default='dinov2-vit-b')
|
| 299 |
+
@click.option('--resolution', help='Maximum number of images to output', metavar='INT', type=int, default=256)
|
| 300 |
+
|
| 301 |
+
def encode(
|
| 302 |
+
source: str,
|
| 303 |
+
dest: str,
|
| 304 |
+
max_images: Optional[int],
|
| 305 |
+
enc_type,
|
| 306 |
+
resolution
|
| 307 |
+
):
|
| 308 |
+
|
| 309 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 310 |
+
encoder, encoder_type, architectures = load_encoders(enc_type, device, resolution)
|
| 311 |
+
encoder, encoder_type, architectures = encoder[0], encoder_type[0], architectures[0]
|
| 312 |
+
print("Encoder is over!!!")
|
| 313 |
+
|
| 314 |
+
"""Encode pixel data to VAE latents."""
|
| 315 |
+
PIL.Image.init()
|
| 316 |
+
if dest == '':
|
| 317 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 318 |
+
|
| 319 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 320 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 321 |
+
print("Data is over!!!")
|
| 322 |
+
labels = []
|
| 323 |
+
|
| 324 |
+
temp_list1 = []
|
| 325 |
+
temp_list2 = []
|
| 326 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
|
| 329 |
+
raw_image_ = preprocess_raw_image(img_tensor, encoder_type)
|
| 330 |
+
z = encoder.forward_features(raw_image_)
|
| 331 |
+
if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
|
| 332 |
+
temp_list1.append(z)
|
| 333 |
+
z = z.detach().cpu().numpy()
|
| 334 |
+
temp_list2.append(z)
|
| 335 |
+
|
| 336 |
+
idx_str = f'{idx:08d}'
|
| 337 |
+
archive_fname = f'{idx_str[:5]}/img-feature-{idx_str}.npy'
|
| 338 |
+
|
| 339 |
+
f = io.BytesIO()
|
| 340 |
+
np.save(f, z)
|
| 341 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
|
| 342 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 346 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 347 |
+
close_dest()
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
cmdline()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/dataset_prepare_convert.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#256
|
| 7 |
+
python preprocessing/dataset_tools.py convert \
|
| 8 |
+
--source=/home/share/imagenet/train \
|
| 9 |
+
--dest=/home/share/imagenet_vae/imagenet_256_vae \
|
| 10 |
+
--resolution=256x256 \
|
| 11 |
+
--transform=center-crop-dhariwal
|
back/preprocessing/dataset_prepare_encode.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#256
|
| 7 |
+
python preprocessing/dataset_tools.py encode \
|
| 8 |
+
--source=/home/share/imagenet_vae/imagenet_256_vae \
|
| 9 |
+
--dest=/home/share/imagenet_vae/vae-sd-256
|
back/preprocessing/dataset_tools.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Tool for creating ZIP/PNG based datasets."""
|
| 9 |
+
|
| 10 |
+
from collections.abc import Iterator
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import functools
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import zipfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Callable, Optional, Tuple, Union
|
| 20 |
+
import click
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from encoders import StabilityVAEEncoder
|
| 27 |
+
|
| 28 |
+
#----------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ImageEntry:
|
| 32 |
+
img: np.ndarray
|
| 33 |
+
label: Optional[int]
|
| 34 |
+
|
| 35 |
+
#----------------------------------------------------------------------------
|
| 36 |
+
# Parse a 'M,N' or 'MxN' integer tuple.
|
| 37 |
+
# Example: '4x2' returns (4,2)
|
| 38 |
+
|
| 39 |
+
def parse_tuple(s: str) -> Tuple[int, int]:
|
| 40 |
+
m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| 41 |
+
if m:
|
| 42 |
+
return int(m.group(1)), int(m.group(2))
|
| 43 |
+
raise click.ClickException(f'cannot parse tuple {s}')
|
| 44 |
+
|
| 45 |
+
#----------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
| 48 |
+
if b is not None:
|
| 49 |
+
return min(a, b)
|
| 50 |
+
return a
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
def file_ext(name: Union[str, Path]) -> str:
|
| 55 |
+
return str(name).split('.')[-1]
|
| 56 |
+
|
| 57 |
+
#----------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
| 60 |
+
ext = file_ext(fname).lower()
|
| 61 |
+
return f'.{ext}' in PIL.Image.EXTENSION
|
| 62 |
+
|
| 63 |
+
#----------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 66 |
+
input_images = []
|
| 67 |
+
def _recurse_dirs(root: str): # workaround Path().rglob() slowness
|
| 68 |
+
with os.scandir(root) as it:
|
| 69 |
+
for e in it:
|
| 70 |
+
if e.is_file():
|
| 71 |
+
input_images.append(os.path.join(root, e.name))
|
| 72 |
+
elif e.is_dir():
|
| 73 |
+
_recurse_dirs(os.path.join(root, e.name))
|
| 74 |
+
_recurse_dirs(source_dir)
|
| 75 |
+
input_images = sorted([f for f in input_images if is_image_ext(f)])
|
| 76 |
+
|
| 77 |
+
arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| 78 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 79 |
+
|
| 80 |
+
# Load labels.
|
| 81 |
+
labels = dict()
|
| 82 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
| 83 |
+
if os.path.isfile(meta_fname):
|
| 84 |
+
with open(meta_fname, 'r') as file:
|
| 85 |
+
data = json.load(file)['labels']
|
| 86 |
+
if data is not None:
|
| 87 |
+
labels = {x[0]: x[1] for x in data}
|
| 88 |
+
|
| 89 |
+
# No labels available => determine from top-level directory names.
|
| 90 |
+
if len(labels) == 0:
|
| 91 |
+
toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| 92 |
+
toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| 93 |
+
if len(toplevel_indices) > 1:
|
| 94 |
+
labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
| 95 |
+
|
| 96 |
+
def iterate_images():
|
| 97 |
+
for idx, fname in enumerate(input_images):
|
| 98 |
+
img = np.array(PIL.Image.open(fname).convert('RGB'))
|
| 99 |
+
yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
|
| 100 |
+
if idx >= max_idx - 1:
|
| 101 |
+
break
|
| 102 |
+
return max_idx, iterate_images()
|
| 103 |
+
|
| 104 |
+
#----------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 107 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 108 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| 109 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 110 |
+
|
| 111 |
+
# Load labels.
|
| 112 |
+
labels = dict()
|
| 113 |
+
if 'dataset.json' in z.namelist():
|
| 114 |
+
with z.open('dataset.json', 'r') as file:
|
| 115 |
+
data = json.load(file)['labels']
|
| 116 |
+
if data is not None:
|
| 117 |
+
labels = {x[0]: x[1] for x in data}
|
| 118 |
+
|
| 119 |
+
def iterate_images():
|
| 120 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 121 |
+
for idx, fname in enumerate(input_images):
|
| 122 |
+
with z.open(fname, 'r') as file:
|
| 123 |
+
img = np.array(PIL.Image.open(file).convert('RGB'))
|
| 124 |
+
yield ImageEntry(img=img, label=labels.get(fname))
|
| 125 |
+
if idx >= max_idx - 1:
|
| 126 |
+
break
|
| 127 |
+
return max_idx, iterate_images()
|
| 128 |
+
|
| 129 |
+
#----------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def make_transform(
|
| 132 |
+
transform: Optional[str],
|
| 133 |
+
output_width: Optional[int],
|
| 134 |
+
output_height: Optional[int]
|
| 135 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| 136 |
+
def scale(width, height, img):
|
| 137 |
+
w = img.shape[1]
|
| 138 |
+
h = img.shape[0]
|
| 139 |
+
if width == w and height == h:
|
| 140 |
+
return img
|
| 141 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 142 |
+
ww = width if width is not None else w
|
| 143 |
+
hh = height if height is not None else h
|
| 144 |
+
img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| 145 |
+
return np.array(img)
|
| 146 |
+
|
| 147 |
+
def center_crop(width, height, img):
|
| 148 |
+
crop = np.min(img.shape[:2])
|
| 149 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| 150 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 151 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 152 |
+
return np.array(img)
|
| 153 |
+
|
| 154 |
+
def center_crop_wide(width, height, img):
|
| 155 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| 156 |
+
if img.shape[1] < width or ch < height:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| 160 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 161 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 162 |
+
img = np.array(img)
|
| 163 |
+
|
| 164 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| 165 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| 166 |
+
return canvas
|
| 167 |
+
|
| 168 |
+
def center_crop_imagenet(image_size: int, arr: np.ndarray):
|
| 169 |
+
"""
|
| 170 |
+
Center cropping implementation from ADM.
|
| 171 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 172 |
+
"""
|
| 173 |
+
pil_image = PIL.Image.fromarray(arr)
|
| 174 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 175 |
+
new_size = tuple(x // 2 for x in pil_image.size)
|
| 176 |
+
assert len(new_size) == 2
|
| 177 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
|
| 178 |
+
|
| 179 |
+
scale = image_size / min(*pil_image.size)
|
| 180 |
+
new_size = tuple(round(x * scale) for x in pil_image.size)
|
| 181 |
+
assert len(new_size) == 2
|
| 182 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
|
| 183 |
+
|
| 184 |
+
arr = np.array(pil_image)
|
| 185 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 186 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 187 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
| 188 |
+
|
| 189 |
+
if transform is None:
|
| 190 |
+
return functools.partial(scale, output_width, output_height)
|
| 191 |
+
if transform == 'center-crop':
|
| 192 |
+
if output_width is None or output_height is None:
|
| 193 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| 194 |
+
return functools.partial(center_crop, output_width, output_height)
|
| 195 |
+
if transform == 'center-crop-wide':
|
| 196 |
+
if output_width is None or output_height is None:
|
| 197 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 198 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
| 199 |
+
if transform == 'center-crop-dhariwal':
|
| 200 |
+
if output_width is None or output_height is None:
|
| 201 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 202 |
+
if output_width != output_height:
|
| 203 |
+
raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
|
| 204 |
+
return functools.partial(center_crop_imagenet, output_width)
|
| 205 |
+
assert False, 'unknown transform'
|
| 206 |
+
|
| 207 |
+
#----------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
| 210 |
+
if os.path.isdir(source):
|
| 211 |
+
return open_image_folder(source, max_images=max_images)
|
| 212 |
+
elif os.path.isfile(source):
|
| 213 |
+
if file_ext(source) == 'zip':
|
| 214 |
+
return open_image_zip(source, max_images=max_images)
|
| 215 |
+
else:
|
| 216 |
+
raise click.ClickException(f'Only zip archives are supported: {source}')
|
| 217 |
+
else:
|
| 218 |
+
raise click.ClickException(f'Missing input file or directory: {source}')
|
| 219 |
+
|
| 220 |
+
#----------------------------------------------------------------------------
|
| 221 |
+
|
| 222 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| 223 |
+
dest_ext = file_ext(dest)
|
| 224 |
+
|
| 225 |
+
if dest_ext == 'zip':
|
| 226 |
+
if os.path.dirname(dest) != '':
|
| 227 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| 228 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| 229 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| 230 |
+
zf.writestr(fname, data)
|
| 231 |
+
return '', zip_write_bytes, zf.close
|
| 232 |
+
else:
|
| 233 |
+
# If the output folder already exists, check that is is
|
| 234 |
+
# empty.
|
| 235 |
+
#
|
| 236 |
+
# Note: creating the output directory is not strictly
|
| 237 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
| 238 |
+
# to give an error message earlier in case the dest folder
|
| 239 |
+
# somehow cannot be created.
|
| 240 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| 241 |
+
raise click.ClickException('--dest folder must be empty')
|
| 242 |
+
os.makedirs(dest, exist_ok=True)
|
| 243 |
+
|
| 244 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| 245 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| 246 |
+
with open(fname, 'wb') as fout:
|
| 247 |
+
if isinstance(data, str):
|
| 248 |
+
data = data.encode('utf8')
|
| 249 |
+
fout.write(data)
|
| 250 |
+
return dest, folder_write_bytes, lambda: None
|
| 251 |
+
|
| 252 |
+
#----------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
@click.group()
|
| 255 |
+
def cmdline():
|
| 256 |
+
'''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
|
| 257 |
+
if os.environ.get('WORLD_SIZE', '1') != '1':
|
| 258 |
+
raise click.ClickException('Distributed execution is not supported.')
|
| 259 |
+
|
| 260 |
+
#----------------------------------------------------------------------------
|
| 261 |
+
|
| 262 |
+
@cmdline.command()
|
| 263 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 264 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 265 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 266 |
+
@click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-dhariwal']))
|
| 267 |
+
@click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
|
| 268 |
+
|
| 269 |
+
def convert(
|
| 270 |
+
source: str,
|
| 271 |
+
dest: str,
|
| 272 |
+
max_images: Optional[int],
|
| 273 |
+
transform: Optional[str],
|
| 274 |
+
resolution: Optional[Tuple[int, int]]
|
| 275 |
+
):
|
| 276 |
+
"""Convert an image dataset into archive format for training.
|
| 277 |
+
|
| 278 |
+
Specifying the input images:
|
| 279 |
+
|
| 280 |
+
\b
|
| 281 |
+
--source path/ Recursively load all images from path/
|
| 282 |
+
--source dataset.zip Load all images from dataset.zip
|
| 283 |
+
|
| 284 |
+
Specifying the output format and path:
|
| 285 |
+
|
| 286 |
+
\b
|
| 287 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
| 288 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
| 289 |
+
|
| 290 |
+
The output dataset format can be either an image folder or an uncompressed zip archive.
|
| 291 |
+
Zip archives makes it easier to move datasets around file servers and clusters, and may
|
| 292 |
+
offer better training performance on network file systems.
|
| 293 |
+
|
| 294 |
+
Images within the dataset archive will be stored as uncompressed PNG.
|
| 295 |
+
Uncompresed PNGs can be efficiently decoded in the training loop.
|
| 296 |
+
|
| 297 |
+
Class labels are stored in a file called 'dataset.json' that is stored at the
|
| 298 |
+
dataset root folder. This file has the following structure:
|
| 299 |
+
|
| 300 |
+
\b
|
| 301 |
+
{
|
| 302 |
+
"labels": [
|
| 303 |
+
["00000/img00000000.png",6],
|
| 304 |
+
["00000/img00000001.png",9],
|
| 305 |
+
... repeated for every image in the datase
|
| 306 |
+
["00049/img00049999.png",1]
|
| 307 |
+
]
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
If the 'dataset.json' file cannot be found, class labels are determined from
|
| 311 |
+
top-level directory names.
|
| 312 |
+
|
| 313 |
+
Image scale/crop and resolution requirements:
|
| 314 |
+
|
| 315 |
+
Output images must be square-shaped and they must all have the same power-of-two
|
| 316 |
+
dimensions.
|
| 317 |
+
|
| 318 |
+
To scale arbitrary input image size to a specific width and height, use the
|
| 319 |
+
--resolution option. Output resolution will be either the original
|
| 320 |
+
input resolution (if resolution was not specified) or the one specified with
|
| 321 |
+
--resolution option.
|
| 322 |
+
|
| 323 |
+
The --transform=center-crop-dhariwal selects a crop/rescale mode that is intended
|
| 324 |
+
to exactly match with results obtained for ImageNet in common diffusion model literature:
|
| 325 |
+
|
| 326 |
+
\b
|
| 327 |
+
python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \\
|
| 328 |
+
--dest=datasets/img64.zip --resolution=64x64 --transform=center-crop-dhariwal
|
| 329 |
+
"""
|
| 330 |
+
PIL.Image.init()
|
| 331 |
+
if dest == '':
|
| 332 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 333 |
+
print("Begin!!!!!!!!")
|
| 334 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 335 |
+
print("open_dataset is over")
|
| 336 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 337 |
+
print("open_dest is over")
|
| 338 |
+
transform_image = make_transform(transform, *resolution if resolution is not None else (None, None))
|
| 339 |
+
dataset_attrs = None
|
| 340 |
+
|
| 341 |
+
labels = []
|
| 342 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 343 |
+
idx_str = f'{idx:08d}'
|
| 344 |
+
archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
|
| 345 |
+
|
| 346 |
+
# Apply crop and resize.
|
| 347 |
+
img = transform_image(image.img)
|
| 348 |
+
if img is None:
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# Error check to require uniform image attributes across
|
| 352 |
+
# the whole dataset.
|
| 353 |
+
assert img.ndim == 3
|
| 354 |
+
cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0]}
|
| 355 |
+
if dataset_attrs is None:
|
| 356 |
+
dataset_attrs = cur_image_attrs
|
| 357 |
+
width = dataset_attrs['width']
|
| 358 |
+
height = dataset_attrs['height']
|
| 359 |
+
if width != height:
|
| 360 |
+
raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
|
| 361 |
+
if width != 2 ** int(np.floor(np.log2(width))):
|
| 362 |
+
raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
|
| 363 |
+
elif dataset_attrs != cur_image_attrs:
|
| 364 |
+
err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
|
| 365 |
+
raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
|
| 366 |
+
|
| 367 |
+
# Save the image as an uncompressed PNG.
|
| 368 |
+
img = PIL.Image.fromarray(img)
|
| 369 |
+
image_bits = io.BytesIO()
|
| 370 |
+
img.save(image_bits, format='png', compress_level=0, optimize=False)
|
| 371 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
|
| 372 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 373 |
+
|
| 374 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 375 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 376 |
+
close_dest()
|
| 377 |
+
|
| 378 |
+
#----------------------------------------------------------------------------
|
| 379 |
+
|
| 380 |
+
@cmdline.command()
|
| 381 |
+
@click.option('--model-url', help='VAE encoder model', metavar='URL', type=str, default='stabilityai/sd-vae-ft-mse', show_default=True)
|
| 382 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 383 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 384 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 385 |
+
|
| 386 |
+
def encode(
|
| 387 |
+
model_url: str,
|
| 388 |
+
source: str,
|
| 389 |
+
dest: str,
|
| 390 |
+
max_images: Optional[int],
|
| 391 |
+
):
|
| 392 |
+
"""Encode pixel data to VAE latents."""
|
| 393 |
+
PIL.Image.init()
|
| 394 |
+
if dest == '':
|
| 395 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 396 |
+
|
| 397 |
+
vae = StabilityVAEEncoder(vae_name=model_url, batch_size=1)
|
| 398 |
+
print("VAE is over!!!")
|
| 399 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 400 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 401 |
+
print("Data is over!!!")
|
| 402 |
+
labels = []
|
| 403 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 404 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 405 |
+
img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
|
| 406 |
+
mean_std = vae.encode_pixels(img_tensor)[0].cpu()
|
| 407 |
+
idx_str = f'{idx:08d}'
|
| 408 |
+
archive_fname = f'{idx_str[:5]}/img-mean-std-{idx_str}.npy'
|
| 409 |
+
|
| 410 |
+
f = io.BytesIO()
|
| 411 |
+
np.save(f, mean_std)
|
| 412 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
|
| 413 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 414 |
+
|
| 415 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 416 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 417 |
+
close_dest()
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
cmdline()
|
| 421 |
+
|
| 422 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
from .util import EasyDict, make_cache_dir_path
|
back/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
back/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
back/preprocessing/dnnlib/util.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous utility classes and functions."""
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import fnmatch
|
| 12 |
+
import importlib
|
| 13 |
+
import inspect
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
import io
|
| 20 |
+
import pickle
|
| 21 |
+
import re
|
| 22 |
+
import requests
|
| 23 |
+
import html
|
| 24 |
+
import hashlib
|
| 25 |
+
import glob
|
| 26 |
+
import tempfile
|
| 27 |
+
import urllib
|
| 28 |
+
import urllib.parse
|
| 29 |
+
import uuid
|
| 30 |
+
|
| 31 |
+
from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional
|
| 32 |
+
|
| 33 |
+
# Util classes
|
| 34 |
+
# ------------------------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EasyDict(dict):
|
| 38 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 39 |
+
|
| 40 |
+
def __getattr__(self, name: str) -> Any:
|
| 41 |
+
try:
|
| 42 |
+
return self[name]
|
| 43 |
+
except KeyError:
|
| 44 |
+
raise AttributeError(name)
|
| 45 |
+
|
| 46 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 47 |
+
self[name] = value
|
| 48 |
+
|
| 49 |
+
def __delattr__(self, name: str) -> None:
|
| 50 |
+
del self[name]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Logger(object):
|
| 54 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
|
| 57 |
+
self.file = None
|
| 58 |
+
|
| 59 |
+
if file_name is not None:
|
| 60 |
+
self.file = open(file_name, file_mode)
|
| 61 |
+
|
| 62 |
+
self.should_flush = should_flush
|
| 63 |
+
self.stdout = sys.stdout
|
| 64 |
+
self.stderr = sys.stderr
|
| 65 |
+
|
| 66 |
+
sys.stdout = self
|
| 67 |
+
sys.stderr = self
|
| 68 |
+
|
| 69 |
+
def __enter__(self) -> "Logger":
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 73 |
+
self.close()
|
| 74 |
+
|
| 75 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 76 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 77 |
+
if isinstance(text, bytes):
|
| 78 |
+
text = text.decode()
|
| 79 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
if self.file is not None:
|
| 83 |
+
self.file.write(text)
|
| 84 |
+
|
| 85 |
+
self.stdout.write(text)
|
| 86 |
+
|
| 87 |
+
if self.should_flush:
|
| 88 |
+
self.flush()
|
| 89 |
+
|
| 90 |
+
def flush(self) -> None:
|
| 91 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 92 |
+
if self.file is not None:
|
| 93 |
+
self.file.flush()
|
| 94 |
+
|
| 95 |
+
self.stdout.flush()
|
| 96 |
+
|
| 97 |
+
def close(self) -> None:
|
| 98 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 99 |
+
self.flush()
|
| 100 |
+
|
| 101 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 102 |
+
if sys.stdout is self:
|
| 103 |
+
sys.stdout = self.stdout
|
| 104 |
+
if sys.stderr is self:
|
| 105 |
+
sys.stderr = self.stderr
|
| 106 |
+
|
| 107 |
+
if self.file is not None:
|
| 108 |
+
self.file.close()
|
| 109 |
+
self.file = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Cache directories
|
| 113 |
+
# ------------------------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
_dnnlib_cache_dir = None
|
| 116 |
+
|
| 117 |
+
def set_cache_dir(path: str) -> None:
|
| 118 |
+
global _dnnlib_cache_dir
|
| 119 |
+
_dnnlib_cache_dir = path
|
| 120 |
+
|
| 121 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 122 |
+
if _dnnlib_cache_dir is not None:
|
| 123 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 124 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 125 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 126 |
+
if 'HOME' in os.environ:
|
| 127 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 128 |
+
if 'USERPROFILE' in os.environ:
|
| 129 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 130 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 131 |
+
|
| 132 |
+
# Small util functions
|
| 133 |
+
# ------------------------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 137 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 138 |
+
s = int(np.rint(seconds))
|
| 139 |
+
|
| 140 |
+
if s < 60:
|
| 141 |
+
return "{0}s".format(s)
|
| 142 |
+
elif s < 60 * 60:
|
| 143 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 144 |
+
elif s < 24 * 60 * 60:
|
| 145 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 146 |
+
else:
|
| 147 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 151 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 152 |
+
s = int(np.rint(seconds))
|
| 153 |
+
|
| 154 |
+
if s < 60:
|
| 155 |
+
return "{0}s".format(s)
|
| 156 |
+
elif s < 60 * 60:
|
| 157 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 158 |
+
elif s < 24 * 60 * 60:
|
| 159 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 160 |
+
else:
|
| 161 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def tuple_product(t: Tuple) -> Any:
|
| 165 |
+
"""Calculate the product of the tuple elements."""
|
| 166 |
+
result = 1
|
| 167 |
+
|
| 168 |
+
for v in t:
|
| 169 |
+
result *= v
|
| 170 |
+
|
| 171 |
+
return result
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
_str_to_ctype = {
|
| 175 |
+
"uint8": ctypes.c_ubyte,
|
| 176 |
+
"uint16": ctypes.c_uint16,
|
| 177 |
+
"uint32": ctypes.c_uint32,
|
| 178 |
+
"uint64": ctypes.c_uint64,
|
| 179 |
+
"int8": ctypes.c_byte,
|
| 180 |
+
"int16": ctypes.c_int16,
|
| 181 |
+
"int32": ctypes.c_int32,
|
| 182 |
+
"int64": ctypes.c_int64,
|
| 183 |
+
"float32": ctypes.c_float,
|
| 184 |
+
"float64": ctypes.c_double
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 189 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 190 |
+
type_str = None
|
| 191 |
+
|
| 192 |
+
if isinstance(type_obj, str):
|
| 193 |
+
type_str = type_obj
|
| 194 |
+
elif hasattr(type_obj, "__name__"):
|
| 195 |
+
type_str = type_obj.__name__
|
| 196 |
+
elif hasattr(type_obj, "name"):
|
| 197 |
+
type_str = type_obj.name
|
| 198 |
+
else:
|
| 199 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 200 |
+
|
| 201 |
+
assert type_str in _str_to_ctype.keys()
|
| 202 |
+
|
| 203 |
+
my_dtype = np.dtype(type_str)
|
| 204 |
+
my_ctype = _str_to_ctype[type_str]
|
| 205 |
+
|
| 206 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 207 |
+
|
| 208 |
+
return my_dtype, my_ctype
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def is_pickleable(obj: Any) -> bool:
|
| 212 |
+
try:
|
| 213 |
+
with io.BytesIO() as stream:
|
| 214 |
+
pickle.dump(obj, stream)
|
| 215 |
+
return True
|
| 216 |
+
except:
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 221 |
+
# ------------------------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 224 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 225 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 226 |
+
|
| 227 |
+
# allow convenience shorthands, substitute them by full names
|
| 228 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 229 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 230 |
+
|
| 231 |
+
# list alternatives for (module_name, local_obj_name)
|
| 232 |
+
parts = obj_name.split(".")
|
| 233 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 234 |
+
|
| 235 |
+
# try each alternative in turn
|
| 236 |
+
for module_name, local_obj_name in name_pairs:
|
| 237 |
+
try:
|
| 238 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 239 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 240 |
+
return module, local_obj_name
|
| 241 |
+
except:
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
# maybe some of the modules themselves contain errors?
|
| 245 |
+
for module_name, _local_obj_name in name_pairs:
|
| 246 |
+
try:
|
| 247 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 248 |
+
except ImportError:
|
| 249 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 250 |
+
raise
|
| 251 |
+
|
| 252 |
+
# maybe the requested attribute is missing?
|
| 253 |
+
for module_name, local_obj_name in name_pairs:
|
| 254 |
+
try:
|
| 255 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 256 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 257 |
+
except ImportError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
# we are out of luck, but we have no idea why
|
| 261 |
+
raise ImportError(obj_name)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 265 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 266 |
+
if obj_name == '':
|
| 267 |
+
return module
|
| 268 |
+
obj = module
|
| 269 |
+
for part in obj_name.split("."):
|
| 270 |
+
obj = getattr(obj, part)
|
| 271 |
+
return obj
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_obj_by_name(name: str) -> Any:
|
| 275 |
+
"""Finds the python object with the given name."""
|
| 276 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 277 |
+
return get_obj_from_module(module, obj_name)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any:
|
| 281 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 282 |
+
assert func_name is not None
|
| 283 |
+
func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name
|
| 284 |
+
assert callable(func_obj)
|
| 285 |
+
return func_obj(*args, **kwargs)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any:
|
| 289 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 290 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 294 |
+
"""Get the directory path of the module containing the given object name."""
|
| 295 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 296 |
+
return os.path.dirname(inspect.getfile(module))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 300 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 301 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 305 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 306 |
+
assert is_top_level_function(obj)
|
| 307 |
+
module = obj.__module__
|
| 308 |
+
if module == '__main__':
|
| 309 |
+
fname = sys.modules[module].__file__
|
| 310 |
+
assert fname is not None
|
| 311 |
+
module = os.path.splitext(os.path.basename(fname))[0]
|
| 312 |
+
return module + "." + obj.__name__
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# File system helpers
|
| 316 |
+
# ------------------------------------------------------------------------------------------
|
| 317 |
+
|
| 318 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 319 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 320 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 321 |
+
assert os.path.isdir(dir_path)
|
| 322 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 323 |
+
|
| 324 |
+
if ignores is None:
|
| 325 |
+
ignores = []
|
| 326 |
+
|
| 327 |
+
result = []
|
| 328 |
+
|
| 329 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 330 |
+
for ignore_ in ignores:
|
| 331 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 332 |
+
|
| 333 |
+
# dirs need to be edited in-place
|
| 334 |
+
for d in dirs_to_remove:
|
| 335 |
+
dirs.remove(d)
|
| 336 |
+
|
| 337 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 338 |
+
|
| 339 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 340 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 341 |
+
|
| 342 |
+
if add_base_to_relative:
|
| 343 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 344 |
+
|
| 345 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 346 |
+
result += zip(absolute_paths, relative_paths)
|
| 347 |
+
|
| 348 |
+
return result
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 352 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 353 |
+
Will create all necessary directories."""
|
| 354 |
+
for file in files:
|
| 355 |
+
target_dir_name = os.path.dirname(file[1])
|
| 356 |
+
|
| 357 |
+
# will create all intermediate-level directories
|
| 358 |
+
os.makedirs(target_dir_name, exist_ok=True)
|
| 359 |
+
shutil.copyfile(file[0], file[1])
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# URL helpers
|
| 363 |
+
# ------------------------------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 366 |
+
"""Determine whether the given object is a valid URL string."""
|
| 367 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 368 |
+
return False
|
| 369 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 370 |
+
return True
|
| 371 |
+
try:
|
| 372 |
+
res = urllib.parse.urlparse(obj)
|
| 373 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 374 |
+
return False
|
| 375 |
+
res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/"))
|
| 376 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 377 |
+
return False
|
| 378 |
+
except:
|
| 379 |
+
return False
|
| 380 |
+
return True
|
| 381 |
+
|
| 382 |
+
# Note on static typing: a better API would be to split 'open_url' to 'openl_url' and
|
| 383 |
+
# 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True`
|
| 384 |
+
# case is somewhat uncommon, we just pretend like this function never returns a string
|
| 385 |
+
# and type ignore return value for those cases.
|
| 386 |
+
def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO:
|
| 387 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 388 |
+
assert num_attempts >= 1
|
| 389 |
+
assert not (return_filename and (not cache))
|
| 390 |
+
|
| 391 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 392 |
+
if not re.match('^[a-z]+://', url):
|
| 393 |
+
return url if return_filename else open(url, "rb") # type: ignore
|
| 394 |
+
|
| 395 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 396 |
+
# arise on Windows:
|
| 397 |
+
#
|
| 398 |
+
# file:///c:/foo.txt
|
| 399 |
+
#
|
| 400 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 401 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 402 |
+
#
|
| 403 |
+
# If you touch this code path, you should test it on both Linux and
|
| 404 |
+
# Windows.
|
| 405 |
+
#
|
| 406 |
+
# Some internet resources suggest using urllib.request.url2pathname()
|
| 407 |
+
# but that converts forward slashes to backslashes and this causes
|
| 408 |
+
# its own set of problems.
|
| 409 |
+
if url.startswith('file://'):
|
| 410 |
+
filename = urllib.parse.urlparse(url).path
|
| 411 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 412 |
+
filename = filename[1:]
|
| 413 |
+
return filename if return_filename else open(filename, "rb") # type: ignore
|
| 414 |
+
|
| 415 |
+
assert is_url(url)
|
| 416 |
+
|
| 417 |
+
# Lookup from cache.
|
| 418 |
+
if cache_dir is None:
|
| 419 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 420 |
+
|
| 421 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 422 |
+
if cache:
|
| 423 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 424 |
+
if len(cache_files) == 1:
|
| 425 |
+
filename = cache_files[0]
|
| 426 |
+
return filename if return_filename else open(filename, "rb") # type: ignore
|
| 427 |
+
|
| 428 |
+
# Download.
|
| 429 |
+
url_name = None
|
| 430 |
+
url_data = None
|
| 431 |
+
with requests.Session() as session:
|
| 432 |
+
if verbose:
|
| 433 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 434 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 435 |
+
try:
|
| 436 |
+
with session.get(url) as res:
|
| 437 |
+
res.raise_for_status()
|
| 438 |
+
if len(res.content) == 0:
|
| 439 |
+
raise IOError("No data received")
|
| 440 |
+
|
| 441 |
+
if len(res.content) < 8192:
|
| 442 |
+
content_str = res.content.decode("utf-8")
|
| 443 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 444 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 445 |
+
if len(links) == 1:
|
| 446 |
+
url = urllib.parse.urljoin(url, links[0])
|
| 447 |
+
raise IOError("Google Drive virus checker nag")
|
| 448 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 449 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 450 |
+
|
| 451 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 452 |
+
url_name = match[1] if match else url
|
| 453 |
+
url_data = res.content
|
| 454 |
+
if verbose:
|
| 455 |
+
print(" done")
|
| 456 |
+
break
|
| 457 |
+
except KeyboardInterrupt:
|
| 458 |
+
raise
|
| 459 |
+
except:
|
| 460 |
+
if not attempts_left:
|
| 461 |
+
if verbose:
|
| 462 |
+
print(" failed")
|
| 463 |
+
raise
|
| 464 |
+
if verbose:
|
| 465 |
+
print(".", end="", flush=True)
|
| 466 |
+
|
| 467 |
+
assert url_data is not None
|
| 468 |
+
|
| 469 |
+
# Save to cache.
|
| 470 |
+
if cache:
|
| 471 |
+
assert url_name is not None
|
| 472 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 473 |
+
safe_name = safe_name[:min(len(safe_name), 128)]
|
| 474 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 475 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 476 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 477 |
+
with open(temp_file, "wb") as f:
|
| 478 |
+
f.write(url_data)
|
| 479 |
+
os.replace(temp_file, cache_file) # atomic
|
| 480 |
+
if return_filename:
|
| 481 |
+
return cache_file # type: ignore
|
| 482 |
+
|
| 483 |
+
# Return data as file object.
|
| 484 |
+
assert not return_filename
|
| 485 |
+
return io.BytesIO(url_data)
|
back/preprocessing/encoders.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Converting between pixel and latent representations of image data."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import warnings
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch_utils import persistence
|
| 15 |
+
from torch_utils import misc
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
|
| 18 |
+
warnings.filterwarnings('ignore', '`resume_download` is deprecated')
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
# Abstract base class for encoders/decoders that convert back and forth
|
| 22 |
+
# between pixel and latent representations of image data.
|
| 23 |
+
#
|
| 24 |
+
# Logically, "raw pixels" are first encoded into "raw latents" that are
|
| 25 |
+
# then further encoded into "final latents". Decoding, on the other hand,
|
| 26 |
+
# goes directly from the final latents to raw pixels. The final latents are
|
| 27 |
+
# used as inputs and outputs of the model, whereas the raw latents are
|
| 28 |
+
# stored in the dataset. This separation provides added flexibility in terms
|
| 29 |
+
# of performing just-in-time adjustments, such as data whitening, without
|
| 30 |
+
# having to construct a new dataset.
|
| 31 |
+
#
|
| 32 |
+
# All image data is represented as PyTorch tensors in NCHW order.
|
| 33 |
+
# Raw pixels are represented as 3-channel uint8.
|
| 34 |
+
|
| 35 |
+
@persistence.persistent_class
|
| 36 |
+
class Encoder:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def init(self, device): # force lazy init to happen now
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
def __getstate__(self):
|
| 44 |
+
return self.__dict__
|
| 45 |
+
|
| 46 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
| 47 |
+
raise NotImplementedError # to be overridden by subclass
|
| 48 |
+
#----------------------------------------------------------------------------
|
| 49 |
+
# Pre-trained VAE encoder from Stability AI.
|
| 50 |
+
|
| 51 |
+
@persistence.persistent_class
|
| 52 |
+
class StabilityVAEEncoder(Encoder):
|
| 53 |
+
def __init__(self,
|
| 54 |
+
vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
|
| 55 |
+
batch_size = 8, # Batch size to use when running the VAE.
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.vae_name = vae_name
|
| 59 |
+
self.batch_size = int(batch_size)
|
| 60 |
+
self._vae = None
|
| 61 |
+
|
| 62 |
+
def init(self, device): # force lazy init to happen now
|
| 63 |
+
super().init(device)
|
| 64 |
+
if self._vae is None:
|
| 65 |
+
self._vae = load_stability_vae(self.vae_name, device=device)
|
| 66 |
+
else:
|
| 67 |
+
self._vae.to(device)
|
| 68 |
+
|
| 69 |
+
def __getstate__(self):
|
| 70 |
+
return dict(super().__getstate__(), _vae=None) # do not pickle the vae
|
| 71 |
+
|
| 72 |
+
def _run_vae_encoder(self, x):
|
| 73 |
+
d = self._vae.encode(x)['latent_dist']
|
| 74 |
+
return torch.cat([d.mean, d.std], dim=1)
|
| 75 |
+
|
| 76 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
| 77 |
+
self.init(x.device)
|
| 78 |
+
x = x.to(torch.float32) / 127.5 - 1
|
| 79 |
+
x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
#----------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
|
| 85 |
+
import dnnlib
|
| 86 |
+
cache_dir = dnnlib.make_cache_dir_path('diffusers')
|
| 87 |
+
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
| 88 |
+
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
| 89 |
+
os.environ['HF_HOME'] = cache_dir
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
|
| 93 |
+
try:
|
| 94 |
+
# First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
|
| 95 |
+
vae = diffusers.models.AutoencoderKL.from_pretrained(
|
| 96 |
+
vae_name, cache_dir=cache_dir, local_files_only=True
|
| 97 |
+
)
|
| 98 |
+
except:
|
| 99 |
+
# Could not load the model from cache; try without local_files_only.
|
| 100 |
+
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir)
|
| 101 |
+
return vae.eval().requires_grad_(False).to(device)
|
| 102 |
+
|
| 103 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/torch_utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
# empty
|
back/preprocessing/torch_utils/distributed.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import socket
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed
|
| 13 |
+
from . import training_stats
|
| 14 |
+
|
| 15 |
+
_sync_device = None
|
| 16 |
+
|
| 17 |
+
#----------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
def init():
|
| 20 |
+
global _sync_device
|
| 21 |
+
|
| 22 |
+
if not torch.distributed.is_initialized():
|
| 23 |
+
# Setup some reasonable defaults for env-based distributed init if
|
| 24 |
+
# not set by the running environment.
|
| 25 |
+
if 'MASTER_ADDR' not in os.environ:
|
| 26 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 27 |
+
if 'MASTER_PORT' not in os.environ:
|
| 28 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 29 |
+
s.bind(('', 0))
|
| 30 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 31 |
+
os.environ['MASTER_PORT'] = str(s.getsockname()[1])
|
| 32 |
+
s.close()
|
| 33 |
+
if 'RANK' not in os.environ:
|
| 34 |
+
os.environ['RANK'] = '0'
|
| 35 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 36 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 37 |
+
if 'WORLD_SIZE' not in os.environ:
|
| 38 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 39 |
+
backend = 'gloo' if os.name == 'nt' else 'nccl'
|
| 40 |
+
torch.distributed.init_process_group(backend=backend, init_method='env://')
|
| 41 |
+
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
|
| 42 |
+
|
| 43 |
+
_sync_device = torch.device('cuda') if get_world_size() > 1 else None
|
| 44 |
+
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
|
| 45 |
+
|
| 46 |
+
#----------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def get_rank():
|
| 49 |
+
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def get_world_size():
|
| 54 |
+
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
| 55 |
+
|
| 56 |
+
#----------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def should_stop():
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
#----------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
def should_suspend():
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
#----------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def request_suspend():
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
#----------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def update_progress(cur, total):
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
#----------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def print0(*args, **kwargs):
|
| 79 |
+
if get_rank() == 0:
|
| 80 |
+
print(*args, **kwargs)
|
| 81 |
+
|
| 82 |
+
#----------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
class CheckpointIO:
|
| 85 |
+
def __init__(self, **kwargs):
|
| 86 |
+
self._state_objs = kwargs
|
| 87 |
+
|
| 88 |
+
def save(self, pt_path, verbose=True):
|
| 89 |
+
if verbose:
|
| 90 |
+
print0(f'Saving {pt_path} ... ', end='', flush=True)
|
| 91 |
+
data = dict()
|
| 92 |
+
for name, obj in self._state_objs.items():
|
| 93 |
+
if obj is None:
|
| 94 |
+
data[name] = None
|
| 95 |
+
elif isinstance(obj, dict):
|
| 96 |
+
data[name] = obj
|
| 97 |
+
elif hasattr(obj, 'state_dict'):
|
| 98 |
+
data[name] = obj.state_dict()
|
| 99 |
+
elif hasattr(obj, '__getstate__'):
|
| 100 |
+
data[name] = obj.__getstate__()
|
| 101 |
+
elif hasattr(obj, '__dict__'):
|
| 102 |
+
data[name] = obj.__dict__
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
|
| 105 |
+
if get_rank() == 0:
|
| 106 |
+
torch.save(data, pt_path)
|
| 107 |
+
if verbose:
|
| 108 |
+
print0('done')
|
| 109 |
+
|
| 110 |
+
def load(self, pt_path, verbose=True):
|
| 111 |
+
if verbose:
|
| 112 |
+
print0(f'Loading {pt_path} ... ', end='', flush=True)
|
| 113 |
+
data = torch.load(pt_path, map_location=torch.device('cpu'))
|
| 114 |
+
for name, obj in self._state_objs.items():
|
| 115 |
+
if obj is None:
|
| 116 |
+
pass
|
| 117 |
+
elif isinstance(obj, dict):
|
| 118 |
+
obj.clear()
|
| 119 |
+
obj.update(data[name])
|
| 120 |
+
elif hasattr(obj, 'load_state_dict'):
|
| 121 |
+
obj.load_state_dict(data[name])
|
| 122 |
+
elif hasattr(obj, '__setstate__'):
|
| 123 |
+
obj.__setstate__(data[name])
|
| 124 |
+
elif hasattr(obj, '__dict__'):
|
| 125 |
+
obj.__dict__.clear()
|
| 126 |
+
obj.__dict__.update(data[name])
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
|
| 129 |
+
if verbose:
|
| 130 |
+
print0('done')
|
| 131 |
+
|
| 132 |
+
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
|
| 133 |
+
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
|
| 134 |
+
if len(fnames) == 0:
|
| 135 |
+
return None
|
| 136 |
+
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
|
| 137 |
+
self.load(pt_path, verbose=verbose)
|
| 138 |
+
return pt_path
|
| 139 |
+
|
| 140 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/torch_utils/misc.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import contextlib
|
| 10 |
+
import functools
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import warnings
|
| 14 |
+
import dnnlib
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
# Re-seed torch & numpy random generators based on the given arguments.
|
| 18 |
+
|
| 19 |
+
def set_random_seed(*args):
|
| 20 |
+
seed = hash(args) % (1 << 31)
|
| 21 |
+
torch.manual_seed(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
|
| 24 |
+
#----------------------------------------------------------------------------
|
| 25 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
| 26 |
+
# same constant is used multiple times.
|
| 27 |
+
|
| 28 |
+
_constant_cache = dict()
|
| 29 |
+
|
| 30 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
| 31 |
+
value = np.asarray(value)
|
| 32 |
+
if shape is not None:
|
| 33 |
+
shape = tuple(shape)
|
| 34 |
+
if dtype is None:
|
| 35 |
+
dtype = torch.get_default_dtype()
|
| 36 |
+
if device is None:
|
| 37 |
+
device = torch.device('cpu')
|
| 38 |
+
if memory_format is None:
|
| 39 |
+
memory_format = torch.contiguous_format
|
| 40 |
+
|
| 41 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
| 42 |
+
tensor = _constant_cache.get(key, None)
|
| 43 |
+
if tensor is None:
|
| 44 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
| 45 |
+
if shape is not None:
|
| 46 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
| 47 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
| 48 |
+
_constant_cache[key] = tensor
|
| 49 |
+
return tensor
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
# Variant of constant() that inherits dtype and device from the given
|
| 53 |
+
# reference tensor by default.
|
| 54 |
+
|
| 55 |
+
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
| 56 |
+
if dtype is None:
|
| 57 |
+
dtype = ref.dtype
|
| 58 |
+
if device is None:
|
| 59 |
+
device = ref.device
|
| 60 |
+
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
| 61 |
+
|
| 62 |
+
#----------------------------------------------------------------------------
|
| 63 |
+
# Cached construction of temporary tensors in pinned CPU memory.
|
| 64 |
+
|
| 65 |
+
@functools.lru_cache(None)
|
| 66 |
+
def pinned_buf(shape, dtype):
|
| 67 |
+
return torch.empty(shape, dtype=dtype).pin_memory()
|
| 68 |
+
|
| 69 |
+
#----------------------------------------------------------------------------
|
| 70 |
+
# Symbolic assert.
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
| 74 |
+
except AttributeError:
|
| 75 |
+
symbolic_assert = torch.Assert # 1.7.0
|
| 76 |
+
|
| 77 |
+
#----------------------------------------------------------------------------
|
| 78 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
| 79 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
| 80 |
+
|
| 81 |
+
@contextlib.contextmanager
|
| 82 |
+
def suppress_tracer_warnings():
|
| 83 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
| 84 |
+
warnings.filters.insert(0, flt)
|
| 85 |
+
yield
|
| 86 |
+
warnings.filters.remove(flt)
|
| 87 |
+
|
| 88 |
+
#----------------------------------------------------------------------------
|
| 89 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
| 90 |
+
# None indicates that the size of a dimension is allowed to vary.
|
| 91 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
| 92 |
+
|
| 93 |
+
def assert_shape(tensor, ref_shape):
|
| 94 |
+
if tensor.ndim != len(ref_shape):
|
| 95 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
| 96 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
| 97 |
+
if ref_size is None:
|
| 98 |
+
pass
|
| 99 |
+
elif isinstance(ref_size, torch.Tensor):
|
| 100 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 101 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
| 102 |
+
elif isinstance(size, torch.Tensor):
|
| 103 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 104 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
| 105 |
+
elif size != ref_size:
|
| 106 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
| 107 |
+
|
| 108 |
+
#----------------------------------------------------------------------------
|
| 109 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
| 110 |
+
|
| 111 |
+
def profiled_function(fn):
|
| 112 |
+
def decorator(*args, **kwargs):
|
| 113 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
| 114 |
+
return fn(*args, **kwargs)
|
| 115 |
+
decorator.__name__ = fn.__name__
|
| 116 |
+
return decorator
|
| 117 |
+
|
| 118 |
+
#----------------------------------------------------------------------------
|
| 119 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
| 120 |
+
# indefinitely, shuffling items as it goes.
|
| 121 |
+
|
| 122 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
| 123 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0):
|
| 124 |
+
assert len(dataset) > 0
|
| 125 |
+
assert num_replicas > 0
|
| 126 |
+
assert 0 <= rank < num_replicas
|
| 127 |
+
warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed')
|
| 128 |
+
super().__init__(dataset)
|
| 129 |
+
self.dataset_size = len(dataset)
|
| 130 |
+
self.start_idx = start_idx + rank
|
| 131 |
+
self.stride = num_replicas
|
| 132 |
+
self.shuffle = shuffle
|
| 133 |
+
self.seed = seed
|
| 134 |
+
|
| 135 |
+
def __iter__(self):
|
| 136 |
+
idx = self.start_idx
|
| 137 |
+
epoch = None
|
| 138 |
+
while True:
|
| 139 |
+
if epoch != idx // self.dataset_size:
|
| 140 |
+
epoch = idx // self.dataset_size
|
| 141 |
+
order = np.arange(self.dataset_size)
|
| 142 |
+
if self.shuffle:
|
| 143 |
+
np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order)
|
| 144 |
+
yield int(order[idx % self.dataset_size])
|
| 145 |
+
idx += self.stride
|
| 146 |
+
|
| 147 |
+
#----------------------------------------------------------------------------
|
| 148 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
| 149 |
+
|
| 150 |
+
def params_and_buffers(module):
|
| 151 |
+
assert isinstance(module, torch.nn.Module)
|
| 152 |
+
return list(module.parameters()) + list(module.buffers())
|
| 153 |
+
|
| 154 |
+
def named_params_and_buffers(module):
|
| 155 |
+
assert isinstance(module, torch.nn.Module)
|
| 156 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
| 160 |
+
assert isinstance(src_module, torch.nn.Module)
|
| 161 |
+
assert isinstance(dst_module, torch.nn.Module)
|
| 162 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
| 163 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
| 164 |
+
assert (name in src_tensors) or (not require_all)
|
| 165 |
+
if name in src_tensors:
|
| 166 |
+
tensor.copy_(src_tensors[name])
|
| 167 |
+
|
| 168 |
+
#----------------------------------------------------------------------------
|
| 169 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
| 170 |
+
# synchronization.
|
| 171 |
+
|
| 172 |
+
@contextlib.contextmanager
|
| 173 |
+
def ddp_sync(module, sync):
|
| 174 |
+
assert isinstance(module, torch.nn.Module)
|
| 175 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
| 176 |
+
yield
|
| 177 |
+
else:
|
| 178 |
+
with module.no_sync():
|
| 179 |
+
yield
|
| 180 |
+
|
| 181 |
+
#----------------------------------------------------------------------------
|
| 182 |
+
# Check DistributedDataParallel consistency across processes.
|
| 183 |
+
|
| 184 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
| 185 |
+
assert isinstance(module, torch.nn.Module)
|
| 186 |
+
for name, tensor in named_params_and_buffers(module):
|
| 187 |
+
fullname = type(module).__name__ + '.' + name
|
| 188 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
| 189 |
+
continue
|
| 190 |
+
tensor = tensor.detach()
|
| 191 |
+
if tensor.is_floating_point():
|
| 192 |
+
tensor = torch.nan_to_num(tensor)
|
| 193 |
+
other = tensor.clone()
|
| 194 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 195 |
+
assert (tensor == other).all(), fullname
|
| 196 |
+
|
| 197 |
+
#----------------------------------------------------------------------------
|
| 198 |
+
# Print summary table of module hierarchy.
|
| 199 |
+
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
| 202 |
+
assert isinstance(module, torch.nn.Module)
|
| 203 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
| 204 |
+
assert isinstance(inputs, (tuple, list))
|
| 205 |
+
|
| 206 |
+
# Register hooks.
|
| 207 |
+
entries = []
|
| 208 |
+
nesting = [0]
|
| 209 |
+
def pre_hook(_mod, _inputs):
|
| 210 |
+
nesting[0] += 1
|
| 211 |
+
def post_hook(mod, _inputs, outputs):
|
| 212 |
+
nesting[0] -= 1
|
| 213 |
+
if nesting[0] <= max_nesting:
|
| 214 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
| 215 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
| 216 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
| 217 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
| 218 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
| 219 |
+
|
| 220 |
+
# Run module.
|
| 221 |
+
outputs = module(*inputs)
|
| 222 |
+
for hook in hooks:
|
| 223 |
+
hook.remove()
|
| 224 |
+
|
| 225 |
+
# Identify unique outputs, parameters, and buffers.
|
| 226 |
+
tensors_seen = set()
|
| 227 |
+
for e in entries:
|
| 228 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
| 229 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
| 230 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
| 231 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
| 232 |
+
|
| 233 |
+
# Filter out redundant entries.
|
| 234 |
+
if skip_redundant:
|
| 235 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
| 236 |
+
|
| 237 |
+
# Construct table.
|
| 238 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
| 239 |
+
rows += [['---'] * len(rows[0])]
|
| 240 |
+
param_total = 0
|
| 241 |
+
buffer_total = 0
|
| 242 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
| 243 |
+
for e in entries:
|
| 244 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
| 245 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
| 246 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
| 247 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
| 248 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
| 249 |
+
rows += [[
|
| 250 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
| 251 |
+
str(param_size) if param_size else '-',
|
| 252 |
+
str(buffer_size) if buffer_size else '-',
|
| 253 |
+
(output_shapes + ['-'])[0],
|
| 254 |
+
(output_dtypes + ['-'])[0],
|
| 255 |
+
]]
|
| 256 |
+
for idx in range(1, len(e.outputs)):
|
| 257 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
| 258 |
+
param_total += param_size
|
| 259 |
+
buffer_total += buffer_size
|
| 260 |
+
rows += [['---'] * len(rows[0])]
|
| 261 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
| 262 |
+
|
| 263 |
+
# Print table.
|
| 264 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
| 265 |
+
print()
|
| 266 |
+
for row in rows:
|
| 267 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
| 268 |
+
print()
|
| 269 |
+
|
| 270 |
+
#----------------------------------------------------------------------------
|
| 271 |
+
# Tile a batch of images into a 2D grid.
|
| 272 |
+
|
| 273 |
+
def tile_images(x, w, h):
|
| 274 |
+
assert x.ndim == 4 # NCHW => CHW
|
| 275 |
+
return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3])
|
| 276 |
+
|
| 277 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/torch_utils/persistence.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for pickling Python code alongside other data.
|
| 9 |
+
|
| 10 |
+
The pickled code is automatically imported into a separate Python module
|
| 11 |
+
during unpickling. This way, any previously exported pickles will remain
|
| 12 |
+
usable even if the original code is no longer available, or if the current
|
| 13 |
+
version of the code is not consistent with what was originally pickled."""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
import pickle
|
| 17 |
+
import io
|
| 18 |
+
import inspect
|
| 19 |
+
import copy
|
| 20 |
+
import uuid
|
| 21 |
+
import types
|
| 22 |
+
import functools
|
| 23 |
+
import dnnlib
|
| 24 |
+
|
| 25 |
+
#----------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
_version = 6 # internal version number
|
| 28 |
+
_decorators = set() # {decorator_class, ...}
|
| 29 |
+
_import_hooks = [] # [hook_function, ...]
|
| 30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
| 31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
| 32 |
+
|
| 33 |
+
#----------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def persistent_class(orig_class):
|
| 36 |
+
r"""Class decorator that extends a given class to save its source code
|
| 37 |
+
when pickled.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
|
| 41 |
+
from torch_utils import persistence
|
| 42 |
+
|
| 43 |
+
@persistence.persistent_class
|
| 44 |
+
class MyNetwork(torch.nn.Module):
|
| 45 |
+
def __init__(self, num_inputs, num_outputs):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
@persistence.persistent_class
|
| 51 |
+
class MyLayer(torch.nn.Module):
|
| 52 |
+
...
|
| 53 |
+
|
| 54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
| 55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
| 56 |
+
and submodules). This way, any previously exported pickle will remain
|
| 57 |
+
usable even if the class definitions have been modified or are no
|
| 58 |
+
longer available.
|
| 59 |
+
|
| 60 |
+
The decorator saves the source code of the entire Python module
|
| 61 |
+
containing the decorated class. It does *not* save the source code of
|
| 62 |
+
any imported modules. Thus, the imported modules must be available
|
| 63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
| 64 |
+
|
| 65 |
+
It is ok to call functions defined in the same module from the
|
| 66 |
+
decorated class. However, if the decorated class depends on other
|
| 67 |
+
classes defined in the same module, they must be decorated as well.
|
| 68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
| 69 |
+
|
| 70 |
+
It is also possible to employ the decorator just-in-time before
|
| 71 |
+
calling the constructor. For example:
|
| 72 |
+
|
| 73 |
+
cls = MyLayer
|
| 74 |
+
if want_to_make_it_persistent:
|
| 75 |
+
cls = persistence.persistent_class(cls)
|
| 76 |
+
layer = cls(num_inputs, num_outputs)
|
| 77 |
+
|
| 78 |
+
As an additional feature, the decorator also keeps track of the
|
| 79 |
+
arguments that were used to construct each instance of the decorated
|
| 80 |
+
class. The arguments can be queried via `obj.init_args` and
|
| 81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
| 82 |
+
object state. This feature can be disabled on a per-instance basis
|
| 83 |
+
by setting `self._record_init_args = False` in the constructor.
|
| 84 |
+
|
| 85 |
+
A typical use case is to first unpickle a previous instance of a
|
| 86 |
+
persistent class, and then upgrade it to use the latest version of
|
| 87 |
+
the source code:
|
| 88 |
+
|
| 89 |
+
with open('old_pickle.pkl', 'rb') as f:
|
| 90 |
+
old_net = pickle.load(f)
|
| 91 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
| 92 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
| 93 |
+
"""
|
| 94 |
+
assert isinstance(orig_class, type)
|
| 95 |
+
if is_persistent(orig_class):
|
| 96 |
+
return orig_class
|
| 97 |
+
|
| 98 |
+
assert orig_class.__module__ in sys.modules
|
| 99 |
+
orig_module = sys.modules[orig_class.__module__]
|
| 100 |
+
orig_module_src = _module_to_src(orig_module)
|
| 101 |
+
|
| 102 |
+
@functools.wraps(orig_class, updated=())
|
| 103 |
+
class Decorator(orig_class):
|
| 104 |
+
_orig_module_src = orig_module_src
|
| 105 |
+
_orig_class_name = orig_class.__name__
|
| 106 |
+
|
| 107 |
+
def __init__(self, *args, **kwargs):
|
| 108 |
+
super().__init__(*args, **kwargs)
|
| 109 |
+
record_init_args = getattr(self, '_record_init_args', True)
|
| 110 |
+
self._init_args = copy.deepcopy(args) if record_init_args else None
|
| 111 |
+
self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
|
| 112 |
+
assert orig_class.__name__ in orig_module.__dict__
|
| 113 |
+
_check_pickleable(self.__reduce__())
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def init_args(self):
|
| 117 |
+
assert self._init_args is not None
|
| 118 |
+
return copy.deepcopy(self._init_args)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def init_kwargs(self):
|
| 122 |
+
assert self._init_kwargs is not None
|
| 123 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
| 124 |
+
|
| 125 |
+
def __reduce__(self):
|
| 126 |
+
fields = list(super().__reduce__())
|
| 127 |
+
fields += [None] * max(3 - len(fields), 0)
|
| 128 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
| 129 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
| 130 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
| 131 |
+
fields[1] = (meta,) # reconstruct args
|
| 132 |
+
fields[2] = None # state dict
|
| 133 |
+
return tuple(fields)
|
| 134 |
+
|
| 135 |
+
_decorators.add(Decorator)
|
| 136 |
+
return Decorator
|
| 137 |
+
|
| 138 |
+
#----------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def is_persistent(obj):
|
| 141 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
| 142 |
+
whether it will save its source code when pickled.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
if obj in _decorators:
|
| 146 |
+
return True
|
| 147 |
+
except TypeError:
|
| 148 |
+
pass
|
| 149 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
| 150 |
+
|
| 151 |
+
#----------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def import_hook(hook):
|
| 154 |
+
r"""Register an import hook that is called whenever a persistent object
|
| 155 |
+
is being unpickled. A typical use case is to patch the pickled source
|
| 156 |
+
code to avoid errors and inconsistencies when the API of some imported
|
| 157 |
+
module has changed.
|
| 158 |
+
|
| 159 |
+
The hook should have the following signature:
|
| 160 |
+
|
| 161 |
+
hook(meta) -> modified meta
|
| 162 |
+
|
| 163 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
| 164 |
+
|
| 165 |
+
type: Type of the persistent object, e.g. `'class'`.
|
| 166 |
+
version: Internal version number of `torch_utils.persistence`.
|
| 167 |
+
module_src Original source code of the Python module.
|
| 168 |
+
class_name: Class name in the original Python module.
|
| 169 |
+
state: Internal state of the object.
|
| 170 |
+
|
| 171 |
+
Example:
|
| 172 |
+
|
| 173 |
+
@persistence.import_hook
|
| 174 |
+
def wreck_my_network(meta):
|
| 175 |
+
if meta.class_name == 'MyNetwork':
|
| 176 |
+
print('MyNetwork is being imported. I will wreck it!')
|
| 177 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
| 178 |
+
return meta
|
| 179 |
+
"""
|
| 180 |
+
assert callable(hook)
|
| 181 |
+
_import_hooks.append(hook)
|
| 182 |
+
|
| 183 |
+
#----------------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
def _reconstruct_persistent_obj(meta):
|
| 186 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
| 187 |
+
a persistent object.
|
| 188 |
+
"""
|
| 189 |
+
meta = dnnlib.EasyDict(meta)
|
| 190 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
| 191 |
+
for hook in _import_hooks:
|
| 192 |
+
meta = hook(meta)
|
| 193 |
+
assert meta is not None
|
| 194 |
+
|
| 195 |
+
assert meta.version == _version
|
| 196 |
+
module = _src_to_module(meta.module_src)
|
| 197 |
+
|
| 198 |
+
assert meta.type == 'class'
|
| 199 |
+
orig_class = module.__dict__[meta.class_name]
|
| 200 |
+
decorator_class = persistent_class(orig_class)
|
| 201 |
+
obj = decorator_class.__new__(decorator_class)
|
| 202 |
+
|
| 203 |
+
setstate = getattr(obj, '__setstate__', None)
|
| 204 |
+
if callable(setstate):
|
| 205 |
+
setstate(meta.state) # pylint: disable=not-callable
|
| 206 |
+
else:
|
| 207 |
+
obj.__dict__.update(meta.state)
|
| 208 |
+
return obj
|
| 209 |
+
|
| 210 |
+
#----------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def _module_to_src(module):
|
| 213 |
+
r"""Query the source code of a given Python module.
|
| 214 |
+
"""
|
| 215 |
+
src = _module_to_src_dict.get(module, None)
|
| 216 |
+
if src is None:
|
| 217 |
+
src = inspect.getsource(module)
|
| 218 |
+
_module_to_src_dict[module] = src
|
| 219 |
+
_src_to_module_dict[src] = module
|
| 220 |
+
return src
|
| 221 |
+
|
| 222 |
+
def _src_to_module(src):
|
| 223 |
+
r"""Get or create a Python module for the given source code.
|
| 224 |
+
"""
|
| 225 |
+
module = _src_to_module_dict.get(src, None)
|
| 226 |
+
if module is None:
|
| 227 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
| 228 |
+
module = types.ModuleType(module_name)
|
| 229 |
+
sys.modules[module_name] = module
|
| 230 |
+
_module_to_src_dict[module] = src
|
| 231 |
+
_src_to_module_dict[src] = module
|
| 232 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
| 233 |
+
return module
|
| 234 |
+
|
| 235 |
+
#----------------------------------------------------------------------------
|
| 236 |
+
|
| 237 |
+
def _check_pickleable(obj):
|
| 238 |
+
r"""Check that the given object is pickleable, raising an exception if
|
| 239 |
+
it is not. This function is expected to be considerably more efficient
|
| 240 |
+
than actually pickling the object.
|
| 241 |
+
"""
|
| 242 |
+
def recurse(obj):
|
| 243 |
+
if isinstance(obj, (list, tuple, set)):
|
| 244 |
+
return [recurse(x) for x in obj]
|
| 245 |
+
if isinstance(obj, dict):
|
| 246 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
| 247 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
| 248 |
+
return None # Python primitive types are pickleable.
|
| 249 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
| 250 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
| 251 |
+
if is_persistent(obj):
|
| 252 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
| 253 |
+
return obj
|
| 254 |
+
with io.BytesIO() as f:
|
| 255 |
+
pickle.dump(recurse(obj), f)
|
| 256 |
+
|
| 257 |
+
#----------------------------------------------------------------------------
|
back/preprocessing/torch_utils/training_stats.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for reporting and collecting training statistics across
|
| 9 |
+
multiple processes and devices. The interface is designed to minimize
|
| 10 |
+
synchronization overhead as well as the amount of boilerplate in user
|
| 11 |
+
code."""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import dnnlib
|
| 17 |
+
|
| 18 |
+
from . import misc
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
| 23 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
| 24 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
| 25 |
+
_rank = 0 # Rank of the current process.
|
| 26 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
| 27 |
+
_sync_called = False # Has _sync() been called yet?
|
| 28 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
| 29 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
| 30 |
+
|
| 31 |
+
#----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def init_multiprocessing(rank, sync_device):
|
| 34 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
| 35 |
+
across multiple processes.
|
| 36 |
+
|
| 37 |
+
This function must be called after
|
| 38 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
| 39 |
+
The call is not necessary if multi-process collection is not needed.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
rank: Rank of the current process.
|
| 43 |
+
sync_device: PyTorch device to use for inter-process
|
| 44 |
+
communication, or None to disable multi-process
|
| 45 |
+
collection. Typically `torch.device('cuda', rank)`.
|
| 46 |
+
"""
|
| 47 |
+
global _rank, _sync_device
|
| 48 |
+
assert not _sync_called
|
| 49 |
+
_rank = rank
|
| 50 |
+
_sync_device = sync_device
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
@misc.profiled_function
|
| 55 |
+
def report(name, value):
|
| 56 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
| 57 |
+
`Collector`, across device and process boundaries. NaNs and Infs are
|
| 58 |
+
ignored.
|
| 59 |
+
|
| 60 |
+
This function is expected to be extremely cheap and can be safely
|
| 61 |
+
called from anywhere in the training loop, loss function, or inside a
|
| 62 |
+
`torch.nn.Module`.
|
| 63 |
+
|
| 64 |
+
Warning: The current implementation expects the set of unique names to
|
| 65 |
+
be consistent across processes. Please make sure that `report()` is
|
| 66 |
+
called at least once for each unique name by each process, and in the
|
| 67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
| 68 |
+
`report(name, [])` (empty list).
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
name: Arbitrary string specifying the name of the statistic.
|
| 72 |
+
Averages are accumulated separately for each unique name.
|
| 73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
| 74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
The same `value` that was passed in.
|
| 78 |
+
"""
|
| 79 |
+
if name not in _counters:
|
| 80 |
+
_counters[name] = dict()
|
| 81 |
+
|
| 82 |
+
elems = torch.as_tensor(value)
|
| 83 |
+
if elems.numel() == 0:
|
| 84 |
+
return value
|
| 85 |
+
|
| 86 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
| 87 |
+
square = elems.square()
|
| 88 |
+
finite = square.isfinite()
|
| 89 |
+
moments = torch.stack([
|
| 90 |
+
finite.sum(dtype=_reduce_dtype),
|
| 91 |
+
torch.where(finite, elems, 0).sum(),
|
| 92 |
+
torch.where(finite, square, 0).sum(),
|
| 93 |
+
])
|
| 94 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
| 95 |
+
moments = moments.to(_counter_dtype)
|
| 96 |
+
|
| 97 |
+
device = moments.device
|
| 98 |
+
if device not in _counters[name]:
|
| 99 |
+
_counters[name][device] = torch.zeros_like(moments)
|
| 100 |
+
_counters[name][device].add_(moments)
|
| 101 |
+
return value
|
| 102 |
+
|
| 103 |
+
#----------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
def report0(name, value):
|
| 106 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
| 107 |
+
but ignores any scalars provided by the other processes.
|
| 108 |
+
See `report()` for further details.
|
| 109 |
+
"""
|
| 110 |
+
report(name, value if _rank == 0 else [])
|
| 111 |
+
return value
|
| 112 |
+
|
| 113 |
+
#----------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class Collector:
|
| 116 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
| 117 |
+
computes their long-term averages (mean and standard deviation) over
|
| 118 |
+
user-defined periods of time.
|
| 119 |
+
|
| 120 |
+
The averages are first collected into internal counters that are not
|
| 121 |
+
directly visible to the user. They are then copied to the user-visible
|
| 122 |
+
state as a result of calling `update()` and can then be queried using
|
| 123 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
| 124 |
+
internal counters for the next round, so that the user-visible state
|
| 125 |
+
effectively reflects averages collected between the last two calls to
|
| 126 |
+
`update()`.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
regex: Regular expression defining which statistics to
|
| 130 |
+
collect. The default is to collect everything.
|
| 131 |
+
keep_previous: Whether to retain the previous averages if no
|
| 132 |
+
scalars were collected on a given round
|
| 133 |
+
(default: False).
|
| 134 |
+
"""
|
| 135 |
+
def __init__(self, regex='.*', keep_previous=False):
|
| 136 |
+
self._regex = re.compile(regex)
|
| 137 |
+
self._keep_previous = keep_previous
|
| 138 |
+
self._cumulative = dict()
|
| 139 |
+
self._moments = dict()
|
| 140 |
+
self.update()
|
| 141 |
+
self._moments.clear()
|
| 142 |
+
|
| 143 |
+
def names(self):
|
| 144 |
+
r"""Returns the names of all statistics broadcasted so far that
|
| 145 |
+
match the regular expression specified at construction time.
|
| 146 |
+
"""
|
| 147 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
| 148 |
+
|
| 149 |
+
def update(self):
|
| 150 |
+
r"""Copies current values of the internal counters to the
|
| 151 |
+
user-visible state and resets them for the next round.
|
| 152 |
+
|
| 153 |
+
If `keep_previous=True` was specified at construction time, the
|
| 154 |
+
operation is skipped for statistics that have received no scalars
|
| 155 |
+
since the last update, retaining their previous averages.
|
| 156 |
+
|
| 157 |
+
This method performs a number of GPU-to-CPU transfers and one
|
| 158 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
| 159 |
+
periodically in the main training loop, typically once every
|
| 160 |
+
N training steps.
|
| 161 |
+
"""
|
| 162 |
+
if not self._keep_previous:
|
| 163 |
+
self._moments.clear()
|
| 164 |
+
for name, cumulative in _sync(self.names()):
|
| 165 |
+
if name not in self._cumulative:
|
| 166 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 167 |
+
delta = cumulative - self._cumulative[name]
|
| 168 |
+
self._cumulative[name].copy_(cumulative)
|
| 169 |
+
if float(delta[0]) != 0:
|
| 170 |
+
self._moments[name] = delta
|
| 171 |
+
|
| 172 |
+
def _get_delta(self, name):
|
| 173 |
+
r"""Returns the raw moments that were accumulated for the given
|
| 174 |
+
statistic between the last two calls to `update()`, or zero if
|
| 175 |
+
no scalars were collected.
|
| 176 |
+
"""
|
| 177 |
+
assert self._regex.fullmatch(name)
|
| 178 |
+
if name not in self._moments:
|
| 179 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 180 |
+
return self._moments[name]
|
| 181 |
+
|
| 182 |
+
def num(self, name):
|
| 183 |
+
r"""Returns the number of scalars that were accumulated for the given
|
| 184 |
+
statistic between the last two calls to `update()`, or zero if
|
| 185 |
+
no scalars were collected.
|
| 186 |
+
"""
|
| 187 |
+
delta = self._get_delta(name)
|
| 188 |
+
return int(delta[0])
|
| 189 |
+
|
| 190 |
+
def mean(self, name):
|
| 191 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
| 192 |
+
given statistic between the last two calls to `update()`, or NaN if
|
| 193 |
+
no scalars were collected.
|
| 194 |
+
"""
|
| 195 |
+
delta = self._get_delta(name)
|
| 196 |
+
if int(delta[0]) == 0:
|
| 197 |
+
return float('nan')
|
| 198 |
+
return float(delta[1] / delta[0])
|
| 199 |
+
|
| 200 |
+
def std(self, name):
|
| 201 |
+
r"""Returns the standard deviation of the scalars that were
|
| 202 |
+
accumulated for the given statistic between the last two calls to
|
| 203 |
+
`update()`, or NaN if no scalars were collected.
|
| 204 |
+
"""
|
| 205 |
+
delta = self._get_delta(name)
|
| 206 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
| 207 |
+
return float('nan')
|
| 208 |
+
if int(delta[0]) == 1:
|
| 209 |
+
return float(0)
|
| 210 |
+
mean = float(delta[1] / delta[0])
|
| 211 |
+
raw_var = float(delta[2] / delta[0])
|
| 212 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
| 213 |
+
|
| 214 |
+
def as_dict(self):
|
| 215 |
+
r"""Returns the averages accumulated between the last two calls to
|
| 216 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
| 217 |
+
|
| 218 |
+
dnnlib.EasyDict(
|
| 219 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
| 220 |
+
...
|
| 221 |
+
)
|
| 222 |
+
"""
|
| 223 |
+
stats = dnnlib.EasyDict()
|
| 224 |
+
for name in self.names():
|
| 225 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
| 226 |
+
return stats
|
| 227 |
+
|
| 228 |
+
def __getitem__(self, name):
|
| 229 |
+
r"""Convenience getter.
|
| 230 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
| 231 |
+
"""
|
| 232 |
+
return self.mean(name)
|
| 233 |
+
|
| 234 |
+
#----------------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def _sync(names):
|
| 237 |
+
r"""Synchronize the global cumulative counters across devices and
|
| 238 |
+
processes. Called internally by `Collector.update()`.
|
| 239 |
+
"""
|
| 240 |
+
if len(names) == 0:
|
| 241 |
+
return []
|
| 242 |
+
global _sync_called
|
| 243 |
+
_sync_called = True
|
| 244 |
+
|
| 245 |
+
# Check that all ranks have the same set of names.
|
| 246 |
+
if _sync_device is not None:
|
| 247 |
+
value = hash(tuple(tuple(ord(char) for char in name) for name in names))
|
| 248 |
+
other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
|
| 249 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 250 |
+
if value != int(other.cpu()):
|
| 251 |
+
raise ValueError('Training statistics are inconsistent between ranks')
|
| 252 |
+
|
| 253 |
+
# Collect deltas within current rank.
|
| 254 |
+
deltas = []
|
| 255 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
| 256 |
+
for name in names:
|
| 257 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
| 258 |
+
for counter in _counters[name].values():
|
| 259 |
+
delta.add_(counter.to(device))
|
| 260 |
+
counter.copy_(torch.zeros_like(counter))
|
| 261 |
+
deltas.append(delta)
|
| 262 |
+
deltas = torch.stack(deltas)
|
| 263 |
+
|
| 264 |
+
# Sum deltas across ranks.
|
| 265 |
+
if _sync_device is not None:
|
| 266 |
+
torch.distributed.all_reduce(deltas)
|
| 267 |
+
|
| 268 |
+
# Update cumulative values.
|
| 269 |
+
deltas = deltas.cpu()
|
| 270 |
+
for idx, name in enumerate(names):
|
| 271 |
+
if name not in _cumulative:
|
| 272 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 273 |
+
_cumulative[name].add_(deltas[idx])
|
| 274 |
+
|
| 275 |
+
# Return name-value pairs.
|
| 276 |
+
return [(name, _cumulative[name]) for name in names]
|
| 277 |
+
|
| 278 |
+
#----------------------------------------------------------------------------
|
| 279 |
+
# Convenience.
|
| 280 |
+
|
| 281 |
+
default_collector = Collector()
|
| 282 |
+
|
| 283 |
+
#----------------------------------------------------------------------------
|
back/wandb/debug-internal.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-23T13:58:41.647788404+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-23T13:58:42.578470875+08:00","level":"INFO","msg":"stream: created new stream","id":"w9holkos"}
|
| 3 |
+
{"time":"2026-03-23T13:58:42.578676113+08:00","level":"INFO","msg":"handler: started","stream_id":"w9holkos"}
|
| 4 |
+
{"time":"2026-03-23T13:58:42.579473589+08:00","level":"INFO","msg":"stream: started","id":"w9holkos"}
|
| 5 |
+
{"time":"2026-03-23T13:58:42.57951741+08:00","level":"INFO","msg":"sender: started","stream_id":"w9holkos"}
|
| 6 |
+
{"time":"2026-03-23T13:58:42.579478227+08:00","level":"INFO","msg":"writer: started","stream_id":"w9holkos"}
|
| 7 |
+
{"time":"2026-03-23T14:49:13.568442881+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 8 |
+
{"time":"2026-03-23T14:52:15.597652411+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 9 |
+
{"time":"2026-03-23T14:52:26.072213509+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": write tcp 172.20.98.27:52324->35.186.228.49:443: write: broken pipe"}
|
| 10 |
+
{"time":"2026-03-23T17:02:52.905542765+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 11 |
+
{"time":"2026-03-23T17:05:55.176103762+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 12 |
+
{"time":"2026-03-23T17:06:10.164453104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
|
| 13 |
+
{"time":"2026-03-23T22:05:06.25355716+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:44154->35.186.228.49:443: read: connection reset by peer"}
|
| 14 |
+
{"time":"2026-03-23T22:05:20.791067182+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:40392->35.186.228.49:443: read: connection reset by peer"}
|
| 15 |
+
{"time":"2026-03-24T02:18:38.770696332+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 16 |
+
{"time":"2026-03-24T06:25:41.879737278+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 17 |
+
{"time":"2026-03-24T06:30:14.989373032+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 18 |
+
{"time":"2026-03-24T09:05:02.85908394+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:46722->35.186.228.49:443: read: connection reset by peer"}
|
| 19 |
+
{"time":"2026-03-25T04:41:04.741907157+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
|
back/wandb/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Configure stats pid to 400275
|
| 3 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug.log
|
| 5 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log
|
| 6 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-23 13:58:41,630 INFO MainThread:400275 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-23 13:58:41,643 INFO MainThread:400275 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-23 13:58:41,646 INFO MainThread:400275 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-23 13:58:41,659 INFO MainThread:400275 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-23 13:58:43,108 INFO MainThread:400275 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-23 13:58:43,202 INFO MainThread:400275 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-23 13:58:43,209 INFO MainThread:400275 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-23 13:58:43,210 INFO MainThread:400275 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.05, 't_c': 0.75, 'ot_cls': True}
|
back/wandb/run-20260322_141726-2yw08kz9/files/config.yaml
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_wandb:
|
| 2 |
+
value:
|
| 3 |
+
cli_version: 0.25.0
|
| 4 |
+
e:
|
| 5 |
+
257k9ot60u1bv0aiwlacsvutj9c72h7y:
|
| 6 |
+
args:
|
| 7 |
+
- --report-to
|
| 8 |
+
- wandb
|
| 9 |
+
- --allow-tf32
|
| 10 |
+
- --mixed-precision
|
| 11 |
+
- bf16
|
| 12 |
+
- --seed
|
| 13 |
+
- "0"
|
| 14 |
+
- --path-type
|
| 15 |
+
- linear
|
| 16 |
+
- --prediction
|
| 17 |
+
- v
|
| 18 |
+
- --weighting
|
| 19 |
+
- uniform
|
| 20 |
+
- --model
|
| 21 |
+
- SiT-XL/2
|
| 22 |
+
- --enc-type
|
| 23 |
+
- dinov2-vit-b
|
| 24 |
+
- --encoder-depth
|
| 25 |
+
- "8"
|
| 26 |
+
- --proj-coeff
|
| 27 |
+
- "0.5"
|
| 28 |
+
- --output-dir
|
| 29 |
+
- exps
|
| 30 |
+
- --exp-name
|
| 31 |
+
- jsflow-experiment
|
| 32 |
+
- --batch-size
|
| 33 |
+
- "256"
|
| 34 |
+
- --data-dir
|
| 35 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 36 |
+
- --semantic-features-dir
|
| 37 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 38 |
+
- --learning-rate
|
| 39 |
+
- "0.00005"
|
| 40 |
+
- --t-c
|
| 41 |
+
- "0.5"
|
| 42 |
+
- --cls
|
| 43 |
+
- "0.2"
|
| 44 |
+
- --ot-cls
|
| 45 |
+
codePath: train.py
|
| 46 |
+
codePathLocal: train.py
|
| 47 |
+
cpu_count: 96
|
| 48 |
+
cpu_count_logical: 192
|
| 49 |
+
cudaVersion: "13.0"
|
| 50 |
+
disk:
|
| 51 |
+
/:
|
| 52 |
+
total: "3838880616448"
|
| 53 |
+
used: "357556633600"
|
| 54 |
+
email: 2365972933@qq.com
|
| 55 |
+
executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 56 |
+
git:
|
| 57 |
+
commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
|
| 58 |
+
remote: https://github.com/Martinser/REG.git
|
| 59 |
+
gpu: NVIDIA H100 80GB HBM3
|
| 60 |
+
gpu_count: 4
|
| 61 |
+
gpu_nvidia:
|
| 62 |
+
- architecture: Hopper
|
| 63 |
+
cudaCores: 16896
|
| 64 |
+
memoryTotal: "85520809984"
|
| 65 |
+
name: NVIDIA H100 80GB HBM3
|
| 66 |
+
uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
|
| 67 |
+
- architecture: Hopper
|
| 68 |
+
cudaCores: 16896
|
| 69 |
+
memoryTotal: "85520809984"
|
| 70 |
+
name: NVIDIA H100 80GB HBM3
|
| 71 |
+
uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
|
| 72 |
+
- architecture: Hopper
|
| 73 |
+
cudaCores: 16896
|
| 74 |
+
memoryTotal: "85520809984"
|
| 75 |
+
name: NVIDIA H100 80GB HBM3
|
| 76 |
+
uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
|
| 77 |
+
- architecture: Hopper
|
| 78 |
+
cudaCores: 16896
|
| 79 |
+
memoryTotal: "85520809984"
|
| 80 |
+
name: NVIDIA H100 80GB HBM3
|
| 81 |
+
uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
|
| 82 |
+
host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 83 |
+
memory:
|
| 84 |
+
total: "2164115296256"
|
| 85 |
+
os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
|
| 86 |
+
program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
|
| 87 |
+
python: CPython 3.12.9
|
| 88 |
+
root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
|
| 89 |
+
startedAt: "2026-03-22T06:17:26.670763Z"
|
| 90 |
+
writerId: 257k9ot60u1bv0aiwlacsvutj9c72h7y
|
| 91 |
+
m: []
|
| 92 |
+
python_version: 3.12.9
|
| 93 |
+
t:
|
| 94 |
+
"1":
|
| 95 |
+
- 1
|
| 96 |
+
- 5
|
| 97 |
+
- 11
|
| 98 |
+
- 41
|
| 99 |
+
- 49
|
| 100 |
+
- 53
|
| 101 |
+
- 63
|
| 102 |
+
- 71
|
| 103 |
+
- 83
|
| 104 |
+
- 98
|
| 105 |
+
"2":
|
| 106 |
+
- 1
|
| 107 |
+
- 5
|
| 108 |
+
- 11
|
| 109 |
+
- 41
|
| 110 |
+
- 49
|
| 111 |
+
- 53
|
| 112 |
+
- 63
|
| 113 |
+
- 71
|
| 114 |
+
- 83
|
| 115 |
+
- 98
|
| 116 |
+
"3":
|
| 117 |
+
- 13
|
| 118 |
+
- 61
|
| 119 |
+
"4": 3.12.9
|
| 120 |
+
"5": 0.25.0
|
| 121 |
+
"6": 4.53.2
|
| 122 |
+
"12": 0.25.0
|
| 123 |
+
"13": linux-x86_64
|
| 124 |
+
adam_beta1:
|
| 125 |
+
value: 0.9
|
| 126 |
+
adam_beta2:
|
| 127 |
+
value: 0.999
|
| 128 |
+
adam_epsilon:
|
| 129 |
+
value: 1e-08
|
| 130 |
+
adam_weight_decay:
|
| 131 |
+
value: 0
|
| 132 |
+
allow_tf32:
|
| 133 |
+
value: true
|
| 134 |
+
batch_size:
|
| 135 |
+
value: 256
|
| 136 |
+
cfg_prob:
|
| 137 |
+
value: 0.1
|
| 138 |
+
checkpointing_steps:
|
| 139 |
+
value: 10000
|
| 140 |
+
cls:
|
| 141 |
+
value: 0.2
|
| 142 |
+
data_dir:
|
| 143 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 144 |
+
enc_type:
|
| 145 |
+
value: dinov2-vit-b
|
| 146 |
+
encoder_depth:
|
| 147 |
+
value: 8
|
| 148 |
+
epochs:
|
| 149 |
+
value: 1400
|
| 150 |
+
exp_name:
|
| 151 |
+
value: jsflow-experiment
|
| 152 |
+
fused_attn:
|
| 153 |
+
value: true
|
| 154 |
+
gradient_accumulation_steps:
|
| 155 |
+
value: 1
|
| 156 |
+
learning_rate:
|
| 157 |
+
value: 5e-05
|
| 158 |
+
legacy:
|
| 159 |
+
value: false
|
| 160 |
+
logging_dir:
|
| 161 |
+
value: logs
|
| 162 |
+
max_grad_norm:
|
| 163 |
+
value: 1
|
| 164 |
+
max_train_steps:
|
| 165 |
+
value: 1000000
|
| 166 |
+
mixed_precision:
|
| 167 |
+
value: bf16
|
| 168 |
+
model:
|
| 169 |
+
value: SiT-XL/2
|
| 170 |
+
num_classes:
|
| 171 |
+
value: 1000
|
| 172 |
+
num_workers:
|
| 173 |
+
value: 4
|
| 174 |
+
ops_head:
|
| 175 |
+
value: 16
|
| 176 |
+
ot_cls:
|
| 177 |
+
value: true
|
| 178 |
+
output_dir:
|
| 179 |
+
value: exps
|
| 180 |
+
path_type:
|
| 181 |
+
value: linear
|
| 182 |
+
prediction:
|
| 183 |
+
value: v
|
| 184 |
+
proj_coeff:
|
| 185 |
+
value: 0.5
|
| 186 |
+
qk_norm:
|
| 187 |
+
value: false
|
| 188 |
+
report_to:
|
| 189 |
+
value: wandb
|
| 190 |
+
resolution:
|
| 191 |
+
value: 256
|
| 192 |
+
resume_step:
|
| 193 |
+
value: 0
|
| 194 |
+
sampling_steps:
|
| 195 |
+
value: 10000
|
| 196 |
+
seed:
|
| 197 |
+
value: 0
|
| 198 |
+
semantic_features_dir:
|
| 199 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 200 |
+
t_c:
|
| 201 |
+
value: 0.5
|
| 202 |
+
weighting:
|
| 203 |
+
value: uniform
|
back/wandb/run-20260322_141726-2yw08kz9/files/output.log
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<614:34:39, 2.21s/it][[34m2026-03-22 14:17:31[0m] Generating EMA samples done.
|
| 2 |
+
[[34m2026-03-22 14:17:31[0m] Step: 1, Training Logs: loss_final: 3.278940, loss_mean: 1.706308, proj_loss: 0.001541, loss_mean_cls: 1.571091, grad_norm: 1.481672
|
| 3 |
+
Steps: 0%| | 2/1000000 [00:02<289:06:04, 1.04s/it, grad_norm=1.48, loss_final=3.28, loss_mean=1.71, loss_mean_cls=1.57, proj_loss=0.001[[34m2026-03-22 14:17:31[0m] Step: 2, Training Logs: loss_final: 3.211831, loss_mean: 1.688932, proj_loss: -0.010287, loss_mean_cls: 1.533185, grad_norm: 1.055476
|
| 4 |
+
Steps: 0%| | 3/1000000 [00:02<187:48:39, 1.48it/s, grad_norm=1.06, loss_final=3.21, loss_mean=1.69, loss_mean_cls=1.53, proj_loss=-0.01[[34m2026-03-22 14:17:31[0m] Step: 3, Training Logs: loss_final: 3.201248, loss_mean: 1.663205, proj_loss: -0.019184, loss_mean_cls: 1.557227, grad_norm: 1.116387
|
| 5 |
+
Steps: 0%| | 4/1000000 [00:02<140:12:43, 1.98it/s, grad_norm=1.12, loss_final=3.2, loss_mean=1.66, loss_mean_cls=1.56, proj_loss=-0.019[[34m2026-03-22 14:17:32[0m] Step: 4, Training Logs: loss_final: 3.198367, loss_mean: 1.682051, proj_loss: -0.026376, loss_mean_cls: 1.542691, grad_norm: 0.722294
|
| 6 |
+
Steps: 0%| | 5/1000000 [00:03<113:52:43, 2.44it/s, grad_norm=0.722, loss_final=3.2, loss_mean=1.68, loss_mean_cls=1.54, proj_loss=-0.02[[34m2026-03-22 14:17:32[0m] Step: 5, Training Logs: loss_final: 3.140483, loss_mean: 1.679105, proj_loss: -0.034564, loss_mean_cls: 1.495943, grad_norm: 0.811589
|
| 7 |
+
Steps: 0%| | 6/1000000 [00:03<97:59:40, 2.83it/s, grad_norm=0.812, loss_final=3.14, loss_mean=1.68, loss_mean_cls=1.5, proj_loss=-0.034[[34m2026-03-22 14:17:32[0m] Step: 6, Training Logs: loss_final: 2.988440, loss_mean: 1.682339, proj_loss: -0.039506, loss_mean_cls: 1.345606, grad_norm: 0.931524
|
| 8 |
+
Steps: 0%| | 7/1000000 [00:03<87:55:00, 3.16it/s, grad_norm=0.932, loss_final=2.99, loss_mean=1.68, loss_mean_cls=1.35, proj_loss=-0.03[[34m2026-03-22 14:17:32[0m] Step: 7, Training Logs: loss_final: 3.111949, loss_mean: 1.690802, proj_loss: -0.042757, loss_mean_cls: 1.463904, grad_norm: 0.830852
|
| 9 |
+
Steps: 0%| | 8/1000000 [00:03<81:19:20, 3.42it/s, grad_norm=0.831, loss_final=3.11, loss_mean=1.69, loss_mean_cls=1.46, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 8, Training Logs: loss_final: 3.278931, loss_mean: 1.660797, proj_loss: -0.045011, loss_mean_cls: 1.663145, grad_norm: 0.847438
|
| 10 |
+
Steps: 0%| | 9/1000000 [00:04<76:56:10, 3.61it/s, grad_norm=0.847, loss_final=3.28, loss_mean=1.66, loss_mean_cls=1.66, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 9, Training Logs: loss_final: 3.221569, loss_mean: 1.658834, proj_loss: -0.046031, loss_mean_cls: 1.608767, grad_norm: 0.909827
|
| 11 |
+
Steps: 0%| | 10/1000000 [00:04<73:57:18, 3.76it/s, grad_norm=0.91, loss_final=3.22, loss_mean=1.66, loss_mean_cls=1.61, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 10, Training Logs: loss_final: 3.216744, loss_mean: 1.665229, proj_loss: -0.047761, loss_mean_cls: 1.599277, grad_norm: 1.014574
|
| 12 |
+
Steps: 0%| | 11/1000000 [00:04<71:52:01, 3.87it/s, grad_norm=1.01, loss_final=3.22, loss_mean=1.67, loss_mean_cls=1.6, proj_loss=-0.047[[34m2026-03-22 14:17:33[0m] Step: 11, Training Logs: loss_final: 3.216658, loss_mean: 1.649915, proj_loss: -0.049347, loss_mean_cls: 1.616090, grad_norm: 1.028789
|
| 13 |
+
Steps: 0%| | 12/1000000 [00:04<70:26:20, 3.94it/s, grad_norm=1.03, loss_final=3.22, loss_mean=1.65, loss_mean_cls=1.62, proj_loss=-0.04[[34m2026-03-22 14:17:34[0m] Step: 12, Training Logs: loss_final: 3.155676, loss_mean: 1.624463, proj_loss: -0.049856, loss_mean_cls: 1.581069, grad_norm: 1.231291
|
| 14 |
+
Steps: 0%| | 13/1000000 [00:05<69:25:29, 4.00it/s, grad_norm=1.23, loss_final=3.16, loss_mean=1.62, loss_mean_cls=1.58, proj_loss=-0.04Traceback (most recent call last):
|
| 15 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
|
| 16 |
+
main(args)
|
| 17 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
|
| 18 |
+
"loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 19 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 20 |
+
KeyboardInterrupt
|
| 21 |
+
[rank0]: Traceback (most recent call last):
|
| 22 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
|
| 23 |
+
[rank0]: main(args)
|
| 24 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
|
| 25 |
+
[rank0]: "loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 26 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 27 |
+
[rank0]: KeyboardInterrupt
|
back/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
back/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T06:17:26.670763Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357556633600"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "257k9ot60u1bv0aiwlacsvutj9c72h7y"
|
| 101 |
+
}
|
back/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"loss_mean_cls":1.5810688734054565,"_timestamp":1.7741602540511734e+09,"_runtime":5.247627056,"loss_mean":1.6244629621505737,"proj_loss":-0.04985573887825012,"grad_norm":1.2312908172607422,"_wandb":{"runtime":5},"_step":12,"loss_final":3.1556761264801025}
|
back/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T14:17:27.013311984+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T14:17:28.347732261+08:00","level":"INFO","msg":"stream: created new stream","id":"2yw08kz9"}
|
| 3 |
+
{"time":"2026-03-22T14:17:28.347960938+08:00","level":"INFO","msg":"handler: started","stream_id":"2yw08kz9"}
|
| 4 |
+
{"time":"2026-03-22T14:17:28.348671928+08:00","level":"INFO","msg":"stream: started","id":"2yw08kz9"}
|
| 5 |
+
{"time":"2026-03-22T14:17:28.348731034+08:00","level":"INFO","msg":"sender: started","stream_id":"2yw08kz9"}
|
| 6 |
+
{"time":"2026-03-22T14:17:28.348748525+08:00","level":"INFO","msg":"writer: started","stream_id":"2yw08kz9"}
|
| 7 |
+
{"time":"2026-03-22T14:17:34.316421629+08:00","level":"INFO","msg":"stream: closing","id":"2yw08kz9"}
|
back/wandb/run-20260322_141726-2yw08kz9/logs/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Configure stats pid to 316313
|
| 3 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log
|
| 5 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log
|
| 6 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 14:17:26,994 INFO MainThread:316313 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 14:17:27,008 INFO MainThread:316313 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 14:17:27,011 INFO MainThread:316313 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 14:17:27,025 INFO MainThread:316313 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 14:17:29,067 INFO MainThread:316313 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 14:17:29,159 INFO MainThread:316313 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
| 21 |
+
2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [service_client.py:_forward_responses():134] Reached EOF.
|
| 22 |
+
2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
|
back/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb
ADDED
|
Binary file (7 Bytes). View file
|
|
|
back/wandb/run-20260322_141833-vm0y8t9t/files/output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
back/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
back/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T06:18:33.208941Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357556703232"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "gklxguwapb72cxij4696gj37bh1rbthi"
|
| 101 |
+
}
|
back/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T14:18:33.472940651+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T14:18:35.380852704+08:00","level":"INFO","msg":"stream: created new stream","id":"vm0y8t9t"}
|
| 3 |
+
{"time":"2026-03-22T14:18:35.381056887+08:00","level":"INFO","msg":"handler: started","stream_id":"vm0y8t9t"}
|
| 4 |
+
{"time":"2026-03-22T14:18:35.382108345+08:00","level":"INFO","msg":"writer: started","stream_id":"vm0y8t9t"}
|
| 5 |
+
{"time":"2026-03-22T14:18:35.382119604+08:00","level":"INFO","msg":"stream: started","id":"vm0y8t9t"}
|
| 6 |
+
{"time":"2026-03-22T14:18:35.382161533+08:00","level":"INFO","msg":"sender: started","stream_id":"vm0y8t9t"}
|
back/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Configure stats pid to 318585
|
| 3 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log
|
| 5 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log
|
| 6 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 14:18:33,460 INFO MainThread:318585 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 14:18:33,470 INFO MainThread:318585 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 14:18:33,472 INFO MainThread:318585 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 14:18:33,485 INFO MainThread:318585 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 14:18:36,829 INFO MainThread:318585 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
back/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_wandb:
|
| 2 |
+
value:
|
| 3 |
+
cli_version: 0.25.0
|
| 4 |
+
e:
|
| 5 |
+
ucanic8s891x6sl28vnbha78lzoecw66:
|
| 6 |
+
args:
|
| 7 |
+
- --report-to
|
| 8 |
+
- wandb
|
| 9 |
+
- --allow-tf32
|
| 10 |
+
- --mixed-precision
|
| 11 |
+
- bf16
|
| 12 |
+
- --seed
|
| 13 |
+
- "0"
|
| 14 |
+
- --path-type
|
| 15 |
+
- linear
|
| 16 |
+
- --prediction
|
| 17 |
+
- v
|
| 18 |
+
- --weighting
|
| 19 |
+
- uniform
|
| 20 |
+
- --model
|
| 21 |
+
- SiT-XL/2
|
| 22 |
+
- --enc-type
|
| 23 |
+
- dinov2-vit-b
|
| 24 |
+
- --encoder-depth
|
| 25 |
+
- "8"
|
| 26 |
+
- --proj-coeff
|
| 27 |
+
- "0.5"
|
| 28 |
+
- --output-dir
|
| 29 |
+
- exps
|
| 30 |
+
- --exp-name
|
| 31 |
+
- jsflow-experiment
|
| 32 |
+
- --batch-size
|
| 33 |
+
- "256"
|
| 34 |
+
- --data-dir
|
| 35 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 36 |
+
- --semantic-features-dir
|
| 37 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 38 |
+
- --learning-rate
|
| 39 |
+
- "0.00005"
|
| 40 |
+
- --t-c
|
| 41 |
+
- "0.5"
|
| 42 |
+
- --cls
|
| 43 |
+
- "0.2"
|
| 44 |
+
- --ot-cls
|
| 45 |
+
codePath: train.py
|
| 46 |
+
codePathLocal: train.py
|
| 47 |
+
cpu_count: 96
|
| 48 |
+
cpu_count_logical: 192
|
| 49 |
+
cudaVersion: "13.0"
|
| 50 |
+
disk:
|
| 51 |
+
/:
|
| 52 |
+
total: "3838880616448"
|
| 53 |
+
used: "357557354496"
|
| 54 |
+
email: 2365972933@qq.com
|
| 55 |
+
executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 56 |
+
git:
|
| 57 |
+
commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
|
| 58 |
+
remote: https://github.com/Martinser/REG.git
|
| 59 |
+
gpu: NVIDIA H100 80GB HBM3
|
| 60 |
+
gpu_count: 4
|
| 61 |
+
gpu_nvidia:
|
| 62 |
+
- architecture: Hopper
|
| 63 |
+
cudaCores: 16896
|
| 64 |
+
memoryTotal: "85520809984"
|
| 65 |
+
name: NVIDIA H100 80GB HBM3
|
| 66 |
+
uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
|
| 67 |
+
- architecture: Hopper
|
| 68 |
+
cudaCores: 16896
|
| 69 |
+
memoryTotal: "85520809984"
|
| 70 |
+
name: NVIDIA H100 80GB HBM3
|
| 71 |
+
uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
|
| 72 |
+
- architecture: Hopper
|
| 73 |
+
cudaCores: 16896
|
| 74 |
+
memoryTotal: "85520809984"
|
| 75 |
+
name: NVIDIA H100 80GB HBM3
|
| 76 |
+
uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
|
| 77 |
+
- architecture: Hopper
|
| 78 |
+
cudaCores: 16896
|
| 79 |
+
memoryTotal: "85520809984"
|
| 80 |
+
name: NVIDIA H100 80GB HBM3
|
| 81 |
+
uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
|
| 82 |
+
host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 83 |
+
memory:
|
| 84 |
+
total: "2164115296256"
|
| 85 |
+
os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
|
| 86 |
+
program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
|
| 87 |
+
python: CPython 3.12.9
|
| 88 |
+
root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
|
| 89 |
+
startedAt: "2026-03-22T07:00:22.092510Z"
|
| 90 |
+
writerId: ucanic8s891x6sl28vnbha78lzoecw66
|
| 91 |
+
m: []
|
| 92 |
+
python_version: 3.12.9
|
| 93 |
+
t:
|
| 94 |
+
"1":
|
| 95 |
+
- 1
|
| 96 |
+
- 5
|
| 97 |
+
- 11
|
| 98 |
+
- 41
|
| 99 |
+
- 49
|
| 100 |
+
- 53
|
| 101 |
+
- 63
|
| 102 |
+
- 71
|
| 103 |
+
- 83
|
| 104 |
+
- 98
|
| 105 |
+
"2":
|
| 106 |
+
- 1
|
| 107 |
+
- 5
|
| 108 |
+
- 11
|
| 109 |
+
- 41
|
| 110 |
+
- 49
|
| 111 |
+
- 53
|
| 112 |
+
- 63
|
| 113 |
+
- 71
|
| 114 |
+
- 83
|
| 115 |
+
- 98
|
| 116 |
+
"3":
|
| 117 |
+
- 13
|
| 118 |
+
"4": 3.12.9
|
| 119 |
+
"5": 0.25.0
|
| 120 |
+
"6": 4.53.2
|
| 121 |
+
"12": 0.25.0
|
| 122 |
+
"13": linux-x86_64
|
| 123 |
+
adam_beta1:
|
| 124 |
+
value: 0.9
|
| 125 |
+
adam_beta2:
|
| 126 |
+
value: 0.999
|
| 127 |
+
adam_epsilon:
|
| 128 |
+
value: 1e-08
|
| 129 |
+
adam_weight_decay:
|
| 130 |
+
value: 0
|
| 131 |
+
allow_tf32:
|
| 132 |
+
value: true
|
| 133 |
+
batch_size:
|
| 134 |
+
value: 256
|
| 135 |
+
cfg_prob:
|
| 136 |
+
value: 0.1
|
| 137 |
+
checkpointing_steps:
|
| 138 |
+
value: 10000
|
| 139 |
+
cls:
|
| 140 |
+
value: 0.2
|
| 141 |
+
data_dir:
|
| 142 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 143 |
+
enc_type:
|
| 144 |
+
value: dinov2-vit-b
|
| 145 |
+
encoder_depth:
|
| 146 |
+
value: 8
|
| 147 |
+
epochs:
|
| 148 |
+
value: 1400
|
| 149 |
+
exp_name:
|
| 150 |
+
value: jsflow-experiment
|
| 151 |
+
fused_attn:
|
| 152 |
+
value: true
|
| 153 |
+
gradient_accumulation_steps:
|
| 154 |
+
value: 1
|
| 155 |
+
learning_rate:
|
| 156 |
+
value: 5e-05
|
| 157 |
+
legacy:
|
| 158 |
+
value: false
|
| 159 |
+
logging_dir:
|
| 160 |
+
value: logs
|
| 161 |
+
max_grad_norm:
|
| 162 |
+
value: 1
|
| 163 |
+
max_train_steps:
|
| 164 |
+
value: 1000000
|
| 165 |
+
mixed_precision:
|
| 166 |
+
value: bf16
|
| 167 |
+
model:
|
| 168 |
+
value: SiT-XL/2
|
| 169 |
+
num_classes:
|
| 170 |
+
value: 1000
|
| 171 |
+
num_workers:
|
| 172 |
+
value: 4
|
| 173 |
+
ops_head:
|
| 174 |
+
value: 16
|
| 175 |
+
ot_cls:
|
| 176 |
+
value: true
|
| 177 |
+
output_dir:
|
| 178 |
+
value: exps
|
| 179 |
+
path_type:
|
| 180 |
+
value: linear
|
| 181 |
+
prediction:
|
| 182 |
+
value: v
|
| 183 |
+
proj_coeff:
|
| 184 |
+
value: 0.5
|
| 185 |
+
qk_norm:
|
| 186 |
+
value: false
|
| 187 |
+
report_to:
|
| 188 |
+
value: wandb
|
| 189 |
+
resolution:
|
| 190 |
+
value: 256
|
| 191 |
+
resume_step:
|
| 192 |
+
value: 0
|
| 193 |
+
sampling_steps:
|
| 194 |
+
value: 2000
|
| 195 |
+
seed:
|
| 196 |
+
value: 0
|
| 197 |
+
semantic_features_dir:
|
| 198 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 199 |
+
t_c:
|
| 200 |
+
value: 0.5
|
| 201 |
+
weighting:
|
| 202 |
+
value: uniform
|
back/wandb/run-20260322_150022-yhxc5cgu/files/output.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<652:30:07, 2.35s/it][[34m2026-03-22 15:00:28[0m] Generating EMA samples for evaluation (t=1→0 and t=0.5)...
|
| 2 |
+
Traceback (most recent call last):
|
| 3 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
|
| 4 |
+
main(args)
|
| 5 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
|
| 6 |
+
cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
|
| 7 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 8 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
|
| 9 |
+
raise AttributeError(
|
| 10 |
+
AttributeError: 'SiT' object has no attribute 'semantic_channels'
|
| 11 |
+
[rank0]: Traceback (most recent call last):
|
| 12 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
|
| 13 |
+
[rank0]: main(args)
|
| 14 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
|
| 15 |
+
[rank0]: cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
|
| 16 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 17 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
|
| 18 |
+
[rank0]: raise AttributeError(
|
| 19 |
+
[rank0]: AttributeError: 'SiT' object has no attribute 'semantic_channels'
|
back/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|