xiangzai commited on
Commit
3b7386d
·
verified ·
1 Parent(s): b2fa289

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. GVP/Baseline/DeltaFM/assets/deltafm.png +3 -0
  3. GVP/Baseline/DeltaFM/interference_vectors/imnet256_interference_vector.pt +3 -0
  4. GVP/Baseline/W_No.log +0 -0
  5. GVP/Baseline/classify_image_graph_def.pb +3 -0
  6. GVP/Baseline/compare_samples.sh +21 -0
  7. GVP/Baseline/compare_sampling.log +0 -0
  8. GVP/Baseline/download.py +41 -0
  9. GVP/Baseline/environment.yml +16 -0
  10. GVP/Baseline/evaluate_samples.sh +65 -0
  11. GVP/Baseline/evaluator.py +690 -0
  12. GVP/Baseline/gvp_sampling.log +51 -0
  13. GVP/Baseline/models.py +647 -0
  14. GVP/Baseline/nohup.out +180 -0
  15. GVP/Baseline/pic_npz.py +168 -0
  16. GVP/Baseline/run.sh +15 -0
  17. GVP/Baseline/sample_compare_ddp_rectified.py +274 -0
  18. GVP/Baseline/sample_ddp.py +233 -0
  19. GVP/Baseline/sample_rectified_noise.py +380 -0
  20. GVP/Baseline/samples.sh +16 -0
  21. GVP/Baseline/samples_ddp.sh +14 -0
  22. GVP/Baseline/transport/__pycache__/ot_plan.cpython-311.pyc +0 -0
  23. GVP/Baseline/transport/__pycache__/path.cpython-310.pyc +0 -0
  24. GVP/Baseline/transport/__pycache__/path.cpython-311.pyc +0 -0
  25. GVP/Baseline/transport/__pycache__/path.cpython-312.pyc +0 -0
  26. GVP/Baseline/transport/__pycache__/path.cpython-38.pyc +0 -0
  27. GVP/Baseline/transport/__pycache__/transport.cpython-310.pyc +0 -0
  28. GVP/Baseline/transport/__pycache__/transport.cpython-311.pyc +0 -0
  29. GVP/Baseline/transport/__pycache__/transport.cpython-312.pyc +0 -0
  30. GVP/Baseline/transport/__pycache__/transport.cpython-38.pyc +0 -0
  31. GVP/Baseline/transport/__pycache__/utils.cpython-310.pyc +0 -0
  32. GVP/Baseline/transport/__pycache__/utils.cpython-311.pyc +0 -0
  33. GVP/Baseline/transport/__pycache__/utils.cpython-312.pyc +0 -0
  34. GVP/Baseline/transport/__pycache__/utils.cpython-38.pyc +0 -0
  35. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0020000.pt +3 -0
  36. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0040000.pt +3 -0
  37. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0060000.pt +3 -0
  38. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0080000.pt +3 -0
  39. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0100000.pt +3 -0
  40. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0120000.pt +3 -0
  41. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0140000.pt +3 -0
  42. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0160000.pt +3 -0
  43. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0180000.pt +3 -0
  44. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0200000.pt +3 -0
  45. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0220000.pt +3 -0
  46. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0240000.pt +3 -0
  47. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0260000.pt +3 -0
  48. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0280000.pt +3 -0
  49. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt +3 -0
  50. VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0320000.pt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GVP/Baseline/DeltaFM/assets/deltafm.png filter=lfs diff=lfs merge=lfs -text
GVP/Baseline/DeltaFM/assets/deltafm.png ADDED

Git LFS Details

  • SHA256: ea6d5e6c5d64fecceb307ee87bfdc6c9fc97db6bf2fd37390b0b189018156336
  • Pointer size: 131 Bytes
  • Size of remote file: 405 kB
GVP/Baseline/DeltaFM/interference_vectors/imnet256_interference_vector.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f34c2a05a7abe7c22e32d6cc06e65f8435e9271f7b4aa0cc7044dce1b7727b
3
+ size 17629
GVP/Baseline/W_No.log ADDED
The diff for this file is too large to render. See raw diff
 
GVP/Baseline/classify_image_graph_def.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:009d6814d1bc560d4e7b236e170e9b2d5ca6f4b57bd8037f6db05776204415c6
3
+ size 95673916
GVP/Baseline/compare_samples.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
4
+ --nnodes=1 \
5
+ --nproc_per_node=4 \
6
+ --rdzv_endpoint=localhost:29166 \
7
+ sample_compare_ddp_rectified.py SDE \
8
+ --model SiT-XL/2 \
9
+ --sample-dir compare_samples \
10
+ --num-fid-samples 50000 \
11
+ --num-classes 1000 \
12
+ --global-seed 1 \
13
+ --cfg-scale 1.0 \
14
+ --num-sampling-steps 250 \
15
+ --depth 6 \
16
+ --use-sitf2 True \
17
+ --sitf2-threshold 0.5 \
18
+ --ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
19
+ --sitf2-ckpt /gemini/space/gzy_new/models/Baseline/results_256_gvp_disp/depth-mu-6-007-SiT-XL-2-GVP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt \
20
+ > compare_sampling.log 2>&1 &
21
+
GVP/Baseline/compare_sampling.log ADDED
The diff for this file is too large to render. See raw diff
 
GVP/Baseline/download.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Functions for downloading pre-trained SiT models
6
+ """
7
+ from torchvision.datasets.utils import download_url
8
+ import torch
9
+ import os
10
+
11
+
12
+ pretrained_models = {'SiT-XL-2-256x256.pt'}
13
+
14
+
15
+ def find_model(model_name):
16
+ """
17
+ Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
18
+ """
19
+ if model_name in pretrained_models:
20
+ return download_model(model_name)
21
+ else:
22
+ assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
23
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage, weights_only=False)
24
+ if "ema" in checkpoint: # supports checkpoints from train.py
25
+ checkpoint = checkpoint["ema"]
26
+ return checkpoint
27
+
28
+
29
+ def download_model(model_name):
30
+ """
31
+ Downloads a pre-trained SiT model from the web.
32
+ """
33
+ assert model_name in pretrained_models
34
+ local_path = f'pretrained_models/{model_name}'
35
+ if not os.path.isfile(local_path):
36
+ os.makedirs('pretrained_models', exist_ok=True)
37
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
38
+ download_url(web_path, 'pretrained_models', filename=model_name)
39
+ model = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=False)
40
+ return model
41
+
GVP/Baseline/environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: RN
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.8
7
+ - pytorch >= 1.13
8
+ - torchvision
9
+ - pytorch-cuda >=11.7
10
+ - pip
11
+ - pip:
12
+ - timm
13
+ - diffusers
14
+ - accelerate
15
+ - torchdiffeq
16
+ - wandb
GVP/Baseline/evaluate_samples.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Execute all evaluation tasks in parallel
4
+ # Each command runs in the background using &
5
+
6
+ echo "Starting all evaluation tasks in parallel..."
7
+
8
+ # Reference batch path
9
+ REF_BATCH="/gemini/space/zhaozy/zhy/dataset/VIRTUAL_imagenet256_labeled.npz"
10
+
11
+ # Base directory for sample files
12
+ SAMPLE_DIR="/gemini/space/zhaozy/zhy/gzy_new/Noise_Matching/Rectified-Noise/last_samples_depth_2_gvp_0.5"
13
+
14
+ # Change to the project root directory
15
+ cd /gemini/space/zhaozy/zhy/gzy_new/Noise_Matching
16
+
17
+ # Evaluate threshold 0.0 on GPU 0
18
+ CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \
19
+ --ref_batch ${REF_BATCH} \
20
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
21
+ > eval_threshold_0.0.log 2>&1 &
22
+
23
+ # Evaluate threshold 0.15 on GPU 1
24
+ CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \
25
+ --ref_batch ${REF_BATCH} \
26
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.15-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
27
+ > eval_threshold_0.15.log 2>&1 &
28
+
29
+ # Evaluate threshold 0.25 on GPU 2
30
+ CUDA_VISIBLE_DEVICES=2 nohup python evaluator.py \
31
+ --ref_batch ${REF_BATCH} \
32
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.25-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
33
+ > eval_threshold_0.25.log 2>&1 &
34
+
35
+ # Evaluate threshold 0.5 on GPU 3
36
+ CUDA_VISIBLE_DEVICES=3 nohup python evaluator.py \
37
+ --ref_batch ${REF_BATCH} \
38
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.5-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
39
+ > eval_threshold_0.5.log 2>&1 &
40
+
41
+ # Evaluate threshold 0.75 on GPU 4
42
+ CUDA_VISIBLE_DEVICES=0 nohup python evaluator.py \
43
+ --ref_batch ${REF_BATCH} \
44
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-0.75-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
45
+ > eval_threshold_0.75.log 2>&1 &
46
+
47
+ # Evaluate threshold 1.0 on GPU 5
48
+ CUDA_VISIBLE_DEVICES=1 nohup python evaluator.py \
49
+ --ref_batch ${REF_BATCH} \
50
+ --sample_batch ${SAMPLE_DIR}/depth-mu-2-threshold-1.0-0550000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04.npz \
51
+ > eval_threshold_1.0.log 2>&1 &
52
+
53
+ # Wait for all background jobs to complete
54
+ echo "All evaluation tasks started. Waiting for completion..."
55
+ wait
56
+
57
+ echo "All evaluation tasks completed!"
58
+ echo ""
59
+ echo "Results saved in:"
60
+ echo " - eval_threshold_0.0.log"
61
+ echo " - eval_threshold_0.15.log"
62
+ echo " - eval_threshold_0.25.log"
63
+ echo " - eval_threshold_0.5.log"
64
+ echo " - eval_threshold_0.75.log"
65
+ echo " - eval_threshold_1.0.log"
GVP/Baseline/evaluator.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
5
+ import random
6
+ import warnings
7
+ import zipfile
8
+ from abc import ABC, abstractmethod
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+ from multiprocessing import cpu_count
12
+ from multiprocessing.pool import ThreadPool
13
+ from typing import Iterable, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import requests
17
+ import tensorflow.compat.v1 as tf
18
+ from scipy import linalg
19
+ from tqdm.auto import tqdm
20
+ from datetime import timedelta
21
+ import torch
22
+
23
+
24
+
25
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
26
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
27
+
28
+ FID_POOL_NAME = "pool_3:0"
29
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--ref_batch", default='/gemini/space/gzy_new/models/reference/VIRTUAL_imagenet256_labeled.npz',help="path to reference batch npz file")
35
+ parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Baseline/GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04.npz', help="path to sample batch npz file")
36
+ args = parser.parse_args()
37
+
38
+ config = tf.ConfigProto(
39
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
40
+ )
41
+ config.gpu_options.allow_growth = True
42
+ evaluator = Evaluator(tf.Session(config=config))
43
+
44
+ print("warming up TensorFlow...")
45
+ # This will cause TF to print a bunch of verbose stuff now rather
46
+ # than after the next print(), to help prevent confusion.
47
+ evaluator.warmup()
48
+
49
+ print("computing reference batch activations...")
50
+ ref_acts = evaluator.read_activations(args.ref_batch)
51
+ print("computing/reading reference batch statistics...")
52
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
53
+
54
+ print("computing sample batch activations...")
55
+ sample_acts = evaluator.read_activations(args.sample_batch)
56
+ print("computing/reading sample batch statistics...")
57
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
58
+
59
+ print("Computing evaluations...")
60
+ print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
61
+ print("FID:", sample_stats.frechet_distance(ref_stats))
62
+ print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
63
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
64
+ print("Precision:", prec)
65
+ print("Recall:", recall)
66
+
67
+
68
+ class InvalidFIDException(Exception):
69
+ pass
70
+
71
+
72
+ class FIDStatistics:
73
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
74
+ self.mu = mu
75
+ self.sigma = sigma
76
+
77
+ def frechet_distance(self, other, eps=1e-6):
78
+ """
79
+ Compute the Frechet distance between two sets of statistics.
80
+ """
81
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
82
+ mu1, sigma1 = self.mu, self.sigma
83
+ mu2, sigma2 = other.mu, other.sigma
84
+
85
+ mu1 = np.atleast_1d(mu1)
86
+ mu2 = np.atleast_1d(mu2)
87
+
88
+ sigma1 = np.atleast_2d(sigma1)
89
+ sigma2 = np.atleast_2d(sigma2)
90
+
91
+ assert (
92
+ mu1.shape == mu2.shape
93
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
94
+ assert (
95
+ sigma1.shape == sigma2.shape
96
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
97
+
98
+ diff = mu1 - mu2
99
+
100
+ # product might be almost singular
101
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
102
+ if not np.isfinite(covmean).all():
103
+ msg = (
104
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
105
+ % eps
106
+ )
107
+ warnings.warn(msg)
108
+ offset = np.eye(sigma1.shape[0]) * eps
109
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
110
+
111
+ # numerical error might give slight imaginary component
112
+ #虚部报错部分
113
+ if np.iscomplexobj(covmean):
114
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1):
115
+ m = np.max(np.abs(covmean.imag))
116
+ print(f"Real component: {covmean.real}")
117
+ raise ValueError("Imaginary component {}".format(m))
118
+ covmean = covmean.real
119
+
120
+ tr_covmean = np.trace(covmean)
121
+
122
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
123
+
124
+
125
+ class Evaluator:
126
+ def __init__(
127
+ self,
128
+ session,
129
+ batch_size=64,
130
+ softmax_batch_size=512,
131
+ ):
132
+ self.sess = session
133
+ self.batch_size = batch_size
134
+ self.softmax_batch_size = softmax_batch_size
135
+ self.manifold_estimator = ManifoldEstimator(session)
136
+ with self.sess.graph.as_default():
137
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
138
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
139
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
140
+ self.softmax = _create_softmax_graph(self.softmax_input)
141
+
142
+ def warmup(self):
143
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
144
+
145
+ def read_activations(self, npz_path: Union[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
146
+ if isinstance(npz_path, str):
147
+ # If npz_path is a string, treat it as a file path and read the .npz file
148
+ with open_npz_array(npz_path, "arr_0") as reader:
149
+ return self.compute_activations(reader.read_batches(self.batch_size))
150
+ elif isinstance(npz_path, np.ndarray):
151
+ # If npz_path is a numpy array, split it into batches manually
152
+ print("--------line 140-----------")
153
+ batches = np.array_split(npz_path, range(self.batch_size, npz_path.shape[0], self.batch_size))
154
+ print("--------line 143-----------")
155
+ return self.compute_activations(batches)
156
+ else:
157
+ raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)")
158
+
159
+
160
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
161
+ """
162
+ Compute image features for downstream evals.
163
+
164
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
165
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
166
+ dimension. The tuple is (pool_3, spatial).
167
+ """
168
+ preds = []
169
+ spatial_preds = []
170
+ for batch in tqdm(batches):
171
+ # print("--------line 164-----------")
172
+
173
+ # # 识别当前进程信息
174
+ # if 'RANK' in os.environ:
175
+ # rank = int(os.environ['RANK'])
176
+ # local_rank = int(os.environ.get('LOCAL_RANK', rank % torch.cuda.device_count()))
177
+ # print(f"Distributed training - Global Rank: {rank}, Local Rank: {local_rank}")
178
+ # print(f"Current GPU device: {torch.cuda.current_device()}" if torch.cuda.is_available() else "No CUDA")
179
+ # else:
180
+ # print("Single process mode")
181
+
182
+ # print(f"Process PID: {os.getpid()}")
183
+
184
+ batch = batch.astype(np.float32)
185
+ pred, spatial_pred = self.sess.run(
186
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
187
+ )
188
+ # print("--------line 169-----------")
189
+ preds.append(pred.reshape([pred.shape[0], -1]))
190
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
191
+ return (
192
+ np.concatenate(preds, axis=0),
193
+ np.concatenate(spatial_preds, axis=0),
194
+ )
195
+
196
+ def read_statistics(
197
+ self, npz_path: Union[str, np.ndarray], activations: Tuple[np.ndarray, np.ndarray]
198
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
199
+ if isinstance(npz_path, str):
200
+ obj = np.load(npz_path)
201
+ if "mu" in list(obj.keys()):
202
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
203
+ obj["mu_s"], obj["sigma_s"]
204
+ )
205
+ elif isinstance(npz_path, np.ndarray):
206
+ obj = npz_path
207
+ else:
208
+ raise ValueError("npz_path must be either a file path (str) or a numpy array (np.ndarray)")
209
+ return tuple(self.compute_statistics(x) for x in activations)
210
+
211
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
212
+ mu = np.mean(activations, axis=0)
213
+ sigma = np.cov(activations, rowvar=False)
214
+ return FIDStatistics(mu, sigma)
215
+
216
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
217
+ softmax_out = []
218
+ for i in range(0, len(activations), self.softmax_batch_size):
219
+ acts = activations[i : i + self.softmax_batch_size]
220
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
221
+ preds = np.concatenate(softmax_out, axis=0)
222
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
223
+ scores = []
224
+ for i in range(0, len(preds), split_size):
225
+ part = preds[i : i + split_size]
226
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
227
+ kl = np.mean(np.sum(kl, 1))
228
+ scores.append(np.exp(kl))
229
+ return float(np.mean(scores))
230
+
231
+ def compute_prec_recall(
232
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
233
+ ) -> Tuple[float, float]:
234
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
235
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
236
+ pr = self.manifold_estimator.evaluate_pr(
237
+ activations_ref, radii_1, activations_sample, radii_2
238
+ )
239
+ return (float(pr[0][0]), float(pr[1][0]))
240
+
241
+
242
+ class ManifoldEstimator:
243
+ """
244
+ A helper for comparing manifolds of feature vectors.
245
+
246
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ session,
252
+ row_batch_size=10000,
253
+ col_batch_size=10000,
254
+ nhood_sizes=(3,),
255
+ clamp_to_percentile=None,
256
+ eps=1e-5,
257
+ ):
258
+ """
259
+ Estimate the manifold of given feature vectors.
260
+
261
+ :param session: the TensorFlow session.
262
+ :param row_batch_size: row batch size to compute pairwise distances
263
+ (parameter to trade-off between memory usage and performance).
264
+ :param col_batch_size: column batch size to compute pairwise distances.
265
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
266
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
267
+ the given percentile.
268
+ :param eps: small number for numerical stability.
269
+ """
270
+ self.distance_block = DistanceBlock(session)
271
+ self.row_batch_size = row_batch_size
272
+ self.col_batch_size = col_batch_size
273
+ self.nhood_sizes = nhood_sizes
274
+ self.num_nhoods = len(nhood_sizes)
275
+ self.clamp_to_percentile = clamp_to_percentile
276
+ self.eps = eps
277
+
278
+ def warmup(self):
279
+ feats, radii = (
280
+ np.zeros([1, 2048], dtype=np.float32),
281
+ np.zeros([1, 1], dtype=np.float32),
282
+ )
283
+ self.evaluate_pr(feats, radii, feats, radii)
284
+
285
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
286
+ num_images = len(features)
287
+
288
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
289
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
290
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
291
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
292
+
293
+ for begin1 in range(0, num_images, self.row_batch_size):
294
+ end1 = min(begin1 + self.row_batch_size, num_images)
295
+ row_batch = features[begin1:end1]
296
+
297
+ for begin2 in range(0, num_images, self.col_batch_size):
298
+ end2 = min(begin2 + self.col_batch_size, num_images)
299
+ col_batch = features[begin2:end2]
300
+
301
+ # Compute distances between batches.
302
+ distance_batch[
303
+ 0 : end1 - begin1, begin2:end2
304
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
305
+
306
+ # Find the k-nearest neighbor from the current batch.
307
+ radii[begin1:end1, :] = np.concatenate(
308
+ [
309
+ x[:, self.nhood_sizes]
310
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
311
+ ],
312
+ axis=0,
313
+ )
314
+
315
+ if self.clamp_to_percentile is not None:
316
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
317
+ radii[radii > max_distances] = 0
318
+ return radii
319
+
320
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
321
+ """
322
+ Evaluate if new feature vectors are at the manifold.
323
+ """
324
+ num_eval_images = eval_features.shape[0]
325
+ num_ref_images = radii.shape[0]
326
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
327
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
328
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
329
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
330
+
331
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
332
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
333
+ feature_batch = eval_features[begin1:end1]
334
+
335
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
336
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
337
+ ref_batch = features[begin2:end2]
338
+
339
+ distance_batch[
340
+ 0 : end1 - begin1, begin2:end2
341
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
342
+
343
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
344
+ # If a feature vector is inside a hypersphere of some reference sample, then
345
+ # the new sample lies at the estimated manifold.
346
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
347
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
348
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
349
+
350
+ max_realism_score[begin1:end1] = np.max(
351
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
352
+ )
353
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
354
+
355
+ return {
356
+ "fraction": float(np.mean(batch_predictions)),
357
+ "batch_predictions": batch_predictions,
358
+ "max_realisim_score": max_realism_score,
359
+ "nearest_indices": nearest_indices,
360
+ }
361
+
362
+ def evaluate_pr(
363
+ self,
364
+ features_1: np.ndarray,
365
+ radii_1: np.ndarray,
366
+ features_2: np.ndarray,
367
+ radii_2: np.ndarray,
368
+ ) -> Tuple[np.ndarray, np.ndarray]:
369
+ """
370
+ Evaluate precision and recall efficiently.
371
+
372
+ :param features_1: [N1 x D] feature vectors for reference batch.
373
+ :param radii_1: [N1 x K1] radii for reference vectors.
374
+ :param features_2: [N2 x D] feature vectors for the other batch.
375
+ :param radii_2: [N x K2] radii for other vectors.
376
+ :return: a tuple of arrays for (precision, recall):
377
+ - precision: an np.ndarray of length K1
378
+ - recall: an np.ndarray of length K2
379
+ """
380
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
381
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
382
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
383
+ end_1 = begin_1 + self.row_batch_size
384
+ batch_1 = features_1[begin_1:end_1]
385
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
386
+ end_2 = begin_2 + self.col_batch_size
387
+ batch_2 = features_2[begin_2:end_2]
388
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
389
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
390
+ )
391
+ features_1_status[begin_1:end_1] |= batch_1_in
392
+ features_2_status[begin_2:end_2] |= batch_2_in
393
+ return (
394
+ np.mean(features_2_status.astype(np.float64), axis=0),
395
+ np.mean(features_1_status.astype(np.float64), axis=0),
396
+ )
397
+
398
+
399
+ class DistanceBlock:
400
+ """
401
+ Calculate pairwise distances between vectors.
402
+
403
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
404
+ """
405
+
406
+ def __init__(self, session):
407
+ self.session = session
408
+
409
+ # Initialize TF graph to calculate pairwise distances.
410
+ with session.graph.as_default():
411
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
412
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
413
+ distance_block_16 = _batch_pairwise_distances(
414
+ tf.cast(self._features_batch1, tf.float16),
415
+ tf.cast(self._features_batch2, tf.float16),
416
+ )
417
+ self.distance_block = tf.cond(
418
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
419
+ lambda: tf.cast(distance_block_16, tf.float32),
420
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
421
+ )
422
+
423
+ # Extra logic for less thans.
424
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
425
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
426
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
427
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
428
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
429
+
430
+ def pairwise_distances(self, U, V):
431
+ """
432
+ Evaluate pairwise distances between two batches of feature vectors.
433
+ """
434
+ return self.session.run(
435
+ self.distance_block,
436
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
437
+ )
438
+
439
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
440
+ return self.session.run(
441
+ [self._batch_1_in, self._batch_2_in],
442
+ feed_dict={
443
+ self._features_batch1: batch_1,
444
+ self._features_batch2: batch_2,
445
+ self._radii1: radii_1,
446
+ self._radii2: radii_2,
447
+ },
448
+ )
449
+
450
+
451
+ def _batch_pairwise_distances(U, V):
452
+ """
453
+ Compute pairwise distances between two batches of feature vectors.
454
+ """
455
+ with tf.variable_scope("pairwise_dist_block"):
456
+ # Squared norms of each row in U and V.
457
+ norm_u = tf.reduce_sum(tf.square(U), 1)
458
+ norm_v = tf.reduce_sum(tf.square(V), 1)
459
+
460
+ # norm_u as a column and norm_v as a row vectors.
461
+ norm_u = tf.reshape(norm_u, [-1, 1])
462
+ norm_v = tf.reshape(norm_v, [1, -1])
463
+
464
+ # Pairwise squared Euclidean distances.
465
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
466
+
467
+ return D
468
+
469
+
470
+ class NpzArrayReader(ABC):
471
+ @abstractmethod
472
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
473
+ pass
474
+
475
+ @abstractmethod
476
+ def remaining(self) -> int:
477
+ pass
478
+
479
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
480
+ def gen_fn():
481
+ while True:
482
+ batch = self.read_batch(batch_size)
483
+ if batch is None:
484
+ break
485
+ yield batch
486
+
487
+ rem = self.remaining()
488
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
489
+ return BatchIterator(gen_fn, num_batches)
490
+
491
+
492
+ class BatchIterator:
493
+ def __init__(self, gen_fn, length):
494
+ self.gen_fn = gen_fn
495
+ self.length = length
496
+
497
+ def __len__(self):
498
+ return self.length
499
+
500
+ def __iter__(self):
501
+ return self.gen_fn()
502
+
503
+
504
+ class StreamingNpzArrayReader(NpzArrayReader):
505
+ def __init__(self, arr_f, shape, dtype):
506
+ self.arr_f = arr_f
507
+ self.shape = shape
508
+ self.dtype = dtype
509
+ self.idx = 0
510
+
511
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
512
+ if self.idx >= self.shape[0]:
513
+ return None
514
+
515
+ bs = min(batch_size, self.shape[0] - self.idx)
516
+ self.idx += bs
517
+
518
+ if self.dtype.itemsize == 0:
519
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
520
+
521
+ read_count = bs * np.prod(self.shape[1:])
522
+ read_size = int(read_count * self.dtype.itemsize)
523
+ data = _read_bytes(self.arr_f, read_size, "array data")
524
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
525
+
526
+ def remaining(self) -> int:
527
+ return max(0, self.shape[0] - self.idx)
528
+
529
+
530
+ class MemoryNpzArrayReader(NpzArrayReader):
531
+ def __init__(self, arr):
532
+ self.arr = arr
533
+ self.idx = 0
534
+
535
+ @classmethod
536
+ def load(cls, path: str, arr_name: str):
537
+ with open(path, "rb") as f:
538
+ arr = np.load(f)[arr_name]
539
+ return cls(arr)
540
+
541
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
542
+ if self.idx >= self.arr.shape[0]:
543
+ return None
544
+
545
+ res = self.arr[self.idx : self.idx + batch_size]
546
+ self.idx += batch_size
547
+ return res
548
+
549
+ def remaining(self) -> int:
550
+ return max(0, self.arr.shape[0] - self.idx)
551
+
552
+
553
+ @contextmanager
554
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
555
+ with _open_npy_file(path, arr_name) as arr_f:
556
+ version = np.lib.format.read_magic(arr_f)
557
+ if version == (1, 0):
558
+ header = np.lib.format.read_array_header_1_0(arr_f)
559
+ elif version == (2, 0):
560
+ header = np.lib.format.read_array_header_2_0(arr_f)
561
+ else:
562
+ yield MemoryNpzArrayReader.load(path, arr_name)
563
+ return
564
+ shape, fortran, dtype = header
565
+ if fortran or dtype.hasobject:
566
+ yield MemoryNpzArrayReader.load(path, arr_name)
567
+ else:
568
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
569
+
570
+
571
+ def _read_bytes(fp, size, error_template="ran out of data"):
572
+ """
573
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
574
+
575
+ Read from file-like object until size bytes are read.
576
+ Raises ValueError if not EOF is encountered before size bytes are read.
577
+ Non-blocking objects only supported if they derive from io objects.
578
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
579
+ requested.
580
+ """
581
+ data = bytes()
582
+ while True:
583
+ # io files (default in python3) return None or raise on
584
+ # would-block, python2 file will truncate, probably nothing can be
585
+ # done about that. note that regular files can't be non-blocking
586
+ try:
587
+ r = fp.read(size - len(data))
588
+ data += r
589
+ if len(r) == 0 or len(data) == size:
590
+ break
591
+ except io.BlockingIOError:
592
+ pass
593
+ if len(data) != size:
594
+ msg = "EOF: reading %s, expected %d bytes got %d"
595
+ raise ValueError(msg % (error_template, size, len(data)))
596
+ else:
597
+ return data
598
+
599
+
600
+ @contextmanager
601
+ def _open_npy_file(path: str, arr_name: str):
602
+ with open(path, "rb") as f:
603
+ with zipfile.ZipFile(f, "r") as zip_f:
604
+ if f"{arr_name}.npy" not in zip_f.namelist():
605
+ raise ValueError(f"missing {arr_name} in npz file")
606
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
607
+ yield arr_f
608
+
609
+
610
+ def _download_inception_model():
611
+ if os.path.exists(INCEPTION_V3_PATH):
612
+ return
613
+ print("downloading InceptionV3 model...")
614
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
615
+ r.raise_for_status()
616
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
617
+ with open(tmp_path, "wb") as f:
618
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
619
+ f.write(chunk)
620
+ os.rename(tmp_path, INCEPTION_V3_PATH)
621
+
622
+
623
+ def _create_feature_graph(input_batch):
624
+ _download_inception_model()
625
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
626
+ with open(INCEPTION_V3_PATH, "rb") as f:
627
+ graph_def = tf.GraphDef()
628
+ graph_def.ParseFromString(f.read())
629
+ pool3, spatial = tf.import_graph_def(
630
+ graph_def,
631
+ input_map={f"ExpandDims:0": input_batch},
632
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
633
+ name=prefix,
634
+ )
635
+ _update_shapes(pool3)
636
+ spatial = spatial[..., :7]
637
+ return pool3, spatial
638
+
639
+
640
+ def _create_softmax_graph(input_batch):
641
+ _download_inception_model()
642
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
643
+ with open(INCEPTION_V3_PATH, "rb") as f:
644
+ graph_def = tf.GraphDef()
645
+ graph_def.ParseFromString(f.read())
646
+ (matmul,) = tf.import_graph_def(
647
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
648
+ )
649
+ w = matmul.inputs[1]
650
+ logits = tf.matmul(input_batch, w)
651
+ return tf.nn.softmax(logits)
652
+
653
+
654
+ def _update_shapes(pool3):
655
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
656
+ ops = pool3.graph.get_operations()
657
+ for op in ops:
658
+ for o in op.outputs:
659
+ shape = o.get_shape()
660
+ if shape._dims is not None: # pylint: disable=protected-access
661
+ # shape = [s.value for s in shape] TF 1.x
662
+ shape = [s for s in shape] # TF 2.x
663
+ new_shape = []
664
+ for j, s in enumerate(shape):
665
+ if s == 1 and j == 0:
666
+ new_shape.append(None)
667
+ else:
668
+ new_shape.append(s)
669
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
670
+ return pool3
671
+
672
+
673
+ def _numpy_partition(arr, kth, **kwargs):
674
+ num_workers = min(cpu_count(), len(arr))
675
+ chunk_size = len(arr) // num_workers
676
+ extra = len(arr) % num_workers
677
+
678
+ start_idx = 0
679
+ batches = []
680
+ for i in range(num_workers):
681
+ size = chunk_size + (1 if i < extra else 0)
682
+ batches.append(arr[start_idx : start_idx + size])
683
+ start_idx += size
684
+
685
+ with ThreadPool(num_workers) as pool:
686
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
687
+
688
+
689
+ if __name__ == "__main__":
690
+ main()
GVP/Baseline/gvp_sampling.log ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/1042 [00:00<?, ?it/s]
1
  0%| | 1/1042 [00:59<17:12:21, 59.50s/it]
2
  0%| | 2/1042 [01:58<17:01:08, 58.91s/it]
3
  0%| | 3/1042 [02:57<17:02:36, 59.05s/it]
4
  0%| | 4/1042 [03:56<17:00:31, 58.99s/it]
5
  0%| | 5/1042 [04:52<16:42:43, 58.02s/it]
6
  1%| | 6/1042 [05:51<16:47:28, 58.35s/it]
7
  1%| | 7/1042 [06:50<16:53:25, 58.75s/it]
8
  1%| | 8/1042 [07:49<16:49:35, 58.58s/it]
9
  1%| | 9/1042 [08:48<16:51:57, 58.78s/it]
10
  1%| | 10/1042 [09:46<16:49:39, 58.70s/it]
11
  1%| | 11/1042 [10:45<16:48:10, 58.67s/it]
12
  1%| | 12/1042 [11:45<16:53:10, 59.02s/it]
13
  1%| | 13/1042 [12:44<16:52:36, 59.04s/it]
14
  1%|▏ | 14/1042 [13:43<16:50:11, 58.96s/it]W0407 16:34:53.638000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2845 closing signal SIGTERM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793]
2
+ W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0407 16:17:44.645000 2760 site-packages/torch/distributed/run.py:793] *****************************************
5
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
6
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
7
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
8
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
9
+ Starting rank=0, seed=0, world_size=4.
10
+ Starting rank=2, seed=2, world_size=4.
11
+ Starting rank=1, seed=1, world_size=4.
12
+ Starting rank=3, seed=3, world_size=4.
13
+ [rank1]:[W407 16:20:17.131912166 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
14
+ [rank3]:[W407 16:20:17.153628536 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
15
+ Saving .png samples at baseline_gvp_/SiT-XL-2-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04
16
+ [rank0]:[W407 16:20:17.306737681 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
17
+ [rank2]:[W407 16:20:18.780347929 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
18
+ Total number of images that will be sampled: 50016
19
+
20
  0%| | 0/1042 [00:00<?, ?it/s]
21
  0%| | 1/1042 [00:59<17:12:21, 59.50s/it]
22
  0%| | 2/1042 [01:58<17:01:08, 58.91s/it]
23
  0%| | 3/1042 [02:57<17:02:36, 59.05s/it]
24
  0%| | 4/1042 [03:56<17:00:31, 58.99s/it]
25
  0%| | 5/1042 [04:52<16:42:43, 58.02s/it]
26
  1%| | 6/1042 [05:51<16:47:28, 58.35s/it]
27
  1%| | 7/1042 [06:50<16:53:25, 58.75s/it]
28
  1%| | 8/1042 [07:49<16:49:35, 58.58s/it]
29
  1%| | 9/1042 [08:48<16:51:57, 58.78s/it]
30
  1%| | 10/1042 [09:46<16:49:39, 58.70s/it]
31
  1%| | 11/1042 [10:45<16:48:10, 58.67s/it]
32
  1%| | 12/1042 [11:45<16:53:10, 59.02s/it]
33
  1%| | 13/1042 [12:44<16:52:36, 59.04s/it]
34
  1%|▏ | 14/1042 [13:43<16:50:11, 58.96s/it]W0407 16:34:53.638000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2845 closing signal SIGTERM
35
+ W0407 16:34:53.639000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2847 closing signal SIGTERM
36
+ W0407 16:34:53.639000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2848 closing signal SIGTERM
37
+ E0407 16:34:53.854000 2760 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -9) local_rank: 1 (pid: 2846) of binary: /root/miniconda3/envs/SiT/bin/python3.10
38
+ Traceback (most recent call last):
39
+ File "/root/miniconda3/envs/SiT/bin/torchrun", line 6, in <module>
40
+ sys.exit(main())
41
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
42
+ return f(*args, **kwargs)
43
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
44
+ run(args)
45
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
46
+ elastic_launch(
47
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
48
+ return launch_agent(self._config, self._entrypoint, list(args))
49
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
50
+ raise ChildFailedError(
51
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
52
+ ==========================================================
53
+ sample_ddp.py FAILED
54
+ ----------------------------------------------------------
55
+ Failures:
56
+ <NO_OTHER_FAILURES>
57
+ ----------------------------------------------------------
58
+ Root Cause (first observed failure):
59
+ [0]:
60
+ time : 2026-04-07_16:34:53
61
+ host : 280c8972fe62c4ab251b3c74bd05a546-taskrole1-0
62
+ rank : 1 (local_rank: 1)
63
+ exitcode : -9 (pid: 2846)
64
+ error_file: <N/A>
65
+ traceback : Signal 9 (SIGKILL) received by PID 2846
66
+ ==========================================================
GVP/Baseline/models.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import math
13
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14
+
15
+
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18
+
19
+
20
+ #################################################################################
21
+ # Embedding Layers for Timesteps and Class Labels #
22
+ #################################################################################
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def timestep_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class LabelEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, num_classes, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
72
+ self.num_classes = num_classes
73
+ self.dropout_prob = dropout_prob
74
+
75
+ def token_drop(self, labels, force_drop_ids=None):
76
+ """
77
+ Drops labels to enable classifier-free guidance.
78
+ """
79
+ if force_drop_ids is None:
80
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
81
+ else:
82
+ drop_ids = force_drop_ids == 1
83
+ labels = torch.where(drop_ids, self.num_classes, labels)
84
+ return labels
85
+
86
+ def forward(self, labels, train, force_drop_ids=None):
87
+ use_dropout = self.dropout_prob > 0
88
+ if (train and use_dropout) or (force_drop_ids is not None):
89
+ labels = self.token_drop(labels, force_drop_ids)
90
+ embeddings = self.embedding_table(labels)
91
+ return embeddings
92
+
93
+
94
+ #################################################################################
95
+ # Core SiT Model #
96
+ #################################################################################
97
+
98
+ class SiTBlock(nn.Module):
99
+ """
100
+ A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
101
+ """
102
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
103
+ super().__init__()
104
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
105
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
106
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
107
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
108
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
109
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
110
+ self.adaLN_modulation = nn.Sequential(
111
+ nn.SiLU(),
112
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
113
+ )
114
+
115
+ def forward(self, x, c):
116
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
117
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
118
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
119
+ return x
120
+
121
+
122
+ class FinalLayer(nn.Module):
123
+ """
124
+ The final layer of SiT.
125
+ """
126
+ def __init__(self, hidden_size, patch_size, out_channels):
127
+ super().__init__()
128
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
129
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
130
+ self.adaLN_modulation = nn.Sequential(
131
+ nn.SiLU(),
132
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
133
+ )
134
+
135
+ def forward(self, x, c):
136
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
137
+ x = modulate(self.norm_final(x), shift, scale)
138
+ x = self.linear(x)
139
+ return x
140
+
141
+
142
+ class SiT(nn.Module):
143
+ """
144
+ Diffusion model with a Transformer backbone.
145
+ """
146
+ def __init__(
147
+ self,
148
+ input_size=32,
149
+ patch_size=2,
150
+ in_channels=4,
151
+ hidden_size=1152,
152
+ depth=28,
153
+ num_heads=16,
154
+ mlp_ratio=4.0,
155
+ class_dropout_prob=0.1,
156
+ num_classes=1000,
157
+ learn_sigma=True,
158
+ ):
159
+ super().__init__()
160
+ self.learn_sigma = learn_sigma
161
+ self.learn_sigma = True
162
+ self.in_channels = in_channels
163
+ self.out_channels = in_channels * 2
164
+ self.patch_size = patch_size
165
+ self.num_heads = num_heads
166
+
167
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
168
+ self.t_embedder = TimestepEmbedder(hidden_size)
169
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
170
+ num_patches = self.x_embedder.num_patches
171
+ # Will use fixed sin-cos embedding:
172
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
173
+
174
+ self.blocks = nn.ModuleList([
175
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
176
+ ])
177
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
178
+ self.initialize_weights()
179
+
180
+ def initialize_weights(self):
181
+ # Initialize transformer layers:
182
+ def _basic_init(module):
183
+ if isinstance(module, nn.Linear):
184
+ torch.nn.init.xavier_uniform_(module.weight)
185
+ if module.bias is not None:
186
+ nn.init.constant_(module.bias, 0)
187
+ self.apply(_basic_init)
188
+
189
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
190
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
191
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
192
+
193
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
194
+ w = self.x_embedder.proj.weight.data
195
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
196
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
197
+
198
+ # Initialize label embedding table:
199
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
200
+
201
+ # Initialize timestep embedding MLP:
202
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204
+
205
+ # Zero-out adaLN modulation layers in SiT blocks:
206
+ for block in self.blocks:
207
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
208
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
209
+
210
+ # Zero-out output layers:
211
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
213
+ nn.init.constant_(self.final_layer.linear.weight, 0)
214
+ nn.init.constant_(self.final_layer.linear.bias, 0)
215
+
216
+ def unpatchify(self, x):
217
+ """
218
+ x: (N, T, patch_size**2 * C)
219
+ imgs: (N, H, W, C)
220
+ """
221
+ c = self.out_channels
222
+ p = self.x_embedder.patch_size[0]
223
+ h = w = int(x.shape[1] ** 0.5)
224
+ assert h * w == x.shape[1]
225
+
226
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
229
+ return imgs
230
+
231
+ def forward(self, x, t, y, return_act=False):
232
+ """
233
+ Forward pass of SiT.
234
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
235
+ t: (N,) tensor of diffusion timesteps
236
+ y: (N,) tensor of class labels
237
+ return_act: if True, return activations from transformer blocks
238
+ """
239
+ act = []
240
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
241
+ t = self.t_embedder(t) # (N, D)
242
+ y = self.y_embedder(y, self.training) # (N, D)
243
+ c = t + y # (N, D)
244
+ for block in self.blocks:
245
+ x = block(x, c) # (N, T, D)
246
+ if return_act:
247
+ act.append(x)
248
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249
+ x = self.unpatchify(x) # (N, out_channels, H, W)
250
+ if self.learn_sigma:
251
+ x, _ = x.chunk(2, dim=1)
252
+ if return_act:
253
+ return x, act
254
+ return x
255
+
256
+ def forward_with_cfg(self, x, t, y, cfg_scale):
257
+ """
258
+ Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
259
+ """
260
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
261
+ half = x[: len(x) // 2]
262
+ combined = torch.cat([half, half], dim=0)
263
+ model_out = self.forward(combined, t, y)
264
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
265
+ # three channels by default. The standard approach to cfg applies it to all channels.
266
+ # This can be done by uncommenting the following line and commenting-out the line following that.
267
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
268
+ eps, rest = model_out[:, :3], model_out[:, 3:]
269
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
270
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
271
+ eps = torch.cat([half_eps, half_eps], dim=0)
272
+ return torch.cat([eps, rest], dim=1)
273
+
274
+
275
+ #################################################################################
276
+ # Sine/Cosine Positional Embedding Functions #
277
+ #################################################################################
278
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
279
+
280
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
281
+ """
282
+ grid_size: int of the grid height and width
283
+ return:
284
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
285
+ """
286
+ grid_h = np.arange(grid_size, dtype=np.float32)
287
+ grid_w = np.arange(grid_size, dtype=np.float32)
288
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
289
+ grid = np.stack(grid, axis=0)
290
+
291
+ grid = grid.reshape([2, 1, grid_size, grid_size])
292
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
293
+ if cls_token and extra_tokens > 0:
294
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
295
+ return pos_embed
296
+
297
+
298
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
299
+ assert embed_dim % 2 == 0
300
+
301
+ # use half of dimensions to encode grid_h
302
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
303
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
304
+
305
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
306
+ return emb
307
+
308
+
309
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
310
+ """
311
+ embed_dim: output dimension for each position
312
+ pos: a list of positions to be encoded: size (M,)
313
+ out: (M, D)
314
+ """
315
+ assert embed_dim % 2 == 0
316
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
317
+ omega /= embed_dim / 2.
318
+ omega = 1. / 10000**omega # (D/2,)
319
+
320
+ pos = pos.reshape(-1) # (M,)
321
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
322
+
323
+ emb_sin = np.sin(out) # (M, D/2)
324
+ emb_cos = np.cos(out) # (M, D/2)
325
+
326
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
327
+ return emb
328
+
329
+
330
+ #################################################################################
331
+ # SiT Configs #
332
+ #################################################################################
333
+
334
+ def SiT_XL_2(**kwargs):
335
+ return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
336
+
337
+ def SiT_XL_4(**kwargs):
338
+ return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
339
+
340
+ def SiT_XL_8(**kwargs):
341
+ return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
342
+
343
+ def SiT_L_2(**kwargs):
344
+ return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
345
+
346
+ def SiT_L_4(**kwargs):
347
+ return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
348
+
349
+ def SiT_L_8(**kwargs):
350
+ return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
351
+
352
+ def SiT_B_2(**kwargs):
353
+ return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
354
+
355
+ def SiT_B_4(**kwargs):
356
+ return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
357
+
358
+ def SiT_B_8(**kwargs):
359
+ return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
360
+
361
+ def SiT_S_2(**kwargs):
362
+ return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
363
+
364
+ def SiT_S_4(**kwargs):
365
+ return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
366
+
367
+ def SiT_S_8(**kwargs):
368
+ return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
369
+
370
+
371
+ SiT_models = {
372
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
373
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
374
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
375
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
376
+ }
377
+
378
+ #################################################################################
379
+ # SiTF1, SiTF2, CombinedModel #
380
+ #################################################################################
381
+
382
+ class SiTF1(nn.Module):
383
+ """
384
+ SiTF1 Model
385
+ """
386
+ def __init__(
387
+ self,
388
+ input_size=32,
389
+ patch_size=2,
390
+ in_channels=4,
391
+ hidden_size=1152,
392
+ depth=28,
393
+ num_heads=16,
394
+ mlp_ratio=4.0,
395
+ class_dropout_prob=0.1,
396
+ num_classes=1000,
397
+ learn_sigma=True,
398
+ final_layer=None,
399
+ ):
400
+ super().__init__()
401
+ self.input_size = input_size
402
+ self.patch_size= patch_size
403
+ self.hidden_size = hidden_size
404
+ self.in_channels = in_channels
405
+ self.out_channels = in_channels * 2
406
+ self.patch_size = patch_size
407
+ self.num_heads = num_heads
408
+ self.learn_sigma = learn_sigma
409
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
410
+ self.t_embedder = TimestepEmbedder(hidden_size)
411
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
412
+ num_patches = self.x_embedder.num_patches
413
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
414
+ self.blocks = nn.ModuleList([
415
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
416
+ ])
417
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
418
+ self.initialize_weights()
419
+
420
+ def unpatchify(self, x):
421
+ """
422
+ x: (N, T, patch_size**2 * C)
423
+ imgs: (N, H, W, C)
424
+ """
425
+ c = self.out_channels
426
+ p = self.x_embedder.patch_size[0]
427
+ h = w = int(x.shape[1] ** 0.5)
428
+ assert h * w == x.shape[1]
429
+
430
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
431
+ x = torch.einsum('nhwpqc->nchpwq', x)
432
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
433
+ return imgs
434
+
435
+ def initialize_weights(self):
436
+ def _basic_init(module):
437
+ if isinstance(module, nn.Linear):
438
+ torch.nn.init.xavier_uniform_(module.weight)
439
+ if module.bias is not None:
440
+ nn.init.constant_(module.bias, 0)
441
+ self.apply(_basic_init)
442
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
443
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
444
+ w = self.x_embedder.proj.weight.data
445
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
446
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
447
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
448
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
449
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
450
+ for block in self.blocks:
451
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
452
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
453
+
454
+ def forward(self, x, t, y):
455
+ x = self.x_embedder(x) + self.pos_embed
456
+ t = self.t_embedder(t)
457
+ y = self.y_embedder(y, self.training)
458
+ c = t + y
459
+ for block in self.blocks:
460
+ x = block(x, c)
461
+ x_now = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
462
+ x_now = self.unpatchify(x_now) # (N, out_channels, H, W)
463
+ x_now, _ = x_now.chunk(2, dim=1)
464
+ return x,x_now # patch token (N, T, D)
465
+
466
+ def forward_with_cfg(self, x, t, y, cfg_scale):
467
+ """
468
+ Forward pass with classifier-free guidance for SiTF1.
469
+ Applies guidance consistently to both patch tokens and image output (x_now).
470
+ """
471
+ # Take the first half (conditional inputs) and duplicate it so that
472
+ # it can be paired with conditional and unconditional labels in `y`.
473
+ half = x[: len(x) // 2]
474
+ combined = torch.cat([half, half], dim=0)
475
+ patch_tokens, x_now = self.forward(combined, t, y)
476
+
477
+ # Apply CFG on the image output channels (first 3 channels by default)
478
+ eps, rest = x_now[:, :3, ...], x_now[:, 3:, ...]
479
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
480
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
481
+ eps = torch.cat([half_eps, half_eps], dim=0)
482
+ x_now = torch.cat([eps, rest], dim=1)
483
+
484
+ # Apply same guidance logic to patch tokens so downstream modules see
485
+ # a consistent guided representation.
486
+ cond_tok, uncond_tok = torch.split(patch_tokens, len(patch_tokens) // 2, dim=0)
487
+ half_tok = uncond_tok + cfg_scale * (cond_tok - uncond_tok)
488
+ patch_tokens = torch.cat([half_tok, half_tok], dim=0)
489
+
490
+ return patch_tokens, x_now
491
+
492
+
493
+ class SiTF2(nn.Module):
494
+ """
495
+ SiTF2:
496
+ """
497
+ def __init__(
498
+ self,
499
+ input_size=32,
500
+ hidden_size=1152,
501
+ out_channels=8,
502
+ patch_size=2,
503
+ num_heads=16,
504
+ mlp_ratio=4.0,
505
+ depth=4,
506
+ learn_sigma=True,
507
+ final_layer=None,
508
+ num_classes=1000,
509
+ class_dropout_prob=0.1,
510
+ learn_mu=False,
511
+ ):
512
+ super().__init__()
513
+ self.learn_sigma = learn_sigma
514
+ self.learn_mu = learn_mu
515
+ self.out_channels = out_channels
516
+ self.in_channels = 4
517
+ self.patch_size = patch_size
518
+ self.num_heads = num_heads
519
+ self.blocks = nn.ModuleList([
520
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
521
+ ])
522
+ self.x_embedder = PatchEmbed(input_size, patch_size, self.in_channels, hidden_size, bias=True)
523
+ self.t_embedder = TimestepEmbedder(hidden_size)
524
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
525
+ num_patches = self.x_embedder.num_patches
526
+ self.num_patches = num_patches # Save original num_patches for unpatchify
527
+ # pos_embed needs to support 2*num_patches for concatenated input
528
+ self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
529
+ # Initialize pos_embed with sin-cos embedding
530
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
531
+ # Repeat the pos_embed for both halves (or could use different embeddings)
532
+ pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
533
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
534
+
535
+ if final_layer is not None:
536
+ self.final_layer = final_layer
537
+ else:
538
+ self.final_layer = FinalLayer(hidden_size, patch_size, out_channels)
539
+ if depth !=0:
540
+ for p in self.final_layer.parameters():
541
+ if p is not None:
542
+ torch.nn.init.constant_(p, 0)
543
+
544
+ def unpatchify(self, x, patch_size, out_channels):
545
+ c = out_channels
546
+ p = patch_size
547
+ # x.shape[1] might be 2*num_patches when using concatenated input
548
+ # Use original num_patches to calculate h and w
549
+ h = w = int(self.num_patches ** 0.5)
550
+ # If input has 2*num_patches, we need to handle it
551
+ if x.shape[1] == 2 * self.num_patches:
552
+ # Take only the first half (or average, or other strategy)
553
+ # For now, we'll take the first half
554
+ x = x[:, :self.num_patches, :]
555
+ assert h * w == x.shape[1], f"Expected {h * w} patches, got {x.shape[1]}"
556
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
557
+ x = torch.einsum('nhwpqc->nchpwq', x)
558
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
559
+ return imgs
560
+
561
+ def forward(self, x, c, t, return_act=False):
562
+ act = []
563
+ for block in self.blocks:
564
+ x = block(x, c)
565
+ if return_act:
566
+ act.append(x)
567
+ x = self.final_layer(x, c)
568
+ x = self.unpatchify(x, self.patch_size, self.out_channels)
569
+ if self.learn_sigma:
570
+ mean_pred, log_var_pred = x.chunk(2, dim=1)
571
+ variance_pred = torch.exp(log_var_pred)
572
+ std_dev_pred = torch.sqrt(variance_pred)
573
+ noise = torch.randn_like(mean_pred)
574
+ #uniform_noise = torch.rand_like(mean_pred)
575
+ #uniform_noise = uniform_noise.clamp(min=1e-5, max=1-1e-5)
576
+ #gumbel_noise = -torch.log(-torch.log(uniform_noise))
577
+
578
+ if self.learn_mu==True:
579
+ resampled_x = mean_pred + std_dev_pred * noise
580
+ else:
581
+ resampled_x = std_dev_pred * noise
582
+ x = resampled_x
583
+ else:
584
+ x, _ = x.chunk(2, dim=1)
585
+ if return_act:
586
+ return x, act
587
+ return x
588
+
589
+ def forward_noise(self, x, c):
590
+ for block in self.blocks:
591
+ x = block(x, c)
592
+ x = self.final_layer(x, c)
593
+ x = self.unpatchify(x, self.patch_size, self.out_channels)
594
+ if self.learn_sigma:
595
+ mean_pred, log_var_pred = x.chunk(2, dim=1)
596
+ variance_pred = torch.exp(log_var_pred)
597
+ std_dev_pred = torch.sqrt(variance_pred)
598
+ noise = torch.randn_like(mean_pred)
599
+ if self.learn_mu==True:
600
+ resampled_x = mean_pred + std_dev_pred * noise
601
+ else:
602
+ resampled_x = std_dev_pred * noise
603
+ x = resampled_x
604
+ else:
605
+ x, _ = x.chunk(2, dim=1)
606
+ return x
607
+
608
+ #有两种写法,一种是拿理想的,一种是拿真实的,一种是拼接,一种是加和
609
+ class CombinedModel(nn.Module):
610
+ """
611
+ CombinedModel。
612
+ """
613
+ def __init__(self, sitf1: SiTF1, sitf2: SiTF2):
614
+ super().__init__()
615
+ self.sitf1 = sitf1
616
+ self.sitf2 = sitf2
617
+ input_size=self.sitf1.input_size
618
+ patch_size=self.sitf1.patch_size
619
+ hidden_size=self.sitf1.hidden_size
620
+ self.x_embedder = PatchEmbed(input_size, patch_size, 4, hidden_size, bias=True)
621
+ num_patches = self.x_embedder.num_patches
622
+ # pos_embed needs to support 2*num_patches for concatenated input
623
+ self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
624
+ # Initialize pos_embed with sin-cos embedding
625
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
626
+ # Repeat the pos_embed for both halves (or could use different embeddings)
627
+ pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
628
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
629
+
630
+ def forward(self, x, t, y, return_act=False):
631
+ patch_tokens,x_now = self.sitf1(x, t, y)
632
+ # Interpolate between x_now and x using timestep t: (1-t)*x_now + t*x
633
+ # t shape is (N,), need to broadcast to (N, 1, 1, 1) for broadcasting with image (N, C, H, W)
634
+ t_broadcast = t.view(-1, 1, 1, 1) # (N, 1, 1, 1)
635
+ # Compute interpolated input: (1-t)*x_now + t*x
636
+ x_interpolated = (1 - t_broadcast) * x_now + x
637
+ # Convert interpolated input (image format) back to patch token format (without pos_embed, will add later)
638
+ x_now_patches = self.x_embedder(x_interpolated)
639
+ # Concatenate patch_tokens and x_now_patches along the sequence dimension
640
+ concatenated_input = torch.cat([patch_tokens, x_now_patches], dim=1) # (N, 2*T, D)
641
+ # Add position embedding for the concatenated input
642
+ # Use the same pos_embed for both halves (or could use different embeddings)
643
+ concatenated_input = concatenated_input + self.pos_embed
644
+ t_emb = self.sitf1.t_embedder(t)
645
+ y_emb = self.sitf1.y_embedder(y, self.training)
646
+ c = t_emb + y_emb
647
+ return self.sitf2(concatenated_input, c, t, return_act=return_act)
GVP/Baseline/nohup.out ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/1042 [00:00<?, ?it/s]
1
  0%| | 1/1042 [00:13<4:01:42, 13.93s/it]
2
  0%| | 2/1042 [00:26<3:44:31, 12.95s/it]
3
  0%| | 3/1042 [00:39<3:43:49, 12.93s/it]W0317 10:30:53.664000 11774 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received Signals.SIGINT death signal, shutting down workers
 
 
 
 
4
  0%| | 3/1042 [00:39<3:46:17, 13.07s/it]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793]
2
+ W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0317 10:27:10.803000 11774 site-packages/torch/distributed/run.py:793] *****************************************
5
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
6
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
7
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
8
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
9
+ Starting rank=0, seed=0, world_size=4.
10
+ Starting rank=1, seed=1, world_size=4.
11
+ Starting rank=3, seed=3, world_size=4.
12
+ Starting rank=2, seed=2, world_size=4.
13
+ Saving .png samples at GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-100-Euler-sigma-Mean-0.04
14
+ Total number of images that will be sampled: 50016
15
+
16
  0%| | 0/1042 [00:00<?, ?it/s]
17
  0%| | 1/1042 [00:13<4:01:42, 13.93s/it]
18
  0%| | 2/1042 [00:26<3:44:31, 12.95s/it]
19
  0%| | 3/1042 [00:39<3:43:49, 12.93s/it]W0317 10:30:53.664000 11774 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received Signals.SIGINT death signal, shutting down workers
20
+ W0317 10:30:53.667000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11854 closing signal SIGINT
21
+ W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11855 closing signal SIGINT
22
+ W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11856 closing signal SIGINT
23
+
24
  0%| | 3/1042 [00:39<3:46:17, 13.07s/it]
25
+ W0317 10:30:53.668000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11857 closing signal SIGINT
26
+ [rank3]: Traceback (most recent call last):
27
+ [rank3]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
28
+ [rank3]: main(mode, args)
29
+ [rank3]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
30
+ [rank3]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
31
+ [rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
32
+ [rank3]: xs = _sde.sample(init, model, **model_kwargs)
33
+ [rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
34
+ [rank3]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
35
+ [rank3]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 30, in __Euler_Maruyama_step
36
+ [rank3]: w_cur = th.randn(x.size()).to(x)
37
+ [rank3]: KeyboardInterrupt
38
+ [rank0]: Traceback (most recent call last):
39
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
40
+ [rank0]: main(mode, args)
41
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
42
+ [rank0]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
43
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
44
+ [rank0]: xs = _sde.sample(init, model, **model_kwargs)
45
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
46
+ [rank0]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
47
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
48
+ [rank0]: drift = self.drift(x, t, model, **model_kwargs)
49
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
50
+ [rank0]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
51
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 247, in body_fn
52
+ [rank0]: model_output = drift_fn(x, t, model, **model_kwargs)
53
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 236, in velocity_ode
54
+ [rank0]: model_output = model(x, t, **model_kwargs)
55
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
56
+ [rank0]: sit_out = base_model.forward(x, t, y)
57
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 245, in forward
58
+ [rank0]: x = block(x, c) # (N, T, D)
59
+ [rank0]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
60
+ [rank0]: return self._call_impl(*args, **kwargs)
61
+ [rank0]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
62
+ [rank0]: return forward_call(*args, **kwargs)
63
+ [rank0]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 117, in forward
64
+ [rank0]: x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
65
+ [rank0]: KeyboardInterrupt
66
+ [rank1]: Traceback (most recent call last):
67
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
68
+ [rank1]: main(mode, args)
69
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
70
+ [rank1]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
71
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
72
+ [rank1]: xs = _sde.sample(init, model, **model_kwargs)
73
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
74
+ [rank1]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
75
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
76
+ [rank1]: drift = self.drift(x, t, model, **model_kwargs)
77
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
78
+ [rank1]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
79
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 264, in <lambda>
80
+ [rank1]: score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
81
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
82
+ [rank1]: sit_out = base_model.forward(x, t, y)
83
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 241, in forward
84
+ [rank1]: t = self.t_embedder(t) # (N, D)
85
+ [rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
86
+ [rank1]: return self._call_impl(*args, **kwargs)
87
+ [rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
88
+ [rank1]: return forward_call(*args, **kwargs)
89
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 59, in forward
90
+ [rank1]: t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
91
+ [rank1]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 49, in timestep_embedding
92
+ [rank1]: freqs = torch.exp(
93
+ [rank1]: KeyboardInterrupt
94
+ [rank2]: Traceback (most recent call last):
95
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 380, in <module>
96
+ [rank2]: main(mode, args)
97
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 312, in main
98
+ [rank2]: samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
99
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 388, in _sample
100
+ [rank2]: xs = _sde.sample(init, model, **model_kwargs)
101
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 72, in sample
102
+ [rank2]: x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
103
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/integrators.py", line 33, in __Euler_Maruyama_step
104
+ [rank2]: drift = self.drift(x, t, model, **model_kwargs)
105
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 299, in <lambda>
106
+ [rank2]: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
107
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/transport/transport.py", line 264, in <lambda>
108
+ [rank2]: score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
109
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/sample_rectified_noise.py", line 194, in combined_sampling_model
110
+ [rank2]: sit_out = base_model.forward(x, t, y)
111
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 241, in forward
112
+ [rank2]: t = self.t_embedder(t) # (N, D)
113
+ [rank2]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
114
+ [rank2]: return self._call_impl(*args, **kwargs)
115
+ [rank2]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
116
+ [rank2]: return forward_call(*args, **kwargs)
117
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 59, in forward
118
+ [rank2]: t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
119
+ [rank2]: File "/gemini/space/gzy_new/models/Baseline/models.py", line 49, in timestep_embedding
120
+ [rank2]: freqs = torch.exp(
121
+ [rank2]: KeyboardInterrupt
122
+ W0317 10:30:53.820000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11854 closing signal SIGTERM
123
+ W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11855 closing signal SIGTERM
124
+ W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11856 closing signal SIGTERM
125
+ W0317 10:30:53.895000 11774 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 11857 closing signal SIGTERM
126
+ Traceback (most recent call last):
127
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
128
+ result = self._invoke_run(role)
129
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
130
+ time.sleep(monitor_interval)
131
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
132
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
133
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
134
+
135
+ During handling of the above exception, another exception occurred:
136
+
137
+ Traceback (most recent call last):
138
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 705, in run
139
+ self._shutdown(e.sigval)
140
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
141
+ self._pcontext.close(death_sig)
142
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
143
+ self._close(death_sig=death_sig, timeout=timeout)
144
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
145
+ handler.proc.wait(time_to_wait)
146
+ File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1209, in wait
147
+ return self._wait(timeout=timeout)
148
+ File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1953, in _wait
149
+ time.sleep(delay)
150
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
151
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
152
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
153
+
154
+ During handling of the above exception, another exception occurred:
155
+
156
+ Traceback (most recent call last):
157
+ File "/root/miniconda3/envs/SiT/bin/torchrun", line 6, in <module>
158
+ sys.exit(main())
159
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
160
+ return f(*args, **kwargs)
161
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
162
+ run(args)
163
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
164
+ elastic_launch(
165
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
166
+ return launch_agent(self._config, self._entrypoint, list(args))
167
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
168
+ result = agent.run()
169
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
170
+ result = f(*args, **kwargs)
171
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py", line 710, in run
172
+ self._shutdown()
173
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
174
+ self._pcontext.close(death_sig)
175
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
176
+ self._close(death_sig=death_sig, timeout=timeout)
177
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
178
+ handler.proc.wait(time_to_wait)
179
+ File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1209, in wait
180
+ return self._wait(timeout=timeout)
181
+ File "/root/miniconda3/envs/SiT/lib/python3.10/subprocess.py", line 1953, in _wait
182
+ time.sleep(delay)
183
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
184
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
185
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 11774 got signal: 2
GVP/Baseline/pic_npz.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
4
+ 基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
5
+ 支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
6
+ 支持从 metadata.jsonl 文件读取图片路径
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ import glob
15
+ import json
16
+
17
+
18
+ def create_npz_from_metadata(*args, **kwargs):
19
+ """
20
+ 占位函数:已废弃 metadata.jsonl 功能,保留空壳避免旧脚本导入时报错。
21
+ """
22
+ raise RuntimeError("metadata.jsonl 功能已移除,请仅使用 --image_folder 方式生成 npz。")
23
+
24
+
25
+
26
+ def main():
27
+ """
28
+ 主函数:解析命令行参数并执行图片到npz的转换
29
+ """
30
+ parser = argparse.ArgumentParser(
31
+ description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
32
+ formatter_class=argparse.RawDescriptionHelpFormatter,
33
+ epilog="""
34
+ 使用示例:
35
+ python pic_npz.py /path/to/image/folder
36
+ python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
37
+ """
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--image_folder",
42
+ type=str,
43
+ default="/gemini/space/gzy_new/models/Baseline/GVP_samples/depth-mu-6-0300000-base-cfg-1.0-12-SDE-250-Euler-sigma-Mean-0.04",
44
+ help="包含PNG或JPG图片文件的文件夹路径"
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--output-dir",
49
+ type=str,
50
+ default=None,
51
+ help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)"
52
+ )
53
+
54
+ args = parser.parse_args()
55
+
56
+ try:
57
+ # 仅使用图片文件夹,不再支持 metadata.jsonl
58
+ image_folder_path = os.path.abspath(args.image_folder)
59
+
60
+ if args.output_dir:
61
+ # 如果指定了输出目录,修改生成逻辑
62
+ folder_name = os.path.basename(image_folder_path.rstrip('/'))
63
+ custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
64
+
65
+ # 创建输出目录(如果不存在)
66
+ os.makedirs(args.output_dir, exist_ok=True)
67
+
68
+ # 使用自定义输出路径版本
69
+ npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
70
+ else:
71
+ npz_path = create_npz_from_image_folder(image_folder_path)
72
+
73
+ print(f"转换完成!NPZ文件已保存至: {npz_path}")
74
+
75
+ except Exception as e:
76
+ print(f"错误: {e}")
77
+ return 1
78
+
79
+ return 0
80
+
81
+
82
+ def create_npz_from_image_folder_custom(image_folder_path, output_path):
83
+ """
84
+ 从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
85
+
86
+ Args:
87
+ image_folder_path (str): 包含图片文件的文件夹路径
88
+ output_path (str): 输出npz文件的完整路径
89
+
90
+ Returns:
91
+ str: 生成的 npz 文件路径
92
+ """
93
+ # 确保路径存在
94
+ if not os.path.exists(image_folder_path):
95
+ raise ValueError(f"文件夹路径不存在: {image_folder_path}")
96
+
97
+ # 获取所有支持的图片文件
98
+ supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
99
+ image_files = []
100
+
101
+ for extension in supported_extensions:
102
+ pattern = os.path.join(image_folder_path, extension)
103
+ image_files.extend(glob.glob(pattern))
104
+
105
+ # 按文件名排序确保一致性
106
+ image_files.sort()
107
+
108
+ if len(image_files) == 0:
109
+ raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
110
+
111
+ print(f"找到 {len(image_files)} 张图片文件")
112
+
113
+ # 读取所有图片
114
+ samples = []
115
+ for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
116
+ try:
117
+ # 打开图片并转换为RGB格式(确保一致性)
118
+ with Image.open(img_path) as img:
119
+ # 转换为RGB,确保所有图片都是3通道
120
+ if img.mode != 'RGB':
121
+ img = img.convert('RGB')
122
+
123
+ # 将图片resize到512x512
124
+ img = img.resize((512, 512), Image.LANCZOS)
125
+
126
+ sample_np = np.asarray(img).astype(np.uint8)
127
+
128
+ # 确保图片是3通道
129
+ if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
130
+ print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
131
+ continue
132
+
133
+ samples.append(sample_np)
134
+
135
+ except Exception as e:
136
+ print(f"警告: 无法读取图片 {img_path}: {e}")
137
+ continue
138
+
139
+ if len(samples) == 0:
140
+ raise ValueError("没有成功读取任何有效的图片文件")
141
+
142
+ # 转换为numpy数组
143
+ samples = np.stack(samples)
144
+ print(f"成功��取 {len(samples)} 张图片,形状: {samples.shape}")
145
+
146
+ # 验证数据形状
147
+ assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
148
+ assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
149
+
150
+ # 保存为npz文件
151
+ np.savez(output_path, arr_0=samples)
152
+ print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
153
+
154
+ return output_path
155
+
156
+
157
+ def create_npz_from_image_folder(image_folder_path):
158
+ """
159
+ 从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz
160
+ """
161
+ parent_dir = os.path.dirname(os.path.abspath(image_folder_path))
162
+ folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/"))
163
+ output_path = os.path.join(parent_dir, f"{folder_name}.npz")
164
+ return create_npz_from_image_folder_custom(image_folder_path, output_path)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ exit(main())
GVP/Baseline/run.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup torchrun \
2
+ --nnodes=1 \
3
+ --nproc_per_node=4 \
4
+ --rdzv_endpoint=localhost:29739 \
5
+ train_rectified_noise.py \
6
+ --depth 6 \
7
+ --results-dir results_256_gvp_disp \
8
+ --data-path /gemini/space/gzy_new/Imagenet256/train \
9
+ --ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
10
+ --num-classes 1000 \
11
+ --path-type GVP \
12
+ --prediction velocity \
13
+ --use-ot \
14
+ --use-contrastive \
15
+ > w_training1.log 2>&1 &
GVP/Baseline/sample_compare_ddp_rectified.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from diffusers.models import AutoencoderKL
10
+ from PIL import Image
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+ from tqdm import tqdm
13
+
14
+ from download import find_model
15
+ from models import SiT_models
16
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
17
+ from transport import Sampler, create_transport
18
+
19
+
20
+ def fix_state_dict_for_ddp(state_dict):
21
+ if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict):
22
+ if "ema" in state_dict:
23
+ state_dict = state_dict["ema"]
24
+ elif "model" in state_dict:
25
+ state_dict = state_dict["model"]
26
+ fixed_state_dict = {}
27
+ for key, value in state_dict.items():
28
+ fixed_state_dict[key if key.startswith("module.") else f"module.{key}"] = value
29
+ return fixed_state_dict
30
+
31
+
32
+ def save_png_batch(samples, out_dir, rank, total_offset):
33
+ for i, sample in enumerate(samples):
34
+ index = i * dist.get_world_size() + rank + total_offset
35
+ Image.fromarray(sample).save(f"{out_dir}/{index:06d}.png")
36
+
37
+
38
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
39
+ samples = []
40
+ for i in tqdm(range(num), desc=f"Building .npz from {os.path.basename(sample_dir)}"):
41
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
42
+ samples.append(np.asarray(sample_pil).astype(np.uint8))
43
+ samples = np.stack(samples)
44
+ npz_path = f"{sample_dir}.npz"
45
+ np.savez(npz_path, arr_0=samples)
46
+ print(f"Saved .npz to {npz_path} [shape={samples.shape}]")
47
+
48
+
49
+ def main(mode, args):
50
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32
51
+ assert torch.cuda.is_available(), "This script requires at least one GPU."
52
+ torch.set_grad_enabled(False)
53
+
54
+ dist.init_process_group("nccl")
55
+ rank = dist.get_rank()
56
+ device = rank % torch.cuda.device_count()
57
+ seed = args.global_seed * dist.get_world_size() + rank
58
+ torch.manual_seed(seed)
59
+ torch.cuda.set_device(device)
60
+
61
+ latent_size = args.image_size // 8
62
+ assert args.cfg_scale >= 1.0
63
+ using_cfg = args.cfg_scale > 1.0
64
+
65
+ # ---------------- Base model (sample_ddp style) ----------------
66
+ base_model = SiT_models[args.model](
67
+ input_size=latent_size, num_classes=args.num_classes, learn_sigma=False
68
+ ).to(device)
69
+ base_state = find_model(args.ckpt)
70
+ if isinstance(base_state, dict) and "model" in base_state:
71
+ base_state = base_state["model"]
72
+ base_model.load_state_dict(base_state, strict=False)
73
+ base_model.eval()
74
+
75
+ # ---------------- Rectified model (sample_rectified style) ----------------
76
+ from models import CombinedModel, SiTF1, SiTF2
77
+
78
+ model_name = args.model
79
+ if "XL" in model_name:
80
+ hidden_size, depth, num_heads = 1152, 28, 16
81
+ elif "L" in model_name:
82
+ hidden_size, depth, num_heads = 1024, 24, 16
83
+ elif "B" in model_name:
84
+ hidden_size, depth, num_heads = 768, 12, 12
85
+ elif "S" in model_name:
86
+ hidden_size, depth, num_heads = 384, 12, 6
87
+ else:
88
+ hidden_size, depth, num_heads = 768, 12, 12
89
+ patch_size = int(model_name.split("/")[-1])
90
+
91
+ sitf1 = SiTF1(
92
+ input_size=latent_size,
93
+ patch_size=patch_size,
94
+ in_channels=4,
95
+ hidden_size=hidden_size,
96
+ depth=depth,
97
+ num_heads=num_heads,
98
+ mlp_ratio=4.0,
99
+ class_dropout_prob=0.1,
100
+ num_classes=args.num_classes,
101
+ learn_sigma=False,
102
+ ).to(device)
103
+ sitf1.load_state_dict(base_state, strict=False)
104
+ sitf1.eval()
105
+
106
+ sitf2 = SiTF2(
107
+ input_size=latent_size,
108
+ hidden_size=hidden_size,
109
+ out_channels=8,
110
+ patch_size=patch_size,
111
+ num_heads=num_heads,
112
+ mlp_ratio=4.0,
113
+ depth=args.depth,
114
+ learn_sigma=True,
115
+ num_classes=args.num_classes,
116
+ learn_mu=args.learn_mu,
117
+ ).to(device)
118
+ sitf2 = DDP(sitf2, device_ids=[device])
119
+ sitf2_state = fix_state_dict_for_ddp(find_model(args.sitf2_ckpt))
120
+ sitf2.load_state_dict(sitf2_state, strict=False)
121
+ sitf2.eval()
122
+ rectified_model = CombinedModel(sitf1, sitf2).to(device)
123
+ rectified_model.eval()
124
+
125
+ def model_base(x, t, y=None, **kwargs):
126
+ if using_cfg and "cfg_scale" in kwargs:
127
+ return base_model.forward_with_cfg(x, t, y, kwargs["cfg_scale"])
128
+ return base_model.forward(x, t, y)
129
+
130
+ def model_rectified(x, t, y=None, **kwargs):
131
+ if using_cfg and "cfg_scale" in kwargs:
132
+ sit_out = base_model.forward_with_cfg(x, t, y, kwargs["cfg_scale"])
133
+ else:
134
+ sit_out = base_model.forward(x, t, y)
135
+ if not args.use_sitf2:
136
+ return sit_out
137
+ out = rectified_model.forward(x, t, y)
138
+ if args.use_sitf2_before_t05:
139
+ mask = (t < args.sitf2_threshold).float()
140
+ while len(mask.shape) < len(out.shape):
141
+ mask = mask.unsqueeze(-1)
142
+ out = out * mask.expand_as(out)
143
+ return sit_out + out
144
+
145
+ transport = create_transport(
146
+ args.path_type, args.prediction, args.loss_weight, args.train_eps, args.sample_eps
147
+ )
148
+ sampler = Sampler(transport)
149
+ if mode == "ODE":
150
+ sample_fn = sampler.sample_ode(
151
+ sampling_method=args.sampling_method,
152
+ num_steps=args.num_sampling_steps,
153
+ atol=args.atol,
154
+ rtol=args.rtol,
155
+ reverse=args.reverse,
156
+ )
157
+ else:
158
+ sample_fn = sampler.sample_sde(
159
+ sampling_method=args.sampling_method,
160
+ diffusion_form=args.diffusion_form,
161
+ diffusion_norm=args.diffusion_norm,
162
+ last_step=args.last_step,
163
+ last_step_size=args.last_step_size,
164
+ num_steps=args.num_sampling_steps,
165
+ )
166
+
167
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
168
+
169
+ exp_name = f"compare-{args.model.replace('/', '-')}-cfg-{args.cfg_scale}-{mode}-{args.num_sampling_steps}"
170
+ root_out = os.path.join(args.sample_dir, exp_name)
171
+ out_base = os.path.join(root_out, "base")
172
+ out_rect = os.path.join(root_out, "rectified")
173
+ out_pair = os.path.join(root_out, "pair")
174
+ if rank == 0:
175
+ os.makedirs(out_base, exist_ok=True)
176
+ os.makedirs(out_rect, exist_ok=True)
177
+ os.makedirs(out_pair, exist_ok=True)
178
+ dist.barrier()
179
+
180
+ n = args.per_proc_batch_size
181
+ global_batch = n * dist.get_world_size()
182
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch) * global_batch)
183
+ iters = total_samples // global_batch
184
+ pbar = tqdm(range(iters)) if rank == 0 else range(iters)
185
+ total = 0
186
+
187
+ for _ in pbar:
188
+ z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
189
+ y = torch.randint(0, args.num_classes, (n,), device=device)
190
+ if using_cfg:
191
+ z_in = torch.cat([z, z], 0)
192
+ y_null = torch.full((n,), args.num_classes, device=device, dtype=y.dtype)
193
+ y_in = torch.cat([y, y_null], 0)
194
+ model_kwargs = dict(y=y_in, cfg_scale=args.cfg_scale)
195
+ else:
196
+ z_in = z
197
+ y_in = y
198
+ model_kwargs = dict(y=y_in)
199
+
200
+ # Ensure SDE process noise is identical between base and rectified:
201
+ # save RNG state -> run base -> restore RNG state -> run rectified.
202
+ cpu_rng_state = torch.get_rng_state()
203
+ cuda_rng_state = torch.cuda.get_rng_state(device)
204
+
205
+ x_base = sample_fn(z_in, model_base, **model_kwargs)[-1]
206
+
207
+ torch.set_rng_state(cpu_rng_state)
208
+ torch.cuda.set_rng_state(cuda_rng_state, device=device)
209
+ x_rect = sample_fn(z_in, model_rectified, **model_kwargs)[-1]
210
+ if using_cfg:
211
+ x_base, _ = x_base.chunk(2, dim=0)
212
+ x_rect, _ = x_rect.chunk(2, dim=0)
213
+
214
+ img_base = vae.decode(x_base / 0.18215).sample
215
+ img_rect = vae.decode(x_rect / 0.18215).sample
216
+ img_base = torch.clamp(127.5 * img_base + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
217
+ img_rect = torch.clamp(127.5 * img_rect + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
218
+
219
+ save_png_batch(img_base, out_base, rank, total)
220
+ save_png_batch(img_rect, out_rect, rank, total)
221
+
222
+ for i in range(img_base.shape[0]):
223
+ index = i * dist.get_world_size() + rank + total
224
+ pair = np.concatenate([img_base[i], img_rect[i]], axis=1)
225
+ Image.fromarray(pair).save(f"{out_pair}/{index:06d}.png")
226
+
227
+ total += global_batch
228
+ dist.barrier()
229
+
230
+ dist.barrier()
231
+ if rank == 0:
232
+ create_npz_from_sample_folder(out_base, args.num_fid_samples)
233
+ create_npz_from_sample_folder(out_rect, args.num_fid_samples)
234
+ print(f"Done. Output root: {root_out}")
235
+ dist.barrier()
236
+ dist.destroy_process_group()
237
+
238
+
239
+ if __name__ == "__main__":
240
+ parser = argparse.ArgumentParser()
241
+ if len(sys.argv) < 2:
242
+ print("Usage: program.py <mode> [options]")
243
+ sys.exit(1)
244
+
245
+ mode = sys.argv[1]
246
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
247
+
248
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
249
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
250
+ parser.add_argument("--sample-dir", type=str, default="compare_samples")
251
+ parser.add_argument("--per-proc-batch-size", type=int, default=12)
252
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
253
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
254
+ parser.add_argument("--num-classes", type=int, default=1000)
255
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
256
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
257
+ parser.add_argument("--global-seed", type=int, default=1)
258
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True)
259
+ parser.add_argument("--ckpt", type=str, required=True, help="Base SiT checkpoint")
260
+ parser.add_argument("--sitf2-ckpt", type=str, required=True, help="SiTF2 checkpoint")
261
+ parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True)
262
+ parser.add_argument("--depth", type=int, default=1, help="SiTF2 depth")
263
+ parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True)
264
+ parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False)
265
+ parser.add_argument("--sitf2-threshold", type=float, default=0.5)
266
+
267
+ parse_transport_args(parser)
268
+ if mode == "ODE":
269
+ parse_ode_args(parser)
270
+ else:
271
+ parse_sde_args(parser)
272
+
273
+ args = parser.parse_known_args()[0]
274
+ main(mode, args)
GVP/Baseline/sample_ddp.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Samples a large number of images from a pre-trained SiT model using DDP.
6
+ Subsequently saves a .npz file that can be used to compute FID and other
7
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
8
+
9
+ For a simple single-GPU/CPU sampling script, see sample.py.
10
+ """
11
+ import torch
12
+ import torch.distributed as dist
13
+ from models import SiT_models
14
+ from download import find_model
15
+ from transport import create_transport, Sampler
16
+ from diffusers.models import AutoencoderKL
17
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ import sys
25
+
26
+
27
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
28
+ """
29
+ Builds a single .npz file from a folder of .png samples.
30
+ """
31
+ samples = []
32
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
33
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
34
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
35
+ samples.append(sample_np)
36
+ samples = np.stack(samples)
37
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
38
+ npz_path = f"{sample_dir}.npz"
39
+ np.savez(npz_path, arr_0=samples)
40
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
41
+ return npz_path
42
+
43
+
44
+ def main(mode, args):
45
+ """
46
+ Run sampling.
47
+ """
48
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
49
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
50
+ torch.set_grad_enabled(False)
51
+
52
+ # Setup DDP:
53
+ dist.init_process_group("nccl")
54
+ rank = dist.get_rank()
55
+ device = rank % torch.cuda.device_count()
56
+ seed = args.global_seed * dist.get_world_size() + rank
57
+ torch.manual_seed(seed)
58
+ torch.cuda.set_device(device)
59
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
60
+
61
+ if args.ckpt is None:
62
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
63
+ assert args.image_size in [256, 512]
64
+ assert args.num_classes == 1000
65
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
66
+ learn_sigma = args.image_size == 256
67
+ else:
68
+ learn_sigma = False
69
+
70
+ # Load model:
71
+ latent_size = args.image_size // 8
72
+ model = SiT_models[args.model](
73
+ input_size=latent_size,
74
+ num_classes=args.num_classes,
75
+ learn_sigma=learn_sigma,
76
+ ).to(device)
77
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
78
+ ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
79
+ state_dict = find_model(ckpt_path)
80
+ model.load_state_dict(state_dict)
81
+ model.eval() # important!
82
+
83
+
84
+ transport = create_transport(
85
+ args.path_type,
86
+ args.prediction,
87
+ args.loss_weight,
88
+ args.train_eps,
89
+ args.sample_eps
90
+ )
91
+ sampler = Sampler(transport)
92
+ if mode == "ODE":
93
+ if args.likelihood:
94
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
95
+ sample_fn = sampler.sample_ode_likelihood(
96
+ sampling_method=args.sampling_method,
97
+ num_steps=args.num_sampling_steps,
98
+ atol=args.atol,
99
+ rtol=args.rtol,
100
+ )
101
+ else:
102
+ sample_fn = sampler.sample_ode(
103
+ sampling_method=args.sampling_method,
104
+ num_steps=args.num_sampling_steps,
105
+ atol=args.atol,
106
+ rtol=args.rtol,
107
+ reverse=args.reverse
108
+ )
109
+ elif mode == "SDE":
110
+ sample_fn = sampler.sample_sde(
111
+ sampling_method=args.sampling_method,
112
+ diffusion_form=args.diffusion_form,
113
+ diffusion_norm=args.diffusion_norm,
114
+ last_step=args.last_step,
115
+ last_step_size=args.last_step_size,
116
+ num_steps=args.num_sampling_steps,
117
+ )
118
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
119
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
120
+ using_cfg = args.cfg_scale > 1.0
121
+
122
+ # Create folder to save samples:
123
+ model_string_name = args.model.replace("/", "-")
124
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
125
+ if mode == "ODE":
126
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
127
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
128
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
129
+ elif mode == "SDE":
130
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
131
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
132
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
133
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
134
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
135
+ if rank == 0:
136
+ os.makedirs(sample_folder_dir, exist_ok=True)
137
+ print(f"Saving .png samples at {sample_folder_dir}")
138
+ dist.barrier()
139
+
140
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
141
+ n = args.per_proc_batch_size
142
+ global_batch_size = n * dist.get_world_size()
143
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
144
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
145
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
146
+ if rank == 0:
147
+ print(f"Total number of images that will be sampled: {total_samples}")
148
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
149
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
150
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
151
+ iterations = int(samples_needed_this_gpu // n)
152
+ done_iterations = int( int(num_samples // dist.get_world_size()) // n)
153
+ pbar = range(iterations)
154
+ pbar = tqdm(pbar) if rank == 0 else pbar
155
+ total = 0
156
+
157
+ for i in pbar:
158
+ # Sample inputs:
159
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
160
+ y = torch.randint(0, args.num_classes, (n,), device=device)
161
+
162
+ # Setup classifier-free guidance:
163
+ if using_cfg:
164
+ z = torch.cat([z, z], 0)
165
+ y_null = torch.tensor([1000] * n, device=device)
166
+ y = torch.cat([y, y_null], 0)
167
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
168
+ model_fn = model.forward_with_cfg
169
+ else:
170
+ model_kwargs = dict(y=y)
171
+ model_fn = model.forward
172
+
173
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
174
+ if using_cfg:
175
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
176
+
177
+ samples = vae.decode(samples / 0.18215).sample
178
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
179
+
180
+ # Save samples to disk as individual .png files
181
+ for i, sample in enumerate(samples):
182
+ index = i * dist.get_world_size() + rank + total
183
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
184
+ total += global_batch_size
185
+ dist.barrier()
186
+
187
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
188
+ dist.barrier()
189
+ if rank == 0:
190
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
191
+ print("Done.")
192
+ dist.barrier()
193
+ dist.destroy_process_group()
194
+
195
+
196
+ if __name__ == "__main__":
197
+
198
+ parser = argparse.ArgumentParser()
199
+
200
+ if len(sys.argv) < 2:
201
+ print("Usage: program.py <mode> [options]")
202
+ sys.exit(1)
203
+
204
+ mode = sys.argv[1]
205
+
206
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
207
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
208
+
209
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
210
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
211
+ parser.add_argument("--sample-dir", type=str, default="samples")
212
+ parser.add_argument("--per-proc-batch-size", type=int, default=12)
213
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
214
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
215
+ parser.add_argument("--num-classes", type=int, default=1000)
216
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
217
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
218
+ parser.add_argument("--global-seed", type=int, default=1)
219
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
220
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
221
+ parser.add_argument("--ckpt", type=str, default=None,
222
+ help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
223
+
224
+ parse_transport_args(parser)
225
+ if mode == "ODE":
226
+ parse_ode_args(parser)
227
+ # Further processing for ODE
228
+ elif mode == "SDE":
229
+ parse_sde_args(parser)
230
+ # Further processing for SDE
231
+
232
+ args = parser.parse_known_args()[0]
233
+ main(mode, args)
GVP/Baseline/sample_rectified_noise.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ from models import SiT_models
5
+ from download import find_model
6
+ from transport import create_transport, Sampler
7
+ from diffusers.models import AutoencoderKL
8
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
9
+ from tqdm import tqdm
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ import math
14
+ import argparse
15
+ import sys
16
+
17
+
18
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
19
+ """
20
+ Builds a single .npz file from a folder of .png samples.
21
+ """
22
+ samples = []
23
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
24
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
25
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
26
+ samples.append(sample_np)
27
+ samples = np.stack(samples)
28
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
29
+ npz_path = f"{sample_dir}.npz"
30
+ np.savez(npz_path, arr_0=samples)
31
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
32
+ return npz_path
33
+
34
+
35
+ def fix_state_dict_for_ddp(state_dict):
36
+ """
37
+ Fix state dict keys to match DistributedDataParallel model keys.
38
+ Add "module." prefix to keys if they don't have it.
39
+ """
40
+ # Check if this is a full checkpoint dict with "model", "ema", or "opt" keys
41
+ if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict or "opt" in state_dict):
42
+ # This is a full checkpoint dict, extract the state dict we need
43
+ # Prefer "ema" then "model" then return as is
44
+ if "ema" in state_dict:
45
+ state_dict = state_dict["ema"]
46
+ elif "model" in state_dict:
47
+ state_dict = state_dict["model"]
48
+ else:
49
+ # If only "opt" or other keys exist, return original
50
+ state_dict = state_dict
51
+
52
+ # Now fix the keys to match DDP format
53
+ fixed_state_dict = {}
54
+ for key, value in state_dict.items():
55
+ if not key.startswith("module."):
56
+ new_key = "module." + key
57
+ else:
58
+ new_key = key
59
+ fixed_state_dict[new_key] = value
60
+ return fixed_state_dict
61
+
62
+ def main(mode, args):
63
+ """
64
+ Run sampling.
65
+ """
66
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
67
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
68
+ torch.set_grad_enabled(False)
69
+ learn_mu = args.learn_mu
70
+ sitf2_depth = args.depth # Save SiTF2 depth before it gets overwritten
71
+
72
+ # Setup DDP:
73
+ dist.init_process_group("nccl")
74
+ rank = dist.get_rank()
75
+ device = rank % torch.cuda.device_count()
76
+ seed = args.global_seed * dist.get_world_size() + rank
77
+ torch.manual_seed(seed)
78
+ torch.cuda.set_device(device)
79
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
80
+
81
+ if args.ckpt is None:
82
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
83
+ assert args.image_size in [256, 512]
84
+ assert args.num_classes == 1000
85
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
86
+ learn_sigma = args.image_size == 256
87
+ else:
88
+ learn_sigma = False
89
+
90
+ # Load SiTF1 and SiTF2 models and create CombinedModel
91
+ from models import SiTF1, SiTF2, CombinedModel
92
+ latent_size = args.image_size // 8
93
+
94
+ # Get model configuration based on args.model
95
+ model_name = args.model
96
+ if 'XL' in model_name:
97
+ hidden_size, depth, num_heads = 1152, 28, 16
98
+ elif 'L' in model_name:
99
+ hidden_size, depth, num_heads = 1024, 24, 16
100
+ elif 'B' in model_name:
101
+ hidden_size, depth, num_heads = 768, 12, 12
102
+ elif 'S' in model_name:
103
+ hidden_size, depth, num_heads = 384, 12, 6
104
+ else:
105
+ # Default fallback
106
+ hidden_size, depth, num_heads = 768, 12, 12
107
+
108
+ # Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2
109
+ patch_size = int(model_name.split('/')[-1])
110
+
111
+ # Load SiTF1
112
+ sitf1 = SiTF1(
113
+ input_size=latent_size,
114
+ patch_size=patch_size,
115
+ in_channels=4,
116
+ hidden_size=hidden_size,
117
+ depth=depth,
118
+ num_heads=num_heads,
119
+ mlp_ratio=4.0,
120
+ class_dropout_prob=0.1,
121
+ num_classes=args.num_classes,
122
+ learn_sigma=False
123
+ ).to(device)
124
+ sitf1_state_raw = find_model(args.ckpt)
125
+ # find_model now returns ema if available, or the full checkpoint
126
+ # Extract the actual state_dict to use for both sitf1 and base_model
127
+ if isinstance(sitf1_state_raw, dict) and "model" in sitf1_state_raw:
128
+ sitf1_state = sitf1_state_raw["model"]
129
+ else:
130
+ # sitf1_state_raw is already a state_dict (either ema or direct model state)
131
+ sitf1_state = sitf1_state_raw
132
+ sitf1.load_state_dict(sitf1_state)
133
+ sitf1.eval()
134
+
135
+ # For sampling, we can use sitf1 directly instead of creating a separate sit model
136
+ # since sitf1 and sit have the same architecture and weights
137
+
138
+ # Load SiTF2 with the same architecture parameters as SiTF1 for compatibility
139
+ sitf2 = SiTF2(
140
+ input_size=latent_size,
141
+ hidden_size=hidden_size, # Use the same hidden_size as SiTF1
142
+ out_channels=8,
143
+ patch_size=patch_size, # Use the same patch_size as SiTF1
144
+ num_heads=num_heads, # Use the same num_heads as SiTF1
145
+ mlp_ratio=4.0,
146
+ depth=sitf2_depth, # Use the depth specified by command line argument (not the model's default depth)
147
+ learn_sigma=True,
148
+ num_classes=args.num_classes,
149
+ learn_mu=learn_mu
150
+ ).to(device)
151
+ sitf2 = DDP(sitf2, device_ids=[device])
152
+ sitf2_state = find_model(args.sitf2_ckpt)
153
+ # Fix state dict keys to match DDP model
154
+ sitf2_state_fixed = fix_state_dict_for_ddp(sitf2_state)
155
+ try:
156
+ sitf2.load_state_dict(sitf2_state_fixed)
157
+ except Exception as e:
158
+ print(f"Error loading state dict: {e}")
159
+ # Try loading with strict=False as fallback
160
+ sitf2.load_state_dict(sitf2_state_fixed, strict=False)
161
+ sitf2.eval()
162
+ # CombinedModel
163
+
164
+ combined_model = CombinedModel(sitf1, sitf2).to(device)
165
+ sitf2.eval()
166
+ combined_model.eval()
167
+
168
+ # Use SiT_models factory function to create the base model, same as in SiT_clean
169
+ # This ensures correct model configuration
170
+ # Use learn_sigma=False to match sitf1 configuration
171
+ base_model = SiT_models[args.model](
172
+ input_size=latent_size,
173
+ num_classes=args.num_classes,
174
+ learn_sigma=False, # Match sitf1's learn_sigma=False
175
+ ).to(device)
176
+ # Load the checkpoint (same as sitf1) - use the exact same state_dict
177
+ base_model.load_state_dict(sitf1_state)
178
+ base_model.eval()
179
+
180
+ # Determine if CFG will be used (needed for combined_sampling_model function)
181
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
182
+ using_cfg = args.cfg_scale > 1.0
183
+
184
+ # There are repeated calculations in the middle,
185
+ # which will cause Flops to double. A simplified version will be released later
186
+ def combined_sampling_model(x, t, y=None, **kwargs):
187
+ with torch.no_grad():
188
+ # Handle CFG same as in SiT_clean/sample_ddp.py
189
+ if using_cfg and 'cfg_scale' in kwargs:
190
+ # Use forward_with_cfg when CFG is enabled
191
+ sit_out = base_model.forward_with_cfg(x, t, y, kwargs['cfg_scale'])
192
+ else:
193
+ # Use regular forward when CFG is disabled
194
+ sit_out = base_model.forward(x, t, y)
195
+ # If use_sitf2_before_t05 is True, only use sitf2 when t < threshold
196
+ if args.use_sitf2:
197
+ if args.use_sitf2_before_t05:
198
+ # t is a tensor, check which samples have t < threshold
199
+ # Create a mask: 1.0 where t < threshold, 0.0 otherwise
200
+ mask = (t < args.sitf2_threshold).float()
201
+ # Compute sitf2 output for all samples
202
+ combined_out = combined_model.forward(x, t, y)
203
+ # Expand mask to match the spatial dimensions of combined_out
204
+ # combined_out shape is (batch, channels, height, width)
205
+ while len(mask.shape) < len(combined_out.shape):
206
+ mask = mask.unsqueeze(-1)
207
+ # Broadcast mask to match combined_out shape
208
+ mask = mask.expand_as(combined_out)
209
+ # Only use sitf2 output where t < threshold
210
+ combined_out = combined_out * mask
211
+ # Combine sit_out and masked combined_out
212
+ return sit_out + combined_out
213
+ else:
214
+ # Default behavior: only use base model output
215
+ return sit_out
216
+ else:
217
+ # Default behavior: only use base model output
218
+ return sit_out
219
+
220
+ transport = create_transport(
221
+ args.path_type,
222
+ args.prediction,
223
+ args.loss_weight,
224
+ args.train_eps,
225
+ args.sample_eps
226
+ )
227
+ sampler = Sampler(transport)
228
+ if mode == "ODE":
229
+ if args.likelihood:
230
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
231
+ sample_fn = sampler.sample_ode_likelihood(
232
+ sampling_method=args.sampling_method,
233
+ num_steps=args.num_sampling_steps,
234
+ atol=args.atol,
235
+ rtol=args.rtol,
236
+ )
237
+ else:
238
+ sample_fn = sampler.sample_ode(
239
+ sampling_method=args.sampling_method,
240
+ num_steps=args.num_sampling_steps,
241
+ atol=args.atol,
242
+ rtol=args.rtol,
243
+ reverse=args.reverse
244
+ )
245
+ elif mode == "SDE":
246
+ sample_fn = sampler.sample_sde(
247
+ sampling_method=args.sampling_method,
248
+ diffusion_form=args.diffusion_form,
249
+ diffusion_norm=args.diffusion_norm,
250
+ last_step=args.last_step,
251
+ last_step_size=args.last_step_size,
252
+ num_steps=args.num_sampling_steps,
253
+ )
254
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
255
+
256
+ # Create folder to save samples:
257
+ model_string_name = args.model.replace("/", "-")
258
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
259
+ sitf2_ckpt_string_name = os.path.basename(args.sitf2_ckpt).replace(".pt", "") if args.ckpt else "pretrained"
260
+ if mode == "ODE":
261
+ folder_name = f"{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
262
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
263
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
264
+ elif mode == "SDE":
265
+ # Add threshold info to folder name if use_sitf2_before_t05 is enabled
266
+ threshold_suffix = f"-threshold-{args.sitf2_threshold}" if args.use_sitf2_before_t05 else ""
267
+ if learn_mu:
268
+ folder_name = f"depth-mu-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
269
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
270
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
271
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
272
+ else:
273
+ folder_name = f"depth-sigma-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
274
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
275
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
276
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
277
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
278
+ if rank == 0:
279
+ os.makedirs(sample_folder_dir, exist_ok=True)
280
+ print(f"Saving .png samples at {sample_folder_dir}")
281
+ dist.barrier()
282
+
283
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
284
+ n = args.per_proc_batch_size
285
+ global_batch_size = n * dist.get_world_size()
286
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
287
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
288
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
289
+ if rank == 0:
290
+ print(f"Total number of images that will be sampled: {total_samples}")
291
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
292
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
293
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
294
+ iterations = int(samples_needed_this_gpu // n)
295
+ done_iterations = int( int(num_samples // dist.get_world_size()) // n)
296
+ pbar = range(iterations)
297
+ pbar = tqdm(pbar) if rank == 0 else pbar
298
+ total = 0
299
+
300
+ for i in pbar:
301
+ # Sample inputs:
302
+ z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
303
+ y = torch.randint(0, args.num_classes, (n,), device=device)
304
+ # Setup classifier-free guidance:
305
+ if using_cfg:
306
+ z = torch.cat([z, z], 0)
307
+ y_null = torch.tensor([1000] * n, device=device)
308
+ y = torch.cat([y, y_null], 0)
309
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
310
+ else:
311
+ model_kwargs = dict(y=y)
312
+ samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
313
+ if using_cfg:
314
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
315
+ samples = vae.decode(samples / 0.18215).sample
316
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
317
+ # Save samples to disk as individual .png files
318
+ for i, sample in enumerate(samples):
319
+ index = i * dist.get_world_size() + rank + total
320
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
321
+ total += global_batch_size
322
+ dist.barrier()
323
+
324
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
325
+ dist.barrier()
326
+ if rank == 0:
327
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
328
+ print("Done.")
329
+ dist.barrier()
330
+ dist.destroy_process_group()
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ parser = argparse.ArgumentParser()
336
+
337
+ if len(sys.argv) < 2:
338
+ print("Usage: program.py <mode> [options]")
339
+ sys.exit(1)
340
+
341
+ mode = sys.argv[1]
342
+
343
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
344
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
345
+
346
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
347
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
348
+ parser.add_argument("--sample-dir", type=str, default="samples")
349
+ parser.add_argument("--per-proc-batch-size", type=int, default=12)
350
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
351
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
352
+ parser.add_argument("--num-classes", type=int, default=1000)
353
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
354
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
355
+ parser.add_argument("--global-seed", type=int, default=0)
356
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
357
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
358
+ parser.add_argument("--ckpt", type=str, default=None,
359
+ help="Optional path to a SiT checkpoint.")
360
+ parser.add_argument("--sitf2-ckpt", type=str, required=True, help="Path to SiTF2 checkpoint")
361
+ parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True,
362
+ help="Whether to learn mu parameter")
363
+ parser.add_argument("--depth", type=int, default=1,
364
+ help="Depth parameter for SiTF2 model")
365
+ parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True,
366
+ help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
367
+ parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False,
368
+ help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
369
+ parser.add_argument("--sitf2-threshold", type=float, default=0.5,
370
+ help="Time threshold for using SiTF2 output (default: 0.5). Only effective when --use-sitf2-before-t05 is True")
371
+ parse_transport_args(parser)
372
+ if mode == "ODE":
373
+ parse_ode_args(parser)
374
+ # Further processing for ODE
375
+ elif mode == "SDE":
376
+ parse_sde_args(parser)
377
+ # Further processing for SDE
378
+
379
+ args = parser.parse_known_args()[0]
380
+ main(mode, args)
GVP/Baseline/samples.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
2
+ --nnodes=1 \
3
+ --nproc_per_node=4 \
4
+ --rdzv_endpoint=localhost:29110 \
5
+ sample_rectified_noise.py SDE \
6
+ --depth 6 \
7
+ --sample-dir GVP_samples \
8
+ --model SiT-XL/2 \
9
+ --num-fid-samples 50000 \
10
+ --num-classes 1000 \
11
+ --global-seed 0 \
12
+ --use-sitf2 True \
13
+ --sitf2-threshold 1 \
14
+ --ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
15
+ --sitf2-ckpt /gemini/space/gzy_new/models/Baseline/results_256_gvp_disp/depth-mu-6-007-SiT-XL-2-GVP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt \
16
+ > W_No.log 2>&1 &
GVP/Baseline/samples_ddp.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3 nohup torchrun \
2
+ --nnodes=1 \
3
+ --nproc_per_node=4 \
4
+ --rdzv_endpoint=localhost:29111 \
5
+ sample_ddp.py SDE \
6
+ --sample-dir baseline_gvp_ \
7
+ --model SiT-XL/2 \
8
+ --num-fid-samples 50000 \
9
+ --num-classes 1000 \
10
+ --global-seed 0 \
11
+ --path-type GVP \
12
+ --prediction velocity \
13
+ --ckpt /gemini/space/gzy_new/models/xiangzai_Back/GVP_check/base.pt \
14
+ > gvp_sampling.log 2>&1 &
GVP/Baseline/transport/__pycache__/ot_plan.cpython-311.pyc ADDED
Binary file (5.71 kB). View file
 
GVP/Baseline/transport/__pycache__/path.cpython-310.pyc ADDED
Binary file (7.9 kB). View file
 
GVP/Baseline/transport/__pycache__/path.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
GVP/Baseline/transport/__pycache__/path.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
GVP/Baseline/transport/__pycache__/path.cpython-38.pyc ADDED
Binary file (7.93 kB). View file
 
GVP/Baseline/transport/__pycache__/transport.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
GVP/Baseline/transport/__pycache__/transport.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
GVP/Baseline/transport/__pycache__/transport.cpython-312.pyc ADDED
Binary file (22.8 kB). View file
 
GVP/Baseline/transport/__pycache__/transport.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
GVP/Baseline/transport/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
GVP/Baseline/transport/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.17 kB). View file
 
GVP/Baseline/transport/__pycache__/utils.cpython-312.pyc ADDED
Binary file (1.9 kB). View file
 
GVP/Baseline/transport/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.26 kB). View file
 
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0020000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de6ee3c5ee036be85b216daddafb56f16df64e9c3f7d3060b31f5cb301a345b1
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0040000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30fe96a8ac31b2744debe3861b4c92e631462eb27401c26f88042b6e65ac287d
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0060000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f82060d1339e53cc05c4863d39dc083795f05ad282cabc830889d8b6f72911b
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0080000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:606a5b7dceacc351bf68867736eed793e3e02c29a6e27aa0e8cf6e2ef3382c06
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0100000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ecb79b4530e3503370cefb16c38706a8b7b823e221f30f621c902e2e174aaec
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0120000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d03dae6043ff1543a412b1c44c02a21aa299014e8b3b3209e21cfd73b2c5d0f
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0140000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c630ada8db890d9fc41eaf8d30b3bb1d95b1e61d380d8f7e97f639ec3b307233
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0160000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aa31b63d4138dec6b8d4c55a58fadf60096ac8f1d4ec83651ea724e7d842e17
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0180000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4b9c5725fd5492e8e4979b190b04809ba03514ff57da6b5ce92e3e99df9c0db
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0200000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49d27a75f329c0170b0b6bbc0028881a40b6d8cb2a34495fba8ba73e49a395a4
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0220000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c59a8b6bca3ade4b6547951082b5dc67ad805861c611a13bccf1995cf8f0ca7d
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0240000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b47c2b4f5baef28f5f1033b57e5ad7cb2c1b52246fcaa0ff4214781bfa305604
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0260000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf7f7746b191c29efd3ff33ba7cdb83f67d242ed03a14228478943a3cb484907
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0280000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f254fdaea4c2c6875493b4a94846b5a8a54cf90210d610d297a7bbceadd84101
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0300000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f198d183997904bfef6ae3222f59022cbb86c40ae625d2196d659bd7c1c3c121
3
+ size 1193384322
VP/depth-mu-4-001-SiT-XL-2-VP-velocity-None-OT-Contrastive0.05/checkpoints/0320000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e4aca9e6138967eea8540f58f7a7b1d09b749ea4b471e62a0052b8c04e9a8d4
3
+ size 1193384322