Add files using upload-large-folder tool
Browse files- REG copy/train.sh +28 -0
- VIRTUAL_imagenet256_labeled/.gitattributes +35 -0
- VIRTUAL_imagenet256_labeled/README.md +11 -0
- __pycache__/evaluator.cpython-38.pyc +0 -0
- __pycache__/fid_custom.cpython-313.pyc +0 -0
- __pycache__/npz_health_check.cpython-313.pyc +0 -0
- back/dataset.py +149 -0
- back/generate.py +227 -0
- back/sample_from_checkpoint_ddp.py +416 -0
- back/samplers.py +840 -0
- back/samples_0.75.log +0 -0
- back/samples_0.75_new.log +11 -0
- conditional-flow-matching/.gitignore +175 -0
- conditional-flow-matching/.pre-commit-config.yaml +99 -0
- conditional-flow-matching/LICENSE +21 -0
- conditional-flow-matching/README.md +230 -0
- conditional-flow-matching/pyproject.toml +39 -0
- conditional-flow-matching/requirements.txt +15 -0
- conditional-flow-matching/runner-requirements.txt +52 -0
- conditional-flow-matching/setup.py +36 -0
REG copy/train.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NUM_GPUS=8
|
| 2 |
+
random_number=$((RANDOM % 100 + 1200))
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
|
| 6 |
+
--report-to="wandb" \
|
| 7 |
+
--allow-tf32 \
|
| 8 |
+
--mixed-precision="fp16" \
|
| 9 |
+
--seed=0 \
|
| 10 |
+
--path-type="linear" \
|
| 11 |
+
--prediction="v" \
|
| 12 |
+
--weighting="uniform" \
|
| 13 |
+
--model="SiT-XL/2" \
|
| 14 |
+
--enc-type="dinov2-vit-b" \
|
| 15 |
+
--proj-coeff=0.5 \
|
| 16 |
+
--encoder-depth=8 \ #SiT-L/XL use 8, SiT-B use 4
|
| 17 |
+
--output-dir="your_path/reg_xlarge_dinov2_base_align_8_cls" \
|
| 18 |
+
--exp-name="linear-dinov2-b-enc8" \
|
| 19 |
+
--batch-size=256 \
|
| 20 |
+
--data-dir="data_path/imagenet_vae" \
|
| 21 |
+
--cls=0.03
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#Dataset Path
|
| 25 |
+
#For example: your_path/imagenet-vae
|
| 26 |
+
#This folder contains two folders
|
| 27 |
+
#(1) The imagenet's RGB image: your_path/imagenet-vae/imagenet_256-vae/
|
| 28 |
+
#(2) The imagenet's VAE latent: your_path/imagenet-vae/vae-sd/
|
VIRTUAL_imagenet256_labeled/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
VIRTUAL_imagenet256_labeled/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: Apache License 2.0
|
| 3 |
+
---
|
| 4 |
+
数据集文件元信息以及数据文件,请浏览“数据集文件”页面获取。
|
| 5 |
+
|
| 6 |
+
当前数据集卡片使用的是默认模版,数据集的贡献者未提供更加详细的数据集介绍,但是您可以通过如下GIT Clone命令,或者ModelScope SDK来下载数据集
|
| 7 |
+
|
| 8 |
+
#### 下载方法
|
| 9 |
+
:modelscope-code[]{type="sdk"}
|
| 10 |
+
:modelscope-code[]{type="git"}
|
| 11 |
+
|
__pycache__/evaluator.cpython-38.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
__pycache__/fid_custom.cpython-313.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
__pycache__/npz_health_check.cpython-313.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
back/dataset.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import PIL.Image
|
| 10 |
+
try:
|
| 11 |
+
import pyspng
|
| 12 |
+
except ImportError:
|
| 13 |
+
pyspng = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CustomDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
data_dir 下 VAE latent:imagenet_256_vae/
|
| 19 |
+
无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。
|
| 20 |
+
有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引,
|
| 21 |
+
按特征文件名推断 imagenet_256_vae 中对应 npy。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, data_dir, semantic_features_dir=None):
|
| 25 |
+
PIL.Image.init()
|
| 26 |
+
supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
|
| 27 |
+
|
| 28 |
+
self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
|
| 29 |
+
|
| 30 |
+
if semantic_features_dir is None:
|
| 31 |
+
potential_semantic_dir = os.path.join(
|
| 32 |
+
data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0'
|
| 33 |
+
)
|
| 34 |
+
if os.path.exists(potential_semantic_dir):
|
| 35 |
+
self.semantic_features_dir = potential_semantic_dir
|
| 36 |
+
self.use_preprocessed_semantic = True
|
| 37 |
+
print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
|
| 38 |
+
else:
|
| 39 |
+
self.semantic_features_dir = None
|
| 40 |
+
self.use_preprocessed_semantic = False
|
| 41 |
+
else:
|
| 42 |
+
self.semantic_features_dir = semantic_features_dir
|
| 43 |
+
self.use_preprocessed_semantic = True
|
| 44 |
+
print(f"Using preprocessed semantic features from: {self.semantic_features_dir}")
|
| 45 |
+
|
| 46 |
+
if self.use_preprocessed_semantic:
|
| 47 |
+
label_fname = os.path.join(self.semantic_features_dir, 'dataset.json')
|
| 48 |
+
if not os.path.exists(label_fname):
|
| 49 |
+
raise FileNotFoundError(f"Label file not found: {label_fname}")
|
| 50 |
+
|
| 51 |
+
print(f"Using {label_fname}.")
|
| 52 |
+
with open(label_fname, 'rb') as f:
|
| 53 |
+
data = json.load(f)
|
| 54 |
+
labels_list = data.get('labels', None)
|
| 55 |
+
if labels_list is None:
|
| 56 |
+
raise ValueError(f"'labels' field is missing in {label_fname}")
|
| 57 |
+
|
| 58 |
+
semantic_fnames = []
|
| 59 |
+
labels = []
|
| 60 |
+
for entry in labels_list:
|
| 61 |
+
if entry is None:
|
| 62 |
+
continue
|
| 63 |
+
fname, lab = entry
|
| 64 |
+
semantic_fnames.append(fname)
|
| 65 |
+
labels.append(0 if lab is None else lab)
|
| 66 |
+
|
| 67 |
+
self.semantic_fnames = semantic_fnames
|
| 68 |
+
self.labels = np.array(labels, dtype=np.int64)
|
| 69 |
+
self.num_samples = len(self.semantic_fnames)
|
| 70 |
+
print(f"Loaded {self.num_samples} semantic entries from dataset.json")
|
| 71 |
+
else:
|
| 72 |
+
self.features_dir = os.path.join(data_dir, 'vae-sd')
|
| 73 |
+
|
| 74 |
+
self._image_fnames = {
|
| 75 |
+
os.path.relpath(os.path.join(root, fname), start=self.images_dir)
|
| 76 |
+
for root, _dirs, files in os.walk(self.images_dir) for fname in files
|
| 77 |
+
}
|
| 78 |
+
self.image_fnames = sorted(
|
| 79 |
+
fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
|
| 80 |
+
)
|
| 81 |
+
self._feature_fnames = {
|
| 82 |
+
os.path.relpath(os.path.join(root, fname), start=self.features_dir)
|
| 83 |
+
for root, _dirs, files in os.walk(self.features_dir) for fname in files
|
| 84 |
+
}
|
| 85 |
+
self.feature_fnames = sorted(
|
| 86 |
+
fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
fname = os.path.join(self.features_dir, 'dataset.json')
|
| 90 |
+
if os.path.exists(fname):
|
| 91 |
+
print(f"Using {fname}.")
|
| 92 |
+
else:
|
| 93 |
+
raise FileNotFoundError("Neither of the specified files exists.")
|
| 94 |
+
|
| 95 |
+
with open(fname, 'rb') as f:
|
| 96 |
+
labels = json.load(f)['labels']
|
| 97 |
+
labels = dict(labels)
|
| 98 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
|
| 99 |
+
labels = np.array(labels)
|
| 100 |
+
self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 101 |
+
|
| 102 |
+
def _file_ext(self, fname):
|
| 103 |
+
return os.path.splitext(fname)[1].lower()
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
if self.use_preprocessed_semantic:
|
| 107 |
+
return self.num_samples
|
| 108 |
+
assert len(self.image_fnames) == len(self.feature_fnames), \
|
| 109 |
+
"Number of feature files and label files should be same"
|
| 110 |
+
return len(self.feature_fnames)
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, idx):
|
| 113 |
+
if self.use_preprocessed_semantic:
|
| 114 |
+
semantic_fname = self.semantic_fnames[idx]
|
| 115 |
+
basename = os.path.basename(semantic_fname)
|
| 116 |
+
idx_str = basename.split('-')[-1].split('.')[0]
|
| 117 |
+
subdir = idx_str[:5]
|
| 118 |
+
vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
|
| 119 |
+
vae_path = os.path.join(self.images_dir, vae_relpath)
|
| 120 |
+
|
| 121 |
+
with open(vae_path, 'rb') as f:
|
| 122 |
+
image = np.load(f)
|
| 123 |
+
|
| 124 |
+
semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
|
| 125 |
+
semantic_features = np.load(semantic_path)
|
| 126 |
+
|
| 127 |
+
return (
|
| 128 |
+
torch.from_numpy(image).float(),
|
| 129 |
+
torch.from_numpy(image).float(),
|
| 130 |
+
torch.from_numpy(semantic_features).float(),
|
| 131 |
+
torch.tensor(self.labels[idx]),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
image_fname = self.image_fnames[idx]
|
| 135 |
+
feature_fname = self.feature_fnames[idx]
|
| 136 |
+
image_ext = self._file_ext(image_fname)
|
| 137 |
+
with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
|
| 138 |
+
if image_ext == '.npy':
|
| 139 |
+
image = np.load(f)
|
| 140 |
+
image = image.reshape(-1, *image.shape[-2:])
|
| 141 |
+
elif image_ext == '.png' and pyspng is not None:
|
| 142 |
+
image = pyspng.load(f.read())
|
| 143 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 144 |
+
else:
|
| 145 |
+
image = np.array(PIL.Image.open(f))
|
| 146 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 147 |
+
|
| 148 |
+
features = np.load(os.path.join(self.features_dir, feature_fname))
|
| 149 |
+
return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
|
back/generate.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 9 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 10 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 11 |
+
|
| 12 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from models.sit import SiT_models
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
from samplers import euler_maruyama_sampler
|
| 25 |
+
from utils import load_legacy_checkpoints, download_model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 29 |
+
"""
|
| 30 |
+
Builds a single .npz file from a folder of .png samples.
|
| 31 |
+
"""
|
| 32 |
+
samples = []
|
| 33 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 34 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 35 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 36 |
+
samples.append(sample_np)
|
| 37 |
+
samples = np.stack(samples)
|
| 38 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 39 |
+
npz_path = f"{sample_dir}.npz"
|
| 40 |
+
np.savez(npz_path, arr_0=samples)
|
| 41 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 42 |
+
return npz_path
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(args):
|
| 46 |
+
"""
|
| 47 |
+
Run sampling.
|
| 48 |
+
"""
|
| 49 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 50 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 51 |
+
torch.set_grad_enabled(False)
|
| 52 |
+
|
| 53 |
+
# Setup DDP:cd
|
| 54 |
+
dist.init_process_group("nccl")
|
| 55 |
+
rank = dist.get_rank()
|
| 56 |
+
device = rank % torch.cuda.device_count()
|
| 57 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 58 |
+
torch.manual_seed(seed)
|
| 59 |
+
torch.cuda.set_device(device)
|
| 60 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 61 |
+
|
| 62 |
+
# Load model:
|
| 63 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 64 |
+
latent_size = args.resolution // 8
|
| 65 |
+
model = SiT_models[args.model](
|
| 66 |
+
input_size=latent_size,
|
| 67 |
+
num_classes=args.num_classes,
|
| 68 |
+
use_cfg = True,
|
| 69 |
+
z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
|
| 70 |
+
encoder_depth=args.encoder_depth,
|
| 71 |
+
**block_kwargs,
|
| 72 |
+
).to(device)
|
| 73 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 74 |
+
ckpt_path = args.ckpt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 78 |
+
if ckpt_path is None:
|
| 79 |
+
args.ckpt = 'SiT-XL-2-256x256.pt'
|
| 80 |
+
assert args.model == 'SiT-XL/2'
|
| 81 |
+
assert len(args.projector_embed_dims.split(',')) == 1
|
| 82 |
+
assert int(args.projector_embed_dims.split(',')[0]) == 768
|
| 83 |
+
state_dict = download_model('last.pt')
|
| 84 |
+
else:
|
| 85 |
+
state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
|
| 86 |
+
|
| 87 |
+
if args.legacy:
|
| 88 |
+
state_dict = load_legacy_checkpoints(
|
| 89 |
+
state_dict=state_dict, encoder_depth=args.encoder_depth
|
| 90 |
+
)
|
| 91 |
+
model.load_state_dict(state_dict)
|
| 92 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
model.eval() # important!
|
| 96 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 97 |
+
#vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Create folder to save samples:
|
| 101 |
+
model_string_name = args.model.replace("/", "-")
|
| 102 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 103 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
|
| 104 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
|
| 105 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 106 |
+
if rank == 0:
|
| 107 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 108 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 109 |
+
dist.barrier()
|
| 110 |
+
|
| 111 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 112 |
+
n = args.per_proc_batch_size
|
| 113 |
+
global_batch_size = n * dist.get_world_size()
|
| 114 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 115 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 116 |
+
if rank == 0:
|
| 117 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 118 |
+
print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 119 |
+
print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
|
| 120 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 121 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 122 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 123 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 124 |
+
pbar = range(iterations)
|
| 125 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 126 |
+
total = 0
|
| 127 |
+
for _ in pbar:
|
| 128 |
+
# Sample inputs:
|
| 129 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 130 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 131 |
+
cls_z = torch.randn(n, args.cls, device=device)
|
| 132 |
+
|
| 133 |
+
# Sample images:
|
| 134 |
+
sampling_kwargs = dict(
|
| 135 |
+
model=model,
|
| 136 |
+
latents=z,
|
| 137 |
+
y=y,
|
| 138 |
+
num_steps=args.num_steps,
|
| 139 |
+
heun=args.heun,
|
| 140 |
+
cfg_scale=args.cfg_scale,
|
| 141 |
+
guidance_low=args.guidance_low,
|
| 142 |
+
guidance_high=args.guidance_high,
|
| 143 |
+
path_type=args.path_type,
|
| 144 |
+
cls_latents=cls_z,
|
| 145 |
+
args=args
|
| 146 |
+
)
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
if args.mode == "sde":
|
| 149 |
+
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
|
| 150 |
+
elif args.mode == "ode":# will support
|
| 151 |
+
exit()
|
| 152 |
+
#samples = euler_sampler(**sampling_kwargs).to(torch.float32)
|
| 153 |
+
else:
|
| 154 |
+
raise NotImplementedError()
|
| 155 |
+
|
| 156 |
+
latents_scale = torch.tensor(
|
| 157 |
+
[0.18215, 0.18215, 0.18215, 0.18215, ]
|
| 158 |
+
).view(1, 4, 1, 1).to(device)
|
| 159 |
+
latents_bias = -torch.tensor(
|
| 160 |
+
[0., 0., 0., 0.,]
|
| 161 |
+
).view(1, 4, 1, 1).to(device)
|
| 162 |
+
samples = vae.decode((samples - latents_bias) / latents_scale).sample
|
| 163 |
+
samples = (samples + 1) / 2.
|
| 164 |
+
samples = torch.clamp(
|
| 165 |
+
255. * samples, 0, 255
|
| 166 |
+
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 167 |
+
|
| 168 |
+
# Save samples to disk as individual .png files
|
| 169 |
+
for i, sample in enumerate(samples):
|
| 170 |
+
index = i * dist.get_world_size() + rank + total
|
| 171 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 172 |
+
total += global_batch_size
|
| 173 |
+
|
| 174 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 175 |
+
dist.barrier()
|
| 176 |
+
if rank == 0:
|
| 177 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 178 |
+
print("Done.")
|
| 179 |
+
dist.barrier()
|
| 180 |
+
dist.destroy_process_group()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
parser = argparse.ArgumentParser()
|
| 185 |
+
# seed
|
| 186 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 187 |
+
|
| 188 |
+
# precision
|
| 189 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 190 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 191 |
+
|
| 192 |
+
# logging/saving:
|
| 193 |
+
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
|
| 194 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 195 |
+
|
| 196 |
+
# model
|
| 197 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 198 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 199 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 200 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 201 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
|
| 202 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 203 |
+
# vae
|
| 204 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 205 |
+
|
| 206 |
+
# number of samples
|
| 207 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
| 208 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 209 |
+
|
| 210 |
+
# sampling related hyperparameters
|
| 211 |
+
parser.add_argument("--mode", type=str, default="ode")
|
| 212 |
+
parser.add_argument("--cfg-scale", type=float, default=1.5)
|
| 213 |
+
parser.add_argument("--cls-cfg-scale", type=float, default=1.5)
|
| 214 |
+
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
|
| 215 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 216 |
+
parser.add_argument("--num-steps", type=int, default=50)
|
| 217 |
+
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 218 |
+
parser.add_argument("--guidance-low", type=float, default=0.)
|
| 219 |
+
parser.add_argument("--guidance-high", type=float, default=1.)
|
| 220 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 221 |
+
parser.add_argument('--cls', default=768, type=int)
|
| 222 |
+
# will be deprecated
|
| 223 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
args = parser.parse_args()
|
| 227 |
+
main(args)
|
back/sample_from_checkpoint_ddp.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。
|
| 4 |
+
|
| 5 |
+
用法(4 卡示例):
|
| 6 |
+
torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \
|
| 7 |
+
--ckpt exps/jsflow-experiment/checkpoints/0290000.pt \
|
| 8 |
+
--out-dir ./my_samples_ddp \
|
| 9 |
+
--num-images 50000 \
|
| 10 |
+
--batch-size 16 \
|
| 11 |
+
--t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \
|
| 12 |
+
--sampler em_image_noise_before_tc
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import types
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
from diffusers.models import AutoencoderKL
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from models.sit import SiT_models
|
| 31 |
+
from samplers import (
|
| 32 |
+
euler_maruyama_image_noise_before_tc_sampler,
|
| 33 |
+
euler_maruyama_image_noise_sampler,
|
| 34 |
+
euler_maruyama_sampler,
|
| 35 |
+
euler_ode_sampler,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_npz_from_sample_folder(sample_dir: str, num: int):
|
| 40 |
+
"""
|
| 41 |
+
将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。
|
| 42 |
+
"""
|
| 43 |
+
samples = []
|
| 44 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 45 |
+
sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png"))
|
| 46 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 47 |
+
samples.append(sample_np)
|
| 48 |
+
samples = np.stack(samples)
|
| 49 |
+
npz_path = f"{sample_dir}.npz"
|
| 50 |
+
np.savez(npz_path, arr_0=samples)
|
| 51 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 52 |
+
return npz_path
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 56 |
+
if enc_type is None:
|
| 57 |
+
return 768
|
| 58 |
+
s = str(enc_type).lower()
|
| 59 |
+
if "vit-g" in s or "vitg" in s:
|
| 60 |
+
return 1536
|
| 61 |
+
if "vit-l" in s or "vitl" in s:
|
| 62 |
+
return 1024
|
| 63 |
+
if "vit-s" in s or "vits" in s:
|
| 64 |
+
return 384
|
| 65 |
+
return 768
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
|
| 69 |
+
a = ckpt.get("args")
|
| 70 |
+
if a is None:
|
| 71 |
+
return None
|
| 72 |
+
if isinstance(a, argparse.Namespace):
|
| 73 |
+
return a
|
| 74 |
+
if isinstance(a, dict):
|
| 75 |
+
return argparse.Namespace(**a)
|
| 76 |
+
if isinstance(a, types.SimpleNamespace):
|
| 77 |
+
return argparse.Namespace(**vars(a))
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_vae(device: torch.device):
|
| 82 |
+
try:
|
| 83 |
+
from preprocessing import dnnlib
|
| 84 |
+
|
| 85 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 86 |
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
| 87 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 88 |
+
os.environ["HF_HOME"] = cache_dir
|
| 89 |
+
try:
|
| 90 |
+
vae = AutoencoderKL.from_pretrained(
|
| 91 |
+
"stabilityai/sd-vae-ft-mse",
|
| 92 |
+
cache_dir=cache_dir,
|
| 93 |
+
local_files_only=True,
|
| 94 |
+
).to(device)
|
| 95 |
+
vae.eval()
|
| 96 |
+
return vae
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
candidate_dir = None
|
| 100 |
+
for root_dir in [
|
| 101 |
+
cache_dir,
|
| 102 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 103 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 104 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 105 |
+
]:
|
| 106 |
+
if not os.path.isdir(root_dir):
|
| 107 |
+
continue
|
| 108 |
+
for root, _, files in os.walk(root_dir):
|
| 109 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 110 |
+
candidate_dir = root
|
| 111 |
+
break
|
| 112 |
+
if candidate_dir is not None:
|
| 113 |
+
break
|
| 114 |
+
if candidate_dir is not None:
|
| 115 |
+
vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
|
| 116 |
+
vae.eval()
|
| 117 |
+
return vae
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
|
| 121 |
+
vae.eval()
|
| 122 |
+
return vae
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
|
| 126 |
+
res = int(getattr(ta, "resolution", 256))
|
| 127 |
+
latent_size = res // 8
|
| 128 |
+
enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
|
| 129 |
+
z_dims = [semantic_dim_from_enc_type(enc_type)]
|
| 130 |
+
block_kwargs = {
|
| 131 |
+
"fused_attn": getattr(ta, "fused_attn", True),
|
| 132 |
+
"qk_norm": getattr(ta, "qk_norm", False),
|
| 133 |
+
}
|
| 134 |
+
cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
|
| 135 |
+
if ta.model not in SiT_models:
|
| 136 |
+
raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
|
| 137 |
+
model = SiT_models[ta.model](
|
| 138 |
+
input_size=latent_size,
|
| 139 |
+
num_classes=int(getattr(ta, "num_classes", 1000)),
|
| 140 |
+
use_cfg=(cfg_prob > 0),
|
| 141 |
+
z_dims=z_dims,
|
| 142 |
+
encoder_depth=int(getattr(ta, "encoder_depth", 8)),
|
| 143 |
+
**block_kwargs,
|
| 144 |
+
).to(device)
|
| 145 |
+
return model, z_dims[0]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def resolve_tc_schedule(cli, ta):
|
| 149 |
+
sb = cli.steps_before_tc
|
| 150 |
+
sa = cli.steps_after_tc
|
| 151 |
+
tc = cli.t_c
|
| 152 |
+
if sb is None and sa is None:
|
| 153 |
+
return None, None, None
|
| 154 |
+
if sb is None or sa is None:
|
| 155 |
+
print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr)
|
| 156 |
+
sys.exit(1)
|
| 157 |
+
if tc is None:
|
| 158 |
+
tc = getattr(ta, "t_c", None) if ta is not None else None
|
| 159 |
+
if tc is None:
|
| 160 |
+
print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr)
|
| 161 |
+
sys.exit(1)
|
| 162 |
+
return float(tc), int(sb), int(sa)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def parse_cli():
|
| 166 |
+
p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)")
|
| 167 |
+
p.add_argument("--ckpt", type=str, required=True)
|
| 168 |
+
p.add_argument("--out-dir", type=str, required=True)
|
| 169 |
+
p.add_argument("--num-images", type=int, required=True)
|
| 170 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 171 |
+
p.add_argument("--seed", type=int, default=0)
|
| 172 |
+
p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema")
|
| 173 |
+
p.add_argument("--device", type=str, default="cuda")
|
| 174 |
+
p.add_argument("--num-steps", type=int, default=50)
|
| 175 |
+
p.add_argument("--t-c", type=float, default=None)
|
| 176 |
+
p.add_argument("--steps-before-tc", type=int, default=None)
|
| 177 |
+
p.add_argument("--steps-after-tc", type=int, default=None)
|
| 178 |
+
p.add_argument("--cfg-scale", type=float, default=1.0)
|
| 179 |
+
p.add_argument("--cls-cfg-scale", type=float, default=0.0)
|
| 180 |
+
p.add_argument("--guidance-low", type=float, default=0.0)
|
| 181 |
+
p.add_argument("--guidance-high", type=float, default=1.0)
|
| 182 |
+
p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"])
|
| 183 |
+
p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 184 |
+
p.add_argument("--model", type=str, default=None)
|
| 185 |
+
p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
|
| 186 |
+
p.add_argument("--num-classes", type=int, default=1000)
|
| 187 |
+
p.add_argument("--encoder-depth", type=int, default=None)
|
| 188 |
+
p.add_argument("--enc-type", type=str, default=None)
|
| 189 |
+
p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
|
| 190 |
+
p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
|
| 191 |
+
p.add_argument("--cfg-prob", type=float, default=None)
|
| 192 |
+
p.add_argument(
|
| 193 |
+
"--sampler",
|
| 194 |
+
type=str,
|
| 195 |
+
default="em_image_noise_before_tc",
|
| 196 |
+
choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
|
| 197 |
+
)
|
| 198 |
+
p.add_argument(
|
| 199 |
+
"--save-fixed-trajectory",
|
| 200 |
+
action="store_true",
|
| 201 |
+
help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}",
|
| 202 |
+
)
|
| 203 |
+
p.add_argument(
|
| 204 |
+
"--save-npz",
|
| 205 |
+
action=argparse.BooleanOptionalAction,
|
| 206 |
+
default=True,
|
| 207 |
+
help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz",
|
| 208 |
+
)
|
| 209 |
+
return p.parse_args()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
|
| 213 |
+
imgs = vae.decode((latents - latents_bias) / latents_scale).sample
|
| 214 |
+
imgs = (imgs + 1) / 2.0
|
| 215 |
+
imgs = torch.clamp(imgs, 0, 1)
|
| 216 |
+
return (
|
| 217 |
+
(imgs * 255.0)
|
| 218 |
+
.round()
|
| 219 |
+
.to(torch.uint8)
|
| 220 |
+
.permute(0, 2, 3, 1)
|
| 221 |
+
.cpu()
|
| 222 |
+
.numpy()
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def init_ddp():
|
| 227 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 228 |
+
rank = int(os.environ["RANK"])
|
| 229 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 230 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 231 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 232 |
+
torch.cuda.set_device(local_rank)
|
| 233 |
+
return True, rank, world_size, local_rank
|
| 234 |
+
return False, 0, 1, 0
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
cli = parse_cli()
|
| 239 |
+
use_ddp, rank, world_size, local_rank = init_ddp()
|
| 240 |
+
|
| 241 |
+
if torch.cuda.is_available():
|
| 242 |
+
device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device)
|
| 243 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 244 |
+
else:
|
| 245 |
+
device = torch.device("cpu")
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
|
| 249 |
+
except TypeError:
|
| 250 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu")
|
| 251 |
+
ta = load_train_args_from_ckpt(ckpt)
|
| 252 |
+
if ta is None:
|
| 253 |
+
if cli.model is None or cli.resolution is None or cli.enc_type is None:
|
| 254 |
+
print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr)
|
| 255 |
+
sys.exit(1)
|
| 256 |
+
ta = argparse.Namespace(
|
| 257 |
+
model=cli.model,
|
| 258 |
+
resolution=cli.resolution,
|
| 259 |
+
num_classes=cli.num_classes if cli.num_classes is not None else 1000,
|
| 260 |
+
encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
|
| 261 |
+
enc_type=cli.enc_type,
|
| 262 |
+
fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
|
| 263 |
+
qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
|
| 264 |
+
cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
if cli.model is not None:
|
| 268 |
+
ta.model = cli.model
|
| 269 |
+
if cli.resolution is not None:
|
| 270 |
+
ta.resolution = cli.resolution
|
| 271 |
+
if cli.num_classes is not None:
|
| 272 |
+
ta.num_classes = cli.num_classes
|
| 273 |
+
if cli.encoder_depth is not None:
|
| 274 |
+
ta.encoder_depth = cli.encoder_depth
|
| 275 |
+
if cli.enc_type is not None:
|
| 276 |
+
ta.enc_type = cli.enc_type
|
| 277 |
+
if cli.fused_attn is not None:
|
| 278 |
+
ta.fused_attn = cli.fused_attn
|
| 279 |
+
if cli.qk_norm is not None:
|
| 280 |
+
ta.qk_norm = cli.qk_norm
|
| 281 |
+
if cli.cfg_prob is not None:
|
| 282 |
+
ta.cfg_prob = cli.cfg_prob
|
| 283 |
+
|
| 284 |
+
path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
|
| 285 |
+
tc_split = resolve_tc_schedule(cli, ta)
|
| 286 |
+
|
| 287 |
+
if rank == 0:
|
| 288 |
+
if tc_split[0] is not None:
|
| 289 |
+
print(
|
| 290 |
+
f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}"
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
print(f"时间网格:均匀 num_steps={cli.num_steps}")
|
| 294 |
+
|
| 295 |
+
if cli.sampler == "ode":
|
| 296 |
+
sampler_fn = euler_ode_sampler
|
| 297 |
+
elif cli.sampler == "em":
|
| 298 |
+
sampler_fn = euler_maruyama_sampler
|
| 299 |
+
elif cli.sampler == "em_image_noise_before_tc":
|
| 300 |
+
sampler_fn = euler_maruyama_image_noise_before_tc_sampler
|
| 301 |
+
else:
|
| 302 |
+
sampler_fn = euler_maruyama_image_noise_sampler
|
| 303 |
+
|
| 304 |
+
model, cls_dim = build_model_from_train_args(ta, device)
|
| 305 |
+
wkey = cli.weights
|
| 306 |
+
if wkey not in ckpt:
|
| 307 |
+
raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
|
| 308 |
+
state = ckpt[wkey]
|
| 309 |
+
if cli.legacy:
|
| 310 |
+
from utils import load_legacy_checkpoints
|
| 311 |
+
|
| 312 |
+
state = load_legacy_checkpoints(
|
| 313 |
+
state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
|
| 314 |
+
)
|
| 315 |
+
model.load_state_dict(state, strict=True)
|
| 316 |
+
model.eval()
|
| 317 |
+
|
| 318 |
+
vae = load_vae(device)
|
| 319 |
+
latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
|
| 320 |
+
latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
|
| 321 |
+
sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
|
| 322 |
+
|
| 323 |
+
os.makedirs(cli.out_dir, exist_ok=True)
|
| 324 |
+
traj_dir = None
|
| 325 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 326 |
+
traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}")
|
| 327 |
+
os.makedirs(traj_dir, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
latent_size = int(getattr(ta, "resolution", 256)) // 8
|
| 330 |
+
n_total = int(cli.num_images)
|
| 331 |
+
b = max(1, int(cli.batch_size))
|
| 332 |
+
global_batch_size = b * world_size
|
| 333 |
+
total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size)
|
| 334 |
+
samples_needed_this_gpu = int(total_samples // world_size)
|
| 335 |
+
if samples_needed_this_gpu % b != 0:
|
| 336 |
+
raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size")
|
| 337 |
+
iterations = int(samples_needed_this_gpu // b)
|
| 338 |
+
|
| 339 |
+
seed_rank = int(cli.seed) + int(rank)
|
| 340 |
+
torch.manual_seed(seed_rank)
|
| 341 |
+
if device.type == "cuda":
|
| 342 |
+
torch.cuda.manual_seed_all(seed_rank)
|
| 343 |
+
|
| 344 |
+
if rank == 0:
|
| 345 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 346 |
+
pbar = range(iterations)
|
| 347 |
+
pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar
|
| 348 |
+
total = 0
|
| 349 |
+
written_local = 0
|
| 350 |
+
for _ in pbar:
|
| 351 |
+
cur = b
|
| 352 |
+
z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
|
| 353 |
+
y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
|
| 354 |
+
cls_z = torch.randn(cur, cls_dim, device=device)
|
| 355 |
+
|
| 356 |
+
with torch.no_grad():
|
| 357 |
+
em_kw = dict(
|
| 358 |
+
num_steps=cli.num_steps,
|
| 359 |
+
cfg_scale=cli.cfg_scale,
|
| 360 |
+
guidance_low=cli.guidance_low,
|
| 361 |
+
guidance_high=cli.guidance_high,
|
| 362 |
+
path_type=path_type,
|
| 363 |
+
cls_latents=cls_z,
|
| 364 |
+
args=sampler_args,
|
| 365 |
+
)
|
| 366 |
+
if tc_split[0] is not None:
|
| 367 |
+
em_kw["t_c"] = tc_split[0]
|
| 368 |
+
em_kw["num_steps_before_tc"] = tc_split[1]
|
| 369 |
+
em_kw["num_steps_after_tc"] = tc_split[2]
|
| 370 |
+
|
| 371 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 372 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 373 |
+
latents, traj = sampler_fn(
|
| 374 |
+
model, z, y, **em_kw, return_trajectory=True
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
latents, traj = sampler_fn(
|
| 378 |
+
model, z, y, **em_kw, return_trajectory=True
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
latents = sampler_fn(model, z, y, **em_kw)
|
| 382 |
+
traj = None
|
| 383 |
+
|
| 384 |
+
latents = latents.to(torch.float32)
|
| 385 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 386 |
+
for i, img in enumerate(imgs):
|
| 387 |
+
gidx = i * world_size + rank + total
|
| 388 |
+
if gidx < n_total:
|
| 389 |
+
Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png"))
|
| 390 |
+
written_local += 1
|
| 391 |
+
|
| 392 |
+
if traj is not None and traj_dir is not None:
|
| 393 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 394 |
+
first_idx = rank + total
|
| 395 |
+
if first_idx < n_total:
|
| 396 |
+
np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np)
|
| 397 |
+
|
| 398 |
+
total += global_batch_size
|
| 399 |
+
if use_ddp:
|
| 400 |
+
dist.barrier()
|
| 401 |
+
if rank == 0 and hasattr(pbar, "close"):
|
| 402 |
+
pbar.close()
|
| 403 |
+
|
| 404 |
+
if use_ddp:
|
| 405 |
+
dist.barrier()
|
| 406 |
+
if rank == 0:
|
| 407 |
+
if cli.save_npz:
|
| 408 |
+
create_npz_from_sample_folder(cli.out_dir, n_total)
|
| 409 |
+
print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).")
|
| 410 |
+
if use_ddp:
|
| 411 |
+
dist.destroy_process_group()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
main()
|
| 416 |
+
|
back/samplers.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def expand_t_like_x(t, x_cur):
|
| 6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 7 |
+
Args:
|
| 8 |
+
t: [batch_dim,], time vector
|
| 9 |
+
x: [batch_dim,...], data point
|
| 10 |
+
"""
|
| 11 |
+
dims = [1] * (len(x_cur.size()) - 1)
|
| 12 |
+
t = t.view(t.size(0), *dims)
|
| 13 |
+
return t
|
| 14 |
+
|
| 15 |
+
def get_score_from_velocity(vt, xt, t, path_type="linear"):
|
| 16 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 17 |
+
Args:
|
| 18 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 19 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 20 |
+
t: [batch_dim,] time tensor
|
| 21 |
+
"""
|
| 22 |
+
t = expand_t_like_x(t, xt)
|
| 23 |
+
if path_type == "linear":
|
| 24 |
+
alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
|
| 25 |
+
sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
|
| 26 |
+
elif path_type == "cosine":
|
| 27 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 28 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 29 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 30 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
mean = xt
|
| 35 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 36 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 37 |
+
score = (reverse_alpha_ratio * vt - mean) / var
|
| 38 |
+
|
| 39 |
+
return score
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_diffusion(t_cur):
|
| 43 |
+
return 2 * t_cur
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_sampling_time_steps(
|
| 47 |
+
num_steps=50,
|
| 48 |
+
t_c=None,
|
| 49 |
+
num_steps_before_tc=None,
|
| 50 |
+
num_steps_after_tc=None,
|
| 51 |
+
t_floor=0.04,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。
|
| 55 |
+
|
| 56 |
+
- 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。
|
| 57 |
+
- 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c);
|
| 58 |
+
t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。
|
| 59 |
+
"""
|
| 60 |
+
t_floor = float(t_floor)
|
| 61 |
+
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
|
| 62 |
+
ns = int(num_steps)
|
| 63 |
+
if ns < 1:
|
| 64 |
+
raise ValueError("num_steps must be >= 1")
|
| 65 |
+
t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64)
|
| 66 |
+
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
|
| 67 |
+
|
| 68 |
+
t_c = float(t_c)
|
| 69 |
+
nb = int(num_steps_before_tc)
|
| 70 |
+
na = int(num_steps_after_tc)
|
| 71 |
+
if nb < 1 or na < 1:
|
| 72 |
+
raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c")
|
| 73 |
+
if not (0.0 < t_c < 1.0):
|
| 74 |
+
raise ValueError("t_c must be in (0, 1)")
|
| 75 |
+
if t_c <= t_floor:
|
| 76 |
+
raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})")
|
| 77 |
+
|
| 78 |
+
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64)
|
| 79 |
+
p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64)
|
| 80 |
+
t_steps = torch.cat([p1, p2[1:]])
|
| 81 |
+
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc):
|
| 85 |
+
"""仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。"""
|
| 86 |
+
return (
|
| 87 |
+
t_c is not None
|
| 88 |
+
and num_steps_before_tc is not None
|
| 89 |
+
and num_steps_after_tc is not None
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _cls_effective_and_freeze(
|
| 94 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。
|
| 98 |
+
返回 (用于前向的 cls, 更新后的 cls_frozen)。
|
| 99 |
+
"""
|
| 100 |
+
if not freeze_after_tc or t_c_v is None:
|
| 101 |
+
return cls_x_cur, cls_frozen
|
| 102 |
+
if float(t_cur) <= float(t_c_v) + 1e-9:
|
| 103 |
+
if cls_frozen is None:
|
| 104 |
+
cls_frozen = cls_x_cur.clone()
|
| 105 |
+
return cls_frozen, cls_frozen
|
| 106 |
+
return cls_x_cur, cls_frozen
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _build_euler_sampler_time_steps(
|
| 110 |
+
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
|
| 111 |
+
):
|
| 112 |
+
"""
|
| 113 |
+
euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。
|
| 114 |
+
"""
|
| 115 |
+
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
|
| 116 |
+
ns = int(num_steps)
|
| 117 |
+
if ns < 1:
|
| 118 |
+
raise ValueError("num_steps must be >= 1")
|
| 119 |
+
return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device)
|
| 120 |
+
t_c = float(t_c)
|
| 121 |
+
nb = int(num_steps_before_tc)
|
| 122 |
+
na = int(num_steps_after_tc)
|
| 123 |
+
if nb < 1 or na < 1:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c"
|
| 126 |
+
)
|
| 127 |
+
if not (0.0 < t_c < 1.0):
|
| 128 |
+
raise ValueError("t_c must be in (0, 1)")
|
| 129 |
+
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device)
|
| 130 |
+
p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device)
|
| 131 |
+
return torch.cat([p1, p2[1:]])
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def euler_maruyama_sampler(
|
| 135 |
+
model,
|
| 136 |
+
latents,
|
| 137 |
+
y,
|
| 138 |
+
num_steps=20,
|
| 139 |
+
heun=False, # not used, just for compatability
|
| 140 |
+
cfg_scale=1.0,
|
| 141 |
+
guidance_low=0.0,
|
| 142 |
+
guidance_high=1.0,
|
| 143 |
+
path_type="linear",
|
| 144 |
+
cls_latents=None,
|
| 145 |
+
args=None,
|
| 146 |
+
return_mid_state=False,
|
| 147 |
+
t_mid=0.5,
|
| 148 |
+
t_c=None,
|
| 149 |
+
num_steps_before_tc=None,
|
| 150 |
+
num_steps_after_tc=None,
|
| 151 |
+
deterministic=False,
|
| 152 |
+
return_trajectory=False,
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致;
|
| 156 |
+
deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor),
|
| 157 |
+
本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。
|
| 158 |
+
"""
|
| 159 |
+
# setup conditioning
|
| 160 |
+
if cfg_scale > 1.0:
|
| 161 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 162 |
+
#[1000, 1000]
|
| 163 |
+
_dtype = latents.dtype
|
| 164 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 165 |
+
|
| 166 |
+
t_steps = build_sampling_time_steps(
|
| 167 |
+
num_steps=num_steps,
|
| 168 |
+
t_c=t_c,
|
| 169 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 170 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 171 |
+
)
|
| 172 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 173 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 174 |
+
x_next = latents.to(torch.float64)
|
| 175 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 176 |
+
device = x_next.device
|
| 177 |
+
z_mid = cls_mid = None
|
| 178 |
+
t_mid = float(t_mid)
|
| 179 |
+
cls_frozen = None
|
| 180 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
|
| 184 |
+
dt = t_next - t_cur
|
| 185 |
+
x_cur = x_next
|
| 186 |
+
cls_x_cur = cls_x_next
|
| 187 |
+
|
| 188 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 189 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
tc, tn = float(t_cur), float(t_next)
|
| 193 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 194 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 195 |
+
z_mid = x_cur.clone()
|
| 196 |
+
cls_mid = cls_model_input.clone()
|
| 197 |
+
|
| 198 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 199 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 200 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 201 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 202 |
+
else:
|
| 203 |
+
model_input = x_cur
|
| 204 |
+
y_cur = y
|
| 205 |
+
|
| 206 |
+
kwargs = dict(y=y_cur)
|
| 207 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 208 |
+
diffusion = compute_diffusion(t_cur)
|
| 209 |
+
|
| 210 |
+
if deterministic:
|
| 211 |
+
deps = torch.zeros_like(x_cur)
|
| 212 |
+
cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)])
|
| 213 |
+
else:
|
| 214 |
+
eps_i = torch.randn_like(x_cur).to(device)
|
| 215 |
+
cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device)
|
| 216 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 217 |
+
cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
|
| 218 |
+
|
| 219 |
+
# compute drift
|
| 220 |
+
v_cur, _, cls_v_cur = model(
|
| 221 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 222 |
+
)
|
| 223 |
+
v_cur = v_cur.to(torch.float64)
|
| 224 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 225 |
+
|
| 226 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 227 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 228 |
+
|
| 229 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 230 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 231 |
+
|
| 232 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 233 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 234 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 235 |
+
|
| 236 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 237 |
+
if cls_cfg > 0:
|
| 238 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 239 |
+
else:
|
| 240 |
+
cls_d_cur = cls_d_cur_cond
|
| 241 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 242 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 243 |
+
cls_x_next = cls_frozen
|
| 244 |
+
else:
|
| 245 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
|
| 246 |
+
|
| 247 |
+
if return_trajectory:
|
| 248 |
+
traj.append(x_next.clone())
|
| 249 |
+
|
| 250 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 251 |
+
z_mid = x_next.clone()
|
| 252 |
+
cls_mid = cls_x_next.clone()
|
| 253 |
+
|
| 254 |
+
# last step
|
| 255 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 256 |
+
dt = t_next - t_cur
|
| 257 |
+
x_cur = x_next
|
| 258 |
+
cls_x_cur = cls_x_next
|
| 259 |
+
|
| 260 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 261 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 265 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 266 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 267 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 268 |
+
else:
|
| 269 |
+
model_input = x_cur
|
| 270 |
+
y_cur = y
|
| 271 |
+
kwargs = dict(y=y_cur)
|
| 272 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 273 |
+
device=device, dtype=torch.float64
|
| 274 |
+
) * t_cur
|
| 275 |
+
|
| 276 |
+
# compute drift
|
| 277 |
+
v_cur, _, cls_v_cur = model(
|
| 278 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 279 |
+
)
|
| 280 |
+
v_cur = v_cur.to(torch.float64)
|
| 281 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 285 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 286 |
+
|
| 287 |
+
diffusion = compute_diffusion(t_cur)
|
| 288 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 289 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
|
| 290 |
+
|
| 291 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 292 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 293 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 294 |
+
|
| 295 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 296 |
+
if cls_cfg > 0:
|
| 297 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 298 |
+
else:
|
| 299 |
+
cls_d_cur = cls_d_cur_cond
|
| 300 |
+
|
| 301 |
+
mean_x = x_cur + dt * d_cur
|
| 302 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 303 |
+
cls_mean_x = cls_frozen
|
| 304 |
+
else:
|
| 305 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 306 |
+
|
| 307 |
+
if return_trajectory:
|
| 308 |
+
traj.append(mean_x.clone())
|
| 309 |
+
|
| 310 |
+
if return_trajectory and return_mid_state:
|
| 311 |
+
return mean_x, z_mid, cls_mid, traj
|
| 312 |
+
if return_trajectory:
|
| 313 |
+
return mean_x, traj
|
| 314 |
+
if return_mid_state:
|
| 315 |
+
return mean_x, z_mid, cls_mid
|
| 316 |
+
return mean_x
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def euler_maruyama_image_noise_sampler(
|
| 320 |
+
model,
|
| 321 |
+
latents,
|
| 322 |
+
y,
|
| 323 |
+
num_steps=20,
|
| 324 |
+
heun=False, # not used, just for compatability
|
| 325 |
+
cfg_scale=1.0,
|
| 326 |
+
guidance_low=0.0,
|
| 327 |
+
guidance_high=1.0,
|
| 328 |
+
path_type="linear",
|
| 329 |
+
cls_latents=None,
|
| 330 |
+
args=None,
|
| 331 |
+
return_mid_state=False,
|
| 332 |
+
t_mid=0.5,
|
| 333 |
+
t_c=None,
|
| 334 |
+
num_steps_before_tc=None,
|
| 335 |
+
num_steps_after_tc=None,
|
| 336 |
+
return_trajectory=False,
|
| 337 |
+
):
|
| 338 |
+
"""
|
| 339 |
+
EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。
|
| 340 |
+
"""
|
| 341 |
+
if cfg_scale > 1.0:
|
| 342 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 343 |
+
_dtype = latents.dtype
|
| 344 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 345 |
+
|
| 346 |
+
t_steps = build_sampling_time_steps(
|
| 347 |
+
num_steps=num_steps,
|
| 348 |
+
t_c=t_c,
|
| 349 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 350 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 351 |
+
)
|
| 352 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 353 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 354 |
+
x_next = latents.to(torch.float64)
|
| 355 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 356 |
+
device = x_next.device
|
| 357 |
+
z_mid = cls_mid = None
|
| 358 |
+
t_mid = float(t_mid)
|
| 359 |
+
cls_frozen = None
|
| 360 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 361 |
+
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
|
| 364 |
+
dt = t_next - t_cur
|
| 365 |
+
x_cur = x_next
|
| 366 |
+
cls_x_cur = cls_x_next
|
| 367 |
+
|
| 368 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 369 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
tc, tn = float(t_cur), float(t_next)
|
| 373 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 374 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 375 |
+
z_mid = x_cur.clone()
|
| 376 |
+
cls_mid = cls_model_input.clone()
|
| 377 |
+
|
| 378 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 379 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 380 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 381 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 382 |
+
else:
|
| 383 |
+
model_input = x_cur
|
| 384 |
+
y_cur = y
|
| 385 |
+
|
| 386 |
+
kwargs = dict(y=y_cur)
|
| 387 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 388 |
+
diffusion = compute_diffusion(t_cur)
|
| 389 |
+
|
| 390 |
+
eps_i = torch.randn_like(x_cur).to(device)
|
| 391 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 392 |
+
|
| 393 |
+
v_cur, _, cls_v_cur = model(
|
| 394 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 395 |
+
)
|
| 396 |
+
v_cur = v_cur.to(torch.float64)
|
| 397 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 398 |
+
|
| 399 |
+
if add_img_noise:
|
| 400 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 401 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 402 |
+
|
| 403 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 404 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 405 |
+
else:
|
| 406 |
+
# t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。
|
| 407 |
+
d_cur = v_cur
|
| 408 |
+
cls_d_cur = cls_v_cur
|
| 409 |
+
|
| 410 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 411 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 412 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 413 |
+
|
| 414 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 415 |
+
if cls_cfg > 0:
|
| 416 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 417 |
+
else:
|
| 418 |
+
cls_d_cur = cls_d_cur_cond
|
| 419 |
+
|
| 420 |
+
# 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项)
|
| 421 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 422 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 423 |
+
cls_x_next = cls_frozen
|
| 424 |
+
else:
|
| 425 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt
|
| 426 |
+
if return_trajectory:
|
| 427 |
+
traj.append(x_next.clone())
|
| 428 |
+
|
| 429 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 430 |
+
z_mid = x_next.clone()
|
| 431 |
+
cls_mid = cls_x_next.clone()
|
| 432 |
+
|
| 433 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 434 |
+
dt = t_next - t_cur
|
| 435 |
+
x_cur = x_next
|
| 436 |
+
cls_x_cur = cls_x_next
|
| 437 |
+
|
| 438 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 439 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 443 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 444 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 445 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 446 |
+
else:
|
| 447 |
+
model_input = x_cur
|
| 448 |
+
y_cur = y
|
| 449 |
+
kwargs = dict(y=y_cur)
|
| 450 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 451 |
+
device=device, dtype=torch.float64
|
| 452 |
+
) * t_cur
|
| 453 |
+
|
| 454 |
+
v_cur, _, cls_v_cur = model(
|
| 455 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 456 |
+
)
|
| 457 |
+
v_cur = v_cur.to(torch.float64)
|
| 458 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 459 |
+
|
| 460 |
+
# 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。
|
| 461 |
+
d_cur = v_cur
|
| 462 |
+
cls_d_cur = cls_v_cur
|
| 463 |
+
|
| 464 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 465 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 466 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 467 |
+
|
| 468 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 469 |
+
if cls_cfg > 0:
|
| 470 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 471 |
+
else:
|
| 472 |
+
cls_d_cur = cls_d_cur_cond
|
| 473 |
+
|
| 474 |
+
mean_x = x_cur + dt * d_cur
|
| 475 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 476 |
+
cls_mean_x = cls_frozen
|
| 477 |
+
else:
|
| 478 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 479 |
+
|
| 480 |
+
if return_trajectory and return_mid_state:
|
| 481 |
+
return mean_x, z_mid, cls_mid, traj
|
| 482 |
+
if return_trajectory:
|
| 483 |
+
return mean_x, traj
|
| 484 |
+
if return_mid_state:
|
| 485 |
+
return mean_x, z_mid, cls_mid
|
| 486 |
+
return mean_x
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def euler_maruyama_image_noise_before_tc_sampler(
|
| 490 |
+
model,
|
| 491 |
+
latents,
|
| 492 |
+
y,
|
| 493 |
+
num_steps=20,
|
| 494 |
+
heun=False, # not used, just for compatability
|
| 495 |
+
cfg_scale=1.0,
|
| 496 |
+
guidance_low=0.0,
|
| 497 |
+
guidance_high=1.0,
|
| 498 |
+
path_type="linear",
|
| 499 |
+
cls_latents=None,
|
| 500 |
+
args=None,
|
| 501 |
+
return_mid_state=False,
|
| 502 |
+
t_mid=0.5,
|
| 503 |
+
t_c=None,
|
| 504 |
+
num_steps_before_tc=None,
|
| 505 |
+
num_steps_after_tc=None,
|
| 506 |
+
return_cls_final=False,
|
| 507 |
+
return_trajectory=False,
|
| 508 |
+
):
|
| 509 |
+
"""
|
| 510 |
+
EM 采样变体:
|
| 511 |
+
- 图像 latent 在 t > t_c 区间引入随机扩散噪声;
|
| 512 |
+
- 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移);
|
| 513 |
+
- cls/token 通道全程不引入随机项。
|
| 514 |
+
"""
|
| 515 |
+
if cfg_scale > 1.0:
|
| 516 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 517 |
+
_dtype = latents.dtype
|
| 518 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 519 |
+
|
| 520 |
+
t_steps = build_sampling_time_steps(
|
| 521 |
+
num_steps=num_steps,
|
| 522 |
+
t_c=t_c,
|
| 523 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 524 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 525 |
+
)
|
| 526 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 527 |
+
t_c_freeze = float(t_c) if freeze_after_tc else None
|
| 528 |
+
x_next = latents.to(torch.float64)
|
| 529 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 530 |
+
device = x_next.device
|
| 531 |
+
z_mid = cls_mid = None
|
| 532 |
+
t_mid = float(t_mid)
|
| 533 |
+
t_c_v = None if t_c is None else float(t_c)
|
| 534 |
+
cls_frozen = None
|
| 535 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 536 |
+
|
| 537 |
+
with torch.no_grad():
|
| 538 |
+
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
|
| 539 |
+
dt = t_next - t_cur
|
| 540 |
+
x_cur = x_next
|
| 541 |
+
cls_x_cur = cls_x_next
|
| 542 |
+
|
| 543 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 544 |
+
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
tc, tn = float(t_cur), float(t_next)
|
| 548 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 549 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 550 |
+
z_mid = x_cur.clone()
|
| 551 |
+
cls_mid = cls_model_input.clone()
|
| 552 |
+
|
| 553 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 554 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 555 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 556 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 557 |
+
else:
|
| 558 |
+
model_input = x_cur
|
| 559 |
+
y_cur = y
|
| 560 |
+
|
| 561 |
+
kwargs = dict(y=y_cur)
|
| 562 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 563 |
+
diffusion = compute_diffusion(t_cur)
|
| 564 |
+
|
| 565 |
+
# 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声
|
| 566 |
+
add_img_noise = True
|
| 567 |
+
if t_c_v is not None and float(t_next) <= t_c_v:
|
| 568 |
+
add_img_noise = False
|
| 569 |
+
|
| 570 |
+
eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur)
|
| 571 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 572 |
+
|
| 573 |
+
v_cur, _, cls_v_cur = model(
|
| 574 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 575 |
+
)
|
| 576 |
+
v_cur = v_cur.to(torch.float64)
|
| 577 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 578 |
+
|
| 579 |
+
if add_img_noise:
|
| 580 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 581 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 582 |
+
|
| 583 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 584 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 585 |
+
else:
|
| 586 |
+
# t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项)
|
| 587 |
+
d_cur = v_cur
|
| 588 |
+
cls_d_cur = cls_v_cur
|
| 589 |
+
|
| 590 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 591 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 592 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 593 |
+
|
| 594 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 595 |
+
if cls_cfg > 0:
|
| 596 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 597 |
+
else:
|
| 598 |
+
cls_d_cur = cls_d_cur_cond
|
| 599 |
+
|
| 600 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 601 |
+
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
|
| 602 |
+
cls_x_next = cls_frozen
|
| 603 |
+
else:
|
| 604 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt
|
| 605 |
+
if return_trajectory:
|
| 606 |
+
traj.append(x_next.clone())
|
| 607 |
+
|
| 608 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 609 |
+
z_mid = x_next.clone()
|
| 610 |
+
cls_mid = cls_x_next.clone()
|
| 611 |
+
|
| 612 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 613 |
+
dt = t_next - t_cur
|
| 614 |
+
x_cur = x_next
|
| 615 |
+
cls_x_cur = cls_x_next
|
| 616 |
+
|
| 617 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 618 |
+
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 622 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 623 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 624 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 625 |
+
else:
|
| 626 |
+
model_input = x_cur
|
| 627 |
+
y_cur = y
|
| 628 |
+
kwargs = dict(y=y_cur)
|
| 629 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 630 |
+
device=device, dtype=torch.float64
|
| 631 |
+
) * t_cur
|
| 632 |
+
|
| 633 |
+
v_cur, _, cls_v_cur = model(
|
| 634 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 635 |
+
)
|
| 636 |
+
v_cur = v_cur.to(torch.float64)
|
| 637 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 638 |
+
|
| 639 |
+
# 最后一步无随机项,保持与 ODE 一致使用 d=v。
|
| 640 |
+
d_cur = v_cur
|
| 641 |
+
cls_d_cur = cls_v_cur
|
| 642 |
+
|
| 643 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 644 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 645 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 646 |
+
|
| 647 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 648 |
+
if cls_cfg > 0:
|
| 649 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 650 |
+
else:
|
| 651 |
+
cls_d_cur = cls_d_cur_cond
|
| 652 |
+
|
| 653 |
+
mean_x = x_cur + dt * d_cur
|
| 654 |
+
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
|
| 655 |
+
cls_mean_x = cls_frozen
|
| 656 |
+
else:
|
| 657 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 658 |
+
|
| 659 |
+
if return_trajectory and return_mid_state and return_cls_final:
|
| 660 |
+
return mean_x, z_mid, cls_mid, cls_mean_x, traj
|
| 661 |
+
if return_trajectory and return_mid_state:
|
| 662 |
+
return mean_x, z_mid, cls_mid, traj
|
| 663 |
+
if return_trajectory and return_cls_final:
|
| 664 |
+
return mean_x, cls_mean_x, traj
|
| 665 |
+
if return_trajectory:
|
| 666 |
+
return mean_x, traj
|
| 667 |
+
if return_mid_state and return_cls_final:
|
| 668 |
+
return mean_x, z_mid, cls_mid, cls_mean_x
|
| 669 |
+
if return_mid_state:
|
| 670 |
+
return mean_x, z_mid, cls_mid
|
| 671 |
+
if return_cls_final:
|
| 672 |
+
return mean_x, cls_mean_x
|
| 673 |
+
return mean_x
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def euler_ode_sampler(
|
| 677 |
+
model,
|
| 678 |
+
latents,
|
| 679 |
+
y,
|
| 680 |
+
num_steps=20,
|
| 681 |
+
cfg_scale=1.0,
|
| 682 |
+
guidance_low=0.0,
|
| 683 |
+
guidance_high=1.0,
|
| 684 |
+
path_type="linear",
|
| 685 |
+
cls_latents=None,
|
| 686 |
+
args=None,
|
| 687 |
+
return_mid_state=False,
|
| 688 |
+
t_mid=0.5,
|
| 689 |
+
t_c=None,
|
| 690 |
+
num_steps_before_tc=None,
|
| 691 |
+
num_steps_after_tc=None,
|
| 692 |
+
return_trajectory=False,
|
| 693 |
+
):
|
| 694 |
+
"""
|
| 695 |
+
REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。
|
| 696 |
+
"""
|
| 697 |
+
return euler_sampler(
|
| 698 |
+
model,
|
| 699 |
+
latents,
|
| 700 |
+
y,
|
| 701 |
+
num_steps=num_steps,
|
| 702 |
+
heun=False,
|
| 703 |
+
cfg_scale=cfg_scale,
|
| 704 |
+
guidance_low=guidance_low,
|
| 705 |
+
guidance_high=guidance_high,
|
| 706 |
+
path_type=path_type,
|
| 707 |
+
cls_latents=cls_latents,
|
| 708 |
+
args=args,
|
| 709 |
+
return_mid_state=return_mid_state,
|
| 710 |
+
t_mid=t_mid,
|
| 711 |
+
t_c=t_c,
|
| 712 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 713 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 714 |
+
return_trajectory=return_trajectory,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def euler_sampler(
|
| 719 |
+
model,
|
| 720 |
+
latents,
|
| 721 |
+
y,
|
| 722 |
+
num_steps=20,
|
| 723 |
+
heun=False,
|
| 724 |
+
cfg_scale=1.0,
|
| 725 |
+
guidance_low=0.0,
|
| 726 |
+
guidance_high=1.0,
|
| 727 |
+
path_type="linear",
|
| 728 |
+
cls_latents=None,
|
| 729 |
+
args=None,
|
| 730 |
+
return_mid_state=False,
|
| 731 |
+
t_mid=0.5,
|
| 732 |
+
t_c=None,
|
| 733 |
+
num_steps_before_tc=None,
|
| 734 |
+
num_steps_after_tc=None,
|
| 735 |
+
return_trajectory=False,
|
| 736 |
+
):
|
| 737 |
+
"""
|
| 738 |
+
轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。
|
| 739 |
+
|
| 740 |
+
- 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。
|
| 741 |
+
- 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。
|
| 742 |
+
- 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。
|
| 743 |
+
|
| 744 |
+
REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。
|
| 745 |
+
"""
|
| 746 |
+
if cls_latents is None:
|
| 747 |
+
raise ValueError(
|
| 748 |
+
"euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。"
|
| 749 |
+
)
|
| 750 |
+
if cfg_scale > 1.0:
|
| 751 |
+
y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype)
|
| 752 |
+
else:
|
| 753 |
+
y_null = None
|
| 754 |
+
_dtype = latents.dtype
|
| 755 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 756 |
+
device = latents.device
|
| 757 |
+
|
| 758 |
+
t_steps = _build_euler_sampler_time_steps(
|
| 759 |
+
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
|
| 760 |
+
)
|
| 761 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 762 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 763 |
+
|
| 764 |
+
x_next = latents.to(torch.float64)
|
| 765 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 766 |
+
z_mid = cls_mid = None
|
| 767 |
+
t_mid = float(t_mid)
|
| 768 |
+
cls_frozen = None
|
| 769 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 770 |
+
|
| 771 |
+
with torch.no_grad():
|
| 772 |
+
for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
|
| 773 |
+
dt = t_next - t_cur
|
| 774 |
+
x_cur = x_next
|
| 775 |
+
cls_x_cur = cls_x_next
|
| 776 |
+
|
| 777 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 778 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
tc, tn = float(t_cur), float(t_next)
|
| 782 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 783 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 784 |
+
z_mid = x_cur.clone()
|
| 785 |
+
cls_mid = cls_model_input.clone()
|
| 786 |
+
|
| 787 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 788 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 789 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 790 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 791 |
+
else:
|
| 792 |
+
model_input = x_cur
|
| 793 |
+
y_cur = y
|
| 794 |
+
|
| 795 |
+
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
|
| 796 |
+
|
| 797 |
+
v_cur, _, cls_v_cur = model(
|
| 798 |
+
model_input.to(dtype=_dtype),
|
| 799 |
+
time_input.to(dtype=_dtype),
|
| 800 |
+
y_cur,
|
| 801 |
+
cls_token=cls_model_input.to(dtype=_dtype),
|
| 802 |
+
)
|
| 803 |
+
v_cur = v_cur.to(torch.float64)
|
| 804 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 805 |
+
|
| 806 |
+
# ODE: follow velocity parameterization directly (d/dt x_t = v_t).
|
| 807 |
+
# This aligns with velocity training target and avoids extra v->score->drift conversion.
|
| 808 |
+
d_cur = v_cur
|
| 809 |
+
cls_d_cur = cls_v_cur
|
| 810 |
+
|
| 811 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 812 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 813 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 814 |
+
|
| 815 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 816 |
+
if cls_cfg > 0:
|
| 817 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 818 |
+
else:
|
| 819 |
+
cls_d_cur = cls_d_cur_cond
|
| 820 |
+
|
| 821 |
+
x_next = x_cur + dt * d_cur
|
| 822 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 823 |
+
cls_x_next = cls_frozen
|
| 824 |
+
else:
|
| 825 |
+
cls_x_next = cls_x_cur + dt * cls_d_cur
|
| 826 |
+
|
| 827 |
+
if return_trajectory:
|
| 828 |
+
traj.append(x_next.clone())
|
| 829 |
+
|
| 830 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 831 |
+
z_mid = x_next.clone()
|
| 832 |
+
cls_mid = cls_x_next.clone()
|
| 833 |
+
|
| 834 |
+
if return_trajectory and return_mid_state:
|
| 835 |
+
return x_next, z_mid, cls_mid, traj
|
| 836 |
+
if return_trajectory:
|
| 837 |
+
return x_next, traj
|
| 838 |
+
if return_mid_state:
|
| 839 |
+
return x_next, z_mid, cls_mid
|
| 840 |
+
return x_next
|
back/samples_0.75.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
back/samples_0.75_new.log
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
时间网格:t_c=0.75, 步数 (1→t_c)=100, (t_c→0)=50
|
| 6 |
+
Total number of images that will be sampled: 40192
|
| 7 |
+
|
| 8 |
+
[rank3]:[W325 16:57:00.799344818 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 9 |
+
[rank1]:[W325 16:57:00.847229448 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 10 |
+
[rank0]:[W325 16:57:01.326116049 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 11 |
+
|
conditional-flow-matching/.gitignore
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.venv
|
| 106 |
+
env/
|
| 107 |
+
venv/
|
| 108 |
+
ENV/
|
| 109 |
+
env.bak/
|
| 110 |
+
venv.bak/
|
| 111 |
+
|
| 112 |
+
# Spyder project settings
|
| 113 |
+
.spyderproject
|
| 114 |
+
.spyproject
|
| 115 |
+
|
| 116 |
+
# Rope project settings
|
| 117 |
+
.ropeproject
|
| 118 |
+
|
| 119 |
+
# mkdocs documentation
|
| 120 |
+
/site
|
| 121 |
+
|
| 122 |
+
# mypy
|
| 123 |
+
.mypy_cache/
|
| 124 |
+
.dmypy.json
|
| 125 |
+
dmypy.json
|
| 126 |
+
|
| 127 |
+
# Pyre type checker
|
| 128 |
+
.pyre/
|
| 129 |
+
|
| 130 |
+
### VisualStudioCode
|
| 131 |
+
.vscode/*
|
| 132 |
+
!.vscode/settings.json
|
| 133 |
+
!.vscode/tasks.json
|
| 134 |
+
!.vscode/launch.json
|
| 135 |
+
!.vscode/extensions.json
|
| 136 |
+
*.code-workspace
|
| 137 |
+
**/.vscode
|
| 138 |
+
|
| 139 |
+
# JetBrains
|
| 140 |
+
.idea/
|
| 141 |
+
|
| 142 |
+
# Lightning-Hydra-Template
|
| 143 |
+
configs/local/default.yaml
|
| 144 |
+
data/
|
| 145 |
+
logs/
|
| 146 |
+
wandb/
|
| 147 |
+
.env
|
| 148 |
+
.autoenv
|
| 149 |
+
|
| 150 |
+
#Vim
|
| 151 |
+
*.sw?
|
| 152 |
+
|
| 153 |
+
# Slurm
|
| 154 |
+
slurm*.out
|
| 155 |
+
|
| 156 |
+
# Data and models
|
| 157 |
+
*.pt
|
| 158 |
+
*.h5
|
| 159 |
+
*.h5ad
|
| 160 |
+
*.tar
|
| 161 |
+
*.tar.gz
|
| 162 |
+
*.pkl
|
| 163 |
+
*.npy
|
| 164 |
+
*.npz
|
| 165 |
+
*.csv
|
| 166 |
+
|
| 167 |
+
# Images
|
| 168 |
+
*.png
|
| 169 |
+
*.svg
|
| 170 |
+
*.gif
|
| 171 |
+
*.jpg
|
| 172 |
+
|
| 173 |
+
notebooks/figures/
|
| 174 |
+
|
| 175 |
+
.DS_Store
|
conditional-flow-matching/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.6.0
|
| 4 |
+
hooks:
|
| 5 |
+
# list of supported hooks: https://pre-commit.com/hooks.html
|
| 6 |
+
- id: trailing-whitespace
|
| 7 |
+
exclude: .svg$
|
| 8 |
+
require_serial: true
|
| 9 |
+
- id: end-of-file-fixer
|
| 10 |
+
require_serial: true
|
| 11 |
+
- id: check-docstring-first
|
| 12 |
+
require_serial: true
|
| 13 |
+
- id: check-yaml
|
| 14 |
+
require_serial: true
|
| 15 |
+
- id: debug-statements
|
| 16 |
+
require_serial: true
|
| 17 |
+
- id: detect-private-key
|
| 18 |
+
require_serial: true
|
| 19 |
+
- id: check-executables-have-shebangs
|
| 20 |
+
require_serial: true
|
| 21 |
+
- id: check-toml
|
| 22 |
+
require_serial: true
|
| 23 |
+
- id: check-case-conflict
|
| 24 |
+
require_serial: true
|
| 25 |
+
- id: check-added-large-files
|
| 26 |
+
require_serial: true
|
| 27 |
+
|
| 28 |
+
# python upgrading syntax to newer version
|
| 29 |
+
- repo: https://github.com/asottile/pyupgrade
|
| 30 |
+
rev: v3.17.0
|
| 31 |
+
hooks:
|
| 32 |
+
- id: pyupgrade
|
| 33 |
+
require_serial: true
|
| 34 |
+
args: [--py38-plus]
|
| 35 |
+
|
| 36 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 37 |
+
rev: v0.5.5
|
| 38 |
+
hooks:
|
| 39 |
+
- id: ruff
|
| 40 |
+
args: [--fix]
|
| 41 |
+
- id: ruff-format
|
| 42 |
+
|
| 43 |
+
# python security linter
|
| 44 |
+
- repo: https://github.com/gitleaks/gitleaks
|
| 45 |
+
rev: v8.18.2
|
| 46 |
+
hooks:
|
| 47 |
+
- id: gitleaks
|
| 48 |
+
|
| 49 |
+
# yaml formatting
|
| 50 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
| 51 |
+
rev: v3.0.0
|
| 52 |
+
hooks:
|
| 53 |
+
- id: prettier
|
| 54 |
+
require_serial: true
|
| 55 |
+
types: [yaml]
|
| 56 |
+
|
| 57 |
+
# shell scripts linter
|
| 58 |
+
- repo: https://github.com/shellcheck-py/shellcheck-py
|
| 59 |
+
rev: v0.9.0.5
|
| 60 |
+
hooks:
|
| 61 |
+
- id: shellcheck
|
| 62 |
+
require_serial: true
|
| 63 |
+
args: ["-e", "SC2102"]
|
| 64 |
+
|
| 65 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
| 66 |
+
rev: v4.0.0-alpha.8
|
| 67 |
+
hooks:
|
| 68 |
+
- id: prettier
|
| 69 |
+
# To avoid conflicts, tell prettier to ignore file types
|
| 70 |
+
# that ruff already handles.
|
| 71 |
+
exclude_types: [python]
|
| 72 |
+
|
| 73 |
+
# word spelling linter
|
| 74 |
+
- repo: https://github.com/codespell-project/codespell
|
| 75 |
+
rev: v2.2.5
|
| 76 |
+
hooks:
|
| 77 |
+
- id: codespell
|
| 78 |
+
require_serial: true
|
| 79 |
+
args:
|
| 80 |
+
- --skip=logs/**,data/**,*.ipynb
|
| 81 |
+
- --ignore-words-list=ot,hist
|
| 82 |
+
|
| 83 |
+
# jupyter notebook linting
|
| 84 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
| 85 |
+
rev: 1.7.0
|
| 86 |
+
hooks:
|
| 87 |
+
- id: nbqa-black
|
| 88 |
+
args: ["--line-length=99"]
|
| 89 |
+
require_serial: true
|
| 90 |
+
- id: nbqa-isort
|
| 91 |
+
args: ["--profile=black"]
|
| 92 |
+
require_serial: true
|
| 93 |
+
- id: nbqa-flake8
|
| 94 |
+
args:
|
| 95 |
+
[
|
| 96 |
+
"--extend-ignore=E203,E402,E501,F401,F841,F821,F403,F405,F811",
|
| 97 |
+
"--exclude=logs/*,data/*,notebooks/*",
|
| 98 |
+
]
|
| 99 |
+
require_serial: true
|
conditional-flow-matching/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Alexander Tong
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
conditional-flow-matching/README.md
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# TorchCFM: a Conditional Flow Matching library
|
| 4 |
+
|
| 5 |
+
<!---[](https://papers.nips.cc/paper/2020) -->
|
| 6 |
+
|
| 7 |
+
<!---[](https://github.com/atong01/conditional-flow-matching/graphs/contributors) -->
|
| 8 |
+
|
| 9 |
+
[](https://arxiv.org/abs/2302.00482)
|
| 10 |
+
[](https://arxiv.org/abs/2307.03672)
|
| 11 |
+
[](https://pytorch.org/get-started/locally/)
|
| 12 |
+
[](https://pytorchlightning.ai/)
|
| 13 |
+
[](https://hydra.cc/)
|
| 14 |
+
[](https://black.readthedocs.io/en/stable/)
|
| 15 |
+
[](https://github.com/pre-commit/pre-commit)
|
| 16 |
+
[](https://github.com/atong01/conditional-flow-matching/actions/workflows/test.yaml)
|
| 17 |
+
[](https://codecov.io/gh/atong01/conditional-flow-matching/)
|
| 18 |
+
[](https://github.com/atong01/conditional-flow-matching/actions/workflows/code-quality-main.yaml)
|
| 19 |
+
[](https://github.com/atong01/conditional-flow-matching#license)
|
| 20 |
+
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a>
|
| 21 |
+
[](https://pepy.tech/project/torchcfm)
|
| 22 |
+
[](https://pepy.tech/project/torchcfm)
|
| 23 |
+
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
## Description
|
| 27 |
+
|
| 28 |
+
Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference. CFM's performance closes the gap between CNFs and diffusion models. To spread its use within the machine learning community, we have built a library focused on Flow Matching methods: TorchCFM. TorchCFM is a library showing how Flow Matching methods can be trained and used to deal with image generation, single-cell dynamics, tabular data and soon SO(3) data.
|
| 29 |
+
|
| 30 |
+
<p align="center">
|
| 31 |
+
<img src="assets/169_generated_samples_otcfm.png" width="600"/>
|
| 32 |
+
<img src="assets/8gaussians-to-moons.gif" />
|
| 33 |
+
</p>
|
| 34 |
+
|
| 35 |
+
The density, vector field, and trajectories of simulation-free CNF training schemes: mapping 8 Gaussians to two moons (above) and a single Gaussian to two moons (below). Action matching with the same architecture (3x64 MLP with SeLU activations) underfits with the ReLU, SiLU, and SiLU activations as suggested in the [example code](https://github.com/necludov/jam), but it seems to fit better under our training setup (Action-Matching (Swish)).
|
| 36 |
+
|
| 37 |
+
The models to produce the GIFs are stored in `examples/models` and can be visualized with this notebook: [](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/model-comparison-plotting.ipynb).
|
| 38 |
+
|
| 39 |
+
We also have included an example of unconditional MNIST generation in `examples/notebooks/mnist_example.ipynb` for both deterministic and stochastic generation. [](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/mnist_example.ipynb).
|
| 40 |
+
|
| 41 |
+
## The torchcfm Package
|
| 42 |
+
|
| 43 |
+
In our version 1 update we have extracted implementations of the relevant flow matching variants into a package `torchcfm`. This allows abstraction of the choice of the conditional distribution `q(z)`. `torchcfm` supplies the following loss functions:
|
| 44 |
+
|
| 45 |
+
- `ConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = q(x_0) q(x_1)$
|
| 46 |
+
- `ExactOptimalTransportConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi(x_0, x_1)$ where $\\pi$ is an exact optimal transport joint. This is used in \[Tong et al. 2023a\] and \[Poolidan et al. 2023\] as "OT-CFM" and "Multisample FM with Batch OT" respectively.
|
| 47 |
+
- `TargetConditionalFlowMatcher`: $z = x_1$, $q(z) = q(x_1)$ as defined in Lipman et al. 2023, learns a flow from a standard normal Gaussian to data using conditional flows which optimally transport the Gaussian to the datapoint (Note that this does not result in the marginal flow being optimal transport).
|
| 48 |
+
- `SchrodingerBridgeConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi\_\\epsilon(x_0, x_1)$ where $\\pi\_\\epsilon$ is an entropically regularized OT plan, although in practice this is often approximated by a minibatch OT plan (See Tong et al. 2023b). The flow-matching variant of this where the marginals are equivalent to the Schrodinger Bridge marginals is known as `SB-CFM` \[Tong et al. 2023a\]. When the score is also known and the bridge is stochastic is called \[SF\]2M \[Tong et al. 2023b\]
|
| 49 |
+
- `VariancePreservingConditionalFlowMatcher`: $z = (x_0, x_1)$ $q(z) = q(x_0) q(x_1)$ but with conditional Gaussian probability paths which preserve variance over time using a trigonometric interpolation as presented in \[Albergo et al. 2023a\].
|
| 50 |
+
|
| 51 |
+
## How to cite
|
| 52 |
+
|
| 53 |
+
This repository contains the code to reproduce the main experiments and illustrations of two preprints:
|
| 54 |
+
|
| 55 |
+
- [Improving and generalizing flow-based generative models with minibatch optimal transport](https://arxiv.org/abs/2302.00482). We introduce **Optimal Transport Conditional Flow Matching** (OT-CFM), a CFM variant that approximates the dynamical formulation of optimal transport (OT). Based on OT theory, OT-CFM leverages the static optimal transport plan as well as the optimal probability paths and vector fields to approximate dynamic OT.
|
| 56 |
+
- [Simulation-free Schrödinger bridges via score and flow matching](https://arxiv.org/abs/2307.03672). We propose **Simulation-Free Score and Flow Matching** (\[SF\]<sup>2</sup>M). \[SF\]<sup>2</sup>M leverages OT-CFM as well as score-based methods to approximate Schrödinger bridges, a stochastic version of optimal transport.
|
| 57 |
+
|
| 58 |
+
If you find this code useful in your research, please cite the following papers (expand for BibTeX):
|
| 59 |
+
|
| 60 |
+
<details>
|
| 61 |
+
<summary>
|
| 62 |
+
A. Tong, N. Malkin, G. Huguet, Y. Zhang, J. Rector-Brooks, K. Fatras, G. Wolf, Y. Bengio. Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport, 2023.
|
| 63 |
+
</summary>
|
| 64 |
+
|
| 65 |
+
```bibtex
|
| 66 |
+
@article{tong2024improving,
|
| 67 |
+
title={Improving and generalizing flow-based generative models with minibatch optimal transport},
|
| 68 |
+
author={Alexander Tong and Kilian FATRAS and Nikolay Malkin and Guillaume Huguet and Yanlei Zhang and Jarrid Rector-Brooks and Guy Wolf and Yoshua Bengio},
|
| 69 |
+
journal={Transactions on Machine Learning Research},
|
| 70 |
+
issn={2835-8856},
|
| 71 |
+
year={2024},
|
| 72 |
+
url={https://openreview.net/forum?id=CD9Snc73AW},
|
| 73 |
+
note={Expert Certification}
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
</details>
|
| 78 |
+
|
| 79 |
+
<details>
|
| 80 |
+
<summary>
|
| 81 |
+
A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Bengio. Simulation-Free Schrödinger Bridges via Score and Flow Matching, 2023.
|
| 82 |
+
</summary>
|
| 83 |
+
|
| 84 |
+
```bibtex
|
| 85 |
+
@article{tong2023simulation,
|
| 86 |
+
title={Simulation-Free Schr{\"o}dinger Bridges via Score and Flow Matching},
|
| 87 |
+
author={Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua},
|
| 88 |
+
year={2023},
|
| 89 |
+
journal={arXiv preprint 2307.03672}
|
| 90 |
+
}
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
</details>
|
| 94 |
+
|
| 95 |
+
## V0 -> V1
|
| 96 |
+
|
| 97 |
+
Major Changes:
|
| 98 |
+
|
| 99 |
+
- **Added cifar10 examples with an FID of 3.5**
|
| 100 |
+
- Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint
|
| 101 |
+
- Created `torchcfm` pip installable package
|
| 102 |
+
- Moved `pytorch-lightning` implementation and experiments to `runner` directory
|
| 103 |
+
- Moved `notebooks` -> `examples`
|
| 104 |
+
- Added image generation implementation in both lightning and a notebook in `examples`
|
| 105 |
+
|
| 106 |
+
## Implemented papers
|
| 107 |
+
|
| 108 |
+
List of implemented papers:
|
| 109 |
+
|
| 110 |
+
- Flow Matching for Generative Modeling (Lipman et al. 2023) [Paper](https://openreview.net/forum?id=PqvMRDCJT9t)
|
| 111 |
+
- Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow (Liu et al. 2023) [Paper](https://openreview.net/forum?id=XVjTT1nw5z) [Code](https://github.com/gnobitab/RectifiedFlow.git)
|
| 112 |
+
- Building Normalizing Flows with Stochastic Interpolants (Albergo et al. 2023a) [Paper](https://openreview.net/forum?id=li7qeBbCR1t)
|
| 113 |
+
- Action Matching: Learning Stochastic Dynamics From Samples (Neklyudov et al. 2022) [Paper](https://arxiv.org/abs/2210.06662) [Code](https://github.com/necludov/jam)
|
| 114 |
+
- Concurrent work to our OT-CFM method: Multisample Flow Matching: Straightening Flows with Minibatch Couplings (Pooladian et al. 2023) [Paper](https://arxiv.org/abs/2304.14772)
|
| 115 |
+
- Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees (Jolicoeur-Martineau et al.) [Paper](https://arxiv.org/abs/2309.09968) [Code](https://github.com/SamsungSAILMontreal/ForestDiffusion)
|
| 116 |
+
- Soon: SE(3)-Stochastic Flow Matching for Protein Backbone Generation (Bose et al.) [Paper](https://arxiv.org/abs/2310.02391)
|
| 117 |
+
|
| 118 |
+
## How to run
|
| 119 |
+
|
| 120 |
+
Run a simple minimal example here [](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/training-8gaussians-to-moons.ipynb). Or install the more efficient code locally with these steps.
|
| 121 |
+
|
| 122 |
+
TorchCFM is now on [pypi](https://pypi.org/project/torchcfm/)! You can install it with:
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
pip install torchcfm
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
To use the full library with the different examples, you can install dependencies:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
# clone project
|
| 132 |
+
git clone https://github.com/atong01/conditional-flow-matching.git
|
| 133 |
+
cd conditional-flow-matching
|
| 134 |
+
|
| 135 |
+
# [OPTIONAL] create conda environment
|
| 136 |
+
conda create -n torchcfm python=3.10
|
| 137 |
+
conda activate torchcfm
|
| 138 |
+
|
| 139 |
+
# install pytorch according to instructions
|
| 140 |
+
# https://pytorch.org/get-started/
|
| 141 |
+
|
| 142 |
+
# install requirements
|
| 143 |
+
pip install -r requirements.txt
|
| 144 |
+
|
| 145 |
+
# install torchcfm
|
| 146 |
+
pip install -e .
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
To run our jupyter notebooks, use the following commands after installing our package.
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
# install ipykernel
|
| 153 |
+
conda install -c anaconda ipykernel
|
| 154 |
+
|
| 155 |
+
# install conda env in jupyter notebook
|
| 156 |
+
python -m ipykernel install --user --name=torchcfm
|
| 157 |
+
|
| 158 |
+
# launch our notebooks with the torchcfm kernel
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## Project Structure
|
| 162 |
+
|
| 163 |
+
The directory structure looks like this:
|
| 164 |
+
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
│
|
| 168 |
+
├── examples <- Jupyter notebooks
|
| 169 |
+
| ├── cifar10 <- Cifar10 experiments
|
| 170 |
+
│ ├── notebooks <- Diverse examples with notebooks
|
| 171 |
+
│
|
| 172 |
+
│── runner <- Everything related to the original version (V0) of the library
|
| 173 |
+
│
|
| 174 |
+
|── torchcfm <- Code base of our Flow Matching methods
|
| 175 |
+
| ├── conditional_flow_matching.py <- CFM classes
|
| 176 |
+
│ ├── models <- Model architectures
|
| 177 |
+
│ │ ├── models <- Models for 2D examples
|
| 178 |
+
│ │ ├── Unet <- Unet models for image examples
|
| 179 |
+
|
|
| 180 |
+
├── .gitignore <- List of files ignored by git
|
| 181 |
+
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
|
| 182 |
+
├── pyproject.toml <- Configuration options for testing and linting
|
| 183 |
+
├── requirements.txt <- File for installing python dependencies
|
| 184 |
+
├── setup.py <- File for installing project as a package
|
| 185 |
+
└── README.md
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
## ❤️ Code Contributions
|
| 189 |
+
|
| 190 |
+
This toolbox has been created and is maintained by
|
| 191 |
+
|
| 192 |
+
- [Alexander Tong](http://alextong.net)
|
| 193 |
+
- [Kilian Fatras](http://kilianfatras.github.io)
|
| 194 |
+
|
| 195 |
+
It was initiated from a larger private codebase which loses the original commit history which contains work from other authors of the papers.
|
| 196 |
+
|
| 197 |
+
Before making an issue, please verify that:
|
| 198 |
+
|
| 199 |
+
- The problem still exists on the current `main` branch.
|
| 200 |
+
- Your python dependencies are updated to recent versions.
|
| 201 |
+
|
| 202 |
+
Suggestions for improvements are always welcome!
|
| 203 |
+
|
| 204 |
+
## License
|
| 205 |
+
|
| 206 |
+
Conditional-Flow-Matching is licensed under the MIT License.
|
| 207 |
+
|
| 208 |
+
```
|
| 209 |
+
MIT License
|
| 210 |
+
|
| 211 |
+
Copyright (c) 2023 Alexander Tong
|
| 212 |
+
|
| 213 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 214 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 215 |
+
in the Software without restriction, including without limitation the rights
|
| 216 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 217 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 218 |
+
furnished to do so, subject to the following conditions:
|
| 219 |
+
|
| 220 |
+
The above copyright notice and this permission notice shall be included in all
|
| 221 |
+
copies or substantial portions of the Software.
|
| 222 |
+
|
| 223 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 224 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 225 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 226 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 227 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 228 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 229 |
+
SOFTWARE.
|
| 230 |
+
```
|
conditional-flow-matching/pyproject.toml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.pytest.ini_options]
|
| 2 |
+
addopts = [
|
| 3 |
+
"--color=yes",
|
| 4 |
+
"--durations=0",
|
| 5 |
+
"--strict-markers",
|
| 6 |
+
"--doctest-modules",
|
| 7 |
+
]
|
| 8 |
+
filterwarnings = [
|
| 9 |
+
"ignore::DeprecationWarning",
|
| 10 |
+
"ignore::UserWarning",
|
| 11 |
+
]
|
| 12 |
+
log_cli = "True"
|
| 13 |
+
markers = [
|
| 14 |
+
"slow: slow tests",
|
| 15 |
+
]
|
| 16 |
+
minversion = "6.0"
|
| 17 |
+
testpaths = "tests/"
|
| 18 |
+
|
| 19 |
+
[tool.coverage.report]
|
| 20 |
+
exclude_lines = [
|
| 21 |
+
"pragma: nocover",
|
| 22 |
+
"raise NotImplementedError",
|
| 23 |
+
"raise NotImplementedError()",
|
| 24 |
+
"if __name__ == .__main__.:",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[tool.ruff]
|
| 28 |
+
line-length = 99
|
| 29 |
+
|
| 30 |
+
[tool.ruff.lint]
|
| 31 |
+
ignore = ["C901", "E501", "E741", "W605", "C408", "E402"]
|
| 32 |
+
select = ["C", "E", "F", "I", "W"]
|
| 33 |
+
|
| 34 |
+
[tool.ruff.lint.per-file-ignores]
|
| 35 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
| 36 |
+
|
| 37 |
+
[tool.ruff.lint.isort]
|
| 38 |
+
known-first-party = ["src"]
|
| 39 |
+
known-third-party = ["torch", "transformers", "wandb"]
|
conditional-flow-matching/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.11.0
|
| 2 |
+
torchvision>=0.11.0
|
| 3 |
+
|
| 4 |
+
lightning-bolts
|
| 5 |
+
matplotlib
|
| 6 |
+
numpy
|
| 7 |
+
scipy
|
| 8 |
+
scikit-learn
|
| 9 |
+
scprep
|
| 10 |
+
scanpy
|
| 11 |
+
torchdyn>=1.0.6 # 1.0.4 is broken on pypi
|
| 12 |
+
pot
|
| 13 |
+
torchdiffeq
|
| 14 |
+
absl-py
|
| 15 |
+
clean-fid
|
conditional-flow-matching/runner-requirements.txt
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Note if using Conda it is recommended to install torch separately.
|
| 2 |
+
# For most of testing the following commands were run to set up the environment
|
| 3 |
+
# This was tested with torch==1.12.1
|
| 4 |
+
# conda create -n ti-env python=3.10
|
| 5 |
+
# conda activate ti-env
|
| 6 |
+
# pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
|
| 7 |
+
# pip install -r requirements.txt
|
| 8 |
+
# --------- pytorch --------- #
|
| 9 |
+
torch>=1.11.0,<2.0.0
|
| 10 |
+
torchvision>=0.11.0
|
| 11 |
+
pytorch-lightning==1.8.3.post2
|
| 12 |
+
torchmetrics==0.11.0
|
| 13 |
+
|
| 14 |
+
# --------- hydra --------- #
|
| 15 |
+
hydra-core==1.2.0
|
| 16 |
+
hydra-colorlog==1.2.0
|
| 17 |
+
hydra-optuna-sweeper==1.2.0
|
| 18 |
+
# hydra-submitit-launcher
|
| 19 |
+
|
| 20 |
+
# --------- loggers --------- #
|
| 21 |
+
wandb
|
| 22 |
+
# neptune-client
|
| 23 |
+
# mlflow
|
| 24 |
+
# comet-ml
|
| 25 |
+
|
| 26 |
+
# --------- others --------- #
|
| 27 |
+
black
|
| 28 |
+
isort
|
| 29 |
+
flake8
|
| 30 |
+
Flake8-pyproject # for configuration via pyproject
|
| 31 |
+
pyrootutils # standardizing the project root setup
|
| 32 |
+
pre-commit # hooks for applying linters on commit
|
| 33 |
+
rich # beautiful text formatting in terminal
|
| 34 |
+
pytest # tests
|
| 35 |
+
# sh # for running bash commands in some tests (linux/macos only)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# --------- pkg reqs -------- #
|
| 39 |
+
lightning-bolts
|
| 40 |
+
matplotlib
|
| 41 |
+
numpy
|
| 42 |
+
scipy
|
| 43 |
+
scikit-learn
|
| 44 |
+
scprep
|
| 45 |
+
scanpy
|
| 46 |
+
timm
|
| 47 |
+
torchdyn>=1.0.5 # 1.0.4 is broken on pypi
|
| 48 |
+
pot
|
| 49 |
+
|
| 50 |
+
# --------- notebook reqs -------- #
|
| 51 |
+
seaborn>=0.12.2
|
| 52 |
+
pandas>=2.2.2
|
conditional-flow-matching/setup.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from setuptools import find_packages, setup
|
| 6 |
+
|
| 7 |
+
install_requires = [
|
| 8 |
+
"torch>=1.11.0",
|
| 9 |
+
"matplotlib",
|
| 10 |
+
"numpy", # Due to pandas incompatibility
|
| 11 |
+
"scipy",
|
| 12 |
+
"scikit-learn",
|
| 13 |
+
"torchdyn>=1.0.6",
|
| 14 |
+
"pot",
|
| 15 |
+
"torchdiffeq",
|
| 16 |
+
"absl-py",
|
| 17 |
+
"pandas>=2.2.2",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
version_py = os.path.join(os.path.dirname(__file__), "torchcfm", "version.py")
|
| 21 |
+
version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip()
|
| 22 |
+
readme = open("README.md", encoding="utf8").read()
|
| 23 |
+
setup(
|
| 24 |
+
name="torchcfm",
|
| 25 |
+
version=version,
|
| 26 |
+
description="Conditional Flow Matching for Fast Continuous Normalizing Flow Training.",
|
| 27 |
+
author="Alexander Tong, Kilian Fatras",
|
| 28 |
+
author_email="alexandertongdev@gmail.com",
|
| 29 |
+
url="https://github.com/atong01/conditional-flow-matching",
|
| 30 |
+
install_requires=install_requires,
|
| 31 |
+
license="MIT",
|
| 32 |
+
long_description=readme,
|
| 33 |
+
long_description_content_type="text/markdown",
|
| 34 |
+
packages=find_packages(exclude=["tests", "tests.*"]),
|
| 35 |
+
extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]},
|
| 36 |
+
)
|