xiangzai commited on
Commit
5484dca
·
verified ·
1 Parent(s): 590a56c

Add files using upload-large-folder tool

Browse files
REG copy/train.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_GPUS=8
2
+ random_number=$((RANDOM % 100 + 1200))
3
+
4
+
5
+ accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
6
+ --report-to="wandb" \
7
+ --allow-tf32 \
8
+ --mixed-precision="fp16" \
9
+ --seed=0 \
10
+ --path-type="linear" \
11
+ --prediction="v" \
12
+ --weighting="uniform" \
13
+ --model="SiT-XL/2" \
14
+ --enc-type="dinov2-vit-b" \
15
+ --proj-coeff=0.5 \
16
+ --encoder-depth=8 \ #SiT-L/XL use 8, SiT-B use 4
17
+ --output-dir="your_path/reg_xlarge_dinov2_base_align_8_cls" \
18
+ --exp-name="linear-dinov2-b-enc8" \
19
+ --batch-size=256 \
20
+ --data-dir="data_path/imagenet_vae" \
21
+ --cls=0.03
22
+
23
+
24
+ #Dataset Path
25
+ #For example: your_path/imagenet-vae
26
+ #This folder contains two folders
27
+ #(1) The imagenet's RGB image: your_path/imagenet-vae/imagenet_256-vae/
28
+ #(2) The imagenet's VAE latent: your_path/imagenet-vae/vae-sd/
VIRTUAL_imagenet256_labeled/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
VIRTUAL_imagenet256_labeled/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: Apache License 2.0
3
+ ---
4
+ 数据集文件元信息以及数据文件,请浏览“数据集文件”页面获取。
5
+
6
+ 当前数据集卡片使用的是默认模版,数据集的贡献者未提供更加详细的数据集介绍,但是您可以通过如下GIT Clone命令,或者ModelScope SDK来下载数据集
7
+
8
+ #### 下载方法
9
+ :modelscope-code[]{type="sdk"}
10
+ :modelscope-code[]{type="git"}
11
+
__pycache__/evaluator.cpython-38.pyc ADDED
Binary file (23.5 kB). View file
 
__pycache__/fid_custom.cpython-313.pyc ADDED
Binary file (16.5 kB). View file
 
__pycache__/npz_health_check.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
back/dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ import PIL.Image
10
+ try:
11
+ import pyspng
12
+ except ImportError:
13
+ pyspng = None
14
+
15
+
16
+ class CustomDataset(Dataset):
17
+ """
18
+ data_dir 下 VAE latent:imagenet_256_vae/
19
+ 无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。
20
+ 有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引,
21
+ 按特征文件名推断 imagenet_256_vae 中对应 npy。
22
+ """
23
+
24
+ def __init__(self, data_dir, semantic_features_dir=None):
25
+ PIL.Image.init()
26
+ supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
27
+
28
+ self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
29
+
30
+ if semantic_features_dir is None:
31
+ potential_semantic_dir = os.path.join(
32
+ data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0'
33
+ )
34
+ if os.path.exists(potential_semantic_dir):
35
+ self.semantic_features_dir = potential_semantic_dir
36
+ self.use_preprocessed_semantic = True
37
+ print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
38
+ else:
39
+ self.semantic_features_dir = None
40
+ self.use_preprocessed_semantic = False
41
+ else:
42
+ self.semantic_features_dir = semantic_features_dir
43
+ self.use_preprocessed_semantic = True
44
+ print(f"Using preprocessed semantic features from: {self.semantic_features_dir}")
45
+
46
+ if self.use_preprocessed_semantic:
47
+ label_fname = os.path.join(self.semantic_features_dir, 'dataset.json')
48
+ if not os.path.exists(label_fname):
49
+ raise FileNotFoundError(f"Label file not found: {label_fname}")
50
+
51
+ print(f"Using {label_fname}.")
52
+ with open(label_fname, 'rb') as f:
53
+ data = json.load(f)
54
+ labels_list = data.get('labels', None)
55
+ if labels_list is None:
56
+ raise ValueError(f"'labels' field is missing in {label_fname}")
57
+
58
+ semantic_fnames = []
59
+ labels = []
60
+ for entry in labels_list:
61
+ if entry is None:
62
+ continue
63
+ fname, lab = entry
64
+ semantic_fnames.append(fname)
65
+ labels.append(0 if lab is None else lab)
66
+
67
+ self.semantic_fnames = semantic_fnames
68
+ self.labels = np.array(labels, dtype=np.int64)
69
+ self.num_samples = len(self.semantic_fnames)
70
+ print(f"Loaded {self.num_samples} semantic entries from dataset.json")
71
+ else:
72
+ self.features_dir = os.path.join(data_dir, 'vae-sd')
73
+
74
+ self._image_fnames = {
75
+ os.path.relpath(os.path.join(root, fname), start=self.images_dir)
76
+ for root, _dirs, files in os.walk(self.images_dir) for fname in files
77
+ }
78
+ self.image_fnames = sorted(
79
+ fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
80
+ )
81
+ self._feature_fnames = {
82
+ os.path.relpath(os.path.join(root, fname), start=self.features_dir)
83
+ for root, _dirs, files in os.walk(self.features_dir) for fname in files
84
+ }
85
+ self.feature_fnames = sorted(
86
+ fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
87
+ )
88
+
89
+ fname = os.path.join(self.features_dir, 'dataset.json')
90
+ if os.path.exists(fname):
91
+ print(f"Using {fname}.")
92
+ else:
93
+ raise FileNotFoundError("Neither of the specified files exists.")
94
+
95
+ with open(fname, 'rb') as f:
96
+ labels = json.load(f)['labels']
97
+ labels = dict(labels)
98
+ labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
99
+ labels = np.array(labels)
100
+ self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
101
+
102
+ def _file_ext(self, fname):
103
+ return os.path.splitext(fname)[1].lower()
104
+
105
+ def __len__(self):
106
+ if self.use_preprocessed_semantic:
107
+ return self.num_samples
108
+ assert len(self.image_fnames) == len(self.feature_fnames), \
109
+ "Number of feature files and label files should be same"
110
+ return len(self.feature_fnames)
111
+
112
+ def __getitem__(self, idx):
113
+ if self.use_preprocessed_semantic:
114
+ semantic_fname = self.semantic_fnames[idx]
115
+ basename = os.path.basename(semantic_fname)
116
+ idx_str = basename.split('-')[-1].split('.')[0]
117
+ subdir = idx_str[:5]
118
+ vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
119
+ vae_path = os.path.join(self.images_dir, vae_relpath)
120
+
121
+ with open(vae_path, 'rb') as f:
122
+ image = np.load(f)
123
+
124
+ semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
125
+ semantic_features = np.load(semantic_path)
126
+
127
+ return (
128
+ torch.from_numpy(image).float(),
129
+ torch.from_numpy(image).float(),
130
+ torch.from_numpy(semantic_features).float(),
131
+ torch.tensor(self.labels[idx]),
132
+ )
133
+
134
+ image_fname = self.image_fnames[idx]
135
+ feature_fname = self.feature_fnames[idx]
136
+ image_ext = self._file_ext(image_fname)
137
+ with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
138
+ if image_ext == '.npy':
139
+ image = np.load(f)
140
+ image = image.reshape(-1, *image.shape[-2:])
141
+ elif image_ext == '.png' and pyspng is not None:
142
+ image = pyspng.load(f.read())
143
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
144
+ else:
145
+ image = np.array(PIL.Image.open(f))
146
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
147
+
148
+ features = np.load(os.path.join(self.features_dir, feature_fname))
149
+ return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
back/generate.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Samples a large number of images from a pre-trained SiT model using DDP.
9
+ Subsequently saves a .npz file that can be used to compute FID and other
10
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11
+
12
+ For a simple single-GPU/CPU sampling script, see sample.py.
13
+ """
14
+ import torch
15
+ import torch.distributed as dist
16
+ from models.sit import SiT_models
17
+ from diffusers.models import AutoencoderKL
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ from samplers import euler_maruyama_sampler
25
+ from utils import load_legacy_checkpoints, download_model
26
+
27
+
28
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
29
+ """
30
+ Builds a single .npz file from a folder of .png samples.
31
+ """
32
+ samples = []
33
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
34
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
35
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
36
+ samples.append(sample_np)
37
+ samples = np.stack(samples)
38
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
39
+ npz_path = f"{sample_dir}.npz"
40
+ np.savez(npz_path, arr_0=samples)
41
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
42
+ return npz_path
43
+
44
+
45
+ def main(args):
46
+ """
47
+ Run sampling.
48
+ """
49
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
50
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
51
+ torch.set_grad_enabled(False)
52
+
53
+ # Setup DDP:cd
54
+ dist.init_process_group("nccl")
55
+ rank = dist.get_rank()
56
+ device = rank % torch.cuda.device_count()
57
+ seed = args.global_seed * dist.get_world_size() + rank
58
+ torch.manual_seed(seed)
59
+ torch.cuda.set_device(device)
60
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
61
+
62
+ # Load model:
63
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
64
+ latent_size = args.resolution // 8
65
+ model = SiT_models[args.model](
66
+ input_size=latent_size,
67
+ num_classes=args.num_classes,
68
+ use_cfg = True,
69
+ z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
70
+ encoder_depth=args.encoder_depth,
71
+ **block_kwargs,
72
+ ).to(device)
73
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
74
+ ckpt_path = args.ckpt
75
+
76
+
77
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
78
+ if ckpt_path is None:
79
+ args.ckpt = 'SiT-XL-2-256x256.pt'
80
+ assert args.model == 'SiT-XL/2'
81
+ assert len(args.projector_embed_dims.split(',')) == 1
82
+ assert int(args.projector_embed_dims.split(',')[0]) == 768
83
+ state_dict = download_model('last.pt')
84
+ else:
85
+ state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
86
+
87
+ if args.legacy:
88
+ state_dict = load_legacy_checkpoints(
89
+ state_dict=state_dict, encoder_depth=args.encoder_depth
90
+ )
91
+ model.load_state_dict(state_dict)
92
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
93
+
94
+
95
+ model.eval() # important!
96
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
97
+ #vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
98
+
99
+
100
+ # Create folder to save samples:
101
+ model_string_name = args.model.replace("/", "-")
102
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
103
+ folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
104
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
105
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
106
+ if rank == 0:
107
+ os.makedirs(sample_folder_dir, exist_ok=True)
108
+ print(f"Saving .png samples at {sample_folder_dir}")
109
+ dist.barrier()
110
+
111
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
112
+ n = args.per_proc_batch_size
113
+ global_batch_size = n * dist.get_world_size()
114
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
115
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
116
+ if rank == 0:
117
+ print(f"Total number of images that will be sampled: {total_samples}")
118
+ print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
119
+ print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
120
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
121
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
122
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
123
+ iterations = int(samples_needed_this_gpu // n)
124
+ pbar = range(iterations)
125
+ pbar = tqdm(pbar) if rank == 0 else pbar
126
+ total = 0
127
+ for _ in pbar:
128
+ # Sample inputs:
129
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
130
+ y = torch.randint(0, args.num_classes, (n,), device=device)
131
+ cls_z = torch.randn(n, args.cls, device=device)
132
+
133
+ # Sample images:
134
+ sampling_kwargs = dict(
135
+ model=model,
136
+ latents=z,
137
+ y=y,
138
+ num_steps=args.num_steps,
139
+ heun=args.heun,
140
+ cfg_scale=args.cfg_scale,
141
+ guidance_low=args.guidance_low,
142
+ guidance_high=args.guidance_high,
143
+ path_type=args.path_type,
144
+ cls_latents=cls_z,
145
+ args=args
146
+ )
147
+ with torch.no_grad():
148
+ if args.mode == "sde":
149
+ samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
150
+ elif args.mode == "ode":# will support
151
+ exit()
152
+ #samples = euler_sampler(**sampling_kwargs).to(torch.float32)
153
+ else:
154
+ raise NotImplementedError()
155
+
156
+ latents_scale = torch.tensor(
157
+ [0.18215, 0.18215, 0.18215, 0.18215, ]
158
+ ).view(1, 4, 1, 1).to(device)
159
+ latents_bias = -torch.tensor(
160
+ [0., 0., 0., 0.,]
161
+ ).view(1, 4, 1, 1).to(device)
162
+ samples = vae.decode((samples - latents_bias) / latents_scale).sample
163
+ samples = (samples + 1) / 2.
164
+ samples = torch.clamp(
165
+ 255. * samples, 0, 255
166
+ ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
167
+
168
+ # Save samples to disk as individual .png files
169
+ for i, sample in enumerate(samples):
170
+ index = i * dist.get_world_size() + rank + total
171
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
172
+ total += global_batch_size
173
+
174
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
175
+ dist.barrier()
176
+ if rank == 0:
177
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
178
+ print("Done.")
179
+ dist.barrier()
180
+ dist.destroy_process_group()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ parser = argparse.ArgumentParser()
185
+ # seed
186
+ parser.add_argument("--global-seed", type=int, default=0)
187
+
188
+ # precision
189
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
190
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
191
+
192
+ # logging/saving:
193
+ parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
194
+ parser.add_argument("--sample-dir", type=str, default="samples")
195
+
196
+ # model
197
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
198
+ parser.add_argument("--num-classes", type=int, default=1000)
199
+ parser.add_argument("--encoder-depth", type=int, default=8)
200
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
201
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
202
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
203
+ # vae
204
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
205
+
206
+ # number of samples
207
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
208
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
209
+
210
+ # sampling related hyperparameters
211
+ parser.add_argument("--mode", type=str, default="ode")
212
+ parser.add_argument("--cfg-scale", type=float, default=1.5)
213
+ parser.add_argument("--cls-cfg-scale", type=float, default=1.5)
214
+ parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
215
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
216
+ parser.add_argument("--num-steps", type=int, default=50)
217
+ parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
218
+ parser.add_argument("--guidance-low", type=float, default=0.)
219
+ parser.add_argument("--guidance-high", type=float, default=1.)
220
+ parser.add_argument('--local-rank', default=-1, type=int)
221
+ parser.add_argument('--cls', default=768, type=int)
222
+ # will be deprecated
223
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
224
+
225
+
226
+ args = parser.parse_args()
227
+ main(args)
back/sample_from_checkpoint_ddp.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。
4
+
5
+ 用法(4 卡示例):
6
+ torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \
7
+ --ckpt exps/jsflow-experiment/checkpoints/0290000.pt \
8
+ --out-dir ./my_samples_ddp \
9
+ --num-images 50000 \
10
+ --batch-size 16 \
11
+ --t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \
12
+ --sampler em_image_noise_before_tc
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import math
19
+ import os
20
+ import sys
21
+ import types
22
+ import numpy as np
23
+
24
+ import torch
25
+ import torch.distributed as dist
26
+ from diffusers.models import AutoencoderKL
27
+ from PIL import Image
28
+ from tqdm import tqdm
29
+
30
+ from models.sit import SiT_models
31
+ from samplers import (
32
+ euler_maruyama_image_noise_before_tc_sampler,
33
+ euler_maruyama_image_noise_sampler,
34
+ euler_maruyama_sampler,
35
+ euler_ode_sampler,
36
+ )
37
+
38
+
39
+ def create_npz_from_sample_folder(sample_dir: str, num: int):
40
+ """
41
+ 将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。
42
+ """
43
+ samples = []
44
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
45
+ sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png"))
46
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
47
+ samples.append(sample_np)
48
+ samples = np.stack(samples)
49
+ npz_path = f"{sample_dir}.npz"
50
+ np.savez(npz_path, arr_0=samples)
51
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
52
+ return npz_path
53
+
54
+
55
+ def semantic_dim_from_enc_type(enc_type):
56
+ if enc_type is None:
57
+ return 768
58
+ s = str(enc_type).lower()
59
+ if "vit-g" in s or "vitg" in s:
60
+ return 1536
61
+ if "vit-l" in s or "vitl" in s:
62
+ return 1024
63
+ if "vit-s" in s or "vits" in s:
64
+ return 384
65
+ return 768
66
+
67
+
68
+ def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
69
+ a = ckpt.get("args")
70
+ if a is None:
71
+ return None
72
+ if isinstance(a, argparse.Namespace):
73
+ return a
74
+ if isinstance(a, dict):
75
+ return argparse.Namespace(**a)
76
+ if isinstance(a, types.SimpleNamespace):
77
+ return argparse.Namespace(**vars(a))
78
+ return None
79
+
80
+
81
+ def load_vae(device: torch.device):
82
+ try:
83
+ from preprocessing import dnnlib
84
+
85
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
86
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
87
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
88
+ os.environ["HF_HOME"] = cache_dir
89
+ try:
90
+ vae = AutoencoderKL.from_pretrained(
91
+ "stabilityai/sd-vae-ft-mse",
92
+ cache_dir=cache_dir,
93
+ local_files_only=True,
94
+ ).to(device)
95
+ vae.eval()
96
+ return vae
97
+ except Exception:
98
+ pass
99
+ candidate_dir = None
100
+ for root_dir in [
101
+ cache_dir,
102
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
103
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
104
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
105
+ ]:
106
+ if not os.path.isdir(root_dir):
107
+ continue
108
+ for root, _, files in os.walk(root_dir):
109
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
110
+ candidate_dir = root
111
+ break
112
+ if candidate_dir is not None:
113
+ break
114
+ if candidate_dir is not None:
115
+ vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
116
+ vae.eval()
117
+ return vae
118
+ except Exception:
119
+ pass
120
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
121
+ vae.eval()
122
+ return vae
123
+
124
+
125
+ def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
126
+ res = int(getattr(ta, "resolution", 256))
127
+ latent_size = res // 8
128
+ enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
129
+ z_dims = [semantic_dim_from_enc_type(enc_type)]
130
+ block_kwargs = {
131
+ "fused_attn": getattr(ta, "fused_attn", True),
132
+ "qk_norm": getattr(ta, "qk_norm", False),
133
+ }
134
+ cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
135
+ if ta.model not in SiT_models:
136
+ raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
137
+ model = SiT_models[ta.model](
138
+ input_size=latent_size,
139
+ num_classes=int(getattr(ta, "num_classes", 1000)),
140
+ use_cfg=(cfg_prob > 0),
141
+ z_dims=z_dims,
142
+ encoder_depth=int(getattr(ta, "encoder_depth", 8)),
143
+ **block_kwargs,
144
+ ).to(device)
145
+ return model, z_dims[0]
146
+
147
+
148
+ def resolve_tc_schedule(cli, ta):
149
+ sb = cli.steps_before_tc
150
+ sa = cli.steps_after_tc
151
+ tc = cli.t_c
152
+ if sb is None and sa is None:
153
+ return None, None, None
154
+ if sb is None or sa is None:
155
+ print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr)
156
+ sys.exit(1)
157
+ if tc is None:
158
+ tc = getattr(ta, "t_c", None) if ta is not None else None
159
+ if tc is None:
160
+ print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr)
161
+ sys.exit(1)
162
+ return float(tc), int(sb), int(sa)
163
+
164
+
165
+ def parse_cli():
166
+ p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)")
167
+ p.add_argument("--ckpt", type=str, required=True)
168
+ p.add_argument("--out-dir", type=str, required=True)
169
+ p.add_argument("--num-images", type=int, required=True)
170
+ p.add_argument("--batch-size", type=int, default=16)
171
+ p.add_argument("--seed", type=int, default=0)
172
+ p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema")
173
+ p.add_argument("--device", type=str, default="cuda")
174
+ p.add_argument("--num-steps", type=int, default=50)
175
+ p.add_argument("--t-c", type=float, default=None)
176
+ p.add_argument("--steps-before-tc", type=int, default=None)
177
+ p.add_argument("--steps-after-tc", type=int, default=None)
178
+ p.add_argument("--cfg-scale", type=float, default=1.0)
179
+ p.add_argument("--cls-cfg-scale", type=float, default=0.0)
180
+ p.add_argument("--guidance-low", type=float, default=0.0)
181
+ p.add_argument("--guidance-high", type=float, default=1.0)
182
+ p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"])
183
+ p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
184
+ p.add_argument("--model", type=str, default=None)
185
+ p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
186
+ p.add_argument("--num-classes", type=int, default=1000)
187
+ p.add_argument("--encoder-depth", type=int, default=None)
188
+ p.add_argument("--enc-type", type=str, default=None)
189
+ p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
190
+ p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
191
+ p.add_argument("--cfg-prob", type=float, default=None)
192
+ p.add_argument(
193
+ "--sampler",
194
+ type=str,
195
+ default="em_image_noise_before_tc",
196
+ choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
197
+ )
198
+ p.add_argument(
199
+ "--save-fixed-trajectory",
200
+ action="store_true",
201
+ help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}",
202
+ )
203
+ p.add_argument(
204
+ "--save-npz",
205
+ action=argparse.BooleanOptionalAction,
206
+ default=True,
207
+ help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz",
208
+ )
209
+ return p.parse_args()
210
+
211
+
212
+ def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
213
+ imgs = vae.decode((latents - latents_bias) / latents_scale).sample
214
+ imgs = (imgs + 1) / 2.0
215
+ imgs = torch.clamp(imgs, 0, 1)
216
+ return (
217
+ (imgs * 255.0)
218
+ .round()
219
+ .to(torch.uint8)
220
+ .permute(0, 2, 3, 1)
221
+ .cpu()
222
+ .numpy()
223
+ )
224
+
225
+
226
+ def init_ddp():
227
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
228
+ rank = int(os.environ["RANK"])
229
+ world_size = int(os.environ["WORLD_SIZE"])
230
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
231
+ dist.init_process_group(backend="nccl", init_method="env://")
232
+ torch.cuda.set_device(local_rank)
233
+ return True, rank, world_size, local_rank
234
+ return False, 0, 1, 0
235
+
236
+
237
+ def main():
238
+ cli = parse_cli()
239
+ use_ddp, rank, world_size, local_rank = init_ddp()
240
+
241
+ if torch.cuda.is_available():
242
+ device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device)
243
+ torch.backends.cuda.matmul.allow_tf32 = True
244
+ else:
245
+ device = torch.device("cpu")
246
+
247
+ try:
248
+ ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
249
+ except TypeError:
250
+ ckpt = torch.load(cli.ckpt, map_location="cpu")
251
+ ta = load_train_args_from_ckpt(ckpt)
252
+ if ta is None:
253
+ if cli.model is None or cli.resolution is None or cli.enc_type is None:
254
+ print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr)
255
+ sys.exit(1)
256
+ ta = argparse.Namespace(
257
+ model=cli.model,
258
+ resolution=cli.resolution,
259
+ num_classes=cli.num_classes if cli.num_classes is not None else 1000,
260
+ encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
261
+ enc_type=cli.enc_type,
262
+ fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
263
+ qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
264
+ cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
265
+ )
266
+ else:
267
+ if cli.model is not None:
268
+ ta.model = cli.model
269
+ if cli.resolution is not None:
270
+ ta.resolution = cli.resolution
271
+ if cli.num_classes is not None:
272
+ ta.num_classes = cli.num_classes
273
+ if cli.encoder_depth is not None:
274
+ ta.encoder_depth = cli.encoder_depth
275
+ if cli.enc_type is not None:
276
+ ta.enc_type = cli.enc_type
277
+ if cli.fused_attn is not None:
278
+ ta.fused_attn = cli.fused_attn
279
+ if cli.qk_norm is not None:
280
+ ta.qk_norm = cli.qk_norm
281
+ if cli.cfg_prob is not None:
282
+ ta.cfg_prob = cli.cfg_prob
283
+
284
+ path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
285
+ tc_split = resolve_tc_schedule(cli, ta)
286
+
287
+ if rank == 0:
288
+ if tc_split[0] is not None:
289
+ print(
290
+ f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}"
291
+ )
292
+ else:
293
+ print(f"时间网格:均匀 num_steps={cli.num_steps}")
294
+
295
+ if cli.sampler == "ode":
296
+ sampler_fn = euler_ode_sampler
297
+ elif cli.sampler == "em":
298
+ sampler_fn = euler_maruyama_sampler
299
+ elif cli.sampler == "em_image_noise_before_tc":
300
+ sampler_fn = euler_maruyama_image_noise_before_tc_sampler
301
+ else:
302
+ sampler_fn = euler_maruyama_image_noise_sampler
303
+
304
+ model, cls_dim = build_model_from_train_args(ta, device)
305
+ wkey = cli.weights
306
+ if wkey not in ckpt:
307
+ raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
308
+ state = ckpt[wkey]
309
+ if cli.legacy:
310
+ from utils import load_legacy_checkpoints
311
+
312
+ state = load_legacy_checkpoints(
313
+ state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
314
+ )
315
+ model.load_state_dict(state, strict=True)
316
+ model.eval()
317
+
318
+ vae = load_vae(device)
319
+ latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
320
+ latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
321
+ sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
322
+
323
+ os.makedirs(cli.out_dir, exist_ok=True)
324
+ traj_dir = None
325
+ if cli.save_fixed_trajectory and cli.sampler != "em":
326
+ traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}")
327
+ os.makedirs(traj_dir, exist_ok=True)
328
+
329
+ latent_size = int(getattr(ta, "resolution", 256)) // 8
330
+ n_total = int(cli.num_images)
331
+ b = max(1, int(cli.batch_size))
332
+ global_batch_size = b * world_size
333
+ total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size)
334
+ samples_needed_this_gpu = int(total_samples // world_size)
335
+ if samples_needed_this_gpu % b != 0:
336
+ raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size")
337
+ iterations = int(samples_needed_this_gpu // b)
338
+
339
+ seed_rank = int(cli.seed) + int(rank)
340
+ torch.manual_seed(seed_rank)
341
+ if device.type == "cuda":
342
+ torch.cuda.manual_seed_all(seed_rank)
343
+
344
+ if rank == 0:
345
+ print(f"Total number of images that will be sampled: {total_samples}")
346
+ pbar = range(iterations)
347
+ pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar
348
+ total = 0
349
+ written_local = 0
350
+ for _ in pbar:
351
+ cur = b
352
+ z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
353
+ y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
354
+ cls_z = torch.randn(cur, cls_dim, device=device)
355
+
356
+ with torch.no_grad():
357
+ em_kw = dict(
358
+ num_steps=cli.num_steps,
359
+ cfg_scale=cli.cfg_scale,
360
+ guidance_low=cli.guidance_low,
361
+ guidance_high=cli.guidance_high,
362
+ path_type=path_type,
363
+ cls_latents=cls_z,
364
+ args=sampler_args,
365
+ )
366
+ if tc_split[0] is not None:
367
+ em_kw["t_c"] = tc_split[0]
368
+ em_kw["num_steps_before_tc"] = tc_split[1]
369
+ em_kw["num_steps_after_tc"] = tc_split[2]
370
+
371
+ if cli.save_fixed_trajectory and cli.sampler != "em":
372
+ if cli.sampler == "em_image_noise_before_tc":
373
+ latents, traj = sampler_fn(
374
+ model, z, y, **em_kw, return_trajectory=True
375
+ )
376
+ else:
377
+ latents, traj = sampler_fn(
378
+ model, z, y, **em_kw, return_trajectory=True
379
+ )
380
+ else:
381
+ latents = sampler_fn(model, z, y, **em_kw)
382
+ traj = None
383
+
384
+ latents = latents.to(torch.float32)
385
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
386
+ for i, img in enumerate(imgs):
387
+ gidx = i * world_size + rank + total
388
+ if gidx < n_total:
389
+ Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png"))
390
+ written_local += 1
391
+
392
+ if traj is not None and traj_dir is not None:
393
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
394
+ first_idx = rank + total
395
+ if first_idx < n_total:
396
+ np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np)
397
+
398
+ total += global_batch_size
399
+ if use_ddp:
400
+ dist.barrier()
401
+ if rank == 0 and hasattr(pbar, "close"):
402
+ pbar.close()
403
+
404
+ if use_ddp:
405
+ dist.barrier()
406
+ if rank == 0:
407
+ if cli.save_npz:
408
+ create_npz_from_sample_folder(cli.out_dir, n_total)
409
+ print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).")
410
+ if use_ddp:
411
+ dist.destroy_process_group()
412
+
413
+
414
+ if __name__ == "__main__":
415
+ main()
416
+
back/samplers.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def expand_t_like_x(t, x_cur):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x_cur.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+ def get_score_from_velocity(vt, xt, t, path_type="linear"):
16
+ """Wrapper function: transfrom velocity prediction model to score
17
+ Args:
18
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
19
+ x: [batch_dim, ...] shaped tensor; x_t data point
20
+ t: [batch_dim,] time tensor
21
+ """
22
+ t = expand_t_like_x(t, xt)
23
+ if path_type == "linear":
24
+ alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
25
+ sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
26
+ elif path_type == "cosine":
27
+ alpha_t = torch.cos(t * np.pi / 2)
28
+ sigma_t = torch.sin(t * np.pi / 2)
29
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
30
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ mean = xt
35
+ reverse_alpha_ratio = alpha_t / d_alpha_t
36
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
37
+ score = (reverse_alpha_ratio * vt - mean) / var
38
+
39
+ return score
40
+
41
+
42
+ def compute_diffusion(t_cur):
43
+ return 2 * t_cur
44
+
45
+
46
+ def build_sampling_time_steps(
47
+ num_steps=50,
48
+ t_c=None,
49
+ num_steps_before_tc=None,
50
+ num_steps_after_tc=None,
51
+ t_floor=0.04,
52
+ ):
53
+ """
54
+ 构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。
55
+
56
+ - 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。
57
+ - 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c);
58
+ t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。
59
+ """
60
+ t_floor = float(t_floor)
61
+ if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
62
+ ns = int(num_steps)
63
+ if ns < 1:
64
+ raise ValueError("num_steps must be >= 1")
65
+ t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64)
66
+ return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
67
+
68
+ t_c = float(t_c)
69
+ nb = int(num_steps_before_tc)
70
+ na = int(num_steps_after_tc)
71
+ if nb < 1 or na < 1:
72
+ raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c")
73
+ if not (0.0 < t_c < 1.0):
74
+ raise ValueError("t_c must be in (0, 1)")
75
+ if t_c <= t_floor:
76
+ raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})")
77
+
78
+ p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64)
79
+ p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64)
80
+ t_steps = torch.cat([p1, p2[1:]])
81
+ return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
82
+
83
+
84
+ def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc):
85
+ """仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。"""
86
+ return (
87
+ t_c is not None
88
+ and num_steps_before_tc is not None
89
+ and num_steps_after_tc is not None
90
+ )
91
+
92
+
93
+ def _cls_effective_and_freeze(
94
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
95
+ ):
96
+ """
97
+ 时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。
98
+ 返回 (用于前向的 cls, 更新后的 cls_frozen)。
99
+ """
100
+ if not freeze_after_tc or t_c_v is None:
101
+ return cls_x_cur, cls_frozen
102
+ if float(t_cur) <= float(t_c_v) + 1e-9:
103
+ if cls_frozen is None:
104
+ cls_frozen = cls_x_cur.clone()
105
+ return cls_frozen, cls_frozen
106
+ return cls_x_cur, cls_frozen
107
+
108
+
109
+ def _build_euler_sampler_time_steps(
110
+ num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
111
+ ):
112
+ """
113
+ euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。
114
+ """
115
+ if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
116
+ ns = int(num_steps)
117
+ if ns < 1:
118
+ raise ValueError("num_steps must be >= 1")
119
+ return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device)
120
+ t_c = float(t_c)
121
+ nb = int(num_steps_before_tc)
122
+ na = int(num_steps_after_tc)
123
+ if nb < 1 or na < 1:
124
+ raise ValueError(
125
+ "num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c"
126
+ )
127
+ if not (0.0 < t_c < 1.0):
128
+ raise ValueError("t_c must be in (0, 1)")
129
+ p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device)
130
+ p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device)
131
+ return torch.cat([p1, p2[1:]])
132
+
133
+
134
+ def euler_maruyama_sampler(
135
+ model,
136
+ latents,
137
+ y,
138
+ num_steps=20,
139
+ heun=False, # not used, just for compatability
140
+ cfg_scale=1.0,
141
+ guidance_low=0.0,
142
+ guidance_high=1.0,
143
+ path_type="linear",
144
+ cls_latents=None,
145
+ args=None,
146
+ return_mid_state=False,
147
+ t_mid=0.5,
148
+ t_c=None,
149
+ num_steps_before_tc=None,
150
+ num_steps_after_tc=None,
151
+ deterministic=False,
152
+ return_trajectory=False,
153
+ ):
154
+ """
155
+ Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致;
156
+ deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor),
157
+ 本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。
158
+ """
159
+ # setup conditioning
160
+ if cfg_scale > 1.0:
161
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
162
+ #[1000, 1000]
163
+ _dtype = latents.dtype
164
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
165
+
166
+ t_steps = build_sampling_time_steps(
167
+ num_steps=num_steps,
168
+ t_c=t_c,
169
+ num_steps_before_tc=num_steps_before_tc,
170
+ num_steps_after_tc=num_steps_after_tc,
171
+ )
172
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
173
+ t_c_v = float(t_c) if freeze_after_tc else None
174
+ x_next = latents.to(torch.float64)
175
+ cls_x_next = cls_latents.to(torch.float64)
176
+ device = x_next.device
177
+ z_mid = cls_mid = None
178
+ t_mid = float(t_mid)
179
+ cls_frozen = None
180
+ traj = [x_next.clone()] if return_trajectory else None
181
+
182
+ with torch.no_grad():
183
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
184
+ dt = t_next - t_cur
185
+ x_cur = x_next
186
+ cls_x_cur = cls_x_next
187
+
188
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
189
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
190
+ )
191
+
192
+ tc, tn = float(t_cur), float(t_next)
193
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
194
+ if abs(tc - t_mid) < abs(tn - t_mid):
195
+ z_mid = x_cur.clone()
196
+ cls_mid = cls_model_input.clone()
197
+
198
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
199
+ model_input = torch.cat([x_cur] * 2, dim=0)
200
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
201
+ y_cur = torch.cat([y, y_null], dim=0)
202
+ else:
203
+ model_input = x_cur
204
+ y_cur = y
205
+
206
+ kwargs = dict(y=y_cur)
207
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
208
+ diffusion = compute_diffusion(t_cur)
209
+
210
+ if deterministic:
211
+ deps = torch.zeros_like(x_cur)
212
+ cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)])
213
+ else:
214
+ eps_i = torch.randn_like(x_cur).to(device)
215
+ cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device)
216
+ deps = eps_i * torch.sqrt(torch.abs(dt))
217
+ cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
218
+
219
+ # compute drift
220
+ v_cur, _, cls_v_cur = model(
221
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
222
+ )
223
+ v_cur = v_cur.to(torch.float64)
224
+ cls_v_cur = cls_v_cur.to(torch.float64)
225
+
226
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
227
+ d_cur = v_cur - 0.5 * diffusion * s_cur
228
+
229
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
230
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
231
+
232
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
233
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
234
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
235
+
236
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
237
+ if cls_cfg > 0:
238
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
239
+ else:
240
+ cls_d_cur = cls_d_cur_cond
241
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
242
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
243
+ cls_x_next = cls_frozen
244
+ else:
245
+ cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
246
+
247
+ if return_trajectory:
248
+ traj.append(x_next.clone())
249
+
250
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
251
+ z_mid = x_next.clone()
252
+ cls_mid = cls_x_next.clone()
253
+
254
+ # last step
255
+ t_cur, t_next = t_steps[-2], t_steps[-1]
256
+ dt = t_next - t_cur
257
+ x_cur = x_next
258
+ cls_x_cur = cls_x_next
259
+
260
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
261
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
262
+ )
263
+
264
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
265
+ model_input = torch.cat([x_cur] * 2, dim=0)
266
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
267
+ y_cur = torch.cat([y, y_null], dim=0)
268
+ else:
269
+ model_input = x_cur
270
+ y_cur = y
271
+ kwargs = dict(y=y_cur)
272
+ time_input = torch.ones(model_input.size(0)).to(
273
+ device=device, dtype=torch.float64
274
+ ) * t_cur
275
+
276
+ # compute drift
277
+ v_cur, _, cls_v_cur = model(
278
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
279
+ )
280
+ v_cur = v_cur.to(torch.float64)
281
+ cls_v_cur = cls_v_cur.to(torch.float64)
282
+
283
+
284
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
285
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
286
+
287
+ diffusion = compute_diffusion(t_cur)
288
+ d_cur = v_cur - 0.5 * diffusion * s_cur
289
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
290
+
291
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
292
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
293
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
294
+
295
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
296
+ if cls_cfg > 0:
297
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
298
+ else:
299
+ cls_d_cur = cls_d_cur_cond
300
+
301
+ mean_x = x_cur + dt * d_cur
302
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
303
+ cls_mean_x = cls_frozen
304
+ else:
305
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
306
+
307
+ if return_trajectory:
308
+ traj.append(mean_x.clone())
309
+
310
+ if return_trajectory and return_mid_state:
311
+ return mean_x, z_mid, cls_mid, traj
312
+ if return_trajectory:
313
+ return mean_x, traj
314
+ if return_mid_state:
315
+ return mean_x, z_mid, cls_mid
316
+ return mean_x
317
+
318
+
319
+ def euler_maruyama_image_noise_sampler(
320
+ model,
321
+ latents,
322
+ y,
323
+ num_steps=20,
324
+ heun=False, # not used, just for compatability
325
+ cfg_scale=1.0,
326
+ guidance_low=0.0,
327
+ guidance_high=1.0,
328
+ path_type="linear",
329
+ cls_latents=None,
330
+ args=None,
331
+ return_mid_state=False,
332
+ t_mid=0.5,
333
+ t_c=None,
334
+ num_steps_before_tc=None,
335
+ num_steps_after_tc=None,
336
+ return_trajectory=False,
337
+ ):
338
+ """
339
+ EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。
340
+ """
341
+ if cfg_scale > 1.0:
342
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
343
+ _dtype = latents.dtype
344
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
345
+
346
+ t_steps = build_sampling_time_steps(
347
+ num_steps=num_steps,
348
+ t_c=t_c,
349
+ num_steps_before_tc=num_steps_before_tc,
350
+ num_steps_after_tc=num_steps_after_tc,
351
+ )
352
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
353
+ t_c_v = float(t_c) if freeze_after_tc else None
354
+ x_next = latents.to(torch.float64)
355
+ cls_x_next = cls_latents.to(torch.float64)
356
+ device = x_next.device
357
+ z_mid = cls_mid = None
358
+ t_mid = float(t_mid)
359
+ cls_frozen = None
360
+ traj = [x_next.clone()] if return_trajectory else None
361
+
362
+ with torch.no_grad():
363
+ for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
364
+ dt = t_next - t_cur
365
+ x_cur = x_next
366
+ cls_x_cur = cls_x_next
367
+
368
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
369
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
370
+ )
371
+
372
+ tc, tn = float(t_cur), float(t_next)
373
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
374
+ if abs(tc - t_mid) < abs(tn - t_mid):
375
+ z_mid = x_cur.clone()
376
+ cls_mid = cls_model_input.clone()
377
+
378
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
379
+ model_input = torch.cat([x_cur] * 2, dim=0)
380
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
381
+ y_cur = torch.cat([y, y_null], dim=0)
382
+ else:
383
+ model_input = x_cur
384
+ y_cur = y
385
+
386
+ kwargs = dict(y=y_cur)
387
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
388
+ diffusion = compute_diffusion(t_cur)
389
+
390
+ eps_i = torch.randn_like(x_cur).to(device)
391
+ deps = eps_i * torch.sqrt(torch.abs(dt))
392
+
393
+ v_cur, _, cls_v_cur = model(
394
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
395
+ )
396
+ v_cur = v_cur.to(torch.float64)
397
+ cls_v_cur = cls_v_cur.to(torch.float64)
398
+
399
+ if add_img_noise:
400
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
401
+ d_cur = v_cur - 0.5 * diffusion * s_cur
402
+
403
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
404
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
405
+ else:
406
+ # t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。
407
+ d_cur = v_cur
408
+ cls_d_cur = cls_v_cur
409
+
410
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
411
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
412
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
413
+
414
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
415
+ if cls_cfg > 0:
416
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
417
+ else:
418
+ cls_d_cur = cls_d_cur_cond
419
+
420
+ # 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项)
421
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
422
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
423
+ cls_x_next = cls_frozen
424
+ else:
425
+ cls_x_next = cls_x_cur + cls_d_cur * dt
426
+ if return_trajectory:
427
+ traj.append(x_next.clone())
428
+
429
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
430
+ z_mid = x_next.clone()
431
+ cls_mid = cls_x_next.clone()
432
+
433
+ t_cur, t_next = t_steps[-2], t_steps[-1]
434
+ dt = t_next - t_cur
435
+ x_cur = x_next
436
+ cls_x_cur = cls_x_next
437
+
438
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
439
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
440
+ )
441
+
442
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
443
+ model_input = torch.cat([x_cur] * 2, dim=0)
444
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
445
+ y_cur = torch.cat([y, y_null], dim=0)
446
+ else:
447
+ model_input = x_cur
448
+ y_cur = y
449
+ kwargs = dict(y=y_cur)
450
+ time_input = torch.ones(model_input.size(0)).to(
451
+ device=device, dtype=torch.float64
452
+ ) * t_cur
453
+
454
+ v_cur, _, cls_v_cur = model(
455
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
456
+ )
457
+ v_cur = v_cur.to(torch.float64)
458
+ cls_v_cur = cls_v_cur.to(torch.float64)
459
+
460
+ # 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。
461
+ d_cur = v_cur
462
+ cls_d_cur = cls_v_cur
463
+
464
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
465
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
466
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
467
+
468
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
469
+ if cls_cfg > 0:
470
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
471
+ else:
472
+ cls_d_cur = cls_d_cur_cond
473
+
474
+ mean_x = x_cur + dt * d_cur
475
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
476
+ cls_mean_x = cls_frozen
477
+ else:
478
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
479
+
480
+ if return_trajectory and return_mid_state:
481
+ return mean_x, z_mid, cls_mid, traj
482
+ if return_trajectory:
483
+ return mean_x, traj
484
+ if return_mid_state:
485
+ return mean_x, z_mid, cls_mid
486
+ return mean_x
487
+
488
+
489
+ def euler_maruyama_image_noise_before_tc_sampler(
490
+ model,
491
+ latents,
492
+ y,
493
+ num_steps=20,
494
+ heun=False, # not used, just for compatability
495
+ cfg_scale=1.0,
496
+ guidance_low=0.0,
497
+ guidance_high=1.0,
498
+ path_type="linear",
499
+ cls_latents=None,
500
+ args=None,
501
+ return_mid_state=False,
502
+ t_mid=0.5,
503
+ t_c=None,
504
+ num_steps_before_tc=None,
505
+ num_steps_after_tc=None,
506
+ return_cls_final=False,
507
+ return_trajectory=False,
508
+ ):
509
+ """
510
+ EM 采样变体:
511
+ - 图像 latent 在 t > t_c 区间引入随机扩散噪声;
512
+ - 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移);
513
+ - cls/token 通道全程不引入随机项。
514
+ """
515
+ if cfg_scale > 1.0:
516
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
517
+ _dtype = latents.dtype
518
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
519
+
520
+ t_steps = build_sampling_time_steps(
521
+ num_steps=num_steps,
522
+ t_c=t_c,
523
+ num_steps_before_tc=num_steps_before_tc,
524
+ num_steps_after_tc=num_steps_after_tc,
525
+ )
526
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
527
+ t_c_freeze = float(t_c) if freeze_after_tc else None
528
+ x_next = latents.to(torch.float64)
529
+ cls_x_next = cls_latents.to(torch.float64)
530
+ device = x_next.device
531
+ z_mid = cls_mid = None
532
+ t_mid = float(t_mid)
533
+ t_c_v = None if t_c is None else float(t_c)
534
+ cls_frozen = None
535
+ traj = [x_next.clone()] if return_trajectory else None
536
+
537
+ with torch.no_grad():
538
+ for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
539
+ dt = t_next - t_cur
540
+ x_cur = x_next
541
+ cls_x_cur = cls_x_next
542
+
543
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
544
+ cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
545
+ )
546
+
547
+ tc, tn = float(t_cur), float(t_next)
548
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
549
+ if abs(tc - t_mid) < abs(tn - t_mid):
550
+ z_mid = x_cur.clone()
551
+ cls_mid = cls_model_input.clone()
552
+
553
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
554
+ model_input = torch.cat([x_cur] * 2, dim=0)
555
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
556
+ y_cur = torch.cat([y, y_null], dim=0)
557
+ else:
558
+ model_input = x_cur
559
+ y_cur = y
560
+
561
+ kwargs = dict(y=y_cur)
562
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
563
+ diffusion = compute_diffusion(t_cur)
564
+
565
+ # 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声
566
+ add_img_noise = True
567
+ if t_c_v is not None and float(t_next) <= t_c_v:
568
+ add_img_noise = False
569
+
570
+ eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur)
571
+ deps = eps_i * torch.sqrt(torch.abs(dt))
572
+
573
+ v_cur, _, cls_v_cur = model(
574
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
575
+ )
576
+ v_cur = v_cur.to(torch.float64)
577
+ cls_v_cur = cls_v_cur.to(torch.float64)
578
+
579
+ if add_img_noise:
580
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
581
+ d_cur = v_cur - 0.5 * diffusion * s_cur
582
+
583
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
584
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
585
+ else:
586
+ # t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项)
587
+ d_cur = v_cur
588
+ cls_d_cur = cls_v_cur
589
+
590
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
591
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
592
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
593
+
594
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
595
+ if cls_cfg > 0:
596
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
597
+ else:
598
+ cls_d_cur = cls_d_cur_cond
599
+
600
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
601
+ if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
602
+ cls_x_next = cls_frozen
603
+ else:
604
+ cls_x_next = cls_x_cur + cls_d_cur * dt
605
+ if return_trajectory:
606
+ traj.append(x_next.clone())
607
+
608
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
609
+ z_mid = x_next.clone()
610
+ cls_mid = cls_x_next.clone()
611
+
612
+ t_cur, t_next = t_steps[-2], t_steps[-1]
613
+ dt = t_next - t_cur
614
+ x_cur = x_next
615
+ cls_x_cur = cls_x_next
616
+
617
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
618
+ cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
619
+ )
620
+
621
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
622
+ model_input = torch.cat([x_cur] * 2, dim=0)
623
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
624
+ y_cur = torch.cat([y, y_null], dim=0)
625
+ else:
626
+ model_input = x_cur
627
+ y_cur = y
628
+ kwargs = dict(y=y_cur)
629
+ time_input = torch.ones(model_input.size(0)).to(
630
+ device=device, dtype=torch.float64
631
+ ) * t_cur
632
+
633
+ v_cur, _, cls_v_cur = model(
634
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
635
+ )
636
+ v_cur = v_cur.to(torch.float64)
637
+ cls_v_cur = cls_v_cur.to(torch.float64)
638
+
639
+ # 最后一步无随机项,保持与 ODE 一致使用 d=v。
640
+ d_cur = v_cur
641
+ cls_d_cur = cls_v_cur
642
+
643
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
644
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
645
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
646
+
647
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
648
+ if cls_cfg > 0:
649
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
650
+ else:
651
+ cls_d_cur = cls_d_cur_cond
652
+
653
+ mean_x = x_cur + dt * d_cur
654
+ if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
655
+ cls_mean_x = cls_frozen
656
+ else:
657
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
658
+
659
+ if return_trajectory and return_mid_state and return_cls_final:
660
+ return mean_x, z_mid, cls_mid, cls_mean_x, traj
661
+ if return_trajectory and return_mid_state:
662
+ return mean_x, z_mid, cls_mid, traj
663
+ if return_trajectory and return_cls_final:
664
+ return mean_x, cls_mean_x, traj
665
+ if return_trajectory:
666
+ return mean_x, traj
667
+ if return_mid_state and return_cls_final:
668
+ return mean_x, z_mid, cls_mid, cls_mean_x
669
+ if return_mid_state:
670
+ return mean_x, z_mid, cls_mid
671
+ if return_cls_final:
672
+ return mean_x, cls_mean_x
673
+ return mean_x
674
+
675
+
676
+ def euler_ode_sampler(
677
+ model,
678
+ latents,
679
+ y,
680
+ num_steps=20,
681
+ cfg_scale=1.0,
682
+ guidance_low=0.0,
683
+ guidance_high=1.0,
684
+ path_type="linear",
685
+ cls_latents=None,
686
+ args=None,
687
+ return_mid_state=False,
688
+ t_mid=0.5,
689
+ t_c=None,
690
+ num_steps_before_tc=None,
691
+ num_steps_after_tc=None,
692
+ return_trajectory=False,
693
+ ):
694
+ """
695
+ REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。
696
+ """
697
+ return euler_sampler(
698
+ model,
699
+ latents,
700
+ y,
701
+ num_steps=num_steps,
702
+ heun=False,
703
+ cfg_scale=cfg_scale,
704
+ guidance_low=guidance_low,
705
+ guidance_high=guidance_high,
706
+ path_type=path_type,
707
+ cls_latents=cls_latents,
708
+ args=args,
709
+ return_mid_state=return_mid_state,
710
+ t_mid=t_mid,
711
+ t_c=t_c,
712
+ num_steps_before_tc=num_steps_before_tc,
713
+ num_steps_after_tc=num_steps_after_tc,
714
+ return_trajectory=return_trajectory,
715
+ )
716
+
717
+
718
+ def euler_sampler(
719
+ model,
720
+ latents,
721
+ y,
722
+ num_steps=20,
723
+ heun=False,
724
+ cfg_scale=1.0,
725
+ guidance_low=0.0,
726
+ guidance_high=1.0,
727
+ path_type="linear",
728
+ cls_latents=None,
729
+ args=None,
730
+ return_mid_state=False,
731
+ t_mid=0.5,
732
+ t_c=None,
733
+ num_steps_before_tc=None,
734
+ num_steps_after_tc=None,
735
+ return_trajectory=False,
736
+ ):
737
+ """
738
+ 轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。
739
+
740
+ - 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。
741
+ - 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。
742
+ - 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。
743
+
744
+ REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。
745
+ """
746
+ if cls_latents is None:
747
+ raise ValueError(
748
+ "euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。"
749
+ )
750
+ if cfg_scale > 1.0:
751
+ y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype)
752
+ else:
753
+ y_null = None
754
+ _dtype = latents.dtype
755
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
756
+ device = latents.device
757
+
758
+ t_steps = _build_euler_sampler_time_steps(
759
+ num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
760
+ )
761
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
762
+ t_c_v = float(t_c) if freeze_after_tc else None
763
+
764
+ x_next = latents.to(torch.float64)
765
+ cls_x_next = cls_latents.to(torch.float64)
766
+ z_mid = cls_mid = None
767
+ t_mid = float(t_mid)
768
+ cls_frozen = None
769
+ traj = [x_next.clone()] if return_trajectory else None
770
+
771
+ with torch.no_grad():
772
+ for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
773
+ dt = t_next - t_cur
774
+ x_cur = x_next
775
+ cls_x_cur = cls_x_next
776
+
777
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
778
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
779
+ )
780
+
781
+ tc, tn = float(t_cur), float(t_next)
782
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
783
+ if abs(tc - t_mid) < abs(tn - t_mid):
784
+ z_mid = x_cur.clone()
785
+ cls_mid = cls_model_input.clone()
786
+
787
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
788
+ model_input = torch.cat([x_cur] * 2, dim=0)
789
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
790
+ y_cur = torch.cat([y, y_null], dim=0)
791
+ else:
792
+ model_input = x_cur
793
+ y_cur = y
794
+
795
+ time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
796
+
797
+ v_cur, _, cls_v_cur = model(
798
+ model_input.to(dtype=_dtype),
799
+ time_input.to(dtype=_dtype),
800
+ y_cur,
801
+ cls_token=cls_model_input.to(dtype=_dtype),
802
+ )
803
+ v_cur = v_cur.to(torch.float64)
804
+ cls_v_cur = cls_v_cur.to(torch.float64)
805
+
806
+ # ODE: follow velocity parameterization directly (d/dt x_t = v_t).
807
+ # This aligns with velocity training target and avoids extra v->score->drift conversion.
808
+ d_cur = v_cur
809
+ cls_d_cur = cls_v_cur
810
+
811
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
812
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
813
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
814
+
815
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
816
+ if cls_cfg > 0:
817
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
818
+ else:
819
+ cls_d_cur = cls_d_cur_cond
820
+
821
+ x_next = x_cur + dt * d_cur
822
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
823
+ cls_x_next = cls_frozen
824
+ else:
825
+ cls_x_next = cls_x_cur + dt * cls_d_cur
826
+
827
+ if return_trajectory:
828
+ traj.append(x_next.clone())
829
+
830
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
831
+ z_mid = x_next.clone()
832
+ cls_mid = cls_x_next.clone()
833
+
834
+ if return_trajectory and return_mid_state:
835
+ return x_next, z_mid, cls_mid, traj
836
+ if return_trajectory:
837
+ return x_next, traj
838
+ if return_mid_state:
839
+ return x_next, z_mid, cls_mid
840
+ return x_next
back/samples_0.75.log ADDED
The diff for this file is too large to render. See raw diff
 
back/samples_0.75_new.log ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793]
2
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
5
+ 时间网格:t_c=0.75, 步数 (1→t_c)=100, (t_c→0)=50
6
+ Total number of images that will be sampled: 40192
7
+
8
+ [rank3]:[W325 16:57:00.799344818 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
9
+ [rank1]:[W325 16:57:00.847229448 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
10
+ [rank0]:[W325 16:57:01.326116049 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
11
+
conditional-flow-matching/.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ ### VisualStudioCode
131
+ .vscode/*
132
+ !.vscode/settings.json
133
+ !.vscode/tasks.json
134
+ !.vscode/launch.json
135
+ !.vscode/extensions.json
136
+ *.code-workspace
137
+ **/.vscode
138
+
139
+ # JetBrains
140
+ .idea/
141
+
142
+ # Lightning-Hydra-Template
143
+ configs/local/default.yaml
144
+ data/
145
+ logs/
146
+ wandb/
147
+ .env
148
+ .autoenv
149
+
150
+ #Vim
151
+ *.sw?
152
+
153
+ # Slurm
154
+ slurm*.out
155
+
156
+ # Data and models
157
+ *.pt
158
+ *.h5
159
+ *.h5ad
160
+ *.tar
161
+ *.tar.gz
162
+ *.pkl
163
+ *.npy
164
+ *.npz
165
+ *.csv
166
+
167
+ # Images
168
+ *.png
169
+ *.svg
170
+ *.gif
171
+ *.jpg
172
+
173
+ notebooks/figures/
174
+
175
+ .DS_Store
conditional-flow-matching/.pre-commit-config.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ # list of supported hooks: https://pre-commit.com/hooks.html
6
+ - id: trailing-whitespace
7
+ exclude: .svg$
8
+ require_serial: true
9
+ - id: end-of-file-fixer
10
+ require_serial: true
11
+ - id: check-docstring-first
12
+ require_serial: true
13
+ - id: check-yaml
14
+ require_serial: true
15
+ - id: debug-statements
16
+ require_serial: true
17
+ - id: detect-private-key
18
+ require_serial: true
19
+ - id: check-executables-have-shebangs
20
+ require_serial: true
21
+ - id: check-toml
22
+ require_serial: true
23
+ - id: check-case-conflict
24
+ require_serial: true
25
+ - id: check-added-large-files
26
+ require_serial: true
27
+
28
+ # python upgrading syntax to newer version
29
+ - repo: https://github.com/asottile/pyupgrade
30
+ rev: v3.17.0
31
+ hooks:
32
+ - id: pyupgrade
33
+ require_serial: true
34
+ args: [--py38-plus]
35
+
36
+ - repo: https://github.com/astral-sh/ruff-pre-commit
37
+ rev: v0.5.5
38
+ hooks:
39
+ - id: ruff
40
+ args: [--fix]
41
+ - id: ruff-format
42
+
43
+ # python security linter
44
+ - repo: https://github.com/gitleaks/gitleaks
45
+ rev: v8.18.2
46
+ hooks:
47
+ - id: gitleaks
48
+
49
+ # yaml formatting
50
+ - repo: https://github.com/pre-commit/mirrors-prettier
51
+ rev: v3.0.0
52
+ hooks:
53
+ - id: prettier
54
+ require_serial: true
55
+ types: [yaml]
56
+
57
+ # shell scripts linter
58
+ - repo: https://github.com/shellcheck-py/shellcheck-py
59
+ rev: v0.9.0.5
60
+ hooks:
61
+ - id: shellcheck
62
+ require_serial: true
63
+ args: ["-e", "SC2102"]
64
+
65
+ - repo: https://github.com/pre-commit/mirrors-prettier
66
+ rev: v4.0.0-alpha.8
67
+ hooks:
68
+ - id: prettier
69
+ # To avoid conflicts, tell prettier to ignore file types
70
+ # that ruff already handles.
71
+ exclude_types: [python]
72
+
73
+ # word spelling linter
74
+ - repo: https://github.com/codespell-project/codespell
75
+ rev: v2.2.5
76
+ hooks:
77
+ - id: codespell
78
+ require_serial: true
79
+ args:
80
+ - --skip=logs/**,data/**,*.ipynb
81
+ - --ignore-words-list=ot,hist
82
+
83
+ # jupyter notebook linting
84
+ - repo: https://github.com/nbQA-dev/nbQA
85
+ rev: 1.7.0
86
+ hooks:
87
+ - id: nbqa-black
88
+ args: ["--line-length=99"]
89
+ require_serial: true
90
+ - id: nbqa-isort
91
+ args: ["--profile=black"]
92
+ require_serial: true
93
+ - id: nbqa-flake8
94
+ args:
95
+ [
96
+ "--extend-ignore=E203,E402,E501,F401,F841,F821,F403,F405,F811",
97
+ "--exclude=logs/*,data/*,notebooks/*",
98
+ ]
99
+ require_serial: true
conditional-flow-matching/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Alexander Tong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
conditional-flow-matching/README.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # TorchCFM: a Conditional Flow Matching library
4
+
5
+ <!---[![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020) -->
6
+
7
+ <!---[![contributors](https://img.shields.io/github/contributors/atong01/conditional-flow-matching.svg)](https://github.com/atong01/conditional-flow-matching/graphs/contributors) -->
8
+
9
+ [![OT-CFM Preprint](http://img.shields.io/badge/paper-arxiv.2302.00482-B31B1B.svg)](https://arxiv.org/abs/2302.00482)
10
+ [![SF2M Preprint](http://img.shields.io/badge/paper-arxiv.2307.03672-B31B1B.svg)](https://arxiv.org/abs/2307.03672)
11
+ [![pytorch](https://img.shields.io/badge/PyTorch_1.8+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
12
+ [![lightning](https://img.shields.io/badge/-Lightning_1.6+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
13
+ [![hydra](https://img.shields.io/badge/Config-Hydra_1.2-89b8cd)](https://hydra.cc/)
14
+ [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
15
+ [![pre-commit](https://img.shields.io/badge/Pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
16
+ [![tests](https://github.com/atong01/conditional-flow-matching/actions/workflows/test.yaml/badge.svg)](https://github.com/atong01/conditional-flow-matching/actions/workflows/test.yaml)
17
+ [![codecov](https://codecov.io/gh/atong01/conditional-flow-matching/branch/main/graph/badge.svg)](https://codecov.io/gh/atong01/conditional-flow-matching/)
18
+ [![code-quality](https://github.com/atong01/conditional-flow-matching/actions/workflows/code-quality-main.yaml/badge.svg)](https://github.com/atong01/conditional-flow-matching/actions/workflows/code-quality-main.yaml)
19
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/atong01/conditional-flow-matching#license)
20
+ <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a>
21
+ [![Downloads](https://static.pepy.tech/badge/torchcfm)](https://pepy.tech/project/torchcfm)
22
+ [![Downloads](https://static.pepy.tech/badge/torchcfm/month)](https://pepy.tech/project/torchcfm)
23
+
24
+ </div>
25
+
26
+ ## Description
27
+
28
+ Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference. CFM's performance closes the gap between CNFs and diffusion models. To spread its use within the machine learning community, we have built a library focused on Flow Matching methods: TorchCFM. TorchCFM is a library showing how Flow Matching methods can be trained and used to deal with image generation, single-cell dynamics, tabular data and soon SO(3) data.
29
+
30
+ <p align="center">
31
+ <img src="assets/169_generated_samples_otcfm.png" width="600"/>
32
+ <img src="assets/8gaussians-to-moons.gif" />
33
+ </p>
34
+
35
+ The density, vector field, and trajectories of simulation-free CNF training schemes: mapping 8 Gaussians to two moons (above) and a single Gaussian to two moons (below). Action matching with the same architecture (3x64 MLP with SeLU activations) underfits with the ReLU, SiLU, and SiLU activations as suggested in the [example code](https://github.com/necludov/jam), but it seems to fit better under our training setup (Action-Matching (Swish)).
36
+
37
+ The models to produce the GIFs are stored in `examples/models` and can be visualized with this notebook: [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/model-comparison-plotting.ipynb).
38
+
39
+ We also have included an example of unconditional MNIST generation in `examples/notebooks/mnist_example.ipynb` for both deterministic and stochastic generation. [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/mnist_example.ipynb).
40
+
41
+ ## The torchcfm Package
42
+
43
+ In our version 1 update we have extracted implementations of the relevant flow matching variants into a package `torchcfm`. This allows abstraction of the choice of the conditional distribution `q(z)`. `torchcfm` supplies the following loss functions:
44
+
45
+ - `ConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = q(x_0) q(x_1)$
46
+ - `ExactOptimalTransportConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi(x_0, x_1)$ where $\\pi$ is an exact optimal transport joint. This is used in \[Tong et al. 2023a\] and \[Poolidan et al. 2023\] as "OT-CFM" and "Multisample FM with Batch OT" respectively.
47
+ - `TargetConditionalFlowMatcher`: $z = x_1$, $q(z) = q(x_1)$ as defined in Lipman et al. 2023, learns a flow from a standard normal Gaussian to data using conditional flows which optimally transport the Gaussian to the datapoint (Note that this does not result in the marginal flow being optimal transport).
48
+ - `SchrodingerBridgeConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi\_\\epsilon(x_0, x_1)$ where $\\pi\_\\epsilon$ is an entropically regularized OT plan, although in practice this is often approximated by a minibatch OT plan (See Tong et al. 2023b). The flow-matching variant of this where the marginals are equivalent to the Schrodinger Bridge marginals is known as `SB-CFM` \[Tong et al. 2023a\]. When the score is also known and the bridge is stochastic is called \[SF\]2M \[Tong et al. 2023b\]
49
+ - `VariancePreservingConditionalFlowMatcher`: $z = (x_0, x_1)$ $q(z) = q(x_0) q(x_1)$ but with conditional Gaussian probability paths which preserve variance over time using a trigonometric interpolation as presented in \[Albergo et al. 2023a\].
50
+
51
+ ## How to cite
52
+
53
+ This repository contains the code to reproduce the main experiments and illustrations of two preprints:
54
+
55
+ - [Improving and generalizing flow-based generative models with minibatch optimal transport](https://arxiv.org/abs/2302.00482). We introduce **Optimal Transport Conditional Flow Matching** (OT-CFM), a CFM variant that approximates the dynamical formulation of optimal transport (OT). Based on OT theory, OT-CFM leverages the static optimal transport plan as well as the optimal probability paths and vector fields to approximate dynamic OT.
56
+ - [Simulation-free Schrödinger bridges via score and flow matching](https://arxiv.org/abs/2307.03672). We propose **Simulation-Free Score and Flow Matching** (\[SF\]<sup>2</sup>M). \[SF\]<sup>2</sup>M leverages OT-CFM as well as score-based methods to approximate Schrödinger bridges, a stochastic version of optimal transport.
57
+
58
+ If you find this code useful in your research, please cite the following papers (expand for BibTeX):
59
+
60
+ <details>
61
+ <summary>
62
+ A. Tong, N. Malkin, G. Huguet, Y. Zhang, J. Rector-Brooks, K. Fatras, G. Wolf, Y. Bengio. Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport, 2023.
63
+ </summary>
64
+
65
+ ```bibtex
66
+ @article{tong2024improving,
67
+ title={Improving and generalizing flow-based generative models with minibatch optimal transport},
68
+ author={Alexander Tong and Kilian FATRAS and Nikolay Malkin and Guillaume Huguet and Yanlei Zhang and Jarrid Rector-Brooks and Guy Wolf and Yoshua Bengio},
69
+ journal={Transactions on Machine Learning Research},
70
+ issn={2835-8856},
71
+ year={2024},
72
+ url={https://openreview.net/forum?id=CD9Snc73AW},
73
+ note={Expert Certification}
74
+ }
75
+ ```
76
+
77
+ </details>
78
+
79
+ <details>
80
+ <summary>
81
+ A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Bengio. Simulation-Free Schrödinger Bridges via Score and Flow Matching, 2023.
82
+ </summary>
83
+
84
+ ```bibtex
85
+ @article{tong2023simulation,
86
+ title={Simulation-Free Schr{\"o}dinger Bridges via Score and Flow Matching},
87
+ author={Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua},
88
+ year={2023},
89
+ journal={arXiv preprint 2307.03672}
90
+ }
91
+ ```
92
+
93
+ </details>
94
+
95
+ ## V0 -> V1
96
+
97
+ Major Changes:
98
+
99
+ - **Added cifar10 examples with an FID of 3.5**
100
+ - Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint
101
+ - Created `torchcfm` pip installable package
102
+ - Moved `pytorch-lightning` implementation and experiments to `runner` directory
103
+ - Moved `notebooks` -> `examples`
104
+ - Added image generation implementation in both lightning and a notebook in `examples`
105
+
106
+ ## Implemented papers
107
+
108
+ List of implemented papers:
109
+
110
+ - Flow Matching for Generative Modeling (Lipman et al. 2023) [Paper](https://openreview.net/forum?id=PqvMRDCJT9t)
111
+ - Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow (Liu et al. 2023) [Paper](https://openreview.net/forum?id=XVjTT1nw5z) [Code](https://github.com/gnobitab/RectifiedFlow.git)
112
+ - Building Normalizing Flows with Stochastic Interpolants (Albergo et al. 2023a) [Paper](https://openreview.net/forum?id=li7qeBbCR1t)
113
+ - Action Matching: Learning Stochastic Dynamics From Samples (Neklyudov et al. 2022) [Paper](https://arxiv.org/abs/2210.06662) [Code](https://github.com/necludov/jam)
114
+ - Concurrent work to our OT-CFM method: Multisample Flow Matching: Straightening Flows with Minibatch Couplings (Pooladian et al. 2023) [Paper](https://arxiv.org/abs/2304.14772)
115
+ - Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees (Jolicoeur-Martineau et al.) [Paper](https://arxiv.org/abs/2309.09968) [Code](https://github.com/SamsungSAILMontreal/ForestDiffusion)
116
+ - Soon: SE(3)-Stochastic Flow Matching for Protein Backbone Generation (Bose et al.) [Paper](https://arxiv.org/abs/2310.02391)
117
+
118
+ ## How to run
119
+
120
+ Run a simple minimal example here [![Run in Google Colab](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/training-8gaussians-to-moons.ipynb). Or install the more efficient code locally with these steps.
121
+
122
+ TorchCFM is now on [pypi](https://pypi.org/project/torchcfm/)! You can install it with:
123
+
124
+ ```bash
125
+ pip install torchcfm
126
+ ```
127
+
128
+ To use the full library with the different examples, you can install dependencies:
129
+
130
+ ```bash
131
+ # clone project
132
+ git clone https://github.com/atong01/conditional-flow-matching.git
133
+ cd conditional-flow-matching
134
+
135
+ # [OPTIONAL] create conda environment
136
+ conda create -n torchcfm python=3.10
137
+ conda activate torchcfm
138
+
139
+ # install pytorch according to instructions
140
+ # https://pytorch.org/get-started/
141
+
142
+ # install requirements
143
+ pip install -r requirements.txt
144
+
145
+ # install torchcfm
146
+ pip install -e .
147
+ ```
148
+
149
+ To run our jupyter notebooks, use the following commands after installing our package.
150
+
151
+ ```bash
152
+ # install ipykernel
153
+ conda install -c anaconda ipykernel
154
+
155
+ # install conda env in jupyter notebook
156
+ python -m ipykernel install --user --name=torchcfm
157
+
158
+ # launch our notebooks with the torchcfm kernel
159
+ ```
160
+
161
+ ## Project Structure
162
+
163
+ The directory structure looks like this:
164
+
165
+ ```
166
+
167
+
168
+ ├── examples <- Jupyter notebooks
169
+ | ├── cifar10 <- Cifar10 experiments
170
+ │ ├── notebooks <- Diverse examples with notebooks
171
+
172
+ │── runner <- Everything related to the original version (V0) of the library
173
+
174
+ |── torchcfm <- Code base of our Flow Matching methods
175
+ | ├── conditional_flow_matching.py <- CFM classes
176
+ │ ├── models <- Model architectures
177
+ │ │ ├── models <- Models for 2D examples
178
+ │ │ ├── Unet <- Unet models for image examples
179
+ |
180
+ ├── .gitignore <- List of files ignored by git
181
+ ├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
182
+ ├── pyproject.toml <- Configuration options for testing and linting
183
+ ├── requirements.txt <- File for installing python dependencies
184
+ ├── setup.py <- File for installing project as a package
185
+ └── README.md
186
+ ```
187
+
188
+ ## ❤️  Code Contributions
189
+
190
+ This toolbox has been created and is maintained by
191
+
192
+ - [Alexander Tong](http://alextong.net)
193
+ - [Kilian Fatras](http://kilianfatras.github.io)
194
+
195
+ It was initiated from a larger private codebase which loses the original commit history which contains work from other authors of the papers.
196
+
197
+ Before making an issue, please verify that:
198
+
199
+ - The problem still exists on the current `main` branch.
200
+ - Your python dependencies are updated to recent versions.
201
+
202
+ Suggestions for improvements are always welcome!
203
+
204
+ ## License
205
+
206
+ Conditional-Flow-Matching is licensed under the MIT License.
207
+
208
+ ```
209
+ MIT License
210
+
211
+ Copyright (c) 2023 Alexander Tong
212
+
213
+ Permission is hereby granted, free of charge, to any person obtaining a copy
214
+ of this software and associated documentation files (the "Software"), to deal
215
+ in the Software without restriction, including without limitation the rights
216
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
217
+ copies of the Software, and to permit persons to whom the Software is
218
+ furnished to do so, subject to the following conditions:
219
+
220
+ The above copyright notice and this permission notice shall be included in all
221
+ copies or substantial portions of the Software.
222
+
223
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
224
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
225
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
226
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
227
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
228
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
229
+ SOFTWARE.
230
+ ```
conditional-flow-matching/pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.pytest.ini_options]
2
+ addopts = [
3
+ "--color=yes",
4
+ "--durations=0",
5
+ "--strict-markers",
6
+ "--doctest-modules",
7
+ ]
8
+ filterwarnings = [
9
+ "ignore::DeprecationWarning",
10
+ "ignore::UserWarning",
11
+ ]
12
+ log_cli = "True"
13
+ markers = [
14
+ "slow: slow tests",
15
+ ]
16
+ minversion = "6.0"
17
+ testpaths = "tests/"
18
+
19
+ [tool.coverage.report]
20
+ exclude_lines = [
21
+ "pragma: nocover",
22
+ "raise NotImplementedError",
23
+ "raise NotImplementedError()",
24
+ "if __name__ == .__main__.:",
25
+ ]
26
+
27
+ [tool.ruff]
28
+ line-length = 99
29
+
30
+ [tool.ruff.lint]
31
+ ignore = ["C901", "E501", "E741", "W605", "C408", "E402"]
32
+ select = ["C", "E", "F", "I", "W"]
33
+
34
+ [tool.ruff.lint.per-file-ignores]
35
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
36
+
37
+ [tool.ruff.lint.isort]
38
+ known-first-party = ["src"]
39
+ known-third-party = ["torch", "transformers", "wandb"]
conditional-flow-matching/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.11.0
2
+ torchvision>=0.11.0
3
+
4
+ lightning-bolts
5
+ matplotlib
6
+ numpy
7
+ scipy
8
+ scikit-learn
9
+ scprep
10
+ scanpy
11
+ torchdyn>=1.0.6 # 1.0.4 is broken on pypi
12
+ pot
13
+ torchdiffeq
14
+ absl-py
15
+ clean-fid
conditional-flow-matching/runner-requirements.txt ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note if using Conda it is recommended to install torch separately.
2
+ # For most of testing the following commands were run to set up the environment
3
+ # This was tested with torch==1.12.1
4
+ # conda create -n ti-env python=3.10
5
+ # conda activate ti-env
6
+ # pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
7
+ # pip install -r requirements.txt
8
+ # --------- pytorch --------- #
9
+ torch>=1.11.0,<2.0.0
10
+ torchvision>=0.11.0
11
+ pytorch-lightning==1.8.3.post2
12
+ torchmetrics==0.11.0
13
+
14
+ # --------- hydra --------- #
15
+ hydra-core==1.2.0
16
+ hydra-colorlog==1.2.0
17
+ hydra-optuna-sweeper==1.2.0
18
+ # hydra-submitit-launcher
19
+
20
+ # --------- loggers --------- #
21
+ wandb
22
+ # neptune-client
23
+ # mlflow
24
+ # comet-ml
25
+
26
+ # --------- others --------- #
27
+ black
28
+ isort
29
+ flake8
30
+ Flake8-pyproject # for configuration via pyproject
31
+ pyrootutils # standardizing the project root setup
32
+ pre-commit # hooks for applying linters on commit
33
+ rich # beautiful text formatting in terminal
34
+ pytest # tests
35
+ # sh # for running bash commands in some tests (linux/macos only)
36
+
37
+
38
+ # --------- pkg reqs -------- #
39
+ lightning-bolts
40
+ matplotlib
41
+ numpy
42
+ scipy
43
+ scikit-learn
44
+ scprep
45
+ scanpy
46
+ timm
47
+ torchdyn>=1.0.5 # 1.0.4 is broken on pypi
48
+ pot
49
+
50
+ # --------- notebook reqs -------- #
51
+ seaborn>=0.12.2
52
+ pandas>=2.2.2
conditional-flow-matching/setup.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+
5
+ from setuptools import find_packages, setup
6
+
7
+ install_requires = [
8
+ "torch>=1.11.0",
9
+ "matplotlib",
10
+ "numpy", # Due to pandas incompatibility
11
+ "scipy",
12
+ "scikit-learn",
13
+ "torchdyn>=1.0.6",
14
+ "pot",
15
+ "torchdiffeq",
16
+ "absl-py",
17
+ "pandas>=2.2.2",
18
+ ]
19
+
20
+ version_py = os.path.join(os.path.dirname(__file__), "torchcfm", "version.py")
21
+ version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip()
22
+ readme = open("README.md", encoding="utf8").read()
23
+ setup(
24
+ name="torchcfm",
25
+ version=version,
26
+ description="Conditional Flow Matching for Fast Continuous Normalizing Flow Training.",
27
+ author="Alexander Tong, Kilian Fatras",
28
+ author_email="alexandertongdev@gmail.com",
29
+ url="https://github.com/atong01/conditional-flow-matching",
30
+ install_requires=install_requires,
31
+ license="MIT",
32
+ long_description=readme,
33
+ long_description_content_type="text/markdown",
34
+ packages=find_packages(exclude=["tests", "tests.*"]),
35
+ extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]},
36
+ )