Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- GVP/Baseline/DeltaFM/assets/deltafm.png +3 -0
- GVP/Baseline/DeltaFM/interference_vectors/imnet256_interference_vector.pt +3 -0
- GVP/Baseline/W_No.log +0 -0
- GVP/Baseline/classify_image_graph_def.pb +3 -0
- GVP/Baseline/compare_samples.sh +21 -0
- GVP/Baseline/compare_sampling.log +0 -0
- GVP/Baseline/download.py +41 -0
- GVP/Baseline/environment.yml +16 -0
- GVP/Baseline/evaluate_samples.sh +65 -0
- GVP/Baseline/evaluator.py +690 -0
- GVP/Baseline/gvp_sampling.log +51 -0
- GVP/Baseline/models.py +647 -0
- GVP/Baseline/nohup.out +180 -0
- GVP/Baseline/pic_npz.py +168 -0
- GVP/Baseline/run.sh +15 -0
- GVP/Baseline/sample_compare_ddp_rectified.py +274 -0
- GVP/Baseline/sample_ddp.py +233 -0
- GVP/Baseline/sample_rectified_noise.py +380 -0
- GVP/Baseline/samples.sh +16 -0
- GVP/Baseline/samples_ddp.sh +14 -0
- GVP/Baseline/transport/__pycache__/ot_plan.cpython-311.pyc +0 -0
- GVP/Baseline/transport/__pycache__/path.cpython-310.pyc +0 -0
- GVP/Baseline/transport/__pycache__/path.cpython-311.pyc +0 -0
- GVP/Baseline/transport/__pycache__/path.cpython-312.pyc +0 -0
- GVP/Baseline/transport/__pycache__/path.cpython-38.pyc +0 -0
- GVP/Baseline/transport/__pycache__/transport.cpython-310.pyc +0 -0
- GVP/Baseline/transport/__pycache__/transport.cpython-311.pyc +0 -0
- GVP/Baseline/transport/__pycache__/transport.cpython-312.pyc +0 -0
- GVP/Baseline/transport/__pycache__/transport.cpython-38.pyc +0 -0
- GVP/Baseline/transport/__pycache__/utils.cpython-310.pyc +0 -0
- GVP/Baseline/transport/__pycache__/utils.cpython-311.pyc +0 -0
- GVP/Baseline/transport/__pycache__/utils.cpython-312.pyc +0 -0
- GVP/Baseline/transport/__pycache__/utils.cpython-38.pyc +0 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0020000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0040000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0060000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0080000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0100000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0120000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0140000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0160000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0180000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0200000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0220000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0240000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0260000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0280000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt +3 -0
- VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0320000.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
GVP/Baseline/DeltaFM/assets/deltafm.png filter=lfs diff=lfs merge=lfs -text
|
GVP/Baseline/DeltaFM/assets/deltafm.png
ADDED
|
Git LFS Details
|
GVP/Baseline/DeltaFM/interference_vectors/imnet256_interference_vector.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2f34c2a05a7abe7c22e32d6cc06e65f8435e9271f7b4aa0cc7044dce1b7727b
|
| 3 |
+
size 17629
|
GVP/Baseline/W_No.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GVP/Baseline/classify_image_graph_def.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:009d6814d1bc560d4e7b236e170e9b2d5ca6f4b57bd8037f6db05776204415c6
|
| 3 |
+
size 95673916
|
GVP/Baseline/compare_samples.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
|
| 4 |
+
--nnodes=1 \
|
| 5 |
+
--nproc_per_node=4 \
|
| 6 |
+
--rdzv_endpoint=localhost:29166 \
|
| 7 |
+
sample_compare_ddp_rectified.py SDE \
|
| 8 |
+
--model SiT-XL/2 \
|
| 9 |
+
--sample-dir compare_samples \
|
| 10 |
+
--num-fid-samples 50000 \
|
| 11 |
+
--num-classes 1000 \
|
| 12 |
+
--global-seed 1 \
|
| 13 |
+
--cfg-scale 1.0 \
|
| 14 |
+
--num-sampling-steps 250 \
|
| 15 |
+
--depth 6 \
|
| 16 |
+
--use-sitf2 True \
|
| 17 |
+
--sitf2-threshold 0.5 \
|
| 18 |
+
--ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
|
| 19 |
+
--sitf2-ckpt /gemini/space/gzy_new/models/Baseline/results_256_gvp_disp/depth-mu-6-007-SiT-XL-2-GVP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt \
|
| 20 |
+
> compare_sampling.log 2>&1 &
|
| 21 |
+
|
GVP/Baseline/compare_sampling.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GVP/Baseline/download.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Functions for downloading pre-trained SiT models
|
| 6 |
+
"""
|
| 7 |
+
from torchvision.datasets.utils import download_url
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
pretrained_models = {'SiT-XL-2-256x256.pt'}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def find_model(model_name):
|
| 16 |
+
"""
|
| 17 |
+
Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
| 18 |
+
"""
|
| 19 |
+
if model_name in pretrained_models:
|
| 20 |
+
return download_model(model_name)
|
| 21 |
+
else:
|
| 22 |
+
assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
|
| 23 |
+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage, weights_only=False)
|
| 24 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 25 |
+
checkpoint = checkpoint["ema"]
|
| 26 |
+
return checkpoint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_model(model_name):
|
| 30 |
+
"""
|
| 31 |
+
Downloads a pre-trained SiT model from the web.
|
| 32 |
+
"""
|
| 33 |
+
assert model_name in pretrained_models
|
| 34 |
+
local_path = f'pretrained_models/{model_name}'
|
| 35 |
+
if not os.path.isfile(local_path):
|
| 36 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 37 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
|
| 38 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 39 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=False)
|
| 40 |
+
return model
|
| 41 |
+
|
GVP/Baseline/environment.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RN
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
dependencies:
|
| 6 |
+
- python >= 3.8
|
| 7 |
+
- pytorch >= 1.13
|
| 8 |
+
- torchvision
|
| 9 |
+
- pytorch-cuda >=11.7
|
| 10 |
+
- pip
|
| 11 |
+
- pip:
|
| 12 |
+
- timm
|
| 13 |
+
- diffusers
|
| 14 |
+
- accelerate
|
| 15 |
+
- torchdiffeq
|
| 16 |
+
- wandb
|
GVP/Baseline/evaluate_samples.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Execute all evaluation tasks in parallel
|
| 4 |
+
# Each command runs in the background using &
|
| 5 |
+
|
| 6 |
+
echo "Starting all evaluation tasks in parallel..."
|
| 7 |
+
|
| 8 |
+
# Reference batch path
|
| 9 |
+
REF_BATCH="/gemini/space/zhaozy/zhy/dataset/VIRTUAL_imagenet256_labeled.npz"
|
| 10 |
+
|
| 11 |
+
# Base directory for sample files
|
| 12 |
+
SAMPLE_DIR="/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/last_samples_depth_2_gvp_0.5"
|
| 13 |
+
|
| 14 |
+
# Change to the project root directory
|
| 15 |
+
cd /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching
|
| 16 |
+
|
| 17 |
+
# Evaluate threshold 0.0 on GPU 0
|
| 18 |
+
CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \
|
| 19 |
+
--ref_batch ${REF_BATCH} \
|
| 20 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 21 |
+
> eval_threshold_0.0.log 2>&1 &
|
| 22 |
+
|
| 23 |
+
# Evaluate threshold 0.15 on GPU 1
|
| 24 |
+
CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \
|
| 25 |
+
--ref_batch ${REF_BATCH} \
|
| 26 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.15-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 27 |
+
> eval_threshold_0.15.log 2>&1 &
|
| 28 |
+
|
| 29 |
+
# Evaluate threshold 0.25 on GPU 2
|
| 30 |
+
CUDA_VISIBLE_DEVICES=2 nohup python evaluator.py \
|
| 31 |
+
--ref_batch ${REF_BATCH} \
|
| 32 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.25-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 33 |
+
> eval_threshold_0.25.log 2>&1 &
|
| 34 |
+
|
| 35 |
+
# Evaluate threshold 0.5 on GPU 3
|
| 36 |
+
CUDA_VISIBLE_DEVICES=3 nohup python evaluator.py \
|
| 37 |
+
--ref_batch ${REF_BATCH} \
|
| 38 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.5-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 39 |
+
> eval_threshold_0.5.log 2>&1 &
|
| 40 |
+
|
| 41 |
+
# Evaluate threshold 0.75 on GPU 4
|
| 42 |
+
CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \
|
| 43 |
+
--ref_batch ${REF_BATCH} \
|
| 44 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.75-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 45 |
+
> eval_threshold_0.75.log 2>&1 &
|
| 46 |
+
|
| 47 |
+
# Evaluate threshold 1.0 on GPU 5
|
| 48 |
+
CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \
|
| 49 |
+
--ref_batch ${REF_BATCH} \
|
| 50 |
+
--sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-1.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
|
| 51 |
+
> eval_threshold_1.0.log 2>&1 &
|
| 52 |
+
|
| 53 |
+
# Wait for all background jobs to complete
|
| 54 |
+
echo "All evaluation tasks started. Waiting for completion..."
|
| 55 |
+
wait
|
| 56 |
+
|
| 57 |
+
echo "All evaluation tasks completed!"
|
| 58 |
+
echo ""
|
| 59 |
+
echo "Results saved in:"
|
| 60 |
+
echo " - eval_threshold_0.0.log"
|
| 61 |
+
echo " - eval_threshold_0.15.log"
|
| 62 |
+
echo " - eval_threshold_0.25.log"
|
| 63 |
+
echo " - eval_threshold_0.5.log"
|
| 64 |
+
echo " - eval_threshold_0.75.log"
|
| 65 |
+
echo " - eval_threshold_1.0.log"
|
GVP/Baseline/evaluator.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 5 |
+
import random
|
| 6 |
+
import warnings
|
| 7 |
+
import zipfile
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
from multiprocessing import cpu_count
|
| 12 |
+
from multiprocessing.pool import ThreadPool
|
| 13 |
+
from typing import Iterable, Optional, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import requests
|
| 17 |
+
import tensorflow.compat.v1 as tf
|
| 18 |
+
from scipy import linalg
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
from datetime import timedelta
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 26 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 27 |
+
|
| 28 |
+
FID_POOL_NAME = "pool_3:0"
|
| 29 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument("--ref_batch", default='/gemini/space/gzy_new/models/reference/VIRTUAL_imagenet256_labeled.npz',help="path to reference batch npz file")
|
| 35 |
+
parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Baseline/GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04.npz', help="path to sample batch npz file")
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
config = tf.ConfigProto(
|
| 39 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 40 |
+
)
|
| 41 |
+
config.gpu_options.allow_growth = True
|
| 42 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 43 |
+
|
| 44 |
+
print("warming up TensorFlow...")
|
| 45 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 46 |
+
# than after the next print(), to help prevent confusion.
|
| 47 |
+
evaluator.warmup()
|
| 48 |
+
|
| 49 |
+
print("computing reference batch activations...")
|
| 50 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 51 |
+
print("computing/reading reference batch statistics...")
|
| 52 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 53 |
+
|
| 54 |
+
print("computing sample batch activations...")
|
| 55 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 56 |
+
print("computing/reading sample batch statistics...")
|
| 57 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 58 |
+
|
| 59 |
+
print("Computing evaluations...")
|
| 60 |
+
print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
|
| 61 |
+
print("FID:", sample_stats.frechet_distance(ref_stats))
|
| 62 |
+
print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
|
| 63 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 64 |
+
print("Precision:", prec)
|
| 65 |
+
print("Recall:", recall)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class InvalidFIDException(Exception):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FIDStatistics:
|
| 73 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 74 |
+
self.mu = mu
|
| 75 |
+
self.sigma = sigma
|
| 76 |
+
|
| 77 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 78 |
+
"""
|
| 79 |
+
Compute the Frechet distance between two sets of statistics.
|
| 80 |
+
"""
|
| 81 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 82 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 83 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 84 |
+
|
| 85 |
+
mu1 = np.atleast_1d(mu1)
|
| 86 |
+
mu2 = np.atleast_1d(mu2)
|
| 87 |
+
|
| 88 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 89 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 90 |
+
|
| 91 |
+
assert (
|
| 92 |
+
mu1.shape == mu2.shape
|
| 93 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 94 |
+
assert (
|
| 95 |
+
sigma1.shape == sigma2.shape
|
| 96 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 97 |
+
|
| 98 |
+
diff = mu1 - mu2
|
| 99 |
+
|
| 100 |
+
# product might be almost singular
|
| 101 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 102 |
+
if not np.isfinite(covmean).all():
|
| 103 |
+
msg = (
|
| 104 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 105 |
+
% eps
|
| 106 |
+
)
|
| 107 |
+
warnings.warn(msg)
|
| 108 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 109 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 110 |
+
|
| 111 |
+
# numerical error might give slight imaginary component
|
| 112 |
+
#虚部报错部分
|
| 113 |
+
if np.iscomplexobj(covmean):
|
| 114 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1):
|
| 115 |
+
m = np.max(np.abs(covmean.imag))
|
| 116 |
+
print(f"Real component: {covmean.real}")
|
| 117 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 118 |
+
covmean = covmean.real
|
| 119 |
+
|
| 120 |
+
tr_covmean = np.trace(covmean)
|
| 121 |
+
|
| 122 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Evaluator:
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
session,
|
| 129 |
+
batch_size=64,
|
| 130 |
+
softmax_batch_size=512,
|
| 131 |
+
):
|
| 132 |
+
self.sess = session
|
| 133 |
+
self.batch_size = batch_size
|
| 134 |
+
self.softmax_batch_size = softmax_batch_size
|
| 135 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 136 |
+
with self.sess.graph.as_default():
|
| 137 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 138 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 139 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 140 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 141 |
+
|
| 142 |
+
def warmup(self):
|
| 143 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 144 |
+
|
| 145 |
+
def read_activations(self, npz_path: Union[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 146 |
+
if isinstance(npz_path, str):
|
| 147 |
+
# If npz_path is a string, treat it as a file path and read the .npz file
|
| 148 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 149 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 150 |
+
elif isinstance(npz_path, np.ndarray):
|
| 151 |
+
# If npz_path is a numpy array, split it into batches manually
|
| 152 |
+
print("--------line 140-----------")
|
| 153 |
+
batches = np.array_split(npz_path, range(self.batch_size, npz_path.shape[0], self.batch_size))
|
| 154 |
+
print("--------line 143-----------")
|
| 155 |
+
return self.compute_activations(batches)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 161 |
+
"""
|
| 162 |
+
Compute image features for downstream evals.
|
| 163 |
+
|
| 164 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 165 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 166 |
+
dimension. The tuple is (pool_3, spatial).
|
| 167 |
+
"""
|
| 168 |
+
preds = []
|
| 169 |
+
spatial_preds = []
|
| 170 |
+
for batch in tqdm(batches):
|
| 171 |
+
# print("--------line 164-----------")
|
| 172 |
+
|
| 173 |
+
# # 识别当前进程信息
|
| 174 |
+
# if 'RANK' in os.environ:
|
| 175 |
+
# rank = int(os.environ['RANK'])
|
| 176 |
+
# local_rank = int(os.environ.get('LOCAL_RANK', rank % torch.cuda.device_count()))
|
| 177 |
+
# print(f"Distributed training - Global Rank: {rank}, Local Rank: {local_rank}")
|
| 178 |
+
# print(f"Current GPU device: {torch.cuda.current_device()}" if torch.cuda.is_available() else "No CUDA")
|
| 179 |
+
# else:
|
| 180 |
+
# print("Single process mode")
|
| 181 |
+
|
| 182 |
+
# print(f"Process PID: {os.getpid()}")
|
| 183 |
+
|
| 184 |
+
batch = batch.astype(np.float32)
|
| 185 |
+
pred, spatial_pred = self.sess.run(
|
| 186 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 187 |
+
)
|
| 188 |
+
# print("--------line 169-----------")
|
| 189 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 190 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 191 |
+
return (
|
| 192 |
+
np.concatenate(preds, axis=0),
|
| 193 |
+
np.concatenate(spatial_preds, axis=0),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def read_statistics(
|
| 197 |
+
self, npz_path: Union[str, np.ndarray], activations: Tuple[np.ndarray, np.ndarray]
|
| 198 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 199 |
+
if isinstance(npz_path, str):
|
| 200 |
+
obj = np.load(npz_path)
|
| 201 |
+
if "mu" in list(obj.keys()):
|
| 202 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 203 |
+
obj["mu_s"], obj["sigma_s"]
|
| 204 |
+
)
|
| 205 |
+
elif isinstance(npz_path, np.ndarray):
|
| 206 |
+
obj = npz_path
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)")
|
| 209 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 210 |
+
|
| 211 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 212 |
+
mu = np.mean(activations, axis=0)
|
| 213 |
+
sigma = np.cov(activations, rowvar=False)
|
| 214 |
+
return FIDStatistics(mu, sigma)
|
| 215 |
+
|
| 216 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 217 |
+
softmax_out = []
|
| 218 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 219 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 220 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 221 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 222 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 223 |
+
scores = []
|
| 224 |
+
for i in range(0, len(preds), split_size):
|
| 225 |
+
part = preds[i : i + split_size]
|
| 226 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 227 |
+
kl = np.mean(np.sum(kl, 1))
|
| 228 |
+
scores.append(np.exp(kl))
|
| 229 |
+
return float(np.mean(scores))
|
| 230 |
+
|
| 231 |
+
def compute_prec_recall(
|
| 232 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 233 |
+
) -> Tuple[float, float]:
|
| 234 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 235 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 236 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 237 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 238 |
+
)
|
| 239 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class ManifoldEstimator:
|
| 243 |
+
"""
|
| 244 |
+
A helper for comparing manifolds of feature vectors.
|
| 245 |
+
|
| 246 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(
|
| 250 |
+
self,
|
| 251 |
+
session,
|
| 252 |
+
row_batch_size=10000,
|
| 253 |
+
col_batch_size=10000,
|
| 254 |
+
nhood_sizes=(3,),
|
| 255 |
+
clamp_to_percentile=None,
|
| 256 |
+
eps=1e-5,
|
| 257 |
+
):
|
| 258 |
+
"""
|
| 259 |
+
Estimate the manifold of given feature vectors.
|
| 260 |
+
|
| 261 |
+
:param session: the TensorFlow session.
|
| 262 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 263 |
+
(parameter to trade-off between memory usage and performance).
|
| 264 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 265 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 266 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 267 |
+
the given percentile.
|
| 268 |
+
:param eps: small number for numerical stability.
|
| 269 |
+
"""
|
| 270 |
+
self.distance_block = DistanceBlock(session)
|
| 271 |
+
self.row_batch_size = row_batch_size
|
| 272 |
+
self.col_batch_size = col_batch_size
|
| 273 |
+
self.nhood_sizes = nhood_sizes
|
| 274 |
+
self.num_nhoods = len(nhood_sizes)
|
| 275 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 276 |
+
self.eps = eps
|
| 277 |
+
|
| 278 |
+
def warmup(self):
|
| 279 |
+
feats, radii = (
|
| 280 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 281 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 282 |
+
)
|
| 283 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 284 |
+
|
| 285 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 286 |
+
num_images = len(features)
|
| 287 |
+
|
| 288 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 289 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 290 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 291 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 292 |
+
|
| 293 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 294 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 295 |
+
row_batch = features[begin1:end1]
|
| 296 |
+
|
| 297 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 298 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 299 |
+
col_batch = features[begin2:end2]
|
| 300 |
+
|
| 301 |
+
# Compute distances between batches.
|
| 302 |
+
distance_batch[
|
| 303 |
+
0 : end1 - begin1, begin2:end2
|
| 304 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 305 |
+
|
| 306 |
+
# Find the k-nearest neighbor from the current batch.
|
| 307 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 308 |
+
[
|
| 309 |
+
x[:, self.nhood_sizes]
|
| 310 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 311 |
+
],
|
| 312 |
+
axis=0,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if self.clamp_to_percentile is not None:
|
| 316 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 317 |
+
radii[radii > max_distances] = 0
|
| 318 |
+
return radii
|
| 319 |
+
|
| 320 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 321 |
+
"""
|
| 322 |
+
Evaluate if new feature vectors are at the manifold.
|
| 323 |
+
"""
|
| 324 |
+
num_eval_images = eval_features.shape[0]
|
| 325 |
+
num_ref_images = radii.shape[0]
|
| 326 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 327 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 328 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 329 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 330 |
+
|
| 331 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 332 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 333 |
+
feature_batch = eval_features[begin1:end1]
|
| 334 |
+
|
| 335 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 336 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 337 |
+
ref_batch = features[begin2:end2]
|
| 338 |
+
|
| 339 |
+
distance_batch[
|
| 340 |
+
0 : end1 - begin1, begin2:end2
|
| 341 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 342 |
+
|
| 343 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 344 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 345 |
+
# the new sample lies at the estimated manifold.
|
| 346 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 347 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 348 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 349 |
+
|
| 350 |
+
max_realism_score[begin1:end1] = np.max(
|
| 351 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 352 |
+
)
|
| 353 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 354 |
+
|
| 355 |
+
return {
|
| 356 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 357 |
+
"batch_predictions": batch_predictions,
|
| 358 |
+
"max_realisim_score": max_realism_score,
|
| 359 |
+
"nearest_indices": nearest_indices,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
def evaluate_pr(
|
| 363 |
+
self,
|
| 364 |
+
features_1: np.ndarray,
|
| 365 |
+
radii_1: np.ndarray,
|
| 366 |
+
features_2: np.ndarray,
|
| 367 |
+
radii_2: np.ndarray,
|
| 368 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 369 |
+
"""
|
| 370 |
+
Evaluate precision and recall efficiently.
|
| 371 |
+
|
| 372 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 373 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 374 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 375 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 376 |
+
:return: a tuple of arrays for (precision, recall):
|
| 377 |
+
- precision: an np.ndarray of length K1
|
| 378 |
+
- recall: an np.ndarray of length K2
|
| 379 |
+
"""
|
| 380 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
| 381 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
| 382 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 383 |
+
end_1 = begin_1 + self.row_batch_size
|
| 384 |
+
batch_1 = features_1[begin_1:end_1]
|
| 385 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 386 |
+
end_2 = begin_2 + self.col_batch_size
|
| 387 |
+
batch_2 = features_2[begin_2:end_2]
|
| 388 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 389 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 390 |
+
)
|
| 391 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 392 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 393 |
+
return (
|
| 394 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 395 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class DistanceBlock:
|
| 400 |
+
"""
|
| 401 |
+
Calculate pairwise distances between vectors.
|
| 402 |
+
|
| 403 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(self, session):
|
| 407 |
+
self.session = session
|
| 408 |
+
|
| 409 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 410 |
+
with session.graph.as_default():
|
| 411 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 412 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 413 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 414 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 415 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 416 |
+
)
|
| 417 |
+
self.distance_block = tf.cond(
|
| 418 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 419 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 420 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Extra logic for less thans.
|
| 424 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 425 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 426 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 427 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 428 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 429 |
+
|
| 430 |
+
def pairwise_distances(self, U, V):
|
| 431 |
+
"""
|
| 432 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 433 |
+
"""
|
| 434 |
+
return self.session.run(
|
| 435 |
+
self.distance_block,
|
| 436 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 440 |
+
return self.session.run(
|
| 441 |
+
[self._batch_1_in, self._batch_2_in],
|
| 442 |
+
feed_dict={
|
| 443 |
+
self._features_batch1: batch_1,
|
| 444 |
+
self._features_batch2: batch_2,
|
| 445 |
+
self._radii1: radii_1,
|
| 446 |
+
self._radii2: radii_2,
|
| 447 |
+
},
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _batch_pairwise_distances(U, V):
|
| 452 |
+
"""
|
| 453 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 454 |
+
"""
|
| 455 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 456 |
+
# Squared norms of each row in U and V.
|
| 457 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 458 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 459 |
+
|
| 460 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 461 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 462 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 463 |
+
|
| 464 |
+
# Pairwise squared Euclidean distances.
|
| 465 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 466 |
+
|
| 467 |
+
return D
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class NpzArrayReader(ABC):
|
| 471 |
+
@abstractmethod
|
| 472 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 473 |
+
pass
|
| 474 |
+
|
| 475 |
+
@abstractmethod
|
| 476 |
+
def remaining(self) -> int:
|
| 477 |
+
pass
|
| 478 |
+
|
| 479 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 480 |
+
def gen_fn():
|
| 481 |
+
while True:
|
| 482 |
+
batch = self.read_batch(batch_size)
|
| 483 |
+
if batch is None:
|
| 484 |
+
break
|
| 485 |
+
yield batch
|
| 486 |
+
|
| 487 |
+
rem = self.remaining()
|
| 488 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 489 |
+
return BatchIterator(gen_fn, num_batches)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class BatchIterator:
|
| 493 |
+
def __init__(self, gen_fn, length):
|
| 494 |
+
self.gen_fn = gen_fn
|
| 495 |
+
self.length = length
|
| 496 |
+
|
| 497 |
+
def __len__(self):
|
| 498 |
+
return self.length
|
| 499 |
+
|
| 500 |
+
def __iter__(self):
|
| 501 |
+
return self.gen_fn()
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 505 |
+
def __init__(self, arr_f, shape, dtype):
|
| 506 |
+
self.arr_f = arr_f
|
| 507 |
+
self.shape = shape
|
| 508 |
+
self.dtype = dtype
|
| 509 |
+
self.idx = 0
|
| 510 |
+
|
| 511 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 512 |
+
if self.idx >= self.shape[0]:
|
| 513 |
+
return None
|
| 514 |
+
|
| 515 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 516 |
+
self.idx += bs
|
| 517 |
+
|
| 518 |
+
if self.dtype.itemsize == 0:
|
| 519 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 520 |
+
|
| 521 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 522 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 523 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 524 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 525 |
+
|
| 526 |
+
def remaining(self) -> int:
|
| 527 |
+
return max(0, self.shape[0] - self.idx)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 531 |
+
def __init__(self, arr):
|
| 532 |
+
self.arr = arr
|
| 533 |
+
self.idx = 0
|
| 534 |
+
|
| 535 |
+
@classmethod
|
| 536 |
+
def load(cls, path: str, arr_name: str):
|
| 537 |
+
with open(path, "rb") as f:
|
| 538 |
+
arr = np.load(f)[arr_name]
|
| 539 |
+
return cls(arr)
|
| 540 |
+
|
| 541 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 542 |
+
if self.idx >= self.arr.shape[0]:
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 546 |
+
self.idx += batch_size
|
| 547 |
+
return res
|
| 548 |
+
|
| 549 |
+
def remaining(self) -> int:
|
| 550 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@contextmanager
|
| 554 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 555 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 556 |
+
version = np.lib.format.read_magic(arr_f)
|
| 557 |
+
if version == (1, 0):
|
| 558 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 559 |
+
elif version == (2, 0):
|
| 560 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 561 |
+
else:
|
| 562 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 563 |
+
return
|
| 564 |
+
shape, fortran, dtype = header
|
| 565 |
+
if fortran or dtype.hasobject:
|
| 566 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 567 |
+
else:
|
| 568 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 572 |
+
"""
|
| 573 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 574 |
+
|
| 575 |
+
Read from file-like object until size bytes are read.
|
| 576 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 577 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 578 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 579 |
+
requested.
|
| 580 |
+
"""
|
| 581 |
+
data = bytes()
|
| 582 |
+
while True:
|
| 583 |
+
# io files (default in python3) return None or raise on
|
| 584 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 585 |
+
# done about that. note that regular files can't be non-blocking
|
| 586 |
+
try:
|
| 587 |
+
r = fp.read(size - len(data))
|
| 588 |
+
data += r
|
| 589 |
+
if len(r) == 0 or len(data) == size:
|
| 590 |
+
break
|
| 591 |
+
except io.BlockingIOError:
|
| 592 |
+
pass
|
| 593 |
+
if len(data) != size:
|
| 594 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 595 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 596 |
+
else:
|
| 597 |
+
return data
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
@contextmanager
|
| 601 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 602 |
+
with open(path, "rb") as f:
|
| 603 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 604 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 605 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 606 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 607 |
+
yield arr_f
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def _download_inception_model():
|
| 611 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 612 |
+
return
|
| 613 |
+
print("downloading InceptionV3 model...")
|
| 614 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 615 |
+
r.raise_for_status()
|
| 616 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 617 |
+
with open(tmp_path, "wb") as f:
|
| 618 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 619 |
+
f.write(chunk)
|
| 620 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def _create_feature_graph(input_batch):
|
| 624 |
+
_download_inception_model()
|
| 625 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 626 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 627 |
+
graph_def = tf.GraphDef()
|
| 628 |
+
graph_def.ParseFromString(f.read())
|
| 629 |
+
pool3, spatial = tf.import_graph_def(
|
| 630 |
+
graph_def,
|
| 631 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 632 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 633 |
+
name=prefix,
|
| 634 |
+
)
|
| 635 |
+
_update_shapes(pool3)
|
| 636 |
+
spatial = spatial[..., :7]
|
| 637 |
+
return pool3, spatial
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def _create_softmax_graph(input_batch):
|
| 641 |
+
_download_inception_model()
|
| 642 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 643 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 644 |
+
graph_def = tf.GraphDef()
|
| 645 |
+
graph_def.ParseFromString(f.read())
|
| 646 |
+
(matmul,) = tf.import_graph_def(
|
| 647 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 648 |
+
)
|
| 649 |
+
w = matmul.inputs[1]
|
| 650 |
+
logits = tf.matmul(input_batch, w)
|
| 651 |
+
return tf.nn.softmax(logits)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def _update_shapes(pool3):
|
| 655 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 656 |
+
ops = pool3.graph.get_operations()
|
| 657 |
+
for op in ops:
|
| 658 |
+
for o in op.outputs:
|
| 659 |
+
shape = o.get_shape()
|
| 660 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 661 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 662 |
+
shape = [s for s in shape] # TF 2.x
|
| 663 |
+
new_shape = []
|
| 664 |
+
for j, s in enumerate(shape):
|
| 665 |
+
if s == 1 and j == 0:
|
| 666 |
+
new_shape.append(None)
|
| 667 |
+
else:
|
| 668 |
+
new_shape.append(s)
|
| 669 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 670 |
+
return pool3
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 674 |
+
num_workers = min(cpu_count(), len(arr))
|
| 675 |
+
chunk_size = len(arr) // num_workers
|
| 676 |
+
extra = len(arr) % num_workers
|
| 677 |
+
|
| 678 |
+
start_idx = 0
|
| 679 |
+
batches = []
|
| 680 |
+
for i in range(num_workers):
|
| 681 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 682 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 683 |
+
start_idx += size
|
| 684 |
+
|
| 685 |
+
with ThreadPool(num_workers) as pool:
|
| 686 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
if __name__ == "__main__":
|
| 690 |
+
main()
|
GVP/Baseline/gvp_sampling.log
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/1042 [00:00<?, ?it/s]
|
| 1 |
0%| | 1/1042 [00:59<17:12:21, 59.50s/it]
|
| 2 |
0%| | 2/1042 [01:58<17:01:08, 58.91s/it]
|
| 3 |
0%| | 3/1042 [02:57<17:02:36, 59.05s/it]
|
| 4 |
0%| | 4/1042 [03:56<17:00:31, 58.99s/it]
|
| 5 |
0%| | 5/1042 [04:52<16:42:43, 58.02s/it]
|
| 6 |
1%| | 6/1042 [05:51<16:47:28, 58.35s/it]
|
| 7 |
1%| | 7/1042 [06:50<16:53:25, 58.75s/it]
|
| 8 |
1%| | 8/1042 [07:49<16:49:35, 58.58s/it]
|
| 9 |
1%| | 9/1042 [08:48<16:51:57, 58.78s/it]
|
| 10 |
1%| | 10/1042 [09:46<16:49:39, 58.70s/it]
|
| 11 |
1%| | 11/1042 [10:45<16:48:10, 58.67s/it]
|
| 12 |
1%| | 12/1042 [11:45<16:53:10, 59.02s/it]
|
| 13 |
1%| | 13/1042 [12:44<16:52:36, 59.04s/it]
|
| 14 |
1%|▏ | 14/1042 [13:43<16:50:11, 58.96s/it]W0407 16:34:53.638000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2845 closing signal SIGTERM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 6 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 7 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 8 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 9 |
+
Starting rank=0, seed=0, world_size=4.
|
| 10 |
+
Starting rank=2, seed=2, world_size=4.
|
| 11 |
+
Starting rank=1, seed=1, world_size=4.
|
| 12 |
+
Starting rank=3, seed=3, world_size=4.
|
| 13 |
+
[rank1]:[W407 16:20:17.131912166 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 14 |
+
[rank3]:[W407 16:20:17.153628536 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 15 |
+
Saving .png samples at baseline_gvp_/SiT-XL-2-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04
|
| 16 |
+
[rank0]:[W407 16:20:17.306737681 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 17 |
+
[rank2]:[W407 16:20:18.780347929 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 18 |
+
Total number of images that will be sampled: 50016
|
| 19 |
+
|
| 20 |
0%| | 0/1042 [00:00<?, ?it/s]
|
| 21 |
0%| | 1/1042 [00:59<17:12:21, 59.50s/it]
|
| 22 |
0%| | 2/1042 [01:58<17:01:08, 58.91s/it]
|
| 23 |
0%| | 3/1042 [02:57<17:02:36, 59.05s/it]
|
| 24 |
0%| | 4/1042 [03:56<17:00:31, 58.99s/it]
|
| 25 |
0%| | 5/1042 [04:52<16:42:43, 58.02s/it]
|
| 26 |
1%| | 6/1042 [05:51<16:47:28, 58.35s/it]
|
| 27 |
1%| | 7/1042 [06:50<16:53:25, 58.75s/it]
|
| 28 |
1%| | 8/1042 [07:49<16:49:35, 58.58s/it]
|
| 29 |
1%| | 9/1042 [08:48<16:51:57, 58.78s/it]
|
| 30 |
1%| | 10/1042 [09:46<16:49:39, 58.70s/it]
|
| 31 |
1%| | 11/1042 [10:45<16:48:10, 58.67s/it]
|
| 32 |
1%| | 12/1042 [11:45<16:53:10, 59.02s/it]
|
| 33 |
1%| | 13/1042 [12:44<16:52:36, 59.04s/it]
|
| 34 |
1%|▏ | 14/1042 [13:43<16:50:11, 58.96s/it]W0407 16:34:53.638000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2845 closing signal SIGTERM
|
| 35 |
+
W0407 16:34:53.639000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2847 closing signal SIGTERM
|
| 36 |
+
W0407 16:34:53.639000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2848 closing signal SIGTERM
|
| 37 |
+
E0407 16:34:53.854000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -9) local_rank: 1 (pid: 2846) of binary: /root/miniconda3/envs/SiT/bin/python3.10
|
| 38 |
+
Traceback (most recent call last):
|
| 39 |
+
File "/root/miniconda3/envs/SiT/bin/torchrun", line 6, in <module>
|
| 40 |
+
sys.exit(main())
|
| 41 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 42 |
+
return f(*args, **kwargs)
|
| 43 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
|
| 44 |
+
run(args)
|
| 45 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
|
| 46 |
+
elastic_launch(
|
| 47 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 48 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 49 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
|
| 50 |
+
raise ChildFailedError(
|
| 51 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 52 |
+
==========================================================
|
| 53 |
+
sample_ddp.py FAILED
|
| 54 |
+
----------------------------------------------------------
|
| 55 |
+
Failures:
|
| 56 |
+
<NO_OTHER_FAILURES>
|
| 57 |
+
----------------------------------------------------------
|
| 58 |
+
Root Cause (first observed failure):
|
| 59 |
+
[0]:
|
| 60 |
+
time : 2026-04-07_16:34:53
|
| 61 |
+
host : 280c8972fe62c4ab251b3c74bd05a546-taskrole1-0
|
| 62 |
+
rank : 1 (local_rank: 1)
|
| 63 |
+
exitcode : -9 (pid: 2846)
|
| 64 |
+
error_file: <N/A>
|
| 65 |
+
traceback : Signal 9 (SIGKILL) received by PID 2846
|
| 66 |
+
==========================================================
|
GVP/Baseline/models.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# References:
|
| 5 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 6 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def modulate(x, shift, scale):
|
| 17 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#################################################################################
|
| 21 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 22 |
+
#################################################################################
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Embeds scalar timesteps into vector representations.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 32 |
+
nn.SiLU(),
|
| 33 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 34 |
+
)
|
| 35 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 39 |
+
"""
|
| 40 |
+
Create sinusoidal timestep embeddings.
|
| 41 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 42 |
+
These may be fractional.
|
| 43 |
+
:param dim: the dimension of the output.
|
| 44 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 45 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 46 |
+
"""
|
| 47 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 48 |
+
half = dim // 2
|
| 49 |
+
freqs = torch.exp(
|
| 50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 51 |
+
).to(device=t.device)
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
return embedding
|
| 57 |
+
|
| 58 |
+
def forward(self, t):
|
| 59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 60 |
+
t_emb = self.mlp(t_freq)
|
| 61 |
+
return t_emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LabelEmbedder(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 69 |
+
super().__init__()
|
| 70 |
+
use_cfg_embedding = dropout_prob > 0
|
| 71 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 72 |
+
self.num_classes = num_classes
|
| 73 |
+
self.dropout_prob = dropout_prob
|
| 74 |
+
|
| 75 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 76 |
+
"""
|
| 77 |
+
Drops labels to enable classifier-free guidance.
|
| 78 |
+
"""
|
| 79 |
+
if force_drop_ids is None:
|
| 80 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 81 |
+
else:
|
| 82 |
+
drop_ids = force_drop_ids == 1
|
| 83 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 84 |
+
return labels
|
| 85 |
+
|
| 86 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 87 |
+
use_dropout = self.dropout_prob > 0
|
| 88 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 89 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 90 |
+
embeddings = self.embedding_table(labels)
|
| 91 |
+
return embeddings
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
#################################################################################
|
| 95 |
+
# Core SiT Model #
|
| 96 |
+
#################################################################################
|
| 97 |
+
|
| 98 |
+
class SiTBlock(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 101 |
+
"""
|
| 102 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 105 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 106 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 107 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 108 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 109 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 110 |
+
self.adaLN_modulation = nn.Sequential(
|
| 111 |
+
nn.SiLU(),
|
| 112 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x, c):
|
| 116 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 117 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 118 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class FinalLayer(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
The final layer of SiT.
|
| 125 |
+
"""
|
| 126 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 129 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 130 |
+
self.adaLN_modulation = nn.Sequential(
|
| 131 |
+
nn.SiLU(),
|
| 132 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, x, c):
|
| 136 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 137 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 138 |
+
x = self.linear(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SiT(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Diffusion model with a Transformer backbone.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
input_size=32,
|
| 149 |
+
patch_size=2,
|
| 150 |
+
in_channels=4,
|
| 151 |
+
hidden_size=1152,
|
| 152 |
+
depth=28,
|
| 153 |
+
num_heads=16,
|
| 154 |
+
mlp_ratio=4.0,
|
| 155 |
+
class_dropout_prob=0.1,
|
| 156 |
+
num_classes=1000,
|
| 157 |
+
learn_sigma=True,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.learn_sigma = learn_sigma
|
| 161 |
+
self.learn_sigma = True
|
| 162 |
+
self.in_channels = in_channels
|
| 163 |
+
self.out_channels = in_channels * 2
|
| 164 |
+
self.patch_size = patch_size
|
| 165 |
+
self.num_heads = num_heads
|
| 166 |
+
|
| 167 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 168 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 169 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 170 |
+
num_patches = self.x_embedder.num_patches
|
| 171 |
+
# Will use fixed sin-cos embedding:
|
| 172 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 173 |
+
|
| 174 |
+
self.blocks = nn.ModuleList([
|
| 175 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 176 |
+
])
|
| 177 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 178 |
+
self.initialize_weights()
|
| 179 |
+
|
| 180 |
+
def initialize_weights(self):
|
| 181 |
+
# Initialize transformer layers:
|
| 182 |
+
def _basic_init(module):
|
| 183 |
+
if isinstance(module, nn.Linear):
|
| 184 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 185 |
+
if module.bias is not None:
|
| 186 |
+
nn.init.constant_(module.bias, 0)
|
| 187 |
+
self.apply(_basic_init)
|
| 188 |
+
|
| 189 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 190 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 191 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 192 |
+
|
| 193 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 194 |
+
w = self.x_embedder.proj.weight.data
|
| 195 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 196 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 197 |
+
|
| 198 |
+
# Initialize label embedding table:
|
| 199 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 200 |
+
|
| 201 |
+
# Initialize timestep embedding MLP:
|
| 202 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 203 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 204 |
+
|
| 205 |
+
# Zero-out adaLN modulation layers in SiT blocks:
|
| 206 |
+
for block in self.blocks:
|
| 207 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 208 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 209 |
+
|
| 210 |
+
# Zero-out output layers:
|
| 211 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 212 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 213 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 214 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 215 |
+
|
| 216 |
+
def unpatchify(self, x):
|
| 217 |
+
"""
|
| 218 |
+
x: (N, T, patch_size**2 * C)
|
| 219 |
+
imgs: (N, H, W, C)
|
| 220 |
+
"""
|
| 221 |
+
c = self.out_channels
|
| 222 |
+
p = self.x_embedder.patch_size[0]
|
| 223 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 224 |
+
assert h * w == x.shape[1]
|
| 225 |
+
|
| 226 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 227 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 228 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 229 |
+
return imgs
|
| 230 |
+
|
| 231 |
+
def forward(self, x, t, y, return_act=False):
|
| 232 |
+
"""
|
| 233 |
+
Forward pass of SiT.
|
| 234 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 235 |
+
t: (N,) tensor of diffusion timesteps
|
| 236 |
+
y: (N,) tensor of class labels
|
| 237 |
+
return_act: if True, return activations from transformer blocks
|
| 238 |
+
"""
|
| 239 |
+
act = []
|
| 240 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 241 |
+
t = self.t_embedder(t) # (N, D)
|
| 242 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 243 |
+
c = t + y # (N, D)
|
| 244 |
+
for block in self.blocks:
|
| 245 |
+
x = block(x, c) # (N, T, D)
|
| 246 |
+
if return_act:
|
| 247 |
+
act.append(x)
|
| 248 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 249 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 250 |
+
if self.learn_sigma:
|
| 251 |
+
x, _ = x.chunk(2, dim=1)
|
| 252 |
+
if return_act:
|
| 253 |
+
return x, act
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 257 |
+
"""
|
| 258 |
+
Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
|
| 259 |
+
"""
|
| 260 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 261 |
+
half = x[: len(x) // 2]
|
| 262 |
+
combined = torch.cat([half, half], dim=0)
|
| 263 |
+
model_out = self.forward(combined, t, y)
|
| 264 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 265 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 266 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 267 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 268 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 269 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 270 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 271 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 272 |
+
return torch.cat([eps, rest], dim=1)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
#################################################################################
|
| 276 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 277 |
+
#################################################################################
|
| 278 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 279 |
+
|
| 280 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 281 |
+
"""
|
| 282 |
+
grid_size: int of the grid height and width
|
| 283 |
+
return:
|
| 284 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 285 |
+
"""
|
| 286 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 287 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 288 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 289 |
+
grid = np.stack(grid, axis=0)
|
| 290 |
+
|
| 291 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 292 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 293 |
+
if cls_token and extra_tokens > 0:
|
| 294 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 295 |
+
return pos_embed
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 299 |
+
assert embed_dim % 2 == 0
|
| 300 |
+
|
| 301 |
+
# use half of dimensions to encode grid_h
|
| 302 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 303 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 304 |
+
|
| 305 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 306 |
+
return emb
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 310 |
+
"""
|
| 311 |
+
embed_dim: output dimension for each position
|
| 312 |
+
pos: a list of positions to be encoded: size (M,)
|
| 313 |
+
out: (M, D)
|
| 314 |
+
"""
|
| 315 |
+
assert embed_dim % 2 == 0
|
| 316 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 317 |
+
omega /= embed_dim / 2.
|
| 318 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 319 |
+
|
| 320 |
+
pos = pos.reshape(-1) # (M,)
|
| 321 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 322 |
+
|
| 323 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 324 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 325 |
+
|
| 326 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 327 |
+
return emb
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
#################################################################################
|
| 331 |
+
# SiT Configs #
|
| 332 |
+
#################################################################################
|
| 333 |
+
|
| 334 |
+
def SiT_XL_2(**kwargs):
|
| 335 |
+
return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 336 |
+
|
| 337 |
+
def SiT_XL_4(**kwargs):
|
| 338 |
+
return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 339 |
+
|
| 340 |
+
def SiT_XL_8(**kwargs):
|
| 341 |
+
return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 342 |
+
|
| 343 |
+
def SiT_L_2(**kwargs):
|
| 344 |
+
return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 345 |
+
|
| 346 |
+
def SiT_L_4(**kwargs):
|
| 347 |
+
return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 348 |
+
|
| 349 |
+
def SiT_L_8(**kwargs):
|
| 350 |
+
return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 351 |
+
|
| 352 |
+
def SiT_B_2(**kwargs):
|
| 353 |
+
return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 354 |
+
|
| 355 |
+
def SiT_B_4(**kwargs):
|
| 356 |
+
return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 357 |
+
|
| 358 |
+
def SiT_B_8(**kwargs):
|
| 359 |
+
return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 360 |
+
|
| 361 |
+
def SiT_S_2(**kwargs):
|
| 362 |
+
return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 363 |
+
|
| 364 |
+
def SiT_S_4(**kwargs):
|
| 365 |
+
return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 366 |
+
|
| 367 |
+
def SiT_S_8(**kwargs):
|
| 368 |
+
return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
SiT_models = {
|
| 372 |
+
'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
|
| 373 |
+
'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
|
| 374 |
+
'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
|
| 375 |
+
'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
#################################################################################
|
| 379 |
+
# SiTF1, SiTF2, CombinedModel #
|
| 380 |
+
#################################################################################
|
| 381 |
+
|
| 382 |
+
class SiTF1(nn.Module):
|
| 383 |
+
"""
|
| 384 |
+
SiTF1 Model
|
| 385 |
+
"""
|
| 386 |
+
def __init__(
|
| 387 |
+
self,
|
| 388 |
+
input_size=32,
|
| 389 |
+
patch_size=2,
|
| 390 |
+
in_channels=4,
|
| 391 |
+
hidden_size=1152,
|
| 392 |
+
depth=28,
|
| 393 |
+
num_heads=16,
|
| 394 |
+
mlp_ratio=4.0,
|
| 395 |
+
class_dropout_prob=0.1,
|
| 396 |
+
num_classes=1000,
|
| 397 |
+
learn_sigma=True,
|
| 398 |
+
final_layer=None,
|
| 399 |
+
):
|
| 400 |
+
super().__init__()
|
| 401 |
+
self.input_size = input_size
|
| 402 |
+
self.patch_size= patch_size
|
| 403 |
+
self.hidden_size = hidden_size
|
| 404 |
+
self.in_channels = in_channels
|
| 405 |
+
self.out_channels = in_channels * 2
|
| 406 |
+
self.patch_size = patch_size
|
| 407 |
+
self.num_heads = num_heads
|
| 408 |
+
self.learn_sigma = learn_sigma
|
| 409 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 410 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 411 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 412 |
+
num_patches = self.x_embedder.num_patches
|
| 413 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 414 |
+
self.blocks = nn.ModuleList([
|
| 415 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 416 |
+
])
|
| 417 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 418 |
+
self.initialize_weights()
|
| 419 |
+
|
| 420 |
+
def unpatchify(self, x):
|
| 421 |
+
"""
|
| 422 |
+
x: (N, T, patch_size**2 * C)
|
| 423 |
+
imgs: (N, H, W, C)
|
| 424 |
+
"""
|
| 425 |
+
c = self.out_channels
|
| 426 |
+
p = self.x_embedder.patch_size[0]
|
| 427 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 428 |
+
assert h * w == x.shape[1]
|
| 429 |
+
|
| 430 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 431 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 432 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 433 |
+
return imgs
|
| 434 |
+
|
| 435 |
+
def initialize_weights(self):
|
| 436 |
+
def _basic_init(module):
|
| 437 |
+
if isinstance(module, nn.Linear):
|
| 438 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 439 |
+
if module.bias is not None:
|
| 440 |
+
nn.init.constant_(module.bias, 0)
|
| 441 |
+
self.apply(_basic_init)
|
| 442 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 443 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 444 |
+
w = self.x_embedder.proj.weight.data
|
| 445 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 446 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 447 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 448 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 449 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 450 |
+
for block in self.blocks:
|
| 451 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 452 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 453 |
+
|
| 454 |
+
def forward(self, x, t, y):
|
| 455 |
+
x = self.x_embedder(x) + self.pos_embed
|
| 456 |
+
t = self.t_embedder(t)
|
| 457 |
+
y = self.y_embedder(y, self.training)
|
| 458 |
+
c = t + y
|
| 459 |
+
for block in self.blocks:
|
| 460 |
+
x = block(x, c)
|
| 461 |
+
x_now = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 462 |
+
x_now = self.unpatchify(x_now) # (N, out_channels, H, W)
|
| 463 |
+
x_now, _ = x_now.chunk(2, dim=1)
|
| 464 |
+
return x,x_now # patch token (N, T, D)
|
| 465 |
+
|
| 466 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 467 |
+
"""
|
| 468 |
+
Forward pass with classifier-free guidance for SiTF1.
|
| 469 |
+
Applies guidance consistently to both patch tokens and image output (x_now).
|
| 470 |
+
"""
|
| 471 |
+
# Take the first half (conditional inputs) and duplicate it so that
|
| 472 |
+
# it can be paired with conditional and unconditional labels in `y`.
|
| 473 |
+
half = x[: len(x) // 2]
|
| 474 |
+
combined = torch.cat([half, half], dim=0)
|
| 475 |
+
patch_tokens, x_now = self.forward(combined, t, y)
|
| 476 |
+
|
| 477 |
+
# Apply CFG on the image output channels (first 3 channels by default)
|
| 478 |
+
eps, rest = x_now[:, :3, ...], x_now[:, 3:, ...]
|
| 479 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 480 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 481 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 482 |
+
x_now = torch.cat([eps, rest], dim=1)
|
| 483 |
+
|
| 484 |
+
# Apply same guidance logic to patch tokens so downstream modules see
|
| 485 |
+
# a consistent guided representation.
|
| 486 |
+
cond_tok, uncond_tok = torch.split(patch_tokens, len(patch_tokens) // 2, dim=0)
|
| 487 |
+
half_tok = uncond_tok + cfg_scale * (cond_tok - uncond_tok)
|
| 488 |
+
patch_tokens = torch.cat([half_tok, half_tok], dim=0)
|
| 489 |
+
|
| 490 |
+
return patch_tokens, x_now
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class SiTF2(nn.Module):
|
| 494 |
+
"""
|
| 495 |
+
SiTF2:
|
| 496 |
+
"""
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
input_size=32,
|
| 500 |
+
hidden_size=1152,
|
| 501 |
+
out_channels=8,
|
| 502 |
+
patch_size=2,
|
| 503 |
+
num_heads=16,
|
| 504 |
+
mlp_ratio=4.0,
|
| 505 |
+
depth=4,
|
| 506 |
+
learn_sigma=True,
|
| 507 |
+
final_layer=None,
|
| 508 |
+
num_classes=1000,
|
| 509 |
+
class_dropout_prob=0.1,
|
| 510 |
+
learn_mu=False,
|
| 511 |
+
):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.learn_sigma = learn_sigma
|
| 514 |
+
self.learn_mu = learn_mu
|
| 515 |
+
self.out_channels = out_channels
|
| 516 |
+
self.in_channels = 4
|
| 517 |
+
self.patch_size = patch_size
|
| 518 |
+
self.num_heads = num_heads
|
| 519 |
+
self.blocks = nn.ModuleList([
|
| 520 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 521 |
+
])
|
| 522 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, self.in_channels, hidden_size, bias=True)
|
| 523 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 524 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 525 |
+
num_patches = self.x_embedder.num_patches
|
| 526 |
+
self.num_patches = num_patches # Save original num_patches for unpatchify
|
| 527 |
+
# pos_embed needs to support 2*num_patches for concatenated input
|
| 528 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
|
| 529 |
+
# Initialize pos_embed with sin-cos embedding
|
| 530 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
|
| 531 |
+
# Repeat the pos_embed for both halves (or could use different embeddings)
|
| 532 |
+
pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
|
| 533 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
|
| 534 |
+
|
| 535 |
+
if final_layer is not None:
|
| 536 |
+
self.final_layer = final_layer
|
| 537 |
+
else:
|
| 538 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, out_channels)
|
| 539 |
+
if depth !=0:
|
| 540 |
+
for p in self.final_layer.parameters():
|
| 541 |
+
if p is not None:
|
| 542 |
+
torch.nn.init.constant_(p, 0)
|
| 543 |
+
|
| 544 |
+
def unpatchify(self, x, patch_size, out_channels):
|
| 545 |
+
c = out_channels
|
| 546 |
+
p = patch_size
|
| 547 |
+
# x.shape[1] might be 2*num_patches when using concatenated input
|
| 548 |
+
# Use original num_patches to calculate h and w
|
| 549 |
+
h = w = int(self.num_patches ** 0.5)
|
| 550 |
+
# If input has 2*num_patches, we need to handle it
|
| 551 |
+
if x.shape[1] == 2 * self.num_patches:
|
| 552 |
+
# Take only the first half (or average, or other strategy)
|
| 553 |
+
# For now, we'll take the first half
|
| 554 |
+
x = x[:, :self.num_patches, :]
|
| 555 |
+
assert h * w == x.shape[1], f"Expected {h * w} patches, got {x.shape[1]}"
|
| 556 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 557 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 558 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 559 |
+
return imgs
|
| 560 |
+
|
| 561 |
+
def forward(self, x, c, t, return_act=False):
|
| 562 |
+
act = []
|
| 563 |
+
for block in self.blocks:
|
| 564 |
+
x = block(x, c)
|
| 565 |
+
if return_act:
|
| 566 |
+
act.append(x)
|
| 567 |
+
x = self.final_layer(x, c)
|
| 568 |
+
x = self.unpatchify(x, self.patch_size, self.out_channels)
|
| 569 |
+
if self.learn_sigma:
|
| 570 |
+
mean_pred, log_var_pred = x.chunk(2, dim=1)
|
| 571 |
+
variance_pred = torch.exp(log_var_pred)
|
| 572 |
+
std_dev_pred = torch.sqrt(variance_pred)
|
| 573 |
+
noise = torch.randn_like(mean_pred)
|
| 574 |
+
#uniform_noise = torch.rand_like(mean_pred)
|
| 575 |
+
#uniform_noise = uniform_noise.clamp(min=1e-5, max=1-1e-5)
|
| 576 |
+
#gumbel_noise = -torch.log(-torch.log(uniform_noise))
|
| 577 |
+
|
| 578 |
+
if self.learn_mu==True:
|
| 579 |
+
resampled_x = mean_pred + std_dev_pred * noise
|
| 580 |
+
else:
|
| 581 |
+
resampled_x = std_dev_pred * noise
|
| 582 |
+
x = resampled_x
|
| 583 |
+
else:
|
| 584 |
+
x, _ = x.chunk(2, dim=1)
|
| 585 |
+
if return_act:
|
| 586 |
+
return x, act
|
| 587 |
+
return x
|
| 588 |
+
|
| 589 |
+
def forward_noise(self, x, c):
|
| 590 |
+
for block in self.blocks:
|
| 591 |
+
x = block(x, c)
|
| 592 |
+
x = self.final_layer(x, c)
|
| 593 |
+
x = self.unpatchify(x, self.patch_size, self.out_channels)
|
| 594 |
+
if self.learn_sigma:
|
| 595 |
+
mean_pred, log_var_pred = x.chunk(2, dim=1)
|
| 596 |
+
variance_pred = torch.exp(log_var_pred)
|
| 597 |
+
std_dev_pred = torch.sqrt(variance_pred)
|
| 598 |
+
noise = torch.randn_like(mean_pred)
|
| 599 |
+
if self.learn_mu==True:
|
| 600 |
+
resampled_x = mean_pred + std_dev_pred * noise
|
| 601 |
+
else:
|
| 602 |
+
resampled_x = std_dev_pred * noise
|
| 603 |
+
x = resampled_x
|
| 604 |
+
else:
|
| 605 |
+
x, _ = x.chunk(2, dim=1)
|
| 606 |
+
return x
|
| 607 |
+
|
| 608 |
+
#有两种写法,一种是拿理想的,一种是拿真实的,一种是拼接,一种是加和
|
| 609 |
+
class CombinedModel(nn.Module):
|
| 610 |
+
"""
|
| 611 |
+
CombinedModel。
|
| 612 |
+
"""
|
| 613 |
+
def __init__(self, sitf1: SiTF1, sitf2: SiTF2):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.sitf1 = sitf1
|
| 616 |
+
self.sitf2 = sitf2
|
| 617 |
+
input_size=self.sitf1.input_size
|
| 618 |
+
patch_size=self.sitf1.patch_size
|
| 619 |
+
hidden_size=self.sitf1.hidden_size
|
| 620 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, 4, hidden_size, bias=True)
|
| 621 |
+
num_patches = self.x_embedder.num_patches
|
| 622 |
+
# pos_embed needs to support 2*num_patches for concatenated input
|
| 623 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
|
| 624 |
+
# Initialize pos_embed with sin-cos embedding
|
| 625 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
|
| 626 |
+
# Repeat the pos_embed for both halves (or could use different embeddings)
|
| 627 |
+
pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
|
| 628 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
|
| 629 |
+
|
| 630 |
+
def forward(self, x, t, y, return_act=False):
|
| 631 |
+
patch_tokens,x_now = self.sitf1(x, t, y)
|
| 632 |
+
# Interpolate between x_now and x using timestep t: (1-t)*x_now + t*x
|
| 633 |
+
# t shape is (N,), need to broadcast to (N, 1, 1, 1) for broadcasting with image (N, C, H, W)
|
| 634 |
+
t_broadcast = t.view(-1, 1, 1, 1) # (N, 1, 1, 1)
|
| 635 |
+
# Compute interpolated input: (1-t)*x_now + t*x
|
| 636 |
+
x_interpolated = (1 - t_broadcast) * x_now + x
|
| 637 |
+
# Convert interpolated input (image format) back to patch token format (without pos_embed, will add later)
|
| 638 |
+
x_now_patches = self.x_embedder(x_interpolated)
|
| 639 |
+
# Concatenate patch_tokens and x_now_patches along the sequence dimension
|
| 640 |
+
concatenated_input = torch.cat([patch_tokens, x_now_patches], dim=1) # (N, 2*T, D)
|
| 641 |
+
# Add position embedding for the concatenated input
|
| 642 |
+
# Use the same pos_embed for both halves (or could use different embeddings)
|
| 643 |
+
concatenated_input = concatenated_input + self.pos_embed
|
| 644 |
+
t_emb = self.sitf1.t_embedder(t)
|
| 645 |
+
y_emb = self.sitf1.y_embedder(y, self.training)
|
| 646 |
+
c = t_emb + y_emb
|
| 647 |
+
return self.sitf2(concatenated_input, c, t, return_act=return_act)
|
GVP/Baseline/nohup.out
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/1042 [00:00<?, ?it/s]
|
| 1 |
0%| | 1/1042 [00:13<4:01:42, 13.93s/it]
|
| 2 |
0%| | 2/1042 [00:26<3:44:31, 12.95s/it]
|
| 3 |
0%| | 3/1042 [00:39<3:43:49, 12.93s/it]W0317 10:30:53.664000 11774 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received Signals.SIGINT death signal, shutting down workers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
0%| | 3/1042 [00:39<3:46:17, 13.07s/it]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 6 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 7 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 8 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 9 |
+
Starting rank=0, seed=0, world_size=4.
|
| 10 |
+
Starting rank=1, seed=1, world_size=4.
|
| 11 |
+
Starting rank=3, seed=3, world_size=4.
|
| 12 |
+
Starting rank=2, seed=2, world_size=4.
|
| 13 |
+
Saving .png samples at GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-100-Euler-sigma-Mean-0.04
|
| 14 |
+
Total number of images that will be sampled: 50016
|
| 15 |
+
|
| 16 |
0%| | 0/1042 [00:00<?, ?it/s]
|
| 17 |
0%| | 1/1042 [00:13<4:01:42, 13.93s/it]
|
| 18 |
0%| | 2/1042 [00:26<3:44:31, 12.95s/it]
|
| 19 |
0%| | 3/1042 [00:39<3:43:49, 12.93s/it]W0317 10:30:53.664000 11774 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received Signals.SIGINT death signal, shutting down workers
|
| 20 |
+
W0317 10:30:53.667000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11854 closing signal SIGINT
|
| 21 |
+
W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11855 closing signal SIGINT
|
| 22 |
+
W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11856 closing signal SIGINT
|
| 23 |
+
|
| 24 |
0%| | 3/1042 [00:39<3:46:17, 13.07s/it]
|
| 25 |
+
W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11857 closing signal SIGINT
|
| 26 |
+
[rank3]: Traceback (most recent call last):
|
| 27 |
+
[rank3]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
|
| 28 |
+
[rank3]: main(mode, args)
|
| 29 |
+
[rank3]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
|
| 30 |
+
[rank3]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 31 |
+
[rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
|
| 32 |
+
[rank3]: xs = _sde.sample(init, model, **model_kwargs)
|
| 33 |
+
[rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
|
| 34 |
+
[rank3]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 35 |
+
[rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 30, in __Euler_Maruyama_step
|
| 36 |
+
[rank3]: w_cur = th.randn(x.size()).to(x)
|
| 37 |
+
[rank3]: KeyboardInterrupt
|
| 38 |
+
[rank0]: Traceback (most recent call last):
|
| 39 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
|
| 40 |
+
[rank0]: main(mode, args)
|
| 41 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
|
| 42 |
+
[rank0]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 43 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
|
| 44 |
+
[rank0]: xs = _sde.sample(init, model, **model_kwargs)
|
| 45 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
|
| 46 |
+
[rank0]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 47 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
|
| 48 |
+
[rank0]: drift = self.drift(x, t, model, **model_kwargs)
|
| 49 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
|
| 50 |
+
[rank0]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 51 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 247, in body_fn
|
| 52 |
+
[rank0]: model_output = drift_fn(x, t, model, **model_kwargs)
|
| 53 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 236, in velocity_ode
|
| 54 |
+
[rank0]: model_output = model(x, t, **model_kwargs)
|
| 55 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
|
| 56 |
+
[rank0]: sit_out = base_model.forward(x, t, y)
|
| 57 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 245, in forward
|
| 58 |
+
[rank0]: x = block(x, c) # (N, T, D)
|
| 59 |
+
[rank0]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
|
| 60 |
+
[rank0]: return self._call_impl(*args, **kwargs)
|
| 61 |
+
[rank0]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
|
| 62 |
+
[rank0]: return forward_call(*args, **kwargs)
|
| 63 |
+
[rank0]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 117, in forward
|
| 64 |
+
[rank0]: x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 65 |
+
[rank0]: KeyboardInterrupt
|
| 66 |
+
[rank1]: Traceback (most recent call last):
|
| 67 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
|
| 68 |
+
[rank1]: main(mode, args)
|
| 69 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
|
| 70 |
+
[rank1]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 71 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
|
| 72 |
+
[rank1]: xs = _sde.sample(init, model, **model_kwargs)
|
| 73 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
|
| 74 |
+
[rank1]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 75 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
|
| 76 |
+
[rank1]: drift = self.drift(x, t, model, **model_kwargs)
|
| 77 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
|
| 78 |
+
[rank1]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 79 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 264, in <lambda>
|
| 80 |
+
[rank1]: score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
|
| 81 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
|
| 82 |
+
[rank1]: sit_out = base_model.forward(x, t, y)
|
| 83 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 241, in forward
|
| 84 |
+
[rank1]: t = self.t_embedder(t) # (N, D)
|
| 85 |
+
[rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
|
| 86 |
+
[rank1]: return self._call_impl(*args, **kwargs)
|
| 87 |
+
[rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
|
| 88 |
+
[rank1]: return forward_call(*args, **kwargs)
|
| 89 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 59, in forward
|
| 90 |
+
[rank1]: t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 91 |
+
[rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 49, in timestep_embedding
|
| 92 |
+
[rank1]: freqs = torch.exp(
|
| 93 |
+
[rank1]: KeyboardInterrupt
|
| 94 |
+
[rank2]: Traceback (most recent call last):
|
| 95 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
|
| 96 |
+
[rank2]: main(mode, args)
|
| 97 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
|
| 98 |
+
[rank2]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 99 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
|
| 100 |
+
[rank2]: xs = _sde.sample(init, model, **model_kwargs)
|
| 101 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
|
| 102 |
+
[rank2]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 103 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
|
| 104 |
+
[rank2]: drift = self.drift(x, t, model, **model_kwargs)
|
| 105 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
|
| 106 |
+
[rank2]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 107 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 264, in <lambda>
|
| 108 |
+
[rank2]: score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
|
| 109 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
|
| 110 |
+
[rank2]: sit_out = base_model.forward(x, t, y)
|
| 111 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 241, in forward
|
| 112 |
+
[rank2]: t = self.t_embedder(t) # (N, D)
|
| 113 |
+
[rank2]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
|
| 114 |
+
[rank2]: return self._call_impl(*args, **kwargs)
|
| 115 |
+
[rank2]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
|
| 116 |
+
[rank2]: return forward_call(*args, **kwargs)
|
| 117 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 59, in forward
|
| 118 |
+
[rank2]: t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 119 |
+
[rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 49, in timestep_embedding
|
| 120 |
+
[rank2]: freqs = torch.exp(
|
| 121 |
+
[rank2]: KeyboardInterrupt
|
| 122 |
+
W0317 10:30:53.820000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11854 closing signal SIGTERM
|
| 123 |
+
W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11855 closing signal SIGTERM
|
| 124 |
+
W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11856 closing signal SIGTERM
|
| 125 |
+
W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11857 closing signal SIGTERM
|
| 126 |
+
Traceback (most recent call last):
|
| 127 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
|
| 128 |
+
result = self._invoke_run(role)
|
| 129 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
|
| 130 |
+
time.sleep(monitor_interval)
|
| 131 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 132 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 133 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
|
| 134 |
+
|
| 135 |
+
During handling of the above exception, another exception occurred:
|
| 136 |
+
|
| 137 |
+
Traceback (most recent call last):
|
| 138 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 705, in run
|
| 139 |
+
self._shutdown(e.sigval)
|
| 140 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
|
| 141 |
+
self._pcontext.close(death_sig)
|
| 142 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
|
| 143 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 144 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
|
| 145 |
+
handler.proc.wait(time_to_wait)
|
| 146 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1209, in wait
|
| 147 |
+
return self._wait(timeout=timeout)
|
| 148 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1953, in _wait
|
| 149 |
+
time.sleep(delay)
|
| 150 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 151 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 152 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
|
| 153 |
+
|
| 154 |
+
During handling of the above exception, another exception occurred:
|
| 155 |
+
|
| 156 |
+
Traceback (most recent call last):
|
| 157 |
+
File "/root/miniconda3/envs/SiT/bin/torchrun", line 6, in <module>
|
| 158 |
+
sys.exit(main())
|
| 159 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 160 |
+
return f(*args, **kwargs)
|
| 161 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
|
| 162 |
+
run(args)
|
| 163 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
|
| 164 |
+
elastic_launch(
|
| 165 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 166 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 167 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
|
| 168 |
+
result = agent.run()
|
| 169 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
|
| 170 |
+
result = f(*args, **kwargs)
|
| 171 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 710, in run
|
| 172 |
+
self._shutdown()
|
| 173 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
|
| 174 |
+
self._pcontext.close(death_sig)
|
| 175 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
|
| 176 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 177 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
|
| 178 |
+
handler.proc.wait(time_to_wait)
|
| 179 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1209, in wait
|
| 180 |
+
return self._wait(timeout=timeout)
|
| 181 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1953, in _wait
|
| 182 |
+
time.sleep(delay)
|
| 183 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 184 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 185 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
|
GVP/Baseline/pic_npz.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
|
| 4 |
+
基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
|
| 5 |
+
支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
|
| 6 |
+
支持从 metadata.jsonl 文件读取图片路径
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import glob
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_npz_from_metadata(*args, **kwargs):
|
| 19 |
+
"""
|
| 20 |
+
占位函数:已废弃 metadata.jsonl 功能,保留空壳避免旧脚本导入时报错。
|
| 21 |
+
"""
|
| 22 |
+
raise RuntimeError("metadata.jsonl 功能已移除,请仅使用 --image_folder 方式生成 npz。")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
"""
|
| 28 |
+
主函数:解析命令行参数并执行图片到npz的转换
|
| 29 |
+
"""
|
| 30 |
+
parser = argparse.ArgumentParser(
|
| 31 |
+
description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
|
| 32 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 33 |
+
epilog="""
|
| 34 |
+
使用示例:
|
| 35 |
+
python pic_npz.py /path/to/image/folder
|
| 36 |
+
python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
|
| 37 |
+
"""
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--image_folder",
|
| 42 |
+
type=str,
|
| 43 |
+
default="/gemini/space/gzy_new/models/Baseline/GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04",
|
| 44 |
+
help="包含PNG或JPG图片文件的文件夹路径"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--output-dir",
|
| 49 |
+
type=str,
|
| 50 |
+
default=None,
|
| 51 |
+
help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
# 仅使用图片文件夹,不再支持 metadata.jsonl
|
| 58 |
+
image_folder_path = os.path.abspath(args.image_folder)
|
| 59 |
+
|
| 60 |
+
if args.output_dir:
|
| 61 |
+
# 如果指定了输出目录,修改生成逻辑
|
| 62 |
+
folder_name = os.path.basename(image_folder_path.rstrip('/'))
|
| 63 |
+
custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
|
| 64 |
+
|
| 65 |
+
# 创建输出目录(如果不存在)
|
| 66 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# 使用自定义输出路径版本
|
| 69 |
+
npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
|
| 70 |
+
else:
|
| 71 |
+
npz_path = create_npz_from_image_folder(image_folder_path)
|
| 72 |
+
|
| 73 |
+
print(f"转换完成!NPZ文件已保存至: {npz_path}")
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"错误: {e}")
|
| 77 |
+
return 1
|
| 78 |
+
|
| 79 |
+
return 0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_npz_from_image_folder_custom(image_folder_path, output_path):
|
| 83 |
+
"""
|
| 84 |
+
从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
image_folder_path (str): 包含图片文件的文件夹路径
|
| 88 |
+
output_path (str): 输出npz文件的完整路径
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
str: 生成的 npz 文件路径
|
| 92 |
+
"""
|
| 93 |
+
# 确保路径存在
|
| 94 |
+
if not os.path.exists(image_folder_path):
|
| 95 |
+
raise ValueError(f"文件夹路径不存在: {image_folder_path}")
|
| 96 |
+
|
| 97 |
+
# 获取所有支持的图片文件
|
| 98 |
+
supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
|
| 99 |
+
image_files = []
|
| 100 |
+
|
| 101 |
+
for extension in supported_extensions:
|
| 102 |
+
pattern = os.path.join(image_folder_path, extension)
|
| 103 |
+
image_files.extend(glob.glob(pattern))
|
| 104 |
+
|
| 105 |
+
# 按文件名排序确保一致性
|
| 106 |
+
image_files.sort()
|
| 107 |
+
|
| 108 |
+
if len(image_files) == 0:
|
| 109 |
+
raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
|
| 110 |
+
|
| 111 |
+
print(f"找到 {len(image_files)} 张图片文件")
|
| 112 |
+
|
| 113 |
+
# 读取所有图片
|
| 114 |
+
samples = []
|
| 115 |
+
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
|
| 116 |
+
try:
|
| 117 |
+
# 打开图片并转换为RGB格式(确保一致性)
|
| 118 |
+
with Image.open(img_path) as img:
|
| 119 |
+
# 转换为RGB,确保所有图片都是3通道
|
| 120 |
+
if img.mode != 'RGB':
|
| 121 |
+
img = img.convert('RGB')
|
| 122 |
+
|
| 123 |
+
# 将图片resize到512x512
|
| 124 |
+
img = img.resize((512, 512), Image.LANCZOS)
|
| 125 |
+
|
| 126 |
+
sample_np = np.asarray(img).astype(np.uint8)
|
| 127 |
+
|
| 128 |
+
# 确保图片是3通道
|
| 129 |
+
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
|
| 130 |
+
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
samples.append(sample_np)
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"警告: 无法读取图片 {img_path}: {e}")
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
if len(samples) == 0:
|
| 140 |
+
raise ValueError("没有成功读取任何有效的图片文件")
|
| 141 |
+
|
| 142 |
+
# 转换为numpy数组
|
| 143 |
+
samples = np.stack(samples)
|
| 144 |
+
print(f"成功��取 {len(samples)} 张图片,形状: {samples.shape}")
|
| 145 |
+
|
| 146 |
+
# 验证数据形状
|
| 147 |
+
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
|
| 148 |
+
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
|
| 149 |
+
|
| 150 |
+
# 保存为npz文件
|
| 151 |
+
np.savez(output_path, arr_0=samples)
|
| 152 |
+
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
|
| 153 |
+
|
| 154 |
+
return output_path
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def create_npz_from_image_folder(image_folder_path):
|
| 158 |
+
"""
|
| 159 |
+
从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz
|
| 160 |
+
"""
|
| 161 |
+
parent_dir = os.path.dirname(os.path.abspath(image_folder_path))
|
| 162 |
+
folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/"))
|
| 163 |
+
output_path = os.path.join(parent_dir, f"{folder_name}.npz")
|
| 164 |
+
return create_npz_from_image_folder_custom(image_folder_path, output_path)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
exit(main())
|
GVP/Baseline/run.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup torchrun \
|
| 2 |
+
--nnodes=1 \
|
| 3 |
+
--nproc_per_node=4 \
|
| 4 |
+
--rdzv_endpoint=localhost:29739 \
|
| 5 |
+
train_rectified_noise.py \
|
| 6 |
+
--depth 6 \
|
| 7 |
+
--results-dir results_256_gvp_disp \
|
| 8 |
+
--data-path /gemini/space/gzy_new/Imagenet256/train \
|
| 9 |
+
--ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
|
| 10 |
+
--num-classes 1000 \
|
| 11 |
+
--path-type GVP \
|
| 12 |
+
--prediction velocity \
|
| 13 |
+
--use-ot \
|
| 14 |
+
--use-contrastive \
|
| 15 |
+
> w_training1.log 2>&1 &
|
GVP/Baseline/sample_compare_ddp_rectified.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from diffusers.models import AutoencoderKL
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from download import find_model
|
| 15 |
+
from models import SiT_models
|
| 16 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 17 |
+
from transport import Sampler, create_transport
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fix_state_dict_for_ddp(state_dict):
|
| 21 |
+
if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict):
|
| 22 |
+
if "ema" in state_dict:
|
| 23 |
+
state_dict = state_dict["ema"]
|
| 24 |
+
elif "model" in state_dict:
|
| 25 |
+
state_dict = state_dict["model"]
|
| 26 |
+
fixed_state_dict = {}
|
| 27 |
+
for key, value in state_dict.items():
|
| 28 |
+
fixed_state_dict[key if key.startswith("module.") else f"module.{key}"] = value
|
| 29 |
+
return fixed_state_dict
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def save_png_batch(samples, out_dir, rank, total_offset):
|
| 33 |
+
for i, sample in enumerate(samples):
|
| 34 |
+
index = i * dist.get_world_size() + rank + total_offset
|
| 35 |
+
Image.fromarray(sample).save(f"{out_dir}/{index:06d}.png")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 39 |
+
samples = []
|
| 40 |
+
for i in tqdm(range(num), desc=f"Building .npz from {os.path.basename(sample_dir)}"):
|
| 41 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 42 |
+
samples.append(np.asarray(sample_pil).astype(np.uint8))
|
| 43 |
+
samples = np.stack(samples)
|
| 44 |
+
npz_path = f"{sample_dir}.npz"
|
| 45 |
+
np.savez(npz_path, arr_0=samples)
|
| 46 |
+
print(f"Saved .npz to {npz_path} [shape={samples.shape}]")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main(mode, args):
|
| 50 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32
|
| 51 |
+
assert torch.cuda.is_available(), "This script requires at least one GPU."
|
| 52 |
+
torch.set_grad_enabled(False)
|
| 53 |
+
|
| 54 |
+
dist.init_process_group("nccl")
|
| 55 |
+
rank = dist.get_rank()
|
| 56 |
+
device = rank % torch.cuda.device_count()
|
| 57 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 58 |
+
torch.manual_seed(seed)
|
| 59 |
+
torch.cuda.set_device(device)
|
| 60 |
+
|
| 61 |
+
latent_size = args.image_size // 8
|
| 62 |
+
assert args.cfg_scale >= 1.0
|
| 63 |
+
using_cfg = args.cfg_scale > 1.0
|
| 64 |
+
|
| 65 |
+
# ---------------- Base model (sample_ddp style) ----------------
|
| 66 |
+
base_model = SiT_models[args.model](
|
| 67 |
+
input_size=latent_size, num_classes=args.num_classes, learn_sigma=False
|
| 68 |
+
).to(device)
|
| 69 |
+
base_state = find_model(args.ckpt)
|
| 70 |
+
if isinstance(base_state, dict) and "model" in base_state:
|
| 71 |
+
base_state = base_state["model"]
|
| 72 |
+
base_model.load_state_dict(base_state, strict=False)
|
| 73 |
+
base_model.eval()
|
| 74 |
+
|
| 75 |
+
# ---------------- Rectified model (sample_rectified style) ----------------
|
| 76 |
+
from models import CombinedModel, SiTF1, SiTF2
|
| 77 |
+
|
| 78 |
+
model_name = args.model
|
| 79 |
+
if "XL" in model_name:
|
| 80 |
+
hidden_size, depth, num_heads = 1152, 28, 16
|
| 81 |
+
elif "L" in model_name:
|
| 82 |
+
hidden_size, depth, num_heads = 1024, 24, 16
|
| 83 |
+
elif "B" in model_name:
|
| 84 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 85 |
+
elif "S" in model_name:
|
| 86 |
+
hidden_size, depth, num_heads = 384, 12, 6
|
| 87 |
+
else:
|
| 88 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 89 |
+
patch_size = int(model_name.split("/")[-1])
|
| 90 |
+
|
| 91 |
+
sitf1 = SiTF1(
|
| 92 |
+
input_size=latent_size,
|
| 93 |
+
patch_size=patch_size,
|
| 94 |
+
in_channels=4,
|
| 95 |
+
hidden_size=hidden_size,
|
| 96 |
+
depth=depth,
|
| 97 |
+
num_heads=num_heads,
|
| 98 |
+
mlp_ratio=4.0,
|
| 99 |
+
class_dropout_prob=0.1,
|
| 100 |
+
num_classes=args.num_classes,
|
| 101 |
+
learn_sigma=False,
|
| 102 |
+
).to(device)
|
| 103 |
+
sitf1.load_state_dict(base_state, strict=False)
|
| 104 |
+
sitf1.eval()
|
| 105 |
+
|
| 106 |
+
sitf2 = SiTF2(
|
| 107 |
+
input_size=latent_size,
|
| 108 |
+
hidden_size=hidden_size,
|
| 109 |
+
out_channels=8,
|
| 110 |
+
patch_size=patch_size,
|
| 111 |
+
num_heads=num_heads,
|
| 112 |
+
mlp_ratio=4.0,
|
| 113 |
+
depth=args.depth,
|
| 114 |
+
learn_sigma=True,
|
| 115 |
+
num_classes=args.num_classes,
|
| 116 |
+
learn_mu=args.learn_mu,
|
| 117 |
+
).to(device)
|
| 118 |
+
sitf2 = DDP(sitf2, device_ids=[device])
|
| 119 |
+
sitf2_state = fix_state_dict_for_ddp(find_model(args.sitf2_ckpt))
|
| 120 |
+
sitf2.load_state_dict(sitf2_state, strict=False)
|
| 121 |
+
sitf2.eval()
|
| 122 |
+
rectified_model = CombinedModel(sitf1, sitf2).to(device)
|
| 123 |
+
rectified_model.eval()
|
| 124 |
+
|
| 125 |
+
def model_base(x, t, y=None, **kwargs):
|
| 126 |
+
if using_cfg and "cfg_scale" in kwargs:
|
| 127 |
+
return base_model.forward_with_cfg(x, t, y, kwargs["cfg_scale"])
|
| 128 |
+
return base_model.forward(x, t, y)
|
| 129 |
+
|
| 130 |
+
def model_rectified(x, t, y=None, **kwargs):
|
| 131 |
+
if using_cfg and "cfg_scale" in kwargs:
|
| 132 |
+
sit_out = base_model.forward_with_cfg(x, t, y, kwargs["cfg_scale"])
|
| 133 |
+
else:
|
| 134 |
+
sit_out = base_model.forward(x, t, y)
|
| 135 |
+
if not args.use_sitf2:
|
| 136 |
+
return sit_out
|
| 137 |
+
out = rectified_model.forward(x, t, y)
|
| 138 |
+
if args.use_sitf2_before_t05:
|
| 139 |
+
mask = (t < args.sitf2_threshold).float()
|
| 140 |
+
while len(mask.shape) < len(out.shape):
|
| 141 |
+
mask = mask.unsqueeze(-1)
|
| 142 |
+
out = out * mask.expand_as(out)
|
| 143 |
+
return sit_out + out
|
| 144 |
+
|
| 145 |
+
transport = create_transport(
|
| 146 |
+
args.path_type, args.prediction, args.loss_weight, args.train_eps, args.sample_eps
|
| 147 |
+
)
|
| 148 |
+
sampler = Sampler(transport)
|
| 149 |
+
if mode == "ODE":
|
| 150 |
+
sample_fn = sampler.sample_ode(
|
| 151 |
+
sampling_method=args.sampling_method,
|
| 152 |
+
num_steps=args.num_sampling_steps,
|
| 153 |
+
atol=args.atol,
|
| 154 |
+
rtol=args.rtol,
|
| 155 |
+
reverse=args.reverse,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
sample_fn = sampler.sample_sde(
|
| 159 |
+
sampling_method=args.sampling_method,
|
| 160 |
+
diffusion_form=args.diffusion_form,
|
| 161 |
+
diffusion_norm=args.diffusion_norm,
|
| 162 |
+
last_step=args.last_step,
|
| 163 |
+
last_step_size=args.last_step_size,
|
| 164 |
+
num_steps=args.num_sampling_steps,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 168 |
+
|
| 169 |
+
exp_name = f"compare-{args.model.replace('/', '-')}-cfg-{args.cfg_scale}-{mode}-{args.num_sampling_steps}"
|
| 170 |
+
root_out = os.path.join(args.sample_dir, exp_name)
|
| 171 |
+
out_base = os.path.join(root_out, "base")
|
| 172 |
+
out_rect = os.path.join(root_out, "rectified")
|
| 173 |
+
out_pair = os.path.join(root_out, "pair")
|
| 174 |
+
if rank == 0:
|
| 175 |
+
os.makedirs(out_base, exist_ok=True)
|
| 176 |
+
os.makedirs(out_rect, exist_ok=True)
|
| 177 |
+
os.makedirs(out_pair, exist_ok=True)
|
| 178 |
+
dist.barrier()
|
| 179 |
+
|
| 180 |
+
n = args.per_proc_batch_size
|
| 181 |
+
global_batch = n * dist.get_world_size()
|
| 182 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch) * global_batch)
|
| 183 |
+
iters = total_samples // global_batch
|
| 184 |
+
pbar = tqdm(range(iters)) if rank == 0 else range(iters)
|
| 185 |
+
total = 0
|
| 186 |
+
|
| 187 |
+
for _ in pbar:
|
| 188 |
+
z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
|
| 189 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 190 |
+
if using_cfg:
|
| 191 |
+
z_in = torch.cat([z, z], 0)
|
| 192 |
+
y_null = torch.full((n,), args.num_classes, device=device, dtype=y.dtype)
|
| 193 |
+
y_in = torch.cat([y, y_null], 0)
|
| 194 |
+
model_kwargs = dict(y=y_in, cfg_scale=args.cfg_scale)
|
| 195 |
+
else:
|
| 196 |
+
z_in = z
|
| 197 |
+
y_in = y
|
| 198 |
+
model_kwargs = dict(y=y_in)
|
| 199 |
+
|
| 200 |
+
# Ensure SDE process noise is identical between base and rectified:
|
| 201 |
+
# save RNG state -> run base -> restore RNG state -> run rectified.
|
| 202 |
+
cpu_rng_state = torch.get_rng_state()
|
| 203 |
+
cuda_rng_state = torch.cuda.get_rng_state(device)
|
| 204 |
+
|
| 205 |
+
x_base = sample_fn(z_in, model_base, **model_kwargs)[-1]
|
| 206 |
+
|
| 207 |
+
torch.set_rng_state(cpu_rng_state)
|
| 208 |
+
torch.cuda.set_rng_state(cuda_rng_state, device=device)
|
| 209 |
+
x_rect = sample_fn(z_in, model_rectified, **model_kwargs)[-1]
|
| 210 |
+
if using_cfg:
|
| 211 |
+
x_base, _ = x_base.chunk(2, dim=0)
|
| 212 |
+
x_rect, _ = x_rect.chunk(2, dim=0)
|
| 213 |
+
|
| 214 |
+
img_base = vae.decode(x_base / 0.18215).sample
|
| 215 |
+
img_rect = vae.decode(x_rect / 0.18215).sample
|
| 216 |
+
img_base = torch.clamp(127.5 * img_base + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 217 |
+
img_rect = torch.clamp(127.5 * img_rect + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 218 |
+
|
| 219 |
+
save_png_batch(img_base, out_base, rank, total)
|
| 220 |
+
save_png_batch(img_rect, out_rect, rank, total)
|
| 221 |
+
|
| 222 |
+
for i in range(img_base.shape[0]):
|
| 223 |
+
index = i * dist.get_world_size() + rank + total
|
| 224 |
+
pair = np.concatenate([img_base[i], img_rect[i]], axis=1)
|
| 225 |
+
Image.fromarray(pair).save(f"{out_pair}/{index:06d}.png")
|
| 226 |
+
|
| 227 |
+
total += global_batch
|
| 228 |
+
dist.barrier()
|
| 229 |
+
|
| 230 |
+
dist.barrier()
|
| 231 |
+
if rank == 0:
|
| 232 |
+
create_npz_from_sample_folder(out_base, args.num_fid_samples)
|
| 233 |
+
create_npz_from_sample_folder(out_rect, args.num_fid_samples)
|
| 234 |
+
print(f"Done. Output root: {root_out}")
|
| 235 |
+
dist.barrier()
|
| 236 |
+
dist.destroy_process_group()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
parser = argparse.ArgumentParser()
|
| 241 |
+
if len(sys.argv) < 2:
|
| 242 |
+
print("Usage: program.py <mode> [options]")
|
| 243 |
+
sys.exit(1)
|
| 244 |
+
|
| 245 |
+
mode = sys.argv[1]
|
| 246 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 247 |
+
|
| 248 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 249 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 250 |
+
parser.add_argument("--sample-dir", type=str, default="compare_samples")
|
| 251 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=12)
|
| 252 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 253 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 254 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 255 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 256 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 257 |
+
parser.add_argument("--global-seed", type=int, default=1)
|
| 258 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True)
|
| 259 |
+
parser.add_argument("--ckpt", type=str, required=True, help="Base SiT checkpoint")
|
| 260 |
+
parser.add_argument("--sitf2-ckpt", type=str, required=True, help="SiTF2 checkpoint")
|
| 261 |
+
parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True)
|
| 262 |
+
parser.add_argument("--depth", type=int, default=1, help="SiTF2 depth")
|
| 263 |
+
parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True)
|
| 264 |
+
parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False)
|
| 265 |
+
parser.add_argument("--sitf2-threshold", type=float, default=0.5)
|
| 266 |
+
|
| 267 |
+
parse_transport_args(parser)
|
| 268 |
+
if mode == "ODE":
|
| 269 |
+
parse_ode_args(parser)
|
| 270 |
+
else:
|
| 271 |
+
parse_sde_args(parser)
|
| 272 |
+
|
| 273 |
+
args = parser.parse_known_args()[0]
|
| 274 |
+
main(mode, args)
|
GVP/Baseline/sample_ddp.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 6 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 7 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 8 |
+
|
| 9 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from models import SiT_models
|
| 14 |
+
from download import find_model
|
| 15 |
+
from transport import create_transport, Sampler
|
| 16 |
+
from diffusers.models import AutoencoderKL
|
| 17 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 28 |
+
"""
|
| 29 |
+
Builds a single .npz file from a folder of .png samples.
|
| 30 |
+
"""
|
| 31 |
+
samples = []
|
| 32 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 33 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 34 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 35 |
+
samples.append(sample_np)
|
| 36 |
+
samples = np.stack(samples)
|
| 37 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 38 |
+
npz_path = f"{sample_dir}.npz"
|
| 39 |
+
np.savez(npz_path, arr_0=samples)
|
| 40 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 41 |
+
return npz_path
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main(mode, args):
|
| 45 |
+
"""
|
| 46 |
+
Run sampling.
|
| 47 |
+
"""
|
| 48 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 49 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 50 |
+
torch.set_grad_enabled(False)
|
| 51 |
+
|
| 52 |
+
# Setup DDP:
|
| 53 |
+
dist.init_process_group("nccl")
|
| 54 |
+
rank = dist.get_rank()
|
| 55 |
+
device = rank % torch.cuda.device_count()
|
| 56 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 57 |
+
torch.manual_seed(seed)
|
| 58 |
+
torch.cuda.set_device(device)
|
| 59 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 60 |
+
|
| 61 |
+
if args.ckpt is None:
|
| 62 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 63 |
+
assert args.image_size in [256, 512]
|
| 64 |
+
assert args.num_classes == 1000
|
| 65 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 66 |
+
learn_sigma = args.image_size == 256
|
| 67 |
+
else:
|
| 68 |
+
learn_sigma = False
|
| 69 |
+
|
| 70 |
+
# Load model:
|
| 71 |
+
latent_size = args.image_size // 8
|
| 72 |
+
model = SiT_models[args.model](
|
| 73 |
+
input_size=latent_size,
|
| 74 |
+
num_classes=args.num_classes,
|
| 75 |
+
learn_sigma=learn_sigma,
|
| 76 |
+
).to(device)
|
| 77 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 78 |
+
ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
|
| 79 |
+
state_dict = find_model(ckpt_path)
|
| 80 |
+
model.load_state_dict(state_dict)
|
| 81 |
+
model.eval() # important!
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
transport = create_transport(
|
| 85 |
+
args.path_type,
|
| 86 |
+
args.prediction,
|
| 87 |
+
args.loss_weight,
|
| 88 |
+
args.train_eps,
|
| 89 |
+
args.sample_eps
|
| 90 |
+
)
|
| 91 |
+
sampler = Sampler(transport)
|
| 92 |
+
if mode == "ODE":
|
| 93 |
+
if args.likelihood:
|
| 94 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 95 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 96 |
+
sampling_method=args.sampling_method,
|
| 97 |
+
num_steps=args.num_sampling_steps,
|
| 98 |
+
atol=args.atol,
|
| 99 |
+
rtol=args.rtol,
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
sample_fn = sampler.sample_ode(
|
| 103 |
+
sampling_method=args.sampling_method,
|
| 104 |
+
num_steps=args.num_sampling_steps,
|
| 105 |
+
atol=args.atol,
|
| 106 |
+
rtol=args.rtol,
|
| 107 |
+
reverse=args.reverse
|
| 108 |
+
)
|
| 109 |
+
elif mode == "SDE":
|
| 110 |
+
sample_fn = sampler.sample_sde(
|
| 111 |
+
sampling_method=args.sampling_method,
|
| 112 |
+
diffusion_form=args.diffusion_form,
|
| 113 |
+
diffusion_norm=args.diffusion_norm,
|
| 114 |
+
last_step=args.last_step,
|
| 115 |
+
last_step_size=args.last_step_size,
|
| 116 |
+
num_steps=args.num_sampling_steps,
|
| 117 |
+
)
|
| 118 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 119 |
+
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
|
| 120 |
+
using_cfg = args.cfg_scale > 1.0
|
| 121 |
+
|
| 122 |
+
# Create folder to save samples:
|
| 123 |
+
model_string_name = args.model.replace("/", "-")
|
| 124 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 125 |
+
if mode == "ODE":
|
| 126 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 127 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 128 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
|
| 129 |
+
elif mode == "SDE":
|
| 130 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 131 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 132 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 133 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 134 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 135 |
+
if rank == 0:
|
| 136 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 137 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 138 |
+
dist.barrier()
|
| 139 |
+
|
| 140 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 141 |
+
n = args.per_proc_batch_size
|
| 142 |
+
global_batch_size = n * dist.get_world_size()
|
| 143 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 144 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 145 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 146 |
+
if rank == 0:
|
| 147 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 148 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 149 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 150 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 151 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 152 |
+
done_iterations = int( int(num_samples // dist.get_world_size()) // n)
|
| 153 |
+
pbar = range(iterations)
|
| 154 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 155 |
+
total = 0
|
| 156 |
+
|
| 157 |
+
for i in pbar:
|
| 158 |
+
# Sample inputs:
|
| 159 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 160 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 161 |
+
|
| 162 |
+
# Setup classifier-free guidance:
|
| 163 |
+
if using_cfg:
|
| 164 |
+
z = torch.cat([z, z], 0)
|
| 165 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 166 |
+
y = torch.cat([y, y_null], 0)
|
| 167 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 168 |
+
model_fn = model.forward_with_cfg
|
| 169 |
+
else:
|
| 170 |
+
model_kwargs = dict(y=y)
|
| 171 |
+
model_fn = model.forward
|
| 172 |
+
|
| 173 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 174 |
+
if using_cfg:
|
| 175 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 176 |
+
|
| 177 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 178 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 179 |
+
|
| 180 |
+
# Save samples to disk as individual .png files
|
| 181 |
+
for i, sample in enumerate(samples):
|
| 182 |
+
index = i * dist.get_world_size() + rank + total
|
| 183 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 184 |
+
total += global_batch_size
|
| 185 |
+
dist.barrier()
|
| 186 |
+
|
| 187 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 188 |
+
dist.barrier()
|
| 189 |
+
if rank == 0:
|
| 190 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 191 |
+
print("Done.")
|
| 192 |
+
dist.barrier()
|
| 193 |
+
dist.destroy_process_group()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
|
| 198 |
+
parser = argparse.ArgumentParser()
|
| 199 |
+
|
| 200 |
+
if len(sys.argv) < 2:
|
| 201 |
+
print("Usage: program.py <mode> [options]")
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
|
| 204 |
+
mode = sys.argv[1]
|
| 205 |
+
|
| 206 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 207 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 208 |
+
|
| 209 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 210 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 211 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 212 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=12)
|
| 213 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 214 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 215 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 216 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 217 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 218 |
+
parser.add_argument("--global-seed", type=int, default=1)
|
| 219 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 220 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 221 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 222 |
+
help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
|
| 223 |
+
|
| 224 |
+
parse_transport_args(parser)
|
| 225 |
+
if mode == "ODE":
|
| 226 |
+
parse_ode_args(parser)
|
| 227 |
+
# Further processing for ODE
|
| 228 |
+
elif mode == "SDE":
|
| 229 |
+
parse_sde_args(parser)
|
| 230 |
+
# Further processing for SDE
|
| 231 |
+
|
| 232 |
+
args = parser.parse_known_args()[0]
|
| 233 |
+
main(mode, args)
|
GVP/Baseline/sample_rectified_noise.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
from models import SiT_models
|
| 5 |
+
from download import find_model
|
| 6 |
+
from transport import create_transport, Sampler
|
| 7 |
+
from diffusers.models import AutoencoderKL
|
| 8 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import os
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
import argparse
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 19 |
+
"""
|
| 20 |
+
Builds a single .npz file from a folder of .png samples.
|
| 21 |
+
"""
|
| 22 |
+
samples = []
|
| 23 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 24 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 25 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 26 |
+
samples.append(sample_np)
|
| 27 |
+
samples = np.stack(samples)
|
| 28 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 29 |
+
npz_path = f"{sample_dir}.npz"
|
| 30 |
+
np.savez(npz_path, arr_0=samples)
|
| 31 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 32 |
+
return npz_path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fix_state_dict_for_ddp(state_dict):
|
| 36 |
+
"""
|
| 37 |
+
Fix state dict keys to match DistributedDataParallel model keys.
|
| 38 |
+
Add "module." prefix to keys if they don't have it.
|
| 39 |
+
"""
|
| 40 |
+
# Check if this is a full checkpoint dict with "model", "ema", or "opt" keys
|
| 41 |
+
if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict or "opt" in state_dict):
|
| 42 |
+
# This is a full checkpoint dict, extract the state dict we need
|
| 43 |
+
# Prefer "ema" then "model" then return as is
|
| 44 |
+
if "ema" in state_dict:
|
| 45 |
+
state_dict = state_dict["ema"]
|
| 46 |
+
elif "model" in state_dict:
|
| 47 |
+
state_dict = state_dict["model"]
|
| 48 |
+
else:
|
| 49 |
+
# If only "opt" or other keys exist, return original
|
| 50 |
+
state_dict = state_dict
|
| 51 |
+
|
| 52 |
+
# Now fix the keys to match DDP format
|
| 53 |
+
fixed_state_dict = {}
|
| 54 |
+
for key, value in state_dict.items():
|
| 55 |
+
if not key.startswith("module."):
|
| 56 |
+
new_key = "module." + key
|
| 57 |
+
else:
|
| 58 |
+
new_key = key
|
| 59 |
+
fixed_state_dict[new_key] = value
|
| 60 |
+
return fixed_state_dict
|
| 61 |
+
|
| 62 |
+
def main(mode, args):
|
| 63 |
+
"""
|
| 64 |
+
Run sampling.
|
| 65 |
+
"""
|
| 66 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 67 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 68 |
+
torch.set_grad_enabled(False)
|
| 69 |
+
learn_mu = args.learn_mu
|
| 70 |
+
sitf2_depth = args.depth # Save SiTF2 depth before it gets overwritten
|
| 71 |
+
|
| 72 |
+
# Setup DDP:
|
| 73 |
+
dist.init_process_group("nccl")
|
| 74 |
+
rank = dist.get_rank()
|
| 75 |
+
device = rank % torch.cuda.device_count()
|
| 76 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 77 |
+
torch.manual_seed(seed)
|
| 78 |
+
torch.cuda.set_device(device)
|
| 79 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 80 |
+
|
| 81 |
+
if args.ckpt is None:
|
| 82 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 83 |
+
assert args.image_size in [256, 512]
|
| 84 |
+
assert args.num_classes == 1000
|
| 85 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 86 |
+
learn_sigma = args.image_size == 256
|
| 87 |
+
else:
|
| 88 |
+
learn_sigma = False
|
| 89 |
+
|
| 90 |
+
# Load SiTF1 and SiTF2 models and create CombinedModel
|
| 91 |
+
from models import SiTF1, SiTF2, CombinedModel
|
| 92 |
+
latent_size = args.image_size // 8
|
| 93 |
+
|
| 94 |
+
# Get model configuration based on args.model
|
| 95 |
+
model_name = args.model
|
| 96 |
+
if 'XL' in model_name:
|
| 97 |
+
hidden_size, depth, num_heads = 1152, 28, 16
|
| 98 |
+
elif 'L' in model_name:
|
| 99 |
+
hidden_size, depth, num_heads = 1024, 24, 16
|
| 100 |
+
elif 'B' in model_name:
|
| 101 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 102 |
+
elif 'S' in model_name:
|
| 103 |
+
hidden_size, depth, num_heads = 384, 12, 6
|
| 104 |
+
else:
|
| 105 |
+
# Default fallback
|
| 106 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 107 |
+
|
| 108 |
+
# Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2
|
| 109 |
+
patch_size = int(model_name.split('/')[-1])
|
| 110 |
+
|
| 111 |
+
# Load SiTF1
|
| 112 |
+
sitf1 = SiTF1(
|
| 113 |
+
input_size=latent_size,
|
| 114 |
+
patch_size=patch_size,
|
| 115 |
+
in_channels=4,
|
| 116 |
+
hidden_size=hidden_size,
|
| 117 |
+
depth=depth,
|
| 118 |
+
num_heads=num_heads,
|
| 119 |
+
mlp_ratio=4.0,
|
| 120 |
+
class_dropout_prob=0.1,
|
| 121 |
+
num_classes=args.num_classes,
|
| 122 |
+
learn_sigma=False
|
| 123 |
+
).to(device)
|
| 124 |
+
sitf1_state_raw = find_model(args.ckpt)
|
| 125 |
+
# find_model now returns ema if available, or the full checkpoint
|
| 126 |
+
# Extract the actual state_dict to use for both sitf1 and base_model
|
| 127 |
+
if isinstance(sitf1_state_raw, dict) and "model" in sitf1_state_raw:
|
| 128 |
+
sitf1_state = sitf1_state_raw["model"]
|
| 129 |
+
else:
|
| 130 |
+
# sitf1_state_raw is already a state_dict (either ema or direct model state)
|
| 131 |
+
sitf1_state = sitf1_state_raw
|
| 132 |
+
sitf1.load_state_dict(sitf1_state)
|
| 133 |
+
sitf1.eval()
|
| 134 |
+
|
| 135 |
+
# For sampling, we can use sitf1 directly instead of creating a separate sit model
|
| 136 |
+
# since sitf1 and sit have the same architecture and weights
|
| 137 |
+
|
| 138 |
+
# Load SiTF2 with the same architecture parameters as SiTF1 for compatibility
|
| 139 |
+
sitf2 = SiTF2(
|
| 140 |
+
input_size=latent_size,
|
| 141 |
+
hidden_size=hidden_size, # Use the same hidden_size as SiTF1
|
| 142 |
+
out_channels=8,
|
| 143 |
+
patch_size=patch_size, # Use the same patch_size as SiTF1
|
| 144 |
+
num_heads=num_heads, # Use the same num_heads as SiTF1
|
| 145 |
+
mlp_ratio=4.0,
|
| 146 |
+
depth=sitf2_depth, # Use the depth specified by command line argument (not the model's default depth)
|
| 147 |
+
learn_sigma=True,
|
| 148 |
+
num_classes=args.num_classes,
|
| 149 |
+
learn_mu=learn_mu
|
| 150 |
+
).to(device)
|
| 151 |
+
sitf2 = DDP(sitf2, device_ids=[device])
|
| 152 |
+
sitf2_state = find_model(args.sitf2_ckpt)
|
| 153 |
+
# Fix state dict keys to match DDP model
|
| 154 |
+
sitf2_state_fixed = fix_state_dict_for_ddp(sitf2_state)
|
| 155 |
+
try:
|
| 156 |
+
sitf2.load_state_dict(sitf2_state_fixed)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error loading state dict: {e}")
|
| 159 |
+
# Try loading with strict=False as fallback
|
| 160 |
+
sitf2.load_state_dict(sitf2_state_fixed, strict=False)
|
| 161 |
+
sitf2.eval()
|
| 162 |
+
# CombinedModel
|
| 163 |
+
|
| 164 |
+
combined_model = CombinedModel(sitf1, sitf2).to(device)
|
| 165 |
+
sitf2.eval()
|
| 166 |
+
combined_model.eval()
|
| 167 |
+
|
| 168 |
+
# Use SiT_models factory function to create the base model, same as in SiT_clean
|
| 169 |
+
# This ensures correct model configuration
|
| 170 |
+
# Use learn_sigma=False to match sitf1 configuration
|
| 171 |
+
base_model = SiT_models[args.model](
|
| 172 |
+
input_size=latent_size,
|
| 173 |
+
num_classes=args.num_classes,
|
| 174 |
+
learn_sigma=False, # Match sitf1's learn_sigma=False
|
| 175 |
+
).to(device)
|
| 176 |
+
# Load the checkpoint (same as sitf1) - use the exact same state_dict
|
| 177 |
+
base_model.load_state_dict(sitf1_state)
|
| 178 |
+
base_model.eval()
|
| 179 |
+
|
| 180 |
+
# Determine if CFG will be used (needed for combined_sampling_model function)
|
| 181 |
+
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
|
| 182 |
+
using_cfg = args.cfg_scale > 1.0
|
| 183 |
+
|
| 184 |
+
# There are repeated calculations in the middle,
|
| 185 |
+
# which will cause Flops to double. A simplified version will be released later
|
| 186 |
+
def combined_sampling_model(x, t, y=None, **kwargs):
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
# Handle CFG same as in SiT_clean/sample_ddp.py
|
| 189 |
+
if using_cfg and 'cfg_scale' in kwargs:
|
| 190 |
+
# Use forward_with_cfg when CFG is enabled
|
| 191 |
+
sit_out = base_model.forward_with_cfg(x, t, y, kwargs['cfg_scale'])
|
| 192 |
+
else:
|
| 193 |
+
# Use regular forward when CFG is disabled
|
| 194 |
+
sit_out = base_model.forward(x, t, y)
|
| 195 |
+
# If use_sitf2_before_t05 is True, only use sitf2 when t < threshold
|
| 196 |
+
if args.use_sitf2:
|
| 197 |
+
if args.use_sitf2_before_t05:
|
| 198 |
+
# t is a tensor, check which samples have t < threshold
|
| 199 |
+
# Create a mask: 1.0 where t < threshold, 0.0 otherwise
|
| 200 |
+
mask = (t < args.sitf2_threshold).float()
|
| 201 |
+
# Compute sitf2 output for all samples
|
| 202 |
+
combined_out = combined_model.forward(x, t, y)
|
| 203 |
+
# Expand mask to match the spatial dimensions of combined_out
|
| 204 |
+
# combined_out shape is (batch, channels, height, width)
|
| 205 |
+
while len(mask.shape) < len(combined_out.shape):
|
| 206 |
+
mask = mask.unsqueeze(-1)
|
| 207 |
+
# Broadcast mask to match combined_out shape
|
| 208 |
+
mask = mask.expand_as(combined_out)
|
| 209 |
+
# Only use sitf2 output where t < threshold
|
| 210 |
+
combined_out = combined_out * mask
|
| 211 |
+
# Combine sit_out and masked combined_out
|
| 212 |
+
return sit_out + combined_out
|
| 213 |
+
else:
|
| 214 |
+
# Default behavior: only use base model output
|
| 215 |
+
return sit_out
|
| 216 |
+
else:
|
| 217 |
+
# Default behavior: only use base model output
|
| 218 |
+
return sit_out
|
| 219 |
+
|
| 220 |
+
transport = create_transport(
|
| 221 |
+
args.path_type,
|
| 222 |
+
args.prediction,
|
| 223 |
+
args.loss_weight,
|
| 224 |
+
args.train_eps,
|
| 225 |
+
args.sample_eps
|
| 226 |
+
)
|
| 227 |
+
sampler = Sampler(transport)
|
| 228 |
+
if mode == "ODE":
|
| 229 |
+
if args.likelihood:
|
| 230 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 231 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 232 |
+
sampling_method=args.sampling_method,
|
| 233 |
+
num_steps=args.num_sampling_steps,
|
| 234 |
+
atol=args.atol,
|
| 235 |
+
rtol=args.rtol,
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
sample_fn = sampler.sample_ode(
|
| 239 |
+
sampling_method=args.sampling_method,
|
| 240 |
+
num_steps=args.num_sampling_steps,
|
| 241 |
+
atol=args.atol,
|
| 242 |
+
rtol=args.rtol,
|
| 243 |
+
reverse=args.reverse
|
| 244 |
+
)
|
| 245 |
+
elif mode == "SDE":
|
| 246 |
+
sample_fn = sampler.sample_sde(
|
| 247 |
+
sampling_method=args.sampling_method,
|
| 248 |
+
diffusion_form=args.diffusion_form,
|
| 249 |
+
diffusion_norm=args.diffusion_norm,
|
| 250 |
+
last_step=args.last_step,
|
| 251 |
+
last_step_size=args.last_step_size,
|
| 252 |
+
num_steps=args.num_sampling_steps,
|
| 253 |
+
)
|
| 254 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 255 |
+
|
| 256 |
+
# Create folder to save samples:
|
| 257 |
+
model_string_name = args.model.replace("/", "-")
|
| 258 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 259 |
+
sitf2_ckpt_string_name = os.path.basename(args.sitf2_ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 260 |
+
if mode == "ODE":
|
| 261 |
+
folder_name = f"{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 262 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 263 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
|
| 264 |
+
elif mode == "SDE":
|
| 265 |
+
# Add threshold info to folder name if use_sitf2_before_t05 is enabled
|
| 266 |
+
threshold_suffix = f"-threshold-{args.sitf2_threshold}" if args.use_sitf2_before_t05 else ""
|
| 267 |
+
if learn_mu:
|
| 268 |
+
folder_name = f"depth-mu-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 269 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 270 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 271 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 272 |
+
else:
|
| 273 |
+
folder_name = f"depth-sigma-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 274 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 275 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 276 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 277 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 278 |
+
if rank == 0:
|
| 279 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 280 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 281 |
+
dist.barrier()
|
| 282 |
+
|
| 283 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 284 |
+
n = args.per_proc_batch_size
|
| 285 |
+
global_batch_size = n * dist.get_world_size()
|
| 286 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 287 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 288 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 289 |
+
if rank == 0:
|
| 290 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 291 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 292 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 293 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 294 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 295 |
+
done_iterations = int( int(num_samples // dist.get_world_size()) // n)
|
| 296 |
+
pbar = range(iterations)
|
| 297 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 298 |
+
total = 0
|
| 299 |
+
|
| 300 |
+
for i in pbar:
|
| 301 |
+
# Sample inputs:
|
| 302 |
+
z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
|
| 303 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 304 |
+
# Setup classifier-free guidance:
|
| 305 |
+
if using_cfg:
|
| 306 |
+
z = torch.cat([z, z], 0)
|
| 307 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 308 |
+
y = torch.cat([y, y_null], 0)
|
| 309 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 310 |
+
else:
|
| 311 |
+
model_kwargs = dict(y=y)
|
| 312 |
+
samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 313 |
+
if using_cfg:
|
| 314 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 315 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 316 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 317 |
+
# Save samples to disk as individual .png files
|
| 318 |
+
for i, sample in enumerate(samples):
|
| 319 |
+
index = i * dist.get_world_size() + rank + total
|
| 320 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 321 |
+
total += global_batch_size
|
| 322 |
+
dist.barrier()
|
| 323 |
+
|
| 324 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 325 |
+
dist.barrier()
|
| 326 |
+
if rank == 0:
|
| 327 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 328 |
+
print("Done.")
|
| 329 |
+
dist.barrier()
|
| 330 |
+
dist.destroy_process_group()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
|
| 335 |
+
parser = argparse.ArgumentParser()
|
| 336 |
+
|
| 337 |
+
if len(sys.argv) < 2:
|
| 338 |
+
print("Usage: program.py <mode> [options]")
|
| 339 |
+
sys.exit(1)
|
| 340 |
+
|
| 341 |
+
mode = sys.argv[1]
|
| 342 |
+
|
| 343 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 344 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 345 |
+
|
| 346 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 347 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 348 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 349 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=12)
|
| 350 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 351 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 352 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 353 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 354 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 355 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 356 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 357 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 358 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 359 |
+
help="Optional path to a SiT checkpoint.")
|
| 360 |
+
parser.add_argument("--sitf2-ckpt", type=str, required=True, help="Path to SiTF2 checkpoint")
|
| 361 |
+
parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True,
|
| 362 |
+
help="Whether to learn mu parameter")
|
| 363 |
+
parser.add_argument("--depth", type=int, default=1,
|
| 364 |
+
help="Depth parameter for SiTF2 model")
|
| 365 |
+
parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True,
|
| 366 |
+
help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
|
| 367 |
+
parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False,
|
| 368 |
+
help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
|
| 369 |
+
parser.add_argument("--sitf2-threshold", type=float, default=0.5,
|
| 370 |
+
help="Time threshold for using SiTF2 output (default: 0.5). Only effective when --use-sitf2-before-t05 is True")
|
| 371 |
+
parse_transport_args(parser)
|
| 372 |
+
if mode == "ODE":
|
| 373 |
+
parse_ode_args(parser)
|
| 374 |
+
# Further processing for ODE
|
| 375 |
+
elif mode == "SDE":
|
| 376 |
+
parse_sde_args(parser)
|
| 377 |
+
# Further processing for SDE
|
| 378 |
+
|
| 379 |
+
args = parser.parse_known_args()[0]
|
| 380 |
+
main(mode, args)
|
GVP/Baseline/samples.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
|
| 2 |
+
--nnodes=1 \
|
| 3 |
+
--nproc_per_node=4 \
|
| 4 |
+
--rdzv_endpoint=localhost:29110 \
|
| 5 |
+
sample_rectified_noise.py SDE \
|
| 6 |
+
--depth 6 \
|
| 7 |
+
--sample-dir GVP_samples \
|
| 8 |
+
--model SiT-XL/2 \
|
| 9 |
+
--num-fid-samples 50000 \
|
| 10 |
+
--num-classes 1000 \
|
| 11 |
+
--global-seed 0 \
|
| 12 |
+
--use-sitf2 True \
|
| 13 |
+
--sitf2-threshold 1 \
|
| 14 |
+
--ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
|
| 15 |
+
--sitf2-ckpt /gemini/space/gzy_new/models/Baseline/results_256_gvp_disp/depth-mu-6-007-SiT-XL-2-GVP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt \
|
| 16 |
+
> W_No.log 2>&1 &
|
GVP/Baseline/samples_ddp.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
|
| 2 |
+
--nnodes=1 \
|
| 3 |
+
--nproc_per_node=4 \
|
| 4 |
+
--rdzv_endpoint=localhost:29111 \
|
| 5 |
+
sample_ddp.py SDE \
|
| 6 |
+
--sample-dir baseline_gvp_ \
|
| 7 |
+
--model SiT-XL/2 \
|
| 8 |
+
--num-fid-samples 50000 \
|
| 9 |
+
--num-classes 1000 \
|
| 10 |
+
--global-seed 0 \
|
| 11 |
+
--path-type GVP \
|
| 12 |
+
--prediction velocity \
|
| 13 |
+
--ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
|
| 14 |
+
> gvp_sampling.log 2>&1 &
|
GVP/Baseline/transport/__pycache__/ot_plan.cpython-311.pyc
ADDED
|
Binary file (5.71 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/path.cpython-310.pyc
ADDED
|
Binary file (7.9 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/path.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/path.cpython-312.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/path.cpython-38.pyc
ADDED
|
Binary file (7.93 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/transport.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/transport.cpython-311.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/transport.cpython-312.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/transport.cpython-38.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (1.9 kB). View file
|
|
|
GVP/Baseline/transport/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0020000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de6ee3c5ee036be85b216daddafb56f16df64e9c3f7d3060b31f5cb301a345b1
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0040000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30fe96a8ac31b2744debe3861b4c92e631462eb27401c26f88042b6e65ac287d
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0060000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f82060d1339e53cc05c4863d39dc083795f05ad282cabc830889d8b6f72911b
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0080000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:606a5b7dceacc351bf68867736eed793e3e02c29a6e27aa0e8cf6e2ef3382c06
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0100000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ecb79b4530e3503370cefb16c38706a8b7b823e221f30f621c902e2e174aaec
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0120000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d03dae6043ff1543a412b1c44c02a21aa299014e8b3b3209e21cfd73b2c5d0f
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0140000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c630ada8db890d9fc41eaf8d30b3bb1d95b1e61d380d8f7e97f639ec3b307233
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0160000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9aa31b63d4138dec6b8d4c55a58fadf60096ac8f1d4ec83651ea724e7d842e17
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0180000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4b9c5725fd5492e8e4979b190b04809ba03514ff57da6b5ce92e3e99df9c0db
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0200000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49d27a75f329c0170b0b6bbc0028881a40b6d8cb2a34495fba8ba73e49a395a4
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0220000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c59a8b6bca3ade4b6547951082b5dc67ad805861c611a13bccf1995cf8f0ca7d
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0240000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b47c2b4f5baef28f5f1033b57e5ad7cb2c1b52246fcaa0ff4214781bfa305604
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0260000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf7f7746b191c29efd3ff33ba7cdb83f67d242ed03a14228478943a3cb484907
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0280000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f254fdaea4c2c6875493b4a94846b5a8a54cf90210d610d297a7bbceadd84101
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f198d183997904bfef6ae3222f59022cbb86c40ae625d2196d659bd7c1c3c121
|
| 3 |
+
size 1193384322
|
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0320000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e4aca9e6138967eea8540f58f7a7b1d09b749ea4b471e62a0052b8c04e9a8d4
|
| 3 |
+
size 1193384322
|