xiangzai commited on
Commit
7803bdf
·
verified ·
1 Parent(s): 3c1ccbd

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. __pycache__/pic_npz.cpython-311.pyc +0 -0
  2. __pycache__/pipeline_stable_diffusion_3.cpython-310.pyc +0 -0
  3. __pycache__/sample_sd3_lora_rn_pair_ddp.cpython-311.pyc +0 -0
  4. __pycache__/train_rectified_noise.cpython-310.pyc +0 -0
  5. __pycache__/visualize_lora_rn_4x8.cpython-311.pyc +0 -0
  6. accelerate_config.yaml +16 -0
  7. cc3m_render.log +0 -0
  8. cc3m_render.py +155 -0
  9. download.log +0 -0
  10. download_sd3_models.py +71 -0
  11. eval_baseline.log +24 -0
  12. eval_rectified_noise_new_batch_2.log +24 -0
  13. evaluate.sh +11 -0
  14. evaluator_base copy.py +680 -0
  15. evaluator_base.log +5 -0
  16. evaluator_base.py +685 -0
  17. evaluator_rf.py +685 -0
  18. evaluator_rf_iter22.log +25 -0
  19. pic_npz copy.py +259 -0
  20. pic_npz.py +157 -0
  21. pipeline_stable_diffusion_3.py +1378 -0
  22. rectified-noise-batch-2/checkpoint-100000/sit_weights/sit_config.json +10 -0
  23. rectified-noise-batch-2/checkpoint-120000/sit_weights/sit_config.json +10 -0
  24. rectified-noise-batch-2/checkpoint-140000/sit_weights/sit_config.json +10 -0
  25. rectified-noise-batch-2/checkpoint-160000/sit_weights/sit_config.json +10 -0
  26. rectified-noise-batch-2/checkpoint-180000/sit_weights/sit_config.json +10 -0
  27. rectified-noise-batch-2/checkpoint-200000/sit_weights/sit_config.json +10 -0
  28. run_sd3_lora_rn_pair_sampling.sh +50 -0
  29. run_sd3_lora_sampling.log +0 -0
  30. run_sd3_lora_sampling.sh +94 -0
  31. run_sd3_rectified_sampling.sh +55 -0
  32. run_sd3_rectified_sampling_old.sh +72 -0
  33. sample_sd3_lora_checkpoint_ddp.py +818 -0
  34. sample_sd3_lora_ddp.py +675 -0
  35. sample_sd3_lora_rn_pair_ddp.py +417 -0
  36. sample_sd3_rectified_ddp.py +1316 -0
  37. sample_sd3_rectified_ddp_old.py +1317 -0
  38. sd3_rectified_samples_batch2_2200005011.01.01.0cfg_cond_true.txt +5 -0
  39. train_lora_sd3.py +1597 -0
  40. train_lora_sd3_new.py +1422 -0
  41. train_rectified_noise.py +0 -0
  42. train_rectified_noise.sh +104 -0
  43. train_rectified_noise2.py +0 -0
  44. train_sd3_lora.log +27 -0
  45. train_sd3_lora.sh +109 -0
  46. train_sd3_lora2.log +216 -0
  47. train_sd3_lora2.sh +107 -0
  48. visual.sh +78 -0
  49. visualize_lora_rn_4x8.py +406 -0
  50. visualize_sitf2_noise_evolution.py +169 -0
__pycache__/pic_npz.cpython-311.pyc ADDED
Binary file (7.6 kB). View file
 
__pycache__/pipeline_stable_diffusion_3.cpython-310.pyc ADDED
Binary file (39.5 kB). View file
 
__pycache__/sample_sd3_lora_rn_pair_ddp.cpython-311.pyc ADDED
Binary file (27.1 kB). View file
 
__pycache__/train_rectified_noise.cpython-310.pyc ADDED
Binary file (58.7 kB). View file
 
__pycache__/visualize_lora_rn_4x8.cpython-311.pyc ADDED
Binary file (23.9 kB). View file
 
accelerate_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'no'
9
+ num_machines: 1
10
+ num_processes: 4
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
cc3m_render.log ADDED
The diff for this file is too large to render. See raw diff
 
cc3m_render.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ """
5
+ 在 `data_root/` 下已经有 `train/` 和 `validation/` 两个文件夹时:
6
+ 分别在这两个文件夹内生成对应的 `metadata.jsonl`,不复制任何图片。
7
+
8
+ `metadata.jsonl` 每行格式:
9
+ {"file_name": "subdir/000026831.jpg", "caption": "..."}
10
+
11
+ 其中 `file_name` 是相对当前 split 目录(train/ 或 validation/)的路径。
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from itertools import islice
19
+ from pathlib import Path
20
+ from typing import Optional, Tuple
21
+
22
+ from tqdm import tqdm
23
+
24
+
25
+ def parse_args() -> argparse.Namespace:
26
+ parser = argparse.ArgumentParser(description="Generate per-split metadata.jsonl for imagefolder (no copy)")
27
+ parser.add_argument(
28
+ "--data_root",
29
+ type=str,
30
+ default="/gemini/space/hsd/project/dataset/cc3m-wds",
31
+ help="数据根目录(必须包含 train/ 和 validation/)",
32
+ )
33
+ parser.add_argument(
34
+ "--jsonl_name",
35
+ type=str,
36
+ default="metadata.jsonl",
37
+ help="每个 split 下生成的 jsonl 文件名(默认 metadata.jsonl)",
38
+ )
39
+ parser.add_argument(
40
+ "--use_txt_caption",
41
+ action="store_true",
42
+ default=True,
43
+ help="优先读取同名 .txt 作为 caption(默认开启),否则回落到 .json",
44
+ )
45
+ parser.add_argument(
46
+ "--num_workers",
47
+ type=int,
48
+ default=32,
49
+ help="线程数(I/O 密集型建议 8~64 之间按机器调整)",
50
+ )
51
+ parser.add_argument(
52
+ "--max_images",
53
+ type=int,
54
+ default=None,
55
+ help="每个 split 最多处理多少张图片(None 表示全部,调试可用)",
56
+ )
57
+ return parser.parse_args()
58
+
59
+
60
+ def read_caption_from_txt(txt_path: Path) -> Optional[str]:
61
+ if not txt_path.exists():
62
+ return None
63
+ try:
64
+ with txt_path.open("r", encoding="utf-8") as f:
65
+ caption = f.read().strip()
66
+ return caption or None
67
+ except Exception:
68
+ return None
69
+
70
+
71
+ def read_caption_from_json(json_path: Path) -> Optional[str]:
72
+ if not json_path.exists():
73
+ return None
74
+ try:
75
+ with json_path.open("r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+ for key in ["caption", "text", "description"]:
78
+ if key in data and isinstance(data[key], str) and data[key].strip():
79
+ return data[key].strip()
80
+ except Exception:
81
+ return None
82
+ return None
83
+
84
+
85
+ def main() -> None:
86
+ args = parse_args()
87
+
88
+ data_root = Path(args.data_root).resolve()
89
+ if not data_root.exists():
90
+ raise FileNotFoundError(f"数据根目录不存在:{data_root}")
91
+
92
+ splits = [("train", data_root / "train"), ("validation", data_root / "validation")]
93
+ for split_name, split_dir in splits:
94
+ if not split_dir.exists():
95
+ raise FileNotFoundError(f"缺少目录:{split_dir}(需要 train/ 和 validation/)")
96
+
97
+ def iter_images(split_dir: Path):
98
+ for root, _dirs, files in os.walk(split_dir):
99
+ for name in files:
100
+ if name.lower().endswith((".jpg", ".jpeg", ".png")):
101
+ yield Path(root) / name
102
+
103
+ def process_one(img_path: Path, split_dir: Path) -> Optional[Tuple[str, str]]:
104
+ txt_path = img_path.with_suffix(".txt")
105
+ json_path = img_path.with_suffix(".json")
106
+
107
+ caption = None
108
+ if args.use_txt_caption:
109
+ caption = read_caption_from_txt(txt_path)
110
+ if caption is None:
111
+ caption = read_caption_from_json(json_path)
112
+ else:
113
+ caption = read_caption_from_json(json_path)
114
+ if caption is None:
115
+ caption = read_caption_from_txt(txt_path)
116
+
117
+ if caption is None:
118
+ return None
119
+
120
+ rel = img_path.relative_to(split_dir)
121
+ return str(rel).replace(os.sep, "/"), caption
122
+
123
+ for split_name, split_dir in splits:
124
+ jsonl_path = split_dir / args.jsonl_name
125
+
126
+ img_iter = iter_images(split_dir)
127
+ if args.max_images is not None:
128
+ img_iter = islice(img_iter, args.max_images)
129
+
130
+ # tqdm 需要可迭代对象,这里不预先收集列表以节省内存
131
+ # 进度条显示 processed 数量(total 可能未知)
132
+ def _task_iter():
133
+ for p in img_iter:
134
+ yield p
135
+
136
+ written = 0
137
+ with jsonl_path.open("w", encoding="utf-8") as f, ThreadPoolExecutor(max_workers=args.num_workers) as ex:
138
+ # executor.map 保持输入顺序;tqdm 显示处理进度
139
+ for result in tqdm(
140
+ ex.map(lambda p: process_one(p, split_dir), _task_iter()),
141
+ desc=f"[{split_name}] Processing",
142
+ ):
143
+ if result is None:
144
+ continue
145
+ file_name, caption = result
146
+ f.write(json.dumps({"file_name": file_name, "caption": caption}, ensure_ascii=False) + "\n")
147
+ written += 1
148
+
149
+ print(f"{split_name}: 写入 {written} 条 -> {jsonl_path}")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
154
+
155
+ # nohup python cc3m_render.py > cc3m_render.log 2>&1 &
download.log ADDED
File without changes
download_sd3_models.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ """
5
+ 只负责“按 train_lora_sd3.py 相同的方式”下载 SD3 及相关组件到默认 HF cache:
6
+ 通过依次调用 `from_pretrained(..., subfolder=...)` 来触发下载。
7
+
8
+ 会下载的子目录(与 train_lora_sd3.py 一致):
9
+ tokenizer, tokenizer_2, tokenizer_3,
10
+ text_encoder, text_encoder_2, text_encoder_3,
11
+ scheduler, vae, transformer
12
+
13
+ 用法:
14
+ python download_sd3_models.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers
15
+
16
+ 下载完成后训练可离线:
17
+ export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1
18
+ python train_lora_sd3.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers ...
19
+ 或直接指向你已经下载好的本地 repo 目录。
20
+ """
21
+
22
+ import argparse
23
+ import gc
24
+ from typing import Optional
25
+
26
+ import torch
27
+ from diffusers import StableDiffusion3Pipeline
28
+
29
+
30
+ def parse_args() -> argparse.Namespace:
31
+ p = argparse.ArgumentParser(description="Download SD3 via a single from_pretrained call (cache warmup only).")
32
+ p.add_argument(
33
+ "--pretrained_model_name_or_path",
34
+ type=str,
35
+ default="stabilityai/stable-diffusion-3-medium-diffusers",
36
+ help="模型 repo id 或本地路径(与 train_lora_sd3.py 参数一致",
37
+ )
38
+ p.add_argument("--revision", type=str, default=None, help="可选:下载特定 revision/branch/tag")
39
+ p.add_argument("--variant", type=str, default=None, help="可选:如 fp16 等 variant(若仓库提供)")
40
+ p.add_argument("--cache_dir", type=str, default=None, help="可选:自定义 HF cache_dir;默认用系统/用户默认")
41
+ return p.parse_args()
42
+
43
+
44
+ def main() -> None:
45
+ args = parse_args()
46
+
47
+ model = args.pretrained_model_name_or_path
48
+
49
+ # 最简单:直接加载整条 pipeline,触发把其依赖的全部组件下载进默认 HF cache
50
+ # 注意:from_pretrained 下载的是 pipeline 会用到的文件;本脚本目的就是让训练时不再联网。
51
+ pipe = StableDiffusion3Pipeline.from_pretrained(
52
+ model,
53
+ revision=args.revision,
54
+ variant=args.variant,
55
+ cache_dir=args.cache_dir,
56
+ low_cpu_mem_usage=True,
57
+ )
58
+
59
+ # 释放内存(下载已完成)
60
+ del pipe
61
+ gc.collect()
62
+ if torch.cuda.is_available():
63
+ torch.cuda.empty_cache()
64
+
65
+ print("下载/缓存预热完成。后续训练可设置离线环境变量避免联网:")
66
+ print(" export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
71
+
eval_baseline.log ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/1 [00:00<?, ?it/s]2026-03-21 21:11:45.919794: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
 
 
 
1
  0%| | 0/211 [00:00<?, ?it/s]
2
  0%| | 1/211 [00:02<07:04, 2.02s/it]
3
  1%| | 2/211 [00:03<06:56, 2.00s/it]
4
  1%|▏ | 3/211 [00:05<06:50, 1.98s/it]
5
  2%|▏ | 4/211 [00:07<06:42, 1.95s/it]
6
  2%|▏ | 5/211 [00:09<06:25, 1.87s/it]
7
  3%|▎ | 6/211 [00:11<06:20, 1.85s/it]
8
  3%|▎ | 7/211 [00:13<06:46, 1.99s/it]
9
  4%|▍ | 8/211 [00:15<06:56, 2.05s/it]
10
  4%|▍ | 9/211 [00:30<20:28, 6.08s/it]
11
  5%|▍ | 10/211 [00:32<15:51, 4.73s/it]
12
  5%|▌ | 11/211 [00:34<12:42, 3.81s/it]
13
  6%|▌ | 12/211 [00:35<10:26, 3.15s/it]
14
  6%|▌ | 13/211 [00:37<08:55, 2.71s/it]
15
  7%|▋ | 14/211 [00:39<08:09, 2.49s/it]
16
  7%|▋ | 15/211 [00:41<07:32, 2.31s/it]
17
  8%|▊ | 16/211 [00:43<06:52, 2.12s/it]
18
  8%|▊ | 17/211 [00:44<06:35, 2.04s/it]
19
  9%|▊ | 18/211 [00:46<06:18, 1.96s/it]
20
  9%|▉ | 19/211 [00:48<06:16, 1.96s/it]
21
  9%|▉ | 20/211 [00:50<06:20, 1.99s/it]
22
  10%|▉ | 21/211 [00:52<06:01, 1.91s/it]
23
  10%|█ | 22/211 [00:54<05:51, 1.86s/it]
24
  11%|█ | 23/211 [00:56<05:48, 1.85s/it]
25
  11%|█▏ | 24/211 [00:57<05:44, 1.84s/it]
26
  12%|█▏ | 25/211 [00:59<05:44, 1.85s/it]
27
  12%|█▏ | 26/211 [01:01<05:38, 1.83s/it]
28
  13%|█▎ | 27/211 [01:03<05:33, 1.81s/it]
29
  13%|█▎ | 28/211 [01:05<06:02, 1.98s/it]
30
  14%|█▎ | 29/211 [01:07<05:49, 1.92s/it]
31
  14%|█▍ | 30/211 [01:09<05:33, 1.84s/it]
32
  15%|█▍ | 31/211 [01:11<05:38, 1.88s/it]
33
  15%|█▌ | 32/211 [01:12<05:32, 1.86s/it]
34
  16%|█▌ | 33/211 [01:14<05:26, 1.83s/it]
35
  16%|█▌ | 34/211 [01:16<05:20, 1.81s/it]
36
  17%|█▋ | 35/211 [01:18<05:16, 1.80s/it]
37
  17%|█▋ | 36/211 [01:19<05:11, 1.78s/it]
38
  18%|█▊ | 37/211 [01:21<05:06, 1.76s/it]
39
  18%|█▊ | 38/211 [01:23<05:22, 1.86s/it]
40
  18%|█▊ | 39/211 [01:25<05:12, 1.82s/it]
41
  19%|█▉ | 40/211 [01:35<11:51, 4.16s/it]
42
  19%|█▉ | 41/211 [01:36<09:45, 3.44s/it]
43
  20%|█▉ | 42/211 [01:38<08:27, 3.00s/it]
44
  20%|██ | 43/211 [01:40<07:25, 2.65s/it]
45
  21%|██ | 44/211 [01:42<06:46, 2.43s/it]
46
  21%|██▏ | 45/211 [01:44<06:28, 2.34s/it]
47
  22%|██▏ | 46/211 [01:46<05:56, 2.16s/it]
48
  22%|██▏ | 47/211 [01:48<05:55, 2.17s/it]
49
  23%|██▎ | 48/211 [01:50<05:30, 2.03s/it]
50
  23%|██▎ | 49/211 [01:52<05:22, 1.99s/it]
51
  24%|██▎ | 50/211 [01:54<05:11, 1.94s/it]
52
  24%|██▍ | 51/211 [01:56<05:18, 1.99s/it]
53
  25%|██▍ | 52/211 [01:57<05:02, 1.90s/it]
54
  25%|██▌ | 53/211 [01:59<04:52, 1.85s/it]
55
  26%|██▌ | 54/211 [02:01<04:42, 1.80s/it]
56
  26%|██▌ | 55/211 [02:03<04:40, 1.80s/it]
57
  27%|██▋ | 56/211 [02:05<04:47, 1.85s/it]
58
  27%|██▋ | 57/211 [02:06<04:38, 1.81s/it]
59
  27%|██▋ | 58/211 [02:08<04:40, 1.83s/it]
60
  28%|██▊ | 59/211 [02:10<04:34, 1.81s/it]
61
  28%|██▊ | 60/211 [02:12<04:31, 1.80s/it]
62
  29%|██▉ | 61/211 [02:13<04:24, 1.76s/it]
63
  29%|██▉ | 62/211 [02:15<04:19, 1.74s/it]
64
  30%|██▉ | 63/211 [02:17<04:25, 1.79s/it]
65
  30%|███ | 64/211 [02:19<04:22, 1.78s/it]
66
  31%|███ | 65/211 [02:21<04:21, 1.79s/it]
67
  31%|███▏ | 66/211 [02:22<04:16, 1.77s/it]
68
  32%|███▏ | 67/211 [02:24<04:27, 1.85s/it]
69
  32%|███▏ | 68/211 [02:26<04:23, 1.84s/it]
70
  33%|███▎ | 69/211 [02:28<04:21, 1.84s/it]
71
  33%|███▎ | 70/211 [02:30<04:15, 1.81s/it]
72
  34%|███▎ | 71/211 [02:32<04:13, 1.81s/it]
73
  34%|███▍ | 72/211 [02:33<04:09, 1.79s/it]
74
  35%|███▍ | 73/211 [02:35<04:04, 1.77s/it]
75
  35%|███▌ | 74/211 [02:53<15:24, 6.75s/it]
76
  36%|███▌ | 75/211 [02:55<11:58, 5.28s/it]
77
  36%|███▌ | 76/211 [02:57<09:31, 4.24s/it]
78
  36%|███▋ | 77/211 [02:59<07:46, 3.48s/it]
79
  37%|███▋ | 78/211 [03:00<06:33, 2.96s/it]
80
  37%|███▋ | 79/211 [03:02<05:47, 2.64s/it]
81
  38%|███▊ | 80/211 [03:04<05:19, 2.44s/it]
82
  38%|███▊ | 81/211 [03:06<04:47, 2.22s/it]
83
  39%|███▉ | 82/211 [03:08<04:27, 2.07s/it]
84
  39%|███▉ | 83/211 [03:09<04:12, 1.97s/it]
85
  40%|███▉ | 84/211 [03:11<04:02, 1.91s/it]
86
  40%|████ | 85/211 [03:13<03:53, 1.85s/it]
87
  41%|████ | 86/211 [03:15<03:48, 1.83s/it]
88
  41%|████ | 87/211 [03:16<03:42, 1.79s/it]
89
  42%|████▏ | 88/211 [03:18<03:40, 1.79s/it]
90
  42%|████▏ | 89/211 [03:20<03:35, 1.77s/it]
91
  43%|████▎ | 90/211 [03:22<03:33, 1.77s/it]
92
  43%|████▎ | 91/211 [03:23<03:30, 1.76s/it]
93
  44%|████▎ | 92/211 [03:25<03:35, 1.81s/it]
94
  44%|████▍ | 93/211 [03:27<03:28, 1.76s/it]
95
  45%|████▍ | 94/211 [03:29<03:25, 1.75s/it]
96
  45%|████▌ | 95/211 [03:31<03:22, 1.74s/it]
97
  45%|████▌ | 96/211 [03:32<03:24, 1.77s/it]
98
  46%|████▌ | 97/211 [03:34<03:21, 1.77s/it]
99
  46%|████▋ | 98/211 [03:36<03:19, 1.77s/it]
100
  47%|████▋ | 99/211 [03:38<03:18, 1.77s/it]
101
  47%|████▋ | 100/211 [03:39<03:14, 1.75s/it]
102
  48%|████▊ | 101/211 [03:41<03:13, 1.75s/it]
103
  48%|████▊ | 102/211 [03:43<03:13, 1.77s/it]
104
  49%|████▉ | 103/211 [03:45<03:08, 1.75s/it]
105
  49%|████▉ | 104/211 [03:47<03:12, 1.79s/it]
106
  50%|████▉ | 105/211 [03:48<03:07, 1.77s/it]
107
  50%|█████ | 106/211 [04:10<13:22, 7.64s/it]
108
  51%|█████ | 107/211 [04:12<10:18, 5.94s/it]
109
  51%|█████ | 108/211 [04:13<08:02, 4.68s/it]
110
  52%|█████▏ | 109/211 [04:15<06:26, 3.79s/it]
111
  52%|█████▏ | 110/211 [04:17<05:18, 3.16s/it]
112
  53%|█████▎ | 111/211 [04:18<04:33, 2.74s/it]
113
  53%|█████▎ | 112/211 [04:20<04:09, 2.52s/it]
114
  54%|█████▎ | 113/211 [04:22<03:50, 2.35s/it]
115
  54%|█████▍ | 114/211 [04:24<03:30, 2.17s/it]
116
  55%|█████▍ | 115/211 [04:26<03:22, 2.11s/it]
117
  55%|█████▍ | 116/211 [04:46<11:49, 7.47s/it]
118
  55%|█████▌ | 117/211 [04:48<08:58, 5.73s/it]
119
  56%|█████▌ | 118/211 [04:50<07:08, 4.61s/it]
120
  56%|█████▋ | 119/211 [04:51<05:41, 3.71s/it]
121
  57%|█████▋ | 120/211 [04:53<04:42, 3.10s/it]
122
  57%|█████▋ | 121/211 [04:55<04:01, 2.69s/it]
123
  58%|█████▊ | 122/211 [04:57<03:33, 2.40s/it]
124
  58%|█████▊ | 123/211 [04:58<03:13, 2.20s/it]
125
  59%|█████▉ | 124/211 [05:00<03:01, 2.08s/it]
126
  59%|█████▉ | 125/211 [05:02<02:50, 1.98s/it]
127
  60%|█████▉ | 126/211 [05:03<02:38, 1.87s/it]
128
  60%|██████ | 127/211 [05:05<02:31, 1.80s/it]
129
  61%|██████ | 128/211 [05:07<02:25, 1.75s/it]
130
  61%|██████ | 129/211 [05:09<02:29, 1.82s/it]
131
  62%|██████▏ | 130/211 [05:10<02:24, 1.79s/it]
132
  62%|██████▏ | 131/211 [05:12<02:20, 1.76s/it]
133
  63%|██████▎ | 132/211 [05:14<02:17, 1.74s/it]
134
  63%|██████▎ | 133/211 [05:15<02:14, 1.73s/it]
135
  64%|██████▎ | 134/211 [05:18<02:23, 1.87s/it]
136
  64%|██████▍ | 135/211 [05:20<02:22, 1.87s/it]
137
  64%|██████▍ | 136/211 [05:21<02:15, 1.80s/it]
138
  65%|██████▍ | 137/211 [05:23<02:13, 1.81s/it]
139
  65%|██████▌ | 138/211 [05:25<02:10, 1.79s/it]
140
  66%|██████▌ | 139/211 [05:26<02:06, 1.75s/it]
141
  66%|██████▋ | 140/211 [05:28<02:03, 1.73s/it]
142
  67%|██████▋ | 141/211 [05:30<02:04, 1.78s/it]
143
  67%|██████▋ | 142/211 [05:32<02:05, 1.82s/it]
144
  68%|██████▊ | 143/211 [05:34<02:04, 1.83s/it]
145
  68%|██��███▊ | 144/211 [05:35<01:59, 1.78s/it]
146
  69%|██████▊ | 145/211 [05:37<01:57, 1.78s/it]
147
  69%|██████▉ | 146/211 [05:55<07:10, 6.63s/it]
148
  70%|██████▉ | 147/211 [05:57<05:30, 5.17s/it]
149
  70%|███████ | 148/211 [05:59<04:23, 4.17s/it]
150
  71%|███████ | 149/211 [06:00<03:31, 3.42s/it]
151
  71%|███████ | 150/211 [06:02<02:58, 2.92s/it]
152
  72%|███████▏ | 151/211 [06:04<02:33, 2.56s/it]
153
  72%|███████▏ | 152/211 [06:06<02:15, 2.29s/it]
154
  73%|███████▎ | 153/211 [06:07<02:02, 2.12s/it]
155
  73%|███████▎ | 154/211 [06:10<02:15, 2.38s/it]
156
  73%|███████▎ | 155/211 [06:12<02:00, 2.15s/it]
157
  74%|███████▍ | 156/211 [06:14<01:50, 2.00s/it]
158
  74%|███████▍ | 157/211 [06:15<01:43, 1.91s/it]
159
  75%|███████▍ | 158/211 [06:17<01:38, 1.86s/it]
160
  75%|███████▌ | 159/211 [06:19<01:37, 1.88s/it]
161
  76%|███████▌ | 160/211 [06:21<01:32, 1.82s/it]
162
  76%|███████▋ | 161/211 [06:22<01:29, 1.79s/it]
163
  77%|███████▋ | 162/211 [06:24<01:31, 1.87s/it]
164
  77%|███████▋ | 163/211 [06:26<01:27, 1.83s/it]
165
  78%|███████▊ | 164/211 [06:29<01:36, 2.05s/it]
166
  78%|███████▊ | 165/211 [06:31<01:32, 2.01s/it]
167
  79%|███████▊ | 166/211 [06:32<01:27, 1.94s/it]
168
  79%|███████▉ | 167/211 [06:34<01:21, 1.85s/it]
169
  80%|███████▉ | 168/211 [06:36<01:17, 1.79s/it]
170
  80%|████████ | 169/211 [06:38<01:19, 1.89s/it]
171
  81%|████████ | 170/211 [06:40<01:17, 1.88s/it]
172
  81%|████████ | 171/211 [06:41<01:14, 1.85s/it]
173
  82%|████████▏ | 172/211 [06:43<01:11, 1.84s/it]
174
  82%|████████▏ | 173/211 [06:45<01:08, 1.81s/it]
175
  82%|████████▏ | 174/211 [06:47<01:07, 1.84s/it]
176
  83%|████████▎ | 175/211 [06:49<01:04, 1.78s/it]
177
  83%|████████▎ | 176/211 [06:50<01:02, 1.79s/it]
178
  84%|████████▍ | 177/211 [06:52<01:00, 1.77s/it]
179
  84%|████████▍ | 178/211 [06:54<01:03, 1.92s/it]
180
  85%|████████▍ | 179/211 [06:56<00:59, 1.87s/it]
181
  85%|████████▌ | 180/211 [06:58<00:56, 1.81s/it]
182
  86%|████████▌ | 181/211 [06:59<00:53, 1.77s/it]
183
  86%|████████▋ | 182/211 [07:01<00:53, 1.84s/it]
184
  87%|████████▋ | 183/211 [07:03<00:50, 1.81s/it]
185
  87%|████████▋ | 184/211 [07:05<00:49, 1.83s/it]
186
  88%|████████▊ | 185/211 [07:07<00:48, 1.86s/it]
187
  88%|████████▊ | 186/211 [07:09<00:45, 1.83s/it]
188
  89%|████████▊ | 187/211 [07:11<00:43, 1.82s/it]
189
  89%|████████▉ | 188/211 [07:12<00:42, 1.83s/it]
190
  90%|████████▉ | 189/211 [07:14<00:39, 1.77s/it]
191
  90%|█████████ | 190/211 [07:16<00:37, 1.77s/it]
192
  91%|█████████ | 191/211 [07:17<00:34, 1.73s/it]
193
  91%|█████████ | 192/211 [07:19<00:32, 1.71s/it]
194
  91%|█████████▏| 193/211 [07:21<00:30, 1.72s/it]
195
  92%|█████████▏| 194/211 [07:23<00:29, 1.72s/it]
196
  92%|█████████▏| 195/211 [07:24<00:27, 1.72s/it]
197
  93%|█████████▎| 196/211 [07:26<00:25, 1.70s/it]
198
  93%|█████████▎| 197/211 [07:28<00:24, 1.75s/it]
199
  94%|█████████▍| 198/211 [07:30<00:23, 1.78s/it]
200
  94%|█████████▍| 199/211 [07:31<00:21, 1.76s/it]
201
  95%|█████████▍| 200/211 [07:33<00:19, 1.78s/it]
202
  95%|█████████▌| 201/211 [07:35<00:17, 1.76s/it]
203
  96%|█████████▌| 202/211 [07:37<00:16, 1.80s/it]
204
  96%|█████████▌| 203/211 [07:39<00:14, 1.79s/it]
205
  97%|█████████▋| 204/211 [07:40<00:12, 1.76s/it]
206
  97%|█████████▋| 205/211 [07:42<00:10, 1.74s/it]
207
  98%|█████████▊| 206/211 [07:44<00:08, 1.75s/it]
208
  98%|█████████▊| 207/211 [07:45<00:06, 1.72s/it]
209
  99%|█████████▊| 208/211 [07:47<00:05, 1.70s/it]
210
  99%|█████████▉| 209/211 [07:49<00:03, 1.68s/it]
 
 
 
211
  0%| | 0/469 [00:00<?, ?it/s]
212
  0%| | 1/469 [00:02<18:09, 2.33s/it]
213
  0%| | 2/469 [00:04<15:10, 1.95s/it]
214
  1%| | 3/469 [00:05<14:14, 1.83s/it]
215
  1%| | 4/469 [00:07<13:52, 1.79s/it]
216
  1%| | 5/469 [00:09<13:23, 1.73s/it]
217
  1%|▏ | 6/469 [00:12<17:01, 2.21s/it]
218
  1%|▏ | 7/469 [00:14<16:26, 2.14s/it]
219
  2%|▏ | 8/469 [00:15<15:33, 2.02s/it]
220
  2%|▏ | 9/469 [00:17<15:18, 2.00s/it]
221
  2%|▏ | 10/469 [00:19<14:43, 1.92s/it]
222
  2%|▏ | 11/469 [00:21<14:04, 1.84s/it]
223
  3%|▎ | 12/469 [00:23<14:21, 1.89s/it]
224
  3%|▎ | 13/469 [00:25<14:16, 1.88s/it]
225
  3%|▎ | 14/469 [00:27<14:15, 1.88s/it]
226
  3%|▎ | 15/469 [00:28<14:05, 1.86s/it]
227
  3%|▎ | 16/469 [00:30<14:24, 1.91s/it]
228
  4%|▎ | 17/469 [00:32<14:05, 1.87s/it]
229
  4%|▍ | 18/469 [00:34<14:13, 1.89s/it]
230
  4%|▍ | 19/469 [00:36<13:44, 1.83s/it]
231
  4%|▍ | 20/469 [00:37<13:14, 1.77s/it]
232
  4%|▍ | 21/469 [00:39<13:07, 1.76s/it]
233
  5%|▍ | 22/469 [00:41<12:51, 1.72s/it]
234
  5%|▍ | 23/469 [00:43<12:56, 1.74s/it]
235
  5%|▌ | 24/469 [00:44<12:53, 1.74s/it]
236
  5%|▌ | 25/469 [00:46<12:43, 1.72s/it]
237
  6%|▌ | 26/469 [00:48<12:47, 1.73s/it]
238
  6%|▌ | 27/469 [00:49<12:34, 1.71s/it]
239
  6%|▌ | 28/469 [00:51<12:53, 1.75s/it]
240
  6%|▌ | 29/469 [00:53<12:38, 1.72s/it]
241
  6%|▋ | 30/469 [00:55<12:33, 1.72s/it]
242
  7%|▋ | 31/469 [00:56<12:17, 1.68s/it]
243
  7%|▋ | 32/469 [00:58<12:27, 1.71s/it]
244
  7%|▋ | 33/469 [01:00<12:22, 1.70s/it]
245
  7%|▋ | 34/469 [01:03<16:09, 2.23s/it]
246
  7%|▋ | 35/469 [01:05<15:00, 2.08s/it]
247
  8%|▊ | 36/469 [01:07<14:10, 1.96s/it]
248
  8%|▊ | 37/469 [01:08<13:48, 1.92s/it]
249
  8%|▊ | 38/469 [01:10<13:13, 1.84s/it]
250
  8%|▊ | 39/469 [01:12<13:28, 1.88s/it]
251
  9%|▊ | 40/469 [01:14<13:04, 1.83s/it]
252
  9%|▊ | 41/469 [01:16<13:04, 1.83s/it]
253
  9%|▉ | 42/469 [01:19<15:53, 2.23s/it]
254
  9%|▉ | 43/469 [01:21<15:36, 2.20s/it]
255
  9%|▉ | 44/469 [01:23<14:48, 2.09s/it]
256
  10%|▉ | 45/469 [01:24<13:55, 1.97s/it]
257
  10%|▉ | 46/469 [01:26<13:20, 1.89s/it]
258
  10%|█ | 47/469 [01:28<13:24, 1.91s/it]
259
  10%|█ | 48/469 [01:30<13:15, 1.89s/it]
260
  10%|█ | 49/469 [01:31<12:39, 1.81s/it]
261
  11%|█ | 50/469 [01:33<12:26, 1.78s/it]
262
  11%|█ | 51/469 [01:35<12:47, 1.84s/it]
263
  11%|█ | 52/469 [01:37<12:28, 1.79s/it]
264
  11%|█▏ | 53/469 [01:39<12:09, 1.75s/it]
265
  12%|█▏ | 54/469 [01:40<12:13, 1.77s/it]
266
  12%|█▏ | 55/469 [01:42<12:03, 1.75s/it]
267
  12%|█▏ | 56/469 [01:44<12:13, 1.78s/it]
268
  12%|█▏ | 57/469 [01:46<12:27, 1.81s/it]
269
  12%|█▏ | 58/469 [01:48<12:53, 1.88s/it]
270
  13%|█▎ | 59/469 [01:50<12:46, 1.87s/it]
271
  13%|█▎ | 60/469 [01:51<12:25, 1.82s/it]
272
  13%|█▎ | 61/469 [01:53<12:01, 1.77s/it]
273
  13%|█▎ | 62/469 [01:55<11:49, 1.74s/it]
274
  13%|█▎ | 63/469 [01:56<11:43, 1.73s/it]
275
  14%|█▎ | 64/469 [01:58<11:37, 1.72s/it]
276
  14%|█▍ | 65/469 [02:00<11:38, 1.73s/it]
277
  14%|█▍ | 66/469 [02:02<11:28, 1.71s/it]
278
  14%|█▍ | 67/469 [02:03<11:30, 1.72s/it]
279
  14%|█▍ | 68/469 [02:05<11:30, 1.72s/it]
280
  15%|█▍ | 69/469 [02:07<12:03, 1.81s/it]
281
  15%|█▍ | 70/469 [02:09<11:46, 1.77s/it]
282
  15%|█▌ | 71/469 [02:10<11:36, 1.75s/it]
283
  15%|█▌ | 72/469 [02:12<11:34, 1.75s/it]
284
  16%|█▌ | 73/469 [02:14<12:02, 1.82s/it]
285
  16%|█▌ | 74/469 [02:16<11:41, 1.78s/it]
286
  16%|█▌ | 75/469 [02:18<12:10, 1.85s/it]
287
  16%|█▌ | 76/469 [02:20<12:04, 1.84s/it]
288
  16%|█▋ | 77/469 [02:21<11:44, 1.80s/it]
289
  17%|█▋ | 78/469 [02:23<11:44, 1.80s/it]
290
  17%|█▋ | 79/469 [02:25<11:44, 1.81s/it]
291
  17%|█▋ | 80/469 [02:27<11:43, 1.81s/it]
292
  17%|█▋ | 81/469 [02:28<11:22, 1.76s/it]
293
  17%|█▋ | 82/469 [02:30<11:10, 1.73s/it]
294
  18%|█▊ | 83/469 [02:32<11:12, 1.74s/it]
295
  18%|█▊ | 84/469 [02:34<11:11, 1.75s/it]
296
  18%|█▊ | 85/469 [02:36<11:32, 1.80s/it]
297
  18%|█▊ | 86/469 [02:37<11:32, 1.81s/it]
298
  19%|█▊ | 87/469 [02:39<11:16, 1.77s/it]
299
  19%|█▉ | 88/469 [02:41<11:07, 1.75s/it]
300
  19%|█▉ | 89/469 [02:42<11:00, 1.74s/it]
301
  19%|█▉ | 90/469 [02:44<11:09, 1.77s/it]
302
  19%|█▉ | 91/469 [02:46<11:15, 1.79s/it]
303
  20%|█▉ | 92/469 [02:48<11:09, 1.78s/it]
304
  20%|█▉ | 93/469 [02:50<10:57, 1.75s/it]
305
  20%|██ | 94/469 [02:51<10:55, 1.75s/it]
306
  20%|██ | 95/469 [02:53<10:50, 1.74s/it]
307
  20%|██ | 96/469 [02:55<10:42, 1.72s/it]
308
  21%|██ | 97/469 [02:56<10:37, 1.71s/it]
309
  21%|██ | 98/469 [02:58<10:51, 1.76s/it]
310
  21%|██ | 99/469 [03:00<10:52, 1.76s/it]
311
  21%|██▏ | 100/469 [03:02<11:08, 1.81s/it]
312
  22%|██▏ | 101/469 [03:04<10:52, 1.77s/it]
313
  22%|██▏ | 102/469 [03:05<10:48, 1.77s/it]
314
  22%|██▏ | 103/469 [03:07<10:35, 1.74s/it]
315
  22%|██▏ | 104/469 [03:09<10:28, 1.72s/it]
316
  22%|██▏ | 105/469 [03:11<10:38, 1.75s/it]
317
  23%|██▎ | 106/469 [03:12<10:29, 1.73s/it]
318
  23%|██▎ | 107/469 [03:14<10:22, 1.72s/it]
319
  23%|██▎ | 108/469 [03:16<10:41, 1.78s/it]
320
  23%|██▎ | 109/469 [03:18<10:29, 1.75s/it]
321
  23%|██▎ | 110/469 [03:19<10:32, 1.76s/it]
322
  24%|██▎ | 111/469 [03:21<10:25, 1.75s/it]
323
  24%|██▍ | 112/469 [03:23<10:24, 1.75s/it]
324
  24%|██▍ | 113/469 [03:25<10:37, 1.79s/it]
325
  24%|██▍ | 114/469 [03:26<10:24, 1.76s/it]
326
  25%|██▍ | 115/469 [03:28<10:27, 1.77s/it]
327
  25%|██▍ | 116/469 [03:30<10:22, 1.76s/it]
328
  25%|██▍ | 117/469 [03:32<10:28, 1.79s/it]
329
  25%|██▌ | 118/469 [03:34<10:52, 1.86s/it]
330
  25%|██▌ | 119/469 [03:36<11:00, 1.89s/it]
331
  26%|██▌ | 120/469 [03:37<10:42, 1.84s/it]
332
  26%|██▌ | 121/469 [03:39<10:37, 1.83s/it]
333
  26%|██▌ | 122/469 [03:41<10:17, 1.78s/it]
334
  26%|██▌ | 123/469 [03:43<10:21, 1.80s/it]
335
  26%|██▋ | 124/469 [03:45<10:14, 1.78s/it]
336
  27%|██▋ | 125/469 [03:46<10:15, 1.79s/it]
337
  27%|██▋ | 126/469 [03:48<10:12, 1.79s/it]
338
  27%|██▋ | 127/469 [03:50<09:59, 1.75s/it]
339
  27%|██▋ | 128/469 [03:51<09:49, 1.73s/it]
340
  28%|██▊ | 129/469 [03:53<09:53, 1.75s/it]
341
  28%|██▊ | 130/469 [03:55<09:42, 1.72s/it]
342
  28%|██▊ | 131/469 [03:57<09:37, 1.71s/it]
343
  28%|██▊ | 132/469 [03:58<09:49, 1.75s/it]
344
  28%|██▊ | 133/469 [04:00<09:59, 1.78s/it]
345
  29%|██▊ | 134/469 [04:02<10:22, 1.86s/it]
346
  29%|██▉ | 135/469 [04:04<10:20, 1.86s/it]
347
  29%|██▉ | 136/469 [04:06<10:26, 1.88s/it]
348
  29%|██▉ | 137/469 [04:08<10:02, 1.81s/it]
349
  29%|██▉ | 138/469 [04:10<10:16, 1.86s/it]
350
  30%|██▉ | 139/469 [04:11<09:58, 1.81s/it]
351
  30%|██▉ | 140/469 [04:13<09:59, 1.82s/it]
352
  30%|███ | 141/469 [04:15<09:40, 1.77s/it]
353
  30%|███ | 142/469 [04:17<09:28, 1.74s/it]
354
  30%|███ | 143/469 [04:18<09:16, 1.71s/it]
355
  31%|███ | 144/469 [04:20<09:13, 1.70s/it]
356
  31%|███ | 145/469 [04:22<09:28, 1.75s/it]
357
  31%|███ | 146/469 [04:24<09:18, 1.73s/it]
358
  31%|███▏ | 147/469 [04:25<09:35, 1.79s/it]
359
  32%|███▏ | 148/469 [04:27<09:28, 1.77s/it]
360
  32%|███▏ | 149/469 [04:29<09:12, 1.73s/it]
361
  32%|███▏ | 150/469 [04:31<09:11, 1.73s/it]
362
  32%|███▏ | 151/469 [04:32<09:18, 1.76s/it]
363
  32%|███▏ | 152/469 [04:34<09:23, 1.78s/it]
364
  33%|███▎ | 153/469 [04:36<09:21, 1.78s/it]
365
  33%|███▎ | 154/469 [04:38<09:11, 1.75s/it]
366
  33%|███▎ | 155/469 [04:39<09:09, 1.75s/it]
367
  33%|███▎ | 156/469 [04:42<10:07, 1.94s/it]
368
  33%|███▎ | 157/469 [04:44<10:39, 2.05s/it]
369
  34%|███▎ | 158/469 [04:46<10:12, 1.97s/it]
370
  34%|███▍ | 159/469 [04:48<09:42, 1.88s/it]
371
  34%|███▍ | 160/469 [04:49<09:42, 1.88s/it]
372
  34%|███▍ | 161/469 [04:51<09:24, 1.83s/it]
373
  35%|███▍ | 162/469 [04:53<09:08, 1.79s/it]
374
  35%|███▍ | 163/469 [04:55<09:01, 1.77s/it]
375
  35%|███▍ | 164/469 [04:56<09:08, 1.80s/it]
376
  35%|███▌ | 165/469 [04:58<09:07, 1.80s/it]
377
  35%|███▌ | 166/469 [05:00<09:06, 1.80s/it]
378
  36%|███▌ | 167/469 [05:02<09:00, 1.79s/it]
379
  36%|███▌ | 168/469 [05:03<08:47, 1.75s/it]
380
  36%|███▌ | 169/469 [05:05<08:51, 1.77s/it]
381
  36%|███▌ | 170/469 [05:07<08:35, 1.73s/it]
382
  36%|███▋ | 171/469 [05:09<08:25, 1.70s/it]
383
  37%|███▋ | 172/469 [05:10<08:24, 1.70s/it]
384
  37%|███▋ | 173/469 [05:12<08:20, 1.69s/it]
385
  37%|███▋ | 174/469 [05:14<08:20, 1.70s/it]
386
  37%|███▋ | 175/469 [05:16<08:42, 1.78s/it]
387
  38%|███▊ | 176/469 [05:17<08:34, 1.76s/it]
388
  38%|███▊ | 177/469 [05:19<08:31, 1.75s/it]
389
  38%|███▊ | 178/469 [05:21<08:34, 1.77s/it]
390
  38%|███▊ | 179/469 [05:22<08:22, 1.73s/it]
391
  38%|███▊ | 180/469 [05:24<08:22, 1.74s/it]
392
  39%|███▊ | 181/469 [05:26<08:14, 1.72s/it]
393
  39%|███▉ | 182/469 [05:28<08:19, 1.74s/it]
394
  39%|███▉ | 183/469 [05:29<08:17, 1.74s/it]
395
  39%|███▉ | 184/469 [05:31<08:29, 1.79s/it]
396
  39%|███▉ | 185/469 [05:33<08:36, 1.82s/it]
397
  40%|███▉ | 186/469 [05:35<08:36, 1.82s/it]
398
  40%|███▉ | 187/469 [05:37<08:29, 1.81s/it]
399
  40%|████ | 188/469 [05:39<08:45, 1.87s/it]
400
  40%|████ | 189/469 [05:41<08:29, 1.82s/it]
401
  41%|████ | 190/469 [05:42<08:28, 1.82s/it]
402
  41%|████ | 191/469 [05:44<08:18, 1.79s/it]
403
  41%|████ | 192/469 [05:46<08:10, 1.77s/it]
404
  41%|████ | 193/469 [05:47<07:58, 1.73s/it]
405
  41%|████▏ | 194/469 [05:49<07:52, 1.72s/it]
406
  42%|████▏ | 195/469 [05:51<08:08, 1.78s/it]
407
  42%|████▏ | 196/469 [05:53<08:19, 1.83s/it]
408
  42%|████▏ | 197/469 [05:55<08:35, 1.90s/it]
409
  42%|████▏ | 198/469 [05:57<08:28, 1.88s/it]
410
  42%|████▏ | 199/469 [05:59<08:17, 1.84s/it]
411
  43%|████▎ | 200/469 [06:01<08:22, 1.87s/it]
412
  43%|████▎ | 201/469 [06:02<08:09, 1.83s/it]
413
  43%|████▎ | 202/469 [06:04<08:08, 1.83s/it]
414
  43%|████▎ | 203/469 [06:06<08:21, 1.88s/it]
415
  43%|████▎ | 204/469 [06:08<07:59, 1.81s/it]
416
  44%|████▎ | 205/469 [06:09<07:46, 1.77s/it]
417
  44%|████▍ | 206/469 [06:11<07:38, 1.74s/it]
418
  44%|████▍ | 207/469 [06:13<07:27, 1.71s/it]
419
  44%|████▍ | 208/469 [06:15<07:40, 1.76s/it]
420
  45%|████▍ | 209/469 [06:16<07:35, 1.75s/it]
421
  45%|████▍ | 210/469 [06:18<07:30, 1.74s/it]
422
  45%|████▍ | 211/469 [06:20<07:35, 1.77s/it]
423
  45%|████▌ | 212/469 [06:22<07:32, 1.76s/it]
424
  45%|████▌ | 213/469 [06:23<07:29, 1.76s/it]
425
  46%|████▌ | 214/469 [06:25<07:38, 1.80s/it]
426
  46%|████▌ | 215/469 [06:27<07:34, 1.79s/it]
427
  46%|████▌ | 216/469 [06:29<07:28, 1.77s/it]
428
  46%|████▋ | 217/469 [06:30<07:20, 1.75s/it]
429
  46%|████▋ | 218/469 [06:32<07:08, 1.71s/it]
430
  47%|████▋ | 219/469 [06:34<07:04, 1.70s/it]
431
  47%|████▋ | 220/469 [06:35<07:02, 1.70s/it]
432
  47%|████▋ | 221/469 [06:38<07:54, 1.91s/it]
433
  47%|████▋ | 222/469 [06:40<07:49, 1.90s/it]
434
  48%|████▊ | 223/469 [06:42<07:37, 1.86s/it]
435
  48%|████▊ | 224/469 [06:43<07:18, 1.79s/it]
436
  48%|████▊ | 225/469 [06:45<07:42, 1.90s/it]
437
  48%|████▊ | 226/469 [06:48<09:05, 2.25s/it]
438
  48%|████▊ | 227/469 [06:50<08:33, 2.12s/it]
439
  49%|████▊ | 228/469 [06:52<07:55, 1.97s/it]
440
  49%|████▉ | 229/469 [06:54<07:40, 1.92s/it]
441
  49%|████▉ | 230/469 [06:55<07:33, 1.90s/it]
442
  49%|████▉ | 231/469 [06:57<07:17, 1.84s/it]
443
  49%|████▉ | 232/469 [06:59<07:05, 1.79s/it]
444
  50%|████▉ | 233/469 [07:01<07:03, 1.80s/it]
445
  50%|████▉ | 234/469 [07:02<06:59, 1.78s/it]
446
  50%|█████ | 235/469 [07:05<08:13, 2.11s/it]
447
  50%|█████ | 236/469 [07:07<07:39, 1.97s/it]
448
  51%|█████ | 237/469 [07:09<07:15, 1.88s/it]
449
  51%|█████ | 238/469 [07:10<07:11, 1.87s/it]
450
  51%|█████ | 239/469 [07:12<06:55, 1.81s/it]
451
  51%|█████ | 240/469 [07:14<07:03, 1.85s/it]
452
  51%|█████▏ | 241/469 [07:16<06:54, 1.82s/it]
453
  52%|█████▏ | 242/469 [07:17<06:41, 1.77s/it]
454
  52%|█████▏ | 243/469 [07:19<06:33, 1.74s/it]
455
  52%|█████▏ | 244/469 [07:21<06:28, 1.73s/it]
456
  52%|█████▏ | 245/469 [07:24<07:39, 2.05s/it]
457
  52%|█████▏ | 246/469 [07:25<07:19, 1.97s/it]
458
  53%|█████▎ | 247/469 [07:27<07:08, 1.93s/it]
459
  53%|█████▎ | 248/469 [07:29<06:57, 1.89s/it]
460
  53%|█████▎ | 249/469 [07:31<06:40, 1.82s/it]
461
  53%|█████▎ | 250/469 [07:33<06:38, 1.82s/it]
462
  54%|█████▎ | 251/469 [07:34<06:25, 1.77s/it]
463
  54%|█████▎ | 252/469 [07:36<06:15, 1.73s/it]
464
  54%|█████▍ | 253/469 [07:37<06:09, 1.71s/it]
465
  54%|█████▍ | 254/469 [07:39<06:05, 1.70s/it]
466
  54%|█████▍ | 255/469 [07:41<06:02, 1.70s/it]
467
  55%|█████▍ | 256/469 [07:42<05:58, 1.68s/it]
468
  55%|█████▍ | 257/469 [07:44<06:01, 1.71s/it]
469
  55%|█████▌ | 258/469 [07:47<06:54, 1.96s/it]
470
  55%|█████▌ | 259/469 [07:49<06:35, 1.88s/it]
471
  55%|█████▌ | 260/469 [07:52<08:26, 2.42s/it]
472
  56%|█████▌ | 261/469 [07:54<07:35, 2.19s/it]
473
  56%|█████▌ | 262/469 [07:56<07:07, 2.07s/it]
474
  56%|█████▌ | 263/469 [07:57<06:47, 1.98s/it]
475
  56%|█████▋ | 264/469 [07:59<06:25, 1.88s/it]
476
  57%|█████▋ | 265/469 [08:01<06:21, 1.87s/it]
477
  57%|█████▋ | 266/469 [08:03<06:20, 1.87s/it]
478
  57%|█████▋ | 267/469 [08:04<06:08, 1.82s/it]
479
  57%|█████▋ | 268/469 [08:06<06:07, 1.83s/it]
480
  57%|█████▋ | 269/469 [08:08<05:54, 1.77s/it]
481
  58%|█████▊ | 270/469 [08:10<05:59, 1.80s/it]
482
  58%|█████▊ | 271/469 [08:12<05:51, 1.77s/it]
483
  58%|█████▊ | 272/469 [08:13<05:58, 1.82s/it]
484
  58%|█████▊ | 273/469 [08:15<06:06, 1.87s/it]
485
  58%|█████▊ | 274/469 [08:17<05:52, 1.81s/it]
486
  59%|█████▊ | 275/469 [08:19<05:50, 1.81s/it]
487
  59%|█████▉ | 276/469 [08:21<05:42, 1.77s/it]
488
  59%|█████▉ | 277/469 [08:22<05:37, 1.76s/it]
489
  59%|█████▉ | 278/469 [08:24<05:34, 1.75s/it]
490
  59%|█████▉ | 279/469 [08:26<05:38, 1.78s/it]
491
  60%|█████▉ | 280/469 [08:28<05:38, 1.79s/it]
492
  60%|█████▉ | 281/469 [08:30<05:39, 1.81s/it]
493
  60%|██████ | 282/469 [08:31<05:38, 1.81s/it]
494
  60%|██████ | 283/469 [08:33<05:30, 1.78s/it]
495
  61%|██████ | 284/469 [08:35<05:26, 1.77s/it]
496
  61%|██████ | 285/469 [08:36<05:17, 1.73s/it]
497
  61%|██████ | 286/469 [08:38<05:14, 1.72s/it]
498
  61%|██████ | 287/469 [08:40<05:10, 1.71s/it]
499
  61%|██████▏ | 288/469 [08:41<05:05, 1.69s/it]
500
  62%|██████▏ | 289/469 [08:43<05:04, 1.69s/it]
501
  62%|██████▏ | 290/469 [08:45<05:11, 1.74s/it]
502
  62%|██████▏ | 291/469 [08:47<05:11, 1.75s/it]
503
  62%|██████▏ | 292/469 [08:48<05:04, 1.72s/it]
504
  62%|██████▏ | 293/469 [08:50<05:04, 1.73s/it]
505
  63%|██████▎ | 294/469 [08:52<05:19, 1.83s/it]
506
  63%|██████▎ | 295/469 [08:54<05:14, 1.81s/it]
507
  63%|██████▎ | 296/469 [08:56<05:04, 1.76s/it]
508
  63%|██████▎ | 297/469 [08:57<04:58, 1.74s/it]
509
  64%|██████▎ | 298/469 [08:59<05:01, 1.76s/it]
510
  64%|██████▍ | 299/469 [09:01<04:56, 1.74s/it]
511
  64%|██████▍ | 300/469 [09:03<04:52, 1.73s/it]
512
  64%|██████▍ | 301/469 [09:04<04:48, 1.72s/it]
513
  64%|██████▍ | 302/469 [09:06<04:43, 1.70s/it]
514
  65%|██████▍ | 303/469 [09:08<04:50, 1.75s/it]
515
  65%|██████▍ | 304/469 [09:10<04:46, 1.74s/it]
516
  65%|██████▌ | 305/469 [09:11<04:47, 1.75s/it]
517
  65%|██████▌ | 306/469 [09:13<04:42, 1.73s/it]
518
  65%|██████▌ | 307/469 [09:15<04:58, 1.84s/it]
519
  66%|██████▌ | 308/469 [09:17<04:50, 1.81s/it]
520
  66%|██████▌ | 309/469 [09:19<04:46, 1.79s/it]
521
  66%|██████▌ | 310/469 [09:20<04:39, 1.76s/it]
522
  66%|██████▋ | 311/469 [09:22<04:34, 1.74s/it]
523
  67%|██████▋ | 312/469 [09:24<04:36, 1.76s/it]
524
  67%|██████▋ | 313/469 [09:26<04:52, 1.88s/it]
525
  67%|██████▋ | 314/469 [09:28<04:55, 1.91s/it]
526
  67%|██████▋ | 315/469 [09:30<04:43, 1.84s/it]
527
  67%|██████▋ | 316/469 [09:31<04:36, 1.81s/it]
528
  68%|██████▊ | 317/469 [09:33<04:42, 1.86s/it]
529
  68%|██████▊ | 318/469 [09:36<05:38, 2.24s/it]
530
  68%|██████▊ | 319/469 [09:38<05:18, 2.13s/it]
531
  68%|██████▊ | 320/469 [09:40<04:59, 2.01s/it]
532
  68%|██████▊ | 321/469 [09:42<04:49, 1.95s/it]
533
  69%|██████▊ | 322/469 [09:43<04:32, 1.86s/it]
534
  69%|██████▉ | 323/469 [09:47<05:51, 2.41s/it]
535
  69%|██████▉ | 324/469 [09:49<05:17, 2.19s/it]
536
  69%|██████▉ | 325/469 [09:51<04:58, 2.08s/it]
537
  70%|██████▉ | 326/469 [09:52<04:42, 1.97s/it]
538
  70%|██████▉ | 327/469 [09:54<04:29, 1.89s/it]
539
  70%|██████▉ | 328/469 [09:56<04:17, 1.82s/it]
540
  70%|███████ | 329/469 [09:58<04:12, 1.80s/it]
541
  70%|███████ | 330/469 [09:59<04:07, 1.78s/it]
542
  71%|███████ | 331/469 [10:01<04:12, 1.83s/it]
543
  71%|███████ | 332/469 [10:03<04:06, 1.80s/it]
544
  71%|███████ | 333/469 [10:05<03:57, 1.75s/it]
545
  71%|███████ | 334/469 [10:06<03:51, 1.71s/it]
546
  71%|███████▏ | 335/469 [10:08<03:47, 1.70s/it]
547
  72%|███████▏ | 336/469 [10:09<03:44, 1.69s/it]
548
  72%|███████▏ | 337/469 [10:11<03:42, 1.68s/it]
549
  72%|███████▏ | 338/469 [10:13<03:45, 1.72s/it]
550
  72%|███████▏ | 339/469 [10:15<03:41, 1.70s/it]
551
  72%|███████▏ | 340/469 [10:16<03:38, 1.70s/it]
552
  73%|███████▎ | 341/469 [10:18<03:37, 1.70s/it]
553
  73%|███████▎ | 342/469 [10:20<03:47, 1.79s/it]
554
  73%|███████▎ | 343/469 [10:22<03:42, 1.77s/it]
555
  73%|███████▎ | 344/469 [10:24<03:40, 1.77s/it]
556
  74%|███████▎ | 345/469 [10:25<03:39, 1.77s/it]
557
  74%|███████▍ | 346/469 [10:27<03:36, 1.76s/it]
558
  74%|███████▍ | 347/469 [10:29<03:31, 1.73s/it]
559
  74%|███████▍ | 348/469 [10:31<03:35, 1.78s/it]
560
  74%|███████▍ | 349/469 [10:32<03:34, 1.78s/it]
561
  75%|███████▍ | 350/469 [10:34<03:30, 1.76s/it]
562
  75%|███████▍ | 351/469 [10:36<03:23, 1.72s/it]
563
  75%|███████▌ | 352/469 [10:38<03:33, 1.82s/it]
564
  75%|███████▌ | 353/469 [10:40<03:29, 1.81s/it]
565
  75%|███████▌ | 354/469 [10:41<03:27, 1.80s/it]
566
  76%|███████▌ | 355/469 [10:43<03:21, 1.77s/it]
567
  76%|███████▌ | 356/469 [10:45<03:27, 1.83s/it]
568
  76%|███████▌ | 357/469 [10:47<03:18, 1.77s/it]
569
  76%|███████▋ | 358/469 [10:48<03:14, 1.75s/it]
570
  77%|██��████▋ | 359/469 [10:50<03:09, 1.73s/it]
571
  77%|███████▋ | 360/469 [10:52<03:07, 1.72s/it]
572
  77%|███████▋ | 361/469 [10:53<03:06, 1.72s/it]
573
  77%|███████▋ | 362/469 [10:55<03:07, 1.75s/it]
574
  77%|███████▋ | 363/469 [10:57<03:10, 1.79s/it]
575
  78%|███████▊ | 364/469 [10:59<03:06, 1.77s/it]
576
  78%|███████▊ | 365/469 [11:01<03:00, 1.73s/it]
577
  78%|███████▊ | 366/469 [11:02<02:55, 1.71s/it]
578
  78%|███████▊ | 367/469 [11:04<02:53, 1.70s/it]
579
  78%|███████▊ | 368/469 [11:06<02:53, 1.72s/it]
580
  79%|███████▊ | 369/469 [11:07<02:51, 1.71s/it]
581
  79%|███████▉ | 370/469 [11:09<02:48, 1.70s/it]
582
  79%|███████▉ | 371/469 [11:11<02:48, 1.71s/it]
583
  79%|███████▉ | 372/469 [11:13<02:51, 1.77s/it]
584
  80%|███████▉ | 373/469 [11:14<02:46, 1.73s/it]
585
  80%|███████▉ | 374/469 [11:16<02:45, 1.75s/it]
586
  80%|███████▉ | 375/469 [11:18<02:43, 1.74s/it]
587
  80%|████████ | 376/469 [11:20<02:42, 1.74s/it]
588
  80%|████████ | 377/469 [11:21<02:39, 1.74s/it]
589
  81%|████████ | 378/469 [11:23<02:36, 1.72s/it]
590
  81%|████████ | 379/469 [11:25<02:35, 1.72s/it]
591
  81%|████████ | 380/469 [11:26<02:32, 1.72s/it]
592
  81%|████████ | 381/469 [11:28<02:34, 1.76s/it]
593
  81%|████████▏ | 382/469 [11:30<02:43, 1.88s/it]
594
  82%|████████▏ | 383/469 [11:32<02:35, 1.81s/it]
595
  82%|████████▏ | 384/469 [11:34<02:29, 1.76s/it]
596
  82%|████████▏ | 385/469 [11:36<02:30, 1.80s/it]
597
  82%|████████▏ | 386/469 [11:37<02:26, 1.77s/it]
598
  83%|████████▎ | 387/469 [11:39<02:26, 1.78s/it]
599
  83%|████████▎ | 388/469 [11:41<02:23, 1.77s/it]
600
  83%|████████▎ | 389/469 [11:42<02:18, 1.73s/it]
601
  83%|████████▎ | 390/469 [11:44<02:18, 1.75s/it]
602
  83%|████████▎ | 391/469 [11:46<02:16, 1.76s/it]
603
  84%|████████▎ | 392/469 [11:49<02:33, 1.99s/it]
604
  84%|████████▍ | 393/469 [11:50<02:24, 1.90s/it]
605
  84%|████████▍ | 394/469 [11:52<02:18, 1.85s/it]
606
  84%|████████▍ | 395/469 [11:54<02:12, 1.79s/it]
607
  84%|████████▍ | 396/469 [11:58<03:02, 2.50s/it]
608
  85%|████████▍ | 397/469 [12:00<02:44, 2.29s/it]
609
  85%|████████▍ | 398/469 [12:01<02:31, 2.13s/it]
610
  85%|████████▌ | 399/469 [12:03<02:21, 2.02s/it]
611
  85%|████████▌ | 400/469 [12:05<02:15, 1.96s/it]
612
  86%|████████▌ | 401/469 [12:07<02:10, 1.92s/it]
613
  86%|████████▌ | 402/469 [12:08<02:02, 1.83s/it]
614
  86%|████████▌ | 403/469 [12:10<01:57, 1.78s/it]
615
  86%|████████▌ | 404/469 [12:12<01:53, 1.74s/it]
616
  86%|████████▋ | 405/469 [12:13<01:49, 1.72s/it]
617
  87%|████████▋ | 406/469 [12:15<01:50, 1.76s/it]
618
  87%|████████▋ | 407/469 [12:17<01:48, 1.74s/it]
619
  87%|████████▋ | 408/469 [12:19<01:44, 1.72s/it]
620
  87%|████████▋ | 409/469 [12:20<01:41, 1.70s/it]
621
  87%|████████▋ | 410/469 [12:22<01:45, 1.79s/it]
622
  88%|████████▊ | 411/469 [12:24<01:40, 1.74s/it]
623
  88%|████████▊ | 412/469 [12:26<01:41, 1.78s/it]
624
  88%|████████▊ | 413/469 [12:28<01:39, 1.78s/it]
625
  88%|████████▊ | 414/469 [12:29<01:35, 1.74s/it]
626
  88%|████████▊ | 415/469 [12:31<01:38, 1.83s/it]
627
  89%|████████▊ | 416/469 [12:33<01:36, 1.82s/it]
628
  89%|████████▉ | 417/469 [12:35<01:31, 1.76s/it]
629
  89%|████████▉ | 418/469 [12:36<01:31, 1.79s/it]
630
  89%|████████▉ | 419/469 [12:38<01:26, 1.74s/it]
631
  90%|████████▉ | 420/469 [12:40<01:25, 1.75s/it]
632
  90%|████████▉ | 421/469 [12:42<01:23, 1.75s/it]
633
  90%|████████▉ | 422/469 [12:43<01:20, 1.71s/it]
634
  90%|█████████ | 423/469 [12:45<01:20, 1.74s/it]
635
  90%|█████████ | 424/469 [12:47<01:18, 1.74s/it]
636
  91%|█████████ | 425/469 [12:49<01:17, 1.75s/it]
637
  91%|█████████ | 426/469 [12:50<01:15, 1.75s/it]
638
  91%|█████████ | 427/469 [12:52<01:14, 1.78s/it]
639
  91%|█████████▏| 428/469 [12:54<01:12, 1.77s/it]
640
  91%|█████████▏| 429/469 [12:56<01:09, 1.75s/it]
641
  92%|█████████▏| 430/469 [12:57<01:07, 1.72s/it]
642
  92%|█████████▏| 431/469 [12:59<01:04, 1.70s/it]
643
  92%|█████████▏| 432/469 [13:01<01:04, 1.74s/it]
644
  92%|█████████▏| 433/469 [13:03<01:07, 1.88s/it]
645
  93%|█████████▎| 434/469 [13:05<01:04, 1.86s/it]
646
  93%|█████████▎| 435/469 [13:07<01:03, 1.87s/it]
647
  93%|█████████▎| 436/469 [13:08<00:59, 1.81s/it]
648
  93%|█████████▎| 437/469 [13:10<00:57, 1.81s/it]
649
  93%|█████████▎| 438/469 [13:12<00:57, 1.86s/it]
650
  94%|█████████▎| 439/469 [13:14<00:54, 1.82s/it]
651
  94%|█████████▍| 440/469 [13:16<00:52, 1.82s/it]
652
  94%|█████████▍| 441/469 [13:18<00:51, 1.84s/it]
653
  94%|█████████▍| 442/469 [13:19<00:48, 1.80s/it]
654
  94%|█████████▍| 443/469 [13:21<00:48, 1.86s/it]
655
  95%|█████████▍| 444/469 [13:23<00:45, 1.80s/it]
656
  95%|█████████▍| 445/469 [13:25<00:42, 1.79s/it]
657
  95%|█████████▌| 446/469 [13:27<00:42, 1.85s/it]
658
  95%|█████████▌| 447/469 [13:29<00:41, 1.86s/it]
659
  96%|█████████▌| 448/469 [13:30<00:39, 1.88s/it]
660
  96%|█████████▌| 449/469 [13:32<00:37, 1.86s/it]
661
  96%|█████████▌| 450/469 [13:34<00:33, 1.79s/it]
662
  96%|█████████▌| 451/469 [13:36<00:31, 1.74s/it]
663
  96%|█████████▋| 452/469 [13:37<00:29, 1.73s/it]
664
  97%|█████████▋| 453/469 [13:39<00:28, 1.76s/it]
665
  97%|█████████▋| 454/469 [13:41<00:26, 1.74s/it]
666
  97%|█████████▋| 455/469 [13:42<00:23, 1.70s/it]
667
  97%|█████████▋| 456/469 [13:44<00:22, 1.71s/it]
668
  97%|█████████▋| 457/469 [13:46<00:20, 1.71s/it]
669
  98%|█████████▊| 458/469 [13:48<00:18, 1.71s/it]
670
  98%|█████████▊| 459/469 [13:49<00:16, 1.69s/it]
671
  98%|█████████▊| 460/469 [13:51<00:15, 1.68s/it]
672
  98%|█████████▊| 461/469 [13:53<00:14, 1.76s/it]
673
  99%|█████████▊| 462/469 [13:55<00:12, 1.78s/it]
674
  99%|█████████▊| 463/469 [13:57<00:10, 1.83s/it]
675
  99%|█████████▉| 464/469 [13:58<00:09, 1.86s/it]
676
  99%|█████████▉| 465/469 [14:00<00:07, 1.81s/it]
677
  99%|█████████▉| 466/469 [14:02<00:05, 1.78s/it]
 
 
 
 
 
 
 
 
1
+ 2026-03-21 21:11:32.431141: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2
+ 2026-03-21 21:11:44.041066: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
3
+ To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
4
+ 2026-03-21 21:11:44.095212: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
5
+ 2026-03-21 21:11:44.095324: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: d700126ec97dc07d69688b0430c49a6a-taskrole1-0
6
+ 2026-03-21 21:11:44.095377: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: d700126ec97dc07d69688b0430c49a6a-taskrole1-0
7
+ 2026-03-21 21:11:44.095515: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
8
+ 2026-03-21 21:11:44.095581: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 550.127.8
9
+ 2026-03-21 21:11:45.185895: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
10
+ warming up TensorFlow...
11
+
12
  0%| | 0/1 [00:00<?, ?it/s]2026-03-21 21:11:45.919794: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
13
+
14
+ computing reference batch activations...
15
+
16
  0%| | 0/211 [00:00<?, ?it/s]
17
  0%| | 1/211 [00:02<07:04, 2.02s/it]
18
  1%| | 2/211 [00:03<06:56, 2.00s/it]
19
  1%|▏ | 3/211 [00:05<06:50, 1.98s/it]
20
  2%|▏ | 4/211 [00:07<06:42, 1.95s/it]
21
  2%|▏ | 5/211 [00:09<06:25, 1.87s/it]
22
  3%|▎ | 6/211 [00:11<06:20, 1.85s/it]
23
  3%|▎ | 7/211 [00:13<06:46, 1.99s/it]
24
  4%|▍ | 8/211 [00:15<06:56, 2.05s/it]
25
  4%|▍ | 9/211 [00:30<20:28, 6.08s/it]
26
  5%|▍ | 10/211 [00:32<15:51, 4.73s/it]
27
  5%|▌ | 11/211 [00:34<12:42, 3.81s/it]
28
  6%|▌ | 12/211 [00:35<10:26, 3.15s/it]
29
  6%|▌ | 13/211 [00:37<08:55, 2.71s/it]
30
  7%|▋ | 14/211 [00:39<08:09, 2.49s/it]
31
  7%|▋ | 15/211 [00:41<07:32, 2.31s/it]
32
  8%|▊ | 16/211 [00:43<06:52, 2.12s/it]
33
  8%|▊ | 17/211 [00:44<06:35, 2.04s/it]
34
  9%|▊ | 18/211 [00:46<06:18, 1.96s/it]
35
  9%|▉ | 19/211 [00:48<06:16, 1.96s/it]
36
  9%|▉ | 20/211 [00:50<06:20, 1.99s/it]
37
  10%|▉ | 21/211 [00:52<06:01, 1.91s/it]
38
  10%|█ | 22/211 [00:54<05:51, 1.86s/it]
39
  11%|█ | 23/211 [00:56<05:48, 1.85s/it]
40
  11%|█▏ | 24/211 [00:57<05:44, 1.84s/it]
41
  12%|█▏ | 25/211 [00:59<05:44, 1.85s/it]
42
  12%|█▏ | 26/211 [01:01<05:38, 1.83s/it]
43
  13%|█▎ | 27/211 [01:03<05:33, 1.81s/it]
44
  13%|█▎ | 28/211 [01:05<06:02, 1.98s/it]
45
  14%|█▎ | 29/211 [01:07<05:49, 1.92s/it]
46
  14%|█▍ | 30/211 [01:09<05:33, 1.84s/it]
47
  15%|█▍ | 31/211 [01:11<05:38, 1.88s/it]
48
  15%|█▌ | 32/211 [01:12<05:32, 1.86s/it]
49
  16%|█▌ | 33/211 [01:14<05:26, 1.83s/it]
50
  16%|█▌ | 34/211 [01:16<05:20, 1.81s/it]
51
  17%|█▋ | 35/211 [01:18<05:16, 1.80s/it]
52
  17%|█▋ | 36/211 [01:19<05:11, 1.78s/it]
53
  18%|█▊ | 37/211 [01:21<05:06, 1.76s/it]
54
  18%|█▊ | 38/211 [01:23<05:22, 1.86s/it]
55
  18%|█▊ | 39/211 [01:25<05:12, 1.82s/it]
56
  19%|█▉ | 40/211 [01:35<11:51, 4.16s/it]
57
  19%|█▉ | 41/211 [01:36<09:45, 3.44s/it]
58
  20%|█▉ | 42/211 [01:38<08:27, 3.00s/it]
59
  20%|██ | 43/211 [01:40<07:25, 2.65s/it]
60
  21%|██ | 44/211 [01:42<06:46, 2.43s/it]
61
  21%|██▏ | 45/211 [01:44<06:28, 2.34s/it]
62
  22%|██▏ | 46/211 [01:46<05:56, 2.16s/it]
63
  22%|██▏ | 47/211 [01:48<05:55, 2.17s/it]
64
  23%|██▎ | 48/211 [01:50<05:30, 2.03s/it]
65
  23%|██▎ | 49/211 [01:52<05:22, 1.99s/it]
66
  24%|██▎ | 50/211 [01:54<05:11, 1.94s/it]
67
  24%|██▍ | 51/211 [01:56<05:18, 1.99s/it]
68
  25%|██▍ | 52/211 [01:57<05:02, 1.90s/it]
69
  25%|██▌ | 53/211 [01:59<04:52, 1.85s/it]
70
  26%|██▌ | 54/211 [02:01<04:42, 1.80s/it]
71
  26%|██▌ | 55/211 [02:03<04:40, 1.80s/it]
72
  27%|██▋ | 56/211 [02:05<04:47, 1.85s/it]
73
  27%|██▋ | 57/211 [02:06<04:38, 1.81s/it]
74
  27%|██▋ | 58/211 [02:08<04:40, 1.83s/it]
75
  28%|██▊ | 59/211 [02:10<04:34, 1.81s/it]
76
  28%|██▊ | 60/211 [02:12<04:31, 1.80s/it]
77
  29%|██▉ | 61/211 [02:13<04:24, 1.76s/it]
78
  29%|██▉ | 62/211 [02:15<04:19, 1.74s/it]
79
  30%|██▉ | 63/211 [02:17<04:25, 1.79s/it]
80
  30%|███ | 64/211 [02:19<04:22, 1.78s/it]
81
  31%|███ | 65/211 [02:21<04:21, 1.79s/it]
82
  31%|███▏ | 66/211 [02:22<04:16, 1.77s/it]
83
  32%|███▏ | 67/211 [02:24<04:27, 1.85s/it]
84
  32%|███▏ | 68/211 [02:26<04:23, 1.84s/it]
85
  33%|███▎ | 69/211 [02:28<04:21, 1.84s/it]
86
  33%|███▎ | 70/211 [02:30<04:15, 1.81s/it]
87
  34%|███▎ | 71/211 [02:32<04:13, 1.81s/it]
88
  34%|███▍ | 72/211 [02:33<04:09, 1.79s/it]
89
  35%|███▍ | 73/211 [02:35<04:04, 1.77s/it]
90
  35%|███▌ | 74/211 [02:53<15:24, 6.75s/it]
91
  36%|███▌ | 75/211 [02:55<11:58, 5.28s/it]
92
  36%|███▌ | 76/211 [02:57<09:31, 4.24s/it]
93
  36%|███▋ | 77/211 [02:59<07:46, 3.48s/it]
94
  37%|███▋ | 78/211 [03:00<06:33, 2.96s/it]
95
  37%|███▋ | 79/211 [03:02<05:47, 2.64s/it]
96
  38%|███▊ | 80/211 [03:04<05:19, 2.44s/it]
97
  38%|███▊ | 81/211 [03:06<04:47, 2.22s/it]
98
  39%|███▉ | 82/211 [03:08<04:27, 2.07s/it]
99
  39%|███▉ | 83/211 [03:09<04:12, 1.97s/it]
100
  40%|███▉ | 84/211 [03:11<04:02, 1.91s/it]
101
  40%|████ | 85/211 [03:13<03:53, 1.85s/it]
102
  41%|████ | 86/211 [03:15<03:48, 1.83s/it]
103
  41%|████ | 87/211 [03:16<03:42, 1.79s/it]
104
  42%|████▏ | 88/211 [03:18<03:40, 1.79s/it]
105
  42%|████▏ | 89/211 [03:20<03:35, 1.77s/it]
106
  43%|████▎ | 90/211 [03:22<03:33, 1.77s/it]
107
  43%|████▎ | 91/211 [03:23<03:30, 1.76s/it]
108
  44%|████▎ | 92/211 [03:25<03:35, 1.81s/it]
109
  44%|████▍ | 93/211 [03:27<03:28, 1.76s/it]
110
  45%|████▍ | 94/211 [03:29<03:25, 1.75s/it]
111
  45%|████▌ | 95/211 [03:31<03:22, 1.74s/it]
112
  45%|████▌ | 96/211 [03:32<03:24, 1.77s/it]
113
  46%|████▌ | 97/211 [03:34<03:21, 1.77s/it]
114
  46%|████▋ | 98/211 [03:36<03:19, 1.77s/it]
115
  47%|████▋ | 99/211 [03:38<03:18, 1.77s/it]
116
  47%|████▋ | 100/211 [03:39<03:14, 1.75s/it]
117
  48%|████▊ | 101/211 [03:41<03:13, 1.75s/it]
118
  48%|████▊ | 102/211 [03:43<03:13, 1.77s/it]
119
  49%|████▉ | 103/211 [03:45<03:08, 1.75s/it]
120
  49%|████▉ | 104/211 [03:47<03:12, 1.79s/it]
121
  50%|████▉ | 105/211 [03:48<03:07, 1.77s/it]
122
  50%|█████ | 106/211 [04:10<13:22, 7.64s/it]
123
  51%|█████ | 107/211 [04:12<10:18, 5.94s/it]
124
  51%|█████ | 108/211 [04:13<08:02, 4.68s/it]
125
  52%|█████▏ | 109/211 [04:15<06:26, 3.79s/it]
126
  52%|█████▏ | 110/211 [04:17<05:18, 3.16s/it]
127
  53%|█████▎ | 111/211 [04:18<04:33, 2.74s/it]
128
  53%|█████▎ | 112/211 [04:20<04:09, 2.52s/it]
129
  54%|█████▎ | 113/211 [04:22<03:50, 2.35s/it]
130
  54%|█████▍ | 114/211 [04:24<03:30, 2.17s/it]
131
  55%|█████▍ | 115/211 [04:26<03:22, 2.11s/it]
132
  55%|█████▍ | 116/211 [04:46<11:49, 7.47s/it]
133
  55%|█████▌ | 117/211 [04:48<08:58, 5.73s/it]
134
  56%|█████▌ | 118/211 [04:50<07:08, 4.61s/it]
135
  56%|█████▋ | 119/211 [04:51<05:41, 3.71s/it]
136
  57%|█████▋ | 120/211 [04:53<04:42, 3.10s/it]
137
  57%|█████▋ | 121/211 [04:55<04:01, 2.69s/it]
138
  58%|█████▊ | 122/211 [04:57<03:33, 2.40s/it]
139
  58%|█████▊ | 123/211 [04:58<03:13, 2.20s/it]
140
  59%|█████▉ | 124/211 [05:00<03:01, 2.08s/it]
141
  59%|█████▉ | 125/211 [05:02<02:50, 1.98s/it]
142
  60%|█████▉ | 126/211 [05:03<02:38, 1.87s/it]
143
  60%|██████ | 127/211 [05:05<02:31, 1.80s/it]
144
  61%|██████ | 128/211 [05:07<02:25, 1.75s/it]
145
  61%|██████ | 129/211 [05:09<02:29, 1.82s/it]
146
  62%|██████▏ | 130/211 [05:10<02:24, 1.79s/it]
147
  62%|██████▏ | 131/211 [05:12<02:20, 1.76s/it]
148
  63%|██████▎ | 132/211 [05:14<02:17, 1.74s/it]
149
  63%|██████▎ | 133/211 [05:15<02:14, 1.73s/it]
150
  64%|██████▎ | 134/211 [05:18<02:23, 1.87s/it]
151
  64%|██████▍ | 135/211 [05:20<02:22, 1.87s/it]
152
  64%|██████▍ | 136/211 [05:21<02:15, 1.80s/it]
153
  65%|██████▍ | 137/211 [05:23<02:13, 1.81s/it]
154
  65%|██████▌ | 138/211 [05:25<02:10, 1.79s/it]
155
  66%|██████▌ | 139/211 [05:26<02:06, 1.75s/it]
156
  66%|██████▋ | 140/211 [05:28<02:03, 1.73s/it]
157
  67%|██████▋ | 141/211 [05:30<02:04, 1.78s/it]
158
  67%|██████▋ | 142/211 [05:32<02:05, 1.82s/it]
159
  68%|██████▊ | 143/211 [05:34<02:04, 1.83s/it]
160
  68%|██��███▊ | 144/211 [05:35<01:59, 1.78s/it]
161
  69%|██████▊ | 145/211 [05:37<01:57, 1.78s/it]
162
  69%|██████▉ | 146/211 [05:55<07:10, 6.63s/it]
163
  70%|██████▉ | 147/211 [05:57<05:30, 5.17s/it]
164
  70%|███████ | 148/211 [05:59<04:23, 4.17s/it]
165
  71%|███████ | 149/211 [06:00<03:31, 3.42s/it]
166
  71%|███████ | 150/211 [06:02<02:58, 2.92s/it]
167
  72%|███████▏ | 151/211 [06:04<02:33, 2.56s/it]
168
  72%|███████▏ | 152/211 [06:06<02:15, 2.29s/it]
169
  73%|███████▎ | 153/211 [06:07<02:02, 2.12s/it]
170
  73%|███████▎ | 154/211 [06:10<02:15, 2.38s/it]
171
  73%|███████▎ | 155/211 [06:12<02:00, 2.15s/it]
172
  74%|███████▍ | 156/211 [06:14<01:50, 2.00s/it]
173
  74%|███████▍ | 157/211 [06:15<01:43, 1.91s/it]
174
  75%|███████▍ | 158/211 [06:17<01:38, 1.86s/it]
175
  75%|███████▌ | 159/211 [06:19<01:37, 1.88s/it]
176
  76%|███████▌ | 160/211 [06:21<01:32, 1.82s/it]
177
  76%|███████▋ | 161/211 [06:22<01:29, 1.79s/it]
178
  77%|███████▋ | 162/211 [06:24<01:31, 1.87s/it]
179
  77%|███████▋ | 163/211 [06:26<01:27, 1.83s/it]
180
  78%|███████▊ | 164/211 [06:29<01:36, 2.05s/it]
181
  78%|███████▊ | 165/211 [06:31<01:32, 2.01s/it]
182
  79%|███████▊ | 166/211 [06:32<01:27, 1.94s/it]
183
  79%|███████▉ | 167/211 [06:34<01:21, 1.85s/it]
184
  80%|███████▉ | 168/211 [06:36<01:17, 1.79s/it]
185
  80%|████████ | 169/211 [06:38<01:19, 1.89s/it]
186
  81%|████████ | 170/211 [06:40<01:17, 1.88s/it]
187
  81%|████████ | 171/211 [06:41<01:14, 1.85s/it]
188
  82%|████████▏ | 172/211 [06:43<01:11, 1.84s/it]
189
  82%|████████▏ | 173/211 [06:45<01:08, 1.81s/it]
190
  82%|████████▏ | 174/211 [06:47<01:07, 1.84s/it]
191
  83%|████████▎ | 175/211 [06:49<01:04, 1.78s/it]
192
  83%|████████▎ | 176/211 [06:50<01:02, 1.79s/it]
193
  84%|████████▍ | 177/211 [06:52<01:00, 1.77s/it]
194
  84%|████████▍ | 178/211 [06:54<01:03, 1.92s/it]
195
  85%|████████▍ | 179/211 [06:56<00:59, 1.87s/it]
196
  85%|████████▌ | 180/211 [06:58<00:56, 1.81s/it]
197
  86%|████████▌ | 181/211 [06:59<00:53, 1.77s/it]
198
  86%|████████▋ | 182/211 [07:01<00:53, 1.84s/it]
199
  87%|████████▋ | 183/211 [07:03<00:50, 1.81s/it]
200
  87%|████████▋ | 184/211 [07:05<00:49, 1.83s/it]
201
  88%|████████▊ | 185/211 [07:07<00:48, 1.86s/it]
202
  88%|████████▊ | 186/211 [07:09<00:45, 1.83s/it]
203
  89%|████████▊ | 187/211 [07:11<00:43, 1.82s/it]
204
  89%|████████▉ | 188/211 [07:12<00:42, 1.83s/it]
205
  90%|████████▉ | 189/211 [07:14<00:39, 1.77s/it]
206
  90%|█████████ | 190/211 [07:16<00:37, 1.77s/it]
207
  91%|█████████ | 191/211 [07:17<00:34, 1.73s/it]
208
  91%|█████████ | 192/211 [07:19<00:32, 1.71s/it]
209
  91%|█████████▏| 193/211 [07:21<00:30, 1.72s/it]
210
  92%|█████████▏| 194/211 [07:23<00:29, 1.72s/it]
211
  92%|█████████▏| 195/211 [07:24<00:27, 1.72s/it]
212
  93%|█████████▎| 196/211 [07:26<00:25, 1.70s/it]
213
  93%|█████████▎| 197/211 [07:28<00:24, 1.75s/it]
214
  94%|█████████▍| 198/211 [07:30<00:23, 1.78s/it]
215
  94%|█████████▍| 199/211 [07:31<00:21, 1.76s/it]
216
  95%|█████████▍| 200/211 [07:33<00:19, 1.78s/it]
217
  95%|█████████▌| 201/211 [07:35<00:17, 1.76s/it]
218
  96%|█████████▌| 202/211 [07:37<00:16, 1.80s/it]
219
  96%|█████████▌| 203/211 [07:39<00:14, 1.79s/it]
220
  97%|█████████▋| 204/211 [07:40<00:12, 1.76s/it]
221
  97%|█████████▋| 205/211 [07:42<00:10, 1.74s/it]
222
  98%|█████████▊| 206/211 [07:44<00:08, 1.75s/it]
223
  98%|█████████▊| 207/211 [07:45<00:06, 1.72s/it]
224
  99%|█████████▊| 208/211 [07:47<00:05, 1.70s/it]
225
  99%|█████████▉| 209/211 [07:49<00:03, 1.68s/it]
226
+ computing/reading reference batch statistics...
227
+ computing sample batch activations...
228
+
229
  0%| | 0/469 [00:00<?, ?it/s]
230
  0%| | 1/469 [00:02<18:09, 2.33s/it]
231
  0%| | 2/469 [00:04<15:10, 1.95s/it]
232
  1%| | 3/469 [00:05<14:14, 1.83s/it]
233
  1%| | 4/469 [00:07<13:52, 1.79s/it]
234
  1%| | 5/469 [00:09<13:23, 1.73s/it]
235
  1%|▏ | 6/469 [00:12<17:01, 2.21s/it]
236
  1%|▏ | 7/469 [00:14<16:26, 2.14s/it]
237
  2%|▏ | 8/469 [00:15<15:33, 2.02s/it]
238
  2%|▏ | 9/469 [00:17<15:18, 2.00s/it]
239
  2%|▏ | 10/469 [00:19<14:43, 1.92s/it]
240
  2%|▏ | 11/469 [00:21<14:04, 1.84s/it]
241
  3%|▎ | 12/469 [00:23<14:21, 1.89s/it]
242
  3%|▎ | 13/469 [00:25<14:16, 1.88s/it]
243
  3%|▎ | 14/469 [00:27<14:15, 1.88s/it]
244
  3%|▎ | 15/469 [00:28<14:05, 1.86s/it]
245
  3%|▎ | 16/469 [00:30<14:24, 1.91s/it]
246
  4%|▎ | 17/469 [00:32<14:05, 1.87s/it]
247
  4%|▍ | 18/469 [00:34<14:13, 1.89s/it]
248
  4%|▍ | 19/469 [00:36<13:44, 1.83s/it]
249
  4%|▍ | 20/469 [00:37<13:14, 1.77s/it]
250
  4%|▍ | 21/469 [00:39<13:07, 1.76s/it]
251
  5%|▍ | 22/469 [00:41<12:51, 1.72s/it]
252
  5%|▍ | 23/469 [00:43<12:56, 1.74s/it]
253
  5%|▌ | 24/469 [00:44<12:53, 1.74s/it]
254
  5%|▌ | 25/469 [00:46<12:43, 1.72s/it]
255
  6%|▌ | 26/469 [00:48<12:47, 1.73s/it]
256
  6%|▌ | 27/469 [00:49<12:34, 1.71s/it]
257
  6%|▌ | 28/469 [00:51<12:53, 1.75s/it]
258
  6%|▌ | 29/469 [00:53<12:38, 1.72s/it]
259
  6%|▋ | 30/469 [00:55<12:33, 1.72s/it]
260
  7%|▋ | 31/469 [00:56<12:17, 1.68s/it]
261
  7%|▋ | 32/469 [00:58<12:27, 1.71s/it]
262
  7%|▋ | 33/469 [01:00<12:22, 1.70s/it]
263
  7%|▋ | 34/469 [01:03<16:09, 2.23s/it]
264
  7%|▋ | 35/469 [01:05<15:00, 2.08s/it]
265
  8%|▊ | 36/469 [01:07<14:10, 1.96s/it]
266
  8%|▊ | 37/469 [01:08<13:48, 1.92s/it]
267
  8%|▊ | 38/469 [01:10<13:13, 1.84s/it]
268
  8%|▊ | 39/469 [01:12<13:28, 1.88s/it]
269
  9%|▊ | 40/469 [01:14<13:04, 1.83s/it]
270
  9%|▊ | 41/469 [01:16<13:04, 1.83s/it]
271
  9%|▉ | 42/469 [01:19<15:53, 2.23s/it]
272
  9%|▉ | 43/469 [01:21<15:36, 2.20s/it]
273
  9%|▉ | 44/469 [01:23<14:48, 2.09s/it]
274
  10%|▉ | 45/469 [01:24<13:55, 1.97s/it]
275
  10%|▉ | 46/469 [01:26<13:20, 1.89s/it]
276
  10%|█ | 47/469 [01:28<13:24, 1.91s/it]
277
  10%|█ | 48/469 [01:30<13:15, 1.89s/it]
278
  10%|█ | 49/469 [01:31<12:39, 1.81s/it]
279
  11%|█ | 50/469 [01:33<12:26, 1.78s/it]
280
  11%|█ | 51/469 [01:35<12:47, 1.84s/it]
281
  11%|█ | 52/469 [01:37<12:28, 1.79s/it]
282
  11%|█▏ | 53/469 [01:39<12:09, 1.75s/it]
283
  12%|█▏ | 54/469 [01:40<12:13, 1.77s/it]
284
  12%|█▏ | 55/469 [01:42<12:03, 1.75s/it]
285
  12%|█▏ | 56/469 [01:44<12:13, 1.78s/it]
286
  12%|█▏ | 57/469 [01:46<12:27, 1.81s/it]
287
  12%|█▏ | 58/469 [01:48<12:53, 1.88s/it]
288
  13%|█▎ | 59/469 [01:50<12:46, 1.87s/it]
289
  13%|█▎ | 60/469 [01:51<12:25, 1.82s/it]
290
  13%|█▎ | 61/469 [01:53<12:01, 1.77s/it]
291
  13%|█▎ | 62/469 [01:55<11:49, 1.74s/it]
292
  13%|█▎ | 63/469 [01:56<11:43, 1.73s/it]
293
  14%|█▎ | 64/469 [01:58<11:37, 1.72s/it]
294
  14%|█▍ | 65/469 [02:00<11:38, 1.73s/it]
295
  14%|█▍ | 66/469 [02:02<11:28, 1.71s/it]
296
  14%|█▍ | 67/469 [02:03<11:30, 1.72s/it]
297
  14%|█▍ | 68/469 [02:05<11:30, 1.72s/it]
298
  15%|█▍ | 69/469 [02:07<12:03, 1.81s/it]
299
  15%|█▍ | 70/469 [02:09<11:46, 1.77s/it]
300
  15%|█▌ | 71/469 [02:10<11:36, 1.75s/it]
301
  15%|█▌ | 72/469 [02:12<11:34, 1.75s/it]
302
  16%|█▌ | 73/469 [02:14<12:02, 1.82s/it]
303
  16%|█▌ | 74/469 [02:16<11:41, 1.78s/it]
304
  16%|█▌ | 75/469 [02:18<12:10, 1.85s/it]
305
  16%|█▌ | 76/469 [02:20<12:04, 1.84s/it]
306
  16%|█▋ | 77/469 [02:21<11:44, 1.80s/it]
307
  17%|█▋ | 78/469 [02:23<11:44, 1.80s/it]
308
  17%|█▋ | 79/469 [02:25<11:44, 1.81s/it]
309
  17%|█▋ | 80/469 [02:27<11:43, 1.81s/it]
310
  17%|█▋ | 81/469 [02:28<11:22, 1.76s/it]
311
  17%|█▋ | 82/469 [02:30<11:10, 1.73s/it]
312
  18%|█▊ | 83/469 [02:32<11:12, 1.74s/it]
313
  18%|█▊ | 84/469 [02:34<11:11, 1.75s/it]
314
  18%|█▊ | 85/469 [02:36<11:32, 1.80s/it]
315
  18%|█▊ | 86/469 [02:37<11:32, 1.81s/it]
316
  19%|█▊ | 87/469 [02:39<11:16, 1.77s/it]
317
  19%|█▉ | 88/469 [02:41<11:07, 1.75s/it]
318
  19%|█▉ | 89/469 [02:42<11:00, 1.74s/it]
319
  19%|█▉ | 90/469 [02:44<11:09, 1.77s/it]
320
  19%|█▉ | 91/469 [02:46<11:15, 1.79s/it]
321
  20%|█▉ | 92/469 [02:48<11:09, 1.78s/it]
322
  20%|█▉ | 93/469 [02:50<10:57, 1.75s/it]
323
  20%|██ | 94/469 [02:51<10:55, 1.75s/it]
324
  20%|██ | 95/469 [02:53<10:50, 1.74s/it]
325
  20%|██ | 96/469 [02:55<10:42, 1.72s/it]
326
  21%|██ | 97/469 [02:56<10:37, 1.71s/it]
327
  21%|██ | 98/469 [02:58<10:51, 1.76s/it]
328
  21%|██ | 99/469 [03:00<10:52, 1.76s/it]
329
  21%|██▏ | 100/469 [03:02<11:08, 1.81s/it]
330
  22%|██▏ | 101/469 [03:04<10:52, 1.77s/it]
331
  22%|██▏ | 102/469 [03:05<10:48, 1.77s/it]
332
  22%|██▏ | 103/469 [03:07<10:35, 1.74s/it]
333
  22%|██▏ | 104/469 [03:09<10:28, 1.72s/it]
334
  22%|██▏ | 105/469 [03:11<10:38, 1.75s/it]
335
  23%|██▎ | 106/469 [03:12<10:29, 1.73s/it]
336
  23%|██▎ | 107/469 [03:14<10:22, 1.72s/it]
337
  23%|██▎ | 108/469 [03:16<10:41, 1.78s/it]
338
  23%|██▎ | 109/469 [03:18<10:29, 1.75s/it]
339
  23%|██▎ | 110/469 [03:19<10:32, 1.76s/it]
340
  24%|██▎ | 111/469 [03:21<10:25, 1.75s/it]
341
  24%|██▍ | 112/469 [03:23<10:24, 1.75s/it]
342
  24%|██▍ | 113/469 [03:25<10:37, 1.79s/it]
343
  24%|██▍ | 114/469 [03:26<10:24, 1.76s/it]
344
  25%|██▍ | 115/469 [03:28<10:27, 1.77s/it]
345
  25%|██▍ | 116/469 [03:30<10:22, 1.76s/it]
346
  25%|██▍ | 117/469 [03:32<10:28, 1.79s/it]
347
  25%|██▌ | 118/469 [03:34<10:52, 1.86s/it]
348
  25%|██▌ | 119/469 [03:36<11:00, 1.89s/it]
349
  26%|██▌ | 120/469 [03:37<10:42, 1.84s/it]
350
  26%|██▌ | 121/469 [03:39<10:37, 1.83s/it]
351
  26%|██▌ | 122/469 [03:41<10:17, 1.78s/it]
352
  26%|██▌ | 123/469 [03:43<10:21, 1.80s/it]
353
  26%|██▋ | 124/469 [03:45<10:14, 1.78s/it]
354
  27%|██▋ | 125/469 [03:46<10:15, 1.79s/it]
355
  27%|██▋ | 126/469 [03:48<10:12, 1.79s/it]
356
  27%|██▋ | 127/469 [03:50<09:59, 1.75s/it]
357
  27%|██▋ | 128/469 [03:51<09:49, 1.73s/it]
358
  28%|██▊ | 129/469 [03:53<09:53, 1.75s/it]
359
  28%|██▊ | 130/469 [03:55<09:42, 1.72s/it]
360
  28%|██▊ | 131/469 [03:57<09:37, 1.71s/it]
361
  28%|██▊ | 132/469 [03:58<09:49, 1.75s/it]
362
  28%|██▊ | 133/469 [04:00<09:59, 1.78s/it]
363
  29%|██▊ | 134/469 [04:02<10:22, 1.86s/it]
364
  29%|██▉ | 135/469 [04:04<10:20, 1.86s/it]
365
  29%|██▉ | 136/469 [04:06<10:26, 1.88s/it]
366
  29%|██▉ | 137/469 [04:08<10:02, 1.81s/it]
367
  29%|██▉ | 138/469 [04:10<10:16, 1.86s/it]
368
  30%|██▉ | 139/469 [04:11<09:58, 1.81s/it]
369
  30%|██▉ | 140/469 [04:13<09:59, 1.82s/it]
370
  30%|███ | 141/469 [04:15<09:40, 1.77s/it]
371
  30%|███ | 142/469 [04:17<09:28, 1.74s/it]
372
  30%|███ | 143/469 [04:18<09:16, 1.71s/it]
373
  31%|███ | 144/469 [04:20<09:13, 1.70s/it]
374
  31%|███ | 145/469 [04:22<09:28, 1.75s/it]
375
  31%|███ | 146/469 [04:24<09:18, 1.73s/it]
376
  31%|███▏ | 147/469 [04:25<09:35, 1.79s/it]
377
  32%|███▏ | 148/469 [04:27<09:28, 1.77s/it]
378
  32%|███▏ | 149/469 [04:29<09:12, 1.73s/it]
379
  32%|███▏ | 150/469 [04:31<09:11, 1.73s/it]
380
  32%|███▏ | 151/469 [04:32<09:18, 1.76s/it]
381
  32%|███▏ | 152/469 [04:34<09:23, 1.78s/it]
382
  33%|███▎ | 153/469 [04:36<09:21, 1.78s/it]
383
  33%|███▎ | 154/469 [04:38<09:11, 1.75s/it]
384
  33%|███▎ | 155/469 [04:39<09:09, 1.75s/it]
385
  33%|███▎ | 156/469 [04:42<10:07, 1.94s/it]
386
  33%|███▎ | 157/469 [04:44<10:39, 2.05s/it]
387
  34%|███▎ | 158/469 [04:46<10:12, 1.97s/it]
388
  34%|███▍ | 159/469 [04:48<09:42, 1.88s/it]
389
  34%|███▍ | 160/469 [04:49<09:42, 1.88s/it]
390
  34%|███▍ | 161/469 [04:51<09:24, 1.83s/it]
391
  35%|███▍ | 162/469 [04:53<09:08, 1.79s/it]
392
  35%|███▍ | 163/469 [04:55<09:01, 1.77s/it]
393
  35%|███▍ | 164/469 [04:56<09:08, 1.80s/it]
394
  35%|███▌ | 165/469 [04:58<09:07, 1.80s/it]
395
  35%|███▌ | 166/469 [05:00<09:06, 1.80s/it]
396
  36%|███▌ | 167/469 [05:02<09:00, 1.79s/it]
397
  36%|███▌ | 168/469 [05:03<08:47, 1.75s/it]
398
  36%|███▌ | 169/469 [05:05<08:51, 1.77s/it]
399
  36%|███▌ | 170/469 [05:07<08:35, 1.73s/it]
400
  36%|███▋ | 171/469 [05:09<08:25, 1.70s/it]
401
  37%|███▋ | 172/469 [05:10<08:24, 1.70s/it]
402
  37%|███▋ | 173/469 [05:12<08:20, 1.69s/it]
403
  37%|███▋ | 174/469 [05:14<08:20, 1.70s/it]
404
  37%|███▋ | 175/469 [05:16<08:42, 1.78s/it]
405
  38%|███▊ | 176/469 [05:17<08:34, 1.76s/it]
406
  38%|███▊ | 177/469 [05:19<08:31, 1.75s/it]
407
  38%|███▊ | 178/469 [05:21<08:34, 1.77s/it]
408
  38%|███▊ | 179/469 [05:22<08:22, 1.73s/it]
409
  38%|███▊ | 180/469 [05:24<08:22, 1.74s/it]
410
  39%|███▊ | 181/469 [05:26<08:14, 1.72s/it]
411
  39%|███▉ | 182/469 [05:28<08:19, 1.74s/it]
412
  39%|███▉ | 183/469 [05:29<08:17, 1.74s/it]
413
  39%|███▉ | 184/469 [05:31<08:29, 1.79s/it]
414
  39%|███▉ | 185/469 [05:33<08:36, 1.82s/it]
415
  40%|███▉ | 186/469 [05:35<08:36, 1.82s/it]
416
  40%|███▉ | 187/469 [05:37<08:29, 1.81s/it]
417
  40%|████ | 188/469 [05:39<08:45, 1.87s/it]
418
  40%|████ | 189/469 [05:41<08:29, 1.82s/it]
419
  41%|████ | 190/469 [05:42<08:28, 1.82s/it]
420
  41%|████ | 191/469 [05:44<08:18, 1.79s/it]
421
  41%|████ | 192/469 [05:46<08:10, 1.77s/it]
422
  41%|████ | 193/469 [05:47<07:58, 1.73s/it]
423
  41%|████▏ | 194/469 [05:49<07:52, 1.72s/it]
424
  42%|████▏ | 195/469 [05:51<08:08, 1.78s/it]
425
  42%|████▏ | 196/469 [05:53<08:19, 1.83s/it]
426
  42%|████▏ | 197/469 [05:55<08:35, 1.90s/it]
427
  42%|████▏ | 198/469 [05:57<08:28, 1.88s/it]
428
  42%|████▏ | 199/469 [05:59<08:17, 1.84s/it]
429
  43%|████▎ | 200/469 [06:01<08:22, 1.87s/it]
430
  43%|████▎ | 201/469 [06:02<08:09, 1.83s/it]
431
  43%|████▎ | 202/469 [06:04<08:08, 1.83s/it]
432
  43%|████▎ | 203/469 [06:06<08:21, 1.88s/it]
433
  43%|████▎ | 204/469 [06:08<07:59, 1.81s/it]
434
  44%|████▎ | 205/469 [06:09<07:46, 1.77s/it]
435
  44%|████▍ | 206/469 [06:11<07:38, 1.74s/it]
436
  44%|████▍ | 207/469 [06:13<07:27, 1.71s/it]
437
  44%|████▍ | 208/469 [06:15<07:40, 1.76s/it]
438
  45%|████▍ | 209/469 [06:16<07:35, 1.75s/it]
439
  45%|████▍ | 210/469 [06:18<07:30, 1.74s/it]
440
  45%|████▍ | 211/469 [06:20<07:35, 1.77s/it]
441
  45%|████▌ | 212/469 [06:22<07:32, 1.76s/it]
442
  45%|████▌ | 213/469 [06:23<07:29, 1.76s/it]
443
  46%|████▌ | 214/469 [06:25<07:38, 1.80s/it]
444
  46%|████▌ | 215/469 [06:27<07:34, 1.79s/it]
445
  46%|████▌ | 216/469 [06:29<07:28, 1.77s/it]
446
  46%|████▋ | 217/469 [06:30<07:20, 1.75s/it]
447
  46%|████▋ | 218/469 [06:32<07:08, 1.71s/it]
448
  47%|████▋ | 219/469 [06:34<07:04, 1.70s/it]
449
  47%|████▋ | 220/469 [06:35<07:02, 1.70s/it]
450
  47%|████▋ | 221/469 [06:38<07:54, 1.91s/it]
451
  47%|████▋ | 222/469 [06:40<07:49, 1.90s/it]
452
  48%|████▊ | 223/469 [06:42<07:37, 1.86s/it]
453
  48%|████▊ | 224/469 [06:43<07:18, 1.79s/it]
454
  48%|████▊ | 225/469 [06:45<07:42, 1.90s/it]
455
  48%|████▊ | 226/469 [06:48<09:05, 2.25s/it]
456
  48%|████▊ | 227/469 [06:50<08:33, 2.12s/it]
457
  49%|████▊ | 228/469 [06:52<07:55, 1.97s/it]
458
  49%|████▉ | 229/469 [06:54<07:40, 1.92s/it]
459
  49%|████▉ | 230/469 [06:55<07:33, 1.90s/it]
460
  49%|████▉ | 231/469 [06:57<07:17, 1.84s/it]
461
  49%|████▉ | 232/469 [06:59<07:05, 1.79s/it]
462
  50%|████▉ | 233/469 [07:01<07:03, 1.80s/it]
463
  50%|████▉ | 234/469 [07:02<06:59, 1.78s/it]
464
  50%|█████ | 235/469 [07:05<08:13, 2.11s/it]
465
  50%|█████ | 236/469 [07:07<07:39, 1.97s/it]
466
  51%|█████ | 237/469 [07:09<07:15, 1.88s/it]
467
  51%|█████ | 238/469 [07:10<07:11, 1.87s/it]
468
  51%|█████ | 239/469 [07:12<06:55, 1.81s/it]
469
  51%|█████ | 240/469 [07:14<07:03, 1.85s/it]
470
  51%|█████▏ | 241/469 [07:16<06:54, 1.82s/it]
471
  52%|█████▏ | 242/469 [07:17<06:41, 1.77s/it]
472
  52%|█████▏ | 243/469 [07:19<06:33, 1.74s/it]
473
  52%|█████▏ | 244/469 [07:21<06:28, 1.73s/it]
474
  52%|█████▏ | 245/469 [07:24<07:39, 2.05s/it]
475
  52%|█████▏ | 246/469 [07:25<07:19, 1.97s/it]
476
  53%|█████▎ | 247/469 [07:27<07:08, 1.93s/it]
477
  53%|█████▎ | 248/469 [07:29<06:57, 1.89s/it]
478
  53%|█████▎ | 249/469 [07:31<06:40, 1.82s/it]
479
  53%|█████▎ | 250/469 [07:33<06:38, 1.82s/it]
480
  54%|█████▎ | 251/469 [07:34<06:25, 1.77s/it]
481
  54%|█████▎ | 252/469 [07:36<06:15, 1.73s/it]
482
  54%|█████▍ | 253/469 [07:37<06:09, 1.71s/it]
483
  54%|█████▍ | 254/469 [07:39<06:05, 1.70s/it]
484
  54%|█████▍ | 255/469 [07:41<06:02, 1.70s/it]
485
  55%|█████▍ | 256/469 [07:42<05:58, 1.68s/it]
486
  55%|█████▍ | 257/469 [07:44<06:01, 1.71s/it]
487
  55%|█████▌ | 258/469 [07:47<06:54, 1.96s/it]
488
  55%|█████▌ | 259/469 [07:49<06:35, 1.88s/it]
489
  55%|█████▌ | 260/469 [07:52<08:26, 2.42s/it]
490
  56%|█████▌ | 261/469 [07:54<07:35, 2.19s/it]
491
  56%|█████▌ | 262/469 [07:56<07:07, 2.07s/it]
492
  56%|█████▌ | 263/469 [07:57<06:47, 1.98s/it]
493
  56%|█████▋ | 264/469 [07:59<06:25, 1.88s/it]
494
  57%|█████▋ | 265/469 [08:01<06:21, 1.87s/it]
495
  57%|█████▋ | 266/469 [08:03<06:20, 1.87s/it]
496
  57%|█████▋ | 267/469 [08:04<06:08, 1.82s/it]
497
  57%|█████▋ | 268/469 [08:06<06:07, 1.83s/it]
498
  57%|█████▋ | 269/469 [08:08<05:54, 1.77s/it]
499
  58%|█████▊ | 270/469 [08:10<05:59, 1.80s/it]
500
  58%|█████▊ | 271/469 [08:12<05:51, 1.77s/it]
501
  58%|█████▊ | 272/469 [08:13<05:58, 1.82s/it]
502
  58%|█████▊ | 273/469 [08:15<06:06, 1.87s/it]
503
  58%|█████▊ | 274/469 [08:17<05:52, 1.81s/it]
504
  59%|█████▊ | 275/469 [08:19<05:50, 1.81s/it]
505
  59%|█████▉ | 276/469 [08:21<05:42, 1.77s/it]
506
  59%|█████▉ | 277/469 [08:22<05:37, 1.76s/it]
507
  59%|█████▉ | 278/469 [08:24<05:34, 1.75s/it]
508
  59%|█████▉ | 279/469 [08:26<05:38, 1.78s/it]
509
  60%|█████▉ | 280/469 [08:28<05:38, 1.79s/it]
510
  60%|█████▉ | 281/469 [08:30<05:39, 1.81s/it]
511
  60%|██████ | 282/469 [08:31<05:38, 1.81s/it]
512
  60%|██████ | 283/469 [08:33<05:30, 1.78s/it]
513
  61%|██████ | 284/469 [08:35<05:26, 1.77s/it]
514
  61%|██████ | 285/469 [08:36<05:17, 1.73s/it]
515
  61%|██████ | 286/469 [08:38<05:14, 1.72s/it]
516
  61%|██████ | 287/469 [08:40<05:10, 1.71s/it]
517
  61%|██████▏ | 288/469 [08:41<05:05, 1.69s/it]
518
  62%|██████▏ | 289/469 [08:43<05:04, 1.69s/it]
519
  62%|██████▏ | 290/469 [08:45<05:11, 1.74s/it]
520
  62%|██████▏ | 291/469 [08:47<05:11, 1.75s/it]
521
  62%|██████▏ | 292/469 [08:48<05:04, 1.72s/it]
522
  62%|██████▏ | 293/469 [08:50<05:04, 1.73s/it]
523
  63%|██████▎ | 294/469 [08:52<05:19, 1.83s/it]
524
  63%|██████▎ | 295/469 [08:54<05:14, 1.81s/it]
525
  63%|██████▎ | 296/469 [08:56<05:04, 1.76s/it]
526
  63%|██████▎ | 297/469 [08:57<04:58, 1.74s/it]
527
  64%|██████▎ | 298/469 [08:59<05:01, 1.76s/it]
528
  64%|██████▍ | 299/469 [09:01<04:56, 1.74s/it]
529
  64%|██████▍ | 300/469 [09:03<04:52, 1.73s/it]
530
  64%|██████▍ | 301/469 [09:04<04:48, 1.72s/it]
531
  64%|██████▍ | 302/469 [09:06<04:43, 1.70s/it]
532
  65%|██████▍ | 303/469 [09:08<04:50, 1.75s/it]
533
  65%|██████▍ | 304/469 [09:10<04:46, 1.74s/it]
534
  65%|██████▌ | 305/469 [09:11<04:47, 1.75s/it]
535
  65%|██████▌ | 306/469 [09:13<04:42, 1.73s/it]
536
  65%|██████▌ | 307/469 [09:15<04:58, 1.84s/it]
537
  66%|██████▌ | 308/469 [09:17<04:50, 1.81s/it]
538
  66%|██████▌ | 309/469 [09:19<04:46, 1.79s/it]
539
  66%|██████▌ | 310/469 [09:20<04:39, 1.76s/it]
540
  66%|██████▋ | 311/469 [09:22<04:34, 1.74s/it]
541
  67%|██████▋ | 312/469 [09:24<04:36, 1.76s/it]
542
  67%|██████▋ | 313/469 [09:26<04:52, 1.88s/it]
543
  67%|██████▋ | 314/469 [09:28<04:55, 1.91s/it]
544
  67%|██████▋ | 315/469 [09:30<04:43, 1.84s/it]
545
  67%|██████▋ | 316/469 [09:31<04:36, 1.81s/it]
546
  68%|██████▊ | 317/469 [09:33<04:42, 1.86s/it]
547
  68%|██████▊ | 318/469 [09:36<05:38, 2.24s/it]
548
  68%|██████▊ | 319/469 [09:38<05:18, 2.13s/it]
549
  68%|██████▊ | 320/469 [09:40<04:59, 2.01s/it]
550
  68%|██████▊ | 321/469 [09:42<04:49, 1.95s/it]
551
  69%|██████▊ | 322/469 [09:43<04:32, 1.86s/it]
552
  69%|██████▉ | 323/469 [09:47<05:51, 2.41s/it]
553
  69%|██████▉ | 324/469 [09:49<05:17, 2.19s/it]
554
  69%|██████▉ | 325/469 [09:51<04:58, 2.08s/it]
555
  70%|██████▉ | 326/469 [09:52<04:42, 1.97s/it]
556
  70%|██████▉ | 327/469 [09:54<04:29, 1.89s/it]
557
  70%|██████▉ | 328/469 [09:56<04:17, 1.82s/it]
558
  70%|███████ | 329/469 [09:58<04:12, 1.80s/it]
559
  70%|███████ | 330/469 [09:59<04:07, 1.78s/it]
560
  71%|███████ | 331/469 [10:01<04:12, 1.83s/it]
561
  71%|███████ | 332/469 [10:03<04:06, 1.80s/it]
562
  71%|███████ | 333/469 [10:05<03:57, 1.75s/it]
563
  71%|███████ | 334/469 [10:06<03:51, 1.71s/it]
564
  71%|███████▏ | 335/469 [10:08<03:47, 1.70s/it]
565
  72%|███████▏ | 336/469 [10:09<03:44, 1.69s/it]
566
  72%|███████▏ | 337/469 [10:11<03:42, 1.68s/it]
567
  72%|███████▏ | 338/469 [10:13<03:45, 1.72s/it]
568
  72%|███████▏ | 339/469 [10:15<03:41, 1.70s/it]
569
  72%|███████▏ | 340/469 [10:16<03:38, 1.70s/it]
570
  73%|███████▎ | 341/469 [10:18<03:37, 1.70s/it]
571
  73%|███████▎ | 342/469 [10:20<03:47, 1.79s/it]
572
  73%|███████▎ | 343/469 [10:22<03:42, 1.77s/it]
573
  73%|███████▎ | 344/469 [10:24<03:40, 1.77s/it]
574
  74%|███████▎ | 345/469 [10:25<03:39, 1.77s/it]
575
  74%|███████▍ | 346/469 [10:27<03:36, 1.76s/it]
576
  74%|███████▍ | 347/469 [10:29<03:31, 1.73s/it]
577
  74%|███████▍ | 348/469 [10:31<03:35, 1.78s/it]
578
  74%|███████▍ | 349/469 [10:32<03:34, 1.78s/it]
579
  75%|███████▍ | 350/469 [10:34<03:30, 1.76s/it]
580
  75%|███████▍ | 351/469 [10:36<03:23, 1.72s/it]
581
  75%|███████▌ | 352/469 [10:38<03:33, 1.82s/it]
582
  75%|███████▌ | 353/469 [10:40<03:29, 1.81s/it]
583
  75%|███████▌ | 354/469 [10:41<03:27, 1.80s/it]
584
  76%|███████▌ | 355/469 [10:43<03:21, 1.77s/it]
585
  76%|███████▌ | 356/469 [10:45<03:27, 1.83s/it]
586
  76%|███████▌ | 357/469 [10:47<03:18, 1.77s/it]
587
  76%|███████▋ | 358/469 [10:48<03:14, 1.75s/it]
588
  77%|██��████▋ | 359/469 [10:50<03:09, 1.73s/it]
589
  77%|███████▋ | 360/469 [10:52<03:07, 1.72s/it]
590
  77%|███████▋ | 361/469 [10:53<03:06, 1.72s/it]
591
  77%|███████▋ | 362/469 [10:55<03:07, 1.75s/it]
592
  77%|███████▋ | 363/469 [10:57<03:10, 1.79s/it]
593
  78%|███████▊ | 364/469 [10:59<03:06, 1.77s/it]
594
  78%|███████▊ | 365/469 [11:01<03:00, 1.73s/it]
595
  78%|███████▊ | 366/469 [11:02<02:55, 1.71s/it]
596
  78%|███████▊ | 367/469 [11:04<02:53, 1.70s/it]
597
  78%|███████▊ | 368/469 [11:06<02:53, 1.72s/it]
598
  79%|███████▊ | 369/469 [11:07<02:51, 1.71s/it]
599
  79%|███████▉ | 370/469 [11:09<02:48, 1.70s/it]
600
  79%|███████▉ | 371/469 [11:11<02:48, 1.71s/it]
601
  79%|███████▉ | 372/469 [11:13<02:51, 1.77s/it]
602
  80%|███████▉ | 373/469 [11:14<02:46, 1.73s/it]
603
  80%|███████▉ | 374/469 [11:16<02:45, 1.75s/it]
604
  80%|███████▉ | 375/469 [11:18<02:43, 1.74s/it]
605
  80%|████████ | 376/469 [11:20<02:42, 1.74s/it]
606
  80%|████████ | 377/469 [11:21<02:39, 1.74s/it]
607
  81%|████████ | 378/469 [11:23<02:36, 1.72s/it]
608
  81%|████████ | 379/469 [11:25<02:35, 1.72s/it]
609
  81%|████████ | 380/469 [11:26<02:32, 1.72s/it]
610
  81%|████████ | 381/469 [11:28<02:34, 1.76s/it]
611
  81%|████████▏ | 382/469 [11:30<02:43, 1.88s/it]
612
  82%|████████▏ | 383/469 [11:32<02:35, 1.81s/it]
613
  82%|████████▏ | 384/469 [11:34<02:29, 1.76s/it]
614
  82%|████████▏ | 385/469 [11:36<02:30, 1.80s/it]
615
  82%|████████▏ | 386/469 [11:37<02:26, 1.77s/it]
616
  83%|████████▎ | 387/469 [11:39<02:26, 1.78s/it]
617
  83%|████████▎ | 388/469 [11:41<02:23, 1.77s/it]
618
  83%|████████▎ | 389/469 [11:42<02:18, 1.73s/it]
619
  83%|████████▎ | 390/469 [11:44<02:18, 1.75s/it]
620
  83%|████████▎ | 391/469 [11:46<02:16, 1.76s/it]
621
  84%|████████▎ | 392/469 [11:49<02:33, 1.99s/it]
622
  84%|████████▍ | 393/469 [11:50<02:24, 1.90s/it]
623
  84%|████████▍ | 394/469 [11:52<02:18, 1.85s/it]
624
  84%|████████▍ | 395/469 [11:54<02:12, 1.79s/it]
625
  84%|████████▍ | 396/469 [11:58<03:02, 2.50s/it]
626
  85%|████████▍ | 397/469 [12:00<02:44, 2.29s/it]
627
  85%|████████▍ | 398/469 [12:01<02:31, 2.13s/it]
628
  85%|████████▌ | 399/469 [12:03<02:21, 2.02s/it]
629
  85%|████████▌ | 400/469 [12:05<02:15, 1.96s/it]
630
  86%|████████▌ | 401/469 [12:07<02:10, 1.92s/it]
631
  86%|████████▌ | 402/469 [12:08<02:02, 1.83s/it]
632
  86%|████████▌ | 403/469 [12:10<01:57, 1.78s/it]
633
  86%|████████▌ | 404/469 [12:12<01:53, 1.74s/it]
634
  86%|████████▋ | 405/469 [12:13<01:49, 1.72s/it]
635
  87%|████████▋ | 406/469 [12:15<01:50, 1.76s/it]
636
  87%|████████▋ | 407/469 [12:17<01:48, 1.74s/it]
637
  87%|████████▋ | 408/469 [12:19<01:44, 1.72s/it]
638
  87%|████████▋ | 409/469 [12:20<01:41, 1.70s/it]
639
  87%|████████▋ | 410/469 [12:22<01:45, 1.79s/it]
640
  88%|████████▊ | 411/469 [12:24<01:40, 1.74s/it]
641
  88%|████████▊ | 412/469 [12:26<01:41, 1.78s/it]
642
  88%|████████▊ | 413/469 [12:28<01:39, 1.78s/it]
643
  88%|████████▊ | 414/469 [12:29<01:35, 1.74s/it]
644
  88%|████████▊ | 415/469 [12:31<01:38, 1.83s/it]
645
  89%|████████▊ | 416/469 [12:33<01:36, 1.82s/it]
646
  89%|████████▉ | 417/469 [12:35<01:31, 1.76s/it]
647
  89%|████████▉ | 418/469 [12:36<01:31, 1.79s/it]
648
  89%|████████▉ | 419/469 [12:38<01:26, 1.74s/it]
649
  90%|████████▉ | 420/469 [12:40<01:25, 1.75s/it]
650
  90%|████████▉ | 421/469 [12:42<01:23, 1.75s/it]
651
  90%|████████▉ | 422/469 [12:43<01:20, 1.71s/it]
652
  90%|█████████ | 423/469 [12:45<01:20, 1.74s/it]
653
  90%|█████████ | 424/469 [12:47<01:18, 1.74s/it]
654
  91%|█████████ | 425/469 [12:49<01:17, 1.75s/it]
655
  91%|█████████ | 426/469 [12:50<01:15, 1.75s/it]
656
  91%|█████████ | 427/469 [12:52<01:14, 1.78s/it]
657
  91%|█████████▏| 428/469 [12:54<01:12, 1.77s/it]
658
  91%|█████████▏| 429/469 [12:56<01:09, 1.75s/it]
659
  92%|█████████▏| 430/469 [12:57<01:07, 1.72s/it]
660
  92%|█████████▏| 431/469 [12:59<01:04, 1.70s/it]
661
  92%|█████████▏| 432/469 [13:01<01:04, 1.74s/it]
662
  92%|█████████▏| 433/469 [13:03<01:07, 1.88s/it]
663
  93%|█████████▎| 434/469 [13:05<01:04, 1.86s/it]
664
  93%|█████████▎| 435/469 [13:07<01:03, 1.87s/it]
665
  93%|█████████▎| 436/469 [13:08<00:59, 1.81s/it]
666
  93%|█████████▎| 437/469 [13:10<00:57, 1.81s/it]
667
  93%|█████████▎| 438/469 [13:12<00:57, 1.86s/it]
668
  94%|█████████▎| 439/469 [13:14<00:54, 1.82s/it]
669
  94%|█████████▍| 440/469 [13:16<00:52, 1.82s/it]
670
  94%|█████████▍| 441/469 [13:18<00:51, 1.84s/it]
671
  94%|█████████▍| 442/469 [13:19<00:48, 1.80s/it]
672
  94%|█████████▍| 443/469 [13:21<00:48, 1.86s/it]
673
  95%|█████████▍| 444/469 [13:23<00:45, 1.80s/it]
674
  95%|█████████▍| 445/469 [13:25<00:42, 1.79s/it]
675
  95%|█████████▌| 446/469 [13:27<00:42, 1.85s/it]
676
  95%|█████████▌| 447/469 [13:29<00:41, 1.86s/it]
677
  96%|█████████▌| 448/469 [13:30<00:39, 1.88s/it]
678
  96%|█████████▌| 449/469 [13:32<00:37, 1.86s/it]
679
  96%|█████████▌| 450/469 [13:34<00:33, 1.79s/it]
680
  96%|█████████▌| 451/469 [13:36<00:31, 1.74s/it]
681
  96%|█████████▋| 452/469 [13:37<00:29, 1.73s/it]
682
  97%|█████████▋| 453/469 [13:39<00:28, 1.76s/it]
683
  97%|█████████▋| 454/469 [13:41<00:26, 1.74s/it]
684
  97%|█████████▋| 455/469 [13:42<00:23, 1.70s/it]
685
  97%|█████████▋| 456/469 [13:44<00:22, 1.71s/it]
686
  97%|█████████▋| 457/469 [13:46<00:20, 1.71s/it]
687
  98%|█████████▊| 458/469 [13:48<00:18, 1.71s/it]
688
  98%|█████████▊| 459/469 [13:49<00:16, 1.69s/it]
689
  98%|█████████▊| 460/469 [13:51<00:15, 1.68s/it]
690
  98%|█████████▊| 461/469 [13:53<00:14, 1.76s/it]
691
  99%|█████████▊| 462/469 [13:55<00:12, 1.78s/it]
692
  99%|█████████▊| 463/469 [13:57<00:10, 1.83s/it]
693
  99%|█████████▉| 464/469 [13:58<00:09, 1.86s/it]
694
  99%|█████████▉| 465/469 [14:00<00:07, 1.81s/it]
695
  99%|█████████▉| 466/469 [14:02<00:05, 1.78s/it]
696
+ computing/reading sample batch statistics...
697
+ Computing evaluations...
698
+ Inception Score: 38.328826904296875
699
+ FID: 21.82574123258769
700
+ sFID: 70.92829349483634
701
+ Precision: 0.6937
702
+ Recall: 0.3517815963698579
eval_rectified_noise_new_batch_2.log ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/1 [00:00<?, ?it/s]2026-03-23 16:28:27.563125: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
 
 
 
1
  0%| | 0/211 [00:00<?, ?it/s]
2
  0%| | 1/211 [00:02<10:16, 2.94s/it]
3
  1%| | 2/211 [00:04<08:15, 2.37s/it]
4
  1%|▏ | 3/211 [00:06<07:38, 2.21s/it]
5
  2%|▏ | 4/211 [00:08<07:20, 2.13s/it]
6
  2%|▏ | 5/211 [00:11<07:25, 2.16s/it]
7
  3%|▎ | 6/211 [00:12<06:59, 2.04s/it]
8
  3%|▎ | 7/211 [00:15<07:15, 2.13s/it]
9
  4%|▍ | 8/211 [00:17<06:55, 2.05s/it]
10
  4%|▍ | 9/211 [00:19<06:53, 2.05s/it]
11
  5%|▍ | 10/211 [00:21<06:37, 1.98s/it]
12
  5%|▌ | 11/211 [00:22<06:18, 1.89s/it]
13
  6%|▌ | 12/211 [00:24<06:13, 1.88s/it]
14
  6%|▌ | 13/211 [00:26<06:08, 1.86s/it]
15
  7%|▋ | 14/211 [00:28<06:02, 1.84s/it]
16
  7%|▋ | 15/211 [00:29<05:57, 1.82s/it]
17
  8%|▊ | 16/211 [00:31<05:55, 1.82s/it]
18
  8%|▊ | 17/211 [00:33<06:00, 1.86s/it]
19
  9%|▊ | 18/211 [00:35<06:01, 1.87s/it]
20
  9%|▉ | 19/211 [00:37<05:59, 1.87s/it]
21
  9%|▉ | 20/211 [00:39<05:51, 1.84s/it]
22
  10%|▉ | 21/211 [00:41<05:45, 1.82s/it]
23
  10%|█ | 22/211 [00:42<05:43, 1.82s/it]
24
  11%|█ | 23/211 [00:44<05:41, 1.82s/it]
25
  11%|█▏ | 24/211 [00:47<06:23, 2.05s/it]
26
  12%|█▏ | 25/211 [00:50<07:07, 2.30s/it]
27
  12%|█▏ | 26/211 [00:51<06:37, 2.15s/it]
28
  13%|█▎ | 27/211 [00:53<06:14, 2.04s/it]
29
  13%|█▎ | 28/211 [00:55<05:56, 1.95s/it]
30
  14%|█▎ | 29/211 [00:57<05:56, 1.96s/it]
31
  14%|█▍ | 30/211 [00:59<05:49, 1.93s/it]
32
  15%|█▍ | 31/211 [01:01<05:49, 1.94s/it]
33
  15%|█▌ | 32/211 [01:03<05:41, 1.91s/it]
34
  16%|█▌ | 33/211 [01:05<05:42, 1.93s/it]
35
  16%|█▌ | 34/211 [01:07<05:49, 1.97s/it]
36
  17%|█▋ | 35/211 [01:09<06:11, 2.11s/it]
37
  17%|█▋ | 36/211 [01:11<06:00, 2.06s/it]
38
  18%|█▊ | 37/211 [01:13<05:46, 1.99s/it]
39
  18%|█▊ | 38/211 [01:15<05:57, 2.07s/it]
40
  18%|█▊ | 39/211 [01:17<05:44, 2.01s/it]
41
  19%|█▉ | 40/211 [01:19<05:31, 1.94s/it]
42
  19%|█▉ | 41/211 [01:21<05:24, 1.91s/it]
43
  20%|█▉ | 42/211 [01:22<05:19, 1.89s/it]
44
  20%|██ | 43/211 [01:24<05:20, 1.91s/it]
45
  21%|██ | 44/211 [01:26<05:09, 1.86s/it]
46
  21%|██▏ | 45/211 [01:29<05:50, 2.11s/it]
47
  22%|██▏ | 46/211 [01:31<05:46, 2.10s/it]
48
  22%|██▏ | 47/211 [01:33<05:33, 2.03s/it]
49
  23%|██▎ | 48/211 [01:35<05:17, 1.95s/it]
50
  23%|██▎ | 49/211 [01:36<05:14, 1.94s/it]
51
  24%|██▎ | 50/211 [01:38<05:03, 1.88s/it]
52
  24%|██▍ | 51/211 [01:40<05:19, 2.00s/it]
53
  25%|██▍ | 52/211 [01:42<05:13, 1.97s/it]
54
  25%|██▌ | 53/211 [01:45<05:19, 2.02s/it]
55
  26%|██▌ | 54/211 [01:46<05:05, 1.95s/it]
56
  26%|██▌ | 55/211 [01:48<04:55, 1.89s/it]
57
  27%|██▋ | 56/211 [01:50<04:56, 1.92s/it]
58
  27%|██▋ | 57/211 [01:52<04:48, 1.88s/it]
59
  27%|██▋ | 58/211 [01:54<04:44, 1.86s/it]
60
  28%|██▊ | 59/211 [01:56<04:47, 1.89s/it]
61
  28%|██▊ | 60/211 [01:58<04:57, 1.97s/it]
62
  29%|██▉ | 61/211 [02:00<04:46, 1.91s/it]
63
  29%|██▉ | 62/211 [02:01<04:38, 1.87s/it]
64
  30%|██▉ | 63/211 [02:04<05:05, 2.07s/it]
65
  30%|███ | 64/211 [02:06<04:59, 2.04s/it]
66
  31%|███ | 65/211 [02:08<05:05, 2.09s/it]
67
  31%|███▏ | 66/211 [02:10<04:48, 1.99s/it]
68
  32%|███▏ | 67/211 [02:12<04:37, 1.93s/it]
69
  32%|███▏ | 68/211 [02:13<04:27, 1.87s/it]
70
  33%|███▎ | 69/211 [02:15<04:25, 1.87s/it]
71
  33%|███▎ | 70/211 [02:17<04:18, 1.83s/it]
72
  34%|███▎ | 71/211 [02:19<04:15, 1.83s/it]
73
  34%|███▍ | 72/211 [02:20<04:10, 1.80s/it]
74
  35%|███▍ | 73/211 [02:22<04:08, 1.80s/it]
75
  35%|███▌ | 74/211 [02:24<04:12, 1.84s/it]
76
  36%|███▌ | 75/211 [02:26<04:10, 1.84s/it]
77
  36%|███▌ | 76/211 [02:28<04:06, 1.83s/it]
78
  36%|███▋ | 77/211 [02:30<04:11, 1.88s/it]
79
  37%|███▋ | 78/211 [02:32<04:07, 1.86s/it]
80
  37%|███▋ | 79/211 [02:34<04:08, 1.88s/it]
81
  38%|███▊ | 80/211 [02:35<04:06, 1.88s/it]
82
  38%|███▊ | 81/211 [02:37<04:11, 1.93s/it]
83
  39%|███▉ | 82/211 [02:39<04:04, 1.89s/it]
84
  39%|███▉ | 83/211 [02:41<04:06, 1.92s/it]
85
  40%|███▉ | 84/211 [02:43<04:03, 1.92s/it]
86
  40%|████ | 85/211 [02:45<03:55, 1.87s/it]
87
  41%|████ | 86/211 [02:47<03:50, 1.84s/it]
88
  41%|████ | 87/211 [02:48<03:45, 1.82s/it]
89
  42%|████▏ | 88/211 [02:50<03:41, 1.80s/it]
90
  42%|████▏ | 89/211 [02:52<03:36, 1.78s/it]
91
  43%|████▎ | 90/211 [02:54<03:51, 1.91s/it]
92
  43%|████▎ | 91/211 [02:56<03:45, 1.88s/it]
93
  44%|████▎ | 92/211 [02:58<03:38, 1.84s/it]
94
  44%|████▍ | 93/211 [03:00<03:37, 1.85s/it]
95
  45%|████▍ | 94/211 [03:02<03:46, 1.94s/it]
96
  45%|████▌ | 95/211 [03:03<03:36, 1.87s/it]
97
  45%|████▌ | 96/211 [03:05<03:34, 1.87s/it]
98
  46%|████▌ | 97/211 [03:08<03:58, 2.09s/it]
99
  46%|████▋ | 98/211 [03:10<03:45, 1.99s/it]
100
  47%|████▋ | 99/211 [03:11<03:34, 1.91s/it]
101
  47%|████▋ | 100/211 [03:13<03:31, 1.91s/it]
102
  48%|████▊ | 101/211 [03:15<03:28, 1.89s/it]
103
  48%|████▊ | 102/211 [03:21<05:46, 3.18s/it]
104
  49%|████▉ | 103/211 [03:23<04:58, 2.76s/it]
105
  49%|████▉ | 104/211 [03:25<04:22, 2.45s/it]
106
  50%|████▉ | 105/211 [03:27<04:03, 2.30s/it]
107
  50%|█████ | 106/211 [03:29<03:45, 2.15s/it]
108
  51%|█████ | 107/211 [03:31<03:35, 2.07s/it]
109
  51%|█████ | 108/211 [03:32<03:29, 2.03s/it]
110
  52%|█████▏ | 109/211 [03:34<03:21, 1.97s/it]
111
  52%|█████▏ | 110/211 [03:37<03:36, 2.15s/it]
112
  53%|█████▎ | 111/211 [03:39<03:23, 2.03s/it]
113
  53%|█████▎ | 112/211 [03:41<03:17, 1.99s/it]
114
  54%|█████▎ | 113/211 [03:42<03:10, 1.95s/it]
115
  54%|█████▍ | 114/211 [03:44<03:03, 1.90s/it]
116
  55%|█████▍ | 115/211 [03:46<03:02, 1.91s/it]
117
  55%|█████▍ | 116/211 [03:48<02:59, 1.89s/it]
118
  55%|█████▌ | 117/211 [03:50<03:03, 1.96s/it]
119
  56%|█████▌ | 118/211 [03:52<02:56, 1.90s/it]
120
  56%|█████▋ | 119/211 [03:54<02:51, 1.86s/it]
121
  57%|█████▋ | 120/211 [03:55<02:49, 1.86s/it]
122
  57%|█████▋ | 121/211 [03:57<02:52, 1.91s/it]
123
  58%|█████▊ | 122/211 [03:59<02:47, 1.89s/it]
124
  58%|█████▊ | 123/211 [04:01<02:51, 1.94s/it]
125
  59%|█████▉ | 124/211 [04:03<02:47, 1.93s/it]
126
  59%|█████▉ | 125/211 [04:05<02:40, 1.87s/it]
127
  60%|█████▉ | 126/211 [04:07<02:46, 1.96s/it]
128
  60%|██████ | 127/211 [04:09<02:43, 1.95s/it]
129
  61%|██████ | 128/211 [04:11<02:42, 1.95s/it]
130
  61%|██████ | 129/211 [04:13<02:47, 2.04s/it]
131
  62%|██████▏ | 130/211 [04:15<02:41, 2.00s/it]
132
  62%|██████▏ | 131/211 [04:17<02:39, 2.00s/it]
133
  63%|██████▎ | 132/211 [04:19<02:36, 1.98s/it]
134
  63%|██████▎ | 133/211 [04:21<02:31, 1.94s/it]
135
  64%|██████▎ | 134/211 [04:23<02:26, 1.91s/it]
136
  64%|██████▍ | 135/211 [04:25<02:21, 1.86s/it]
137
  64%|██████▍ | 136/211 [04:26<02:20, 1.87s/it]
138
  65%|██████▍ | 137/211 [04:29<02:23, 1.94s/it]
139
  65%|██████▌ | 138/211 [04:30<02:18, 1.90s/it]
140
  66%|██████▌ | 139/211 [04:32<02:13, 1.86s/it]
141
  66%|██████▋ | 140/211 [04:34<02:10, 1.84s/it]
142
  67%|██████▋ | 141/211 [04:36<02:07, 1.82s/it]
143
  67%|██████▋ | 142/211 [04:38<02:06, 1.83s/it]
144
  68%|██████▊ | 143/211 [04:39<02:06, 1.86s/it]
145
  68%|██��███▊ | 144/211 [04:41<02:02, 1.83s/it]
146
  69%|██████▊ | 145/211 [04:43<02:00, 1.82s/it]
147
  69%|██████▉ | 146/211 [04:45<01:58, 1.82s/it]
148
  70%|██████▉ | 147/211 [04:47<02:06, 1.98s/it]
149
  70%|███████ | 148/211 [04:49<02:07, 2.02s/it]
150
  71%|███████ | 149/211 [04:51<02:02, 1.97s/it]
151
  71%|███████ | 150/211 [04:53<01:59, 1.96s/it]
152
  72%|███████▏ | 151/211 [04:55<01:56, 1.95s/it]
153
  72%|███████▏ | 152/211 [04:57<01:54, 1.94s/it]
154
  73%|███████▎ | 153/211 [04:59<01:49, 1.88s/it]
155
  73%|███████▎ | 154/211 [05:01<01:56, 2.04s/it]
156
  73%|███████▎ | 155/211 [05:03<01:50, 1.97s/it]
157
  74%|███████▍ | 156/211 [05:05<01:48, 1.97s/it]
158
  74%|███████▍ | 157/211 [05:07<01:46, 1.96s/it]
159
  75%|███████▍ | 158/211 [05:09<01:41, 1.92s/it]
160
  75%|███████▌ | 159/211 [05:10<01:38, 1.90s/it]
161
  76%|███████▌ | 160/211 [05:12<01:35, 1.88s/it]
162
  76%|███████▋ | 161/211 [05:14<01:32, 1.85s/it]
163
  77%|███████▋ | 162/211 [05:16<01:32, 1.89s/it]
164
  77%|███████▋ | 163/211 [05:18<01:31, 1.90s/it]
165
  78%|███████▊ | 164/211 [05:20<01:28, 1.88s/it]
166
  78%|███████▊ | 165/211 [05:22<01:26, 1.88s/it]
167
  79%|███████▊ | 166/211 [05:24<01:25, 1.89s/it]
168
  79%|███████▉ | 167/211 [05:25<01:22, 1.88s/it]
169
  80%|███████▉ | 168/211 [05:27<01:22, 1.92s/it]
170
  80%|████████ | 169/211 [05:29<01:21, 1.94s/it]
171
  81%|████████ | 170/211 [05:31<01:18, 1.92s/it]
172
  81%|████████ | 171/211 [05:34<01:26, 2.15s/it]
173
  82%|████████▏ | 172/211 [05:36<01:23, 2.14s/it]
174
  82%|████████▏ | 173/211 [05:38<01:17, 2.05s/it]
175
  82%|████████▏ | 174/211 [05:40<01:13, 1.99s/it]
176
  83%|████████▎ | 175/211 [05:42<01:09, 1.93s/it]
177
  83%|████████▎ | 176/211 [05:44<01:10, 2.02s/it]
178
  84%|████████▍ | 177/211 [05:46<01:10, 2.06s/it]
179
  84%|████████▍ | 178/211 [05:49<01:16, 2.33s/it]
180
  85%|████████▍ | 179/211 [05:51<01:09, 2.16s/it]
181
  85%|████████▌ | 180/211 [05:53<01:03, 2.05s/it]
182
  86%|████████▌ | 181/211 [05:54<00:58, 1.96s/it]
183
  86%|████████▋ | 182/211 [05:56<00:56, 1.95s/it]
184
  87%|████████▋ | 183/211 [05:59<00:57, 2.05s/it]
185
  87%|████████▋ | 184/211 [06:00<00:54, 2.01s/it]
186
  88%|████████▊ | 185/211 [06:02<00:50, 1.95s/it]
187
  88%|████████▊ | 186/211 [06:04<00:47, 1.89s/it]
188
  89%|████████▊ | 187/211 [06:06<00:45, 1.88s/it]
189
  89%|████████▉ | 188/211 [06:08<00:42, 1.86s/it]
190
  90%|████████▉ | 189/211 [06:09<00:40, 1.83s/it]
191
  90%|█████████ | 190/211 [06:11<00:38, 1.85s/it]
192
  91%|█████████ | 191/211 [06:13<00:36, 1.85s/it]
193
  91%|█████████ | 192/211 [06:15<00:34, 1.84s/it]
194
  91%|█████████▏| 193/211 [06:17<00:34, 1.92s/it]
195
  92%|█████████▏| 194/211 [06:19<00:32, 1.91s/it]
196
  92%|█████████▏| 195/211 [06:21<00:31, 1.96s/it]
197
  93%|█████████▎| 196/211 [06:23<00:28, 1.92s/it]
198
  93%|█████████▎| 197/211 [06:25<00:27, 1.96s/it]
199
  94%|█████████▍| 198/211 [06:27<00:25, 1.93s/it]
200
  94%|█████████▍| 199/211 [06:29<00:22, 1.88s/it]
201
  95%|█████████▍| 200/211 [06:30<00:20, 1.89s/it]
202
  95%|█████████▌| 201/211 [06:32<00:18, 1.88s/it]
203
  96%|█████████▌| 202/211 [06:34<00:16, 1.86s/it]
204
  96%|█████████▌| 203/211 [06:36<00:15, 1.88s/it]
205
  97%|█████████▋| 204/211 [06:38<00:13, 1.89s/it]
206
  97%|█████████▋| 205/211 [06:40<00:11, 1.91s/it]
207
  98%|█████████▊| 206/211 [06:42<00:10, 2.10s/it]
208
  98%|█████████▊| 207/211 [06:44<00:08, 2.07s/it]
209
  99%|█████████▊| 208/211 [06:46<00:05, 1.98s/it]
210
  99%|█████████▉| 209/211 [06:49<00:04, 2.12s/it]
 
 
 
211
  0%| | 0/469 [00:00<?, ?it/s]
212
  0%| | 1/469 [00:01<14:40, 1.88s/it]
213
  0%| | 2/469 [00:04<16:02, 2.06s/it]
214
  1%| | 3/469 [00:05<15:06, 1.95s/it]
215
  1%| | 4/469 [00:07<14:44, 1.90s/it]
216
  1%| | 5/469 [00:09<14:34, 1.88s/it]
217
  1%|▏ | 6/469 [00:11<14:24, 1.87s/it]
218
  1%|▏ | 7/469 [00:13<14:05, 1.83s/it]
219
  2%|▏ | 8/469 [00:15<14:19, 1.87s/it]
220
  2%|▏ | 9/469 [00:16<14:13, 1.85s/it]
221
  2%|▏ | 10/469 [00:18<14:09, 1.85s/it]
222
  2%|▏ | 11/469 [00:20<14:04, 1.84s/it]
223
  3%|▎ | 12/469 [00:22<13:59, 1.84s/it]
224
  3%|▎ | 13/469 [00:24<14:09, 1.86s/it]
225
  3%|▎ | 14/469 [00:26<14:18, 1.89s/it]
226
  3%|▎ | 15/469 [00:28<14:01, 1.85s/it]
227
  3%|▎ | 16/469 [00:30<14:36, 1.94s/it]
228
  4%|▎ | 17/469 [00:31<14:12, 1.89s/it]
229
  4%|▍ | 18/469 [00:33<13:58, 1.86s/it]
230
  4%|▍ | 19/469 [00:36<14:55, 1.99s/it]
231
  4%|▍ | 20/469 [00:38<14:53, 1.99s/it]
232
  4%|▍ | 21/469 [00:40<14:50, 1.99s/it]
233
  5%|▍ | 22/469 [00:41<14:43, 1.98s/it]
234
  5%|▍ | 23/469 [00:44<15:04, 2.03s/it]
235
  5%|▌ | 24/469 [00:45<14:34, 1.97s/it]
236
  5%|▌ | 25/469 [00:47<14:03, 1.90s/it]
237
  6%|▌ | 26/469 [00:49<13:42, 1.86s/it]
238
  6%|▌ | 27/469 [00:51<13:31, 1.84s/it]
239
  6%|▌ | 28/469 [00:53<13:35, 1.85s/it]
240
  6%|▌ | 29/469 [00:54<13:24, 1.83s/it]
241
  6%|▋ | 30/469 [00:56<13:30, 1.85s/it]
242
  7%|▋ | 31/469 [00:58<13:10, 1.80s/it]
243
  7%|▋ | 32/469 [01:00<13:43, 1.88s/it]
244
  7%|▋ | 33/469 [01:02<13:23, 1.84s/it]
245
  7%|▋ | 34/469 [01:05<15:15, 2.10s/it]
246
  7%|▋ | 35/469 [01:06<14:40, 2.03s/it]
247
  8%|▊ | 36/469 [01:08<14:23, 1.99s/it]
248
  8%|▊ | 37/469 [01:10<13:51, 1.93s/it]
249
  8%|▊ | 38/469 [01:12<13:26, 1.87s/it]
250
  8%|▊ | 39/469 [01:14<13:13, 1.85s/it]
251
  9%|▊ | 40/469 [01:15<13:15, 1.85s/it]
252
  9%|▊ | 41/469 [01:17<13:14, 1.86s/it]
253
  9%|▉ | 42/469 [01:19<12:58, 1.82s/it]
254
  9%|▉ | 43/469 [01:21<13:04, 1.84s/it]
255
  9%|▉ | 44/469 [01:23<12:55, 1.82s/it]
256
  10%|▉ | 45/469 [01:25<12:56, 1.83s/it]
257
  10%|▉ | 46/469 [01:26<12:50, 1.82s/it]
258
  10%|█ | 47/469 [01:28<12:47, 1.82s/it]
259
  10%|█ | 48/469 [01:30<12:41, 1.81s/it]
260
  10%|█ | 49/469 [01:32<12:52, 1.84s/it]
261
  11%|█ | 50/469 [01:34<13:00, 1.86s/it]
262
  11%|█ | 51/469 [01:36<12:46, 1.83s/it]
263
  11%|█ | 52/469 [01:37<12:41, 1.83s/it]
264
  11%|█▏ | 53/469 [01:39<12:40, 1.83s/it]
265
  12%|█▏ | 54/469 [01:41<12:46, 1.85s/it]
266
  12%|█▏ | 55/469 [01:44<15:33, 2.25s/it]
267
  12%|█▏ | 56/469 [01:46<14:24, 2.09s/it]
268
  12%|█▏ | 57/469 [01:48<13:41, 1.99s/it]
269
  12%|█▏ | 58/469 [01:50<13:22, 1.95s/it]
270
  13%|█▎ | 59/469 [01:51<12:54, 1.89s/it]
271
  13%|█▎ | 60/469 [01:53<12:33, 1.84s/it]
272
  13%|█▎ | 61/469 [01:55<12:23, 1.82s/it]
273
  13%|█▎ | 62/469 [01:57<12:29, 1.84s/it]
274
  13%|█▎ | 63/469 [01:59<12:42, 1.88s/it]
275
  14%|█▎ | 64/469 [02:01<12:43, 1.89s/it]
276
  14%|█▍ | 65/469 [02:02<12:29, 1.86s/it]
277
  14%|█▍ | 66/469 [02:04<12:49, 1.91s/it]
278
  14%|█▍ | 67/469 [02:06<12:50, 1.92s/it]
279
  14%|█▍ | 68/469 [02:08<12:34, 1.88s/it]
280
  15%|█▍ | 69/469 [02:10<12:22, 1.86s/it]
281
  15%|█▍ | 70/469 [02:12<12:20, 1.86s/it]
282
  15%|█▌ | 71/469 [02:14<12:10, 1.84s/it]
283
  15%|█▌ | 72/469 [02:15<12:06, 1.83s/it]
284
  16%|█▌ | 73/469 [02:17<12:00, 1.82s/it]
285
  16%|█▌ | 74/469 [02:19<12:09, 1.85s/it]
286
  16%|█▌ | 75/469 [02:21<12:18, 1.87s/it]
287
  16%|█▌ | 76/469 [02:23<12:13, 1.87s/it]
288
  16%|█▋ | 77/469 [02:25<12:16, 1.88s/it]
289
  17%|█▋ | 78/469 [02:27<12:45, 1.96s/it]
290
  17%|█▋ | 79/469 [02:29<12:17, 1.89s/it]
291
  17%|█▋ | 80/469 [02:31<12:07, 1.87s/it]
292
  17%|█▋ | 81/469 [02:32<11:53, 1.84s/it]
293
  17%|█▋ | 82/469 [02:34<12:09, 1.89s/it]
294
  18%|█▊ | 83/469 [02:36<12:01, 1.87s/it]
295
  18%|█▊ | 84/469 [02:38<11:45, 1.83s/it]
296
  18%|█▊ | 85/469 [02:40<12:13, 1.91s/it]
297
  18%|█▊ | 86/469 [02:42<12:12, 1.91s/it]
298
  19%|█▊ | 87/469 [02:44<11:53, 1.87s/it]
299
  19%|█▉ | 88/469 [02:45<11:32, 1.82s/it]
300
  19%|█▉ | 89/469 [02:47<11:28, 1.81s/it]
301
  19%|█▉ | 90/469 [02:49<11:52, 1.88s/it]
302
  19%|█▉ | 91/469 [02:51<11:34, 1.84s/it]
303
  20%|█▉ | 92/469 [02:53<11:36, 1.85s/it]
304
  20%|█▉ | 93/469 [02:55<11:36, 1.85s/it]
305
  20%|██ | 94/469 [02:56<11:27, 1.83s/it]
306
  20%|██ | 95/469 [02:58<11:24, 1.83s/it]
307
  20%|██ | 96/469 [03:00<11:19, 1.82s/it]
308
  21%|██ | 97/469 [03:02<11:25, 1.84s/it]
309
  21%|██ | 98/469 [03:04<11:45, 1.90s/it]
310
  21%|██ | 99/469 [03:06<11:45, 1.91s/it]
311
  21%|██▏ | 100/469 [03:08<11:30, 1.87s/it]
312
  22%|██▏ | 101/469 [03:10<12:35, 2.05s/it]
313
  22%|██▏ | 102/469 [03:12<12:40, 2.07s/it]
314
  22%|██▏ | 103/469 [03:14<12:10, 2.00s/it]
315
  22%|██▏ | 104/469 [03:16<11:49, 1.94s/it]
316
  22%|██▏ | 105/469 [03:18<11:38, 1.92s/it]
317
  23%|██▎ | 106/469 [03:20<11:50, 1.96s/it]
318
  23%|██▎ | 107/469 [03:22<11:29, 1.91s/it]
319
  23%|██▎ | 108/469 [03:24<12:19, 2.05s/it]
320
  23%|██▎ | 109/469 [03:26<11:50, 1.97s/it]
321
  23%|██▎ | 110/469 [03:28<12:21, 2.06s/it]
322
  24%|██▎ | 111/469 [03:30<12:00, 2.01s/it]
323
  24%|██▍ | 112/469 [03:32<11:37, 1.95s/it]
324
  24%|██▍ | 113/469 [03:34<11:13, 1.89s/it]
325
  24%|██▍ | 114/469 [03:36<11:26, 1.93s/it]
326
  25%|██▍ | 115/469 [03:37<11:06, 1.88s/it]
327
  25%|██▍ | 116/469 [03:39<11:27, 1.95s/it]
328
  25%|██▍ | 117/469 [03:42<12:08, 2.07s/it]
329
  25%|██▌ | 118/469 [03:44<11:37, 1.99s/it]
330
  25%|██▌ | 119/469 [03:45<11:14, 1.93s/it]
331
  26%|██▌ | 120/469 [03:47<10:53, 1.87s/it]
332
  26%|██▌ | 121/469 [03:49<10:39, 1.84s/it]
333
  26%|██▌ | 122/469 [03:51<10:28, 1.81s/it]
334
  26%|██▌ | 123/469 [03:52<10:23, 1.80s/it]
335
  26%|██▋ | 124/469 [03:54<10:16, 1.79s/it]
336
  27%|██▋ | 125/469 [03:56<10:26, 1.82s/it]
337
  27%|██▋ | 126/469 [03:58<10:19, 1.81s/it]
338
  27%|██▋ | 127/469 [04:00<10:05, 1.77s/it]
339
  27%|██▋ | 128/469 [04:01<10:13, 1.80s/it]
340
  28%|██▊ | 129/469 [04:04<10:47, 1.91s/it]
341
  28%|██▊ | 130/469 [04:05<10:30, 1.86s/it]
342
  28%|██▊ | 131/469 [04:08<11:02, 1.96s/it]
343
  28%|██▊ | 132/469 [04:10<11:07, 1.98s/it]
344
  28%|██▊ | 133/469 [04:11<11:00, 1.97s/it]
345
  29%|██▊ | 134/469 [04:13<10:36, 1.90s/it]
346
  29%|██▉ | 135/469 [04:15<10:33, 1.90s/it]
347
  29%|██▉ | 136/469 [04:17<10:28, 1.89s/it]
348
  29%|██▉ | 137/469 [04:23<16:46, 3.03s/it]
349
  29%|██▉ | 138/469 [04:24<14:39, 2.66s/it]
350
  30%|██▉ | 139/469 [04:27<13:43, 2.49s/it]
351
  30%|██▉ | 140/469 [04:29<13:38, 2.49s/it]
352
  30%|███ | 141/469 [04:31<12:21, 2.26s/it]
353
  30%|███ | 142/469 [04:33<11:27, 2.10s/it]
354
  30%|███ | 143/469 [04:34<10:47, 1.98s/it]
355
  31%|███ | 144/469 [04:36<10:23, 1.92s/it]
356
  31%|███ | 145/469 [04:38<10:18, 1.91s/it]
357
  31%|███ | 146/469 [04:40<10:36, 1.97s/it]
358
  31%|███▏ | 147/469 [04:42<10:39, 1.99s/it]
359
  32%|███▏ | 148/469 [04:44<10:50, 2.03s/it]
360
  32%|███▏ | 149/469 [04:46<10:30, 1.97s/it]
361
  32%|███▏ | 150/469 [04:48<10:12, 1.92s/it]
362
  32%|███▏ | 151/469 [04:50<09:55, 1.87s/it]
363
  32%|███▏ | 152/469 [04:51<10:00, 1.89s/it]
364
  33%|███▎ | 153/469 [04:53<10:00, 1.90s/it]
365
  33%|███▎ | 154/469 [04:55<09:56, 1.89s/it]
366
  33%|███▎ | 155/469 [04:57<09:55, 1.90s/it]
367
  33%|███▎ | 156/469 [04:59<09:38, 1.85s/it]
368
  33%|███▎ | 157/469 [05:01<09:29, 1.83s/it]
369
  34%|███▎ | 158/469 [05:03<09:30, 1.84s/it]
370
  34%|███▍ | 159/469 [05:04<09:18, 1.80s/it]
371
  34%|███▍ | 160/469 [05:07<09:59, 1.94s/it]
372
  34%|███▍ | 161/469 [05:08<09:52, 1.93s/it]
373
  35%|███▍ | 162/469 [05:10<09:53, 1.93s/it]
374
  35%|███▍ | 163/469 [05:12<09:41, 1.90s/it]
375
  35%|███▍ | 164/469 [05:14<09:34, 1.88s/it]
376
  35%|███▌ | 165/469 [05:16<09:32, 1.88s/it]
377
  35%|███▌ | 166/469 [05:18<09:23, 1.86s/it]
378
  36%|███▌ | 167/469 [05:20<09:39, 1.92s/it]
379
  36%|███▌ | 168/469 [05:22<09:30, 1.90s/it]
380
  36%|███▌ | 169/469 [05:23<09:21, 1.87s/it]
381
  36%|███▌ | 170/469 [05:25<09:10, 1.84s/it]
382
  36%|███▋ | 171/469 [05:27<09:20, 1.88s/it]
383
  37%|███▋ | 172/469 [05:29<09:07, 1.84s/it]
384
  37%|███▋ | 173/469 [05:31<09:01, 1.83s/it]
385
  37%|███▋ | 174/469 [05:33<09:00, 1.83s/it]
386
  37%|███▋ | 175/469 [05:34<08:49, 1.80s/it]
387
  38%|███▊ | 176/469 [05:36<09:09, 1.88s/it]
388
  38%|███▊ | 177/469 [05:38<09:08, 1.88s/it]
389
  38%|███▊ | 178/469 [05:40<09:00, 1.86s/it]
390
  38%|███▊ | 179/469 [05:42<08:51, 1.83s/it]
391
  38%|███▊ | 180/469 [05:44<08:51, 1.84s/it]
392
  39%|███▊ | 181/469 [05:46<10:03, 2.10s/it]
393
  39%|███▉ | 182/469 [05:48<09:44, 2.04s/it]
394
  39%|███▉ | 183/469 [05:50<09:21, 1.96s/it]
395
  39%|███▉ | 184/469 [05:52<09:23, 1.98s/it]
396
  39%|███▉ | 185/469 [05:54<09:03, 1.92s/it]
397
  40%|███▉ | 186/469 [05:56<08:53, 1.88s/it]
398
  40%|███▉ | 187/469 [05:57<08:40, 1.85s/it]
399
  40%|████ | 188/469 [05:59<08:58, 1.92s/it]
400
  40%|████ | 189/469 [06:01<08:50, 1.89s/it]
401
  41%|████ | 190/469 [06:03<08:55, 1.92s/it]
402
  41%|████ | 191/469 [06:05<08:49, 1.91s/it]
403
  41%|████ | 192/469 [06:07<08:48, 1.91s/it]
404
  41%|████ | 193/469 [06:09<08:43, 1.90s/it]
405
  41%|████▏ | 194/469 [06:11<08:31, 1.86s/it]
406
  42%|████▏ | 195/469 [06:12<08:18, 1.82s/it]
407
  42%|████▏ | 196/469 [06:14<08:06, 1.78s/it]
408
  42%|████▏ | 197/469 [06:16<08:05, 1.79s/it]
409
  42%|████▏ | 198/469 [06:18<08:14, 1.83s/it]
410
  42%|████▏ | 199/469 [06:20<08:11, 1.82s/it]
411
  43%|████▎ | 200/469 [06:22<08:36, 1.92s/it]
412
  43%|████▎ | 201/469 [06:24<08:25, 1.89s/it]
413
  43%|████▎ | 202/469 [06:26<08:28, 1.91s/it]
414
  43%|████▎ | 203/469 [06:27<08:24, 1.90s/it]
415
  43%|████▎ | 204/469 [06:29<08:31, 1.93s/it]
416
  44%|████▎ | 205/469 [06:31<08:20, 1.89s/it]
417
  44%|████▍ | 206/469 [06:33<08:09, 1.86s/it]
418
  44%|████▍ | 207/469 [06:35<08:13, 1.89s/it]
419
  44%|████▍ | 208/469 [06:37<08:21, 1.92s/it]
420
  45%|████▍ | 209/469 [06:39<08:55, 2.06s/it]
421
  45%|████▍ | 210/469 [06:41<08:32, 1.98s/it]
422
  45%|████▍ | 211/469 [06:43<08:21, 1.94s/it]
423
  45%|████▌ | 212/469 [06:45<08:19, 1.94s/it]
424
  45%|████▌ | 213/469 [06:47<08:13, 1.93s/it]
425
  46%|████▌ | 214/469 [06:49<08:23, 1.98s/it]
426
  46%|████▌ | 215/469 [06:51<08:15, 1.95s/it]
427
  46%|████▌ | 216/469 [06:53<08:02, 1.91s/it]
428
  46%|████▋ | 217/469 [06:55<07:57, 1.90s/it]
429
  46%|████▋ | 218/469 [06:56<07:51, 1.88s/it]
430
  47%|████▋ | 219/469 [06:58<07:40, 1.84s/it]
431
  47%|████▋ | 220/469 [07:00<07:52, 1.90s/it]
432
  47%|████▋ | 221/469 [07:02<07:44, 1.87s/it]
433
  47%|████▋ | 222/469 [07:04<08:09, 1.98s/it]
434
  48%|████▊ | 223/469 [07:07<09:19, 2.27s/it]
435
  48%|████▊ | 224/469 [07:09<08:42, 2.13s/it]
436
  48%|████▊ | 225/469 [07:11<08:22, 2.06s/it]
437
  48%|████▊ | 226/469 [07:13<08:01, 1.98s/it]
438
  48%|████▊ | 227/469 [07:15<07:51, 1.95s/it]
439
  49%|████▊ | 228/469 [07:16<07:43, 1.92s/it]
440
  49%|████▉ | 229/469 [07:18<07:35, 1.90s/it]
441
  49%|████▉ | 230/469 [07:20<07:28, 1.88s/it]
442
  49%|████▉ | 231/469 [07:22<07:44, 1.95s/it]
443
  49%|████▉ | 232/469 [07:24<07:32, 1.91s/it]
444
  50%|████▉ | 233/469 [07:26<07:24, 1.88s/it]
445
  50%|████▉ | 234/469 [07:28<07:26, 1.90s/it]
446
  50%|█████ | 235/469 [07:30<07:18, 1.88s/it]
447
  50%|█████ | 236/469 [07:32<07:21, 1.89s/it]
448
  51%|█████ | 237/469 [07:33<07:15, 1.88s/it]
449
  51%|█████ | 238/469 [07:35<07:31, 1.95s/it]
450
  51%|█████ | 239/469 [07:37<07:25, 1.94s/it]
451
  51%|█████ | 240/469 [07:39<07:26, 1.95s/it]
452
  51%|█████▏ | 241/469 [07:42<08:39, 2.28s/it]
453
  52%|█████▏ | 242/469 [07:44<08:17, 2.19s/it]
454
  52%|█████▏ | 243/469 [07:47<08:19, 2.21s/it]
455
  52%|█████▏ | 244/469 [07:49<07:57, 2.12s/it]
456
  52%|█████▏ | 245/469 [07:50<07:35, 2.03s/it]
457
  52%|█████▏ | 246/469 [07:52<07:26, 2.00s/it]
458
  53%|█████▎ | 247/469 [07:54<07:19, 1.98s/it]
459
  53%|█████▎ | 248/469 [07:56<07:11, 1.95s/it]
460
  53%|█████▎ | 249/469 [07:58<07:15, 1.98s/it]
461
  53%|█████▎ | 250/469 [08:00<06:58, 1.91s/it]
462
  54%|█████▎ | 251/469 [08:02<06:51, 1.89s/it]
463
  54%|█████▎ | 252/469 [08:04<06:42, 1.85s/it]
464
  54%|█████▍ | 253/469 [08:05<06:33, 1.82s/it]
465
  54%|█████▍ | 254/469 [08:07<06:29, 1.81s/it]
466
  54%|█████▍ | 255/469 [08:09<06:32, 1.83s/it]
467
  55%|█████▍ | 256/469 [08:11<06:24, 1.80s/it]
468
  55%|█████▍ | 257/469 [08:12<06:18, 1.79s/it]
469
  55%|█████▌ | 258/469 [08:14<06:25, 1.83s/it]
470
  55%|█████▌ | 259/469 [08:16<06:25, 1.83s/it]
471
  55%|█████▌ | 260/469 [08:18<06:19, 1.82s/it]
472
  56%|█████▌ | 261/469 [08:20<06:42, 1.94s/it]
473
  56%|█████▌ | 262/469 [08:22<06:32, 1.90s/it]
474
  56%|█████▌ | 263/469 [08:24<06:25, 1.87s/it]
475
  56%|█████▋ | 264/469 [08:26<06:17, 1.84s/it]
476
  57%|█████▋ | 265/469 [08:28<06:20, 1.87s/it]
477
  57%|█████▋ | 266/469 [08:29<06:17, 1.86s/it]
478
  57%|█████▋ | 267/469 [08:31<06:16, 1.86s/it]
479
  57%|█████▋ | 268/469 [08:33<06:12, 1.86s/it]
480
  57%|█████▋ | 269/469 [08:35<06:41, 2.01s/it]
481
  58%|█████▊ | 270/469 [08:37<06:24, 1.93s/it]
482
  58%|█████▊ | 271/469 [08:39<06:10, 1.87s/it]
483
  58%|█████▊ | 272/469 [08:41<06:09, 1.87s/it]
484
  58%|█████▊ | 273/469 [08:43<06:24, 1.96s/it]
485
  58%|█████▊ | 274/469 [08:45<06:11, 1.90s/it]
486
  59%|█████▊ | 275/469 [08:47<06:10, 1.91s/it]
487
  59%|█████▉ | 276/469 [08:48<06:03, 1.88s/it]
488
  59%|█████▉ | 277/469 [08:50<06:01, 1.88s/it]
489
  59%|█████▉ | 278/469 [08:52<06:08, 1.93s/it]
490
  59%|█████▉ | 279/469 [08:54<05:57, 1.88s/it]
491
  60%|█████▉ | 280/469 [08:56<06:11, 1.97s/it]
492
  60%|█████▉ | 281/469 [08:58<06:07, 1.95s/it]
493
  60%|██████ | 282/469 [09:00<05:57, 1.91s/it]
494
  60%|██████ | 283/469 [09:02<05:48, 1.87s/it]
495
  61%|██████ | 284/469 [09:04<05:44, 1.86s/it]
496
  61%|██████ | 285/469 [09:06<05:45, 1.88s/it]
497
  61%|██████ | 286/469 [09:08<06:11, 2.03s/it]
498
  61%|██████ | 287/469 [09:10<06:16, 2.07s/it]
499
  61%|██████▏ | 288/469 [09:12<05:59, 1.99s/it]
500
  62%|██████▏ | 289/469 [09:14<05:50, 1.95s/it]
501
  62%|██████▏ | 290/469 [09:16<05:39, 1.90s/it]
502
  62%|██████▏ | 291/469 [09:17<05:28, 1.85s/it]
503
  62%|██████▏ | 292/469 [09:19<05:25, 1.84s/it]
504
  62%|██████▏ | 293/469 [09:21<05:23, 1.84s/it]
505
  63%|██████▎ | 294/469 [09:23<05:31, 1.89s/it]
506
  63%|██████▎ | 295/469 [09:25<05:20, 1.84s/it]
507
  63%|██████▎ | 296/469 [09:26<05:13, 1.81s/it]
508
  63%|██████▎ | 297/469 [09:28<05:08, 1.80s/it]
509
  64%|██████▎ | 298/469 [09:30<05:04, 1.78s/it]
510
  64%|██████▍ | 299/469 [09:32<05:05, 1.80s/it]
511
  64%|██████▍ | 300/469 [09:34<05:08, 1.82s/it]
512
  64%|██████▍ | 301/469 [09:36<05:07, 1.83s/it]
513
  64%|██████▍ | 302/469 [09:37<05:04, 1.82s/it]
514
  65%|██████▍ | 303/469 [09:39<05:06, 1.84s/it]
515
  65%|██████▍ | 304/469 [09:41<05:06, 1.86s/it]
516
  65%|██████▌ | 305/469 [09:43<05:02, 1.84s/it]
517
  65%|██████▌ | 306/469 [09:45<05:01, 1.85s/it]
518
  65%|██████▌ | 307/469 [09:47<05:00, 1.85s/it]
519
  66%|██████▌ | 308/469 [09:49<05:02, 1.88s/it]
520
  66%|██████▌ | 309/469 [09:50<05:00, 1.88s/it]
521
  66%|██████▌ | 310/469 [09:52<04:53, 1.85s/it]
522
  66%|██████▋ | 311/469 [09:54<04:54, 1.86s/it]
523
  67%|██████▋ | 312/469 [09:56<04:58, 1.90s/it]
524
  67%|██████▋ | 313/469 [09:58<04:50, 1.86s/it]
525
  67%|██████▋ | 314/469 [10:00<04:51, 1.88s/it]
526
  67%|██████▋ | 315/469 [10:02<04:52, 1.90s/it]
527
  67%|██████▋ | 316/469 [10:04<04:53, 1.92s/it]
528
  68%|██████▊ | 317/469 [10:06<04:45, 1.88s/it]
529
  68%|██████▊ | 318/469 [10:08<04:57, 1.97s/it]
530
  68%|██████▊ | 319/469 [10:09<04:46, 1.91s/it]
531
  68%|██████▊ | 320/469 [10:11<04:46, 1.92s/it]
532
  68%|██████▊ | 321/469 [10:14<05:08, 2.09s/it]
533
  69%|██████▊ | 322/469 [10:16<05:13, 2.14s/it]
534
  69%|██████▉ | 323/469 [10:18<05:02, 2.07s/it]
535
  69%|██████▉ | 324/469 [10:20<04:52, 2.02s/it]
536
  69%|██████▉ | 325/469 [10:22<04:42, 1.96s/it]
537
  70%|██████▉ | 326/469 [10:24<04:34, 1.92s/it]
538
  70%|██████▉ | 327/469 [10:26<04:41, 1.98s/it]
539
  70%|██████▉ | 328/469 [10:28<04:31, 1.93s/it]
540
  70%|███████ | 329/469 [10:29<04:24, 1.89s/it]
541
  70%|███████ | 330/469 [10:31<04:19, 1.87s/it]
542
  71%|███████ | 331/469 [10:33<04:13, 1.84s/it]
543
  71%|███████ | 332/469 [10:35<04:24, 1.93s/it]
544
  71%|███████ | 333/469 [10:37<04:16, 1.89s/it]
545
  71%|███████ | 334/469 [10:42<06:23, 2.84s/it]
546
  71%|███████▏ | 335/469 [10:44<05:35, 2.51s/it]
547
  72%|███████▏ | 336/469 [10:45<05:04, 2.29s/it]
548
  72%|███████▏ | 337/469 [10:47<04:42, 2.14s/it]
549
  72%|███████▏ | 338/469 [10:49<04:41, 2.15s/it]
550
  72%|███████▏ | 339/469 [10:51<04:33, 2.11s/it]
551
  72%|███████▏ | 340/469 [10:53<04:21, 2.03s/it]
552
  73%|███████▎ | 341/469 [10:55<04:17, 2.01s/it]
553
  73%|███████▎ | 342/469 [10:57<04:19, 2.04s/it]
554
  73%|███████▎ | 343/469 [10:59<04:18, 2.05s/it]
555
  73%|███████▎ | 344/469 [11:01<04:04, 1.96s/it]
556
  74%|███████▎ | 345/469 [11:03<04:00, 1.94s/it]
557
  74%|███████▍ | 346/469 [11:05<03:52, 1.89s/it]
558
  74%|███████▍ | 347/469 [11:07<03:45, 1.85s/it]
559
  74%|███████▍ | 348/469 [11:08<03:39, 1.82s/it]
560
  74%|███████▍ | 349/469 [11:10<03:38, 1.82s/it]
561
  75%|███████▍ | 350/469 [11:12<03:42, 1.87s/it]
562
  75%|███████▍ | 351/469 [11:14<03:44, 1.90s/it]
563
  75%|███████▌ | 352/469 [11:16<03:40, 1.88s/it]
564
  75%|███████▌ | 353/469 [11:25<07:38, 3.95s/it]
565
  75%|███████▌ | 354/469 [11:27<06:26, 3.36s/it]
566
  76%|███████▌ | 355/469 [11:29<05:30, 2.90s/it]
567
  76%|███████▌ | 356/469 [11:31<05:08, 2.73s/it]
568
  76%|███████▌ | 357/469 [11:33<04:45, 2.55s/it]
569
  76%|███████▋ | 358/469 [11:35<04:23, 2.37s/it]
570
  77%|██��████▋ | 359/469 [11:37<04:10, 2.28s/it]
571
  77%|███████▋ | 360/469 [11:39<03:51, 2.12s/it]
572
  77%|███████▋ | 361/469 [11:41<03:37, 2.02s/it]
573
  77%|███████▋ | 362/469 [12:00<13:03, 7.32s/it]
574
  77%|███████▋ | 363/469 [12:02<10:03, 5.69s/it]
575
  78%|███████▊ | 364/469 [12:04<07:55, 4.53s/it]
576
  78%|███████▊ | 365/469 [12:06<06:27, 3.73s/it]
577
  78%|███████▊ | 366/469 [12:08<05:26, 3.17s/it]
578
  78%|███████▊ | 367/469 [12:09<04:40, 2.75s/it]
579
  78%|███████▊ | 368/469 [12:12<04:21, 2.59s/it]
580
  79%|███████▊ | 369/469 [12:13<03:55, 2.36s/it]
581
  79%|███████▉ | 370/469 [12:15<03:43, 2.25s/it]
582
  79%|███████▉ | 371/469 [12:17<03:28, 2.13s/it]
583
  79%|███████▉ | 372/469 [12:19<03:21, 2.07s/it]
584
  80%|███████▉ | 373/469 [12:21<03:10, 1.99s/it]
585
  80%|███████▉ | 374/469 [12:23<03:02, 1.92s/it]
586
  80%|███████▉ | 375/469 [12:25<02:55, 1.86s/it]
587
  80%|████████ | 376/469 [12:26<02:51, 1.84s/it]
588
  80%|████████ | 377/469 [12:28<02:51, 1.86s/it]
589
  81%|████████ | 378/469 [12:30<02:47, 1.84s/it]
590
  81%|████████ | 379/469 [12:32<02:43, 1.81s/it]
591
  81%|████████ | 380/469 [12:34<02:48, 1.89s/it]
592
  81%|████████ | 381/469 [12:36<02:45, 1.88s/it]
593
  81%|████████▏ | 382/469 [12:38<02:43, 1.88s/it]
594
  82%|████████▏ | 383/469 [12:39<02:38, 1.85s/it]
595
  82%|████████▏ | 384/469 [12:41<02:36, 1.84s/it]
596
  82%|████████▏ | 385/469 [12:43<02:37, 1.88s/it]
597
  82%|████████▏ | 386/469 [12:45<02:35, 1.87s/it]
598
  83%|████████▎ | 387/469 [12:47<02:45, 2.02s/it]
599
  83%|████████▎ | 388/469 [12:49<02:37, 1.94s/it]
600
  83%|████████▎ | 389/469 [12:51<02:30, 1.88s/it]
601
  83%|████████▎ | 390/469 [12:53<02:32, 1.93s/it]
602
  83%|████████▎ | 391/469 [12:55<02:29, 1.92s/it]
603
  84%|████████▎ | 392/469 [12:57<02:23, 1.87s/it]
604
  84%|████████▍ | 393/469 [13:00<03:00, 2.38s/it]
605
  84%|████████▍ | 394/469 [13:02<02:47, 2.23s/it]
606
  84%|████████▍ | 395/469 [13:04<02:36, 2.12s/it]
607
  84%|████████▍ | 396/469 [13:06<02:27, 2.02s/it]
608
  85%|████████▍ | 397/469 [13:08<02:24, 2.00s/it]
609
  85%|████████▍ | 398/469 [13:09<02:19, 1.96s/it]
610
  85%|████████▌ | 399/469 [13:11<02:14, 1.92s/it]
611
  85%|████████▌ | 400/469 [13:13<02:08, 1.86s/it]
612
  86%|████████▌ | 401/469 [13:15<02:05, 1.85s/it]
613
  86%|████████▌ | 402/469 [13:17<02:02, 1.84s/it]
614
  86%|████████▌ | 403/469 [13:18<01:59, 1.81s/it]
615
  86%|████████▌ | 404/469 [13:20<02:02, 1.89s/it]
616
  86%|████████▋ | 405/469 [13:22<01:59, 1.87s/it]
617
  87%|████████▋ | 406/469 [13:24<01:56, 1.85s/it]
618
  87%|████████▋ | 407/469 [13:26<01:52, 1.81s/it]
619
  87%|████████▋ | 408/469 [13:28<01:50, 1.81s/it]
620
  87%|████████▋ | 409/469 [13:30<01:50, 1.84s/it]
621
  87%|████████▋ | 410/469 [13:32<01:55, 1.95s/it]
622
  88%|████████▊ | 411/469 [13:33<01:49, 1.89s/it]
623
  88%|████████▊ | 412/469 [13:35<01:48, 1.90s/it]
624
  88%|████████▊ | 413/469 [13:37<01:45, 1.89s/it]
625
  88%|████████▊ | 414/469 [13:39<01:42, 1.86s/it]
626
  88%|████████▊ | 415/469 [13:41<01:38, 1.82s/it]
627
  89%|████████▊ | 416/469 [13:43<01:39, 1.87s/it]
628
  89%|████████▉ | 417/469 [13:45<01:38, 1.90s/it]
629
  89%|████████▉ | 418/469 [13:50<02:30, 2.96s/it]
630
  89%|████████▉ | 419/469 [13:52<02:12, 2.65s/it]
631
  90%|████████▉ | 420/469 [13:54<02:01, 2.47s/it]
632
  90%|████████▉ | 421/469 [13:56<01:48, 2.26s/it]
633
  90%|████████▉ | 422/469 [13:58<01:39, 2.12s/it]
634
  90%|█████████ | 423/469 [14:00<01:33, 2.02s/it]
635
  90%|█████████ | 424/469 [14:01<01:29, 2.00s/it]
636
  91%|█████████ | 425/469 [14:03<01:24, 1.93s/it]
637
  91%|█████████ | 426/469 [14:05<01:22, 1.93s/it]
638
  91%|█████████ | 427/469 [14:07<01:19, 1.90s/it]
639
  91%|█████████▏| 428/469 [14:10<01:30, 2.21s/it]
640
  91%|█████████▏| 429/469 [14:12<01:23, 2.09s/it]
641
  92%|█████████▏| 430/469 [14:13<01:17, 2.00s/it]
642
  92%|█████████▏| 431/469 [14:16<01:17, 2.05s/it]
643
  92%|█████████▏| 432/469 [14:17<01:12, 1.96s/it]
644
  92%|█████████▏| 433/469 [14:19<01:09, 1.93s/it]
645
  93%|█████████▎| 434/469 [14:21<01:06, 1.89s/it]
646
  93%|█████████▎| 435/469 [14:23<01:03, 1.88s/it]
647
  93%|█████████▎| 436/469 [14:25<01:06, 2.01s/it]
648
  93%|█████████▎| 437/469 [14:27<01:02, 1.95s/it]
649
  93%|█████████▎| 438/469 [14:29<00:59, 1.93s/it]
650
  94%|█████████▎| 439/469 [14:31<00:57, 1.91s/it]
651
  94%|█████████▍| 440/469 [14:33<00:59, 2.04s/it]
652
  94%|█████████▍| 441/469 [14:35<00:55, 1.99s/it]
653
  94%|█████████▍| 442/469 [14:37<00:53, 1.97s/it]
654
  94%|█████████▍| 443/469 [14:39<00:49, 1.92s/it]
655
  95%|█████████▍| 444/469 [14:41<00:47, 1.89s/it]
656
  95%|█████████▍| 445/469 [14:42<00:44, 1.87s/it]
657
  95%|█████████▌| 446/469 [14:44<00:44, 1.94s/it]
658
  95%|█████████▌| 447/469 [14:46<00:41, 1.88s/it]
659
  96%|█████████▌| 448/469 [14:48<00:38, 1.83s/it]
660
  96%|█████████▌| 449/469 [14:50<00:37, 1.89s/it]
661
  96%|█████████▌| 450/469 [14:52<00:34, 1.84s/it]
662
  96%|█████████▌| 451/469 [14:53<00:32, 1.81s/it]
663
  96%|█████████▋| 452/469 [14:55<00:30, 1.80s/it]
664
  97%|█████████▋| 453/469 [14:57<00:28, 1.77s/it]
665
  97%|█████████▋| 454/469 [14:59<00:27, 1.85s/it]
666
  97%|█████████▋| 455/469 [15:01<00:25, 1.84s/it]
667
  97%|█████████▋| 456/469 [15:03<00:24, 1.92s/it]
668
  97%|█████████▋| 457/469 [15:05<00:22, 1.88s/it]
669
  98%|█████████▊| 458/469 [15:07<00:20, 1.88s/it]
670
  98%|█████████▊| 459/469 [15:09<00:19, 1.92s/it]
671
  98%|█████████▊| 460/469 [15:10<00:17, 1.92s/it]
672
  98%|█████████▊| 461/469 [15:12<00:15, 1.90s/it]
673
  99%|█████████▊| 462/469 [15:15<00:14, 2.11s/it]
674
  99%|█████████▊| 463/469 [15:17<00:13, 2.18s/it]
675
  99%|█████████▉| 464/469 [15:19<00:10, 2.09s/it]
676
  99%|█████████▉| 465/469 [15:21<00:08, 2.07s/it]
677
  99%|█████████▉| 466/469 [15:23<00:06, 2.02s/it]
 
 
 
 
 
 
 
 
1
+ 2026-03-23 16:28:17.115582: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2
+ 2026-03-23 16:28:26.392554: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
3
+ To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
4
+ 2026-03-23 16:28:26.438747: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
5
+ 2026-03-23 16:28:26.438805: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: d9a147cbff28e342e3a570e4cd1afa4e-taskrole1-0
6
+ 2026-03-23 16:28:26.438830: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: d9a147cbff28e342e3a570e4cd1afa4e-taskrole1-0
7
+ 2026-03-23 16:28:26.438915: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
8
+ 2026-03-23 16:28:26.438959: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 535.154.5
9
+ 2026-03-23 16:28:26.776048: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
10
+ warming up TensorFlow...
11
+
12
  0%| | 0/1 [00:00<?, ?it/s]2026-03-23 16:28:27.563125: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
13
+
14
+ computing reference batch activations...
15
+
16
  0%| | 0/211 [00:00<?, ?it/s]
17
  0%| | 1/211 [00:02<10:16, 2.94s/it]
18
  1%| | 2/211 [00:04<08:15, 2.37s/it]
19
  1%|▏ | 3/211 [00:06<07:38, 2.21s/it]
20
  2%|▏ | 4/211 [00:08<07:20, 2.13s/it]
21
  2%|▏ | 5/211 [00:11<07:25, 2.16s/it]
22
  3%|▎ | 6/211 [00:12<06:59, 2.04s/it]
23
  3%|▎ | 7/211 [00:15<07:15, 2.13s/it]
24
  4%|▍ | 8/211 [00:17<06:55, 2.05s/it]
25
  4%|▍ | 9/211 [00:19<06:53, 2.05s/it]
26
  5%|▍ | 10/211 [00:21<06:37, 1.98s/it]
27
  5%|▌ | 11/211 [00:22<06:18, 1.89s/it]
28
  6%|▌ | 12/211 [00:24<06:13, 1.88s/it]
29
  6%|▌ | 13/211 [00:26<06:08, 1.86s/it]
30
  7%|▋ | 14/211 [00:28<06:02, 1.84s/it]
31
  7%|▋ | 15/211 [00:29<05:57, 1.82s/it]
32
  8%|▊ | 16/211 [00:31<05:55, 1.82s/it]
33
  8%|▊ | 17/211 [00:33<06:00, 1.86s/it]
34
  9%|▊ | 18/211 [00:35<06:01, 1.87s/it]
35
  9%|▉ | 19/211 [00:37<05:59, 1.87s/it]
36
  9%|▉ | 20/211 [00:39<05:51, 1.84s/it]
37
  10%|▉ | 21/211 [00:41<05:45, 1.82s/it]
38
  10%|█ | 22/211 [00:42<05:43, 1.82s/it]
39
  11%|█ | 23/211 [00:44<05:41, 1.82s/it]
40
  11%|█▏ | 24/211 [00:47<06:23, 2.05s/it]
41
  12%|█▏ | 25/211 [00:50<07:07, 2.30s/it]
42
  12%|█▏ | 26/211 [00:51<06:37, 2.15s/it]
43
  13%|█▎ | 27/211 [00:53<06:14, 2.04s/it]
44
  13%|█▎ | 28/211 [00:55<05:56, 1.95s/it]
45
  14%|█▎ | 29/211 [00:57<05:56, 1.96s/it]
46
  14%|█▍ | 30/211 [00:59<05:49, 1.93s/it]
47
  15%|█▍ | 31/211 [01:01<05:49, 1.94s/it]
48
  15%|█▌ | 32/211 [01:03<05:41, 1.91s/it]
49
  16%|█▌ | 33/211 [01:05<05:42, 1.93s/it]
50
  16%|█▌ | 34/211 [01:07<05:49, 1.97s/it]
51
  17%|█▋ | 35/211 [01:09<06:11, 2.11s/it]
52
  17%|█▋ | 36/211 [01:11<06:00, 2.06s/it]
53
  18%|█▊ | 37/211 [01:13<05:46, 1.99s/it]
54
  18%|█▊ | 38/211 [01:15<05:57, 2.07s/it]
55
  18%|█▊ | 39/211 [01:17<05:44, 2.01s/it]
56
  19%|█▉ | 40/211 [01:19<05:31, 1.94s/it]
57
  19%|█▉ | 41/211 [01:21<05:24, 1.91s/it]
58
  20%|█▉ | 42/211 [01:22<05:19, 1.89s/it]
59
  20%|██ | 43/211 [01:24<05:20, 1.91s/it]
60
  21%|██ | 44/211 [01:26<05:09, 1.86s/it]
61
  21%|██▏ | 45/211 [01:29<05:50, 2.11s/it]
62
  22%|██▏ | 46/211 [01:31<05:46, 2.10s/it]
63
  22%|██▏ | 47/211 [01:33<05:33, 2.03s/it]
64
  23%|██▎ | 48/211 [01:35<05:17, 1.95s/it]
65
  23%|██▎ | 49/211 [01:36<05:14, 1.94s/it]
66
  24%|██▎ | 50/211 [01:38<05:03, 1.88s/it]
67
  24%|██▍ | 51/211 [01:40<05:19, 2.00s/it]
68
  25%|██▍ | 52/211 [01:42<05:13, 1.97s/it]
69
  25%|██▌ | 53/211 [01:45<05:19, 2.02s/it]
70
  26%|██▌ | 54/211 [01:46<05:05, 1.95s/it]
71
  26%|██▌ | 55/211 [01:48<04:55, 1.89s/it]
72
  27%|██▋ | 56/211 [01:50<04:56, 1.92s/it]
73
  27%|██▋ | 57/211 [01:52<04:48, 1.88s/it]
74
  27%|██▋ | 58/211 [01:54<04:44, 1.86s/it]
75
  28%|██▊ | 59/211 [01:56<04:47, 1.89s/it]
76
  28%|██▊ | 60/211 [01:58<04:57, 1.97s/it]
77
  29%|██▉ | 61/211 [02:00<04:46, 1.91s/it]
78
  29%|██▉ | 62/211 [02:01<04:38, 1.87s/it]
79
  30%|██▉ | 63/211 [02:04<05:05, 2.07s/it]
80
  30%|███ | 64/211 [02:06<04:59, 2.04s/it]
81
  31%|███ | 65/211 [02:08<05:05, 2.09s/it]
82
  31%|███▏ | 66/211 [02:10<04:48, 1.99s/it]
83
  32%|███▏ | 67/211 [02:12<04:37, 1.93s/it]
84
  32%|███▏ | 68/211 [02:13<04:27, 1.87s/it]
85
  33%|███▎ | 69/211 [02:15<04:25, 1.87s/it]
86
  33%|███▎ | 70/211 [02:17<04:18, 1.83s/it]
87
  34%|███▎ | 71/211 [02:19<04:15, 1.83s/it]
88
  34%|███▍ | 72/211 [02:20<04:10, 1.80s/it]
89
  35%|███▍ | 73/211 [02:22<04:08, 1.80s/it]
90
  35%|███▌ | 74/211 [02:24<04:12, 1.84s/it]
91
  36%|███▌ | 75/211 [02:26<04:10, 1.84s/it]
92
  36%|███▌ | 76/211 [02:28<04:06, 1.83s/it]
93
  36%|███▋ | 77/211 [02:30<04:11, 1.88s/it]
94
  37%|███▋ | 78/211 [02:32<04:07, 1.86s/it]
95
  37%|███▋ | 79/211 [02:34<04:08, 1.88s/it]
96
  38%|███▊ | 80/211 [02:35<04:06, 1.88s/it]
97
  38%|███▊ | 81/211 [02:37<04:11, 1.93s/it]
98
  39%|███▉ | 82/211 [02:39<04:04, 1.89s/it]
99
  39%|███▉ | 83/211 [02:41<04:06, 1.92s/it]
100
  40%|███▉ | 84/211 [02:43<04:03, 1.92s/it]
101
  40%|████ | 85/211 [02:45<03:55, 1.87s/it]
102
  41%|████ | 86/211 [02:47<03:50, 1.84s/it]
103
  41%|████ | 87/211 [02:48<03:45, 1.82s/it]
104
  42%|████▏ | 88/211 [02:50<03:41, 1.80s/it]
105
  42%|████▏ | 89/211 [02:52<03:36, 1.78s/it]
106
  43%|████▎ | 90/211 [02:54<03:51, 1.91s/it]
107
  43%|████▎ | 91/211 [02:56<03:45, 1.88s/it]
108
  44%|████▎ | 92/211 [02:58<03:38, 1.84s/it]
109
  44%|████▍ | 93/211 [03:00<03:37, 1.85s/it]
110
  45%|████▍ | 94/211 [03:02<03:46, 1.94s/it]
111
  45%|████▌ | 95/211 [03:03<03:36, 1.87s/it]
112
  45%|████▌ | 96/211 [03:05<03:34, 1.87s/it]
113
  46%|████▌ | 97/211 [03:08<03:58, 2.09s/it]
114
  46%|████▋ | 98/211 [03:10<03:45, 1.99s/it]
115
  47%|████▋ | 99/211 [03:11<03:34, 1.91s/it]
116
  47%|████▋ | 100/211 [03:13<03:31, 1.91s/it]
117
  48%|████▊ | 101/211 [03:15<03:28, 1.89s/it]
118
  48%|████▊ | 102/211 [03:21<05:46, 3.18s/it]
119
  49%|████▉ | 103/211 [03:23<04:58, 2.76s/it]
120
  49%|████▉ | 104/211 [03:25<04:22, 2.45s/it]
121
  50%|████▉ | 105/211 [03:27<04:03, 2.30s/it]
122
  50%|█████ | 106/211 [03:29<03:45, 2.15s/it]
123
  51%|█████ | 107/211 [03:31<03:35, 2.07s/it]
124
  51%|█████ | 108/211 [03:32<03:29, 2.03s/it]
125
  52%|█████▏ | 109/211 [03:34<03:21, 1.97s/it]
126
  52%|█████▏ | 110/211 [03:37<03:36, 2.15s/it]
127
  53%|█████▎ | 111/211 [03:39<03:23, 2.03s/it]
128
  53%|█████▎ | 112/211 [03:41<03:17, 1.99s/it]
129
  54%|█████▎ | 113/211 [03:42<03:10, 1.95s/it]
130
  54%|█████▍ | 114/211 [03:44<03:03, 1.90s/it]
131
  55%|█████▍ | 115/211 [03:46<03:02, 1.91s/it]
132
  55%|█████▍ | 116/211 [03:48<02:59, 1.89s/it]
133
  55%|█████▌ | 117/211 [03:50<03:03, 1.96s/it]
134
  56%|█████▌ | 118/211 [03:52<02:56, 1.90s/it]
135
  56%|█████▋ | 119/211 [03:54<02:51, 1.86s/it]
136
  57%|█████▋ | 120/211 [03:55<02:49, 1.86s/it]
137
  57%|█████▋ | 121/211 [03:57<02:52, 1.91s/it]
138
  58%|█████▊ | 122/211 [03:59<02:47, 1.89s/it]
139
  58%|█████▊ | 123/211 [04:01<02:51, 1.94s/it]
140
  59%|█████▉ | 124/211 [04:03<02:47, 1.93s/it]
141
  59%|█████▉ | 125/211 [04:05<02:40, 1.87s/it]
142
  60%|█████▉ | 126/211 [04:07<02:46, 1.96s/it]
143
  60%|██████ | 127/211 [04:09<02:43, 1.95s/it]
144
  61%|██████ | 128/211 [04:11<02:42, 1.95s/it]
145
  61%|██████ | 129/211 [04:13<02:47, 2.04s/it]
146
  62%|██████▏ | 130/211 [04:15<02:41, 2.00s/it]
147
  62%|██████▏ | 131/211 [04:17<02:39, 2.00s/it]
148
  63%|██████▎ | 132/211 [04:19<02:36, 1.98s/it]
149
  63%|██████▎ | 133/211 [04:21<02:31, 1.94s/it]
150
  64%|██████▎ | 134/211 [04:23<02:26, 1.91s/it]
151
  64%|██████▍ | 135/211 [04:25<02:21, 1.86s/it]
152
  64%|██████▍ | 136/211 [04:26<02:20, 1.87s/it]
153
  65%|██████▍ | 137/211 [04:29<02:23, 1.94s/it]
154
  65%|██████▌ | 138/211 [04:30<02:18, 1.90s/it]
155
  66%|██████▌ | 139/211 [04:32<02:13, 1.86s/it]
156
  66%|██████▋ | 140/211 [04:34<02:10, 1.84s/it]
157
  67%|██████▋ | 141/211 [04:36<02:07, 1.82s/it]
158
  67%|██████▋ | 142/211 [04:38<02:06, 1.83s/it]
159
  68%|██████▊ | 143/211 [04:39<02:06, 1.86s/it]
160
  68%|██��███▊ | 144/211 [04:41<02:02, 1.83s/it]
161
  69%|██████▊ | 145/211 [04:43<02:00, 1.82s/it]
162
  69%|██████▉ | 146/211 [04:45<01:58, 1.82s/it]
163
  70%|██████▉ | 147/211 [04:47<02:06, 1.98s/it]
164
  70%|███████ | 148/211 [04:49<02:07, 2.02s/it]
165
  71%|███████ | 149/211 [04:51<02:02, 1.97s/it]
166
  71%|███████ | 150/211 [04:53<01:59, 1.96s/it]
167
  72%|███████▏ | 151/211 [04:55<01:56, 1.95s/it]
168
  72%|███████▏ | 152/211 [04:57<01:54, 1.94s/it]
169
  73%|███████▎ | 153/211 [04:59<01:49, 1.88s/it]
170
  73%|███████▎ | 154/211 [05:01<01:56, 2.04s/it]
171
  73%|███████▎ | 155/211 [05:03<01:50, 1.97s/it]
172
  74%|███████▍ | 156/211 [05:05<01:48, 1.97s/it]
173
  74%|███████▍ | 157/211 [05:07<01:46, 1.96s/it]
174
  75%|███████▍ | 158/211 [05:09<01:41, 1.92s/it]
175
  75%|███████▌ | 159/211 [05:10<01:38, 1.90s/it]
176
  76%|███████▌ | 160/211 [05:12<01:35, 1.88s/it]
177
  76%|███████▋ | 161/211 [05:14<01:32, 1.85s/it]
178
  77%|███████▋ | 162/211 [05:16<01:32, 1.89s/it]
179
  77%|███████▋ | 163/211 [05:18<01:31, 1.90s/it]
180
  78%|███████▊ | 164/211 [05:20<01:28, 1.88s/it]
181
  78%|███████▊ | 165/211 [05:22<01:26, 1.88s/it]
182
  79%|███████▊ | 166/211 [05:24<01:25, 1.89s/it]
183
  79%|███████▉ | 167/211 [05:25<01:22, 1.88s/it]
184
  80%|███████▉ | 168/211 [05:27<01:22, 1.92s/it]
185
  80%|████████ | 169/211 [05:29<01:21, 1.94s/it]
186
  81%|████████ | 170/211 [05:31<01:18, 1.92s/it]
187
  81%|████████ | 171/211 [05:34<01:26, 2.15s/it]
188
  82%|████████▏ | 172/211 [05:36<01:23, 2.14s/it]
189
  82%|████████▏ | 173/211 [05:38<01:17, 2.05s/it]
190
  82%|████████▏ | 174/211 [05:40<01:13, 1.99s/it]
191
  83%|████████▎ | 175/211 [05:42<01:09, 1.93s/it]
192
  83%|████████▎ | 176/211 [05:44<01:10, 2.02s/it]
193
  84%|████████▍ | 177/211 [05:46<01:10, 2.06s/it]
194
  84%|████████▍ | 178/211 [05:49<01:16, 2.33s/it]
195
  85%|████████▍ | 179/211 [05:51<01:09, 2.16s/it]
196
  85%|████████▌ | 180/211 [05:53<01:03, 2.05s/it]
197
  86%|████████▌ | 181/211 [05:54<00:58, 1.96s/it]
198
  86%|████████▋ | 182/211 [05:56<00:56, 1.95s/it]
199
  87%|████████▋ | 183/211 [05:59<00:57, 2.05s/it]
200
  87%|████████▋ | 184/211 [06:00<00:54, 2.01s/it]
201
  88%|████████▊ | 185/211 [06:02<00:50, 1.95s/it]
202
  88%|████████▊ | 186/211 [06:04<00:47, 1.89s/it]
203
  89%|████████▊ | 187/211 [06:06<00:45, 1.88s/it]
204
  89%|████████▉ | 188/211 [06:08<00:42, 1.86s/it]
205
  90%|████████▉ | 189/211 [06:09<00:40, 1.83s/it]
206
  90%|█████████ | 190/211 [06:11<00:38, 1.85s/it]
207
  91%|█████████ | 191/211 [06:13<00:36, 1.85s/it]
208
  91%|█████████ | 192/211 [06:15<00:34, 1.84s/it]
209
  91%|█████████▏| 193/211 [06:17<00:34, 1.92s/it]
210
  92%|█████████▏| 194/211 [06:19<00:32, 1.91s/it]
211
  92%|█████████▏| 195/211 [06:21<00:31, 1.96s/it]
212
  93%|█████████▎| 196/211 [06:23<00:28, 1.92s/it]
213
  93%|█████████▎| 197/211 [06:25<00:27, 1.96s/it]
214
  94%|█████████▍| 198/211 [06:27<00:25, 1.93s/it]
215
  94%|█████████▍| 199/211 [06:29<00:22, 1.88s/it]
216
  95%|█████████▍| 200/211 [06:30<00:20, 1.89s/it]
217
  95%|█████████▌| 201/211 [06:32<00:18, 1.88s/it]
218
  96%|█████████▌| 202/211 [06:34<00:16, 1.86s/it]
219
  96%|█████████▌| 203/211 [06:36<00:15, 1.88s/it]
220
  97%|█████████▋| 204/211 [06:38<00:13, 1.89s/it]
221
  97%|█████████▋| 205/211 [06:40<00:11, 1.91s/it]
222
  98%|█████████▊| 206/211 [06:42<00:10, 2.10s/it]
223
  98%|█████████▊| 207/211 [06:44<00:08, 2.07s/it]
224
  99%|█████████▊| 208/211 [06:46<00:05, 1.98s/it]
225
  99%|█████████▉| 209/211 [06:49<00:04, 2.12s/it]
226
+ computing/reading reference batch statistics...
227
+ computing sample batch activations...
228
+
229
  0%| | 0/469 [00:00<?, ?it/s]
230
  0%| | 1/469 [00:01<14:40, 1.88s/it]
231
  0%| | 2/469 [00:04<16:02, 2.06s/it]
232
  1%| | 3/469 [00:05<15:06, 1.95s/it]
233
  1%| | 4/469 [00:07<14:44, 1.90s/it]
234
  1%| | 5/469 [00:09<14:34, 1.88s/it]
235
  1%|▏ | 6/469 [00:11<14:24, 1.87s/it]
236
  1%|▏ | 7/469 [00:13<14:05, 1.83s/it]
237
  2%|▏ | 8/469 [00:15<14:19, 1.87s/it]
238
  2%|▏ | 9/469 [00:16<14:13, 1.85s/it]
239
  2%|▏ | 10/469 [00:18<14:09, 1.85s/it]
240
  2%|▏ | 11/469 [00:20<14:04, 1.84s/it]
241
  3%|▎ | 12/469 [00:22<13:59, 1.84s/it]
242
  3%|▎ | 13/469 [00:24<14:09, 1.86s/it]
243
  3%|▎ | 14/469 [00:26<14:18, 1.89s/it]
244
  3%|▎ | 15/469 [00:28<14:01, 1.85s/it]
245
  3%|▎ | 16/469 [00:30<14:36, 1.94s/it]
246
  4%|▎ | 17/469 [00:31<14:12, 1.89s/it]
247
  4%|▍ | 18/469 [00:33<13:58, 1.86s/it]
248
  4%|▍ | 19/469 [00:36<14:55, 1.99s/it]
249
  4%|▍ | 20/469 [00:38<14:53, 1.99s/it]
250
  4%|▍ | 21/469 [00:40<14:50, 1.99s/it]
251
  5%|▍ | 22/469 [00:41<14:43, 1.98s/it]
252
  5%|▍ | 23/469 [00:44<15:04, 2.03s/it]
253
  5%|▌ | 24/469 [00:45<14:34, 1.97s/it]
254
  5%|▌ | 25/469 [00:47<14:03, 1.90s/it]
255
  6%|▌ | 26/469 [00:49<13:42, 1.86s/it]
256
  6%|▌ | 27/469 [00:51<13:31, 1.84s/it]
257
  6%|▌ | 28/469 [00:53<13:35, 1.85s/it]
258
  6%|▌ | 29/469 [00:54<13:24, 1.83s/it]
259
  6%|▋ | 30/469 [00:56<13:30, 1.85s/it]
260
  7%|▋ | 31/469 [00:58<13:10, 1.80s/it]
261
  7%|▋ | 32/469 [01:00<13:43, 1.88s/it]
262
  7%|▋ | 33/469 [01:02<13:23, 1.84s/it]
263
  7%|▋ | 34/469 [01:05<15:15, 2.10s/it]
264
  7%|▋ | 35/469 [01:06<14:40, 2.03s/it]
265
  8%|▊ | 36/469 [01:08<14:23, 1.99s/it]
266
  8%|▊ | 37/469 [01:10<13:51, 1.93s/it]
267
  8%|▊ | 38/469 [01:12<13:26, 1.87s/it]
268
  8%|▊ | 39/469 [01:14<13:13, 1.85s/it]
269
  9%|▊ | 40/469 [01:15<13:15, 1.85s/it]
270
  9%|▊ | 41/469 [01:17<13:14, 1.86s/it]
271
  9%|▉ | 42/469 [01:19<12:58, 1.82s/it]
272
  9%|▉ | 43/469 [01:21<13:04, 1.84s/it]
273
  9%|▉ | 44/469 [01:23<12:55, 1.82s/it]
274
  10%|▉ | 45/469 [01:25<12:56, 1.83s/it]
275
  10%|▉ | 46/469 [01:26<12:50, 1.82s/it]
276
  10%|█ | 47/469 [01:28<12:47, 1.82s/it]
277
  10%|█ | 48/469 [01:30<12:41, 1.81s/it]
278
  10%|█ | 49/469 [01:32<12:52, 1.84s/it]
279
  11%|█ | 50/469 [01:34<13:00, 1.86s/it]
280
  11%|█ | 51/469 [01:36<12:46, 1.83s/it]
281
  11%|█ | 52/469 [01:37<12:41, 1.83s/it]
282
  11%|█▏ | 53/469 [01:39<12:40, 1.83s/it]
283
  12%|█▏ | 54/469 [01:41<12:46, 1.85s/it]
284
  12%|█▏ | 55/469 [01:44<15:33, 2.25s/it]
285
  12%|█▏ | 56/469 [01:46<14:24, 2.09s/it]
286
  12%|█▏ | 57/469 [01:48<13:41, 1.99s/it]
287
  12%|█▏ | 58/469 [01:50<13:22, 1.95s/it]
288
  13%|█▎ | 59/469 [01:51<12:54, 1.89s/it]
289
  13%|█▎ | 60/469 [01:53<12:33, 1.84s/it]
290
  13%|█▎ | 61/469 [01:55<12:23, 1.82s/it]
291
  13%|█▎ | 62/469 [01:57<12:29, 1.84s/it]
292
  13%|█▎ | 63/469 [01:59<12:42, 1.88s/it]
293
  14%|█▎ | 64/469 [02:01<12:43, 1.89s/it]
294
  14%|█▍ | 65/469 [02:02<12:29, 1.86s/it]
295
  14%|█▍ | 66/469 [02:04<12:49, 1.91s/it]
296
  14%|█▍ | 67/469 [02:06<12:50, 1.92s/it]
297
  14%|█▍ | 68/469 [02:08<12:34, 1.88s/it]
298
  15%|█▍ | 69/469 [02:10<12:22, 1.86s/it]
299
  15%|█▍ | 70/469 [02:12<12:20, 1.86s/it]
300
  15%|█▌ | 71/469 [02:14<12:10, 1.84s/it]
301
  15%|█▌ | 72/469 [02:15<12:06, 1.83s/it]
302
  16%|█▌ | 73/469 [02:17<12:00, 1.82s/it]
303
  16%|█▌ | 74/469 [02:19<12:09, 1.85s/it]
304
  16%|█▌ | 75/469 [02:21<12:18, 1.87s/it]
305
  16%|█▌ | 76/469 [02:23<12:13, 1.87s/it]
306
  16%|█▋ | 77/469 [02:25<12:16, 1.88s/it]
307
  17%|█▋ | 78/469 [02:27<12:45, 1.96s/it]
308
  17%|█▋ | 79/469 [02:29<12:17, 1.89s/it]
309
  17%|█▋ | 80/469 [02:31<12:07, 1.87s/it]
310
  17%|█▋ | 81/469 [02:32<11:53, 1.84s/it]
311
  17%|█▋ | 82/469 [02:34<12:09, 1.89s/it]
312
  18%|█▊ | 83/469 [02:36<12:01, 1.87s/it]
313
  18%|█▊ | 84/469 [02:38<11:45, 1.83s/it]
314
  18%|█▊ | 85/469 [02:40<12:13, 1.91s/it]
315
  18%|█▊ | 86/469 [02:42<12:12, 1.91s/it]
316
  19%|█▊ | 87/469 [02:44<11:53, 1.87s/it]
317
  19%|█▉ | 88/469 [02:45<11:32, 1.82s/it]
318
  19%|█▉ | 89/469 [02:47<11:28, 1.81s/it]
319
  19%|█▉ | 90/469 [02:49<11:52, 1.88s/it]
320
  19%|█▉ | 91/469 [02:51<11:34, 1.84s/it]
321
  20%|█▉ | 92/469 [02:53<11:36, 1.85s/it]
322
  20%|█▉ | 93/469 [02:55<11:36, 1.85s/it]
323
  20%|██ | 94/469 [02:56<11:27, 1.83s/it]
324
  20%|██ | 95/469 [02:58<11:24, 1.83s/it]
325
  20%|██ | 96/469 [03:00<11:19, 1.82s/it]
326
  21%|██ | 97/469 [03:02<11:25, 1.84s/it]
327
  21%|██ | 98/469 [03:04<11:45, 1.90s/it]
328
  21%|██ | 99/469 [03:06<11:45, 1.91s/it]
329
  21%|██▏ | 100/469 [03:08<11:30, 1.87s/it]
330
  22%|██▏ | 101/469 [03:10<12:35, 2.05s/it]
331
  22%|██▏ | 102/469 [03:12<12:40, 2.07s/it]
332
  22%|██▏ | 103/469 [03:14<12:10, 2.00s/it]
333
  22%|██▏ | 104/469 [03:16<11:49, 1.94s/it]
334
  22%|██▏ | 105/469 [03:18<11:38, 1.92s/it]
335
  23%|██▎ | 106/469 [03:20<11:50, 1.96s/it]
336
  23%|██▎ | 107/469 [03:22<11:29, 1.91s/it]
337
  23%|██▎ | 108/469 [03:24<12:19, 2.05s/it]
338
  23%|██▎ | 109/469 [03:26<11:50, 1.97s/it]
339
  23%|██▎ | 110/469 [03:28<12:21, 2.06s/it]
340
  24%|██▎ | 111/469 [03:30<12:00, 2.01s/it]
341
  24%|██▍ | 112/469 [03:32<11:37, 1.95s/it]
342
  24%|██▍ | 113/469 [03:34<11:13, 1.89s/it]
343
  24%|██▍ | 114/469 [03:36<11:26, 1.93s/it]
344
  25%|██▍ | 115/469 [03:37<11:06, 1.88s/it]
345
  25%|██▍ | 116/469 [03:39<11:27, 1.95s/it]
346
  25%|██▍ | 117/469 [03:42<12:08, 2.07s/it]
347
  25%|██▌ | 118/469 [03:44<11:37, 1.99s/it]
348
  25%|██▌ | 119/469 [03:45<11:14, 1.93s/it]
349
  26%|██▌ | 120/469 [03:47<10:53, 1.87s/it]
350
  26%|██▌ | 121/469 [03:49<10:39, 1.84s/it]
351
  26%|██▌ | 122/469 [03:51<10:28, 1.81s/it]
352
  26%|██▌ | 123/469 [03:52<10:23, 1.80s/it]
353
  26%|██▋ | 124/469 [03:54<10:16, 1.79s/it]
354
  27%|██▋ | 125/469 [03:56<10:26, 1.82s/it]
355
  27%|██▋ | 126/469 [03:58<10:19, 1.81s/it]
356
  27%|██▋ | 127/469 [04:00<10:05, 1.77s/it]
357
  27%|██▋ | 128/469 [04:01<10:13, 1.80s/it]
358
  28%|██▊ | 129/469 [04:04<10:47, 1.91s/it]
359
  28%|██▊ | 130/469 [04:05<10:30, 1.86s/it]
360
  28%|██▊ | 131/469 [04:08<11:02, 1.96s/it]
361
  28%|██▊ | 132/469 [04:10<11:07, 1.98s/it]
362
  28%|██▊ | 133/469 [04:11<11:00, 1.97s/it]
363
  29%|██▊ | 134/469 [04:13<10:36, 1.90s/it]
364
  29%|██▉ | 135/469 [04:15<10:33, 1.90s/it]
365
  29%|██▉ | 136/469 [04:17<10:28, 1.89s/it]
366
  29%|██▉ | 137/469 [04:23<16:46, 3.03s/it]
367
  29%|██▉ | 138/469 [04:24<14:39, 2.66s/it]
368
  30%|██▉ | 139/469 [04:27<13:43, 2.49s/it]
369
  30%|██▉ | 140/469 [04:29<13:38, 2.49s/it]
370
  30%|███ | 141/469 [04:31<12:21, 2.26s/it]
371
  30%|███ | 142/469 [04:33<11:27, 2.10s/it]
372
  30%|███ | 143/469 [04:34<10:47, 1.98s/it]
373
  31%|███ | 144/469 [04:36<10:23, 1.92s/it]
374
  31%|███ | 145/469 [04:38<10:18, 1.91s/it]
375
  31%|███ | 146/469 [04:40<10:36, 1.97s/it]
376
  31%|███▏ | 147/469 [04:42<10:39, 1.99s/it]
377
  32%|███▏ | 148/469 [04:44<10:50, 2.03s/it]
378
  32%|███▏ | 149/469 [04:46<10:30, 1.97s/it]
379
  32%|███▏ | 150/469 [04:48<10:12, 1.92s/it]
380
  32%|███▏ | 151/469 [04:50<09:55, 1.87s/it]
381
  32%|███▏ | 152/469 [04:51<10:00, 1.89s/it]
382
  33%|███▎ | 153/469 [04:53<10:00, 1.90s/it]
383
  33%|███▎ | 154/469 [04:55<09:56, 1.89s/it]
384
  33%|███▎ | 155/469 [04:57<09:55, 1.90s/it]
385
  33%|███▎ | 156/469 [04:59<09:38, 1.85s/it]
386
  33%|███▎ | 157/469 [05:01<09:29, 1.83s/it]
387
  34%|███▎ | 158/469 [05:03<09:30, 1.84s/it]
388
  34%|███▍ | 159/469 [05:04<09:18, 1.80s/it]
389
  34%|███▍ | 160/469 [05:07<09:59, 1.94s/it]
390
  34%|███▍ | 161/469 [05:08<09:52, 1.93s/it]
391
  35%|███▍ | 162/469 [05:10<09:53, 1.93s/it]
392
  35%|███▍ | 163/469 [05:12<09:41, 1.90s/it]
393
  35%|███▍ | 164/469 [05:14<09:34, 1.88s/it]
394
  35%|███▌ | 165/469 [05:16<09:32, 1.88s/it]
395
  35%|███▌ | 166/469 [05:18<09:23, 1.86s/it]
396
  36%|███▌ | 167/469 [05:20<09:39, 1.92s/it]
397
  36%|███▌ | 168/469 [05:22<09:30, 1.90s/it]
398
  36%|███▌ | 169/469 [05:23<09:21, 1.87s/it]
399
  36%|███▌ | 170/469 [05:25<09:10, 1.84s/it]
400
  36%|███▋ | 171/469 [05:27<09:20, 1.88s/it]
401
  37%|███▋ | 172/469 [05:29<09:07, 1.84s/it]
402
  37%|███▋ | 173/469 [05:31<09:01, 1.83s/it]
403
  37%|███▋ | 174/469 [05:33<09:00, 1.83s/it]
404
  37%|███▋ | 175/469 [05:34<08:49, 1.80s/it]
405
  38%|███▊ | 176/469 [05:36<09:09, 1.88s/it]
406
  38%|███▊ | 177/469 [05:38<09:08, 1.88s/it]
407
  38%|███▊ | 178/469 [05:40<09:00, 1.86s/it]
408
  38%|███▊ | 179/469 [05:42<08:51, 1.83s/it]
409
  38%|███▊ | 180/469 [05:44<08:51, 1.84s/it]
410
  39%|███▊ | 181/469 [05:46<10:03, 2.10s/it]
411
  39%|███▉ | 182/469 [05:48<09:44, 2.04s/it]
412
  39%|███▉ | 183/469 [05:50<09:21, 1.96s/it]
413
  39%|███▉ | 184/469 [05:52<09:23, 1.98s/it]
414
  39%|███▉ | 185/469 [05:54<09:03, 1.92s/it]
415
  40%|███▉ | 186/469 [05:56<08:53, 1.88s/it]
416
  40%|███▉ | 187/469 [05:57<08:40, 1.85s/it]
417
  40%|████ | 188/469 [05:59<08:58, 1.92s/it]
418
  40%|████ | 189/469 [06:01<08:50, 1.89s/it]
419
  41%|████ | 190/469 [06:03<08:55, 1.92s/it]
420
  41%|████ | 191/469 [06:05<08:49, 1.91s/it]
421
  41%|████ | 192/469 [06:07<08:48, 1.91s/it]
422
  41%|████ | 193/469 [06:09<08:43, 1.90s/it]
423
  41%|████▏ | 194/469 [06:11<08:31, 1.86s/it]
424
  42%|████▏ | 195/469 [06:12<08:18, 1.82s/it]
425
  42%|████▏ | 196/469 [06:14<08:06, 1.78s/it]
426
  42%|████▏ | 197/469 [06:16<08:05, 1.79s/it]
427
  42%|████▏ | 198/469 [06:18<08:14, 1.83s/it]
428
  42%|████▏ | 199/469 [06:20<08:11, 1.82s/it]
429
  43%|████▎ | 200/469 [06:22<08:36, 1.92s/it]
430
  43%|████▎ | 201/469 [06:24<08:25, 1.89s/it]
431
  43%|████▎ | 202/469 [06:26<08:28, 1.91s/it]
432
  43%|████▎ | 203/469 [06:27<08:24, 1.90s/it]
433
  43%|████▎ | 204/469 [06:29<08:31, 1.93s/it]
434
  44%|████▎ | 205/469 [06:31<08:20, 1.89s/it]
435
  44%|████▍ | 206/469 [06:33<08:09, 1.86s/it]
436
  44%|████▍ | 207/469 [06:35<08:13, 1.89s/it]
437
  44%|████▍ | 208/469 [06:37<08:21, 1.92s/it]
438
  45%|████▍ | 209/469 [06:39<08:55, 2.06s/it]
439
  45%|████▍ | 210/469 [06:41<08:32, 1.98s/it]
440
  45%|████▍ | 211/469 [06:43<08:21, 1.94s/it]
441
  45%|████▌ | 212/469 [06:45<08:19, 1.94s/it]
442
  45%|████▌ | 213/469 [06:47<08:13, 1.93s/it]
443
  46%|████▌ | 214/469 [06:49<08:23, 1.98s/it]
444
  46%|████▌ | 215/469 [06:51<08:15, 1.95s/it]
445
  46%|████▌ | 216/469 [06:53<08:02, 1.91s/it]
446
  46%|████▋ | 217/469 [06:55<07:57, 1.90s/it]
447
  46%|████▋ | 218/469 [06:56<07:51, 1.88s/it]
448
  47%|████▋ | 219/469 [06:58<07:40, 1.84s/it]
449
  47%|████▋ | 220/469 [07:00<07:52, 1.90s/it]
450
  47%|████▋ | 221/469 [07:02<07:44, 1.87s/it]
451
  47%|████▋ | 222/469 [07:04<08:09, 1.98s/it]
452
  48%|████▊ | 223/469 [07:07<09:19, 2.27s/it]
453
  48%|████▊ | 224/469 [07:09<08:42, 2.13s/it]
454
  48%|████▊ | 225/469 [07:11<08:22, 2.06s/it]
455
  48%|████▊ | 226/469 [07:13<08:01, 1.98s/it]
456
  48%|████▊ | 227/469 [07:15<07:51, 1.95s/it]
457
  49%|████▊ | 228/469 [07:16<07:43, 1.92s/it]
458
  49%|████▉ | 229/469 [07:18<07:35, 1.90s/it]
459
  49%|████▉ | 230/469 [07:20<07:28, 1.88s/it]
460
  49%|████▉ | 231/469 [07:22<07:44, 1.95s/it]
461
  49%|████▉ | 232/469 [07:24<07:32, 1.91s/it]
462
  50%|████▉ | 233/469 [07:26<07:24, 1.88s/it]
463
  50%|████▉ | 234/469 [07:28<07:26, 1.90s/it]
464
  50%|█████ | 235/469 [07:30<07:18, 1.88s/it]
465
  50%|█████ | 236/469 [07:32<07:21, 1.89s/it]
466
  51%|█████ | 237/469 [07:33<07:15, 1.88s/it]
467
  51%|█████ | 238/469 [07:35<07:31, 1.95s/it]
468
  51%|█████ | 239/469 [07:37<07:25, 1.94s/it]
469
  51%|█████ | 240/469 [07:39<07:26, 1.95s/it]
470
  51%|█████▏ | 241/469 [07:42<08:39, 2.28s/it]
471
  52%|█████▏ | 242/469 [07:44<08:17, 2.19s/it]
472
  52%|█████▏ | 243/469 [07:47<08:19, 2.21s/it]
473
  52%|█████▏ | 244/469 [07:49<07:57, 2.12s/it]
474
  52%|█████▏ | 245/469 [07:50<07:35, 2.03s/it]
475
  52%|█████▏ | 246/469 [07:52<07:26, 2.00s/it]
476
  53%|█████▎ | 247/469 [07:54<07:19, 1.98s/it]
477
  53%|█████▎ | 248/469 [07:56<07:11, 1.95s/it]
478
  53%|█████▎ | 249/469 [07:58<07:15, 1.98s/it]
479
  53%|█████▎ | 250/469 [08:00<06:58, 1.91s/it]
480
  54%|█████▎ | 251/469 [08:02<06:51, 1.89s/it]
481
  54%|█████▎ | 252/469 [08:04<06:42, 1.85s/it]
482
  54%|█████▍ | 253/469 [08:05<06:33, 1.82s/it]
483
  54%|█████▍ | 254/469 [08:07<06:29, 1.81s/it]
484
  54%|█████▍ | 255/469 [08:09<06:32, 1.83s/it]
485
  55%|█████▍ | 256/469 [08:11<06:24, 1.80s/it]
486
  55%|█████▍ | 257/469 [08:12<06:18, 1.79s/it]
487
  55%|█████▌ | 258/469 [08:14<06:25, 1.83s/it]
488
  55%|█████▌ | 259/469 [08:16<06:25, 1.83s/it]
489
  55%|█████▌ | 260/469 [08:18<06:19, 1.82s/it]
490
  56%|█████▌ | 261/469 [08:20<06:42, 1.94s/it]
491
  56%|█████▌ | 262/469 [08:22<06:32, 1.90s/it]
492
  56%|█████▌ | 263/469 [08:24<06:25, 1.87s/it]
493
  56%|█████▋ | 264/469 [08:26<06:17, 1.84s/it]
494
  57%|█████▋ | 265/469 [08:28<06:20, 1.87s/it]
495
  57%|█████▋ | 266/469 [08:29<06:17, 1.86s/it]
496
  57%|█████▋ | 267/469 [08:31<06:16, 1.86s/it]
497
  57%|█████▋ | 268/469 [08:33<06:12, 1.86s/it]
498
  57%|█████▋ | 269/469 [08:35<06:41, 2.01s/it]
499
  58%|█████▊ | 270/469 [08:37<06:24, 1.93s/it]
500
  58%|█████▊ | 271/469 [08:39<06:10, 1.87s/it]
501
  58%|█████▊ | 272/469 [08:41<06:09, 1.87s/it]
502
  58%|█████▊ | 273/469 [08:43<06:24, 1.96s/it]
503
  58%|█████▊ | 274/469 [08:45<06:11, 1.90s/it]
504
  59%|█████▊ | 275/469 [08:47<06:10, 1.91s/it]
505
  59%|█████▉ | 276/469 [08:48<06:03, 1.88s/it]
506
  59%|█████▉ | 277/469 [08:50<06:01, 1.88s/it]
507
  59%|█████▉ | 278/469 [08:52<06:08, 1.93s/it]
508
  59%|█████▉ | 279/469 [08:54<05:57, 1.88s/it]
509
  60%|█████▉ | 280/469 [08:56<06:11, 1.97s/it]
510
  60%|█████▉ | 281/469 [08:58<06:07, 1.95s/it]
511
  60%|██████ | 282/469 [09:00<05:57, 1.91s/it]
512
  60%|██████ | 283/469 [09:02<05:48, 1.87s/it]
513
  61%|██████ | 284/469 [09:04<05:44, 1.86s/it]
514
  61%|██████ | 285/469 [09:06<05:45, 1.88s/it]
515
  61%|██████ | 286/469 [09:08<06:11, 2.03s/it]
516
  61%|██████ | 287/469 [09:10<06:16, 2.07s/it]
517
  61%|██████▏ | 288/469 [09:12<05:59, 1.99s/it]
518
  62%|██████▏ | 289/469 [09:14<05:50, 1.95s/it]
519
  62%|██████▏ | 290/469 [09:16<05:39, 1.90s/it]
520
  62%|██████▏ | 291/469 [09:17<05:28, 1.85s/it]
521
  62%|██████▏ | 292/469 [09:19<05:25, 1.84s/it]
522
  62%|██████▏ | 293/469 [09:21<05:23, 1.84s/it]
523
  63%|██████▎ | 294/469 [09:23<05:31, 1.89s/it]
524
  63%|██████▎ | 295/469 [09:25<05:20, 1.84s/it]
525
  63%|██████▎ | 296/469 [09:26<05:13, 1.81s/it]
526
  63%|██████▎ | 297/469 [09:28<05:08, 1.80s/it]
527
  64%|██████▎ | 298/469 [09:30<05:04, 1.78s/it]
528
  64%|██████▍ | 299/469 [09:32<05:05, 1.80s/it]
529
  64%|██████▍ | 300/469 [09:34<05:08, 1.82s/it]
530
  64%|██████▍ | 301/469 [09:36<05:07, 1.83s/it]
531
  64%|██████▍ | 302/469 [09:37<05:04, 1.82s/it]
532
  65%|██████▍ | 303/469 [09:39<05:06, 1.84s/it]
533
  65%|██████▍ | 304/469 [09:41<05:06, 1.86s/it]
534
  65%|██████▌ | 305/469 [09:43<05:02, 1.84s/it]
535
  65%|██████▌ | 306/469 [09:45<05:01, 1.85s/it]
536
  65%|██████▌ | 307/469 [09:47<05:00, 1.85s/it]
537
  66%|██████▌ | 308/469 [09:49<05:02, 1.88s/it]
538
  66%|██████▌ | 309/469 [09:50<05:00, 1.88s/it]
539
  66%|██████▌ | 310/469 [09:52<04:53, 1.85s/it]
540
  66%|██████▋ | 311/469 [09:54<04:54, 1.86s/it]
541
  67%|██████▋ | 312/469 [09:56<04:58, 1.90s/it]
542
  67%|██████▋ | 313/469 [09:58<04:50, 1.86s/it]
543
  67%|██████▋ | 314/469 [10:00<04:51, 1.88s/it]
544
  67%|██████▋ | 315/469 [10:02<04:52, 1.90s/it]
545
  67%|██████▋ | 316/469 [10:04<04:53, 1.92s/it]
546
  68%|██████▊ | 317/469 [10:06<04:45, 1.88s/it]
547
  68%|██████▊ | 318/469 [10:08<04:57, 1.97s/it]
548
  68%|██████▊ | 319/469 [10:09<04:46, 1.91s/it]
549
  68%|██████▊ | 320/469 [10:11<04:46, 1.92s/it]
550
  68%|██████▊ | 321/469 [10:14<05:08, 2.09s/it]
551
  69%|██████▊ | 322/469 [10:16<05:13, 2.14s/it]
552
  69%|██████▉ | 323/469 [10:18<05:02, 2.07s/it]
553
  69%|██████▉ | 324/469 [10:20<04:52, 2.02s/it]
554
  69%|██████▉ | 325/469 [10:22<04:42, 1.96s/it]
555
  70%|██████▉ | 326/469 [10:24<04:34, 1.92s/it]
556
  70%|██████▉ | 327/469 [10:26<04:41, 1.98s/it]
557
  70%|██████▉ | 328/469 [10:28<04:31, 1.93s/it]
558
  70%|███████ | 329/469 [10:29<04:24, 1.89s/it]
559
  70%|███████ | 330/469 [10:31<04:19, 1.87s/it]
560
  71%|███████ | 331/469 [10:33<04:13, 1.84s/it]
561
  71%|███████ | 332/469 [10:35<04:24, 1.93s/it]
562
  71%|███████ | 333/469 [10:37<04:16, 1.89s/it]
563
  71%|███████ | 334/469 [10:42<06:23, 2.84s/it]
564
  71%|███████▏ | 335/469 [10:44<05:35, 2.51s/it]
565
  72%|███████▏ | 336/469 [10:45<05:04, 2.29s/it]
566
  72%|███████▏ | 337/469 [10:47<04:42, 2.14s/it]
567
  72%|███████▏ | 338/469 [10:49<04:41, 2.15s/it]
568
  72%|███████▏ | 339/469 [10:51<04:33, 2.11s/it]
569
  72%|███████▏ | 340/469 [10:53<04:21, 2.03s/it]
570
  73%|███████▎ | 341/469 [10:55<04:17, 2.01s/it]
571
  73%|███████▎ | 342/469 [10:57<04:19, 2.04s/it]
572
  73%|███████▎ | 343/469 [10:59<04:18, 2.05s/it]
573
  73%|███████▎ | 344/469 [11:01<04:04, 1.96s/it]
574
  74%|███████▎ | 345/469 [11:03<04:00, 1.94s/it]
575
  74%|███████▍ | 346/469 [11:05<03:52, 1.89s/it]
576
  74%|███████▍ | 347/469 [11:07<03:45, 1.85s/it]
577
  74%|███████▍ | 348/469 [11:08<03:39, 1.82s/it]
578
  74%|███████▍ | 349/469 [11:10<03:38, 1.82s/it]
579
  75%|███████▍ | 350/469 [11:12<03:42, 1.87s/it]
580
  75%|███████▍ | 351/469 [11:14<03:44, 1.90s/it]
581
  75%|███████▌ | 352/469 [11:16<03:40, 1.88s/it]
582
  75%|███████▌ | 353/469 [11:25<07:38, 3.95s/it]
583
  75%|███████▌ | 354/469 [11:27<06:26, 3.36s/it]
584
  76%|███████▌ | 355/469 [11:29<05:30, 2.90s/it]
585
  76%|███████▌ | 356/469 [11:31<05:08, 2.73s/it]
586
  76%|███████▌ | 357/469 [11:33<04:45, 2.55s/it]
587
  76%|███████▋ | 358/469 [11:35<04:23, 2.37s/it]
588
  77%|██��████▋ | 359/469 [11:37<04:10, 2.28s/it]
589
  77%|███████▋ | 360/469 [11:39<03:51, 2.12s/it]
590
  77%|███████▋ | 361/469 [11:41<03:37, 2.02s/it]
591
  77%|███████▋ | 362/469 [12:00<13:03, 7.32s/it]
592
  77%|███████▋ | 363/469 [12:02<10:03, 5.69s/it]
593
  78%|███████▊ | 364/469 [12:04<07:55, 4.53s/it]
594
  78%|███████▊ | 365/469 [12:06<06:27, 3.73s/it]
595
  78%|███████▊ | 366/469 [12:08<05:26, 3.17s/it]
596
  78%|███████▊ | 367/469 [12:09<04:40, 2.75s/it]
597
  78%|███████▊ | 368/469 [12:12<04:21, 2.59s/it]
598
  79%|███████▊ | 369/469 [12:13<03:55, 2.36s/it]
599
  79%|███████▉ | 370/469 [12:15<03:43, 2.25s/it]
600
  79%|███████▉ | 371/469 [12:17<03:28, 2.13s/it]
601
  79%|███████▉ | 372/469 [12:19<03:21, 2.07s/it]
602
  80%|███████▉ | 373/469 [12:21<03:10, 1.99s/it]
603
  80%|███████▉ | 374/469 [12:23<03:02, 1.92s/it]
604
  80%|███████▉ | 375/469 [12:25<02:55, 1.86s/it]
605
  80%|████████ | 376/469 [12:26<02:51, 1.84s/it]
606
  80%|████████ | 377/469 [12:28<02:51, 1.86s/it]
607
  81%|████████ | 378/469 [12:30<02:47, 1.84s/it]
608
  81%|████████ | 379/469 [12:32<02:43, 1.81s/it]
609
  81%|████████ | 380/469 [12:34<02:48, 1.89s/it]
610
  81%|████████ | 381/469 [12:36<02:45, 1.88s/it]
611
  81%|████████▏ | 382/469 [12:38<02:43, 1.88s/it]
612
  82%|████████▏ | 383/469 [12:39<02:38, 1.85s/it]
613
  82%|████████▏ | 384/469 [12:41<02:36, 1.84s/it]
614
  82%|████████▏ | 385/469 [12:43<02:37, 1.88s/it]
615
  82%|████████▏ | 386/469 [12:45<02:35, 1.87s/it]
616
  83%|████████▎ | 387/469 [12:47<02:45, 2.02s/it]
617
  83%|████████▎ | 388/469 [12:49<02:37, 1.94s/it]
618
  83%|████████▎ | 389/469 [12:51<02:30, 1.88s/it]
619
  83%|████████▎ | 390/469 [12:53<02:32, 1.93s/it]
620
  83%|████████▎ | 391/469 [12:55<02:29, 1.92s/it]
621
  84%|████████▎ | 392/469 [12:57<02:23, 1.87s/it]
622
  84%|████████▍ | 393/469 [13:00<03:00, 2.38s/it]
623
  84%|████████▍ | 394/469 [13:02<02:47, 2.23s/it]
624
  84%|████████▍ | 395/469 [13:04<02:36, 2.12s/it]
625
  84%|████████▍ | 396/469 [13:06<02:27, 2.02s/it]
626
  85%|████████▍ | 397/469 [13:08<02:24, 2.00s/it]
627
  85%|████████▍ | 398/469 [13:09<02:19, 1.96s/it]
628
  85%|████████▌ | 399/469 [13:11<02:14, 1.92s/it]
629
  85%|████████▌ | 400/469 [13:13<02:08, 1.86s/it]
630
  86%|████████▌ | 401/469 [13:15<02:05, 1.85s/it]
631
  86%|████████▌ | 402/469 [13:17<02:02, 1.84s/it]
632
  86%|████████▌ | 403/469 [13:18<01:59, 1.81s/it]
633
  86%|████████▌ | 404/469 [13:20<02:02, 1.89s/it]
634
  86%|████████▋ | 405/469 [13:22<01:59, 1.87s/it]
635
  87%|████████▋ | 406/469 [13:24<01:56, 1.85s/it]
636
  87%|████████▋ | 407/469 [13:26<01:52, 1.81s/it]
637
  87%|████████▋ | 408/469 [13:28<01:50, 1.81s/it]
638
  87%|████████▋ | 409/469 [13:30<01:50, 1.84s/it]
639
  87%|████████▋ | 410/469 [13:32<01:55, 1.95s/it]
640
  88%|████████▊ | 411/469 [13:33<01:49, 1.89s/it]
641
  88%|████████▊ | 412/469 [13:35<01:48, 1.90s/it]
642
  88%|████████▊ | 413/469 [13:37<01:45, 1.89s/it]
643
  88%|████████▊ | 414/469 [13:39<01:42, 1.86s/it]
644
  88%|████████▊ | 415/469 [13:41<01:38, 1.82s/it]
645
  89%|████████▊ | 416/469 [13:43<01:39, 1.87s/it]
646
  89%|████████▉ | 417/469 [13:45<01:38, 1.90s/it]
647
  89%|████████▉ | 418/469 [13:50<02:30, 2.96s/it]
648
  89%|████████▉ | 419/469 [13:52<02:12, 2.65s/it]
649
  90%|████████▉ | 420/469 [13:54<02:01, 2.47s/it]
650
  90%|████████▉ | 421/469 [13:56<01:48, 2.26s/it]
651
  90%|████████▉ | 422/469 [13:58<01:39, 2.12s/it]
652
  90%|█████████ | 423/469 [14:00<01:33, 2.02s/it]
653
  90%|█████████ | 424/469 [14:01<01:29, 2.00s/it]
654
  91%|█████████ | 425/469 [14:03<01:24, 1.93s/it]
655
  91%|█████████ | 426/469 [14:05<01:22, 1.93s/it]
656
  91%|█████████ | 427/469 [14:07<01:19, 1.90s/it]
657
  91%|█████████▏| 428/469 [14:10<01:30, 2.21s/it]
658
  91%|█████████▏| 429/469 [14:12<01:23, 2.09s/it]
659
  92%|█████████▏| 430/469 [14:13<01:17, 2.00s/it]
660
  92%|█████████▏| 431/469 [14:16<01:17, 2.05s/it]
661
  92%|█████████▏| 432/469 [14:17<01:12, 1.96s/it]
662
  92%|█████████▏| 433/469 [14:19<01:09, 1.93s/it]
663
  93%|█████████▎| 434/469 [14:21<01:06, 1.89s/it]
664
  93%|█████████▎| 435/469 [14:23<01:03, 1.88s/it]
665
  93%|█████████▎| 436/469 [14:25<01:06, 2.01s/it]
666
  93%|█████████▎| 437/469 [14:27<01:02, 1.95s/it]
667
  93%|█████████▎| 438/469 [14:29<00:59, 1.93s/it]
668
  94%|█████████▎| 439/469 [14:31<00:57, 1.91s/it]
669
  94%|█████████▍| 440/469 [14:33<00:59, 2.04s/it]
670
  94%|█████████▍| 441/469 [14:35<00:55, 1.99s/it]
671
  94%|█████████▍| 442/469 [14:37<00:53, 1.97s/it]
672
  94%|█████████▍| 443/469 [14:39<00:49, 1.92s/it]
673
  95%|█████████▍| 444/469 [14:41<00:47, 1.89s/it]
674
  95%|█████████▍| 445/469 [14:42<00:44, 1.87s/it]
675
  95%|█████████▌| 446/469 [14:44<00:44, 1.94s/it]
676
  95%|█████████▌| 447/469 [14:46<00:41, 1.88s/it]
677
  96%|█████████▌| 448/469 [14:48<00:38, 1.83s/it]
678
  96%|█████████▌| 449/469 [14:50<00:37, 1.89s/it]
679
  96%|█████████▌| 450/469 [14:52<00:34, 1.84s/it]
680
  96%|█████████▌| 451/469 [14:53<00:32, 1.81s/it]
681
  96%|█████████▋| 452/469 [14:55<00:30, 1.80s/it]
682
  97%|█████████▋| 453/469 [14:57<00:28, 1.77s/it]
683
  97%|█████████▋| 454/469 [14:59<00:27, 1.85s/it]
684
  97%|█████████▋| 455/469 [15:01<00:25, 1.84s/it]
685
  97%|█████████▋| 456/469 [15:03<00:24, 1.92s/it]
686
  97%|█████████▋| 457/469 [15:05<00:22, 1.88s/it]
687
  98%|█████████▊| 458/469 [15:07<00:20, 1.88s/it]
688
  98%|█████████▊| 459/469 [15:09<00:19, 1.92s/it]
689
  98%|█████████▊| 460/469 [15:10<00:17, 1.92s/it]
690
  98%|█████████▊| 461/469 [15:12<00:15, 1.90s/it]
691
  99%|█████████▊| 462/469 [15:15<00:14, 2.11s/it]
692
  99%|█████████▊| 463/469 [15:17<00:13, 2.18s/it]
693
  99%|█████████▉| 464/469 [15:19<00:10, 2.09s/it]
694
  99%|█████████▉| 465/469 [15:21<00:08, 2.07s/it]
695
  99%|█████████▉| 466/469 [15:23<00:06, 2.02s/it]
696
+ computing/reading sample batch statistics...
697
+ Computing evaluations...
698
+ Inception Score: 37.95753860473633
699
+ FID: 21.04736987276152
700
+ sFID: 71.53442942455422
701
+ Precision: 0.6907333333333333
702
+ Recall: 0.35639366212898904
evaluate.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ REF_BATCH="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz"
4
+ CUDA_VISIBLE_DEVICES=1 nohup python evaluator_rf.py \
5
+ --ref_batch ${REF_BATCH} \
6
+ --sample_batch /gemini/space/gzy_new/models/Sida/sd3_rectified_samples_new_batch_2.npz \
7
+ > eval_rectified_noise_new_batch_2.log 2>&1 &
8
+ # CUDA_VISIBLE_DEVICES=0 nohup python evaluator_rf.py \
9
+ # --ref_batch ${REF_BATCH} \
10
+ # --sample_batch "/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512.npz" \
11
+ # > eval_baseline.log 2>&1 &
evaluator_base copy.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ #/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
30
+ parser.add_argument("--ref_batch", default='/gemini/space/dataset/coco/coco_train_3w.npz',help="path to reference batch npz file")
31
+ parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_lora_samples/batch32-rank64-last-sd3-sd3-lora-finetuned-batch-32-guidance-7.0-steps-20-size-512x512.npz', help="path to sample batch npz file")
32
+ parser.add_argument("--save_path", default='/gemini/space/gzy/w_w_last/w_w_sit_last1/temp/',help="path to sample batch npz file")
33
+ parser.add_argument("--cfg_cond", default=1, type=int)
34
+ parser.add_argument("--step", default=1, type=int)
35
+ parser.add_argument("--cfg", default=1.0, type=float)
36
+ parser.add_argument("--cls_cfg", default=1.0, type=float)
37
+ parser.add_argument("--gh", default=1.0, type=float)
38
+ parser.add_argument("--num_steps", default=250, type=int)
39
+ args = parser.parse_args()
40
+
41
+ if not os.path.exists(args.save_path):
42
+ os.mkdir(args.save_path)
43
+
44
+
45
+ config = tf.ConfigProto(
46
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
47
+ )
48
+ config.gpu_options.allow_growth = True
49
+ evaluator = Evaluator(tf.Session(config=config))
50
+
51
+ print("warming up TensorFlow...")
52
+ # This will cause TF to print a bunch of verbose stuff now rather
53
+ # than after the next print(), to help prevent confusion.
54
+ evaluator.warmup()
55
+
56
+ print("computing reference batch activations...")
57
+ ref_acts = evaluator.read_activations(args.ref_batch)
58
+ print("computing/reading reference batch statistics...")
59
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
60
+
61
+ print("computing sample batch activations...")
62
+ sample_acts = evaluator.read_activations(args.sample_batch)
63
+ print("computing/reading sample batch statistics...")
64
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
65
+
66
+ print("Computing evaluations...")
67
+ Inception_Score = evaluator.compute_inception_score(sample_acts[0])
68
+ FID = sample_stats.frechet_distance(ref_stats)
69
+ sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
70
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
71
+
72
+ print("Inception Score:", Inception_Score)
73
+ print("FID:", FID)
74
+ print("sFID:", sFID)
75
+ print("Precision:", prec)
76
+ print("Recall:", recall)
77
+
78
+ if args.cfg_cond:
79
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
80
+ else:
81
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
82
+ with open(file_path, "w") as file:
83
+ file.write("Inception Score: {}\n".format(Inception_Score))
84
+ file.write("FID: {}\n".format(FID))
85
+ file.write("sFID: {}\n".format(sFID))
86
+ file.write("Precision: {}\n".format(prec))
87
+ file.write("Recall: {}\n".format(recall))
88
+
89
+
90
+ class InvalidFIDException(Exception):
91
+ pass
92
+
93
+
94
+ class FIDStatistics:
95
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
96
+ self.mu = mu
97
+ self.sigma = sigma
98
+
99
+ def frechet_distance(self, other, eps=1e-6):
100
+ """
101
+ Compute the Frechet distance between two sets of statistics.
102
+ """
103
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
104
+ mu1, sigma1 = self.mu, self.sigma
105
+ mu2, sigma2 = other.mu, other.sigma
106
+
107
+ mu1 = np.atleast_1d(mu1)
108
+ mu2 = np.atleast_1d(mu2)
109
+
110
+ sigma1 = np.atleast_2d(sigma1)
111
+ sigma2 = np.atleast_2d(sigma2)
112
+
113
+ assert (
114
+ mu1.shape == mu2.shape
115
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
116
+ assert (
117
+ sigma1.shape == sigma2.shape
118
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
119
+
120
+ diff = mu1 - mu2
121
+
122
+ # product might be almost singular
123
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
124
+ if not np.isfinite(covmean).all():
125
+ msg = (
126
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
127
+ % eps
128
+ )
129
+ warnings.warn(msg)
130
+ offset = np.eye(sigma1.shape[0]) * eps
131
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
132
+
133
+ # numerical error might give slight imaginary component
134
+ if np.iscomplexobj(covmean):
135
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
136
+ m = np.max(np.abs(covmean.imag))
137
+ raise ValueError("Imaginary component {}".format(m))
138
+ covmean = covmean.real
139
+
140
+ tr_covmean = np.trace(covmean)
141
+
142
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
143
+
144
+
145
+ class Evaluator:
146
+ def __init__(
147
+ self,
148
+ session,
149
+ batch_size=64,
150
+ softmax_batch_size=512,
151
+ ):
152
+ self.sess = session
153
+ self.batch_size = batch_size
154
+ self.softmax_batch_size = softmax_batch_size
155
+ self.manifold_estimator = ManifoldEstimator(session)
156
+ with self.sess.graph.as_default():
157
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
158
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
159
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
160
+ self.softmax = _create_softmax_graph(self.softmax_input)
161
+
162
+ def warmup(self):
163
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
164
+
165
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
166
+ with open_npz_array(npz_path, "arr_0") as reader:
167
+ return self.compute_activations(reader.read_batches(self.batch_size))
168
+
169
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
170
+ """
171
+ Compute image features for downstream evals.
172
+
173
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
174
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
175
+ dimension. The tuple is (pool_3, spatial).
176
+ """
177
+ preds = []
178
+ spatial_preds = []
179
+ for batch in tqdm(batches):
180
+ batch = batch.astype(np.float32)
181
+ pred, spatial_pred = self.sess.run(
182
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
183
+ )
184
+ preds.append(pred.reshape([pred.shape[0], -1]))
185
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
186
+ return (
187
+ np.concatenate(preds, axis=0),
188
+ np.concatenate(spatial_preds, axis=0),
189
+ )
190
+
191
+ def read_statistics(
192
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
193
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
194
+ obj = np.load(npz_path)
195
+ if "mu" in list(obj.keys()):
196
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
197
+ obj["mu_s"], obj["sigma_s"]
198
+ )
199
+ return tuple(self.compute_statistics(x) for x in activations)
200
+
201
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
202
+ mu = np.mean(activations, axis=0)
203
+ sigma = np.cov(activations, rowvar=False)
204
+ return FIDStatistics(mu, sigma)
205
+
206
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
207
+ softmax_out = []
208
+ for i in range(0, len(activations), self.softmax_batch_size):
209
+ acts = activations[i : i + self.softmax_batch_size]
210
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
211
+ preds = np.concatenate(softmax_out, axis=0)
212
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
213
+ scores = []
214
+ for i in range(0, len(preds), split_size):
215
+ part = preds[i : i + split_size]
216
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
217
+ kl = np.mean(np.sum(kl, 1))
218
+ scores.append(np.exp(kl))
219
+ return float(np.mean(scores))
220
+
221
+ def compute_prec_recall(
222
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
223
+ ) -> Tuple[float, float]:
224
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
225
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
226
+ pr = self.manifold_estimator.evaluate_pr(
227
+ activations_ref, radii_1, activations_sample, radii_2
228
+ )
229
+ return (float(pr[0][0]), float(pr[1][0]))
230
+
231
+
232
+ class ManifoldEstimator:
233
+ """
234
+ A helper for comparing manifolds of feature vectors.
235
+
236
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ session,
242
+ row_batch_size=10000,
243
+ col_batch_size=10000,
244
+ nhood_sizes=(3,),
245
+ clamp_to_percentile=None,
246
+ eps=1e-5,
247
+ ):
248
+ """
249
+ Estimate the manifold of given feature vectors.
250
+
251
+ :param session: the TensorFlow session.
252
+ :param row_batch_size: row batch size to compute pairwise distances
253
+ (parameter to trade-off between memory usage and performance).
254
+ :param col_batch_size: column batch size to compute pairwise distances.
255
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
256
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
257
+ the given percentile.
258
+ :param eps: small number for numerical stability.
259
+ """
260
+ self.distance_block = DistanceBlock(session)
261
+ self.row_batch_size = row_batch_size
262
+ self.col_batch_size = col_batch_size
263
+ self.nhood_sizes = nhood_sizes
264
+ self.num_nhoods = len(nhood_sizes)
265
+ self.clamp_to_percentile = clamp_to_percentile
266
+ self.eps = eps
267
+
268
+ def warmup(self):
269
+ feats, radii = (
270
+ np.zeros([1, 2048], dtype=np.float32),
271
+ np.zeros([1, 1], dtype=np.float32),
272
+ )
273
+ self.evaluate_pr(feats, radii, feats, radii)
274
+
275
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
276
+ num_images = len(features)
277
+
278
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
279
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
280
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
281
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
282
+
283
+ for begin1 in range(0, num_images, self.row_batch_size):
284
+ end1 = min(begin1 + self.row_batch_size, num_images)
285
+ row_batch = features[begin1:end1]
286
+
287
+ for begin2 in range(0, num_images, self.col_batch_size):
288
+ end2 = min(begin2 + self.col_batch_size, num_images)
289
+ col_batch = features[begin2:end2]
290
+
291
+ # Compute distances between batches.
292
+ distance_batch[
293
+ 0 : end1 - begin1, begin2:end2
294
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
295
+
296
+ # Find the k-nearest neighbor from the current batch.
297
+ radii[begin1:end1, :] = np.concatenate(
298
+ [
299
+ x[:, self.nhood_sizes]
300
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
301
+ ],
302
+ axis=0,
303
+ )
304
+
305
+ if self.clamp_to_percentile is not None:
306
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
307
+ radii[radii > max_distances] = 0
308
+ return radii
309
+
310
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
311
+ """
312
+ Evaluate if new feature vectors are at the manifold.
313
+ """
314
+ num_eval_images = eval_features.shape[0]
315
+ num_ref_images = radii.shape[0]
316
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
317
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
318
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
319
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
320
+
321
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
322
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
323
+ feature_batch = eval_features[begin1:end1]
324
+
325
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
326
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
327
+ ref_batch = features[begin2:end2]
328
+
329
+ distance_batch[
330
+ 0 : end1 - begin1, begin2:end2
331
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
332
+
333
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
334
+ # If a feature vector is inside a hypersphere of some reference sample, then
335
+ # the new sample lies at the estimated manifold.
336
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
337
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
338
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
339
+
340
+ max_realism_score[begin1:end1] = np.max(
341
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
342
+ )
343
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
344
+
345
+ return {
346
+ "fraction": float(np.mean(batch_predictions)),
347
+ "batch_predictions": batch_predictions,
348
+ "max_realisim_score": max_realism_score,
349
+ "nearest_indices": nearest_indices,
350
+ }
351
+
352
+ def evaluate_pr(
353
+ self,
354
+ features_1: np.ndarray,
355
+ radii_1: np.ndarray,
356
+ features_2: np.ndarray,
357
+ radii_2: np.ndarray,
358
+ ) -> Tuple[np.ndarray, np.ndarray]:
359
+ """
360
+ Evaluate precision and recall efficiently.
361
+
362
+ :param features_1: [N1 x D] feature vectors for reference batch.
363
+ :param radii_1: [N1 x K1] radii for reference vectors.
364
+ :param features_2: [N2 x D] feature vectors for the other batch.
365
+ :param radii_2: [N x K2] radii for other vectors.
366
+ :return: a tuple of arrays for (precision, recall):
367
+ - precision: an np.ndarray of length K1
368
+ - recall: an np.ndarray of length K2
369
+ """
370
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
371
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
372
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
373
+ end_1 = begin_1 + self.row_batch_size
374
+ batch_1 = features_1[begin_1:end_1]
375
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
376
+ end_2 = begin_2 + self.col_batch_size
377
+ batch_2 = features_2[begin_2:end_2]
378
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
379
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
380
+ )
381
+ features_1_status[begin_1:end_1] |= batch_1_in
382
+ features_2_status[begin_2:end_2] |= batch_2_in
383
+ return (
384
+ np.mean(features_2_status.astype(np.float64), axis=0),
385
+ np.mean(features_1_status.astype(np.float64), axis=0),
386
+ )
387
+
388
+
389
+ class DistanceBlock:
390
+ """
391
+ Calculate pairwise distances between vectors.
392
+
393
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
394
+ """
395
+
396
+ def __init__(self, session):
397
+ self.session = session
398
+
399
+ # Initialize TF graph to calculate pairwise distances.
400
+ with session.graph.as_default():
401
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
402
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
403
+ distance_block_16 = _batch_pairwise_distances(
404
+ tf.cast(self._features_batch1, tf.float16),
405
+ tf.cast(self._features_batch2, tf.float16),
406
+ )
407
+ self.distance_block = tf.cond(
408
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
409
+ lambda: tf.cast(distance_block_16, tf.float32),
410
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
411
+ )
412
+
413
+ # Extra logic for less thans.
414
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
415
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
416
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
417
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
418
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
419
+
420
+ def pairwise_distances(self, U, V):
421
+ """
422
+ Evaluate pairwise distances between two batches of feature vectors.
423
+ """
424
+ return self.session.run(
425
+ self.distance_block,
426
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
427
+ )
428
+
429
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
430
+ return self.session.run(
431
+ [self._batch_1_in, self._batch_2_in],
432
+ feed_dict={
433
+ self._features_batch1: batch_1,
434
+ self._features_batch2: batch_2,
435
+ self._radii1: radii_1,
436
+ self._radii2: radii_2,
437
+ },
438
+ )
439
+
440
+
441
+ def _batch_pairwise_distances(U, V):
442
+ """
443
+ Compute pairwise distances between two batches of feature vectors.
444
+ """
445
+ with tf.variable_scope("pairwise_dist_block"):
446
+ # Squared norms of each row in U and V.
447
+ norm_u = tf.reduce_sum(tf.square(U), 1)
448
+ norm_v = tf.reduce_sum(tf.square(V), 1)
449
+
450
+ # norm_u as a column and norm_v as a row vectors.
451
+ norm_u = tf.reshape(norm_u, [-1, 1])
452
+ norm_v = tf.reshape(norm_v, [1, -1])
453
+
454
+ # Pairwise squared Euclidean distances.
455
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
456
+
457
+ return D
458
+
459
+
460
+ class NpzArrayReader(ABC):
461
+ @abstractmethod
462
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
463
+ pass
464
+
465
+ @abstractmethod
466
+ def remaining(self) -> int:
467
+ pass
468
+
469
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
470
+ def gen_fn():
471
+ while True:
472
+ batch = self.read_batch(batch_size)
473
+ if batch is None:
474
+ break
475
+ yield batch
476
+
477
+ rem = self.remaining()
478
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
479
+ return BatchIterator(gen_fn, num_batches)
480
+
481
+
482
+ class BatchIterator:
483
+ def __init__(self, gen_fn, length):
484
+ self.gen_fn = gen_fn
485
+ self.length = length
486
+
487
+ def __len__(self):
488
+ return self.length
489
+
490
+ def __iter__(self):
491
+ return self.gen_fn()
492
+
493
+
494
+ class StreamingNpzArrayReader(NpzArrayReader):
495
+ def __init__(self, arr_f, shape, dtype):
496
+ self.arr_f = arr_f
497
+ self.shape = shape
498
+ self.dtype = dtype
499
+ self.idx = 0
500
+
501
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
502
+ if self.idx >= self.shape[0]:
503
+ return None
504
+
505
+ bs = min(batch_size, self.shape[0] - self.idx)
506
+ self.idx += bs
507
+
508
+ if self.dtype.itemsize == 0:
509
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
510
+
511
+ read_count = bs * np.prod(self.shape[1:])
512
+ read_size = int(read_count * self.dtype.itemsize)
513
+ data = _read_bytes(self.arr_f, read_size, "array data")
514
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
515
+
516
+ def remaining(self) -> int:
517
+ return max(0, self.shape[0] - self.idx)
518
+
519
+
520
+ class MemoryNpzArrayReader(NpzArrayReader):
521
+ def __init__(self, arr):
522
+ self.arr = arr
523
+ self.idx = 0
524
+
525
+ @classmethod
526
+ def load(cls, path: str, arr_name: str):
527
+ with open(path, "rb") as f:
528
+ arr = np.load(f)[arr_name]
529
+ return cls(arr)
530
+
531
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
532
+ if self.idx >= self.arr.shape[0]:
533
+ return None
534
+
535
+ res = self.arr[self.idx : self.idx + batch_size]
536
+ self.idx += batch_size
537
+ return res
538
+
539
+ def remaining(self) -> int:
540
+ return max(0, self.arr.shape[0] - self.idx)
541
+
542
+
543
+ @contextmanager
544
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
545
+ with _open_npy_file(path, arr_name) as arr_f:
546
+ version = np.lib.format.read_magic(arr_f)
547
+ if version == (1, 0):
548
+ header = np.lib.format.read_array_header_1_0(arr_f)
549
+ elif version == (2, 0):
550
+ header = np.lib.format.read_array_header_2_0(arr_f)
551
+ else:
552
+ yield MemoryNpzArrayReader.load(path, arr_name)
553
+ return
554
+ shape, fortran, dtype = header
555
+ if fortran or dtype.hasobject:
556
+ yield MemoryNpzArrayReader.load(path, arr_name)
557
+ else:
558
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
559
+
560
+
561
+ def _read_bytes(fp, size, error_template="ran out of data"):
562
+ """
563
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
564
+
565
+ Read from file-like object until size bytes are read.
566
+ Raises ValueError if not EOF is encountered before size bytes are read.
567
+ Non-blocking objects only supported if they derive from io objects.
568
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
569
+ requested.
570
+ """
571
+ data = bytes()
572
+ while True:
573
+ # io files (default in python3) return None or raise on
574
+ # would-block, python2 file will truncate, probably nothing can be
575
+ # done about that. note that regular files can't be non-blocking
576
+ try:
577
+ r = fp.read(size - len(data))
578
+ data += r
579
+ if len(r) == 0 or len(data) == size:
580
+ break
581
+ except io.BlockingIOError:
582
+ pass
583
+ if len(data) != size:
584
+ msg = "EOF: reading %s, expected %d bytes got %d"
585
+ raise ValueError(msg % (error_template, size, len(data)))
586
+ else:
587
+ return data
588
+
589
+
590
+ @contextmanager
591
+ def _open_npy_file(path: str, arr_name: str):
592
+ with open(path, "rb") as f:
593
+ with zipfile.ZipFile(f, "r") as zip_f:
594
+ if f"{arr_name}.npy" not in zip_f.namelist():
595
+ raise ValueError(f"missing {arr_name} in npz file")
596
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
597
+ yield arr_f
598
+
599
+
600
+ def _download_inception_model():
601
+ if os.path.exists(INCEPTION_V3_PATH):
602
+ return
603
+ print("downloading InceptionV3 model...")
604
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
605
+ r.raise_for_status()
606
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
607
+ with open(tmp_path, "wb") as f:
608
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
609
+ f.write(chunk)
610
+ os.rename(tmp_path, INCEPTION_V3_PATH)
611
+
612
+
613
+ def _create_feature_graph(input_batch):
614
+ _download_inception_model()
615
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
616
+ with open(INCEPTION_V3_PATH, "rb") as f:
617
+ graph_def = tf.GraphDef()
618
+ graph_def.ParseFromString(f.read())
619
+ pool3, spatial = tf.import_graph_def(
620
+ graph_def,
621
+ input_map={f"ExpandDims:0": input_batch},
622
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
623
+ name=prefix,
624
+ )
625
+ _update_shapes(pool3)
626
+ spatial = spatial[..., :7]
627
+ return pool3, spatial
628
+
629
+
630
+ def _create_softmax_graph(input_batch):
631
+ _download_inception_model()
632
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
633
+ with open(INCEPTION_V3_PATH, "rb") as f:
634
+ graph_def = tf.GraphDef()
635
+ graph_def.ParseFromString(f.read())
636
+ (matmul,) = tf.import_graph_def(
637
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
638
+ )
639
+ w = matmul.inputs[1]
640
+ logits = tf.matmul(input_batch, w)
641
+ return tf.nn.softmax(logits)
642
+
643
+
644
+ def _update_shapes(pool3):
645
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
646
+ ops = pool3.graph.get_operations()
647
+ for op in ops:
648
+ for o in op.outputs:
649
+ shape = o.get_shape()
650
+ if shape._dims is not None: # pylint: disable=protected-access
651
+ # shape = [s.value for s in shape] TF 1.x
652
+ shape = [s for s in shape] # TF 2.x
653
+ new_shape = []
654
+ for j, s in enumerate(shape):
655
+ if s == 1 and j == 0:
656
+ new_shape.append(None)
657
+ else:
658
+ new_shape.append(s)
659
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
660
+ return pool3
661
+
662
+
663
+ def _numpy_partition(arr, kth, **kwargs):
664
+ num_workers = min(cpu_count(), len(arr))
665
+ chunk_size = len(arr) // num_workers
666
+ extra = len(arr) % num_workers
667
+
668
+ start_idx = 0
669
+ batches = []
670
+ for i in range(num_workers):
671
+ size = chunk_size + (1 if i < extra else 0)
672
+ batches.append(arr[start_idx : start_idx + size])
673
+ start_idx += size
674
+
675
+ with ThreadPool(num_workers) as pool:
676
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
677
+
678
+
679
+ if __name__ == "__main__":
680
+ main()
evaluator_base.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ Traceback (most recent call last):
3
+ File "/gemini/space/gzy_new/models/Sida/evaluator_base.py", line 16, in <module>
4
+ import tensorflow.compat.v1 as tf
5
+ ModuleNotFoundError: No module named 'tensorflow'
evaluator_base.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ #/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
30
+ parser.add_argument("--ref_batch", default='/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz',help="path to reference batch npz file")
31
+ parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512.npz', help="path to sample batch npz file")
32
+ parser.add_argument("--save_path", default='/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512/result',help="path to sample batch npz file")
33
+ parser.add_argument("--cfg_cond", default=1, type=int)
34
+ parser.add_argument("--step", default=1, type=int)
35
+ parser.add_argument("--cfg", default=1.0, type=float)
36
+ parser.add_argument("--cls_cfg", default=1.0, type=float)
37
+ parser.add_argument("--gh", default=1.0, type=float)
38
+ parser.add_argument("--num_steps", default=50, type=int)
39
+ args = parser.parse_args()
40
+
41
+ if not os.path.exists(args.save_path):
42
+ os.mkdir(args.save_path)
43
+
44
+ # NOTE: 当前环境中 TensorFlow 与 CUDA/cuDNN 可能版本不匹配(例如报 "No DNN in stream executor"),
45
+ # 这会导致 GPU 计算失败。这里强制使用 CPU 进行评估(会慢一些,但能保证运行)。
46
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
47
+
48
+ config = tf.ConfigProto(
49
+ allow_soft_placement=True, # allows DecodeJpeg to run on CPU in Inception graph
50
+ device_count={"GPU": 0},
51
+ )
52
+ evaluator = Evaluator(tf.Session(config=config))
53
+
54
+ print("warming up TensorFlow...")
55
+ # This will cause TF to print a bunch of verbose stuff now rather
56
+ # than after the next print(), to help prevent confusion.
57
+ evaluator.warmup()
58
+
59
+ print("computing reference batch activations...")
60
+ ref_acts = evaluator.read_activations(args.ref_batch)
61
+ print("computing/reading reference batch statistics...")
62
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
63
+
64
+ print("computing sample batch activations...")
65
+ sample_acts = evaluator.read_activations(args.sample_batch)
66
+ print("computing/reading sample batch statistics...")
67
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
68
+
69
+ print("Computing evaluations...")
70
+ Inception_Score = evaluator.compute_inception_score(sample_acts[0])
71
+ FID = sample_stats.frechet_distance(ref_stats)
72
+ sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
73
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
74
+
75
+ print("Inception Score:", Inception_Score)
76
+ print("FID:", FID)
77
+ print("sFID:", sFID)
78
+ print("Precision:", prec)
79
+ print("Recall:", recall)
80
+
81
+ if args.cfg_cond:
82
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
83
+ else:
84
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
85
+ with open(file_path, "w") as file:
86
+ file.write("Inception Score: {}\n".format(Inception_Score))
87
+ file.write("FID: {}\n".format(FID))
88
+ file.write("sFID: {}\n".format(sFID))
89
+ file.write("Precision: {}\n".format(prec))
90
+ file.write("Recall: {}\n".format(recall))
91
+
92
+
93
+ class InvalidFIDException(Exception):
94
+ pass
95
+
96
+
97
+ class FIDStatistics:
98
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
99
+ self.mu = mu
100
+ self.sigma = sigma
101
+
102
+ def frechet_distance(self, other, eps=1e-6):
103
+ """
104
+ Compute the Frechet distance between two sets of statistics.
105
+ """
106
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
107
+ mu1, sigma1 = self.mu, self.sigma
108
+ mu2, sigma2 = other.mu, other.sigma
109
+
110
+ mu1 = np.atleast_1d(mu1)
111
+ mu2 = np.atleast_1d(mu2)
112
+
113
+ sigma1 = np.atleast_2d(sigma1)
114
+ sigma2 = np.atleast_2d(sigma2)
115
+
116
+ assert (
117
+ mu1.shape == mu2.shape
118
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
119
+ assert (
120
+ sigma1.shape == sigma2.shape
121
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
122
+
123
+ diff = mu1 - mu2
124
+
125
+ # product might be almost singular
126
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
127
+ if not np.isfinite(covmean).all():
128
+ msg = (
129
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
130
+ % eps
131
+ )
132
+ warnings.warn(msg)
133
+ offset = np.eye(sigma1.shape[0]) * eps
134
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135
+
136
+ # numerical error might give slight imaginary component
137
+ if np.iscomplexobj(covmean):
138
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139
+ m = np.max(np.abs(covmean.imag))
140
+ raise ValueError("Imaginary component {}".format(m))
141
+ covmean = covmean.real
142
+
143
+ tr_covmean = np.trace(covmean)
144
+
145
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
146
+
147
+
148
+ class Evaluator:
149
+ def __init__(
150
+ self,
151
+ session,
152
+ batch_size=64,
153
+ softmax_batch_size=512,
154
+ ):
155
+ self.sess = session
156
+ self.batch_size = batch_size
157
+ self.softmax_batch_size = softmax_batch_size
158
+ self.manifold_estimator = ManifoldEstimator(session)
159
+ with self.sess.graph.as_default():
160
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
161
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
162
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
163
+ self.softmax = _create_softmax_graph(self.softmax_input)
164
+
165
+ def warmup(self):
166
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
167
+
168
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
169
+ with open_npz_array(npz_path, "arr_0") as reader:
170
+ return self.compute_activations(reader.read_batches(self.batch_size))
171
+
172
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
173
+ """
174
+ Compute image features for downstream evals.
175
+
176
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
177
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
178
+ dimension. The tuple is (pool_3, spatial).
179
+ """
180
+ preds = []
181
+ spatial_preds = []
182
+ for batch in tqdm(batches):
183
+ batch = batch.astype(np.float32)
184
+ pred, spatial_pred = self.sess.run(
185
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
186
+ )
187
+ preds.append(pred.reshape([pred.shape[0], -1]))
188
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
189
+ return (
190
+ np.concatenate(preds, axis=0),
191
+ np.concatenate(spatial_preds, axis=0),
192
+ )
193
+
194
+ def read_statistics(
195
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
196
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
197
+ obj = np.load(npz_path)
198
+ if "mu" in list(obj.keys()):
199
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
200
+ obj["mu_s"], obj["sigma_s"]
201
+ )
202
+ return tuple(self.compute_statistics(x) for x in activations)
203
+
204
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
205
+ mu = np.mean(activations, axis=0)
206
+ sigma = np.cov(activations, rowvar=False)
207
+ return FIDStatistics(mu, sigma)
208
+
209
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
210
+ softmax_out = []
211
+ for i in range(0, len(activations), self.softmax_batch_size):
212
+ acts = activations[i : i + self.softmax_batch_size]
213
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
214
+ preds = np.concatenate(softmax_out, axis=0)
215
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
216
+ scores = []
217
+ for i in range(0, len(preds), split_size):
218
+ part = preds[i : i + split_size]
219
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
220
+ kl = np.mean(np.sum(kl, 1))
221
+ scores.append(np.exp(kl))
222
+ return float(np.mean(scores))
223
+
224
+ def compute_prec_recall(
225
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
226
+ ) -> Tuple[float, float]:
227
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
228
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
229
+ pr = self.manifold_estimator.evaluate_pr(
230
+ activations_ref, radii_1, activations_sample, radii_2
231
+ )
232
+ return (float(pr[0][0]), float(pr[1][0]))
233
+
234
+
235
+ class ManifoldEstimator:
236
+ """
237
+ A helper for comparing manifolds of feature vectors.
238
+
239
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ session,
245
+ row_batch_size=10000,
246
+ col_batch_size=10000,
247
+ nhood_sizes=(3,),
248
+ clamp_to_percentile=None,
249
+ eps=1e-5,
250
+ ):
251
+ """
252
+ Estimate the manifold of given feature vectors.
253
+
254
+ :param session: the TensorFlow session.
255
+ :param row_batch_size: row batch size to compute pairwise distances
256
+ (parameter to trade-off between memory usage and performance).
257
+ :param col_batch_size: column batch size to compute pairwise distances.
258
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
259
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
260
+ the given percentile.
261
+ :param eps: small number for numerical stability.
262
+ """
263
+ self.distance_block = DistanceBlock(session)
264
+ self.row_batch_size = row_batch_size
265
+ self.col_batch_size = col_batch_size
266
+ self.nhood_sizes = nhood_sizes
267
+ self.num_nhoods = len(nhood_sizes)
268
+ self.clamp_to_percentile = clamp_to_percentile
269
+ self.eps = eps
270
+
271
+ def warmup(self):
272
+ feats, radii = (
273
+ np.zeros([1, 2048], dtype=np.float32),
274
+ np.zeros([1, 1], dtype=np.float32),
275
+ )
276
+ self.evaluate_pr(feats, radii, feats, radii)
277
+
278
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
279
+ num_images = len(features)
280
+
281
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
282
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
283
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
284
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
285
+
286
+ for begin1 in range(0, num_images, self.row_batch_size):
287
+ end1 = min(begin1 + self.row_batch_size, num_images)
288
+ row_batch = features[begin1:end1]
289
+
290
+ for begin2 in range(0, num_images, self.col_batch_size):
291
+ end2 = min(begin2 + self.col_batch_size, num_images)
292
+ col_batch = features[begin2:end2]
293
+
294
+ # Compute distances between batches.
295
+ distance_batch[
296
+ 0 : end1 - begin1, begin2:end2
297
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
298
+
299
+ # Find the k-nearest neighbor from the current batch.
300
+ radii[begin1:end1, :] = np.concatenate(
301
+ [
302
+ x[:, self.nhood_sizes]
303
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
304
+ ],
305
+ axis=0,
306
+ )
307
+
308
+ if self.clamp_to_percentile is not None:
309
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
310
+ radii[radii > max_distances] = 0
311
+ return radii
312
+
313
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
314
+ """
315
+ Evaluate if new feature vectors are at the manifold.
316
+ """
317
+ num_eval_images = eval_features.shape[0]
318
+ num_ref_images = radii.shape[0]
319
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
320
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
321
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
322
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
323
+
324
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
325
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
326
+ feature_batch = eval_features[begin1:end1]
327
+
328
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
329
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
330
+ ref_batch = features[begin2:end2]
331
+
332
+ distance_batch[
333
+ 0 : end1 - begin1, begin2:end2
334
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
335
+
336
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
337
+ # If a feature vector is inside a hypersphere of some reference sample, then
338
+ # the new sample lies at the estimated manifold.
339
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
340
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
341
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
342
+
343
+ max_realism_score[begin1:end1] = np.max(
344
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
345
+ )
346
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
347
+
348
+ return {
349
+ "fraction": float(np.mean(batch_predictions)),
350
+ "batch_predictions": batch_predictions,
351
+ "max_realisim_score": max_realism_score,
352
+ "nearest_indices": nearest_indices,
353
+ }
354
+
355
+ def evaluate_pr(
356
+ self,
357
+ features_1: np.ndarray,
358
+ radii_1: np.ndarray,
359
+ features_2: np.ndarray,
360
+ radii_2: np.ndarray,
361
+ ) -> Tuple[np.ndarray, np.ndarray]:
362
+ """
363
+ Evaluate precision and recall efficiently.
364
+
365
+ :param features_1: [N1 x D] feature vectors for reference batch.
366
+ :param radii_1: [N1 x K1] radii for reference vectors.
367
+ :param features_2: [N2 x D] feature vectors for the other batch.
368
+ :param radii_2: [N x K2] radii for other vectors.
369
+ :return: a tuple of arrays for (precision, recall):
370
+ - precision: an np.ndarray of length K1
371
+ - recall: an np.ndarray of length K2
372
+ """
373
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
374
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
375
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
376
+ end_1 = begin_1 + self.row_batch_size
377
+ batch_1 = features_1[begin_1:end_1]
378
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
379
+ end_2 = begin_2 + self.col_batch_size
380
+ batch_2 = features_2[begin_2:end_2]
381
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
382
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
383
+ )
384
+ features_1_status[begin_1:end_1] |= batch_1_in
385
+ features_2_status[begin_2:end_2] |= batch_2_in
386
+ return (
387
+ np.mean(features_2_status.astype(np.float64), axis=0),
388
+ np.mean(features_1_status.astype(np.float64), axis=0),
389
+ )
390
+
391
+
392
+ class DistanceBlock:
393
+ """
394
+ Calculate pairwise distances between vectors.
395
+
396
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
397
+ """
398
+
399
+ def __init__(self, session):
400
+ self.session = session
401
+
402
+ # Initialize TF graph to calculate pairwise distances.
403
+ with session.graph.as_default():
404
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
405
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
406
+ distance_block_16 = _batch_pairwise_distances(
407
+ tf.cast(self._features_batch1, tf.float16),
408
+ tf.cast(self._features_batch2, tf.float16),
409
+ )
410
+ self.distance_block = tf.cond(
411
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
412
+ lambda: tf.cast(distance_block_16, tf.float32),
413
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
414
+ )
415
+
416
+ # Extra logic for less thans.
417
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
418
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
419
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
420
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
421
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
422
+
423
+ def pairwise_distances(self, U, V):
424
+ """
425
+ Evaluate pairwise distances between two batches of feature vectors.
426
+ """
427
+ return self.session.run(
428
+ self.distance_block,
429
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
430
+ )
431
+
432
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
433
+ return self.session.run(
434
+ [self._batch_1_in, self._batch_2_in],
435
+ feed_dict={
436
+ self._features_batch1: batch_1,
437
+ self._features_batch2: batch_2,
438
+ self._radii1: radii_1,
439
+ self._radii2: radii_2,
440
+ },
441
+ )
442
+
443
+
444
+ def _batch_pairwise_distances(U, V):
445
+ """
446
+ Compute pairwise distances between two batches of feature vectors.
447
+ """
448
+ with tf.variable_scope("pairwise_dist_block"):
449
+ # Squared norms of each row in U and V.
450
+ norm_u = tf.reduce_sum(tf.square(U), 1)
451
+ norm_v = tf.reduce_sum(tf.square(V), 1)
452
+
453
+ # norm_u as a column and norm_v as a row vectors.
454
+ norm_u = tf.reshape(norm_u, [-1, 1])
455
+ norm_v = tf.reshape(norm_v, [1, -1])
456
+
457
+ # Pairwise squared Euclidean distances.
458
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
459
+
460
+ return D
461
+
462
+
463
+ class NpzArrayReader(ABC):
464
+ @abstractmethod
465
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
466
+ pass
467
+
468
+ @abstractmethod
469
+ def remaining(self) -> int:
470
+ pass
471
+
472
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
473
+ def gen_fn():
474
+ while True:
475
+ batch = self.read_batch(batch_size)
476
+ if batch is None:
477
+ break
478
+ yield batch
479
+
480
+ rem = self.remaining()
481
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
482
+ return BatchIterator(gen_fn, num_batches)
483
+
484
+
485
+ class BatchIterator:
486
+ def __init__(self, gen_fn, length):
487
+ self.gen_fn = gen_fn
488
+ self.length = length
489
+
490
+ def __len__(self):
491
+ return self.length
492
+
493
+ def __iter__(self):
494
+ return self.gen_fn()
495
+
496
+
497
+ class StreamingNpzArrayReader(NpzArrayReader):
498
+ def __init__(self, arr_f, shape, dtype):
499
+ self.arr_f = arr_f
500
+ self.shape = shape
501
+ self.dtype = dtype
502
+ self.idx = 0
503
+
504
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
505
+ if self.idx >= self.shape[0]:
506
+ return None
507
+
508
+ bs = min(batch_size, self.shape[0] - self.idx)
509
+ self.idx += bs
510
+
511
+ if self.dtype.itemsize == 0:
512
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
513
+
514
+ read_count = bs * np.prod(self.shape[1:])
515
+ read_size = int(read_count * self.dtype.itemsize)
516
+ data = _read_bytes(self.arr_f, read_size, "array data")
517
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
518
+
519
+ def remaining(self) -> int:
520
+ return max(0, self.shape[0] - self.idx)
521
+
522
+
523
+ class MemoryNpzArrayReader(NpzArrayReader):
524
+ def __init__(self, arr):
525
+ self.arr = arr
526
+ self.idx = 0
527
+
528
+ @classmethod
529
+ def load(cls, path: str, arr_name: str):
530
+ with open(path, "rb") as f:
531
+ arr = np.load(f)[arr_name]
532
+ return cls(arr)
533
+
534
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
535
+ if self.idx >= self.arr.shape[0]:
536
+ return None
537
+
538
+ res = self.arr[self.idx : self.idx + batch_size]
539
+ self.idx += batch_size
540
+ return res
541
+
542
+ def remaining(self) -> int:
543
+ return max(0, self.arr.shape[0] - self.idx)
544
+
545
+
546
+ @contextmanager
547
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
548
+ with _open_npy_file(path, arr_name) as arr_f:
549
+ version = np.lib.format.read_magic(arr_f)
550
+ if version == (1, 0):
551
+ header = np.lib.format.read_array_header_1_0(arr_f)
552
+ elif version == (2, 0):
553
+ header = np.lib.format.read_array_header_2_0(arr_f)
554
+ else:
555
+ yield MemoryNpzArrayReader.load(path, arr_name)
556
+ return
557
+ shape, fortran, dtype = header
558
+ if fortran or dtype.hasobject:
559
+ yield MemoryNpzArrayReader.load(path, arr_name)
560
+ else:
561
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
562
+
563
+
564
+ def _read_bytes(fp, size, error_template="ran out of data"):
565
+ """
566
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
567
+
568
+ Read from file-like object until size bytes are read.
569
+ Raises ValueError if not EOF is encountered before size bytes are read.
570
+ Non-blocking objects only supported if they derive from io objects.
571
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
572
+ requested.
573
+ """
574
+ data = bytes()
575
+ while True:
576
+ # io files (default in python3) return None or raise on
577
+ # would-block, python2 file will truncate, probably nothing can be
578
+ # done about that. note that regular files can't be non-blocking
579
+ try:
580
+ r = fp.read(size - len(data))
581
+ data += r
582
+ if len(r) == 0 or len(data) == size:
583
+ break
584
+ except io.BlockingIOError:
585
+ pass
586
+ if len(data) != size:
587
+ msg = "EOF: reading %s, expected %d bytes got %d"
588
+ raise ValueError(msg % (error_template, size, len(data)))
589
+ else:
590
+ return data
591
+
592
+
593
+ @contextmanager
594
+ def _open_npy_file(path: str, arr_name: str):
595
+ with open(path, "rb") as f:
596
+ with zipfile.ZipFile(f, "r") as zip_f:
597
+ if f"{arr_name}.npy" not in zip_f.namelist():
598
+ raise ValueError(f"missing {arr_name} in npz file")
599
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
600
+ yield arr_f
601
+
602
+
603
+ def _download_inception_model():
604
+ if os.path.exists(INCEPTION_V3_PATH):
605
+ return
606
+ print("downloading InceptionV3 model...")
607
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
608
+ r.raise_for_status()
609
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
610
+ with open(tmp_path, "wb") as f:
611
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
612
+ f.write(chunk)
613
+ os.rename(tmp_path, INCEPTION_V3_PATH)
614
+
615
+
616
+ def _create_feature_graph(input_batch):
617
+ _download_inception_model()
618
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
619
+ with open(INCEPTION_V3_PATH, "rb") as f:
620
+ graph_def = tf.GraphDef()
621
+ graph_def.ParseFromString(f.read())
622
+ pool3, spatial = tf.import_graph_def(
623
+ graph_def,
624
+ input_map={f"ExpandDims:0": input_batch},
625
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
626
+ name=prefix,
627
+ )
628
+ _update_shapes(pool3)
629
+ spatial = spatial[..., :7]
630
+ return pool3, spatial
631
+
632
+
633
+ def _create_softmax_graph(input_batch):
634
+ _download_inception_model()
635
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
636
+ with open(INCEPTION_V3_PATH, "rb") as f:
637
+ graph_def = tf.GraphDef()
638
+ graph_def.ParseFromString(f.read())
639
+ (matmul,) = tf.import_graph_def(
640
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
641
+ )
642
+ w = matmul.inputs[1]
643
+ logits = tf.matmul(input_batch, w)
644
+ return tf.nn.softmax(logits)
645
+
646
+
647
+ def _update_shapes(pool3):
648
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
649
+ ops = pool3.graph.get_operations()
650
+ for op in ops:
651
+ for o in op.outputs:
652
+ shape = o.get_shape()
653
+ if shape._dims is not None: # pylint: disable=protected-access
654
+ # shape = [s.value for s in shape] TF 1.x
655
+ shape = [s for s in shape] # TF 2.x
656
+ new_shape = []
657
+ for j, s in enumerate(shape):
658
+ if s == 1 and j == 0:
659
+ new_shape.append(None)
660
+ else:
661
+ new_shape.append(s)
662
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
663
+ return pool3
664
+
665
+
666
+ def _numpy_partition(arr, kth, **kwargs):
667
+ num_workers = min(cpu_count(), len(arr))
668
+ chunk_size = len(arr) // num_workers
669
+ extra = len(arr) % num_workers
670
+
671
+ start_idx = 0
672
+ batches = []
673
+ for i in range(num_workers):
674
+ size = chunk_size + (1 if i < extra else 0)
675
+ batches.append(arr[start_idx : start_idx + size])
676
+ start_idx += size
677
+
678
+ with ThreadPool(num_workers) as pool:
679
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
680
+
681
+
682
+ if __name__ == "__main__":
683
+ main()
684
+
685
+ # nohup python evaluator_base.py > evaluator_base.log 2>&1 &
evaluator_rf.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ #/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
30
+ parser.add_argument("--ref_batch", default='/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz',help="path to reference batch npz file")
31
+ parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_batch2_220000.npz', help="path to sample batch npz file")
32
+ parser.add_argument("--save_path", default='/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_batch2_220000',help="path to sample batch npz file")
33
+ parser.add_argument("--cfg_cond", default=1, type=int)
34
+ parser.add_argument("--step", default=1, type=int)
35
+ parser.add_argument("--cfg", default=1.0, type=float)
36
+ parser.add_argument("--cls_cfg", default=1.0, type=float)
37
+ parser.add_argument("--gh", default=1.0, type=float)
38
+ parser.add_argument("--num_steps", default=50, type=int)
39
+ args = parser.parse_args()
40
+
41
+ if not os.path.exists(args.save_path):
42
+ os.mkdir(args.save_path)
43
+
44
+ # NOTE: 当前环境中 TensorFlow 与 CUDA/cuDNN 可能版本不匹配(例如报 "No DNN in stream executor"),
45
+ # 这会导致 GPU 计算失败。这里强制使用 CPU 进行评估(会慢一些,但能保证运行)。
46
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
47
+
48
+ config = tf.ConfigProto(
49
+ allow_soft_placement=True, # allows DecodeJpeg to run on CPU in Inception graph
50
+ device_count={"GPU": 0},
51
+ )
52
+ evaluator = Evaluator(tf.Session(config=config))
53
+
54
+ print("warming up TensorFlow...")
55
+ # This will cause TF to print a bunch of verbose stuff now rather
56
+ # than after the next print(), to help prevent confusion.
57
+ evaluator.warmup()
58
+
59
+ print("computing reference batch activations...")
60
+ ref_acts = evaluator.read_activations(args.ref_batch)
61
+ print("computing/reading reference batch statistics...")
62
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
63
+
64
+ print("computing sample batch activations...")
65
+ sample_acts = evaluator.read_activations(args.sample_batch)
66
+ print("computing/reading sample batch statistics...")
67
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
68
+
69
+ print("Computing evaluations...")
70
+ Inception_Score = evaluator.compute_inception_score(sample_acts[0])
71
+ FID = sample_stats.frechet_distance(ref_stats)
72
+ sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
73
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
74
+
75
+ print("Inception Score:", Inception_Score)
76
+ print("FID:", FID)
77
+ print("sFID:", sFID)
78
+ print("Precision:", prec)
79
+ print("Recall:", recall)
80
+
81
+ if args.cfg_cond:
82
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
83
+ else:
84
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
85
+ with open(file_path, "w") as file:
86
+ file.write("Inception Score: {}\n".format(Inception_Score))
87
+ file.write("FID: {}\n".format(FID))
88
+ file.write("sFID: {}\n".format(sFID))
89
+ file.write("Precision: {}\n".format(prec))
90
+ file.write("Recall: {}\n".format(recall))
91
+
92
+
93
+ class InvalidFIDException(Exception):
94
+ pass
95
+
96
+
97
+ class FIDStatistics:
98
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
99
+ self.mu = mu
100
+ self.sigma = sigma
101
+
102
+ def frechet_distance(self, other, eps=1e-6):
103
+ """
104
+ Compute the Frechet distance between two sets of statistics.
105
+ """
106
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
107
+ mu1, sigma1 = self.mu, self.sigma
108
+ mu2, sigma2 = other.mu, other.sigma
109
+
110
+ mu1 = np.atleast_1d(mu1)
111
+ mu2 = np.atleast_1d(mu2)
112
+
113
+ sigma1 = np.atleast_2d(sigma1)
114
+ sigma2 = np.atleast_2d(sigma2)
115
+
116
+ assert (
117
+ mu1.shape == mu2.shape
118
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
119
+ assert (
120
+ sigma1.shape == sigma2.shape
121
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
122
+
123
+ diff = mu1 - mu2
124
+
125
+ # product might be almost singular
126
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
127
+ if not np.isfinite(covmean).all():
128
+ msg = (
129
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
130
+ % eps
131
+ )
132
+ warnings.warn(msg)
133
+ offset = np.eye(sigma1.shape[0]) * eps
134
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135
+
136
+ # numerical error might give slight imaginary component
137
+ if np.iscomplexobj(covmean):
138
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139
+ m = np.max(np.abs(covmean.imag))
140
+ raise ValueError("Imaginary component {}".format(m))
141
+ covmean = covmean.real
142
+
143
+ tr_covmean = np.trace(covmean)
144
+
145
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
146
+
147
+
148
+ class Evaluator:
149
+ def __init__(
150
+ self,
151
+ session,
152
+ batch_size=64,
153
+ softmax_batch_size=512,
154
+ ):
155
+ self.sess = session
156
+ self.batch_size = batch_size
157
+ self.softmax_batch_size = softmax_batch_size
158
+ self.manifold_estimator = ManifoldEstimator(session)
159
+ with self.sess.graph.as_default():
160
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
161
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
162
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
163
+ self.softmax = _create_softmax_graph(self.softmax_input)
164
+
165
+ def warmup(self):
166
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
167
+
168
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
169
+ with open_npz_array(npz_path, "arr_0") as reader:
170
+ return self.compute_activations(reader.read_batches(self.batch_size))
171
+
172
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
173
+ """
174
+ Compute image features for downstream evals.
175
+
176
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
177
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
178
+ dimension. The tuple is (pool_3, spatial).
179
+ """
180
+ preds = []
181
+ spatial_preds = []
182
+ for batch in tqdm(batches):
183
+ batch = batch.astype(np.float32)
184
+ pred, spatial_pred = self.sess.run(
185
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
186
+ )
187
+ preds.append(pred.reshape([pred.shape[0], -1]))
188
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
189
+ return (
190
+ np.concatenate(preds, axis=0),
191
+ np.concatenate(spatial_preds, axis=0),
192
+ )
193
+
194
+ def read_statistics(
195
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
196
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
197
+ obj = np.load(npz_path)
198
+ if "mu" in list(obj.keys()):
199
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
200
+ obj["mu_s"], obj["sigma_s"]
201
+ )
202
+ return tuple(self.compute_statistics(x) for x in activations)
203
+
204
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
205
+ mu = np.mean(activations, axis=0)
206
+ sigma = np.cov(activations, rowvar=False)
207
+ return FIDStatistics(mu, sigma)
208
+
209
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
210
+ softmax_out = []
211
+ for i in range(0, len(activations), self.softmax_batch_size):
212
+ acts = activations[i : i + self.softmax_batch_size]
213
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
214
+ preds = np.concatenate(softmax_out, axis=0)
215
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
216
+ scores = []
217
+ for i in range(0, len(preds), split_size):
218
+ part = preds[i : i + split_size]
219
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
220
+ kl = np.mean(np.sum(kl, 1))
221
+ scores.append(np.exp(kl))
222
+ return float(np.mean(scores))
223
+
224
+ def compute_prec_recall(
225
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
226
+ ) -> Tuple[float, float]:
227
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
228
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
229
+ pr = self.manifold_estimator.evaluate_pr(
230
+ activations_ref, radii_1, activations_sample, radii_2
231
+ )
232
+ return (float(pr[0][0]), float(pr[1][0]))
233
+
234
+
235
+ class ManifoldEstimator:
236
+ """
237
+ A helper for comparing manifolds of feature vectors.
238
+
239
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ session,
245
+ row_batch_size=10000,
246
+ col_batch_size=10000,
247
+ nhood_sizes=(3,),
248
+ clamp_to_percentile=None,
249
+ eps=1e-5,
250
+ ):
251
+ """
252
+ Estimate the manifold of given feature vectors.
253
+
254
+ :param session: the TensorFlow session.
255
+ :param row_batch_size: row batch size to compute pairwise distances
256
+ (parameter to trade-off between memory usage and performance).
257
+ :param col_batch_size: column batch size to compute pairwise distances.
258
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
259
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
260
+ the given percentile.
261
+ :param eps: small number for numerical stability.
262
+ """
263
+ self.distance_block = DistanceBlock(session)
264
+ self.row_batch_size = row_batch_size
265
+ self.col_batch_size = col_batch_size
266
+ self.nhood_sizes = nhood_sizes
267
+ self.num_nhoods = len(nhood_sizes)
268
+ self.clamp_to_percentile = clamp_to_percentile
269
+ self.eps = eps
270
+
271
+ def warmup(self):
272
+ feats, radii = (
273
+ np.zeros([1, 2048], dtype=np.float32),
274
+ np.zeros([1, 1], dtype=np.float32),
275
+ )
276
+ self.evaluate_pr(feats, radii, feats, radii)
277
+
278
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
279
+ num_images = len(features)
280
+
281
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
282
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
283
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
284
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
285
+
286
+ for begin1 in range(0, num_images, self.row_batch_size):
287
+ end1 = min(begin1 + self.row_batch_size, num_images)
288
+ row_batch = features[begin1:end1]
289
+
290
+ for begin2 in range(0, num_images, self.col_batch_size):
291
+ end2 = min(begin2 + self.col_batch_size, num_images)
292
+ col_batch = features[begin2:end2]
293
+
294
+ # Compute distances between batches.
295
+ distance_batch[
296
+ 0 : end1 - begin1, begin2:end2
297
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
298
+
299
+ # Find the k-nearest neighbor from the current batch.
300
+ radii[begin1:end1, :] = np.concatenate(
301
+ [
302
+ x[:, self.nhood_sizes]
303
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
304
+ ],
305
+ axis=0,
306
+ )
307
+
308
+ if self.clamp_to_percentile is not None:
309
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
310
+ radii[radii > max_distances] = 0
311
+ return radii
312
+
313
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
314
+ """
315
+ Evaluate if new feature vectors are at the manifold.
316
+ """
317
+ num_eval_images = eval_features.shape[0]
318
+ num_ref_images = radii.shape[0]
319
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
320
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
321
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
322
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
323
+
324
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
325
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
326
+ feature_batch = eval_features[begin1:end1]
327
+
328
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
329
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
330
+ ref_batch = features[begin2:end2]
331
+
332
+ distance_batch[
333
+ 0 : end1 - begin1, begin2:end2
334
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
335
+
336
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
337
+ # If a feature vector is inside a hypersphere of some reference sample, then
338
+ # the new sample lies at the estimated manifold.
339
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
340
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
341
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
342
+
343
+ max_realism_score[begin1:end1] = np.max(
344
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
345
+ )
346
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
347
+
348
+ return {
349
+ "fraction": float(np.mean(batch_predictions)),
350
+ "batch_predictions": batch_predictions,
351
+ "max_realisim_score": max_realism_score,
352
+ "nearest_indices": nearest_indices,
353
+ }
354
+
355
+ def evaluate_pr(
356
+ self,
357
+ features_1: np.ndarray,
358
+ radii_1: np.ndarray,
359
+ features_2: np.ndarray,
360
+ radii_2: np.ndarray,
361
+ ) -> Tuple[np.ndarray, np.ndarray]:
362
+ """
363
+ Evaluate precision and recall efficiently.
364
+
365
+ :param features_1: [N1 x D] feature vectors for reference batch.
366
+ :param radii_1: [N1 x K1] radii for reference vectors.
367
+ :param features_2: [N2 x D] feature vectors for the other batch.
368
+ :param radii_2: [N x K2] radii for other vectors.
369
+ :return: a tuple of arrays for (precision, recall):
370
+ - precision: an np.ndarray of length K1
371
+ - recall: an np.ndarray of length K2
372
+ """
373
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
374
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
375
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
376
+ end_1 = begin_1 + self.row_batch_size
377
+ batch_1 = features_1[begin_1:end_1]
378
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
379
+ end_2 = begin_2 + self.col_batch_size
380
+ batch_2 = features_2[begin_2:end_2]
381
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
382
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
383
+ )
384
+ features_1_status[begin_1:end_1] |= batch_1_in
385
+ features_2_status[begin_2:end_2] |= batch_2_in
386
+ return (
387
+ np.mean(features_2_status.astype(np.float64), axis=0),
388
+ np.mean(features_1_status.astype(np.float64), axis=0),
389
+ )
390
+
391
+
392
+ class DistanceBlock:
393
+ """
394
+ Calculate pairwise distances between vectors.
395
+
396
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
397
+ """
398
+
399
+ def __init__(self, session):
400
+ self.session = session
401
+
402
+ # Initialize TF graph to calculate pairwise distances.
403
+ with session.graph.as_default():
404
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
405
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
406
+ distance_block_16 = _batch_pairwise_distances(
407
+ tf.cast(self._features_batch1, tf.float16),
408
+ tf.cast(self._features_batch2, tf.float16),
409
+ )
410
+ self.distance_block = tf.cond(
411
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
412
+ lambda: tf.cast(distance_block_16, tf.float32),
413
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
414
+ )
415
+
416
+ # Extra logic for less thans.
417
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
418
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
419
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
420
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
421
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
422
+
423
+ def pairwise_distances(self, U, V):
424
+ """
425
+ Evaluate pairwise distances between two batches of feature vectors.
426
+ """
427
+ return self.session.run(
428
+ self.distance_block,
429
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
430
+ )
431
+
432
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
433
+ return self.session.run(
434
+ [self._batch_1_in, self._batch_2_in],
435
+ feed_dict={
436
+ self._features_batch1: batch_1,
437
+ self._features_batch2: batch_2,
438
+ self._radii1: radii_1,
439
+ self._radii2: radii_2,
440
+ },
441
+ )
442
+
443
+
444
+ def _batch_pairwise_distances(U, V):
445
+ """
446
+ Compute pairwise distances between two batches of feature vectors.
447
+ """
448
+ with tf.variable_scope("pairwise_dist_block"):
449
+ # Squared norms of each row in U and V.
450
+ norm_u = tf.reduce_sum(tf.square(U), 1)
451
+ norm_v = tf.reduce_sum(tf.square(V), 1)
452
+
453
+ # norm_u as a column and norm_v as a row vectors.
454
+ norm_u = tf.reshape(norm_u, [-1, 1])
455
+ norm_v = tf.reshape(norm_v, [1, -1])
456
+
457
+ # Pairwise squared Euclidean distances.
458
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
459
+
460
+ return D
461
+
462
+
463
+ class NpzArrayReader(ABC):
464
+ @abstractmethod
465
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
466
+ pass
467
+
468
+ @abstractmethod
469
+ def remaining(self) -> int:
470
+ pass
471
+
472
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
473
+ def gen_fn():
474
+ while True:
475
+ batch = self.read_batch(batch_size)
476
+ if batch is None:
477
+ break
478
+ yield batch
479
+
480
+ rem = self.remaining()
481
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
482
+ return BatchIterator(gen_fn, num_batches)
483
+
484
+
485
+ class BatchIterator:
486
+ def __init__(self, gen_fn, length):
487
+ self.gen_fn = gen_fn
488
+ self.length = length
489
+
490
+ def __len__(self):
491
+ return self.length
492
+
493
+ def __iter__(self):
494
+ return self.gen_fn()
495
+
496
+
497
+ class StreamingNpzArrayReader(NpzArrayReader):
498
+ def __init__(self, arr_f, shape, dtype):
499
+ self.arr_f = arr_f
500
+ self.shape = shape
501
+ self.dtype = dtype
502
+ self.idx = 0
503
+
504
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
505
+ if self.idx >= self.shape[0]:
506
+ return None
507
+
508
+ bs = min(batch_size, self.shape[0] - self.idx)
509
+ self.idx += bs
510
+
511
+ if self.dtype.itemsize == 0:
512
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
513
+
514
+ read_count = bs * np.prod(self.shape[1:])
515
+ read_size = int(read_count * self.dtype.itemsize)
516
+ data = _read_bytes(self.arr_f, read_size, "array data")
517
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
518
+
519
+ def remaining(self) -> int:
520
+ return max(0, self.shape[0] - self.idx)
521
+
522
+
523
+ class MemoryNpzArrayReader(NpzArrayReader):
524
+ def __init__(self, arr):
525
+ self.arr = arr
526
+ self.idx = 0
527
+
528
+ @classmethod
529
+ def load(cls, path: str, arr_name: str):
530
+ with open(path, "rb") as f:
531
+ arr = np.load(f)[arr_name]
532
+ return cls(arr)
533
+
534
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
535
+ if self.idx >= self.arr.shape[0]:
536
+ return None
537
+
538
+ res = self.arr[self.idx : self.idx + batch_size]
539
+ self.idx += batch_size
540
+ return res
541
+
542
+ def remaining(self) -> int:
543
+ return max(0, self.arr.shape[0] - self.idx)
544
+
545
+
546
+ @contextmanager
547
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
548
+ with _open_npy_file(path, arr_name) as arr_f:
549
+ version = np.lib.format.read_magic(arr_f)
550
+ if version == (1, 0):
551
+ header = np.lib.format.read_array_header_1_0(arr_f)
552
+ elif version == (2, 0):
553
+ header = np.lib.format.read_array_header_2_0(arr_f)
554
+ else:
555
+ yield MemoryNpzArrayReader.load(path, arr_name)
556
+ return
557
+ shape, fortran, dtype = header
558
+ if fortran or dtype.hasobject:
559
+ yield MemoryNpzArrayReader.load(path, arr_name)
560
+ else:
561
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
562
+
563
+
564
+ def _read_bytes(fp, size, error_template="ran out of data"):
565
+ """
566
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
567
+
568
+ Read from file-like object until size bytes are read.
569
+ Raises ValueError if not EOF is encountered before size bytes are read.
570
+ Non-blocking objects only supported if they derive from io objects.
571
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
572
+ requested.
573
+ """
574
+ data = bytes()
575
+ while True:
576
+ # io files (default in python3) return None or raise on
577
+ # would-block, python2 file will truncate, probably nothing can be
578
+ # done about that. note that regular files can't be non-blocking
579
+ try:
580
+ r = fp.read(size - len(data))
581
+ data += r
582
+ if len(r) == 0 or len(data) == size:
583
+ break
584
+ except io.BlockingIOError:
585
+ pass
586
+ if len(data) != size:
587
+ msg = "EOF: reading %s, expected %d bytes got %d"
588
+ raise ValueError(msg % (error_template, size, len(data)))
589
+ else:
590
+ return data
591
+
592
+
593
+ @contextmanager
594
+ def _open_npy_file(path: str, arr_name: str):
595
+ with open(path, "rb") as f:
596
+ with zipfile.ZipFile(f, "r") as zip_f:
597
+ if f"{arr_name}.npy" not in zip_f.namelist():
598
+ raise ValueError(f"missing {arr_name} in npz file")
599
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
600
+ yield arr_f
601
+
602
+
603
+ def _download_inception_model():
604
+ if os.path.exists(INCEPTION_V3_PATH):
605
+ return
606
+ print("downloading InceptionV3 model...")
607
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
608
+ r.raise_for_status()
609
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
610
+ with open(tmp_path, "wb") as f:
611
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
612
+ f.write(chunk)
613
+ os.rename(tmp_path, INCEPTION_V3_PATH)
614
+
615
+
616
+ def _create_feature_graph(input_batch):
617
+ _download_inception_model()
618
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
619
+ with open(INCEPTION_V3_PATH, "rb") as f:
620
+ graph_def = tf.GraphDef()
621
+ graph_def.ParseFromString(f.read())
622
+ pool3, spatial = tf.import_graph_def(
623
+ graph_def,
624
+ input_map={f"ExpandDims:0": input_batch},
625
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
626
+ name=prefix,
627
+ )
628
+ _update_shapes(pool3)
629
+ spatial = spatial[..., :7]
630
+ return pool3, spatial
631
+
632
+
633
+ def _create_softmax_graph(input_batch):
634
+ _download_inception_model()
635
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
636
+ with open(INCEPTION_V3_PATH, "rb") as f:
637
+ graph_def = tf.GraphDef()
638
+ graph_def.ParseFromString(f.read())
639
+ (matmul,) = tf.import_graph_def(
640
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
641
+ )
642
+ w = matmul.inputs[1]
643
+ logits = tf.matmul(input_batch, w)
644
+ return tf.nn.softmax(logits)
645
+
646
+
647
+ def _update_shapes(pool3):
648
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
649
+ ops = pool3.graph.get_operations()
650
+ for op in ops:
651
+ for o in op.outputs:
652
+ shape = o.get_shape()
653
+ if shape._dims is not None: # pylint: disable=protected-access
654
+ # shape = [s.value for s in shape] TF 1.x
655
+ shape = [s for s in shape] # TF 2.x
656
+ new_shape = []
657
+ for j, s in enumerate(shape):
658
+ if s == 1 and j == 0:
659
+ new_shape.append(None)
660
+ else:
661
+ new_shape.append(s)
662
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
663
+ return pool3
664
+
665
+
666
+ def _numpy_partition(arr, kth, **kwargs):
667
+ num_workers = min(cpu_count(), len(arr))
668
+ chunk_size = len(arr) // num_workers
669
+ extra = len(arr) % num_workers
670
+
671
+ start_idx = 0
672
+ batches = []
673
+ for i in range(num_workers):
674
+ size = chunk_size + (1 if i < extra else 0)
675
+ batches.append(arr[start_idx : start_idx + size])
676
+ start_idx += size
677
+
678
+ with ThreadPool(num_workers) as pool:
679
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
680
+
681
+
682
+ if __name__ == "__main__":
683
+ main()
684
+
685
+ # nohup python evaluator_rf.py > evaluator_rf_iter22.log 2>&1 &
evaluator_rf_iter22.log ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/1 [00:00<?, ?it/s]2026-03-25 14:55:37.841840: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
 
 
 
1
  0%| | 0/211 [00:00<?, ?it/s]
2
  0%| | 1/211 [00:02<07:28, 2.13s/it]
3
  1%| | 2/211 [00:03<06:40, 1.92s/it]
4
  1%|▏ | 3/211 [00:06<08:10, 2.36s/it]
5
  2%|▏ | 4/211 [00:08<07:14, 2.10s/it]
6
  2%|▏ | 5/211 [00:10<06:43, 1.96s/it]
7
  3%|▎ | 6/211 [00:11<06:27, 1.89s/it]
8
  3%|▎ | 7/211 [00:13<06:15, 1.84s/it]
9
  4%|▍ | 8/211 [00:15<06:05, 1.80s/it]
10
  4%|▍ | 9/211 [00:17<06:00, 1.79s/it]
11
  5%|▍ | 10/211 [00:18<05:53, 1.76s/it]
12
  5%|▌ | 11/211 [00:20<05:47, 1.74s/it]
13
  6%|▌ | 12/211 [00:22<05:40, 1.71s/it]
14
  6%|▌ | 13/211 [00:23<05:34, 1.69s/it]
15
  7%|▋ | 14/211 [00:25<05:32, 1.69s/it]
16
  7%|▋ | 15/211 [00:27<05:39, 1.73s/it]
17
  8%|▊ | 16/211 [00:29<05:34, 1.71s/it]
18
  8%|▊ | 17/211 [00:30<05:34, 1.72s/it]
19
  9%|▊ | 18/211 [00:32<06:00, 1.87s/it]
20
  9%|▉ | 19/211 [00:34<05:48, 1.82s/it]
21
  9%|▉ | 20/211 [00:36<05:44, 1.80s/it]
22
  10%|▉ | 21/211 [00:38<05:39, 1.78s/it]
23
  10%|█ | 22/211 [00:40<05:40, 1.80s/it]
24
  11%|█ | 23/211 [00:41<05:45, 1.84s/it]
25
  11%|█▏ | 24/211 [00:43<05:43, 1.84s/it]
26
  12%|█▏ | 25/211 [00:45<05:35, 1.80s/it]
27
  12%|█▏ | 26/211 [00:47<05:41, 1.84s/it]
28
  13%|█▎ | 27/211 [00:49<05:42, 1.86s/it]
29
  13%|█▎ | 28/211 [00:51<05:36, 1.84s/it]
30
  14%|█▎ | 29/211 [00:52<05:27, 1.80s/it]
31
  14%|█▍ | 30/211 [00:54<05:18, 1.76s/it]
32
  15%|█▍ | 31/211 [00:56<05:28, 1.83s/it]
33
  15%|█▌ | 32/211 [00:58<05:24, 1.81s/it]
34
  16%|█▌ | 33/211 [01:00<05:28, 1.85s/it]
35
  16%|█▌ | 34/211 [01:01<05:20, 1.81s/it]
36
  17%|█▋ | 35/211 [01:03<05:19, 1.82s/it]
37
  17%|█▋ | 36/211 [01:05<05:14, 1.80s/it]
38
  18%|█▊ | 37/211 [01:07<05:07, 1.77s/it]
39
  18%|█▊ | 38/211 [01:08<05:04, 1.76s/it]
40
  18%|█▊ | 39/211 [01:10<05:17, 1.85s/it]
41
  19%|█▉ | 40/211 [01:12<05:10, 1.82s/it]
42
  19%|█▉ | 41/211 [01:14<05:23, 1.90s/it]
43
  20%|█▉ | 42/211 [01:16<05:15, 1.87s/it]
44
  20%|██ | 43/211 [01:18<05:06, 1.82s/it]
45
  21%|██ | 44/211 [01:20<05:01, 1.81s/it]
46
  21%|██▏ | 45/211 [01:21<04:55, 1.78s/it]
47
  22%|██▏ | 46/211 [01:23<04:54, 1.78s/it]
48
  22%|██▏ | 47/211 [01:25<04:49, 1.77s/it]
49
  23%|██▎ | 48/211 [01:27<04:54, 1.81s/it]
50
  23%|██▎ | 49/211 [01:28<04:48, 1.78s/it]
51
  24%|██▎ | 50/211 [01:30<04:42, 1.75s/it]
52
  24%|██▍ | 51/211 [01:32<04:36, 1.73s/it]
53
  25%|██▍ | 52/211 [01:34<04:44, 1.79s/it]
54
  25%|██▌ | 53/211 [01:36<05:01, 1.91s/it]
55
  26%|██▌ | 54/211 [01:38<04:49, 1.84s/it]
56
  26%|██▌ | 55/211 [01:40<05:11, 2.00s/it]
57
  27%|██▋ | 56/211 [01:42<04:57, 1.92s/it]
58
  27%|██▋ | 57/211 [01:43<04:43, 1.84s/it]
59
  27%|█��▋ | 58/211 [01:45<04:48, 1.88s/it]
60
  28%|██▊ | 59/211 [01:47<04:34, 1.81s/it]
61
  28%|██▊ | 60/211 [01:49<04:36, 1.83s/it]
62
  29%|██▉ | 61/211 [01:51<04:25, 1.77s/it]
63
  29%|██▉ | 62/211 [01:52<04:17, 1.73s/it]
64
  30%|██▉ | 63/211 [01:54<04:13, 1.71s/it]
65
  30%|███ | 64/211 [01:55<04:08, 1.69s/it]
66
  31%|███ | 65/211 [01:57<04:09, 1.71s/it]
67
  31%|███▏ | 66/211 [01:59<04:11, 1.73s/it]
68
  32%|███▏ | 67/211 [02:01<04:11, 1.75s/it]
69
  32%|███▏ | 68/211 [02:03<04:16, 1.79s/it]
70
  33%|███▎ | 69/211 [02:04<04:11, 1.77s/it]
71
  33%|███▎ | 70/211 [02:07<04:30, 1.92s/it]
72
  34%|███▎ | 71/211 [02:08<04:20, 1.86s/it]
73
  34%|███▍ | 72/211 [02:10<04:12, 1.82s/it]
74
  35%|███▍ | 73/211 [02:12<04:07, 1.79s/it]
75
  35%|███▌ | 74/211 [02:14<04:12, 1.84s/it]
76
  36%|███▌ | 75/211 [02:16<04:19, 1.91s/it]
77
  36%|███▌ | 76/211 [02:18<04:06, 1.83s/it]
78
  36%|███▋ | 77/211 [02:20<04:17, 1.92s/it]
79
  37%|███▋ | 78/211 [02:21<04:06, 1.86s/it]
80
  37%|███▋ | 79/211 [02:23<04:03, 1.84s/it]
81
  38%|███▊ | 80/211 [02:26<04:34, 2.10s/it]
82
  38%|███▊ | 81/211 [02:28<04:16, 1.97s/it]
83
  39%|███▉ | 82/211 [02:29<04:01, 1.88s/it]
84
  39%|███▉ | 83/211 [02:31<03:54, 1.84s/it]
85
  40%|███▉ | 84/211 [02:33<03:48, 1.80s/it]
86
  40%|████ | 85/211 [02:34<03:42, 1.77s/it]
87
  41%|████ | 86/211 [02:36<03:39, 1.75s/it]
88
  41%|████ | 87/211 [02:38<03:37, 1.76s/it]
89
  42%|████▏ | 88/211 [02:39<03:33, 1.73s/it]
90
  42%|████▏ | 89/211 [02:41<03:27, 1.70s/it]
91
  43%|████▎ | 90/211 [02:43<03:27, 1.71s/it]
92
  43%|████▎ | 91/211 [02:45<03:27, 1.73s/it]
93
  44%|████▎ | 92/211 [02:46<03:23, 1.71s/it]
94
  44%|████▍ | 93/211 [02:48<03:22, 1.71s/it]
95
  45%|████▍ | 94/211 [02:57<07:25, 3.81s/it]
96
  45%|████▌ | 95/211 [02:58<06:11, 3.20s/it]
97
  45%|████▌ | 96/211 [03:00<05:17, 2.76s/it]
98
  46%|████▌ | 97/211 [03:02<04:47, 2.52s/it]
99
  46%|████▋ | 98/211 [03:04<04:15, 2.27s/it]
100
  47%|████▋ | 99/211 [03:06<03:56, 2.11s/it]
101
  47%|████▋ | 100/211 [03:07<03:42, 2.00s/it]
102
  48%|████▊ | 101/211 [03:09<03:28, 1.90s/it]
103
  48%|████▊ | 102/211 [03:11<03:17, 1.81s/it]
104
  49%|████▉ | 103/211 [03:12<03:12, 1.78s/it]
105
  49%|████▉ | 104/211 [03:14<03:05, 1.74s/it]
106
  50%|████▉ | 105/211 [03:16<03:01, 1.71s/it]
107
  50%|█████ | 106/211 [03:17<03:05, 1.76s/it]
108
  51%|█████ | 107/211 [03:19<03:03, 1.76s/it]
109
  51%|█████ | 108/211 [03:21<02:58, 1.74s/it]
110
  52%|█████▏ | 109/211 [03:23<02:58, 1.75s/it]
111
  52%|█████▏ | 110/211 [03:25<02:58, 1.77s/it]
112
  53%|█████▎ | 111/211 [03:26<02:51, 1.72s/it]
113
  53%|█████▎ | 112/211 [03:28<02:49, 1.71s/it]
114
  54%|█████▎ | 113/211 [03:29<02:46, 1.70s/it]
115
  54%|█████▍ | 114/211 [03:31<02:44, 1.69s/it]
116
  55%|█████▍ | 115/211 [03:33<02:41, 1.68s/it]
117
  55%|█████▍ | 116/211 [03:35<02:42, 1.71s/it]
118
  55%|█████▌ | 117/211 [03:36<02:38, 1.69s/it]
119
  56%|█████▌ | 118/211 [03:38<02:38, 1.70s/it]
120
  56%|█████▋ | 119/211 [03:40<02:35, 1.69s/it]
121
  57%|█████▋ | 120/211 [03:41<02:36, 1.72s/it]
122
  57%|█████▋ | 121/211 [03:43<02:36, 1.74s/it]
123
  58%|█████▊ | 122/211 [03:45<02:31, 1.70s/it]
124
  58%|█████▊ | 123/211 [03:47<02:34, 1.76s/it]
125
  59%|█████▉ | 124/211 [03:48<02:30, 1.73s/it]
126
  59%|█████▉ | 125/211 [03:50<02:25, 1.70s/it]
127
  60%|█████▉ | 126/211 [03:52<02:22, 1.68s/it]
128
  60%|██████ | 127/211 [03:53<02:21, 1.69s/it]
129
  61%|██████ | 128/211 [03:55<02:19, 1.68s/it]
130
  61%|██████ | 129/211 [03:57<02:17, 1.68s/it]
131
  62%|██████▏ | 130/211 [03:59<02:21, 1.75s/it]
132
  62%|██████▏ | 131/211 [04:00<02:20, 1.75s/it]
133
  63%|██████▎ | 132/211 [04:02<02:18, 1.75s/it]
134
  63%|██████▎ | 133/211 [04:04<02:21, 1.81s/it]
135
  64%|██████▎ | 134/211 [04:06<02:15, 1.76s/it]
136
  64%|██████▍ | 135/211 [04:08<02:17, 1.81s/it]
137
  64%|██████▍ | 136/211 [04:10<02:20, 1.88s/it]
138
  65%|██████▍ | 137/211 [04:11<02:13, 1.80s/it]
139
  65%|██████▌ | 138/211 [04:13<02:07, 1.74s/it]
140
  66%|██████▌ | 139/211 [04:16<02:38, 2.20s/it]
141
  66%|██████▋ | 140/211 [04:18<02:27, 2.08s/it]
142
  67%|██████▋ | 141/211 [04:20<02:16, 1.95s/it]
143
  67%|██████▋ | 142/211 [04:21<02:11, 1.91s/it]
144
  68%|██████▊ | 143/211 [04:23<02:06, 1.86s/it]
145
  68%|██████▊ | 144/211 [04:25<02:02, 1.82s/it]
146
  69%|██████▊ | 145/211 [04:27<01:58, 1.80s/it]
147
  69%|██████▉ | 146/211 [04:28<01:53, 1.75s/it]
148
  70%|██████▉ | 147/211 [04:30<01:54, 1.78s/it]
149
  70%|███████ | 148/211 [04:32<01:53, 1.80s/it]
150
  71%|███████ | 149/211 [04:34<01:54, 1.85s/it]
151
  71%|███████ | 150/211 [04:36<01:49, 1.80s/it]
152
  72%|███████▏ | 151/211 [04:37<01:46, 1.77s/it]
153
  72%|███████▏ | 152/211 [04:39<01:45, 1.79s/it]
154
  73%|███████▎ | 153/211 [04:41<01:43, 1.78s/it]
155
  73%|███████▎ | 154/211 [04:43<01:41, 1.78s/it]
156
  73%|███████▎ | 155/211 [04:44<01:37, 1.73s/it]
157
  74%|███████▍ | 156/211 [04:46<01:35, 1.73s/it]
158
  74%|███████▍ | 157/211 [04:48<01:33, 1.73s/it]
159
  75%|███████▍ | 158/211 [04:50<01:33, 1.77s/it]
160
  75%|███████▌ | 159/211 [04:51<01:30, 1.74s/it]
161
  76%|███████▌ | 160/211 [04:53<01:27, 1.72s/it]
162
  76%|███████▋ | 161/211 [04:55<01:28, 1.77s/it]
163
  77%|███████▋ | 162/211 [04:57<01:28, 1.80s/it]
164
  77%|███████▋ | 163/211 [04:58<01:24, 1.75s/it]
165
  78%|███████▊ | 164/211 [05:00<01:26, 1.84s/it]
166
  78%|███████▊ | 165/211 [05:02<01:21, 1.77s/it]
167
  79%|███████▊ | 166/211 [05:04<01:17, 1.72s/it]
168
  79%|███████▉ | 167/211 [05:05<01:15, 1.72s/it]
169
  80%|███████▉ | 168/211 [05:07<01:16, 1.79s/it]
170
  80%|████████ | 169/211 [05:09<01:17, 1.84s/it]
171
  81%|████████ | 170/211 [05:11<01:15, 1.83s/it]
172
  81%|████████ | 171/211 [05:13<01:12, 1.81s/it]
173
  82%|████████▏ | 172/211 [05:15<01:09, 1.79s/it]
174
  82%|████████▏ | 173/211 [05:16<01:06, 1.74s/it]
175
  82%|████████▏ | 174/211 [05:18<01:04, 1.75s/it]
176
  83%|████████▎ | 175/211 [05:20<01:02, 1.74s/it]
177
  83%|████████▎ | 176/211 [05:41<04:21, 7.47s/it]
178
  84%|████████▍ | 177/211 [05:42<03:17, 5.81s/it]
179
  84%|████████▍ | 178/211 [05:44<02:32, 4.63s/it]
180
  85%|████████▍ | 179/211 [05:46<02:00, 3.75s/it]
181
  85%|████████▌ | 180/211 [05:48<01:36, 3.13s/it]
182
  86%|████████▌ | 181/211 [05:49<01:21, 2.72s/it]
183
  86%|████████▋ | 182/211 [05:51<01:10, 2.42s/it]
184
  87%|████████▋ | 183/211 [05:53<01:01, 2.19s/it]
185
  87%|████████▋ | 184/211 [05:55<00:54, 2.03s/it]
186
  88%|████████▊ | 185/211 [05:56<00:50, 1.94s/it]
187
  88%|████████▊ | 186/211 [05:58<00:46, 1.88s/it]
188
  89%|████████▊ | 187/211 [06:00<00:43, 1.82s/it]
189
  89%|████████▉ | 188/211 [06:01<00:41, 1.79s/it]
190
  90%|████████▉ | 189/211 [06:03<00:40, 1.86s/it]
191
  90%|█████████ | 190/211 [06:06<00:45, 2.14s/it]
192
  91%|█████████ | 191/211 [06:08<00:40, 2.02s/it]
193
  91%|█████████ | 192/211 [06:10<00:36, 1.91s/it]
194
  91%|█████████▏| 193/211 [06:11<00:33, 1.85s/it]
195
  92%|█████████▏| 194/211 [06:13<00:30, 1.82s/it]
196
  92%|█████████▏| 195/211 [06:15<00:28, 1.78s/it]
197
  93%|█████████▎| 196/211 [06:17<00:28, 1.92s/it]
198
  93%|█████████▎| 197/211 [06:19<00:25, 1.84s/it]
199
  94%|█████████▍| 198/211 [06:20<00:23, 1.80s/it]
200
  94%|█████████▍| 199/211 [06:22<00:22, 1.84s/it]
201
  95%|█████████▍| 200/211 [06:24<00:19, 1.80s/it]
202
  95%|█████████▌| 201/211 [06:26<00:17, 1.77s/it]
203
  96%|█████████▌| 202/211 [06:28<00:16, 1.82s/it]
204
  96%|█████████▌| 203/211 [06:29<00:14, 1.77s/it]
205
  97%|█████████▋| 204/211 [06:31<00:12, 1.73s/it]
206
  97%|█████████▋| 205/211 [06:33<00:10, 1.72s/it]
207
  98%|█████████▊| 206/211 [06:34<00:08, 1.73s/it]
208
  98%|█████████▊| 207/211 [06:36<00:06, 1.71s/it]
209
  99%|█████████▊| 208/211 [06:39<00:05, 1.97s/it]
210
  99%|█████████▉| 209/211 [06:40<00:03, 1.93s/it]
 
 
 
211
  0%| | 0/469 [00:00<?, ?it/s]
212
  0%| | 1/469 [00:02<15:45, 2.02s/it]
213
  0%| | 2/469 [00:03<14:20, 1.84s/it]
214
  1%| | 3/469 [00:05<13:51, 1.78s/it]
215
  1%| | 4/469 [00:07<13:37, 1.76s/it]
216
  1%| | 5/469 [00:08<13:23, 1.73s/it]
217
  1%|▏ | 6/469 [00:10<13:11, 1.71s/it]
218
  1%|▏ | 7/469 [00:12<13:03, 1.70s/it]
219
  2%|▏ | 8/469 [00:13<12:57, 1.69s/it]
220
  2%|▏ | 9/469 [00:15<12:53, 1.68s/it]
221
  2%|▏ | 10/469 [00:17<13:08, 1.72s/it]
222
  2%|▏ | 11/469 [00:26<31:14, 4.09s/it]
223
  3%|▎ | 12/469 [00:28<25:29, 3.35s/it]
224
  3%|▎ | 13/469 [00:30<21:32, 2.84s/it]
225
  3%|▎ | 14/469 [00:31<18:50, 2.49s/it]
226
  3%|▎ | 15/469 [00:33<16:50, 2.23s/it]
227
  3%|▎ | 16/469 [00:35<15:32, 2.06s/it]
228
  4%|▎ | 17/469 [00:36<14:57, 1.99s/it]
229
  4%|▍ | 18/469 [00:38<14:12, 1.89s/it]
230
  4%|▍ | 19/469 [00:40<13:45, 1.83s/it]
231
  4%|▍ | 20/469 [00:41<13:17, 1.78s/it]
232
  4%|▍ | 21/469 [00:43<13:08, 1.76s/it]
233
  5%|▍ | 22/469 [00:45<12:51, 1.73s/it]
234
  5%|▍ | 23/469 [00:47<12:51, 1.73s/it]
235
  5%|▌ | 24/469 [00:48<12:38, 1.70s/it]
236
  5%|▌ | 25/469 [00:50<12:33, 1.70s/it]
237
  6%|▌ | 26/469 [00:51<12:26, 1.68s/it]
238
  6%|▌ | 27/469 [00:53<12:26, 1.69s/it]
239
  6%|▌ | 28/469 [00:55<12:17, 1.67s/it]
240
  6%|▌ | 29/469 [00:57<12:18, 1.68s/it]
241
  6%|▋ | 30/469 [00:58<12:12, 1.67s/it]
242
  7%|▋ | 31/469 [01:00<12:15, 1.68s/it]
243
  7%|▋ | 32/469 [01:02<12:10, 1.67s/it]
244
  7%|▋ | 33/469 [01:04<12:55, 1.78s/it]
245
  7%|▋ | 34/469 [01:05<12:41, 1.75s/it]
246
  7%|▋ | 35/469 [01:07<12:24, 1.72s/it]
247
  8%|▊ | 36/469 [01:09<12:34, 1.74s/it]
248
  8%|▊ | 37/469 [01:10<12:19, 1.71s/it]
249
  8%|▊ | 38/469 [01:12<12:27, 1.73s/it]
250
  8%|▊ | 39/469 [01:14<12:14, 1.71s/it]
251
  9%|▊ | 40/469 [01:15<12:10, 1.70s/it]
252
  9%|▊ | 41/469 [01:17<12:07, 1.70s/it]
253
  9%|▉ | 42/469 [01:19<12:09, 1.71s/it]
254
  9%|▉ | 43/469 [01:21<12:18, 1.73s/it]
255
  9%|▉ | 44/469 [01:22<12:13, 1.72s/it]
256
  10%|▉ | 45/469 [01:24<12:06, 1.71s/it]
257
  10%|▉ | 46/469 [01:26<12:22, 1.75s/it]
258
  10%|█ | 47/469 [01:28<12:35, 1.79s/it]
259
  10%|█ | 48/469 [01:29<12:12, 1.74s/it]
260
  10%|█ | 49/469 [01:31<12:01, 1.72s/it]
261
  11%|█ | 50/469 [01:33<12:24, 1.78s/it]
262
  11%|█ | 51/469 [01:35<12:32, 1.80s/it]
263
  11%|█ | 52/469 [01:37<12:29, 1.80s/it]
264
  11%|█▏ | 53/469 [01:39<13:30, 1.95s/it]
265
  12%|█▏ | 54/469 [01:41<12:50, 1.86s/it]
266
  12%|█▏ | 55/469 [01:42<12:27, 1.81s/it]
267
  12%|█▏ | 56/469 [01:44<12:05, 1.76s/it]
268
  12%|█▏ | 57/469 [01:46<12:17, 1.79s/it]
269
  12%|█▏ | 58/469 [01:47<12:08, 1.77s/it]
270
  13%|█▎ | 59/469 [01:49<12:09, 1.78s/it]
271
  13%|█▎ | 60/469 [01:51<12:17, 1.80s/it]
272
  13%|█▎ | 61/469 [01:53<12:07, 1.78s/it]
273
  13%|█▎ | 62/469 [01:55<12:24, 1.83s/it]
274
  13%|█▎ | 63/469 [01:56<12:00, 1.77s/it]
275
  14%|█▎ | 64/469 [01:58<11:42, 1.73s/it]
276
  14%|█▍ | 65/469 [02:00<11:35, 1.72s/it]
277
  14%|█▍ | 66/469 [02:02<11:34, 1.72s/it]
278
  14%|█▍ | 67/469 [02:04<12:11, 1.82s/it]
279
  14%|█▍ | 68/469 [02:05<11:45, 1.76s/it]
280
  15%|█▍ | 69/469 [02:07<11:31, 1.73s/it]
281
  15%|█▍ | 70/469 [02:09<11:27, 1.72s/it]
282
  15%|█▌ | 71/469 [02:10<11:25, 1.72s/it]
283
  15%|█▌ | 72/469 [02:12<11:16, 1.70s/it]
284
  16%|█▌ | 73/469 [02:14<11:09, 1.69s/it]
285
  16%|█▌ | 74/469 [02:15<11:32, 1.75s/it]
286
  16%|█▌ | 75/469 [02:17<11:47, 1.80s/it]
287
  16%|█▌ | 76/469 [02:19<11:43, 1.79s/it]
288
  16%|█▋ | 77/469 [02:21<12:10, 1.86s/it]
289
  17%|█▋ | 78/469 [02:23<11:45, 1.80s/it]
290
  17%|█▋ | 79/469 [02:25<12:10, 1.87s/it]
291
  17%|█▋ | 80/469 [02:27<11:48, 1.82s/it]
292
  17%|█▋ | 81/469 [02:28<11:30, 1.78s/it]
293
  17%|█▋ | 82/469 [02:30<11:36, 1.80s/it]
294
  18%|█▊ | 83/469 [02:32<11:18, 1.76s/it]
295
  18%|█▊ | 84/469 [02:34<11:12, 1.75s/it]
296
  18%|█▊ | 85/469 [02:35<11:03, 1.73s/it]
297
  18%|█▊ | 86/469 [02:37<11:06, 1.74s/it]
298
  19%|█▊ | 87/469 [02:39<11:04, 1.74s/it]
299
  19%|█▉ | 88/469 [02:40<10:56, 1.72s/it]
300
  19%|█▉ | 89/469 [02:42<11:07, 1.76s/it]
301
  19%|█▉ | 90/469 [02:45<12:56, 2.05s/it]
302
  19%|█▉ | 91/469 [02:47<12:10, 1.93s/it]
303
  20%|█▉ | 92/469 [02:48<11:35, 1.84s/it]
304
  20%|█▉ | 93/469 [02:50<11:19, 1.81s/it]
305
  20%|██ | 94/469 [02:52<11:04, 1.77s/it]
306
  20%|██ | 95/469 [02:53<10:48, 1.73s/it]
307
  20%|██ | 96/469 [02:55<10:38, 1.71s/it]
308
  21%|██ | 97/469 [02:57<11:01, 1.78s/it]
309
  21%|██ | 98/469 [02:59<10:50, 1.75s/it]
310
  21%|██ | 99/469 [03:00<10:33, 1.71s/it]
311
  21%|██▏ | 100/469 [03:02<10:35, 1.72s/it]
312
  22%|██▏ | 101/469 [03:04<11:42, 1.91s/it]
313
  22%|██▏ | 102/469 [03:06<11:19, 1.85s/it]
314
  22%|██▏ | 103/469 [03:08<11:48, 1.94s/it]
315
  22%|██▏ | 104/469 [03:10<11:23, 1.87s/it]
316
  22%|██▏ | 105/469 [03:11<10:53, 1.80s/it]
317
  23%|██▎ | 106/469 [03:13<10:34, 1.75s/it]
318
  23%|██▎ | 107/469 [03:15<10:32, 1.75s/it]
319
  23%|██▎ | 108/469 [03:17<10:19, 1.71s/it]
320
  23%|██▎ | 109/469 [03:18<10:08, 1.69s/it]
321
  23%|██▎ | 110/469 [03:20<11:07, 1.86s/it]
322
  24%|██▎ | 111/469 [03:22<10:49, 1.81s/it]
323
  24%|██▍ | 112/469 [03:24<10:35, 1.78s/it]
324
  24%|██▍ | 113/469 [03:25<10:22, 1.75s/it]
325
  24%|██▍ | 114/469 [03:27<10:13, 1.73s/it]
326
  25%|██▍ | 115/469 [03:29<10:01, 1.70s/it]
327
  25%|██▍ | 116/469 [03:30<09:56, 1.69s/it]
328
  25%|██▍ | 117/469 [03:32<09:56, 1.69s/it]
329
  25%|██▌ | 118/469 [03:34<09:45, 1.67s/it]
330
  25%|██▌ | 119/469 [03:35<09:45, 1.67s/it]
331
  26%|██▌ | 120/469 [03:37<09:45, 1.68s/it]
332
  26%|██▌ | 121/469 [03:40<11:56, 2.06s/it]
333
  26%|██▌ | 122/469 [03:42<11:18, 1.95s/it]
334
  26%|██▌ | 123/469 [03:44<11:00, 1.91s/it]
335
  26%|██▋ | 124/469 [03:45<10:31, 1.83s/it]
336
  27%|██▋ | 125/469 [03:47<10:06, 1.76s/it]
337
  27%|██▋ | 126/469 [03:48<09:48, 1.72s/it]
338
  27%|██▋ | 127/469 [03:50<09:57, 1.75s/it]
339
  27%|██▋ | 128/469 [03:52<09:52, 1.74s/it]
340
  28%|██▊ | 129/469 [03:54<09:46, 1.72s/it]
341
  28%|██▊ | 130/469 [03:56<09:59, 1.77s/it]
342
  28%|██▊ | 131/469 [03:57<09:56, 1.76s/it]
343
  28%|██▊ | 132/469 [03:59<09:40, 1.72s/it]
344
  28%|██▊ | 133/469 [04:01<10:56, 1.95s/it]
345
  29%|██▊ | 134/469 [04:03<10:35, 1.90s/it]
346
  29%|██▉ | 135/469 [04:05<10:05, 1.81s/it]
347
  29%|██▉ | 136/469 [04:06<09:45, 1.76s/it]
348
  29%|██▉ | 137/469 [04:08<09:30, 1.72s/it]
349
  29%|██▉ | 138/469 [04:10<09:35, 1.74s/it]
350
  30%|██▉ | 139/469 [04:11<09:19, 1.69s/it]
351
  30%|██▉ | 140/469 [04:13<09:06, 1.66s/it]
352
  30%|███ | 141/469 [04:15<09:00, 1.65s/it]
353
  30%|███ | 142/469 [04:16<08:55, 1.64s/it]
354
  30%|███ | 143/469 [04:18<08:55, 1.64s/it]
355
  31%|███ | 144/469 [04:19<08:47, 1.62s/it]
356
  31%|███ | 145/469 [04:21<09:06, 1.69s/it]
357
  31%|███ | 146/469 [04:23<09:06, 1.69s/it]
358
  31%|███▏ | 147/469 [04:25<08:56, 1.67s/it]
359
  32%|███▏ | 148/469 [04:26<08:56, 1.67s/it]
360
  32%|███▏ | 149/469 [04:28<09:09, 1.72s/it]
361
  32%|███▏ | 150/469 [04:30<08:57, 1.69s/it]
362
  32%|███▏ | 151/469 [04:31<08:47, 1.66s/it]
363
  32%|███▏ | 152/469 [04:33<08:41, 1.64s/it]
364
  33%|███▎ | 153/469 [04:35<08:34, 1.63s/it]
365
  33%|███▎ | 154/469 [04:36<08:32, 1.63s/it]
366
  33%|███▎ | 155/469 [04:38<08:25, 1.61s/it]
367
  33%|███▎ | 156/469 [04:39<08:23, 1.61s/it]
368
  33%|███▎ | 157/469 [04:41<08:22, 1.61s/it]
369
  34%|███▎ | 158/469 [04:47<15:58, 3.08s/it]
370
  34%|███▍ | 159/469 [04:49<13:44, 2.66s/it]
371
  34%|███▍ | 160/469 [04:51<12:27, 2.42s/it]
372
  34%|███▍ | 161/469 [04:53<11:14, 2.19s/it]
373
  35%|███▍ | 162/469 [04:54<10:22, 2.03s/it]
374
  35%|███▍ | 163/469 [04:56<09:58, 1.96s/it]
375
  35%|███▍ | 164/469 [04:58<09:35, 1.89s/it]
376
  35%|███▌ | 165/469 [04:59<09:06, 1.80s/it]
377
  35%|███▌ | 166/469 [05:01<08:47, 1.74s/it]
378
  36%|███▌ | 167/469 [05:03<08:31, 1.69s/it]
379
  36%|███▌ | 168/469 [05:05<08:55, 1.78s/it]
380
  36%|███▌ | 169/469 [05:07<09:07, 1.82s/it]
381
  36%|███▌ | 170/469 [05:08<08:54, 1.79s/it]
382
  36%|███▋ | 171/469 [05:10<08:46, 1.77s/it]
383
  37%|███▋ | 172/469 [05:12<08:31, 1.72s/it]
384
  37%|███▋ | 173/469 [05:13<08:23, 1.70s/it]
385
  37%|███▋ | 174/469 [05:15<08:16, 1.68s/it]
386
  37%|███▋ | 175/469 [05:17<08:12, 1.67s/it]
387
  38%|███▊ | 176/469 [05:18<08:14, 1.69s/it]
388
  38%|███▊ | 177/469 [05:20<08:13, 1.69s/it]
389
  38%|███▊ | 178/469 [05:22<08:07, 1.68s/it]
390
  38%|███▊ | 179/469 [05:23<08:00, 1.66s/it]
391
  38%|███▊ | 180/469 [05:25<07:58, 1.66s/it]
392
  39%|███▊ | 181/469 [05:26<07:53, 1.64s/it]
393
  39%|███▉ | 182/469 [05:28<07:47, 1.63s/it]
394
  39%|███▉ | 183/469 [05:30<07:48, 1.64s/it]
395
  39%|███▉ | 184/469 [05:31<07:50, 1.65s/it]
396
  39%|███▉ | 185/469 [05:33<07:44, 1.64s/it]
397
  40%|███▉ | 186/469 [05:35<07:50, 1.66s/it]
398
  40%|███▉ | 187/469 [05:37<08:00, 1.70s/it]
399
  40%|████ | 188/469 [05:38<07:55, 1.69s/it]
400
  40%|████ | 189/469 [05:40<07:48, 1.67s/it]
401
  41%|████ | 190/469 [05:42<07:53, 1.70s/it]
402
  41%|████ | 191/469 [05:43<07:44, 1.67s/it]
403
  41%|████ | 192/469 [05:45<07:38, 1.65s/it]
404
  41%|████ | 193/469 [05:46<07:32, 1.64s/it]
405
  41%|████▏ | 194/469 [05:48<07:54, 1.73s/it]
406
  42%|███��▏ | 195/469 [05:50<07:45, 1.70s/it]
407
  42%|████▏ | 196/469 [05:52<07:44, 1.70s/it]
408
  42%|████▏ | 197/469 [05:53<07:48, 1.72s/it]
409
  42%|████▏ | 198/469 [05:55<07:40, 1.70s/it]
410
  42%|████▏ | 199/469 [05:57<08:02, 1.79s/it]
411
  43%|████▎ | 200/469 [05:59<07:49, 1.74s/it]
412
  43%|████▎ | 201/469 [06:00<07:41, 1.72s/it]
413
  43%|████▎ | 202/469 [06:02<07:38, 1.72s/it]
414
  43%|████▎ | 203/469 [06:04<07:37, 1.72s/it]
415
  43%|████▎ | 204/469 [06:05<07:29, 1.70s/it]
416
  44%|████▎ | 205/469 [06:07<07:31, 1.71s/it]
417
  44%|████▍ | 206/469 [06:09<07:21, 1.68s/it]
418
  44%|████▍ | 207/469 [06:10<07:16, 1.67s/it]
419
  44%|████▍ | 208/469 [06:12<07:09, 1.64s/it]
420
  45%|████▍ | 209/469 [06:14<07:07, 1.64s/it]
421
  45%|████▍ | 210/469 [06:15<07:11, 1.67s/it]
422
  45%|████▍ | 211/469 [06:17<07:21, 1.71s/it]
423
  45%|████▌ | 212/469 [06:19<07:27, 1.74s/it]
424
  45%|████▌ | 213/469 [06:21<07:20, 1.72s/it]
425
  46%|████▌ | 214/469 [06:22<07:09, 1.69s/it]
426
  46%|████▌ | 215/469 [06:24<07:07, 1.68s/it]
427
  46%|████▌ | 216/469 [06:26<07:00, 1.66s/it]
428
  46%|████▋ | 217/469 [06:27<06:56, 1.65s/it]
429
  46%|████▋ | 218/469 [06:29<06:48, 1.63s/it]
430
  47%|████▋ | 219/469 [06:31<07:08, 1.72s/it]
431
  47%|████▋ | 220/469 [06:32<06:57, 1.68s/it]
432
  47%|████▋ | 221/469 [06:34<06:56, 1.68s/it]
433
  47%|████▋ | 222/469 [06:36<06:54, 1.68s/it]
434
  48%|████▊ | 223/469 [06:37<06:47, 1.66s/it]
435
  48%|████▊ | 224/469 [06:39<06:44, 1.65s/it]
436
  48%|████▊ | 225/469 [06:41<06:42, 1.65s/it]
437
  48%|████▊ | 226/469 [06:42<06:35, 1.63s/it]
438
  48%|████▊ | 227/469 [06:44<06:31, 1.62s/it]
439
  49%|████▊ | 228/469 [06:46<06:57, 1.73s/it]
440
  49%|████▉ | 229/469 [06:47<06:47, 1.70s/it]
441
  49%|████▉ | 230/469 [06:49<06:46, 1.70s/it]
442
  49%|████▉ | 231/469 [06:51<06:40, 1.68s/it]
443
  49%|████▉ | 232/469 [06:52<06:33, 1.66s/it]
444
  50%|████▉ | 233/469 [06:54<06:46, 1.72s/it]
445
  50%|████▉ | 234/469 [06:56<07:17, 1.86s/it]
446
  50%|█████ | 235/469 [06:58<07:00, 1.80s/it]
447
  50%|█████ | 236/469 [07:00<06:47, 1.75s/it]
448
  51%|█████ | 237/469 [07:01<06:40, 1.72s/it]
449
  51%|█████ | 238/469 [07:03<06:31, 1.69s/it]
450
  51%|█████ | 239/469 [07:05<06:23, 1.67s/it]
451
  51%|█████ | 240/469 [07:06<06:18, 1.65s/it]
452
  51%|█████▏ | 241/469 [07:08<06:14, 1.64s/it]
453
  52%|█████▏ | 242/469 [07:09<06:10, 1.63s/it]
454
  52%|█████▏ | 243/469 [07:11<06:41, 1.77s/it]
455
  52%|█████▏ | 244/469 [07:13<06:27, 1.72s/it]
456
  52%|█████▏ | 245/469 [07:15<06:49, 1.83s/it]
457
  52%|█████▏ | 246/469 [07:17<06:33, 1.76s/it]
458
  53%|█████▎ | 247/469 [07:18<06:22, 1.72s/it]
459
  53%|█████▎ | 248/469 [07:20<06:19, 1.72s/it]
460
  53%|█████▎ | 249/469 [07:22<06:12, 1.69s/it]
461
  53%|█████▎ | 250/469 [07:24<06:16, 1.72s/it]
462
  54%|█████▎ | 251/469 [07:25<06:10, 1.70s/it]
463
  54%|█████▎ | 252/469 [07:27<06:11, 1.71s/it]
464
  54%|█████▍ | 253/469 [07:29<06:03, 1.68s/it]
465
  54%|█████▍ | 254/469 [07:30<06:04, 1.70s/it]
466
  54%|█████▍ | 255/469 [07:32<05:56, 1.67s/it]
467
  55%|█████▍ | 256/469 [07:34<06:02, 1.70s/it]
468
  55%|█████▍ | 257/469 [07:35<05:56, 1.68s/it]
469
  55%|█████▌ | 258/469 [07:37<06:02, 1.72s/it]
470
  55%|█████▌ | 259/469 [07:39<05:57, 1.70s/it]
471
  55%|█████▌ | 260/469 [07:40<05:50, 1.68s/it]
472
  56%|█████▌ | 261/469 [07:42<05:50, 1.69s/it]
473
  56%|█████▌ | 262/469 [07:44<05:52, 1.70s/it]
474
  56%|█████▌ | 263/469 [07:46<05:50, 1.70s/it]
475
  56%|█████▋ | 264/469 [07:47<05:48, 1.70s/it]
476
  57%|█████▋ | 265/469 [07:49<05:45, 1.69s/it]
477
  57%|█████▋ | 266/469 [07:50<05:37, 1.66s/it]
478
  57%|█████▋ | 267/469 [07:52<05:33, 1.65s/it]
479
  57%|█████▋ | 268/469 [07:58<09:27, 2.82s/it]
480
  57%|█████▋ | 269/469 [07:59<08:23, 2.52s/it]
481
  58%|█████▊ | 270/469 [08:01<07:26, 2.24s/it]
482
  58%|█████▊ | 271/469 [08:03<06:53, 2.09s/it]
483
  58%|█████▊ | 272/469 [08:04<06:26, 1.96s/it]
484
  58%|█████▊ | 273/469 [08:06<06:03, 1.86s/it]
485
  58%|█████▊ | 274/469 [08:08<05:44, 1.77s/it]
486
  59%|█████▊ | 275/469 [08:09<05:33, 1.72s/it]
487
  59%|█████▉ | 276/469 [08:11<05:25, 1.69s/it]
488
  59%|█████▉ | 277/469 [08:12<05:21, 1.68s/it]
489
  59%|█████▉ | 278/469 [08:14<05:21, 1.68s/it]
490
  59%|█████▉ | 279/469 [08:16<05:26, 1.72s/it]
491
  60%|█████▉ | 280/469 [08:18<05:16, 1.67s/it]
492
  60%|█████▉ | 281/469 [08:20<05:45, 1.84s/it]
493
  60%|██████ | 282/469 [08:21<05:31, 1.77s/it]
494
  60%|██████ | 283/469 [08:23<05:26, 1.75s/it]
495
  61%|██████ | 284/469 [08:25<05:17, 1.72s/it]
496
  61%|██████ | 285/469 [08:26<05:10, 1.69s/it]
497
  61%|██████ | 286/469 [08:28<05:21, 1.76s/it]
498
  61%|██████ | 287/469 [08:30<05:16, 1.74s/it]
499
  61%|██████▏ | 288/469 [08:32<05:13, 1.73s/it]
500
  62%|██████▏ | 289/469 [08:33<05:03, 1.69s/it]
501
  62%|██████▏ | 290/469 [08:35<05:10, 1.74s/it]
502
  62%|██████▏ | 291/469 [08:37<05:17, 1.78s/it]
503
  62%|██████▏ | 292/469 [08:39<05:17, 1.79s/it]
504
  62%|██████▏ | 293/469 [08:41<05:19, 1.81s/it]
505
  63%|██████▎ | 294/469 [08:42<05:06, 1.75s/it]
506
  63%|██████▎ | 295/469 [08:44<04:55, 1.70s/it]
507
  63%|██████▎ | 296/469 [08:46<04:53, 1.70s/it]
508
  63%|██████▎ | 297/469 [08:47<04:46, 1.67s/it]
509
  64%|██████▎ | 298/469 [08:49<04:40, 1.64s/it]
510
  64%|██████▍ | 299/469 [08:50<04:39, 1.65s/it]
511
  64%|██████▍ | 300/469 [08:52<04:37, 1.64s/it]
512
  64%|██████▍ | 301/469 [08:54<04:39, 1.66s/it]
513
  64%|██████▍ | 302/469 [08:55<04:35, 1.65s/it]
514
  65%|██████▍ | 303/469 [08:57<04:36, 1.66s/it]
515
  65%|██████▍ | 304/469 [08:59<04:36, 1.67s/it]
516
  65%|██████▌ | 305/469 [09:00<04:34, 1.68s/it]
517
  65%|██████▌ | 306/469 [09:02<04:32, 1.67s/it]
518
  65%|██████▌ | 307/469 [09:04<04:30, 1.67s/it]
519
  66%|██████▌ | 308/469 [09:05<04:25, 1.65s/it]
520
  66%|██████▌ | 309/469 [09:07<04:23, 1.65s/it]
521
  66%|██████▌ | 310/469 [09:09<04:20, 1.64s/it]
522
  66%|██████▋ | 311/469 [09:10<04:18, 1.64s/it]
523
  67%|██████▋ | 312/469 [09:12<04:19, 1.65s/it]
524
  67%|██████▋ | 313/469 [09:14<04:15, 1.64s/it]
525
  67%|██████▋ | 314/469 [09:15<04:10, 1.62s/it]
526
  67%|██████▋ | 315/469 [09:17<04:08, 1.61s/it]
527
  67%|██████▋ | 316/469 [09:19<04:21, 1.71s/it]
528
  68%|██████▊ | 317/469 [09:20<04:23, 1.73s/it]
529
  68%|██████▊ | 318/469 [09:22<04:18, 1.71s/it]
530
  68%|██████▊ | 319/469 [09:24<04:20, 1.74s/it]
531
  68%|██████▊ | 320/469 [09:26<04:15, 1.71s/it]
532
  68%|██████▊ | 321/469 [09:27<04:21, 1.77s/it]
533
  69%|██████▊ | 322/469 [09:29<04:12, 1.72s/it]
534
  69%|██████▉ | 323/469 [09:31<04:04, 1.67s/it]
535
  69%|██████▉ | 324/469 [09:32<04:01, 1.67s/it]
536
  69%|██████▉ | 325/469 [09:34<03:55, 1.64s/it]
537
  70%|██████▉ | 326/469 [09:35<03:52, 1.62s/it]
538
  70%|██████▉ | 327/469 [09:37<03:49, 1.61s/it]
539
  70%|██████▉ | 328/469 [09:39<03:52, 1.65s/it]
540
  70%|███████ | 329/469 [09:41<04:01, 1.72s/it]
541
  70%|███████ | 330/469 [09:43<04:06, 1.77s/it]
542
  71%|███████ | 331/469 [09:44<04:03, 1.76s/it]
543
  71%|███████ | 332/469 [09:46<03:55, 1.72s/it]
544
  71%|███████ | 333/469 [09:47<03:48, 1.68s/it]
545
  71%|███████ | 334/469 [09:49<03:46, 1.68s/it]
546
  71%|███████▏ | 335/469 [09:51<03:42, 1.66s/it]
547
  72%|███████▏ | 336/469 [09:52<03:37, 1.63s/it]
548
  72%|███████▏ | 337/469 [09:54<03:33, 1.62s/it]
549
  72%|███████▏ | 338/469 [09:56<03:33, 1.63s/it]
550
  72%|███████▏ | 339/469 [09:57<03:31, 1.62s/it]
551
  72%|███████▏ | 340/469 [09:59<03:27, 1.61s/it]
552
  73%|███████▎ | 341/469 [10:01<03:32, 1.66s/it]
553
  73%|███████▎ | 342/469 [10:03<03:41, 1.75s/it]
554
  73%|███████▎ | 343/469 [10:04<03:35, 1.71s/it]
555
  73%|███████▎ | 344/469 [10:06<03:31, 1.70s/it]
556
  74%|███████▎ | 345/469 [10:27<15:27, 7.48s/it]
557
  74%|███████▍ | 346/469 [10:29<11:48, 5.76s/it]
558
  74%|███████▍ | 347/469 [10:30<09:19, 4.59s/it]
559
  74%|███████▍ | 348/469 [10:32<07:29, 3.71s/it]
560
  74%|███████▍ | 349/469 [10:34<06:11, 3.10s/it]
561
  75%|███████▍ | 350/469 [10:36<05:49, 2.93s/it]
562
  75%|███████▍ | 351/469 [10:38<05:09, 2.63s/it]
563
  75%|███████▌ | 352/469 [10:40<04:39, 2.39s/it]
564
  75%|███████▌ | 353/469 [10:42<04:16, 2.21s/it]
565
  75%|███████▌ | 354/469 [10:43<03:53, 2.03s/it]
566
  76%|███████▌ | 355/469 [10:45<03:35, 1.89s/it]
567
  76%|███████▌ | 356/469 [10:47<03:30, 1.86s/it]
568
  76%|███████▌ | 357/469 [10:48<03:23, 1.82s/it]
569
  76%|███████▋ | 358/469 [10:50<03:19, 1.80s/it]
570
  77%|███████▋ | 359/469 [10:52<03:20, 1.82s/it]
571
  77%|███████▋ | 360/469 [10:54<03:23, 1.87s/it]
572
  77%|███████▋ | 361/469 [10:56<03:16, 1.82s/it]
573
  77%|███████▋ | 362/469 [10:58<03:24, 1.91s/it]
574
  77%|███████▋ | 363/469 [11:00<03:12, 1.82s/it]
575
  78%|███████▊ | 364/469 [11:01<03:05, 1.77s/it]
576
  78%|███████▊ | 365/469 [11:03<02:58, 1.72s/it]
577
  78%|███████▊ | 366/469 [11:04<02:54, 1.69s/it]
578
  78%|███████▊ | 367/469 [11:06<02:52, 1.69s/it]
579
  78%|███████▊ | 368/469 [11:08<02:50, 1.68s/it]
580
  79%|███████▊ | 369/469 [11:10<02:50, 1.71s/it]
581
  79%|███████▉ | 370/469 [11:11<02:49, 1.71s/it]
582
  79%|███████▉ | 371/469 [11:13<02:46, 1.70s/it]
583
  79%|███████▉ | 372/469 [11:15<02:49, 1.74s/it]
584
  80%|███████▉ | 373/469 [11:16<02:43, 1.70s/it]
585
  80%|███████▉ | 374/469 [11:18<02:42, 1.71s/it]
586
  80%|███████▉ | 375/469 [11:20<02:44, 1.75s/it]
587
  80%|████████ | 376/469 [11:22<02:47, 1.80s/it]
588
  80%|████████ | 377/469 [11:24<02:41, 1.76s/it]
589
  81%|████████ | 378/469 [11:25<02:34, 1.70s/it]
590
  81%|████████ | 379/469 [11:27<02:30, 1.67s/it]
591
  81%|████████ | 380/469 [11:29<02:33, 1.72s/it]
592
  81%|████████ | 381/469 [11:30<02:28, 1.69s/it]
593
  81%|████████▏ | 382/469 [11:32<02:24, 1.66s/it]
594
  82%|████████▏ | 383/469 [11:34<02:29, 1.73s/it]
595
  82%|████████▏ | 384/469 [11:36<02:33, 1.80s/it]
596
  82%|████████▏ | 385/469 [11:38<02:34, 1.84s/it]
597
  82%|████████▏ | 386/469 [11:39<02:27, 1.77s/it]
598
  83%|████████▎ | 387/469 [11:41<02:26, 1.79s/it]
599
  83%|████████▎ | 388/469 [11:43<02:21, 1.74s/it]
600
  83%|████████▎ | 389/469 [11:45<02:23, 1.79s/it]
601
  83%|████████▎ | 390/469 [11:46<02:16, 1.73s/it]
602
  83%|████████▎ | 391/469 [11:48<02:12, 1.70s/it]
603
  84%|████████▎ | 392/469 [11:49<02:09, 1.69s/it]
604
  84%|████████▍ | 393/469 [11:51<02:05, 1.66s/it]
605
  84%|████████▍ | 394/469 [11:53<02:03, 1.65s/it]
606
  84%|████████▍ | 395/469 [11:54<02:01, 1.65s/it]
607
  84%|████████▍ | 396/469 [11:56<01:59, 1.63s/it]
608
  85%|████████▍ | 397/469 [11:58<01:58, 1.64s/it]
609
  85%|████████▍ | 398/469 [11:59<02:00, 1.70s/it]
610
  85%|████████▌ | 399/469 [12:01<01:57, 1.68s/it]
611
  85%|████████▌ | 400/469 [12:03<01:54, 1.66s/it]
612
  86%|████████▌ | 401/469 [12:04<01:51, 1.64s/it]
613
  86%|████████▌ | 402/469 [12:25<08:21, 7.49s/it]
614
  86%|████████▌ | 403/469 [12:27<06:17, 5.72s/it]
615
  86%|████████▌ | 404/469 [12:29<04:51, 4.49s/it]
616
  86%|████████▋ | 405/469 [12:30<03:53, 3.65s/it]
617
  87%|████████▋ | 406/469 [12:32<03:12, 3.06s/it]
618
  87%|████████▋ | 407/469 [12:34<02:42, 2.63s/it]
619
  87%|████████▋ | 408/469 [12:35<02:21, 2.32s/it]
620
  87%|████████▋ | 409/469 [12:37<02:06, 2.10s/it]
621
  87%|████████▋ | 410/469 [12:38<01:54, 1.94s/it]
622
  88%|████████▊ | 411/469 [12:40<01:46, 1.83s/it]
623
  88%|████████▊ | 412/469 [12:41<01:39, 1.75s/it]
624
  88%|████████▊ | 413/469 [12:43<01:36, 1.73s/it]
625
  88%|████████▊ | 414/469 [12:45<01:33, 1.69s/it]
626
  88%|████████▊ | 415/469 [12:46<01:29, 1.65s/it]
627
  89%|████████▊ | 416/469 [12:48<01:28, 1.68s/it]
628
  89%|████████▉ | 417/469 [12:50<01:26, 1.67s/it]
629
  89%|████████▉ | 418/469 [12:51<01:24, 1.65s/it]
630
  89%|████████▉ | 419/469 [12:53<01:21, 1.64s/it]
631
  90%|████████▉ | 420/469 [12:55<01:21, 1.66s/it]
632
  90%|████████▉ | 421/469 [12:56<01:19, 1.66s/it]
633
  90%|████████▉ | 422/469 [12:58<01:18, 1.67s/it]
634
  90%|█████████ | 423/469 [13:00<01:18, 1.71s/it]
635
  90%|█████████ | 424/469 [13:02<01:18, 1.73s/it]
636
  91%|█████████ | 425/469 [13:03<01:14, 1.69s/it]
637
  91%|█████████ | 426/469 [13:05<01:11, 1.66s/it]
638
  91%|█████████ | 427/469 [13:07<01:13, 1.74s/it]
639
  91%|█████████▏| 428/469 [13:08<01:09, 1.70s/it]
640
  91%|█████████▏| 429/469 [13:10<01:07, 1.68s/it]
641
  92%|█████████▏| 430/469 [13:12<01:05, 1.68s/it]
642
  92%|█████████▏| 431/469 [13:13<01:05, 1.71s/it]
643
  92%|█████████▏| 432/469 [13:15<01:02, 1.68s/it]
644
  92%|█████████▏| 433/469 [13:17<01:01, 1.71s/it]
645
  93%|█████████▎| 434/469 [13:18<00:58, 1.69s/it]
646
  93%|█████████▎| 435/469 [13:20<00:57, 1.69s/it]
647
  93%|█████████▎| 436/469 [13:22<00:55, 1.68s/it]
648
  93%|█████████▎| 437/469 [13:23<00:53, 1.67s/it]
649
  93%|█████████▎| 438/469 [13:25<00:51, 1.66s/it]
650
  94%|█████████▎| 439/469 [13:27<00:49, 1.64s/it]
651
  94%|█████████▍| 440/469 [13:28<00:47, 1.65s/it]
652
  94%|█████████▍| 441/469 [13:30<00:46, 1.65s/it]
653
  94%|█████████▍| 442/469 [13:32<00:44, 1.63s/it]
654
  94%|█████████▍| 443/469 [13:33<00:43, 1.67s/it]
655
  95%|█████████▍| 444/469 [13:35<00:41, 1.65s/it]
656
  95%|█████████▍| 445/469 [13:36<00:39, 1.64s/it]
657
  95%|█████████▌| 446/469 [13:39<00:40, 1.77s/it]
658
  95%|█████████▌| 447/469 [13:40<00:37, 1.72s/it]
659
  96%|█████████▌| 448/469 [13:42<00:35, 1.68s/it]
660
  96%|█████████▌| 449/469 [13:44<00:36, 1.84s/it]
661
  96%|█████████▌| 450/469 [13:46<00:33, 1.78s/it]
662
  96%|█████████▌| 451/469 [13:47<00:30, 1.72s/it]
663
  96%|█████████▋| 452/469 [13:49<00:30, 1.78s/it]
664
  97%|█████████▋| 453/469 [13:51<00:27, 1.73s/it]
665
  97%|█████████▋| 454/469 [13:52<00:25, 1.70s/it]
666
  97%|█████████▋| 455/469 [13:54<00:23, 1.70s/it]
667
  97%|█████████▋| 456/469 [13:56<00:21, 1.67s/it]
668
  97%|█████████▋| 457/469 [13:57<00:19, 1.66s/it]
669
  98%|█████████▊| 458/469 [13:59<00:18, 1.65s/it]
670
  98%|█████████▊| 459/469 [14:01<00:16, 1.64s/it]
671
  98%|█████████▊| 460/469 [14:02<00:14, 1.65s/it]
672
  98%|█████████▊| 461/469 [14:04<00:13, 1.65s/it]
673
  99%|█████████▊| 462/469 [14:06<00:11, 1.67s/it]
674
  99%|█████████▊| 463/469 [14:07<00:10, 1.67s/it]
675
  99%|█████████▉| 464/469 [14:09<00:09, 1.83s/it]
676
  99%|█████████▉| 465/469 [14:11<00:07, 1.84s/it]
677
  99%|█████████▉| 466/469 [14:13<00:05, 1.80s/it]
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ 2026-03-25 14:55:31.459313: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
3
+ 2026-03-25 14:55:36.111147: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
4
+ To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
5
+ 2026-03-25 14:55:36.137507: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
6
+ 2026-03-25 14:55:36.137574: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: 66d2d54653616c6252364513da490658-taskrole1-0
7
+ 2026-03-25 14:55:36.137623: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: 66d2d54653616c6252364513da490658-taskrole1-0
8
+ 2026-03-25 14:55:36.137710: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
9
+ 2026-03-25 14:55:36.137756: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 535.154.5
10
+ 2026-03-25 14:55:37.175397: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
11
+ warming up TensorFlow...
12
+
13
  0%| | 0/1 [00:00<?, ?it/s]2026-03-25 14:55:37.841840: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
14
+
15
+ computing reference batch activations...
16
+
17
  0%| | 0/211 [00:00<?, ?it/s]
18
  0%| | 1/211 [00:02<07:28, 2.13s/it]
19
  1%| | 2/211 [00:03<06:40, 1.92s/it]
20
  1%|▏ | 3/211 [00:06<08:10, 2.36s/it]
21
  2%|▏ | 4/211 [00:08<07:14, 2.10s/it]
22
  2%|▏ | 5/211 [00:10<06:43, 1.96s/it]
23
  3%|▎ | 6/211 [00:11<06:27, 1.89s/it]
24
  3%|▎ | 7/211 [00:13<06:15, 1.84s/it]
25
  4%|▍ | 8/211 [00:15<06:05, 1.80s/it]
26
  4%|▍ | 9/211 [00:17<06:00, 1.79s/it]
27
  5%|▍ | 10/211 [00:18<05:53, 1.76s/it]
28
  5%|▌ | 11/211 [00:20<05:47, 1.74s/it]
29
  6%|▌ | 12/211 [00:22<05:40, 1.71s/it]
30
  6%|▌ | 13/211 [00:23<05:34, 1.69s/it]
31
  7%|▋ | 14/211 [00:25<05:32, 1.69s/it]
32
  7%|▋ | 15/211 [00:27<05:39, 1.73s/it]
33
  8%|▊ | 16/211 [00:29<05:34, 1.71s/it]
34
  8%|▊ | 17/211 [00:30<05:34, 1.72s/it]
35
  9%|▊ | 18/211 [00:32<06:00, 1.87s/it]
36
  9%|▉ | 19/211 [00:34<05:48, 1.82s/it]
37
  9%|▉ | 20/211 [00:36<05:44, 1.80s/it]
38
  10%|▉ | 21/211 [00:38<05:39, 1.78s/it]
39
  10%|█ | 22/211 [00:40<05:40, 1.80s/it]
40
  11%|█ | 23/211 [00:41<05:45, 1.84s/it]
41
  11%|█▏ | 24/211 [00:43<05:43, 1.84s/it]
42
  12%|█▏ | 25/211 [00:45<05:35, 1.80s/it]
43
  12%|█▏ | 26/211 [00:47<05:41, 1.84s/it]
44
  13%|█▎ | 27/211 [00:49<05:42, 1.86s/it]
45
  13%|█▎ | 28/211 [00:51<05:36, 1.84s/it]
46
  14%|█▎ | 29/211 [00:52<05:27, 1.80s/it]
47
  14%|█▍ | 30/211 [00:54<05:18, 1.76s/it]
48
  15%|█▍ | 31/211 [00:56<05:28, 1.83s/it]
49
  15%|█▌ | 32/211 [00:58<05:24, 1.81s/it]
50
  16%|█▌ | 33/211 [01:00<05:28, 1.85s/it]
51
  16%|█▌ | 34/211 [01:01<05:20, 1.81s/it]
52
  17%|█▋ | 35/211 [01:03<05:19, 1.82s/it]
53
  17%|█▋ | 36/211 [01:05<05:14, 1.80s/it]
54
  18%|█▊ | 37/211 [01:07<05:07, 1.77s/it]
55
  18%|█▊ | 38/211 [01:08<05:04, 1.76s/it]
56
  18%|█▊ | 39/211 [01:10<05:17, 1.85s/it]
57
  19%|█▉ | 40/211 [01:12<05:10, 1.82s/it]
58
  19%|█▉ | 41/211 [01:14<05:23, 1.90s/it]
59
  20%|█▉ | 42/211 [01:16<05:15, 1.87s/it]
60
  20%|██ | 43/211 [01:18<05:06, 1.82s/it]
61
  21%|██ | 44/211 [01:20<05:01, 1.81s/it]
62
  21%|██▏ | 45/211 [01:21<04:55, 1.78s/it]
63
  22%|██▏ | 46/211 [01:23<04:54, 1.78s/it]
64
  22%|██▏ | 47/211 [01:25<04:49, 1.77s/it]
65
  23%|██▎ | 48/211 [01:27<04:54, 1.81s/it]
66
  23%|██▎ | 49/211 [01:28<04:48, 1.78s/it]
67
  24%|██▎ | 50/211 [01:30<04:42, 1.75s/it]
68
  24%|██▍ | 51/211 [01:32<04:36, 1.73s/it]
69
  25%|██▍ | 52/211 [01:34<04:44, 1.79s/it]
70
  25%|██▌ | 53/211 [01:36<05:01, 1.91s/it]
71
  26%|██▌ | 54/211 [01:38<04:49, 1.84s/it]
72
  26%|██▌ | 55/211 [01:40<05:11, 2.00s/it]
73
  27%|██▋ | 56/211 [01:42<04:57, 1.92s/it]
74
  27%|██▋ | 57/211 [01:43<04:43, 1.84s/it]
75
  27%|█��▋ | 58/211 [01:45<04:48, 1.88s/it]
76
  28%|██▊ | 59/211 [01:47<04:34, 1.81s/it]
77
  28%|██▊ | 60/211 [01:49<04:36, 1.83s/it]
78
  29%|██▉ | 61/211 [01:51<04:25, 1.77s/it]
79
  29%|██▉ | 62/211 [01:52<04:17, 1.73s/it]
80
  30%|██▉ | 63/211 [01:54<04:13, 1.71s/it]
81
  30%|███ | 64/211 [01:55<04:08, 1.69s/it]
82
  31%|███ | 65/211 [01:57<04:09, 1.71s/it]
83
  31%|███▏ | 66/211 [01:59<04:11, 1.73s/it]
84
  32%|███▏ | 67/211 [02:01<04:11, 1.75s/it]
85
  32%|███▏ | 68/211 [02:03<04:16, 1.79s/it]
86
  33%|███▎ | 69/211 [02:04<04:11, 1.77s/it]
87
  33%|███▎ | 70/211 [02:07<04:30, 1.92s/it]
88
  34%|███▎ | 71/211 [02:08<04:20, 1.86s/it]
89
  34%|███▍ | 72/211 [02:10<04:12, 1.82s/it]
90
  35%|███▍ | 73/211 [02:12<04:07, 1.79s/it]
91
  35%|███▌ | 74/211 [02:14<04:12, 1.84s/it]
92
  36%|███▌ | 75/211 [02:16<04:19, 1.91s/it]
93
  36%|███▌ | 76/211 [02:18<04:06, 1.83s/it]
94
  36%|███▋ | 77/211 [02:20<04:17, 1.92s/it]
95
  37%|███▋ | 78/211 [02:21<04:06, 1.86s/it]
96
  37%|███▋ | 79/211 [02:23<04:03, 1.84s/it]
97
  38%|███▊ | 80/211 [02:26<04:34, 2.10s/it]
98
  38%|███▊ | 81/211 [02:28<04:16, 1.97s/it]
99
  39%|███▉ | 82/211 [02:29<04:01, 1.88s/it]
100
  39%|███▉ | 83/211 [02:31<03:54, 1.84s/it]
101
  40%|███▉ | 84/211 [02:33<03:48, 1.80s/it]
102
  40%|████ | 85/211 [02:34<03:42, 1.77s/it]
103
  41%|████ | 86/211 [02:36<03:39, 1.75s/it]
104
  41%|████ | 87/211 [02:38<03:37, 1.76s/it]
105
  42%|████▏ | 88/211 [02:39<03:33, 1.73s/it]
106
  42%|████▏ | 89/211 [02:41<03:27, 1.70s/it]
107
  43%|████▎ | 90/211 [02:43<03:27, 1.71s/it]
108
  43%|████▎ | 91/211 [02:45<03:27, 1.73s/it]
109
  44%|████▎ | 92/211 [02:46<03:23, 1.71s/it]
110
  44%|████▍ | 93/211 [02:48<03:22, 1.71s/it]
111
  45%|████▍ | 94/211 [02:57<07:25, 3.81s/it]
112
  45%|████▌ | 95/211 [02:58<06:11, 3.20s/it]
113
  45%|████▌ | 96/211 [03:00<05:17, 2.76s/it]
114
  46%|████▌ | 97/211 [03:02<04:47, 2.52s/it]
115
  46%|████▋ | 98/211 [03:04<04:15, 2.27s/it]
116
  47%|████▋ | 99/211 [03:06<03:56, 2.11s/it]
117
  47%|████▋ | 100/211 [03:07<03:42, 2.00s/it]
118
  48%|████▊ | 101/211 [03:09<03:28, 1.90s/it]
119
  48%|████▊ | 102/211 [03:11<03:17, 1.81s/it]
120
  49%|████▉ | 103/211 [03:12<03:12, 1.78s/it]
121
  49%|████▉ | 104/211 [03:14<03:05, 1.74s/it]
122
  50%|████▉ | 105/211 [03:16<03:01, 1.71s/it]
123
  50%|█████ | 106/211 [03:17<03:05, 1.76s/it]
124
  51%|█████ | 107/211 [03:19<03:03, 1.76s/it]
125
  51%|█████ | 108/211 [03:21<02:58, 1.74s/it]
126
  52%|█████▏ | 109/211 [03:23<02:58, 1.75s/it]
127
  52%|█████▏ | 110/211 [03:25<02:58, 1.77s/it]
128
  53%|█████▎ | 111/211 [03:26<02:51, 1.72s/it]
129
  53%|█████▎ | 112/211 [03:28<02:49, 1.71s/it]
130
  54%|█████▎ | 113/211 [03:29<02:46, 1.70s/it]
131
  54%|█████▍ | 114/211 [03:31<02:44, 1.69s/it]
132
  55%|█████▍ | 115/211 [03:33<02:41, 1.68s/it]
133
  55%|█████▍ | 116/211 [03:35<02:42, 1.71s/it]
134
  55%|█████▌ | 117/211 [03:36<02:38, 1.69s/it]
135
  56%|█████▌ | 118/211 [03:38<02:38, 1.70s/it]
136
  56%|█████▋ | 119/211 [03:40<02:35, 1.69s/it]
137
  57%|█████▋ | 120/211 [03:41<02:36, 1.72s/it]
138
  57%|█████▋ | 121/211 [03:43<02:36, 1.74s/it]
139
  58%|█████▊ | 122/211 [03:45<02:31, 1.70s/it]
140
  58%|█████▊ | 123/211 [03:47<02:34, 1.76s/it]
141
  59%|█████▉ | 124/211 [03:48<02:30, 1.73s/it]
142
  59%|█████▉ | 125/211 [03:50<02:25, 1.70s/it]
143
  60%|█████▉ | 126/211 [03:52<02:22, 1.68s/it]
144
  60%|██████ | 127/211 [03:53<02:21, 1.69s/it]
145
  61%|██████ | 128/211 [03:55<02:19, 1.68s/it]
146
  61%|██████ | 129/211 [03:57<02:17, 1.68s/it]
147
  62%|██████▏ | 130/211 [03:59<02:21, 1.75s/it]
148
  62%|██████▏ | 131/211 [04:00<02:20, 1.75s/it]
149
  63%|██████▎ | 132/211 [04:02<02:18, 1.75s/it]
150
  63%|██████▎ | 133/211 [04:04<02:21, 1.81s/it]
151
  64%|██████▎ | 134/211 [04:06<02:15, 1.76s/it]
152
  64%|██████▍ | 135/211 [04:08<02:17, 1.81s/it]
153
  64%|██████▍ | 136/211 [04:10<02:20, 1.88s/it]
154
  65%|██████▍ | 137/211 [04:11<02:13, 1.80s/it]
155
  65%|██████▌ | 138/211 [04:13<02:07, 1.74s/it]
156
  66%|██████▌ | 139/211 [04:16<02:38, 2.20s/it]
157
  66%|██████▋ | 140/211 [04:18<02:27, 2.08s/it]
158
  67%|██████▋ | 141/211 [04:20<02:16, 1.95s/it]
159
  67%|██████▋ | 142/211 [04:21<02:11, 1.91s/it]
160
  68%|██████▊ | 143/211 [04:23<02:06, 1.86s/it]
161
  68%|██████▊ | 144/211 [04:25<02:02, 1.82s/it]
162
  69%|██████▊ | 145/211 [04:27<01:58, 1.80s/it]
163
  69%|██████▉ | 146/211 [04:28<01:53, 1.75s/it]
164
  70%|██████▉ | 147/211 [04:30<01:54, 1.78s/it]
165
  70%|███████ | 148/211 [04:32<01:53, 1.80s/it]
166
  71%|███████ | 149/211 [04:34<01:54, 1.85s/it]
167
  71%|███████ | 150/211 [04:36<01:49, 1.80s/it]
168
  72%|███████▏ | 151/211 [04:37<01:46, 1.77s/it]
169
  72%|███████▏ | 152/211 [04:39<01:45, 1.79s/it]
170
  73%|███████▎ | 153/211 [04:41<01:43, 1.78s/it]
171
  73%|███████▎ | 154/211 [04:43<01:41, 1.78s/it]
172
  73%|███████▎ | 155/211 [04:44<01:37, 1.73s/it]
173
  74%|███████▍ | 156/211 [04:46<01:35, 1.73s/it]
174
  74%|███████▍ | 157/211 [04:48<01:33, 1.73s/it]
175
  75%|███████▍ | 158/211 [04:50<01:33, 1.77s/it]
176
  75%|███████▌ | 159/211 [04:51<01:30, 1.74s/it]
177
  76%|███████▌ | 160/211 [04:53<01:27, 1.72s/it]
178
  76%|███████▋ | 161/211 [04:55<01:28, 1.77s/it]
179
  77%|███████▋ | 162/211 [04:57<01:28, 1.80s/it]
180
  77%|███████▋ | 163/211 [04:58<01:24, 1.75s/it]
181
  78%|███████▊ | 164/211 [05:00<01:26, 1.84s/it]
182
  78%|███████▊ | 165/211 [05:02<01:21, 1.77s/it]
183
  79%|███████▊ | 166/211 [05:04<01:17, 1.72s/it]
184
  79%|███████▉ | 167/211 [05:05<01:15, 1.72s/it]
185
  80%|███████▉ | 168/211 [05:07<01:16, 1.79s/it]
186
  80%|████████ | 169/211 [05:09<01:17, 1.84s/it]
187
  81%|████████ | 170/211 [05:11<01:15, 1.83s/it]
188
  81%|████████ | 171/211 [05:13<01:12, 1.81s/it]
189
  82%|████████▏ | 172/211 [05:15<01:09, 1.79s/it]
190
  82%|████████▏ | 173/211 [05:16<01:06, 1.74s/it]
191
  82%|████████▏ | 174/211 [05:18<01:04, 1.75s/it]
192
  83%|████████▎ | 175/211 [05:20<01:02, 1.74s/it]
193
  83%|████████▎ | 176/211 [05:41<04:21, 7.47s/it]
194
  84%|████████▍ | 177/211 [05:42<03:17, 5.81s/it]
195
  84%|████████▍ | 178/211 [05:44<02:32, 4.63s/it]
196
  85%|████████▍ | 179/211 [05:46<02:00, 3.75s/it]
197
  85%|████████▌ | 180/211 [05:48<01:36, 3.13s/it]
198
  86%|████████▌ | 181/211 [05:49<01:21, 2.72s/it]
199
  86%|████████▋ | 182/211 [05:51<01:10, 2.42s/it]
200
  87%|████████▋ | 183/211 [05:53<01:01, 2.19s/it]
201
  87%|████████▋ | 184/211 [05:55<00:54, 2.03s/it]
202
  88%|████████▊ | 185/211 [05:56<00:50, 1.94s/it]
203
  88%|████████▊ | 186/211 [05:58<00:46, 1.88s/it]
204
  89%|████████▊ | 187/211 [06:00<00:43, 1.82s/it]
205
  89%|████████▉ | 188/211 [06:01<00:41, 1.79s/it]
206
  90%|████████▉ | 189/211 [06:03<00:40, 1.86s/it]
207
  90%|█████████ | 190/211 [06:06<00:45, 2.14s/it]
208
  91%|█████████ | 191/211 [06:08<00:40, 2.02s/it]
209
  91%|█████████ | 192/211 [06:10<00:36, 1.91s/it]
210
  91%|█████████▏| 193/211 [06:11<00:33, 1.85s/it]
211
  92%|█████████▏| 194/211 [06:13<00:30, 1.82s/it]
212
  92%|█████████▏| 195/211 [06:15<00:28, 1.78s/it]
213
  93%|█████████▎| 196/211 [06:17<00:28, 1.92s/it]
214
  93%|█████████▎| 197/211 [06:19<00:25, 1.84s/it]
215
  94%|█████████▍| 198/211 [06:20<00:23, 1.80s/it]
216
  94%|█████████▍| 199/211 [06:22<00:22, 1.84s/it]
217
  95%|█████████▍| 200/211 [06:24<00:19, 1.80s/it]
218
  95%|█████████▌| 201/211 [06:26<00:17, 1.77s/it]
219
  96%|█████████▌| 202/211 [06:28<00:16, 1.82s/it]
220
  96%|█████████▌| 203/211 [06:29<00:14, 1.77s/it]
221
  97%|█████████▋| 204/211 [06:31<00:12, 1.73s/it]
222
  97%|█████████▋| 205/211 [06:33<00:10, 1.72s/it]
223
  98%|█████████▊| 206/211 [06:34<00:08, 1.73s/it]
224
  98%|█████████▊| 207/211 [06:36<00:06, 1.71s/it]
225
  99%|█████████▊| 208/211 [06:39<00:05, 1.97s/it]
226
  99%|█████████▉| 209/211 [06:40<00:03, 1.93s/it]
227
+ computing/reading reference batch statistics...
228
+ computing sample batch activations...
229
+
230
  0%| | 0/469 [00:00<?, ?it/s]
231
  0%| | 1/469 [00:02<15:45, 2.02s/it]
232
  0%| | 2/469 [00:03<14:20, 1.84s/it]
233
  1%| | 3/469 [00:05<13:51, 1.78s/it]
234
  1%| | 4/469 [00:07<13:37, 1.76s/it]
235
  1%| | 5/469 [00:08<13:23, 1.73s/it]
236
  1%|▏ | 6/469 [00:10<13:11, 1.71s/it]
237
  1%|▏ | 7/469 [00:12<13:03, 1.70s/it]
238
  2%|▏ | 8/469 [00:13<12:57, 1.69s/it]
239
  2%|▏ | 9/469 [00:15<12:53, 1.68s/it]
240
  2%|▏ | 10/469 [00:17<13:08, 1.72s/it]
241
  2%|▏ | 11/469 [00:26<31:14, 4.09s/it]
242
  3%|▎ | 12/469 [00:28<25:29, 3.35s/it]
243
  3%|▎ | 13/469 [00:30<21:32, 2.84s/it]
244
  3%|▎ | 14/469 [00:31<18:50, 2.49s/it]
245
  3%|▎ | 15/469 [00:33<16:50, 2.23s/it]
246
  3%|▎ | 16/469 [00:35<15:32, 2.06s/it]
247
  4%|▎ | 17/469 [00:36<14:57, 1.99s/it]
248
  4%|▍ | 18/469 [00:38<14:12, 1.89s/it]
249
  4%|▍ | 19/469 [00:40<13:45, 1.83s/it]
250
  4%|▍ | 20/469 [00:41<13:17, 1.78s/it]
251
  4%|▍ | 21/469 [00:43<13:08, 1.76s/it]
252
  5%|▍ | 22/469 [00:45<12:51, 1.73s/it]
253
  5%|▍ | 23/469 [00:47<12:51, 1.73s/it]
254
  5%|▌ | 24/469 [00:48<12:38, 1.70s/it]
255
  5%|▌ | 25/469 [00:50<12:33, 1.70s/it]
256
  6%|▌ | 26/469 [00:51<12:26, 1.68s/it]
257
  6%|▌ | 27/469 [00:53<12:26, 1.69s/it]
258
  6%|▌ | 28/469 [00:55<12:17, 1.67s/it]
259
  6%|▌ | 29/469 [00:57<12:18, 1.68s/it]
260
  6%|▋ | 30/469 [00:58<12:12, 1.67s/it]
261
  7%|▋ | 31/469 [01:00<12:15, 1.68s/it]
262
  7%|▋ | 32/469 [01:02<12:10, 1.67s/it]
263
  7%|▋ | 33/469 [01:04<12:55, 1.78s/it]
264
  7%|▋ | 34/469 [01:05<12:41, 1.75s/it]
265
  7%|▋ | 35/469 [01:07<12:24, 1.72s/it]
266
  8%|▊ | 36/469 [01:09<12:34, 1.74s/it]
267
  8%|▊ | 37/469 [01:10<12:19, 1.71s/it]
268
  8%|▊ | 38/469 [01:12<12:27, 1.73s/it]
269
  8%|▊ | 39/469 [01:14<12:14, 1.71s/it]
270
  9%|▊ | 40/469 [01:15<12:10, 1.70s/it]
271
  9%|▊ | 41/469 [01:17<12:07, 1.70s/it]
272
  9%|▉ | 42/469 [01:19<12:09, 1.71s/it]
273
  9%|▉ | 43/469 [01:21<12:18, 1.73s/it]
274
  9%|▉ | 44/469 [01:22<12:13, 1.72s/it]
275
  10%|▉ | 45/469 [01:24<12:06, 1.71s/it]
276
  10%|▉ | 46/469 [01:26<12:22, 1.75s/it]
277
  10%|█ | 47/469 [01:28<12:35, 1.79s/it]
278
  10%|█ | 48/469 [01:29<12:12, 1.74s/it]
279
  10%|█ | 49/469 [01:31<12:01, 1.72s/it]
280
  11%|█ | 50/469 [01:33<12:24, 1.78s/it]
281
  11%|█ | 51/469 [01:35<12:32, 1.80s/it]
282
  11%|█ | 52/469 [01:37<12:29, 1.80s/it]
283
  11%|█▏ | 53/469 [01:39<13:30, 1.95s/it]
284
  12%|█▏ | 54/469 [01:41<12:50, 1.86s/it]
285
  12%|█▏ | 55/469 [01:42<12:27, 1.81s/it]
286
  12%|█▏ | 56/469 [01:44<12:05, 1.76s/it]
287
  12%|█▏ | 57/469 [01:46<12:17, 1.79s/it]
288
  12%|█▏ | 58/469 [01:47<12:08, 1.77s/it]
289
  13%|█▎ | 59/469 [01:49<12:09, 1.78s/it]
290
  13%|█▎ | 60/469 [01:51<12:17, 1.80s/it]
291
  13%|█▎ | 61/469 [01:53<12:07, 1.78s/it]
292
  13%|█▎ | 62/469 [01:55<12:24, 1.83s/it]
293
  13%|█▎ | 63/469 [01:56<12:00, 1.77s/it]
294
  14%|█▎ | 64/469 [01:58<11:42, 1.73s/it]
295
  14%|█▍ | 65/469 [02:00<11:35, 1.72s/it]
296
  14%|█▍ | 66/469 [02:02<11:34, 1.72s/it]
297
  14%|█▍ | 67/469 [02:04<12:11, 1.82s/it]
298
  14%|█▍ | 68/469 [02:05<11:45, 1.76s/it]
299
  15%|█▍ | 69/469 [02:07<11:31, 1.73s/it]
300
  15%|█▍ | 70/469 [02:09<11:27, 1.72s/it]
301
  15%|█▌ | 71/469 [02:10<11:25, 1.72s/it]
302
  15%|█▌ | 72/469 [02:12<11:16, 1.70s/it]
303
  16%|█▌ | 73/469 [02:14<11:09, 1.69s/it]
304
  16%|█▌ | 74/469 [02:15<11:32, 1.75s/it]
305
  16%|█▌ | 75/469 [02:17<11:47, 1.80s/it]
306
  16%|█▌ | 76/469 [02:19<11:43, 1.79s/it]
307
  16%|█▋ | 77/469 [02:21<12:10, 1.86s/it]
308
  17%|█▋ | 78/469 [02:23<11:45, 1.80s/it]
309
  17%|█▋ | 79/469 [02:25<12:10, 1.87s/it]
310
  17%|█▋ | 80/469 [02:27<11:48, 1.82s/it]
311
  17%|█▋ | 81/469 [02:28<11:30, 1.78s/it]
312
  17%|█▋ | 82/469 [02:30<11:36, 1.80s/it]
313
  18%|█▊ | 83/469 [02:32<11:18, 1.76s/it]
314
  18%|█▊ | 84/469 [02:34<11:12, 1.75s/it]
315
  18%|█▊ | 85/469 [02:35<11:03, 1.73s/it]
316
  18%|█▊ | 86/469 [02:37<11:06, 1.74s/it]
317
  19%|█▊ | 87/469 [02:39<11:04, 1.74s/it]
318
  19%|█▉ | 88/469 [02:40<10:56, 1.72s/it]
319
  19%|█▉ | 89/469 [02:42<11:07, 1.76s/it]
320
  19%|█▉ | 90/469 [02:45<12:56, 2.05s/it]
321
  19%|█▉ | 91/469 [02:47<12:10, 1.93s/it]
322
  20%|█▉ | 92/469 [02:48<11:35, 1.84s/it]
323
  20%|█▉ | 93/469 [02:50<11:19, 1.81s/it]
324
  20%|██ | 94/469 [02:52<11:04, 1.77s/it]
325
  20%|██ | 95/469 [02:53<10:48, 1.73s/it]
326
  20%|██ | 96/469 [02:55<10:38, 1.71s/it]
327
  21%|██ | 97/469 [02:57<11:01, 1.78s/it]
328
  21%|██ | 98/469 [02:59<10:50, 1.75s/it]
329
  21%|██ | 99/469 [03:00<10:33, 1.71s/it]
330
  21%|██▏ | 100/469 [03:02<10:35, 1.72s/it]
331
  22%|██▏ | 101/469 [03:04<11:42, 1.91s/it]
332
  22%|██▏ | 102/469 [03:06<11:19, 1.85s/it]
333
  22%|██▏ | 103/469 [03:08<11:48, 1.94s/it]
334
  22%|██▏ | 104/469 [03:10<11:23, 1.87s/it]
335
  22%|██▏ | 105/469 [03:11<10:53, 1.80s/it]
336
  23%|██▎ | 106/469 [03:13<10:34, 1.75s/it]
337
  23%|██▎ | 107/469 [03:15<10:32, 1.75s/it]
338
  23%|██▎ | 108/469 [03:17<10:19, 1.71s/it]
339
  23%|██▎ | 109/469 [03:18<10:08, 1.69s/it]
340
  23%|██▎ | 110/469 [03:20<11:07, 1.86s/it]
341
  24%|██▎ | 111/469 [03:22<10:49, 1.81s/it]
342
  24%|██▍ | 112/469 [03:24<10:35, 1.78s/it]
343
  24%|██▍ | 113/469 [03:25<10:22, 1.75s/it]
344
  24%|██▍ | 114/469 [03:27<10:13, 1.73s/it]
345
  25%|██▍ | 115/469 [03:29<10:01, 1.70s/it]
346
  25%|██▍ | 116/469 [03:30<09:56, 1.69s/it]
347
  25%|██▍ | 117/469 [03:32<09:56, 1.69s/it]
348
  25%|██▌ | 118/469 [03:34<09:45, 1.67s/it]
349
  25%|██▌ | 119/469 [03:35<09:45, 1.67s/it]
350
  26%|██▌ | 120/469 [03:37<09:45, 1.68s/it]
351
  26%|██▌ | 121/469 [03:40<11:56, 2.06s/it]
352
  26%|██▌ | 122/469 [03:42<11:18, 1.95s/it]
353
  26%|██▌ | 123/469 [03:44<11:00, 1.91s/it]
354
  26%|██▋ | 124/469 [03:45<10:31, 1.83s/it]
355
  27%|██▋ | 125/469 [03:47<10:06, 1.76s/it]
356
  27%|██▋ | 126/469 [03:48<09:48, 1.72s/it]
357
  27%|██▋ | 127/469 [03:50<09:57, 1.75s/it]
358
  27%|██▋ | 128/469 [03:52<09:52, 1.74s/it]
359
  28%|██▊ | 129/469 [03:54<09:46, 1.72s/it]
360
  28%|██▊ | 130/469 [03:56<09:59, 1.77s/it]
361
  28%|██▊ | 131/469 [03:57<09:56, 1.76s/it]
362
  28%|██▊ | 132/469 [03:59<09:40, 1.72s/it]
363
  28%|██▊ | 133/469 [04:01<10:56, 1.95s/it]
364
  29%|██▊ | 134/469 [04:03<10:35, 1.90s/it]
365
  29%|██▉ | 135/469 [04:05<10:05, 1.81s/it]
366
  29%|██▉ | 136/469 [04:06<09:45, 1.76s/it]
367
  29%|██▉ | 137/469 [04:08<09:30, 1.72s/it]
368
  29%|██▉ | 138/469 [04:10<09:35, 1.74s/it]
369
  30%|██▉ | 139/469 [04:11<09:19, 1.69s/it]
370
  30%|██▉ | 140/469 [04:13<09:06, 1.66s/it]
371
  30%|███ | 141/469 [04:15<09:00, 1.65s/it]
372
  30%|███ | 142/469 [04:16<08:55, 1.64s/it]
373
  30%|███ | 143/469 [04:18<08:55, 1.64s/it]
374
  31%|███ | 144/469 [04:19<08:47, 1.62s/it]
375
  31%|███ | 145/469 [04:21<09:06, 1.69s/it]
376
  31%|███ | 146/469 [04:23<09:06, 1.69s/it]
377
  31%|███▏ | 147/469 [04:25<08:56, 1.67s/it]
378
  32%|███▏ | 148/469 [04:26<08:56, 1.67s/it]
379
  32%|███▏ | 149/469 [04:28<09:09, 1.72s/it]
380
  32%|███▏ | 150/469 [04:30<08:57, 1.69s/it]
381
  32%|███▏ | 151/469 [04:31<08:47, 1.66s/it]
382
  32%|███▏ | 152/469 [04:33<08:41, 1.64s/it]
383
  33%|███▎ | 153/469 [04:35<08:34, 1.63s/it]
384
  33%|███▎ | 154/469 [04:36<08:32, 1.63s/it]
385
  33%|███▎ | 155/469 [04:38<08:25, 1.61s/it]
386
  33%|███▎ | 156/469 [04:39<08:23, 1.61s/it]
387
  33%|███▎ | 157/469 [04:41<08:22, 1.61s/it]
388
  34%|███▎ | 158/469 [04:47<15:58, 3.08s/it]
389
  34%|███▍ | 159/469 [04:49<13:44, 2.66s/it]
390
  34%|███▍ | 160/469 [04:51<12:27, 2.42s/it]
391
  34%|███▍ | 161/469 [04:53<11:14, 2.19s/it]
392
  35%|███▍ | 162/469 [04:54<10:22, 2.03s/it]
393
  35%|███▍ | 163/469 [04:56<09:58, 1.96s/it]
394
  35%|███▍ | 164/469 [04:58<09:35, 1.89s/it]
395
  35%|███▌ | 165/469 [04:59<09:06, 1.80s/it]
396
  35%|███▌ | 166/469 [05:01<08:47, 1.74s/it]
397
  36%|███▌ | 167/469 [05:03<08:31, 1.69s/it]
398
  36%|███▌ | 168/469 [05:05<08:55, 1.78s/it]
399
  36%|███▌ | 169/469 [05:07<09:07, 1.82s/it]
400
  36%|███▌ | 170/469 [05:08<08:54, 1.79s/it]
401
  36%|███▋ | 171/469 [05:10<08:46, 1.77s/it]
402
  37%|███▋ | 172/469 [05:12<08:31, 1.72s/it]
403
  37%|███▋ | 173/469 [05:13<08:23, 1.70s/it]
404
  37%|███▋ | 174/469 [05:15<08:16, 1.68s/it]
405
  37%|███▋ | 175/469 [05:17<08:12, 1.67s/it]
406
  38%|███▊ | 176/469 [05:18<08:14, 1.69s/it]
407
  38%|███▊ | 177/469 [05:20<08:13, 1.69s/it]
408
  38%|███▊ | 178/469 [05:22<08:07, 1.68s/it]
409
  38%|███▊ | 179/469 [05:23<08:00, 1.66s/it]
410
  38%|███▊ | 180/469 [05:25<07:58, 1.66s/it]
411
  39%|███▊ | 181/469 [05:26<07:53, 1.64s/it]
412
  39%|███▉ | 182/469 [05:28<07:47, 1.63s/it]
413
  39%|███▉ | 183/469 [05:30<07:48, 1.64s/it]
414
  39%|███▉ | 184/469 [05:31<07:50, 1.65s/it]
415
  39%|███▉ | 185/469 [05:33<07:44, 1.64s/it]
416
  40%|███▉ | 186/469 [05:35<07:50, 1.66s/it]
417
  40%|███▉ | 187/469 [05:37<08:00, 1.70s/it]
418
  40%|████ | 188/469 [05:38<07:55, 1.69s/it]
419
  40%|████ | 189/469 [05:40<07:48, 1.67s/it]
420
  41%|████ | 190/469 [05:42<07:53, 1.70s/it]
421
  41%|████ | 191/469 [05:43<07:44, 1.67s/it]
422
  41%|████ | 192/469 [05:45<07:38, 1.65s/it]
423
  41%|████ | 193/469 [05:46<07:32, 1.64s/it]
424
  41%|████▏ | 194/469 [05:48<07:54, 1.73s/it]
425
  42%|███��▏ | 195/469 [05:50<07:45, 1.70s/it]
426
  42%|████▏ | 196/469 [05:52<07:44, 1.70s/it]
427
  42%|████▏ | 197/469 [05:53<07:48, 1.72s/it]
428
  42%|████▏ | 198/469 [05:55<07:40, 1.70s/it]
429
  42%|████▏ | 199/469 [05:57<08:02, 1.79s/it]
430
  43%|████▎ | 200/469 [05:59<07:49, 1.74s/it]
431
  43%|████▎ | 201/469 [06:00<07:41, 1.72s/it]
432
  43%|████▎ | 202/469 [06:02<07:38, 1.72s/it]
433
  43%|████▎ | 203/469 [06:04<07:37, 1.72s/it]
434
  43%|████▎ | 204/469 [06:05<07:29, 1.70s/it]
435
  44%|████▎ | 205/469 [06:07<07:31, 1.71s/it]
436
  44%|████▍ | 206/469 [06:09<07:21, 1.68s/it]
437
  44%|████▍ | 207/469 [06:10<07:16, 1.67s/it]
438
  44%|████▍ | 208/469 [06:12<07:09, 1.64s/it]
439
  45%|████▍ | 209/469 [06:14<07:07, 1.64s/it]
440
  45%|████▍ | 210/469 [06:15<07:11, 1.67s/it]
441
  45%|████▍ | 211/469 [06:17<07:21, 1.71s/it]
442
  45%|████▌ | 212/469 [06:19<07:27, 1.74s/it]
443
  45%|████▌ | 213/469 [06:21<07:20, 1.72s/it]
444
  46%|████▌ | 214/469 [06:22<07:09, 1.69s/it]
445
  46%|████▌ | 215/469 [06:24<07:07, 1.68s/it]
446
  46%|████▌ | 216/469 [06:26<07:00, 1.66s/it]
447
  46%|████▋ | 217/469 [06:27<06:56, 1.65s/it]
448
  46%|████▋ | 218/469 [06:29<06:48, 1.63s/it]
449
  47%|████▋ | 219/469 [06:31<07:08, 1.72s/it]
450
  47%|████▋ | 220/469 [06:32<06:57, 1.68s/it]
451
  47%|████▋ | 221/469 [06:34<06:56, 1.68s/it]
452
  47%|████▋ | 222/469 [06:36<06:54, 1.68s/it]
453
  48%|████▊ | 223/469 [06:37<06:47, 1.66s/it]
454
  48%|████▊ | 224/469 [06:39<06:44, 1.65s/it]
455
  48%|████▊ | 225/469 [06:41<06:42, 1.65s/it]
456
  48%|████▊ | 226/469 [06:42<06:35, 1.63s/it]
457
  48%|████▊ | 227/469 [06:44<06:31, 1.62s/it]
458
  49%|████▊ | 228/469 [06:46<06:57, 1.73s/it]
459
  49%|████▉ | 229/469 [06:47<06:47, 1.70s/it]
460
  49%|████▉ | 230/469 [06:49<06:46, 1.70s/it]
461
  49%|████▉ | 231/469 [06:51<06:40, 1.68s/it]
462
  49%|████▉ | 232/469 [06:52<06:33, 1.66s/it]
463
  50%|████▉ | 233/469 [06:54<06:46, 1.72s/it]
464
  50%|████▉ | 234/469 [06:56<07:17, 1.86s/it]
465
  50%|█████ | 235/469 [06:58<07:00, 1.80s/it]
466
  50%|█████ | 236/469 [07:00<06:47, 1.75s/it]
467
  51%|█████ | 237/469 [07:01<06:40, 1.72s/it]
468
  51%|█████ | 238/469 [07:03<06:31, 1.69s/it]
469
  51%|█████ | 239/469 [07:05<06:23, 1.67s/it]
470
  51%|█████ | 240/469 [07:06<06:18, 1.65s/it]
471
  51%|█████▏ | 241/469 [07:08<06:14, 1.64s/it]
472
  52%|█████▏ | 242/469 [07:09<06:10, 1.63s/it]
473
  52%|█████▏ | 243/469 [07:11<06:41, 1.77s/it]
474
  52%|█████▏ | 244/469 [07:13<06:27, 1.72s/it]
475
  52%|█████▏ | 245/469 [07:15<06:49, 1.83s/it]
476
  52%|█████▏ | 246/469 [07:17<06:33, 1.76s/it]
477
  53%|█████▎ | 247/469 [07:18<06:22, 1.72s/it]
478
  53%|█████▎ | 248/469 [07:20<06:19, 1.72s/it]
479
  53%|█████▎ | 249/469 [07:22<06:12, 1.69s/it]
480
  53%|█████▎ | 250/469 [07:24<06:16, 1.72s/it]
481
  54%|█████▎ | 251/469 [07:25<06:10, 1.70s/it]
482
  54%|█████▎ | 252/469 [07:27<06:11, 1.71s/it]
483
  54%|█████▍ | 253/469 [07:29<06:03, 1.68s/it]
484
  54%|█████▍ | 254/469 [07:30<06:04, 1.70s/it]
485
  54%|█████▍ | 255/469 [07:32<05:56, 1.67s/it]
486
  55%|█████▍ | 256/469 [07:34<06:02, 1.70s/it]
487
  55%|█████▍ | 257/469 [07:35<05:56, 1.68s/it]
488
  55%|█████▌ | 258/469 [07:37<06:02, 1.72s/it]
489
  55%|█████▌ | 259/469 [07:39<05:57, 1.70s/it]
490
  55%|█████▌ | 260/469 [07:40<05:50, 1.68s/it]
491
  56%|█████▌ | 261/469 [07:42<05:50, 1.69s/it]
492
  56%|█████▌ | 262/469 [07:44<05:52, 1.70s/it]
493
  56%|█████▌ | 263/469 [07:46<05:50, 1.70s/it]
494
  56%|█████▋ | 264/469 [07:47<05:48, 1.70s/it]
495
  57%|█████▋ | 265/469 [07:49<05:45, 1.69s/it]
496
  57%|█████▋ | 266/469 [07:50<05:37, 1.66s/it]
497
  57%|█████▋ | 267/469 [07:52<05:33, 1.65s/it]
498
  57%|█████▋ | 268/469 [07:58<09:27, 2.82s/it]
499
  57%|█████▋ | 269/469 [07:59<08:23, 2.52s/it]
500
  58%|█████▊ | 270/469 [08:01<07:26, 2.24s/it]
501
  58%|█████▊ | 271/469 [08:03<06:53, 2.09s/it]
502
  58%|█████▊ | 272/469 [08:04<06:26, 1.96s/it]
503
  58%|█████▊ | 273/469 [08:06<06:03, 1.86s/it]
504
  58%|█████▊ | 274/469 [08:08<05:44, 1.77s/it]
505
  59%|█████▊ | 275/469 [08:09<05:33, 1.72s/it]
506
  59%|█████▉ | 276/469 [08:11<05:25, 1.69s/it]
507
  59%|█████▉ | 277/469 [08:12<05:21, 1.68s/it]
508
  59%|█████▉ | 278/469 [08:14<05:21, 1.68s/it]
509
  59%|█████▉ | 279/469 [08:16<05:26, 1.72s/it]
510
  60%|█████▉ | 280/469 [08:18<05:16, 1.67s/it]
511
  60%|█████▉ | 281/469 [08:20<05:45, 1.84s/it]
512
  60%|██████ | 282/469 [08:21<05:31, 1.77s/it]
513
  60%|██████ | 283/469 [08:23<05:26, 1.75s/it]
514
  61%|██████ | 284/469 [08:25<05:17, 1.72s/it]
515
  61%|██████ | 285/469 [08:26<05:10, 1.69s/it]
516
  61%|██████ | 286/469 [08:28<05:21, 1.76s/it]
517
  61%|██████ | 287/469 [08:30<05:16, 1.74s/it]
518
  61%|██████▏ | 288/469 [08:32<05:13, 1.73s/it]
519
  62%|██████▏ | 289/469 [08:33<05:03, 1.69s/it]
520
  62%|██████▏ | 290/469 [08:35<05:10, 1.74s/it]
521
  62%|██████▏ | 291/469 [08:37<05:17, 1.78s/it]
522
  62%|██████▏ | 292/469 [08:39<05:17, 1.79s/it]
523
  62%|██████▏ | 293/469 [08:41<05:19, 1.81s/it]
524
  63%|██████▎ | 294/469 [08:42<05:06, 1.75s/it]
525
  63%|██████▎ | 295/469 [08:44<04:55, 1.70s/it]
526
  63%|██████▎ | 296/469 [08:46<04:53, 1.70s/it]
527
  63%|██████▎ | 297/469 [08:47<04:46, 1.67s/it]
528
  64%|██████▎ | 298/469 [08:49<04:40, 1.64s/it]
529
  64%|██████▍ | 299/469 [08:50<04:39, 1.65s/it]
530
  64%|██████▍ | 300/469 [08:52<04:37, 1.64s/it]
531
  64%|██████▍ | 301/469 [08:54<04:39, 1.66s/it]
532
  64%|██████▍ | 302/469 [08:55<04:35, 1.65s/it]
533
  65%|██████▍ | 303/469 [08:57<04:36, 1.66s/it]
534
  65%|██████▍ | 304/469 [08:59<04:36, 1.67s/it]
535
  65%|██████▌ | 305/469 [09:00<04:34, 1.68s/it]
536
  65%|██████▌ | 306/469 [09:02<04:32, 1.67s/it]
537
  65%|██████▌ | 307/469 [09:04<04:30, 1.67s/it]
538
  66%|██████▌ | 308/469 [09:05<04:25, 1.65s/it]
539
  66%|██████▌ | 309/469 [09:07<04:23, 1.65s/it]
540
  66%|██████▌ | 310/469 [09:09<04:20, 1.64s/it]
541
  66%|██████▋ | 311/469 [09:10<04:18, 1.64s/it]
542
  67%|██████▋ | 312/469 [09:12<04:19, 1.65s/it]
543
  67%|██████▋ | 313/469 [09:14<04:15, 1.64s/it]
544
  67%|██████▋ | 314/469 [09:15<04:10, 1.62s/it]
545
  67%|██████▋ | 315/469 [09:17<04:08, 1.61s/it]
546
  67%|██████▋ | 316/469 [09:19<04:21, 1.71s/it]
547
  68%|██████▊ | 317/469 [09:20<04:23, 1.73s/it]
548
  68%|██████▊ | 318/469 [09:22<04:18, 1.71s/it]
549
  68%|██████▊ | 319/469 [09:24<04:20, 1.74s/it]
550
  68%|██████▊ | 320/469 [09:26<04:15, 1.71s/it]
551
  68%|██████▊ | 321/469 [09:27<04:21, 1.77s/it]
552
  69%|██████▊ | 322/469 [09:29<04:12, 1.72s/it]
553
  69%|██████▉ | 323/469 [09:31<04:04, 1.67s/it]
554
  69%|██████▉ | 324/469 [09:32<04:01, 1.67s/it]
555
  69%|██████▉ | 325/469 [09:34<03:55, 1.64s/it]
556
  70%|██████▉ | 326/469 [09:35<03:52, 1.62s/it]
557
  70%|██████▉ | 327/469 [09:37<03:49, 1.61s/it]
558
  70%|██████▉ | 328/469 [09:39<03:52, 1.65s/it]
559
  70%|███████ | 329/469 [09:41<04:01, 1.72s/it]
560
  70%|███████ | 330/469 [09:43<04:06, 1.77s/it]
561
  71%|███████ | 331/469 [09:44<04:03, 1.76s/it]
562
  71%|███████ | 332/469 [09:46<03:55, 1.72s/it]
563
  71%|███████ | 333/469 [09:47<03:48, 1.68s/it]
564
  71%|███████ | 334/469 [09:49<03:46, 1.68s/it]
565
  71%|███████▏ | 335/469 [09:51<03:42, 1.66s/it]
566
  72%|███████▏ | 336/469 [09:52<03:37, 1.63s/it]
567
  72%|███████▏ | 337/469 [09:54<03:33, 1.62s/it]
568
  72%|███████▏ | 338/469 [09:56<03:33, 1.63s/it]
569
  72%|███████▏ | 339/469 [09:57<03:31, 1.62s/it]
570
  72%|███████▏ | 340/469 [09:59<03:27, 1.61s/it]
571
  73%|███████▎ | 341/469 [10:01<03:32, 1.66s/it]
572
  73%|███████▎ | 342/469 [10:03<03:41, 1.75s/it]
573
  73%|███████▎ | 343/469 [10:04<03:35, 1.71s/it]
574
  73%|███████▎ | 344/469 [10:06<03:31, 1.70s/it]
575
  74%|███████▎ | 345/469 [10:27<15:27, 7.48s/it]
576
  74%|███████▍ | 346/469 [10:29<11:48, 5.76s/it]
577
  74%|███████▍ | 347/469 [10:30<09:19, 4.59s/it]
578
  74%|███████▍ | 348/469 [10:32<07:29, 3.71s/it]
579
  74%|███████▍ | 349/469 [10:34<06:11, 3.10s/it]
580
  75%|███████▍ | 350/469 [10:36<05:49, 2.93s/it]
581
  75%|███████▍ | 351/469 [10:38<05:09, 2.63s/it]
582
  75%|███████▌ | 352/469 [10:40<04:39, 2.39s/it]
583
  75%|███████▌ | 353/469 [10:42<04:16, 2.21s/it]
584
  75%|███████▌ | 354/469 [10:43<03:53, 2.03s/it]
585
  76%|███████▌ | 355/469 [10:45<03:35, 1.89s/it]
586
  76%|███████▌ | 356/469 [10:47<03:30, 1.86s/it]
587
  76%|███████▌ | 357/469 [10:48<03:23, 1.82s/it]
588
  76%|███████▋ | 358/469 [10:50<03:19, 1.80s/it]
589
  77%|███████▋ | 359/469 [10:52<03:20, 1.82s/it]
590
  77%|███████▋ | 360/469 [10:54<03:23, 1.87s/it]
591
  77%|███████▋ | 361/469 [10:56<03:16, 1.82s/it]
592
  77%|███████▋ | 362/469 [10:58<03:24, 1.91s/it]
593
  77%|███████▋ | 363/469 [11:00<03:12, 1.82s/it]
594
  78%|███████▊ | 364/469 [11:01<03:05, 1.77s/it]
595
  78%|███████▊ | 365/469 [11:03<02:58, 1.72s/it]
596
  78%|███████▊ | 366/469 [11:04<02:54, 1.69s/it]
597
  78%|███████▊ | 367/469 [11:06<02:52, 1.69s/it]
598
  78%|███████▊ | 368/469 [11:08<02:50, 1.68s/it]
599
  79%|███████▊ | 369/469 [11:10<02:50, 1.71s/it]
600
  79%|███████▉ | 370/469 [11:11<02:49, 1.71s/it]
601
  79%|███████▉ | 371/469 [11:13<02:46, 1.70s/it]
602
  79%|███████▉ | 372/469 [11:15<02:49, 1.74s/it]
603
  80%|███████▉ | 373/469 [11:16<02:43, 1.70s/it]
604
  80%|███████▉ | 374/469 [11:18<02:42, 1.71s/it]
605
  80%|███████▉ | 375/469 [11:20<02:44, 1.75s/it]
606
  80%|████████ | 376/469 [11:22<02:47, 1.80s/it]
607
  80%|████████ | 377/469 [11:24<02:41, 1.76s/it]
608
  81%|████████ | 378/469 [11:25<02:34, 1.70s/it]
609
  81%|████████ | 379/469 [11:27<02:30, 1.67s/it]
610
  81%|████████ | 380/469 [11:29<02:33, 1.72s/it]
611
  81%|████████ | 381/469 [11:30<02:28, 1.69s/it]
612
  81%|████████▏ | 382/469 [11:32<02:24, 1.66s/it]
613
  82%|████████▏ | 383/469 [11:34<02:29, 1.73s/it]
614
  82%|████████▏ | 384/469 [11:36<02:33, 1.80s/it]
615
  82%|████████▏ | 385/469 [11:38<02:34, 1.84s/it]
616
  82%|████████▏ | 386/469 [11:39<02:27, 1.77s/it]
617
  83%|████████▎ | 387/469 [11:41<02:26, 1.79s/it]
618
  83%|████████▎ | 388/469 [11:43<02:21, 1.74s/it]
619
  83%|████████▎ | 389/469 [11:45<02:23, 1.79s/it]
620
  83%|████████▎ | 390/469 [11:46<02:16, 1.73s/it]
621
  83%|████████▎ | 391/469 [11:48<02:12, 1.70s/it]
622
  84%|████████▎ | 392/469 [11:49<02:09, 1.69s/it]
623
  84%|████████▍ | 393/469 [11:51<02:05, 1.66s/it]
624
  84%|████████▍ | 394/469 [11:53<02:03, 1.65s/it]
625
  84%|████████▍ | 395/469 [11:54<02:01, 1.65s/it]
626
  84%|████████▍ | 396/469 [11:56<01:59, 1.63s/it]
627
  85%|████████▍ | 397/469 [11:58<01:58, 1.64s/it]
628
  85%|████████▍ | 398/469 [11:59<02:00, 1.70s/it]
629
  85%|████████▌ | 399/469 [12:01<01:57, 1.68s/it]
630
  85%|████████▌ | 400/469 [12:03<01:54, 1.66s/it]
631
  86%|████████▌ | 401/469 [12:04<01:51, 1.64s/it]
632
  86%|████████▌ | 402/469 [12:25<08:21, 7.49s/it]
633
  86%|████████▌ | 403/469 [12:27<06:17, 5.72s/it]
634
  86%|████████▌ | 404/469 [12:29<04:51, 4.49s/it]
635
  86%|████████▋ | 405/469 [12:30<03:53, 3.65s/it]
636
  87%|████████▋ | 406/469 [12:32<03:12, 3.06s/it]
637
  87%|████████▋ | 407/469 [12:34<02:42, 2.63s/it]
638
  87%|████████▋ | 408/469 [12:35<02:21, 2.32s/it]
639
  87%|████████▋ | 409/469 [12:37<02:06, 2.10s/it]
640
  87%|████████▋ | 410/469 [12:38<01:54, 1.94s/it]
641
  88%|████████▊ | 411/469 [12:40<01:46, 1.83s/it]
642
  88%|████████▊ | 412/469 [12:41<01:39, 1.75s/it]
643
  88%|████████▊ | 413/469 [12:43<01:36, 1.73s/it]
644
  88%|████████▊ | 414/469 [12:45<01:33, 1.69s/it]
645
  88%|████████▊ | 415/469 [12:46<01:29, 1.65s/it]
646
  89%|████████▊ | 416/469 [12:48<01:28, 1.68s/it]
647
  89%|████████▉ | 417/469 [12:50<01:26, 1.67s/it]
648
  89%|████████▉ | 418/469 [12:51<01:24, 1.65s/it]
649
  89%|████████▉ | 419/469 [12:53<01:21, 1.64s/it]
650
  90%|████████▉ | 420/469 [12:55<01:21, 1.66s/it]
651
  90%|████████▉ | 421/469 [12:56<01:19, 1.66s/it]
652
  90%|████████▉ | 422/469 [12:58<01:18, 1.67s/it]
653
  90%|█████████ | 423/469 [13:00<01:18, 1.71s/it]
654
  90%|█████████ | 424/469 [13:02<01:18, 1.73s/it]
655
  91%|█████████ | 425/469 [13:03<01:14, 1.69s/it]
656
  91%|█████████ | 426/469 [13:05<01:11, 1.66s/it]
657
  91%|█████████ | 427/469 [13:07<01:13, 1.74s/it]
658
  91%|█████████▏| 428/469 [13:08<01:09, 1.70s/it]
659
  91%|█████████▏| 429/469 [13:10<01:07, 1.68s/it]
660
  92%|█████████▏| 430/469 [13:12<01:05, 1.68s/it]
661
  92%|█████████▏| 431/469 [13:13<01:05, 1.71s/it]
662
  92%|█████████▏| 432/469 [13:15<01:02, 1.68s/it]
663
  92%|█████████▏| 433/469 [13:17<01:01, 1.71s/it]
664
  93%|█████████▎| 434/469 [13:18<00:58, 1.69s/it]
665
  93%|█████████▎| 435/469 [13:20<00:57, 1.69s/it]
666
  93%|█████████▎| 436/469 [13:22<00:55, 1.68s/it]
667
  93%|█████████▎| 437/469 [13:23<00:53, 1.67s/it]
668
  93%|█████████▎| 438/469 [13:25<00:51, 1.66s/it]
669
  94%|█████████▎| 439/469 [13:27<00:49, 1.64s/it]
670
  94%|█████████▍| 440/469 [13:28<00:47, 1.65s/it]
671
  94%|█████████▍| 441/469 [13:30<00:46, 1.65s/it]
672
  94%|█████████▍| 442/469 [13:32<00:44, 1.63s/it]
673
  94%|█████████▍| 443/469 [13:33<00:43, 1.67s/it]
674
  95%|█████████▍| 444/469 [13:35<00:41, 1.65s/it]
675
  95%|█████████▍| 445/469 [13:36<00:39, 1.64s/it]
676
  95%|█████████▌| 446/469 [13:39<00:40, 1.77s/it]
677
  95%|█████████▌| 447/469 [13:40<00:37, 1.72s/it]
678
  96%|█████████▌| 448/469 [13:42<00:35, 1.68s/it]
679
  96%|█████████▌| 449/469 [13:44<00:36, 1.84s/it]
680
  96%|█████████▌| 450/469 [13:46<00:33, 1.78s/it]
681
  96%|█████████▌| 451/469 [13:47<00:30, 1.72s/it]
682
  96%|█████████▋| 452/469 [13:49<00:30, 1.78s/it]
683
  97%|█████████▋| 453/469 [13:51<00:27, 1.73s/it]
684
  97%|█████████▋| 454/469 [13:52<00:25, 1.70s/it]
685
  97%|█████████▋| 455/469 [13:54<00:23, 1.70s/it]
686
  97%|█████████▋| 456/469 [13:56<00:21, 1.67s/it]
687
  97%|█████████▋| 457/469 [13:57<00:19, 1.66s/it]
688
  98%|█████████▊| 458/469 [13:59<00:18, 1.65s/it]
689
  98%|█████████▊| 459/469 [14:01<00:16, 1.64s/it]
690
  98%|█████████▊| 460/469 [14:02<00:14, 1.65s/it]
691
  98%|█████████▊| 461/469 [14:04<00:13, 1.65s/it]
692
  99%|█████████▊| 462/469 [14:06<00:11, 1.67s/it]
693
  99%|█████████▊| 463/469 [14:07<00:10, 1.67s/it]
694
  99%|█████████▉| 464/469 [14:09<00:09, 1.83s/it]
695
  99%|█████████▉| 465/469 [14:11<00:07, 1.84s/it]
696
  99%|█████████▉| 466/469 [14:13<00:05, 1.80s/it]
697
+ computing/reading sample batch statistics...
698
+ Computing evaluations...
699
+ Inception Score: 37.646392822265625
700
+ FID: 21.19386100577333
701
+ sFID: 71.79977998851734
702
+ Precision: 0.690407122136641
703
+ Recall: 0.358997247638176
pic_npz copy.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
4
+ 基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
5
+ 支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
6
+ 支持从 metadata.jsonl 文件读取图片路径
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ import glob
15
+ import json
16
+
17
+
18
+ def create_npz_from_metadata(metadata_jsonl_path, output_path=None):
19
+ """
20
+ 从 metadata.jsonl 文件读取图片路径并构建 .npz 文件
21
+
22
+ Args:
23
+ metadata_jsonl_path (str): metadata.jsonl 文件路径
24
+ output_path (str, optional): 输出 npz 文件路径,默认在 metadata.jsonl 同目录下生成
25
+
26
+ Returns:
27
+ str: 生成的 npz 文件路径
28
+ """
29
+ # 确保 metadata.jsonl 存在
30
+ if not os.path.exists(metadata_jsonl_path):
31
+ raise ValueError(f"metadata.jsonl 文件不存在: {metadata_jsonl_path}")
32
+
33
+ # 获取基础目录
34
+ base_dir = os.path.dirname(metadata_jsonl_path)
35
+
36
+ # 读取 metadata.jsonl
37
+ image_files = []
38
+ with open(metadata_jsonl_path, 'r', encoding='utf-8') as f:
39
+ for line in f:
40
+ line = line.strip()
41
+ if line:
42
+ try:
43
+ data = json.loads(line)
44
+ file_name = data.get('file_name')
45
+ if file_name:
46
+ full_path = os.path.join(base_dir, file_name)
47
+ image_files.append(full_path)
48
+ except json.JSONDecodeError as e:
49
+ print(f"警告: 跳过无效的 JSON 行: {e}")
50
+ continue
51
+
52
+ if len(image_files) == 0:
53
+ raise ValueError(f"在 {metadata_jsonl_path} 中未找到任何有效的图片路径")
54
+
55
+ print(f"从 metadata.jsonl 读取到 {len(image_files)} 张图片路径")
56
+
57
+ # 读取所有图片
58
+ samples = []
59
+ for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
60
+ try:
61
+ # 打开图片并转换为RGB格式(确保一致性)
62
+ with Image.open(img_path) as img:
63
+ # 转换为RGB,确保所有图片都是3通道
64
+ if img.mode != 'RGB':
65
+ img = img.convert('RGB')
66
+
67
+ # 将图片resize到512x512
68
+ img = img.resize((512, 512), Image.LANCZOS)
69
+
70
+ sample_np = np.asarray(img).astype(np.uint8)
71
+
72
+ # 确保图片是3通道
73
+ if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
74
+ print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
75
+ continue
76
+
77
+ samples.append(sample_np)
78
+
79
+ except Exception as e:
80
+ print(f"警告: 无法读取图片 {img_path}: {e}")
81
+ continue
82
+
83
+ if len(samples) == 0:
84
+ raise ValueError("没有成功读取任何有效的图片文件")
85
+
86
+ # 转换为numpy数组
87
+ samples = np.stack(samples)
88
+ print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
89
+
90
+ # 验证数据形状
91
+ assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
92
+ assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
93
+
94
+ # 生成输出路径
95
+ if output_path is None:
96
+ base_name = os.path.splitext(os.path.basename(metadata_jsonl_path))[0]
97
+ output_path = os.path.join(base_dir, f"{base_name}.npz")
98
+
99
+ # 保存为npz文件
100
+ np.savez(output_path, arr_0=samples)
101
+ print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
102
+
103
+ return output_path
104
+
105
+
106
+ def main():
107
+ """
108
+ 主函数:解析命令行参数并执行图片到npz的转换
109
+ """
110
+ parser = argparse.ArgumentParser(
111
+ description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
112
+ formatter_class=argparse.RawDescriptionHelpFormatter,
113
+ epilog="""
114
+ 使用示例:
115
+ python pic_npz.py /path/to/image/folder
116
+ python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
117
+ """
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--image_folder",
122
+ type=str,
123
+ default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples",
124
+ help="包含PNG或JPG图片文件的文件夹路径"
125
+ )
126
+
127
+ # parser.add_argument(
128
+ # "--metadata_jsonl",
129
+ # type=str,
130
+ # default="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl",
131
+ # help="metadata.jsonl 文件路径,用于从 JSONL 文件读取图片路径"
132
+ # )
133
+
134
+ parser.add_argument(
135
+ "--output-dir",
136
+ type=str,
137
+ default=None,
138
+ help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)"
139
+ )
140
+
141
+ args = parser.parse_args()
142
+
143
+ try:
144
+ if args.metadata_jsonl and os.path.exists(args.metadata_jsonl):
145
+ # 使用 metadata.jsonl
146
+ metadata_path = os.path.abspath(args.metadata_jsonl)
147
+ base_dir = os.path.dirname(metadata_path)
148
+ base_name = os.path.splitext(os.path.basename(metadata_path))[0]
149
+
150
+ if args.output_dir:
151
+ os.makedirs(args.output_dir, exist_ok=True)
152
+ output_path = os.path.join(args.output_dir, f"{base_name}.npz")
153
+ else:
154
+ output_path = os.path.join(base_dir, f"{base_name}.npz")
155
+
156
+ npz_path = create_npz_from_metadata(metadata_path, output_path)
157
+ else:
158
+ # 使用图片文件夹
159
+ image_folder_path = os.path.abspath(args.image_folder)
160
+
161
+ if args.output_dir:
162
+ # 如果指定了输出目录,修改生成逻辑
163
+ folder_name = os.path.basename(image_folder_path.rstrip('/'))
164
+ custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
165
+
166
+ # 创建输出目录(如果不存在)
167
+ os.makedirs(args.output_dir, exist_ok=True)
168
+
169
+ # 临时修改函数以支持自定义输出路径
170
+ npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
171
+ else:
172
+ npz_path = create_npz_from_image_folder(image_folder_path)
173
+
174
+ print(f"转换完成!NPZ文件已保存至: {npz_path}")
175
+
176
+ except Exception as e:
177
+ print(f"错误: {e}")
178
+ return 1
179
+
180
+ return 0
181
+
182
+
183
+ def create_npz_from_image_folder_custom(image_folder_path, output_path):
184
+ """
185
+ 从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
186
+
187
+ Args:
188
+ image_folder_path (str): 包含图片文件的文件夹路径
189
+ output_path (str): 输出npz文件的完整路径
190
+
191
+ Returns:
192
+ str: 生成的 npz 文件路径
193
+ """
194
+ # 确保路径存在
195
+ if not os.path.exists(image_folder_path):
196
+ raise ValueError(f"文件夹路径不存在: {image_folder_path}")
197
+
198
+ # 获取所有支持的图片文件
199
+ supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
200
+ image_files = []
201
+
202
+ for extension in supported_extensions:
203
+ pattern = os.path.join(image_folder_path, extension)
204
+ image_files.extend(glob.glob(pattern))
205
+
206
+ # 按文件名排序确保一致性
207
+ image_files.sort()
208
+
209
+ if len(image_files) == 0:
210
+ raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
211
+
212
+ print(f"找到 {len(image_files)} 张图片文件")
213
+
214
+ # 读取所有图片
215
+ samples = []
216
+ for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
217
+ try:
218
+ # 打开图片并转换为RGB格式(确保一致性)
219
+ with Image.open(img_path) as img:
220
+ # 转换为RGB,确保所有图片都是3通道
221
+ if img.mode != 'RGB':
222
+ img = img.convert('RGB')
223
+
224
+ # 将图片resize到512x512
225
+ img = img.resize((512, 512), Image.LANCZOS)
226
+
227
+ sample_np = np.asarray(img).astype(np.uint8)
228
+
229
+ # 确保图片是3通道
230
+ if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
231
+ print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
232
+ continue
233
+
234
+ samples.append(sample_np)
235
+
236
+ except Exception as e:
237
+ print(f"警告: 无法读取图片 {img_path}: {e}")
238
+ continue
239
+
240
+ if len(samples) == 0:
241
+ raise ValueError("没有成功读取任何有效的图片文件")
242
+
243
+ # 转换为numpy数组
244
+ samples = np.stack(samples)
245
+ print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
246
+
247
+ # 验证数据形状
248
+ assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
249
+ assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
250
+
251
+ # 保存为npz文件
252
+ np.savez(output_path, arr_0=samples)
253
+ print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
254
+
255
+ return output_path
256
+
257
+
258
+ if __name__ == "__main__":
259
+ exit(main())
pic_npz.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
4
+ 基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
5
+ 支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
6
+ """
7
+
8
+ import os
9
+ import argparse
10
+ import numpy as np
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ import glob
14
+
15
+
16
+ def main():
17
+ """
18
+ 主函数:解析命令行参数并执行图片到npz的转换
19
+ """
20
+ parser = argparse.ArgumentParser(
21
+ description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
22
+ formatter_class=argparse.RawDescriptionHelpFormatter,
23
+ epilog="""
24
+ 使用示例:
25
+ python pic_npz.py /path/to/image/folder
26
+ python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
27
+ """
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--image_folder",
32
+ type=str,
33
+ default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_new_batch_2",
34
+ help="包含PNG或JPG图片文件的文件夹路径"
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--output-dir",
39
+ type=str,
40
+ default=None,
41
+ help="自定义输出目录(默认为输入文件夹的父级目录)"
42
+ )
43
+
44
+ args = parser.parse_args()
45
+
46
+ try:
47
+ # 仅支持从图片文件夹生成 npz
48
+ image_folder_path = os.path.abspath(args.image_folder)
49
+
50
+ if args.output_dir:
51
+ # 如果指定了输出目录,修改生成逻辑
52
+ folder_name = os.path.basename(image_folder_path.rstrip('/'))
53
+ custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
54
+
55
+ # 创建输出目录(如果不存在)
56
+ os.makedirs(args.output_dir, exist_ok=True)
57
+
58
+ npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
59
+ else:
60
+ npz_path = create_npz_from_image_folder(image_folder_path)
61
+
62
+ print(f"转换完成!NPZ文件已保存至: {npz_path}")
63
+
64
+ except Exception as e:
65
+ print(f"错误: {e}")
66
+ return 1
67
+
68
+ return 0
69
+
70
+
71
+ def create_npz_from_image_folder_custom(image_folder_path, output_path):
72
+ """
73
+ 从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
74
+
75
+ Args:
76
+ image_folder_path (str): 包含图片文件的文件夹路径
77
+ output_path (str): 输出npz文件的完整路径
78
+
79
+ Returns:
80
+ str: 生成的 npz 文件路径
81
+ """
82
+ # 确保路径存在
83
+ if not os.path.exists(image_folder_path):
84
+ raise ValueError(f"文件夹路径不存在: {image_folder_path}")
85
+
86
+ # 获取所有支持的图片文件
87
+ supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
88
+ image_files = []
89
+
90
+ for extension in supported_extensions:
91
+ pattern = os.path.join(image_folder_path, extension)
92
+ image_files.extend(glob.glob(pattern))
93
+
94
+ # 按文件名排序确保一致性
95
+ image_files.sort()
96
+
97
+ if len(image_files) == 0:
98
+ raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
99
+
100
+ print(f"找到 {len(image_files)} 张图片文件")
101
+
102
+ # 读取所有图片
103
+ samples = []
104
+ for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
105
+ try:
106
+ # 打开图片并转换为RGB格式(确保一致性)
107
+ with Image.open(img_path) as img:
108
+ # 转换为RGB,确保所有图片都是3通道
109
+ if img.mode != 'RGB':
110
+ img = img.convert('RGB')
111
+
112
+ # 将图片resize到512x512
113
+ img = img.resize((512, 512), Image.LANCZOS)
114
+
115
+ sample_np = np.asarray(img).astype(np.uint8)
116
+
117
+ # 确保图片是3通道
118
+ if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
119
+ print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
120
+ continue
121
+
122
+ samples.append(sample_np)
123
+
124
+ except Exception as e:
125
+ print(f"警告: 无法读取图片 {img_path}: {e}")
126
+ continue
127
+
128
+ if len(samples) == 0:
129
+ raise ValueError("没有成功读取任何有效的图片文件")
130
+
131
+ # 转换为numpy数组
132
+ samples = np.stack(samples)
133
+ print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
134
+
135
+ # 验证数据形状
136
+ assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
137
+ assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
138
+
139
+ # 保存为npz文件
140
+ np.savez(output_path, arr_0=samples)
141
+ print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
142
+
143
+ return output_path
144
+
145
+
146
+ def create_npz_from_image_folder(image_folder_path):
147
+ """
148
+ 从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz
149
+ """
150
+ parent_dir = os.path.dirname(os.path.abspath(image_folder_path))
151
+ folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/"))
152
+ output_path = os.path.join(parent_dir, f"{folder_name}.npz")
153
+ return create_npz_from_image_folder_custom(image_folder_path, output_path)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ exit(main())
pipeline_stable_diffusion_3.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from ...models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import SD3Transformer2DModel
33
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
34
+ from ...utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from ...utils.torch_utils import randn_tensor
43
+ from ..pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import StableDiffusion3PipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import StableDiffusion3Pipeline
62
+
63
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
64
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+ >>> prompt = "A cat holding a sign that says hello world"
68
+ >>> image = pipe(prompt).images[0]
69
+ >>> image.save("sd3.png")
70
+ ```
71
+ """
72
+
73
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.15,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
148
+ r"""
149
+ Args:
150
+ transformer ([`SD3Transformer2DModel`]):
151
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
152
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154
+ vae ([`AutoencoderKL`]):
155
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
156
+ text_encoder ([`CLIPTextModelWithProjection`]):
157
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
158
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
159
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
160
+ as its dimension.
161
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
162
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
163
+ specifically the
164
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
165
+ variant.
166
+ text_encoder_3 ([`T5EncoderModel`]):
167
+ Frozen text-encoder. Stable Diffusion 3 uses
168
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
169
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
170
+ tokenizer (`CLIPTokenizer`):
171
+ Tokenizer of class
172
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
173
+ tokenizer_2 (`CLIPTokenizer`):
174
+ Second Tokenizer of class
175
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
176
+ tokenizer_3 (`T5TokenizerFast`):
177
+ Tokenizer of class
178
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179
+ image_encoder (`SiglipVisionModel`, *optional*):
180
+ Pre-trained Vision Model for IP Adapter.
181
+ feature_extractor (`SiglipImageProcessor`, *optional*):
182
+ Image processor for IP Adapter.
183
+ model (`SD3WithRectifiedNoise`, *optional*):
184
+ Optional SD3WithRectifiedNoise model for enhanced noise prediction. If provided, will be used instead of
185
+ the default transformer for denoising.
186
+ """
187
+
188
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
189
+ _optional_components = ["image_encoder", "feature_extractor"]
190
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
191
+
192
+ def __init__(
193
+ self,
194
+ transformer: SD3Transformer2DModel,
195
+ scheduler: FlowMatchEulerDiscreteScheduler,
196
+ vae: AutoencoderKL,
197
+ text_encoder: CLIPTextModelWithProjection,
198
+ tokenizer: CLIPTokenizer,
199
+ text_encoder_2: CLIPTextModelWithProjection,
200
+ tokenizer_2: CLIPTokenizer,
201
+ text_encoder_3: T5EncoderModel,
202
+ tokenizer_3: T5TokenizerFast,
203
+ image_encoder: SiglipVisionModel = None,
204
+ feature_extractor: SiglipImageProcessor = None,
205
+ model = None, # 添加 model 参数
206
+ ):
207
+ super().__init__()
208
+
209
+ self.register_modules(
210
+ vae=vae,
211
+ text_encoder=text_encoder,
212
+ text_encoder_2=text_encoder_2,
213
+ text_encoder_3=text_encoder_3,
214
+ tokenizer=tokenizer,
215
+ tokenizer_2=tokenizer_2,
216
+ tokenizer_3=tokenizer_3,
217
+ transformer=transformer,
218
+ scheduler=scheduler,
219
+ image_encoder=image_encoder,
220
+ feature_extractor=feature_extractor,
221
+ model=model, # 添加 model 参数到 register_modules
222
+ )
223
+ #print(f"VAE is None: {getattr(self, 'vae', None) is None}")
224
+ #if getattr(self, 'vae', None) is not None:
225
+ # print(f"VAE config block_out_channels: {self.vae.config.block_out_channels}")
226
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
227
+ #print(f"VAE scale factor: {self.vae_scale_factor}")
228
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
229
+ self.tokenizer_max_length = (
230
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
231
+ )
232
+ self.default_sample_size = (
233
+ self.transformer.config.sample_size
234
+ if hasattr(self, "transformer") and self.transformer is not None
235
+ else 64
236
+ )
237
+ self.patch_size = (
238
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
239
+ )
240
+ # 添加对 SD3WithRectifiedNoise 模型的支持
241
+ self.model = model
242
+
243
+ def _get_t5_prompt_embeds(
244
+ self,
245
+ prompt: Union[str, List[str]] = None,
246
+ num_images_per_prompt: int = 1,
247
+ max_sequence_length: int = 256,
248
+ device: Optional[torch.device] = None,
249
+ dtype: Optional[torch.dtype] = None,
250
+ ):
251
+ device = device or self._execution_device
252
+ dtype = dtype or self.text_encoder_3.dtype
253
+ #max_sequence_length=77
254
+ prompt = [prompt] if isinstance(prompt, str) else prompt
255
+ batch_size = len(prompt)
256
+
257
+ # print(f"T5处理 - 输入提示: {prompt}")
258
+ # print(f"T5处理 - batch_size: {batch_size}")
259
+ # print(f"T5处理 - num_images_per_prompt: {num_images_per_prompt}")
260
+ # print(f"T5处理 - max_sequence_length: {max_sequence_length}")
261
+ # print(f"T5处理 - device: {device}")
262
+ # print(f"T5处理 - dtype: {dtype}")
263
+
264
+ if self.text_encoder_3 is None:
265
+ #print("T5处理 - text_encoder_3为None,返回零张量")
266
+ return torch.zeros(
267
+ (
268
+ batch_size * num_images_per_prompt,
269
+ self.tokenizer_max_length,
270
+ self.transformer.config.joint_attention_dim,
271
+ ),
272
+ device=device,
273
+ dtype=dtype,
274
+ )
275
+
276
+ text_inputs = self.tokenizer_3(
277
+ prompt,
278
+ padding="max_length",
279
+ max_length=max_sequence_length,
280
+ truncation=True,
281
+ # add_special_tokens=True,
282
+ return_tensors="pt",
283
+ )
284
+ text_input_ids = text_inputs.input_ids
285
+ # if torch.isnan(text_input_ids).any():
286
+ # print("T5处理 - text_input_ids输入提示包含NaN值")
287
+ # else:
288
+ # print("T5处理 - text_input_ids",text_input_ids)
289
+ # print(f"T5处理 - text_input_ids形状: {text_input_ids.shape}")
290
+ # print(f"T5处理 - text_input_ids范围: [{text_input_ids.min().item()}, {text_input_ids.max().item()}]")
291
+ # print(f"T5处理 - tokenizer_3.vocab_size: {self.tokenizer_3.vocab_size}")
292
+
293
+ # # 检查输入token IDs是否包含非法值
294
+ # if torch.any(text_input_ids < 0):
295
+ # print(f"警告:发现负数token ID,最小值: {text_input_ids.min().item()}")
296
+ # if torch.any(text_input_ids >= self.tokenizer_3.vocab_size):
297
+ # print(f"警告:发现超过词汇表大小的token ID,最大值: {text_input_ids.max().item()}, vocab_size: {self.tokenizer_3.vocab_size}")
298
+
299
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
300
+
301
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
302
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
303
+ logger.warning(
304
+ "The following part of your input was truncated because `max_sequence_length` is set to "
305
+ f" {max_sequence_length} tokens: {removed_text}"
306
+ )
307
+
308
+ # 将输入移动到设备并确保数据类型正确
309
+ text_input_ids = text_input_ids.to(device)#, dtype=torch.long)
310
+ # print(f"T5处理 - text_input_ids形状: ",text_input_ids)
311
+ # print(f"T5处理 - text_input_ids设备: {text_input_ids.device}, dtype: {text_input_ids.dtype}")
312
+
313
+ # 检查text_encoder_3的状态
314
+ # print(f"T5处理 - text_encoder_3设备: {next(self.text_encoder_3.parameters()).device}")
315
+ # print(f"T5处理 - text_encoder_3.dtype: {self.text_encoder_3.dtype}")
316
+ with torch.autocast(device.type if isinstance(device, torch.device) else "cuda", enabled=False):
317
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
318
+ #prompt_embeds = self.text_encoder_3(text_input_ids)[0]
319
+ # print(f"T5处理 - T5编码器输出形状: {prompt_embeds.shape}")
320
+ # print(f"T5处理 - T5编码器输出设备: {prompt_embeds.device}")
321
+ # print(f"T5处理 - T5编码器输出dtype: {prompt_embeds.dtype}")
322
+
323
+ # # 检查T5编码器输出是否包含NaN或inf
324
+ # has_nan = torch.isnan(prompt_embeds).any()
325
+ # has_inf = torch.isinf(prompt_embeds).any()
326
+ # if has_nan or has_inf:
327
+ # print(f"警告:T5编码器输出包含NaN: {has_nan} 或inf: {has_inf}")
328
+ # print(f"T5编码器输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
329
+ # if has_nan:
330
+ # nan_locations = torch.where(torch.isnan(prompt_embeds))
331
+ # print(f"NaN位置 - 前10个: {[(nan_locations[i][:10].tolist()) for i in range(len(nan_locations))]}")
332
+
333
+ dtype = self.text_encoder_3.dtype
334
+ # 强制在无 autocast 的上下文中运行 T5 以避免 fp16 溢出为 NaN
335
+
336
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
337
+ #print(f"T5处理 - 转换后prompt_embeds dtype: {prompt_embeds.dtype}, device: {prompt_embeds.device}")
338
+
339
+ _, seq_len, _ = prompt_embeds.shape
340
+
341
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
342
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
343
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
344
+ # print(f"T5处理 - 最终输出形状: {prompt_embeds.shape}")
345
+
346
+ # # 检查最终输出是否包含NaN
347
+ # if torch.isnan(prompt_embeds).any():
348
+ # print(f"警告:T5最终输出包含NaN,位置: {torch.where(torch.isnan(prompt_embeds))}")
349
+ # print(f"T5最终输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
350
+ # print("最终输出",prompt_embeds)
351
+ return prompt_embeds
352
+
353
+ def _get_clip_prompt_embeds(
354
+ self,
355
+ prompt: Union[str, List[str]],
356
+ num_images_per_prompt: int = 1,
357
+ device: Optional[torch.device] = None,
358
+ clip_skip: Optional[int] = None,
359
+ clip_model_index: int = 0,
360
+ ):
361
+ device = device or self._execution_device
362
+
363
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
364
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
365
+
366
+ tokenizer = clip_tokenizers[clip_model_index]
367
+ text_encoder = clip_text_encoders[clip_model_index]
368
+
369
+ # print(f"CLIP处理 - clip_model_index: {clip_model_index}")
370
+ # print(f"CLIP处理 - 使用的tokenizer: {type(tokenizer)}")
371
+ # print(f"CLIP处理 - 使用的text_encoder: {type(text_encoder)}")
372
+
373
+ prompt = [prompt] if isinstance(prompt, str) else prompt
374
+ batch_size = len(prompt)
375
+ # print(f"CLIP处理 - 输入提示: {prompt}")
376
+ # print(f"CLIP处理 - batch_size: {batch_size}")
377
+
378
+ text_inputs = tokenizer(
379
+ prompt,
380
+ padding="max_length",
381
+ max_length=self.tokenizer_max_length,
382
+ truncation=True,
383
+ return_tensors="pt",
384
+ )
385
+
386
+ text_input_ids = text_inputs.input_ids
387
+ # print(f"CLIP处理 - text_input_ids形状: {text_input_ids.shape}")
388
+ # print(f"CLIP处理 - text_input_ids范围: [{text_input_ids.min().item()}, {text_input_ids.max().item()}]")
389
+
390
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
391
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
392
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
393
+ logger.warning(
394
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
395
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
396
+ )
397
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
398
+ pooled_prompt_embeds = prompt_embeds[0]
399
+
400
+ # print(f"CLIP处理 - pooled_prompt_embeds形状: {pooled_prompt_embeds.shape}")
401
+ # print(f"CLIP处理 - pooled_prompt_embeds设备: {pooled_prompt_embeds.device}")
402
+ # print(f"CLIP处理 - pooled_prompt_embeds dtype: {pooled_prompt_embeds.dtype}")
403
+
404
+ # 检查CLIP编码器输出是否包含NaN或inf
405
+ # has_nan = torch.isnan(pooled_prompt_embeds).any()
406
+ # has_inf = torch.isinf(pooled_prompt_embeds).any()
407
+ # if has_nan or has_inf:
408
+ # print(f"警告:CLIP编码器pooled输出包含NaN: {has_nan} 或inf: {has_inf}")
409
+ # print(f"CLIP编码器pooled输出统计 - min: {pooled_prompt_embeds.min().item()}, max: {pooled_prompt_embeds.max().item()}, mean: {pooled_prompt_embeds.mean().item()}")
410
+
411
+ if clip_skip is None:
412
+ prompt_embeds = prompt_embeds.hidden_states[-2]
413
+ else:
414
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
415
+
416
+ # 检查CLIP编码器embeds输出是否包含NaN或inf
417
+ # has_nan = torch.isnan(prompt_embeds).any()
418
+ # has_inf = torch.isinf(prompt_embeds).any()
419
+ # if has_nan or has_inf:
420
+ # print(f"警告:CLIP编码器embeds输出包含NaN: {has_nan} 或inf: {has_inf}")
421
+ # print(f"CLIP编码器embeds输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
422
+
423
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
424
+ #print(f"CLIP处理 - 转换后prompt_embeds dtype: {prompt_embeds.dtype}, device: {prompt_embeds.device}")
425
+
426
+ _, seq_len, _ = prompt_embeds.shape
427
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
428
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
429
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
430
+ # print(f"CLIP处理 - 最终prompt_embeds形状: {prompt_embeds.shape}")
431
+
432
+ # # 检查CLIP编码器embeds最终输出是否包含NaN
433
+ # if torch.isnan(prompt_embeds).any():
434
+ # print(f"警告:CLIP编码器embeds最终输出包含NaN,位置: {torch.where(torch.isnan(prompt_embeds))}")
435
+ # print(f"CLIP编码器embeds最终输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
436
+
437
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
438
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
439
+ # print(f"CLIP处理 - 最终pooled_prompt_embeds形状: {pooled_prompt_embeds.shape}")
440
+
441
+ # # 检查CLIP编码器pooled最终输出是否包含NaN
442
+ # if torch.isnan(pooled_prompt_embeds).any():
443
+ # print(f"警告:CLIP编码器pooled最终输出包含NaN,位置: {torch.where(torch.isnan(pooled_prompt_embeds))}")
444
+ # print(f"CLIP编码器pooled最终输出统计 - min: {pooled_prompt_embeds.min().item()}, max: {pooled_prompt_embeds.max().item()}, mean: {pooled_prompt_embeds.mean().item()}")
445
+
446
+ return prompt_embeds, pooled_prompt_embeds
447
+
448
+ def encode_prompt(
449
+ self,
450
+ prompt: Union[str, List[str]],
451
+ prompt_2: Union[str, List[str]],
452
+ prompt_3: Union[str, List[str]],
453
+ device: Optional[torch.device] = None,
454
+ num_images_per_prompt: int = 1,
455
+ do_classifier_free_guidance: bool = True,
456
+ negative_prompt: Optional[Union[str, List[str]]] = None,
457
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
458
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
459
+ prompt_embeds: Optional[torch.FloatTensor] = None,
460
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
461
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
462
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
463
+ clip_skip: Optional[int] = None,
464
+ max_sequence_length: int = 256,
465
+ lora_scale: Optional[float] = None,
466
+ ):
467
+ r"""
468
+
469
+ Args:
470
+ prompt (`str` or `List[str]`, *optional*):
471
+ prompt to be encoded
472
+ prompt_2 (`str` or `List[str]`, *optional*):
473
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
474
+ used in all text-encoders
475
+ prompt_3 (`str` or `List[str]`, *optional*):
476
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
477
+ used in all text-encoders
478
+ device: (`torch.device`):
479
+ torch device
480
+ num_images_per_prompt (`int`):
481
+ number of images that should be generated per prompt
482
+ do_classifier_free_guidance (`bool`):
483
+ whether to use classifier free guidance or not
484
+ negative_prompt (`str` or `List[str]`, *optional*):
485
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
486
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
487
+ less than `1`).
488
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
489
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
490
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
491
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
492
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
493
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
494
+ prompt_embeds (`torch.FloatTensor`, *optional*):
495
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
496
+ provided, text embeddings will be generated from `prompt` input argument.
497
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
498
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
499
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
500
+ argument.
501
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
502
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
503
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
504
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
505
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
506
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
507
+ input argument.
508
+ clip_skip (`int`, *optional*):
509
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
510
+ the output of the pre-final layer will be used for computing the prompt embeddings.
511
+ lora_scale (`float`, *optional*):
512
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
513
+ """
514
+ device = device or self._execution_device
515
+ # print(f"encode_prompt - 开始处理提示编码")
516
+ # print(f"encode_prompt - device: {device}")
517
+ # print(f"encode_prompt - num_images_per_prompt: {num_images_per_prompt}")
518
+ # print(f"encode_prompt - do_classifier_free_guidance: {do_classifier_free_guidance}")
519
+ # print(f"encode_prompt - max_sequence_length: {max_sequence_length}")
520
+
521
+ # set lora scale so that monkey patched LoRA
522
+ # function of text encoder can correctly access it
523
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
524
+ self._lora_scale = lora_scale
525
+
526
+ # dynamically adjust the LoRA scale
527
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
528
+ scale_lora_layers(self.text_encoder, lora_scale)
529
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
530
+ scale_lora_layers(self.text_encoder_2, lora_scale)
531
+
532
+ prompt = [prompt] if isinstance(prompt, str) else prompt
533
+ if prompt is not None:
534
+ batch_size = len(prompt)
535
+ else:
536
+ batch_size = prompt_embeds.shape[0]
537
+ # print(f"encode_prompt - batch_size: {batch_size}")
538
+
539
+ if prompt_embeds is None:
540
+ prompt_2 = prompt_2 or prompt
541
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
542
+
543
+ prompt_3 = prompt_3 or prompt
544
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
545
+
546
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
547
+ prompt=prompt,
548
+ device=device,
549
+ num_images_per_prompt=num_images_per_prompt,
550
+ clip_skip=clip_skip,
551
+ clip_model_index=0,
552
+ )
553
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
554
+ prompt=prompt_2,
555
+ device=device,
556
+ num_images_per_prompt=num_images_per_prompt,
557
+ clip_skip=clip_skip,
558
+ clip_model_index=1,
559
+ )
560
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
561
+
562
+ t5_prompt_embed = self._get_t5_prompt_embeds(
563
+ prompt=prompt_3,
564
+ num_images_per_prompt=num_images_per_prompt,
565
+ max_sequence_length=max_sequence_length,
566
+ device=device,
567
+ )
568
+
569
+ clip_prompt_embeds = torch.nn.functional.pad(
570
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
571
+ )
572
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
573
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
574
+
575
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
576
+ negative_prompt = negative_prompt or ""
577
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
578
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
579
+
580
+ # normalize str to list
581
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
582
+ negative_prompt_2 = (
583
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
584
+ )
585
+ negative_prompt_3 = (
586
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
587
+ )
588
+
589
+ if prompt is not None and type(prompt) is not type(negative_prompt):
590
+ raise TypeError(
591
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
592
+ f" {type(prompt)}."
593
+ )
594
+ elif batch_size != len(negative_prompt):
595
+ raise ValueError(
596
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
597
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
598
+ " the batch size of `prompt`."
599
+ )
600
+
601
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
602
+ negative_prompt,
603
+ device=device,
604
+ num_images_per_prompt=num_images_per_prompt,
605
+ clip_skip=None,
606
+ clip_model_index=0,
607
+ )
608
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
609
+ negative_prompt_2,
610
+ device=device,
611
+ num_images_per_prompt=num_images_per_prompt,
612
+ clip_skip=None,
613
+ clip_model_index=1,
614
+ )
615
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
616
+
617
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
618
+ prompt=negative_prompt_3,
619
+ num_images_per_prompt=num_images_per_prompt,
620
+ max_sequence_length=max_sequence_length,
621
+ device=device,
622
+ )
623
+
624
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
625
+ negative_clip_prompt_embeds,
626
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
627
+ )
628
+
629
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
630
+ negative_pooled_prompt_embeds = torch.cat(
631
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
632
+ )
633
+
634
+ if self.text_encoder is not None:
635
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
636
+ # Retrieve the original scale by scaling back the LoRA layers
637
+ unscale_lora_layers(self.text_encoder, lora_scale)
638
+
639
+ if self.text_encoder_2 is not None:
640
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
641
+ # Retrieve the original scale by scaling back the LoRA layers
642
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
643
+
644
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
645
+
646
+ def check_inputs(
647
+ self,
648
+ prompt,
649
+ prompt_2,
650
+ prompt_3,
651
+ height,
652
+ width,
653
+ negative_prompt=None,
654
+ negative_prompt_2=None,
655
+ negative_prompt_3=None,
656
+ prompt_embeds=None,
657
+ negative_prompt_embeds=None,
658
+ pooled_prompt_embeds=None,
659
+ negative_pooled_prompt_embeds=None,
660
+ callback_on_step_end_tensor_inputs=None,
661
+ max_sequence_length=None,
662
+ ):
663
+ if (
664
+ height % (self.vae_scale_factor * self.patch_size) != 0
665
+ or width % (self.vae_scale_factor * self.patch_size) != 0
666
+ ):
667
+ raise ValueError(
668
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
669
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
670
+ )
671
+
672
+ if callback_on_step_end_tensor_inputs is not None and not all(
673
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
674
+ ):
675
+ raise ValueError(
676
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
677
+ )
678
+
679
+ if prompt is not None and prompt_embeds is not None:
680
+ raise ValueError(
681
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
682
+ " only forward one of the two."
683
+ )
684
+ elif prompt_2 is not None and prompt_embeds is not None:
685
+ raise ValueError(
686
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
687
+ " only forward one of the two."
688
+ )
689
+ elif prompt_3 is not None and prompt_embeds is not None:
690
+ raise ValueError(
691
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
692
+ " only forward one of the two."
693
+ )
694
+ elif prompt is None and prompt_embeds is None:
695
+ raise ValueError(
696
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
697
+ )
698
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
699
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
700
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
701
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
702
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
703
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
704
+
705
+ if negative_prompt is not None and negative_prompt_embeds is not None:
706
+ raise ValueError(
707
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
708
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
709
+ )
710
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
711
+ raise ValueError(
712
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
713
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
714
+ )
715
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
716
+ raise ValueError(
717
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
718
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
719
+ )
720
+
721
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
722
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
723
+ raise ValueError(
724
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
725
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
726
+ f" {negative_prompt_embeds.shape}."
727
+ )
728
+
729
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
730
+ raise ValueError(
731
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
732
+ )
733
+
734
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
735
+ raise ValueError(
736
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
737
+ )
738
+
739
+ if max_sequence_length is not None and max_sequence_length > 512:
740
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
741
+
742
+ def prepare_latents(
743
+ self,
744
+ batch_size,
745
+ num_channels_latents,
746
+ height,
747
+ width,
748
+ dtype,
749
+ device,
750
+ generator,
751
+ latents=None,
752
+ ):
753
+ if latents is not None:
754
+ return latents.to(device=device, dtype=dtype)
755
+
756
+ shape = (
757
+ batch_size,
758
+ num_channels_latents,
759
+ int(height) // self.vae_scale_factor,
760
+ int(width) // self.vae_scale_factor,
761
+ )
762
+
763
+ if isinstance(generator, list) and len(generator) != batch_size:
764
+ raise ValueError(
765
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
766
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
767
+ )
768
+
769
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
770
+
771
+ return latents
772
+
773
+ @property
774
+ def guidance_scale(self):
775
+ return self._guidance_scale
776
+
777
+ @property
778
+ def skip_guidance_layers(self):
779
+ return self._skip_guidance_layers
780
+
781
+ @property
782
+ def clip_skip(self):
783
+ return self._clip_skip
784
+
785
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
786
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
787
+ # corresponds to doing no classifier free guidance.
788
+ @property
789
+ def do_classifier_free_guidance(self):
790
+ return self._guidance_scale > 1
791
+
792
+ @property
793
+ def joint_attention_kwargs(self):
794
+ return self._joint_attention_kwargs
795
+
796
+ @property
797
+ def num_timesteps(self):
798
+ return self._num_timesteps
799
+
800
+ @property
801
+ def interrupt(self):
802
+ return self._interrupt
803
+
804
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
805
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
806
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
807
+
808
+ Args:
809
+ image (`PipelineImageInput`):
810
+ Input image to be encoded.
811
+ device: (`torch.device`):
812
+ Torch device.
813
+
814
+ Returns:
815
+ `torch.Tensor`: The encoded image feature representation.
816
+ """
817
+ if not isinstance(image, torch.Tensor):
818
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
819
+
820
+ image = image.to(device=device, dtype=self.dtype)
821
+
822
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
823
+
824
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
825
+ def prepare_ip_adapter_image_embeds(
826
+ self,
827
+ ip_adapter_image: Optional[PipelineImageInput] = None,
828
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
829
+ device: Optional[torch.device] = None,
830
+ num_images_per_prompt: int = 1,
831
+ do_classifier_free_guidance: bool = True,
832
+ ) -> torch.Tensor:
833
+ """Prepares image embeddings for use in the IP-Adapter.
834
+
835
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
836
+
837
+ Args:
838
+ ip_adapter_image (`PipelineImageInput`, *optional*):
839
+ The input image to extract features from for IP-Adapter.
840
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
841
+ Precomputed image embeddings.
842
+ device: (`torch.device`, *optional*):
843
+ Torch device.
844
+ num_images_per_prompt (`int`, defaults to 1):
845
+ Number of images that should be generated per prompt.
846
+ do_classifier_free_guidance (`bool`, defaults to True):
847
+ Whether to use classifier free guidance or not.
848
+ """
849
+ device = device or self._execution_device
850
+
851
+ if ip_adapter_image_embeds is not None:
852
+ if do_classifier_free_guidance:
853
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
854
+ else:
855
+ single_image_embeds = ip_adapter_image_embeds
856
+ elif ip_adapter_image is not None:
857
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
858
+ if do_classifier_free_guidance:
859
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
860
+ else:
861
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
862
+
863
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
864
+
865
+ if do_classifier_free_guidance:
866
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
867
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
868
+
869
+ return image_embeds.to(device=device)
870
+
871
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
872
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
873
+ logger.warning(
874
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
875
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
876
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
877
+ )
878
+
879
+ super().enable_sequential_cpu_offload(*args, **kwargs)
880
+
881
+ @torch.no_grad()
882
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
883
+ def __call__(
884
+ self,
885
+ prompt: Union[str, List[str]] = None,
886
+ prompt_2: Optional[Union[str, List[str]]] = None,
887
+ prompt_3: Optional[Union[str, List[str]]] = None,
888
+ height: Optional[int] = None,
889
+ width: Optional[int] = None,
890
+ num_inference_steps: int = 28,
891
+ sigmas: Optional[List[float]] = None,
892
+ guidance_scale: float = 7.0,
893
+ negative_prompt: Optional[Union[str, List[str]]] = None,
894
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
895
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
896
+ num_images_per_prompt: Optional[int] = 1,
897
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
898
+ latents: Optional[torch.FloatTensor] = None,
899
+ prompt_embeds: Optional[torch.FloatTensor] = None,
900
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
901
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
902
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
903
+ ip_adapter_image: Optional[PipelineImageInput] = None,
904
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
905
+ output_type: Optional[str] = "pil",
906
+ return_dict: bool = True,
907
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
908
+ clip_skip: Optional[int] = None,
909
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
910
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
911
+ max_sequence_length: int = 256,
912
+ skip_guidance_layers: List[int] = None,
913
+ skip_layer_guidance_scale: float = 2.8,
914
+ skip_layer_guidance_stop: float = 0.2,
915
+ skip_layer_guidance_start: float = 0.01,
916
+ mu: Optional[float] = None,
917
+ model: Optional[Any] = None, # 添加 model 参数,默认为 None
918
+ ):
919
+ r"""
920
+ Function invoked when calling the pipeline for generation.
921
+
922
+ Args:
923
+ prompt (`str` or `List[str]`, *optional*):
924
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
925
+ instead.
926
+ prompt_2 (`str` or `List[str]`, *optional*):
927
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
928
+ will be used instead
929
+ prompt_3 (`str` or `List[str]`, *optional*):
930
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
931
+ will be used instead
932
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
933
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
934
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
935
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
936
+ num_inference_steps (`int`, *optional*, defaults to 50):
937
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
938
+ expense of slower inference.
939
+ sigmas (`List[float]`, *optional*):
940
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
941
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
942
+ will be used.
943
+ guidance_scale (`float`, *optional*, defaults to 7.0):
944
+ Guidance scale as defined in [Classifier-Free Diffusion
945
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
946
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
947
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
948
+ the text `prompt`, usually at the expense of lower image quality.
949
+ negative_prompt (`str` or `List[str]`, *optional*):
950
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
951
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
952
+ less than `1`).
953
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
954
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
955
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
956
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
957
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
958
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
959
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
960
+ The number of images to generate per prompt.
961
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
962
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
963
+ to make generation deterministic.
964
+ latents (`torch.FloatTensor`, *optional*):
965
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
966
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
967
+ tensor will ge generated by sampling using the supplied random `generator`.
968
+ prompt_embeds (`torch.FloatTensor`, *optional*):
969
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
970
+ provided, text embeddings will be generated from `prompt` input argument.
971
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
972
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
973
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
974
+ argument.
975
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
976
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
977
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
978
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
979
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
980
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
981
+ input argument.
982
+ ip_adapter_image (`PipelineImageInput`, *optional*):
983
+ Optional image input to work with IP Adapters.
984
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
985
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
986
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
987
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
988
+ output_type (`str`, *optional*, defaults to `"pil"`):
989
+ The output format of the generate image. Choose between
990
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
991
+ return_dict (`bool`, *optional*, defaults to `True`):
992
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
993
+ a plain tuple.
994
+ joint_attention_kwargs (`dict`, *optional*):
995
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
996
+ `self.processor` in
997
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
998
+ callback_on_step_end (`Callable`, *optional*):
999
+ A function that calls at the end of each denoising steps during the inference. The function is called
1000
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1001
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1002
+ `callback_on_step_end_tensor_inputs`.
1003
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1004
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1005
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1006
+ `._callback_tensor_inputs` attribute of your pipeline class.
1007
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1008
+ skip_guidance_layers (`List[int]`, *optional*):
1009
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
1010
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
1011
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
1012
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
1013
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
1014
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
1015
+ with a scale of `1`.
1016
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
1017
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
1018
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
1019
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
1020
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
1021
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
1022
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
1023
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
1024
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
1025
+ model (`SD3WithRectifiedNoise`, *optional*):
1026
+ Optional SD3WithRectifiedNoise model for enhanced noise prediction. If provided, will be used instead of
1027
+ the default transformer for denoising. The model should be an instance of SD3WithRectifiedNoise class.
1028
+
1029
+ Examples:
1030
+
1031
+ Returns:
1032
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1033
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1034
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1035
+ """
1036
+
1037
+ height = height or self.default_sample_size * self.vae_scale_factor
1038
+ width = width or self.default_sample_size * self.vae_scale_factor
1039
+ #height=512
1040
+ #width=512
1041
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1042
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1043
+
1044
+ # 1. Check inputs. Raise error if not correct
1045
+ self.check_inputs(
1046
+ prompt,
1047
+ prompt_2,
1048
+ prompt_3,
1049
+ height,
1050
+ width,
1051
+ negative_prompt=negative_prompt,
1052
+ negative_prompt_2=negative_prompt_2,
1053
+ negative_prompt_3=negative_prompt_3,
1054
+ prompt_embeds=prompt_embeds,
1055
+ negative_prompt_embeds=negative_prompt_embeds,
1056
+ pooled_prompt_embeds=pooled_prompt_embeds,
1057
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1058
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1059
+ max_sequence_length=max_sequence_length,
1060
+ )
1061
+
1062
+ self._guidance_scale = guidance_scale
1063
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
1064
+ self._clip_skip = clip_skip
1065
+ self._joint_attention_kwargs = joint_attention_kwargs
1066
+ self._interrupt = False
1067
+
1068
+ # 2. Define call parameters
1069
+ if prompt is not None and isinstance(prompt, str):
1070
+ batch_size = 1
1071
+ elif prompt is not None and isinstance(prompt, list):
1072
+ batch_size = len(prompt)
1073
+ else:
1074
+ batch_size = prompt_embeds.shape[0]
1075
+
1076
+ device = self._execution_device
1077
+
1078
+ lora_scale = (
1079
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1080
+ )
1081
+ (
1082
+ prompt_embeds,
1083
+ negative_prompt_embeds,
1084
+ pooled_prompt_embeds,
1085
+ negative_pooled_prompt_embeds,
1086
+ ) = self.encode_prompt(
1087
+ prompt=prompt,
1088
+ prompt_2=prompt_2,
1089
+ prompt_3=prompt_3,
1090
+ negative_prompt=negative_prompt,
1091
+ negative_prompt_2=negative_prompt_2,
1092
+ negative_prompt_3=negative_prompt_3,
1093
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1094
+ prompt_embeds=prompt_embeds,
1095
+ negative_prompt_embeds=negative_prompt_embeds,
1096
+ pooled_prompt_embeds=pooled_prompt_embeds,
1097
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1098
+ device=device,
1099
+ clip_skip=self.clip_skip,
1100
+ num_images_per_prompt=num_images_per_prompt,
1101
+ max_sequence_length=max_sequence_length,
1102
+ lora_scale=lora_scale,
1103
+ )
1104
+
1105
+ if self.do_classifier_free_guidance:
1106
+ if skip_guidance_layers is not None:
1107
+ original_prompt_embeds = prompt_embeds
1108
+ original_pooled_prompt_embeds = pooled_prompt_embeds
1109
+ # print("检测negative_prompt_embeds",negative_prompt_embeds)
1110
+ #print("检测pooled_prompt_embeds",prompt_embeds)
1111
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1112
+
1113
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1114
+
1115
+ # 4. Prepare latent variables
1116
+ num_channels_latents = self.transformer.config.in_channels
1117
+ latents = self.prepare_latents(
1118
+ batch_size * num_images_per_prompt,
1119
+ num_channels_latents,
1120
+ height,
1121
+ width,
1122
+ prompt_embeds.dtype,
1123
+ device,
1124
+ generator,
1125
+ latents,
1126
+ )
1127
+
1128
+ # 5. Prepare timesteps
1129
+ scheduler_kwargs = {}
1130
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1131
+ _, _, height, width = latents.shape
1132
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1133
+ width // self.transformer.config.patch_size
1134
+ )
1135
+ mu = calculate_shift(
1136
+ image_seq_len,
1137
+ self.scheduler.config.get("base_image_seq_len", 256),
1138
+ self.scheduler.config.get("max_image_seq_len", 4096),
1139
+ self.scheduler.config.get("base_shift", 0.5),
1140
+ self.scheduler.config.get("max_shift", 1.16),
1141
+ )
1142
+ scheduler_kwargs["mu"] = mu
1143
+ elif mu is not None:
1144
+ scheduler_kwargs["mu"] = mu
1145
+ timesteps, num_inference_steps = retrieve_timesteps(
1146
+ self.scheduler,
1147
+ num_inference_steps,
1148
+ device,
1149
+ sigmas=sigmas,
1150
+ **scheduler_kwargs,
1151
+ )
1152
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1153
+ self._num_timesteps = len(timesteps)
1154
+
1155
+ # 6. Prepare image embeddings
1156
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1157
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1158
+ ip_adapter_image,
1159
+ ip_adapter_image_embeds,
1160
+ device,
1161
+ batch_size * num_images_per_prompt,
1162
+ self.do_classifier_free_guidance,
1163
+ )
1164
+
1165
+ if self.joint_attention_kwargs is None:
1166
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1167
+ else:
1168
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1169
+
1170
+ # 7. Denoising loop
1171
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1172
+ for i, t in enumerate(timesteps):
1173
+ if self.interrupt:
1174
+ continue
1175
+
1176
+ # expand the latents if we are doing classifier free guidance
1177
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1178
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1179
+ timestep = t.expand(latent_model_input.shape[0])
1180
+
1181
+ # Check for NaN in latents before transformer
1182
+ if torch.isnan(latents).any():
1183
+ # print(f"NaN detected in latents at step {i}")
1184
+ # print(f"NaN locations: {torch.where(torch.isnan(latents))}")
1185
+ break
1186
+ # 优先使用传入的 model 参数,其次使用类属性中的 model,最后回退到默认 transformer
1187
+ effective_model = model or getattr(self, 'model', None) or self.transformer
1188
+
1189
+ if hasattr(effective_model, '__call__') and callable(effective_model):
1190
+ # 使用有效的模型进行预测
1191
+ # 检查模型是否支持 skip_layers 参数
1192
+ if hasattr(effective_model, 'forward') and 'skip_layers' in inspect.signature(effective_model.forward).parameters:
1193
+ noise_pred_output = effective_model(
1194
+ hidden_states=latent_model_input,
1195
+ timestep=timestep,
1196
+ encoder_hidden_states=prompt_embeds,
1197
+ pooled_projections=pooled_prompt_embeds,
1198
+ #joint_attention_kwargs=self.joint_attention_kwargs,
1199
+ return_dict=False,
1200
+ #skip_layers=skip_guidance_layers if skip_guidance_layers is not None else None,
1201
+ )
1202
+ #print(f"effective_model type: {type(effective_model)}")
1203
+ #print(f"noise_pred_output: {noise_pred_output}")
1204
+ else:
1205
+ # SD3WithRectifiedNoise 不支持 skip_layers 参数
1206
+ noise_pred_output = effective_model(
1207
+ hidden_states=latent_model_input,
1208
+ timestep=timestep,
1209
+ encoder_hidden_states=prompt_embeds,
1210
+ pooled_projections=pooled_prompt_embeds,
1211
+ joint_attention_kwargs=self.joint_attention_kwargs,
1212
+ return_dict=False,
1213
+ )
1214
+ # 正确处理 SD3WithRectifiedNoise 模型的输出
1215
+ # SD3WithRectifiedNoise 返回 (final_output, mean_out, var_out) 元组,我们只需要第一个元素
1216
+ # 如果返回的是字典,则使用 "sample" 键
1217
+ if isinstance(noise_pred_output, dict):
1218
+ noise_pred = noise_pred_output["sample"]
1219
+ elif isinstance(noise_pred_output, tuple):
1220
+ # 对于 SD3WithRectifiedNoise,取第一个输出作为主要预测结果
1221
+ noise_pred = noise_pred_output[0]
1222
+ else:
1223
+ noise_pred = noise_pred_output
1224
+ else:
1225
+ # 使用默认的 transformer 进行预测
1226
+ noise_pred = self.transformer(
1227
+ hidden_states=latent_model_input,
1228
+ timestep=timestep,
1229
+ encoder_hidden_states=prompt_embeds,
1230
+ pooled_projections=pooled_prompt_embeds,
1231
+ joint_attention_kwargs=self.joint_attention_kwargs,
1232
+ return_dict=False,
1233
+ )[0]
1234
+
1235
+ # Check for NaN in noise prediction
1236
+ if torch.isnan(noise_pred).any():
1237
+ # print(f"NaN detected in noise_pred at step {i}")
1238
+ # print(f"NaN locations: {torch.where(torch.isnan(noise_pred))}")
1239
+ # print(f"noise_pred stats - min: {noise_pred.min().item()}, max: {noise_pred.max().item()}, mean: {noise_pred.mean().item()}")
1240
+ break
1241
+
1242
+ # perform guidance
1243
+ if self.do_classifier_free_guidance:
1244
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1245
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1246
+
1247
+ # Check for NaN after guidance
1248
+ if torch.isnan(noise_pred).any():
1249
+ # print(f"NaN detected in noise_pred after guidance at step {i}")
1250
+ # print(f"noise_pred_uncond stats - min: {noise_pred_uncond.min().item()}, max: {noise_pred_uncond.max().item()}, mean: {noise_pred_uncond.mean().item()}")
1251
+ # print(f"noise_pred_text stats - min: {noise_pred_text.min().item()}, max: {noise_pred_text.max().item()}, mean: {noise_pred_text.mean().item()}")
1252
+ # print(f"guidance_scale: {self.guidance_scale}")
1253
+ break
1254
+
1255
+ should_skip_layers = (
1256
+ True
1257
+ if i > num_inference_steps * skip_layer_guidance_start
1258
+ and i < num_inference_steps * skip_layer_guidance_stop
1259
+ else False
1260
+ )
1261
+ if skip_guidance_layers is not None and should_skip_layers:
1262
+ timestep = t.expand(latents.shape[0])
1263
+ latent_model_input = latents
1264
+ # 修改 skip_guidance_layers 部分的 transformer 调用逻辑以支持 SD3WithRectifiedNoise 模型
1265
+ # 优先使用传入的 model 参数,其次使用类属性中的 model,最后回退到默认 transformer
1266
+ effective_model = model or getattr(self, 'model', None) or self.transformer
1267
+
1268
+ if hasattr(effective_model, '__call__') and callable(effective_model):
1269
+ # 使用有效的模型进行预测
1270
+ # 检查模型是否支持 skip_layers 参数
1271
+ if hasattr(effective_model, 'forward') and 'skip_layers' in inspect.signature(effective_model.forward).parameters:
1272
+ noise_pred_skip_output = effective_model(
1273
+ hidden_states=latent_model_input,
1274
+ timestep=timestep,
1275
+ encoder_hidden_states=original_prompt_embeds,
1276
+ pooled_projections=original_pooled_prompt_embeds,
1277
+ joint_attention_kwargs=self.joint_attention_kwargs,
1278
+ return_dict=False,
1279
+ skip_layers=skip_guidance_layers,
1280
+ )
1281
+ else:
1282
+ # SD3WithRectifiedNoise 不支持 skip_layers 参数
1283
+ noise_pred_skip_output = effective_model(
1284
+ hidden_states=latent_model_input,
1285
+ timestep=timestep,
1286
+ encoder_hidden_states=original_prompt_embeds,
1287
+ pooled_projections=original_pooled_prompt_embeds,
1288
+ joint_attention_kwargs=self.joint_attention_kwargs,
1289
+ return_dict=False,
1290
+ )
1291
+ # 正确处理 SD3WithRectifiedNoise 模型的输出
1292
+ # SD3WithRectifiedNoise 返回 (final_output, mean_out, var_out) 元组,我们只需要第一个元素
1293
+ # 如果返回的是字典,则使用 "sample" 键
1294
+ if isinstance(noise_pred_skip_output, dict):
1295
+ noise_pred_skip_layers = noise_pred_skip_output["sample"]
1296
+ elif isinstance(noise_pred_skip_output, tuple):
1297
+ # 对于 SD3WithRectifiedNoise,取第一个输出作为主要预测结果
1298
+ noise_pred_skip_layers = noise_pred_skip_output[0]
1299
+ else:
1300
+ noise_pred_skip_layers = noise_pred_skip_output
1301
+ else:
1302
+ # 使用默认的 transformer 进行预测
1303
+ noise_pred_skip_layers = self.transformer(
1304
+ hidden_states=latent_model_input,
1305
+ timestep=timestep,
1306
+ encoder_hidden_states=original_prompt_embeds,
1307
+ pooled_projections=original_pooled_prompt_embeds,
1308
+ joint_attention_kwargs=self.joint_attention_kwargs,
1309
+ return_dict=False,
1310
+ skip_layers=skip_guidance_layers,
1311
+ )[0]
1312
+
1313
+ # Check for NaN in skip layers noise prediction
1314
+ if torch.isnan(noise_pred_skip_layers).any():
1315
+ # print(f"NaN detected in noise_pred_skip_layers at step {i}")
1316
+ # print(f"NaN locations: {torch.where(torch.isnan(noise_pred_skip_layers))}")
1317
+ break
1318
+
1319
+ noise_pred = (
1320
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1321
+ )
1322
+
1323
+ # Check for NaN after skip layer guidance
1324
+ if torch.isnan(noise_pred).any():
1325
+ # print(f"NaN detected in noise_pred after skip layer guidance at step {i}")
1326
+ break
1327
+
1328
+ # compute the previous noisy sample x_t -> x_t-1
1329
+ latents_dtype = latents.dtype
1330
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1331
+
1332
+ # Check for NaN in latents after scheduler step
1333
+ if torch.isnan(latents).any():
1334
+ # print(f"NaN detected in latents after scheduler step at step {i}")
1335
+ # print(f"noise_pred stats - min: {noise_pred.min().item()}, max: {noise_pred.max().item()}, mean: {noise_pred.mean().item()}")
1336
+ break
1337
+
1338
+ # Print intermediate results
1339
+ # print(f"Step {i+1}/{num_inference_steps}, Timestep: {t.item():.2f}, Latents mean: {latents.mean().item():.6f}, Latents std: {latents.std().item():.6f}")
1340
+
1341
+ if latents.dtype != latents_dtype:
1342
+ if torch.backends.mps.is_available():
1343
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1344
+ latents = latents.to(latents_dtype)
1345
+
1346
+ if callback_on_step_end is not None:
1347
+ callback_kwargs = {}
1348
+ for k in callback_on_step_end_tensor_inputs:
1349
+ callback_kwargs[k] = locals()[k]
1350
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1351
+
1352
+ latents = callback_outputs.pop("latents", latents)
1353
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1354
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
1355
+
1356
+ # call the callback, if provided
1357
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1358
+ progress_bar.update()
1359
+
1360
+ if XLA_AVAILABLE:
1361
+ xm.mark_step()
1362
+
1363
+ if output_type == "latent":
1364
+ image = latents
1365
+
1366
+ else:
1367
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1368
+
1369
+ image = self.vae.decode(latents, return_dict=False)[0]
1370
+ image = self.image_processor.postprocess(image, output_type=output_type)
1371
+
1372
+ # Offload all models
1373
+ self.maybe_free_model_hooks()
1374
+
1375
+ if not return_dict:
1376
+ return (image,)
1377
+
1378
+ return StableDiffusion3PipelineOutput(images=image)
rectified-noise-batch-2/checkpoint-100000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
rectified-noise-batch-2/checkpoint-120000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
rectified-noise-batch-2/checkpoint-140000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
rectified-noise-batch-2/checkpoint-160000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
rectified-noise-batch-2/checkpoint-180000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
rectified-noise-batch-2/checkpoint-200000/sit_weights/sit_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_sit_layers": 1,
3
+ "hidden_size": 4096,
4
+ "input_dim": 16,
5
+ "num_attention_heads": 16,
6
+ "intermediate_size": 16384,
7
+ "model_type": "rectified_noise",
8
+ "architecture": "SIT",
9
+ "version": "1.0"
10
+ }
run_sd3_lora_rn_pair_sampling.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
4
+
5
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
6
+ LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
7
+ RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
8
+
9
+ CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
10
+ SAMPLE_DIR="./sd3_lora_rn_pair_samples"
11
+
12
+ NUM_INFERENCE_STEPS=40
13
+ GUIDANCE_SCALE=7.0
14
+ HEIGHT=512
15
+ WIDTH=512
16
+ PER_PROC_BATCH_SIZE=1
17
+ IMAGES_PER_CAPTION=1
18
+ MAX_SAMPLES=500
19
+ GLOBAL_SEED=42
20
+ MIXED_PRECISION="fp16"
21
+ NUM_SIT_LAYERS=1
22
+
23
+ ARGS=(
24
+ --pretrained_model_name_or_path "$PRETRAINED_MODEL"
25
+ --captions_jsonl "$CAPTIONS_JSONL"
26
+ --sample_dir "$SAMPLE_DIR"
27
+ --num_inference_steps $NUM_INFERENCE_STEPS
28
+ --guidance_scale $GUIDANCE_SCALE
29
+ --height $HEIGHT
30
+ --width $WIDTH
31
+ --per_proc_batch_size $PER_PROC_BATCH_SIZE
32
+ --images_per_caption $IMAGES_PER_CAPTION
33
+ --max_samples $MAX_SAMPLES
34
+ --global_seed $GLOBAL_SEED
35
+ --num_sit_layers $NUM_SIT_LAYERS
36
+ --mixed_precision $MIXED_PRECISION
37
+ --rectified_weights "$RECTIFIED_WEIGHTS"
38
+ )
39
+
40
+ if [ -n "$LORA_PATH" ]; then
41
+ ARGS+=(--lora_path "$LORA_PATH")
42
+ fi
43
+
44
+ torchrun --nproc_per_node=4 --master_port=25923 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage lora
45
+ torchrun --nproc_per_node=4 --master_port=25924 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage rn
46
+ torchrun --nproc_per_node=4 --master_port=25925 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage pair
47
+
48
+ echo "Sampling done. Output at: $SAMPLE_DIR"
49
+ # nohup bash run_sd3_lora_rn_pair_sampling.sh > run_sd3_lora_rn_pair_sampling.log 2>&1 &
50
+
run_sd3_lora_sampling.log ADDED
The diff for this file is too large to render. See raw diff
 
run_sd3_lora_sampling.sh ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # SD3 LoRA模型采样脚本
4
+ # 使用JSONL文件进行采样的示例脚本
5
+ # 使用方法: ./run_sd3_lora_sampling.sh
6
+
7
+ # 设置GPU设备
8
+ export CUDA_VISIBLE_DEVICES="0,1,2,3" # 使用4个GPU(0,1,2,3)
9
+
10
+ # 内存优化设置
11
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
12
+
13
+ # 模型和LoRA路径配置
14
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
15
+ # LoRA checkpoint路径 - 使用accelerator checkpoint目录
16
+ LORA_CHECKPOINT_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
17
+ # LoRA rank(必须与训练时一致)
18
+ LORA_RANK=32
19
+
20
+ # 采样参数配置
21
+ NUM_INFERENCE_STEPS=40
22
+ GUIDANCE_SCALE=7.0
23
+ HEIGHT=512
24
+ WIDTH=512
25
+ PER_PROC_BATCH_SIZE=1 # 每个GPU的批大小,建议从1开始(SD3模型很大,保持为1以避免内存溢出)
26
+ MAX_SAMPLES=30000 # 最大采样数量限制
27
+
28
+ # 提示词配置
29
+ #NEGATIVE_PROMPT="blurry, low quality, distorted, ugly, bad anatomy"
30
+
31
+ # Caption文件配置
32
+ CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl" # JSONL文件路径
33
+ IMAGES_PER_CAPTION=3 # 每个caption生成几张图片
34
+
35
+ # 输出配置
36
+ SAMPLE_DIR="./sd3_lora_samples_3w"
37
+ GLOBAL_SEED=42
38
+
39
+ echo "开始SD3 LoRA采样(从checkpoint加载)..."
40
+ echo "模型: $PRETRAINED_MODEL"
41
+ echo "LoRA Checkpoint路径: $LORA_CHECKPOINT_PATH"
42
+ echo "LoRA Rank: $LORA_RANK"
43
+ echo "Caption文件: $CAPTIONS_JSONL"
44
+ echo "每个caption生成图片数: $IMAGES_PER_CAPTION"
45
+ echo "图像尺寸: ${HEIGHT}x${WIDTH}"
46
+ echo "引导尺度: $GUIDANCE_SCALE"
47
+ echo "推理步数: $NUM_INFERENCE_STEPS"
48
+
49
+ # 检查必要文件
50
+ if [ ! -f "$CAPTIONS_JSONL" ]; then
51
+ echo "错误: Caption文件 $CAPTIONS_JSONL 不存在"
52
+ exit 1
53
+ fi
54
+
55
+ if [ ! -d "$LORA_CHECKPOINT_PATH" ]; then
56
+ echo "错误: LoRA checkpoint目录 $LORA_CHECKPOINT_PATH 不存在"
57
+ exit 1
58
+ fi
59
+
60
+ # 构建命令参数数组
61
+ CMD_ARGS=(
62
+ "--pretrained_model_name_or_path=$PRETRAINED_MODEL"
63
+ "--lora_checkpoint_path=$LORA_CHECKPOINT_PATH"
64
+ "--lora_rank=$LORA_RANK"
65
+ "--num_inference_steps=$NUM_INFERENCE_STEPS"
66
+ "--guidance_scale=$GUIDANCE_SCALE"
67
+ "--height=$HEIGHT"
68
+ "--width=$WIDTH"
69
+ "--per_proc_batch_size=$PER_PROC_BATCH_SIZE"
70
+ "--captions_jsonl=$CAPTIONS_JSONL"
71
+ "--images_per_caption=$IMAGES_PER_CAPTION"
72
+ "--sample_dir=$SAMPLE_DIR"
73
+ "--global_seed=$GLOBAL_SEED"
74
+ #"--max_samples=$MAX_SAMPLES"
75
+ "--mixed_precision=fp16" # 使用 fp16 以减少内存占用
76
+ # 注意:在多GPU环境下,CPU offload会被代码自动禁用(不支持分布式)
77
+ # 代码会自动检测world_size > 1并禁用CPU offload
78
+ "--enable_cpu_offload"
79
+ )
80
+
81
+ # # 添加负面提示词参数(如果存在)
82
+ # if [ ! -z "$NEGATIVE_PROMPT" ]; then
83
+ # CMD_ARGS+=("--negative_prompt" "$NEGATIVE_PROMPT")
84
+ # fi
85
+
86
+ # 运行分布式采样
87
+ torchrun --nproc_per_node=4 --master_port=25900 sample_sd3_lora_checkpoint_ddp.py "${CMD_ARGS[@]}"
88
+
89
+ echo "采样完成!"
90
+ echo "结果保存在: $SAMPLE_DIR"
91
+ echo "Caption信息保存在: $SAMPLE_DIR/*/captions.txt"
92
+ echo "NPZ文件已生成用于FID评估"
93
+
94
+ # nohup bash run_sd3_lora_sampling.sh > run_sd3_lora_sampling.log 2>&1 &
run_sd3_rectified_sampling.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 分布式采样:指定 LoRA 与 Rectified(SIT) 权重
4
+
5
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
6
+
7
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
8
+ LOCAL_PIPELINE_PATH="/gemini/space/gzy_new/models/Sida/pipeline_stable_diffusion_3.py"
9
+ LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
10
+ RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
11
+
12
+ CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
13
+ SAMPLE_DIR="./sd3_rectified_samples_batch2_220000"
14
+
15
+ NUM_INFERENCE_STEPS=40
16
+ GUIDANCE_SCALE=7.0
17
+ HEIGHT=512
18
+ WIDTH=512
19
+ PER_PROC_BATCH_SIZE=32
20
+ IMAGES_PER_CAPTION=3
21
+ MAX_SAMPLES=30000
22
+ GLOBAL_SEED=42
23
+ MIXED_PRECISION="fp16" # no / fp16 / bf16
24
+ NUM_SIT_LAYERS=1 # 需与训练一致
25
+
26
+ ARGS=(
27
+ --pretrained_model_name_or_path "$PRETRAINED_MODEL"
28
+ --captions_jsonl "$CAPTIONS_JSONL"
29
+ --sample_dir "$SAMPLE_DIR"
30
+ --num_inference_steps $NUM_INFERENCE_STEPS
31
+ --guidance_scale $GUIDANCE_SCALE
32
+ --height $HEIGHT
33
+ --width $WIDTH
34
+ --per_proc_batch_size $PER_PROC_BATCH_SIZE
35
+ --images_per_caption $IMAGES_PER_CAPTION
36
+ --max_samples $MAX_SAMPLES
37
+ --global_seed $GLOBAL_SEED
38
+ --num_sit_layers $NUM_SIT_LAYERS
39
+ --mixed_precision $MIXED_PRECISION
40
+ )
41
+
42
+ if [ -n "$LORA_PATH" ]; then
43
+ ARGS+=(--lora_path "$LORA_PATH")
44
+ fi
45
+
46
+ if [ -n "$RECTIFIED_WEIGHTS" ]; then
47
+ ARGS+=(--rectified_weights "$RECTIFIED_WEIGHTS")
48
+ fi
49
+
50
+ torchrun --nproc_per_node=4 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"
51
+
52
+ echo "Sampling done. Output at: $SAMPLE_DIR"
53
+ # nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
54
+
55
+
run_sd3_rectified_sampling_old.sh ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # 分布式采样:指定 LoRA 与 Rectified(SIT) 权重
5
+
6
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
7
+ export NCCL_DEBUG=INFO
8
+ export NCCL_DEBUG_SUBSYS=ALL
9
+ export NCCL_IB_DISABLE=1
10
+ export NCCL_P2P_LEVEL=SYS
11
+
12
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
13
+ #"/gemini/space/zhaozy/zhy/hsd/project/pretrained_model/models--stabilityai--stable-diffusion-3-medium-diffusers"
14
+ LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000" # 可为空
15
+ RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-120000" # 可为空(若不用 Rectified)
16
+
17
+ CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
18
+ SAMPLE_DIR="./sd3_rectified_samples_new_batch_2"
19
+
20
+ NUM_INFERENCE_STEPS=40
21
+ GUIDANCE_SCALE=7.0
22
+ HEIGHT=512
23
+ WIDTH=512
24
+ PER_PROC_BATCH_SIZE=32
25
+ IMAGES_PER_CAPTION=3
26
+ MAX_SAMPLES=30000
27
+ GLOBAL_SEED=42
28
+ MIXED_PRECISION="fp16" # no / fp16 / bf16
29
+ NUM_SIT_LAYERS=1 # 需与训练一致
30
+
31
+ ARGS=(
32
+ --pretrained_model_name_or_path "$PRETRAINED_MODEL"
33
+ --captions_jsonl "$CAPTIONS_JSONL"
34
+ --sample_dir "$SAMPLE_DIR"
35
+ --num_inference_steps $NUM_INFERENCE_STEPS
36
+ --guidance_scale $GUIDANCE_SCALE
37
+ --height $HEIGHT
38
+ --width $WIDTH
39
+ --per_proc_batch_size $PER_PROC_BATCH_SIZE
40
+ --images_per_caption $IMAGES_PER_CAPTION
41
+ --max_samples $MAX_SAMPLES
42
+ --global_seed $GLOBAL_SEED
43
+ --num_sit_layers $NUM_SIT_LAYERS
44
+ --mixed_precision $MIXED_PRECISION
45
+ )
46
+
47
+ if [ -n "$LORA_PATH" ]; then
48
+ ARGS+=(--lora_path "$LORA_PATH")
49
+ fi
50
+
51
+ if [ -n "$RECTIFIED_WEIGHTS" ]; then
52
+ ARGS+=(--rectified_weights "$RECTIFIED_WEIGHTS")
53
+ fi
54
+
55
+ echo "[run_sd3_rectified_sampling.sh] start torchrun: $(date)"
56
+
57
+ # 先尝试 4 卡模式,如果失败则退到单卡模式
58
+ if ! torchrun --nproc_per_node=4 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"; then
59
+ ret=$?
60
+ echo "[run_sd3_rectified_sampling.sh] 4卡运行失败(退出码 ${ret}),尝试单卡模式"
61
+ if ! torchrun --nproc_per_node=1 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"; then
62
+ ret2=$?
63
+ echo "[run_sd3_rectified_sampling.sh] 单卡运行也失败(退出码 ${ret2}),请查看具体错误信息。"
64
+ exit $ret2
65
+ fi
66
+ echo "[run_sd3_rectified_sampling.sh] 单卡运行成功,建议降低 per_proc_batch_size 或使用单卡配置继续。"
67
+ fi
68
+
69
+ wait
70
+
71
+ echo "Sampling done. Output at: $SAMPLE_DIR"
72
+ # nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
sample_sd3_lora_checkpoint_ddp.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ SD3 LoRA分布式采样脚本 - 从accelerator checkpoint加载LoRA权重
5
+ 使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估
6
+ """
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tqdm import tqdm
11
+ import os
12
+ from PIL import Image
13
+ import numpy as np
14
+ import math
15
+ import argparse
16
+ import sys
17
+ import json
18
+ import random
19
+ from pathlib import Path
20
+
21
+ from diffusers import (
22
+ StableDiffusion3Pipeline,
23
+ AutoencoderKL,
24
+ FlowMatchEulerDiscreteScheduler,
25
+ SD3Transformer2DModel,
26
+ )
27
+ from transformers import CLIPTokenizer, T5TokenizerFast
28
+ from accelerate import Accelerator
29
+ from peft import LoraConfig, PeftModel
30
+ from peft.utils import get_peft_model_state_dict
31
+ from safetensors.torch import load_file, save_file
32
+
33
+
34
+ def create_npz_from_sample_folder(sample_dir, num_samples):
35
+ """
36
+ 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
37
+ """
38
+ samples = []
39
+ actual_files = []
40
+
41
+ # 收集所有PNG文件
42
+ for filename in sorted(os.listdir(sample_dir)):
43
+ if filename.endswith('.png'):
44
+ actual_files.append(filename)
45
+
46
+ # 按照数量限制处理
47
+ for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
48
+ if i < len(actual_files):
49
+ sample_path = os.path.join(sample_dir, actual_files[i])
50
+ sample_pil = Image.open(sample_path)
51
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
52
+ samples.append(sample_np)
53
+ else:
54
+ # 如果不够,创建空白图像
55
+ sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
56
+ samples.append(sample_np)
57
+
58
+ if samples:
59
+ samples = np.stack(samples)
60
+ npz_path = f"{sample_dir}.npz"
61
+ np.savez(npz_path, arr_0=samples)
62
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
63
+ return npz_path
64
+ else:
65
+ print("No samples found to create npz file.")
66
+ return None
67
+
68
+
69
+ def extract_lora_from_checkpoint(checkpoint_path, output_lora_path, rank=64, rank0_only=True):
70
+ """
71
+ 从accelerator checkpoint中提取LoRA权重并保存为标准格式
72
+
73
+ Args:
74
+ checkpoint_path: checkpoint目录路径
75
+ output_lora_path: 输出LoRA权重保存路径
76
+ rank: LoRA rank
77
+ rank0_only: 是否只在rank 0上执行
78
+ """
79
+ model_file = os.path.join(checkpoint_path, "model.safetensors")
80
+ if not os.path.exists(model_file):
81
+ if rank0_only:
82
+ print(f"Model file not found: {model_file}")
83
+ return False
84
+
85
+ try:
86
+ # 加载checkpoint state dict
87
+ state_dict = load_file(model_file)
88
+
89
+ if rank0_only:
90
+ print(f"Loaded checkpoint with {len(state_dict)} keys")
91
+
92
+ # 提取LoRA权重
93
+ # Accelerator保存的格式可能是: "transformer.lora_A.weight" 或 "model.transformer.lora_A.weight"
94
+ # 需要转换为diffusers格式: "transformer.lora_A.weight"
95
+ lora_state_dict = {}
96
+
97
+ # 查找所有LoRA相关的键
98
+ lora_keys = []
99
+ for key in state_dict.keys():
100
+ # 检查是否是LoRA权重(lora_A, lora_B, lora_embedding等)
101
+ if 'lora_A' in key or 'lora_B' in key or 'lora_embedding' in key:
102
+ lora_keys.append(key)
103
+
104
+ if rank0_only:
105
+ print(f"Found {len(lora_keys)} LoRA keys")
106
+ if lora_keys:
107
+ print(f"Sample LoRA keys: {lora_keys[:5]}")
108
+
109
+ if not lora_keys:
110
+ if rank0_only:
111
+ print("Warning: No LoRA keys found in checkpoint. Trying alternative extraction method...")
112
+
113
+ # 尝试另一种方法:检查是否有完整的transformer权重
114
+ # 如果是全量微调,我们需要计算LoRA权重 = 微调权重 - 基础权重
115
+ # 但这需要基础模型,所以这里我们假设checkpoint中已经包含了LoRA权重
116
+ # 或者checkpoint保存的是合并后的权重
117
+
118
+ # 检查是否有transformer的完整权重
119
+ transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()]
120
+ if transformer_keys:
121
+ if rank0_only:
122
+ print(f"Found {len(transformer_keys)} transformer keys (full fine-tuning checkpoint)")
123
+ print("This checkpoint appears to contain full model weights, not LoRA weights.")
124
+ print("You may need to use a different loading method.")
125
+ return False
126
+
127
+ # 转换键名格式:从accelerator格式转换为diffusers格式
128
+ for key in lora_keys:
129
+ # 移除可能的"model."前缀
130
+ new_key = key
131
+ if new_key.startswith("model."):
132
+ new_key = new_key[6:] # 移除"model."前缀
133
+
134
+ # 确保键名符合diffusers格式
135
+ # diffusers格式通常是: "transformer.lora_A.weight" 或 "transformer.transformer_blocks.X.attn.to_q.lora_A.weight"
136
+ lora_state_dict[new_key] = state_dict[key]
137
+
138
+ if not lora_state_dict:
139
+ if rank0_only:
140
+ print("Error: Failed to extract LoRA weights from checkpoint")
141
+ return False
142
+
143
+ # 保存LoRA权重
144
+ if rank0_only:
145
+ os.makedirs(output_lora_path, exist_ok=True)
146
+ lora_file = os.path.join(output_lora_path, "pytorch_lora_weights.safetensors")
147
+ save_file(lora_state_dict, lora_file)
148
+ print(f"Saved LoRA weights to {lora_file} ({len(lora_state_dict)} keys)")
149
+
150
+ return True
151
+
152
+ except Exception as e:
153
+ if rank0_only:
154
+ print(f"Error extracting LoRA from checkpoint: {e}")
155
+ import traceback
156
+ traceback.print_exc()
157
+ return False
158
+
159
+
160
+ def load_lora_from_checkpoint_direct(pipeline, checkpoint_path, rank=64, rank0_print=True):
161
+ """
162
+ 直接从checkpoint加载LoRA权重到pipeline
163
+
164
+ 这个方法尝试直接从checkpoint中加载LoRA权重,而不需要先提取
165
+ """
166
+ model_file = os.path.join(checkpoint_path, "model.safetensors")
167
+ if not os.path.exists(model_file):
168
+ if rank0_print:
169
+ print(f"Model file not found: {model_file}")
170
+ return False
171
+
172
+ try:
173
+ # 加载checkpoint state dict
174
+ state_dict = load_file(model_file)
175
+
176
+ if rank0_print:
177
+ print(f"Loaded checkpoint with {len(state_dict)} keys")
178
+ # 显示前10个键名以便调试
179
+ sample_keys = list(state_dict.keys())[:10]
180
+ print(f"Sample keys: {sample_keys}")
181
+
182
+ # 查找LoRA权重
183
+ lora_keys = [k for k in state_dict.keys() if 'lora_A' in k or 'lora_B' in k or 'lora_embedding' in k]
184
+
185
+ if not lora_keys:
186
+ if rank0_print:
187
+ print("No LoRA keys found in checkpoint.")
188
+ print("This checkpoint might contain merged weights or use a different format.")
189
+ print("Checking checkpoint structure...")
190
+
191
+ # 检查是否是全量微调的checkpoint(包含完整transformer权重)
192
+ transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()]
193
+ if transformer_keys:
194
+ if rank0_print:
195
+ print(f"Found {len(transformer_keys)} transformer keys")
196
+ print("This appears to be a full fine-tuning checkpoint with merged weights.")
197
+ print("Attempting to use Accelerator to load the checkpoint...")
198
+
199
+ # 尝试使用Accelerator加载checkpoint
200
+ try:
201
+ # 配置LoRA适配器
202
+ transformer_lora_config = LoraConfig(
203
+ r=rank,
204
+ lora_alpha=rank,
205
+ init_lora_weights="gaussian",
206
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
207
+ )
208
+
209
+ # 为transformer添加LoRA适配器
210
+ pipeline.transformer.add_adapter(transformer_lora_config)
211
+
212
+ # 使用Accelerator加载checkpoint
213
+ accelerator = Accelerator()
214
+ # 准备模型
215
+ transformer_prepared = accelerator.prepare(pipeline.transformer)
216
+ # 加载状态
217
+ accelerator.load_state(checkpoint_path)
218
+ # 提取模型
219
+ pipeline.transformer = accelerator.unwrap_model(transformer_prepared)
220
+
221
+ if rank0_print:
222
+ print("Successfully loaded checkpoint using Accelerator")
223
+ return True
224
+ except Exception as e:
225
+ if rank0_print:
226
+ print(f"Failed to load using Accelerator: {e}")
227
+ return False
228
+ else:
229
+ if rank0_print:
230
+ print("Could not identify checkpoint format. Please check the checkpoint structure.")
231
+ return False
232
+
233
+ if rank0_print:
234
+ print(f"Found {len(lora_keys)} LoRA keys")
235
+ print(f"Sample LoRA keys: {lora_keys[:5]}")
236
+
237
+ # 配置LoRA适配器
238
+ transformer_lora_config = LoraConfig(
239
+ r=rank,
240
+ lora_alpha=rank,
241
+ init_lora_weights="gaussian",
242
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
243
+ )
244
+
245
+ # 为transformer添加LoRA适配器
246
+ pipeline.transformer.add_adapter(transformer_lora_config)
247
+
248
+ if rank0_print:
249
+ print("LoRA adapter configured")
250
+
251
+ # 提取并转换LoRA权重
252
+ lora_state_dict = {}
253
+ for key in lora_keys:
254
+ # 移除可能的"model."或"transformer."前缀(取决于accelerator保存格式)
255
+ new_key = key
256
+ # 移除常见的accelerator前缀
257
+ prefixes_to_remove = ["model.", "module.", "transformer."]
258
+ for prefix in prefixes_to_remove:
259
+ if new_key.startswith(prefix):
260
+ new_key = new_key[len(prefix):]
261
+ break
262
+
263
+ # 确保键名符合PEFT格式
264
+ # PEFT格式通常是: "transformer_blocks.X.attn.to_q.lora_A.weight"
265
+ # 或者: "lora_A.weight" (如果已经包含完整路径)
266
+ lora_state_dict[new_key] = state_dict[key]
267
+
268
+ if rank0_print:
269
+ print(f"Extracted {len(lora_state_dict)} LoRA weights")
270
+ print(f"Sample extracted keys: {list(lora_state_dict.keys())[:5]}")
271
+
272
+ # 加载LoRA权重到模型
273
+ # 使用PEFT的load_state_dict方法
274
+ missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
275
+
276
+ if rank0_print:
277
+ if missing_keys:
278
+ print(f"Missing keys: {len(missing_keys)}")
279
+ if len(missing_keys) <= 10:
280
+ for k in missing_keys:
281
+ print(f" - {k}")
282
+ else:
283
+ print(f" (showing first 10 of {len(missing_keys)} missing keys)")
284
+ for k in list(missing_keys)[:10]:
285
+ print(f" - {k}")
286
+ if unexpected_keys:
287
+ print(f"Unexpected keys: {len(unexpected_keys)}")
288
+ if len(unexpected_keys) <= 10:
289
+ for k in unexpected_keys:
290
+ print(f" - {k}")
291
+ else:
292
+ print(f" (showing first 10 of {len(unexpected_keys)} unexpected keys)")
293
+ for k in list(unexpected_keys)[:10]:
294
+ print(f" - {k}")
295
+
296
+ # 检查是否有peft_config
297
+ if hasattr(pipeline.transformer, 'peft_config'):
298
+ if rank0_print:
299
+ print(f"LoRA config found: {list(pipeline.transformer.peft_config.keys())}")
300
+ else:
301
+ if rank0_print:
302
+ print("Warning: No peft_config found after loading LoRA")
303
+
304
+ # 验证LoRA是否真的被加载
305
+ if rank0_print:
306
+ # 检查一个LoRA层的权重是否非零
307
+ has_lora_weights = False
308
+ for name, param in pipeline.transformer.named_parameters():
309
+ if 'lora' in name.lower() and param.requires_grad:
310
+ if param.abs().max().item() > 1e-6:
311
+ has_lora_weights = True
312
+ if rank0_print:
313
+ print(f"Verified LoRA weights loaded (found non-zero LoRA param: {name})")
314
+ break
315
+
316
+ if not has_lora_weights:
317
+ print("Warning: LoRA weights may not have been loaded correctly (all LoRA params are zero or not found)")
318
+
319
+ if rank0_print:
320
+ print("LoRA weights loaded successfully")
321
+
322
+ return True
323
+
324
+ except Exception as e:
325
+ if rank0_print:
326
+ print(f"Error loading LoRA from checkpoint: {e}")
327
+ import traceback
328
+ traceback.print_exc()
329
+ return False
330
+
331
+
332
+ def load_captions_from_jsonl(jsonl_path):
333
+ """
334
+ 从JSONL文件加载caption列表
335
+ """
336
+ captions = []
337
+ try:
338
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
339
+ for line_num, line in enumerate(f, 1):
340
+ line = line.strip()
341
+ if not line:
342
+ continue
343
+
344
+ try:
345
+ data = json.loads(line)
346
+ # 支持多种字段名
347
+ caption = None
348
+ for field in ['caption', 'text', 'prompt', 'description']:
349
+ if field in data and isinstance(data[field], str):
350
+ caption = data[field].strip()
351
+ break
352
+
353
+ if caption:
354
+ captions.append(caption)
355
+ else:
356
+ # 如果没有找到标准字段,取第一个字符串值
357
+ for value in data.values():
358
+ if isinstance(value, str) and value.strip():
359
+ captions.append(value.strip())
360
+ break
361
+
362
+ except json.JSONDecodeError as e:
363
+ print(f"Warning: Invalid JSON on line {line_num}: {e}")
364
+ continue
365
+
366
+ except FileNotFoundError:
367
+ print(f"Error: JSONL file {jsonl_path} not found")
368
+ return []
369
+ except Exception as e:
370
+ print(f"Error loading JSONL file {jsonl_path}: {e}")
371
+ return []
372
+
373
+ print(f"Loaded {len(captions)} captions from {jsonl_path}")
374
+ return captions
375
+
376
+
377
+ def main(args):
378
+ """
379
+ 运行 SD3 LoRA 采样
380
+ """
381
+ assert torch.cuda.is_available(), "DDP采样需要至少一个GPU"
382
+ torch.set_grad_enabled(False)
383
+
384
+ # 设置 DDP
385
+ dist.init_process_group("nccl")
386
+ rank = dist.get_rank()
387
+ world_size = dist.get_world_size()
388
+ device = torch.device(f"cuda:{rank}")
389
+ seed = args.global_seed * world_size + rank
390
+ torch.manual_seed(seed)
391
+ torch.cuda.set_device(device)
392
+ print(f"Starting rank={rank}, device={device}, seed={seed}, world_size={world_size}, visible_devices={torch.cuda.device_count()}.")
393
+
394
+ # 加载captions
395
+ captions = []
396
+ if args.captions_jsonl:
397
+ if rank == 0:
398
+ print(f"Loading captions from {args.captions_jsonl}")
399
+ captions = load_captions_from_jsonl(args.captions_jsonl)
400
+ if not captions:
401
+ if rank == 0:
402
+ print("Warning: No captions loaded, using default caption")
403
+ captions = ["a beautiful high quality image"]
404
+ else:
405
+ # 使用默认caption
406
+ captions = ["a beautiful high quality image"]
407
+
408
+ # 计算总的图片数量
409
+ total_images_needed = len(captions) * args.images_per_caption
410
+ # 应用最大样本数限制
411
+ total_images_needed = min(total_images_needed, args.max_samples)
412
+ if rank == 0:
413
+ print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions")
414
+ print(f"Total images requested: {len(captions) * args.images_per_caption}")
415
+ print(f"Max samples limit: {args.max_samples}")
416
+ print(f"Total images to generate: {total_images_needed}")
417
+
418
+ # 设置数据类型 - 使用混合精度以减少内存占用
419
+ if args.mixed_precision == "fp16":
420
+ dtype = torch.float16
421
+ elif args.mixed_precision == "bf16":
422
+ dtype = torch.bfloat16
423
+ else:
424
+ dtype = torch.float32
425
+
426
+ # 加载基础模型
427
+ if rank == 0:
428
+ print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}")
429
+
430
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
431
+ args.pretrained_model_name_or_path,
432
+ revision=args.revision,
433
+ variant=args.variant,
434
+ torch_dtype=dtype,
435
+ )
436
+
437
+ # 从checkpoint加载LoRA权重
438
+ lora_loaded = False
439
+ lora_source = "baseline"
440
+
441
+ if args.lora_checkpoint_path:
442
+ if rank == 0:
443
+ print(f"Loading LoRA weights from checkpoint: {args.lora_checkpoint_path}")
444
+
445
+ # 方法1: 直接从checkpoint加载
446
+ lora_loaded = load_lora_from_checkpoint_direct(
447
+ pipeline,
448
+ args.lora_checkpoint_path,
449
+ rank=args.lora_rank,
450
+ rank0_print=(rank == 0)
451
+ )
452
+
453
+ if lora_loaded:
454
+ lora_source = os.path.basename(args.lora_checkpoint_path.rstrip('/'))
455
+ if rank == 0:
456
+ print("Successfully loaded LoRA weights from checkpoint")
457
+ else:
458
+ if rank == 0:
459
+ print("Failed to load LoRA weights directly from checkpoint")
460
+ print("Trying alternative method: extracting LoRA weights first...")
461
+
462
+ # 方法2: 先提取LoRA权重,再加载
463
+ temp_lora_path = os.path.join(args.lora_checkpoint_path, "extracted_lora")
464
+ if rank == 0:
465
+ extract_success = extract_lora_from_checkpoint(
466
+ args.lora_checkpoint_path,
467
+ temp_lora_path,
468
+ rank=args.lora_rank,
469
+ rank0_only=True
470
+ )
471
+ else:
472
+ extract_success = False
473
+
474
+ dist.barrier() # 等待rank 0完成提取
475
+
476
+ if extract_success and os.path.exists(os.path.join(temp_lora_path, "pytorch_lora_weights.safetensors")):
477
+ if rank == 0:
478
+ print(f"Loading extracted LoRA weights from {temp_lora_path}")
479
+ try:
480
+ pipeline.load_lora_weights(temp_lora_path)
481
+ lora_loaded = True
482
+ lora_source = f"{os.path.basename(args.lora_checkpoint_path.rstrip('/'))}_extracted"
483
+ if rank == 0:
484
+ print("Successfully loaded extracted LoRA weights")
485
+ except Exception as e:
486
+ if rank == 0:
487
+ print(f"Failed to load extracted LoRA weights: {e}")
488
+
489
+ if not lora_loaded:
490
+ if rank == 0:
491
+ print("Warning: No LoRA weights loaded. Using baseline model.")
492
+
493
+ # 启用内存优化选项(必须在移动到设备之前)
494
+ # 注意:在分布式环境下,CPU offload 不支持多GPU,会导致所有进程挤在一张卡上
495
+ # 因此禁用 CPU offload,改用其他内存优化方法
496
+ if args.enable_cpu_offload and world_size > 1:
497
+ if rank == 0:
498
+ print(f"Warning: CPU offload is disabled in multi-GPU mode (world_size={world_size})")
499
+ print("Using device-specific placement instead")
500
+ args.enable_cpu_offload = False
501
+
502
+ if args.enable_cpu_offload:
503
+ if rank == 0:
504
+ print("Enabling CPU offload to save memory (single GPU mode)")
505
+ # CPU offload 会自动管理设备,不需要先 to(device)
506
+ pipeline.enable_model_cpu_offload()
507
+ else:
508
+ # 在分布式环境下,明确将pipeline移动到对应的设备
509
+ if rank == 0:
510
+ print(f"Moving pipeline to device {device} (multi-GPU mode)")
511
+ pipeline = pipeline.to(device)
512
+ if rank == 0:
513
+ print("Enabling memory optimization options")
514
+
515
+ # 检查并启用可用的内存优化方法
516
+ # 注意:所有进程都需要执行这些操作,不仅仅是 rank 0
517
+ if hasattr(pipeline, 'enable_attention_slicing'):
518
+ try:
519
+ pipeline.enable_attention_slicing()
520
+ if rank == 0:
521
+ print(" - Attention slicing enabled")
522
+ except Exception as e:
523
+ if rank == 0:
524
+ print(f" - Warning: Failed to enable attention slicing: {e}")
525
+ else:
526
+ if rank == 0:
527
+ print(" - Attention slicing not available for this pipeline")
528
+
529
+ # SD3 pipeline 可能不支持 enable_vae_slicing,需要检查
530
+ # 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
531
+ enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
532
+ if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
533
+ try:
534
+ enable_vae_slicing_method()
535
+ if rank == 0:
536
+ print(" - VAE slicing enabled")
537
+ except Exception as e:
538
+ if rank == 0:
539
+ print(f" - Warning: Failed to enable VAE slicing: {e}")
540
+ else:
541
+ if rank == 0:
542
+ print(" - VAE slicing not available for this pipeline (SD3 may not support this)")
543
+
544
+ # 验证设备分配
545
+ if rank == 0:
546
+ print(f"Pipeline device verification:")
547
+ print(f" - Transformer device: {next(pipeline.transformer.parameters()).device}")
548
+ print(f" - VAE device: {next(pipeline.vae.parameters()).device}")
549
+ if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
550
+ print(f" - Text encoder device: {next(pipeline.text_encoder.parameters()).device}")
551
+ dist.barrier() # 等待所有进程完成设备分配
552
+
553
+ # 禁用进度条
554
+ pipeline.set_progress_bar_config(disable=True)
555
+
556
+ # 创建保存目录
557
+ folder_name = f"checkpoint-{lora_source}-rank{args.lora_rank}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
558
+ sample_folder_dir = os.path.join(args.sample_dir, folder_name)
559
+
560
+ if rank == 0:
561
+ os.makedirs(sample_folder_dir, exist_ok=True)
562
+ print(f"Saving .png samples at {sample_folder_dir}")
563
+ # 清空caption文件
564
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
565
+ if os.path.exists(caption_file):
566
+ os.remove(caption_file)
567
+ dist.barrier()
568
+
569
+ # 计算采样参数
570
+ n = args.per_proc_batch_size
571
+ global_batch_size = n * dist.get_world_size()
572
+
573
+ # 检查已存在的样本数量
574
+ existing_samples = 0
575
+ if os.path.exists(sample_folder_dir):
576
+ existing_samples = len([
577
+ name for name in os.listdir(sample_folder_dir)
578
+ if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png")
579
+ ])
580
+
581
+ total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size)
582
+ if rank == 0:
583
+ print(f"Total number of images that will be sampled: {total_samples}")
584
+ print(f"Existing samples: {existing_samples}")
585
+
586
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
587
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
588
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
589
+
590
+ iterations = int(samples_needed_this_gpu // n)
591
+ done_iterations = int(int(existing_samples // dist.get_world_size()) // n)
592
+
593
+ pbar = range(done_iterations, iterations)
594
+ pbar = tqdm(pbar) if rank == 0 else pbar
595
+
596
+ # 生成caption和image的映射列表
597
+ caption_image_pairs = []
598
+ for i, caption in enumerate(captions):
599
+ for j in range(args.images_per_caption):
600
+ caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
601
+
602
+ total_generated = existing_samples
603
+
604
+ # 采样循环
605
+ for i in pbar:
606
+ # 获取这个batch对应的caption
607
+ batch_prompts = []
608
+ batch_caption_info = []
609
+
610
+ for j in range(n):
611
+ global_index = i * global_batch_size + j * dist.get_world_size() + rank
612
+ if global_index < len(caption_image_pairs):
613
+ caption, caption_idx, image_idx = caption_image_pairs[global_index]
614
+ batch_prompts.append(caption)
615
+ batch_caption_info.append((caption, caption_idx, image_idx))
616
+ else:
617
+ # 如果超出范围,使用最后一个caption
618
+ if caption_image_pairs:
619
+ caption, caption_idx, image_idx = caption_image_pairs[-1]
620
+ batch_prompts.append(caption)
621
+ batch_caption_info.append((caption, caption_idx, image_idx))
622
+ else:
623
+ batch_prompts.append("a beautiful high quality image")
624
+ batch_caption_info.append(("a beautiful high quality image", 0, 0))
625
+
626
+ # 生成图像 - 为每个图像使用不同的随机种子
627
+ # 确保使用正确的设备进行autocast
628
+ device_str = str(device) # 使用明确的设备字符串,如 "cuda:0", "cuda:1" 等
629
+ with torch.autocast(device_str, dtype=dtype):
630
+ # 为每个prompt生成独立的图像(使用不同的generator)
631
+ images = []
632
+ for k, prompt in enumerate(batch_prompts):
633
+ # 为每个图像创建独立的随机种子
634
+ image_seed = seed + i * 10000 + k * 1000 + rank
635
+ generator = torch.Generator(device=device).manual_seed(image_seed)
636
+
637
+ # 调试信息(仅在第一个batch的第一个图像时打印)
638
+ if i == done_iterations and k == 0 and rank < 2: # 只打印前两个rank
639
+ print(f"[Rank {rank}] Generating image on device {device}, generator device: {generator.device}")
640
+
641
+ image = pipeline(
642
+ prompt=prompt,
643
+ negative_prompt=args.negative_prompt if args.negative_prompt else None,
644
+ height=args.height,
645
+ width=args.width,
646
+ num_inference_steps=args.num_inference_steps,
647
+ guidance_scale=args.guidance_scale,
648
+ generator=generator,
649
+ num_images_per_prompt=1,
650
+ ).images[0]
651
+ images.append(image)
652
+
653
+ # 清理 GPU 缓存以释放内存
654
+ if k == len(batch_prompts) - 1: # 每个 batch 的最后一张图片后清理
655
+ torch.cuda.empty_cache()
656
+
657
+ # 保存图像
658
+ for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
659
+ global_index = i * global_batch_size + j * dist.get_world_size() + rank
660
+ if global_index < len(caption_image_pairs):
661
+ # 保存图片,文件名包含caption索引和图片索引
662
+ filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
663
+ image_path = os.path.join(sample_folder_dir, filename)
664
+ image.save(image_path)
665
+
666
+ # 保存caption信息到文本文件(只在rank 0上操作)
667
+ if rank == 0:
668
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
669
+ with open(caption_file, "a", encoding="utf-8") as f:
670
+ f.write(f"{filename}\t{caption}\n")
671
+
672
+ total_generated += global_batch_size
673
+
674
+ # 每个迭代后清理 GPU 缓存
675
+ torch.cuda.empty_cache()
676
+
677
+ dist.barrier()
678
+
679
+ # 确保所有进程都完成采样
680
+ dist.barrier()
681
+
682
+ # 创建npz文件
683
+ if rank == 0:
684
+ # 重新计算实际生成的图片数量
685
+ actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
686
+ print(f"Actually generated {actual_num_samples} images")
687
+ # 使用实际的图片数量或用户指定的数量,取较小值
688
+ npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
689
+ create_npz_from_sample_folder(sample_folder_dir, npz_samples)
690
+ print("Done.")
691
+
692
+ dist.barrier()
693
+ dist.destroy_process_group()
694
+
695
+
696
+ if __name__ == "__main__":
697
+ parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本 - 从checkpoint加载")
698
+
699
+ # 模型和路径参数
700
+ parser.add_argument(
701
+ "--pretrained_model_name_or_path",
702
+ type=str,
703
+ default="stabilityai/stable-diffusion-3-medium-diffusers",
704
+ help="预训练模型路径或HuggingFace模型ID"
705
+ )
706
+ parser.add_argument(
707
+ "--lora_checkpoint_path",
708
+ type=str,
709
+ required=True,
710
+ help="LoRA checkpoint目录路径(包含model.safetensors的目录)"
711
+ )
712
+ parser.add_argument(
713
+ "--lora_rank",
714
+ type=int,
715
+ default=64,
716
+ help="LoRA rank(必须与训练时一致)"
717
+ )
718
+ parser.add_argument(
719
+ "--revision",
720
+ type=str,
721
+ default=None,
722
+ help="模型修订版本"
723
+ )
724
+ parser.add_argument(
725
+ "--variant",
726
+ type=str,
727
+ default=None,
728
+ help="模型变体,如fp16"
729
+ )
730
+
731
+ # 采样参数
732
+ parser.add_argument(
733
+ "--num_inference_steps",
734
+ type=int,
735
+ default=28,
736
+ help="推理步数"
737
+ )
738
+ parser.add_argument(
739
+ "--guidance_scale",
740
+ type=float,
741
+ default=7.0,
742
+ help="引导尺度"
743
+ )
744
+ parser.add_argument(
745
+ "--height",
746
+ type=int,
747
+ default=1024,
748
+ help="生成图像高度"
749
+ )
750
+ parser.add_argument(
751
+ "--width",
752
+ type=int,
753
+ default=1024,
754
+ help="生成图像宽度"
755
+ )
756
+ parser.add_argument(
757
+ "--negative_prompt",
758
+ type=str,
759
+ default="",
760
+ help="负面提示词"
761
+ )
762
+
763
+ # 批处理和数据集参数
764
+ parser.add_argument(
765
+ "--per_proc_batch_size",
766
+ type=int,
767
+ default=1,
768
+ help="每个进程的批处理大小"
769
+ )
770
+ parser.add_argument(
771
+ "--sample_dir",
772
+ type=str,
773
+ default="sd3_lora_samples",
774
+ help="样本保存目录"
775
+ )
776
+
777
+ # Caption相关参数
778
+ parser.add_argument(
779
+ "--captions_jsonl",
780
+ type=str,
781
+ required=True,
782
+ help="包含caption列表的JSONL文件路径"
783
+ )
784
+ parser.add_argument(
785
+ "--images_per_caption",
786
+ type=int,
787
+ default=1,
788
+ help="每个caption生成的图像数量"
789
+ )
790
+ parser.add_argument(
791
+ "--max_samples",
792
+ type=int,
793
+ default=30000,
794
+ help="最大样本生成数量"
795
+ )
796
+
797
+ # 其他参数
798
+ parser.add_argument(
799
+ "--global_seed",
800
+ type=int,
801
+ default=42,
802
+ help="全局随机种子"
803
+ )
804
+ parser.add_argument(
805
+ "--mixed_precision",
806
+ type=str,
807
+ default="fp16",
808
+ choices=["no", "fp16", "bf16"],
809
+ help="混合精度类型"
810
+ )
811
+ parser.add_argument(
812
+ "--enable_cpu_offload",
813
+ action="store_true",
814
+ help="启用CPU offload以节省显存"
815
+ )
816
+
817
+ args = parser.parse_args()
818
+ main(args)
sample_sd3_lora_ddp.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ SD3 LoRA分布式采样脚本
5
+ 使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估
6
+ """
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tqdm import tqdm
11
+ import os
12
+ from PIL import Image
13
+ import numpy as np
14
+ import math
15
+ import argparse
16
+ import sys
17
+ import json
18
+ import random
19
+ from pathlib import Path
20
+
21
+ from diffusers import (
22
+ StableDiffusion3Pipeline,
23
+ AutoencoderKL,
24
+ FlowMatchEulerDiscreteScheduler,
25
+ SD3Transformer2DModel,
26
+ )
27
+ from transformers import CLIPTokenizer, T5TokenizerFast
28
+ from accelerate import Accelerator
29
+ from peft import LoraConfig
30
+ from peft.utils import get_peft_model_state_dict
31
+
32
+
33
+ def create_npz_from_sample_folder(sample_dir, num_samples):
34
+ """
35
+ 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
36
+ """
37
+ samples = []
38
+ actual_files = []
39
+
40
+ # 收集所有PNG文件
41
+ for filename in sorted(os.listdir(sample_dir)):
42
+ if filename.endswith('.png'):
43
+ actual_files.append(filename)
44
+
45
+ # 按照数量限制处理
46
+ for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
47
+ if i < len(actual_files):
48
+ sample_path = os.path.join(sample_dir, actual_files[i])
49
+ sample_pil = Image.open(sample_path)
50
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
51
+ samples.append(sample_np)
52
+ else:
53
+ # 如果不够,创建空白图像
54
+ sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
55
+ samples.append(sample_np)
56
+
57
+ if samples:
58
+ samples = np.stack(samples)
59
+ npz_path = f"{sample_dir}.npz"
60
+ np.savez(npz_path, arr_0=samples)
61
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
62
+ return npz_path
63
+ else:
64
+ print("No samples found to create npz file.")
65
+ return None
66
+
67
+
68
+ def find_latest_checkpoint(output_dir):
69
+ """
70
+ 查找最新的检查点目录
71
+ """
72
+ checkpoint_dirs = []
73
+ if os.path.exists(output_dir):
74
+ for item in os.listdir(output_dir):
75
+ if item.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, item)):
76
+ try:
77
+ step = int(item.split("-")[1])
78
+ checkpoint_dirs.append((step, item))
79
+ except (ValueError, IndexError):
80
+ continue
81
+
82
+ if checkpoint_dirs:
83
+ # 按步数排序,返回最新的
84
+ checkpoint_dirs.sort(key=lambda x: x[0])
85
+ latest_step, latest_dir = checkpoint_dirs[-1]
86
+ latest_path = os.path.join(output_dir, latest_dir)
87
+ return latest_path, latest_step
88
+ return None, None
89
+
90
+
91
+ def check_lora_weights_exist(lora_path):
92
+ """
93
+ 检查LoRA权重文件是否存在
94
+ """
95
+ if not lora_path:
96
+ return False
97
+
98
+ # 检查是否是目录
99
+ if os.path.isdir(lora_path):
100
+ # 检查目录中是否有pytorch_lora_weights.safetensors文件
101
+ weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
102
+ if os.path.exists(weight_file):
103
+ return True
104
+ # 检查是否有其他.safetensors文件
105
+ for file in os.listdir(lora_path):
106
+ if file.endswith(".safetensors") and "lora" in file.lower():
107
+ return True
108
+ return False
109
+
110
+ # 检查是否是文件
111
+ elif os.path.isfile(lora_path):
112
+ return lora_path.endswith(".safetensors")
113
+
114
+ return False
115
+
116
+
117
+ def check_full_finetune_checkpoint(checkpoint_path):
118
+ """
119
+ 检查是否是全量微调的checkpoint(包含model.safetensors)
120
+ """
121
+ if not checkpoint_path or not os.path.isdir(checkpoint_path):
122
+ return False
123
+
124
+ # 检查是否有model.safetensors文件(全量微调的标志)
125
+ model_file = os.path.join(checkpoint_path, "model.safetensors")
126
+ return os.path.exists(model_file)
127
+
128
+
129
+ def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0):
130
+ """
131
+ 从检查点加载LoRA权重
132
+ """
133
+ if rank == 0:
134
+ print(f"Loading LoRA weights from checkpoint: {checkpoint_path}")
135
+
136
+ # 直接从检查点目录加载state dict
137
+ try:
138
+ # 使用accelerator来加载检查点
139
+ accelerator = Accelerator()
140
+
141
+ # 先配置LoRA
142
+ transformer_lora_config = LoraConfig(
143
+ r=64, # 假设使用rank=64,可以根据需要调整
144
+ lora_alpha=64,
145
+ init_lora_weights="gaussian",
146
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
147
+ )
148
+
149
+ # 为transformer添加LoRA
150
+ pipeline.transformer.add_adapter(transformer_lora_config)
151
+
152
+ # 加载检查点状态
153
+ accelerator.load_state(checkpoint_path)
154
+
155
+ if rank == 0:
156
+ print(f"Successfully loaded LoRA weights from checkpoint {checkpoint_path}")
157
+
158
+ return True
159
+
160
+ except Exception as e:
161
+ if rank == 0:
162
+ print(f"Error loading LoRA from checkpoint {checkpoint_path}: {e}")
163
+ print("Falling back to baseline model without LoRA")
164
+ return False
165
+
166
+
167
+ def load_captions_from_jsonl(jsonl_path):
168
+ """
169
+ 从JSONL文件加载caption列表
170
+ """
171
+ captions = []
172
+ try:
173
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
174
+ for line_num, line in enumerate(f, 1):
175
+ line = line.strip()
176
+ if not line:
177
+ continue
178
+
179
+ try:
180
+ data = json.loads(line)
181
+ # 支持多种字段名
182
+ caption = None
183
+ for field in ['caption', 'text', 'prompt', 'description']:
184
+ if field in data and isinstance(data[field], str):
185
+ caption = data[field].strip()
186
+ break
187
+
188
+ if caption:
189
+ captions.append(caption)
190
+ else:
191
+ # 如果没有找到标准字段,取第一个字符串值
192
+ for value in data.values():
193
+ if isinstance(value, str) and value.strip():
194
+ captions.append(value.strip())
195
+ break
196
+
197
+ except json.JSONDecodeError as e:
198
+ print(f"Warning: Invalid JSON on line {line_num}: {e}")
199
+ continue
200
+
201
+ except FileNotFoundError:
202
+ print(f"Error: JSONL file {jsonl_path} not found")
203
+ return []
204
+ except Exception as e:
205
+ print(f"Error loading JSONL file {jsonl_path}: {e}")
206
+ return []
207
+
208
+ print(f"Loaded {len(captions)} captions from {jsonl_path}")
209
+ return captions
210
+
211
+
212
+ def main(args):
213
+ """
214
+ 运行 SD3 LoRA 采样
215
+ """
216
+ assert torch.cuda.is_available(), "DDP采样需要至少一个GPU"
217
+ torch.set_grad_enabled(False)
218
+
219
+ # 设置 DDP
220
+ dist.init_process_group("nccl")
221
+ rank = dist.get_rank()
222
+ device = rank % torch.cuda.device_count()
223
+ seed = args.global_seed * dist.get_world_size() + rank
224
+ torch.manual_seed(seed)
225
+ torch.cuda.set_device(device)
226
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
227
+
228
+ # 加载captions
229
+ captions = []
230
+ if args.captions_jsonl:
231
+ if rank == 0:
232
+ print(f"Loading captions from {args.captions_jsonl}")
233
+ captions = load_captions_from_jsonl(args.captions_jsonl)
234
+ if not captions:
235
+ if rank == 0:
236
+ print("Warning: No captions loaded, using default caption")
237
+ captions = ["a beautiful high quality image"]
238
+ else:
239
+ # 使用默认caption
240
+ captions = ["a beautiful high quality image"]
241
+
242
+ # 计算总的图片数量
243
+ total_images_needed = len(captions) * args.images_per_caption
244
+ # 应用最大样本数限制
245
+ total_images_needed = min(total_images_needed, args.max_samples)
246
+ if rank == 0:
247
+ print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions")
248
+ print(f"Total images requested: {len(captions) * args.images_per_caption}")
249
+ print(f"Max samples limit: {args.max_samples}")
250
+ print(f"Total images to generate: {total_images_needed}")
251
+
252
+ # 设置数据类型 - 使用混合精度以减少内存占用
253
+ if args.mixed_precision == "fp16":
254
+ dtype = torch.float16
255
+ elif args.mixed_precision == "bf16":
256
+ dtype = torch.bfloat16
257
+ else:
258
+ dtype = torch.float32
259
+
260
+ # 检查是否是全量微调的checkpoint
261
+ is_full_finetune = False
262
+ if args.lora_path and check_full_finetune_checkpoint(args.lora_path):
263
+ # 全量微调:直接从checkpoint加载
264
+ if rank == 0:
265
+ print(f"Detected full fine-tuning checkpoint, loading from: {args.lora_path}")
266
+ try:
267
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
268
+ args.lora_path,
269
+ revision=args.revision,
270
+ variant=args.variant,
271
+ torch_dtype=dtype,
272
+ )
273
+ is_full_finetune = True
274
+ lora_source = os.path.basename(args.lora_path.rstrip('/'))
275
+ if rank == 0:
276
+ print("Successfully loaded full fine-tuned model from checkpoint")
277
+ except Exception as e:
278
+ if rank == 0:
279
+ print(f"Failed to load full fine-tuned model: {e}")
280
+ print("Falling back to baseline model + LoRA loading")
281
+ is_full_finetune = False
282
+
283
+ # 如果不是全量微调,加载基础模型
284
+ if not is_full_finetune:
285
+ if rank == 0:
286
+ print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}")
287
+
288
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
289
+ args.pretrained_model_name_or_path,
290
+ revision=args.revision,
291
+ variant=args.variant,
292
+ torch_dtype=dtype,
293
+ )
294
+
295
+ # 检查和加载 LoRA 权重(仅当不是全量微调时)
296
+ lora_loaded = False
297
+ lora_source = "baseline" if not is_full_finetune else lora_source
298
+
299
+ if not is_full_finetune and args.lora_path:
300
+ # 检查指定的LoRA路径是否存在权重文件
301
+ if check_lora_weights_exist(args.lora_path):
302
+ if rank == 0:
303
+ print(f"Loading LoRA weights from specified path: {args.lora_path}")
304
+ try:
305
+ pipeline.load_lora_weights(args.lora_path)
306
+ lora_loaded = True
307
+ lora_source = os.path.basename(args.lora_path.rstrip('/'))
308
+ if rank == 0:
309
+ print("Successfully loaded LoRA weights from specified path")
310
+ except Exception as e:
311
+ if rank == 0:
312
+ print(f"Failed to load LoRA from specified path: {e}")
313
+ else:
314
+ if rank == 0:
315
+ print(f"No LoRA weights found at specified path: {args.lora_path}")
316
+
317
+ # 如果没有成功加载LoRA权重,尝试从当前目录或检查点加载(仅当不是全量微调时)
318
+ if not is_full_finetune and not lora_loaded:
319
+ # 首先检查当前工作目录是否有权重文件
320
+ current_dir = os.getcwd()
321
+ if check_lora_weights_exist(current_dir):
322
+ if rank == 0:
323
+ print(f"Found LoRA weights in current directory: {current_dir}")
324
+ try:
325
+ pipeline.load_lora_weights(current_dir)
326
+ lora_loaded = True
327
+ lora_source = "current_dir"
328
+ if rank == 0:
329
+ print("Successfully loaded LoRA weights from current directory")
330
+ except Exception as e:
331
+ if rank == 0:
332
+ print(f"Failed to load LoRA from current directory: {e}")
333
+
334
+ # 如果当前目录也没有,检查是否有检查点目录
335
+ if not lora_loaded:
336
+ # 检查常见的输出目录
337
+ possible_output_dirs = [
338
+ "sd3-lora-finetuned",
339
+ "sd3-lora-finetuned-last",
340
+ "output",
341
+ "checkpoints"
342
+ ]
343
+
344
+ checkpoint_found = False
345
+ for output_dir in possible_output_dirs:
346
+ if os.path.exists(output_dir):
347
+ # 首先检查输出目录是否直接包含权重文件
348
+ if check_lora_weights_exist(output_dir):
349
+ if rank == 0:
350
+ print(f"Found LoRA weights in output directory: {output_dir}")
351
+ try:
352
+ pipeline.load_lora_weights(output_dir)
353
+ lora_loaded = True
354
+ lora_source = output_dir
355
+ if rank == 0:
356
+ print(f"Successfully loaded LoRA weights from {output_dir}")
357
+ break
358
+ except Exception as e:
359
+ if rank == 0:
360
+ print(f"Failed to load LoRA from {output_dir}: {e}")
361
+
362
+ # 如果输出目录没有直接的权重文件,查找最新的检查点
363
+ if not lora_loaded:
364
+ latest_checkpoint, latest_step = find_latest_checkpoint(output_dir)
365
+ if latest_checkpoint:
366
+ if rank == 0:
367
+ print(f"Found latest checkpoint: {latest_checkpoint} (step {latest_step})")
368
+
369
+ # 尝试从检查点加载LoRA权重
370
+ if load_lora_from_checkpoint(pipeline, latest_checkpoint, rank):
371
+ lora_loaded = True
372
+ lora_source = f"checkpoint-{latest_step}"
373
+ checkpoint_found = True
374
+ break
375
+
376
+ if not checkpoint_found and not lora_loaded:
377
+ if rank == 0:
378
+ print("No LoRA weights or checkpoints found. Using baseline model.")
379
+
380
+ # 启用内存优化选项(必须在移动到设备之前)
381
+ if args.enable_cpu_offload:
382
+ if rank == 0:
383
+ print("Enabling CPU offload to save memory")
384
+ # CPU offload 会自动管理设备,不需要先 to(device)
385
+ pipeline.enable_model_cpu_offload()
386
+ else:
387
+ # 如果不使用 CPU offload,先移动到设备,然后启用其他优化
388
+ pipeline = pipeline.to(device)
389
+ if rank == 0:
390
+ print("Enabling memory optimization options")
391
+
392
+ # 检查并启用可用的内存优化方法
393
+ # 注意:所有进程都需要执行这些操作,不仅仅是 rank 0
394
+ if hasattr(pipeline, 'enable_attention_slicing'):
395
+ try:
396
+ pipeline.enable_attention_slicing()
397
+ if rank == 0:
398
+ print(" - Attention slicing enabled")
399
+ except Exception as e:
400
+ if rank == 0:
401
+ print(f" - Warning: Failed to enable attention slicing: {e}")
402
+ else:
403
+ if rank == 0:
404
+ print(" - Attention slicing not available for this pipeline")
405
+
406
+ # SD3 pipeline 可能不支持 enable_vae_slicing,需要检查
407
+ # 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
408
+ enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
409
+ if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
410
+ try:
411
+ enable_vae_slicing_method()
412
+ if rank == 0:
413
+ print(" - VAE slicing enabled")
414
+ except Exception as e:
415
+ if rank == 0:
416
+ print(f" - Warning: Failed to enable VAE slicing: {e}")
417
+ else:
418
+ if rank == 0:
419
+ print(" - VAE slicing not available for this pipeline (SD3 may not support this)")
420
+
421
+ # 禁用进度条
422
+ pipeline.set_progress_bar_config(disable=True)
423
+
424
+ # 创建保存目录
425
+ folder_name = f"batch32-rank64-last-sd3-{lora_source}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
426
+ sample_folder_dir = os.path.join(args.sample_dir, folder_name)
427
+
428
+ if rank == 0:
429
+ os.makedirs(sample_folder_dir, exist_ok=True)
430
+ print(f"Saving .png samples at {sample_folder_dir}")
431
+ # 清空caption文件
432
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
433
+ if os.path.exists(caption_file):
434
+ os.remove(caption_file)
435
+ dist.barrier()
436
+
437
+ # 计算采样参数
438
+ n = args.per_proc_batch_size
439
+ global_batch_size = n * dist.get_world_size()
440
+
441
+ # 检查已存在的样本数量
442
+ existing_samples = 0
443
+ if os.path.exists(sample_folder_dir):
444
+ existing_samples = len([
445
+ name for name in os.listdir(sample_folder_dir)
446
+ if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png")
447
+ ])
448
+
449
+ total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size)
450
+ if rank == 0:
451
+ print(f"Total number of images that will be sampled: {total_samples}")
452
+ print(f"Existing samples: {existing_samples}")
453
+
454
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
455
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
456
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
457
+
458
+ iterations = int(samples_needed_this_gpu // n)
459
+ done_iterations = int(int(existing_samples // dist.get_world_size()) // n)
460
+
461
+ pbar = range(done_iterations, iterations)
462
+ pbar = tqdm(pbar) if rank == 0 else pbar
463
+
464
+ # 生成caption和image的映射列表
465
+ caption_image_pairs = []
466
+ for i, caption in enumerate(captions):
467
+ for j in range(args.images_per_caption):
468
+ caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
469
+
470
+ total_generated = existing_samples
471
+
472
+ # 采样循环
473
+ for i in pbar:
474
+ # 获取这个batch对应的caption
475
+ batch_prompts = []
476
+ batch_caption_info = []
477
+
478
+ for j in range(n):
479
+ global_index = i * global_batch_size + j * dist.get_world_size() + rank
480
+ if global_index < len(caption_image_pairs):
481
+ caption, caption_idx, image_idx = caption_image_pairs[global_index]
482
+ batch_prompts.append(caption)
483
+ batch_caption_info.append((caption, caption_idx, image_idx))
484
+ else:
485
+ # 如果超出范围,使用最后一个caption
486
+ if caption_image_pairs:
487
+ caption, caption_idx, image_idx = caption_image_pairs[-1]
488
+ batch_prompts.append(caption)
489
+ batch_caption_info.append((caption, caption_idx, image_idx))
490
+ else:
491
+ batch_prompts.append("a beautiful high quality image")
492
+ batch_caption_info.append(("a beautiful high quality image", 0, 0))
493
+
494
+ # 生成图像 - 为每个图像使用不同的随机种子
495
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
496
+ with torch.autocast(device_str, dtype=dtype):
497
+ # 为每个prompt生成独立的图像(使用不同的generator)
498
+ images = []
499
+ for k, prompt in enumerate(batch_prompts):
500
+ # 为每个图像创建独立的随机种子
501
+ image_seed = seed + i * 10000 + k * 1000 + rank
502
+ generator = torch.Generator(device=device).manual_seed(image_seed)
503
+
504
+ image = pipeline(
505
+ prompt=prompt,
506
+ negative_prompt=args.negative_prompt if args.negative_prompt else None,
507
+ height=args.height,
508
+ width=args.width,
509
+ num_inference_steps=args.num_inference_steps,
510
+ guidance_scale=args.guidance_scale,
511
+ generator=generator,
512
+ num_images_per_prompt=1,
513
+ ).images[0]
514
+ images.append(image)
515
+
516
+ # 清理 GPU 缓存以释放内存
517
+ if k == len(batch_prompts) - 1: # 每个 batch 的最后一张图片后清理
518
+ torch.cuda.empty_cache()
519
+
520
+ # 保存图像
521
+ for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
522
+ global_index = i * global_batch_size + j * dist.get_world_size() + rank
523
+ if global_index < len(caption_image_pairs):
524
+ # 保存图片,文件名包含caption索引和图片索引
525
+ filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
526
+ image_path = os.path.join(sample_folder_dir, filename)
527
+ image.save(image_path)
528
+
529
+ # 保存caption信息到文本文件(只在rank 0上操作)
530
+ if rank == 0:
531
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
532
+ with open(caption_file, "a", encoding="utf-8") as f:
533
+ f.write(f"{filename}\t{caption}\n")
534
+
535
+ total_generated += global_batch_size
536
+
537
+ # 每个迭代后清理 GPU 缓存
538
+ torch.cuda.empty_cache()
539
+
540
+ dist.barrier()
541
+
542
+ # 确保所有进程都完成采样
543
+ dist.barrier()
544
+
545
+ # 创建npz文件
546
+ if rank == 0:
547
+ # 重新计算实际生成的图片数量
548
+ actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
549
+ print(f"Actually generated {actual_num_samples} images")
550
+ # 使用实际的图片数量或用户指定的数量,取较小值
551
+ npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
552
+ create_npz_from_sample_folder(sample_folder_dir, npz_samples)
553
+ print("Done.")
554
+
555
+ dist.barrier()
556
+ dist.destroy_process_group()
557
+
558
+
559
+ if __name__ == "__main__":
560
+ parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本")
561
+
562
+ # 模型和路径参数
563
+ parser.add_argument(
564
+ "--pretrained_model_name_or_path",
565
+ type=str,
566
+ default="stabilityai/stable-diffusion-3-medium-diffusers",
567
+ help="预训练模型路径或HuggingFace模型ID"
568
+ )
569
+ parser.add_argument(
570
+ "--lora_path",
571
+ type=str,
572
+ default=None,
573
+ help="LoRA权重文件路径"
574
+ )
575
+ parser.add_argument(
576
+ "--revision",
577
+ type=str,
578
+ default=None,
579
+ help="模型修订版本"
580
+ )
581
+ parser.add_argument(
582
+ "--variant",
583
+ type=str,
584
+ default=None,
585
+ help="模型变体,如fp16"
586
+ )
587
+
588
+ # 采样参数
589
+ parser.add_argument(
590
+ "--num_inference_steps",
591
+ type=int,
592
+ default=28,
593
+ help="推理步数"
594
+ )
595
+ parser.add_argument(
596
+ "--guidance_scale",
597
+ type=float,
598
+ default=7.0,
599
+ help="引导尺度"
600
+ )
601
+ parser.add_argument(
602
+ "--height",
603
+ type=int,
604
+ default=1024,
605
+ help="生成图像高度"
606
+ )
607
+ parser.add_argument(
608
+ "--width",
609
+ type=int,
610
+ default=1024,
611
+ help="生成图像宽度"
612
+ )
613
+ parser.add_argument(
614
+ "--negative_prompt",
615
+ type=str,
616
+ default="",
617
+ help="负面提示词"
618
+ )
619
+
620
+ # 批处理和数据集参数
621
+ parser.add_argument(
622
+ "--per_proc_batch_size",
623
+ type=int,
624
+ default=1,
625
+ help="每个进程的批处理大小"
626
+ )
627
+ parser.add_argument(
628
+ "--sample_dir",
629
+ type=str,
630
+ default="sd3_lora_samples",
631
+ help="样本保存目录"
632
+ )
633
+
634
+ # Caption相关参数
635
+ parser.add_argument(
636
+ "--captions_jsonl",
637
+ type=str,
638
+ required=True,
639
+ help="包含caption列表的JSONL文件路径"
640
+ )
641
+ parser.add_argument(
642
+ "--images_per_caption",
643
+ type=int,
644
+ default=1,
645
+ help="每个caption生成的图像数量"
646
+ )
647
+ parser.add_argument(
648
+ "--max_samples",
649
+ type=int,
650
+ default=30000,
651
+ help="最大样本生成数量"
652
+ )
653
+
654
+ # 其他参数
655
+ parser.add_argument(
656
+ "--global_seed",
657
+ type=int,
658
+ default=42,
659
+ help="全局随机种子"
660
+ )
661
+ parser.add_argument(
662
+ "--mixed_precision",
663
+ type=str,
664
+ default="fp16",
665
+ choices=["no", "fp16", "bf16"],
666
+ help="混合精度类型"
667
+ )
668
+ parser.add_argument(
669
+ "--enable_cpu_offload",
670
+ action="store_true",
671
+ help="启用CPU offload以节省显存"
672
+ )
673
+
674
+ args = parser.parse_args()
675
+ main(args)
sample_sd3_lora_rn_pair_ddp.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ DDP对照采样:同一文本+同一初始噪声,分别生成 LoRA 与 RN 两类图像,并输出 pair 拼接图与 metadata。
5
+ """
6
+
7
+ import argparse
8
+ import importlib.util
9
+ import json
10
+ import math
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from PIL import Image
18
+ from tqdm import tqdm
19
+
20
+ from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline
21
+
22
+
23
+ def dynamic_import_training_classes(project_root: str):
24
+ sys.path.insert(0, project_root)
25
+ import train_rectified_noise as trn
26
+
27
+ return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
28
+
29
+
30
+ def load_local_pipeline_class(local_pipeline_path: str):
31
+ """
32
+ 从本地文件加载 StableDiffusion3Pipeline。
33
+ 通过将模块名挂在 diffusers.pipelines.stable_diffusion_3 下,兼容文件内的相对导入。
34
+ """
35
+ module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3"
36
+ spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path)
37
+ if spec is None or spec.loader is None:
38
+ raise ImportError(f"Failed to build import spec from: {local_pipeline_path}")
39
+ module = importlib.util.module_from_spec(spec)
40
+ spec.loader.exec_module(module)
41
+ if not hasattr(module, "StableDiffusion3Pipeline"):
42
+ raise ImportError("Local pipeline file has no StableDiffusion3Pipeline symbol.")
43
+ return module.StableDiffusion3Pipeline
44
+
45
+
46
+ def load_captions_from_jsonl(jsonl_path):
47
+ captions = []
48
+ with open(jsonl_path, "r", encoding="utf-8") as f:
49
+ for line in f:
50
+ line = line.strip()
51
+ if not line:
52
+ continue
53
+ try:
54
+ data = json.loads(line)
55
+ cap = None
56
+ for field in ["caption", "text", "prompt", "description"]:
57
+ if field in data and isinstance(data[field], str):
58
+ cap = data[field].strip()
59
+ break
60
+ if cap:
61
+ captions.append(cap)
62
+ except Exception:
63
+ continue
64
+ return captions if captions else ["a beautiful high quality image"]
65
+
66
+
67
+ def load_sit_weights(rectified_module, weights_path: str):
68
+ if os.path.isdir(weights_path):
69
+ search_dirs = [weights_path, os.path.join(weights_path, "sit_weights")]
70
+ for d in search_dirs:
71
+ if not os.path.exists(d):
72
+ continue
73
+ st = os.path.join(d, "pytorch_sit_weights.safetensors")
74
+ if os.path.exists(st):
75
+ from safetensors.torch import load_file
76
+
77
+ state = load_file(st)
78
+ rectified_module.load_state_dict(state, strict=False)
79
+ return True
80
+ for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
81
+ cand = os.path.join(d, name)
82
+ if os.path.exists(cand):
83
+ state = torch.load(cand, map_location="cpu")
84
+ rectified_module.load_state_dict(state, strict=False)
85
+ return True
86
+ return False
87
+ else:
88
+ if weights_path.endswith(".safetensors"):
89
+ from safetensors.torch import load_file
90
+
91
+ state = load_file(weights_path)
92
+ else:
93
+ state = torch.load(weights_path, map_location="cpu")
94
+ rectified_module.load_state_dict(state, strict=False)
95
+ return True
96
+
97
+
98
+ def save_jsonl_line(path, obj):
99
+ with open(path, "a", encoding="utf-8") as f:
100
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
101
+
102
+
103
+ def load_jsonl(path):
104
+ if not os.path.exists(path):
105
+ return []
106
+ rows = []
107
+ with open(path, "r", encoding="utf-8") as f:
108
+ for line in f:
109
+ line = line.strip()
110
+ if not line:
111
+ continue
112
+ rows.append(json.loads(line))
113
+ return rows
114
+
115
+
116
+ def merge_rank_metadata(out_path, rank_paths):
117
+ rows = []
118
+ for rp in rank_paths:
119
+ rows.extend(load_jsonl(rp))
120
+ rows.sort(key=lambda x: x.get("file_name", ""))
121
+ with open(out_path, "w", encoding="utf-8") as f:
122
+ for r in rows:
123
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
124
+
125
+
126
+ def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device):
127
+ RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
128
+ tfm = base_pipeline.transformer
129
+
130
+ if hasattr(tfm.config, "joint_attention_dim") and tfm.config.joint_attention_dim is not None:
131
+ sit_hidden_size = tfm.config.joint_attention_dim
132
+ elif hasattr(tfm.config, "inner_dim") and tfm.config.inner_dim is not None:
133
+ sit_hidden_size = tfm.config.inner_dim
134
+ else:
135
+ sit_hidden_size = 4096
136
+
137
+ transformer_hidden_size = getattr(tfm.config, "hidden_size", 1536)
138
+ num_attention_heads = getattr(tfm.config, "num_attention_heads", 32)
139
+ input_dim = getattr(tfm.config, "in_channels", 16)
140
+
141
+ rectified_module = RectifiedNoiseModule(
142
+ hidden_size=sit_hidden_size,
143
+ num_sit_layers=num_sit_layers,
144
+ num_attention_heads=num_attention_heads,
145
+ input_dim=input_dim,
146
+ transformer_hidden_size=transformer_hidden_size,
147
+ )
148
+ ok = load_sit_weights(rectified_module, rectified_weights)
149
+ if not ok:
150
+ raise RuntimeError(f"Failed to load rectified weights from: {rectified_weights}")
151
+
152
+ model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device)
153
+ model.eval()
154
+ return model
155
+
156
+
157
+ def create_npz_from_dir(sample_dir, max_samples):
158
+ import numpy as np
159
+
160
+ files = sorted([x for x in os.listdir(sample_dir) if x.endswith(".png") and x[:-4].isdigit()])
161
+ files = files[:max_samples]
162
+ if not files:
163
+ return None
164
+ arr = []
165
+ for fn in tqdm(files, desc=f"npz:{os.path.basename(sample_dir)}"):
166
+ arr.append(np.asarray(Image.open(os.path.join(sample_dir, fn))).astype(np.uint8))
167
+ arr = np.stack(arr)
168
+ out = f"{sample_dir}.npz"
169
+ np.savez(out, arr_0=arr)
170
+ return out
171
+
172
+
173
+ def set_pipeline_modules_eval(pipe):
174
+ """
175
+ Diffusers pipeline 本身没有 .eval(),需要对内部 nn.Module 分别设为 eval。
176
+ """
177
+ for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "model"]:
178
+ module = getattr(pipe, name, None)
179
+ if module is not None and hasattr(module, "eval"):
180
+ module.eval()
181
+
182
+
183
+ def main(args):
184
+ assert torch.cuda.is_available(), "Need GPU"
185
+ dist.init_process_group("nccl")
186
+ rank = dist.get_rank()
187
+ world = dist.get_world_size()
188
+ device = rank % torch.cuda.device_count()
189
+ torch.cuda.set_device(device)
190
+ seed = args.global_seed * world + rank
191
+ torch.manual_seed(seed)
192
+
193
+ dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
194
+
195
+ root = Path(args.sample_dir)
196
+ lora_dir = root / "lora"
197
+ rn_dir = root / "rn"
198
+ pair_dir = root / "pair"
199
+ metadata_path = root / "metadata.jsonl"
200
+ lora_meta = lora_dir / "metadata.jsonl"
201
+ rn_meta = rn_dir / "metadata.jsonl"
202
+ pair_meta = pair_dir / "metadata.jsonl"
203
+ if rank == 0:
204
+ lora_dir.mkdir(parents=True, exist_ok=True)
205
+ rn_dir.mkdir(parents=True, exist_ok=True)
206
+ pair_dir.mkdir(parents=True, exist_ok=True)
207
+ dist.barrier()
208
+
209
+ if args.stage == "lora":
210
+ pipe_lora = DiffusersStableDiffusion3Pipeline.from_pretrained(
211
+ args.pretrained_model_name_or_path,
212
+ revision=args.revision,
213
+ variant=args.variant,
214
+ torch_dtype=dtype,
215
+ ).to(device)
216
+ if args.lora_path:
217
+ pipe_lora.load_lora_weights(args.lora_path)
218
+ pipe_lora.set_progress_bar_config(disable=True)
219
+ set_pipeline_modules_eval(pipe_lora)
220
+
221
+ captions = load_captions_from_jsonl(args.captions_jsonl)
222
+ total_needed = min(len(captions) * args.images_per_caption, args.max_samples)
223
+ n = args.per_proc_batch_size
224
+ global_batch = n * world
225
+ total_samples = int(math.ceil(total_needed / global_batch) * global_batch)
226
+ iters = total_samples // global_batch
227
+ pbar = tqdm(range(iters)) if rank == 0 else range(iters)
228
+
229
+ rank_meta_path = root / f"metadata.rank{rank}.jsonl"
230
+ if rank_meta_path.exists():
231
+ rank_meta_path.unlink()
232
+ rank_lora_meta_path = lora_dir / f"metadata.rank{rank}.jsonl"
233
+ if rank_lora_meta_path.exists():
234
+ rank_lora_meta_path.unlink()
235
+
236
+ for it in pbar:
237
+ for k in range(n):
238
+ global_idx = it * global_batch + k * world + rank
239
+ if global_idx >= total_needed:
240
+ continue
241
+ cap_idx = global_idx // args.images_per_caption
242
+ prompt = captions[cap_idx]
243
+ image_seed = seed + it * 10000 + k * 1000
244
+
245
+ g = torch.Generator(device=device).manual_seed(image_seed)
246
+ latent_h = args.height // pipe_lora.vae_scale_factor
247
+ latent_w = args.width // pipe_lora.vae_scale_factor
248
+ latents = torch.randn(
249
+ (1, pipe_lora.transformer.config.in_channels, latent_h, latent_w),
250
+ device=device,
251
+ dtype=dtype,
252
+ generator=g,
253
+ )
254
+
255
+ with torch.autocast("cuda", dtype=dtype):
256
+ img_lora = pipe_lora(
257
+ prompt=prompt,
258
+ height=args.height,
259
+ width=args.width,
260
+ num_inference_steps=args.num_inference_steps,
261
+ guidance_scale=args.guidance_scale,
262
+ latents=latents,
263
+ num_images_per_prompt=1,
264
+ ).images[0]
265
+
266
+ fn = f"{global_idx:07d}.png"
267
+ img_lora.save(lora_dir / fn)
268
+ save_jsonl_line(str(rank_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed), "lora_file": f"lora/{fn}"})
269
+ save_jsonl_line(str(rank_lora_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed)})
270
+ dist.barrier()
271
+
272
+ dist.barrier()
273
+ if rank == 0:
274
+ merge_rank_metadata(str(metadata_path), [str(root / f"metadata.rank{r}.jsonl") for r in range(world)])
275
+ merge_rank_metadata(str(lora_meta), [str(lora_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
276
+ records = load_jsonl(str(metadata_path))
277
+ create_npz_from_dir(str(lora_dir), len(records))
278
+
279
+ elif args.stage == "rn":
280
+ records = load_jsonl(str(metadata_path))
281
+ if not records:
282
+ raise RuntimeError(f"metadata not found or empty: {metadata_path}. Run --stage lora first.")
283
+ total_needed = min(len(records), args.max_samples)
284
+
285
+ LocalStableDiffusion3Pipeline = load_local_pipeline_class(args.local_pipeline_path)
286
+ pipe_rn = LocalStableDiffusion3Pipeline.from_pretrained(
287
+ args.pretrained_model_name_or_path,
288
+ revision=args.revision,
289
+ variant=args.variant,
290
+ torch_dtype=dtype,
291
+ ).to(device)
292
+ if args.lora_path:
293
+ pipe_rn.load_lora_weights(args.lora_path)
294
+ pipe_rn.model = build_rn_model(pipe_rn, args.rectified_weights, args.num_sit_layers, device)
295
+ pipe_rn.set_progress_bar_config(disable=True)
296
+ set_pipeline_modules_eval(pipe_rn)
297
+
298
+ rank_rn_meta_path = rn_dir / f"metadata.rank{rank}.jsonl"
299
+ if rank_rn_meta_path.exists():
300
+ rank_rn_meta_path.unlink()
301
+
302
+ assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
303
+ pbar = tqdm(assigned) if rank == 0 else assigned
304
+ for rec in pbar:
305
+ fn = rec["file_name"]
306
+ prompt = rec["caption"]
307
+ image_seed = int(rec["seed"])
308
+ g = torch.Generator(device=device).manual_seed(image_seed)
309
+ latent_h = args.height // pipe_rn.vae_scale_factor
310
+ latent_w = args.width // pipe_rn.vae_scale_factor
311
+ latents = torch.randn(
312
+ (1, pipe_rn.transformer.config.in_channels, latent_h, latent_w),
313
+ device=device,
314
+ dtype=dtype,
315
+ generator=g,
316
+ )
317
+ with torch.autocast("cuda", dtype=dtype):
318
+ img_rn = pipe_rn(
319
+ prompt=prompt,
320
+ height=args.height,
321
+ width=args.width,
322
+ num_inference_steps=args.num_inference_steps,
323
+ guidance_scale=args.guidance_scale,
324
+ latents=latents,
325
+ num_images_per_prompt=1,
326
+ ).images[0]
327
+ img_rn.save(rn_dir / fn)
328
+ save_jsonl_line(str(rank_rn_meta_path), {"file_name": fn, "caption": prompt, "seed": image_seed})
329
+ dist.barrier()
330
+ if rank == 0:
331
+ merge_rank_metadata(str(rn_meta), [str(rn_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
332
+ create_npz_from_dir(str(rn_dir), total_needed)
333
+
334
+ elif args.stage == "pair":
335
+ records = load_jsonl(str(metadata_path))
336
+ if not records:
337
+ raise RuntimeError(f"metadata not found: {metadata_path}")
338
+ total_needed = min(len(records), args.max_samples)
339
+ rank_pair_meta_path = pair_dir / f"metadata.rank{rank}.jsonl"
340
+ if rank_pair_meta_path.exists():
341
+ rank_pair_meta_path.unlink()
342
+ assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
343
+ for rec in assigned:
344
+ fn = rec["file_name"]
345
+ lora_img_path = lora_dir / fn
346
+ rn_img_path = rn_dir / fn
347
+ if not lora_img_path.exists() or not rn_img_path.exists():
348
+ continue
349
+ img_lora = Image.open(lora_img_path).convert("RGB")
350
+ img_rn = Image.open(rn_img_path).convert("RGB")
351
+ pair = Image.new("RGB", (img_lora.width + img_rn.width, max(img_lora.height, img_rn.height)))
352
+ pair.paste(img_lora, (0, 0))
353
+ pair.paste(img_rn, (img_lora.width, 0))
354
+ pair.save(pair_dir / fn)
355
+ save_jsonl_line(
356
+ str(rank_pair_meta_path),
357
+ {"file_name": fn, "caption": rec["caption"], "seed": int(rec["seed"]), "pair_file": f"pair/{fn}"},
358
+ )
359
+ dist.barrier()
360
+ if rank == 0:
361
+ merge_rank_metadata(str(pair_meta), [str(pair_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
362
+ # 更新根 metadata,补齐 rn/pair 路径
363
+ rn_set = {r["file_name"] for r in load_jsonl(str(rn_meta))}
364
+ pair_set = {r["file_name"] for r in load_jsonl(str(pair_meta))}
365
+ merged = []
366
+ for r in records[:total_needed]:
367
+ fn = r["file_name"]
368
+ out = dict(r)
369
+ if fn in rn_set:
370
+ out["rn_file"] = f"rn/{fn}"
371
+ if fn in pair_set:
372
+ out["pair_file"] = f"pair/{fn}"
373
+ merged.append(out)
374
+ with open(metadata_path, "w", encoding="utf-8") as f:
375
+ for r in merged:
376
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
377
+
378
+ else:
379
+ raise ValueError(f"Unknown stage: {args.stage}")
380
+
381
+ dist.barrier()
382
+ if rank == 0:
383
+ print(f"Stage {args.stage} done. Output root: {root}")
384
+ dist.barrier()
385
+ dist.destroy_process_group()
386
+
387
+
388
+ if __name__ == "__main__":
389
+ parser = argparse.ArgumentParser(description="DDP compare sampling: LoRA vs RN with same latent/prompt.")
390
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
391
+ parser.add_argument(
392
+ "--local_pipeline_path",
393
+ type=str,
394
+ default=str(Path(__file__).parent / "pipeline_stable_diffusion_3.py"),
395
+ help="RN 分支使用的本地 pipeline 文件路径",
396
+ )
397
+ parser.add_argument("--revision", type=str, default=None)
398
+ parser.add_argument("--variant", type=str, default=None)
399
+ parser.add_argument("--lora_path", type=str, default=None)
400
+ parser.add_argument("--rectified_weights", type=str, required=True)
401
+ parser.add_argument("--num_sit_layers", type=int, default=1)
402
+ parser.add_argument("--captions_jsonl", type=str, required=True)
403
+ parser.add_argument("--sample_dir", type=str, default="./sd3_lora_rn_compare")
404
+ parser.add_argument("--num_inference_steps", type=int, default=40)
405
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
406
+ parser.add_argument("--height", type=int, default=512)
407
+ parser.add_argument("--width", type=int, default=512)
408
+ parser.add_argument("--per_proc_batch_size", type=int, default=4)
409
+ parser.add_argument("--images_per_caption", type=int, default=1)
410
+ parser.add_argument("--max_samples", type=int, default=10000)
411
+ parser.add_argument("--global_seed", type=int, default=42)
412
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
413
+ parser.add_argument("--stage", type=str, default="lora", choices=["lora", "rn", "pair"])
414
+
415
+ args = parser.parse_args()
416
+ main(args)
417
+
sample_sd3_rectified_ddp.py ADDED
@@ -0,0 +1,1316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ 分布式采样脚本:支持指定 LoRA 权重与 Rectified Noise(SIT) 权重
5
+
6
+ 依据 train_rectified_noise.py 的模型结构,加载并组装 SD3WithRectifiedNoise 进行采样。
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import math
13
+ import argparse
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
+ from accelerate import Accelerator
23
+ from diffusers import StableDiffusion3Pipeline
24
+ from peft import LoraConfig, get_peft_model_state_dict
25
+ from peft.utils import set_peft_model_state_dict
26
+
27
+
28
+ def dynamic_import_training_classes(project_root: str):
29
+ """从 train_rectified_noise.py 动态导入 RectifiedNoiseModule 和 SD3WithRectifiedNoise"""
30
+ sys.path.insert(0, project_root)
31
+ try:
32
+ import train_rectified_noise as trn
33
+ return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
34
+ except Exception as e:
35
+ raise ImportError(f"无法从 train_rectified_noise.py 导入类: {e}")
36
+
37
+ def create_npz_from_sample_folder(sample_dir, num_samples):
38
+ """
39
+ 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
40
+ """
41
+ samples = []
42
+ actual_files = []
43
+
44
+ # 收集所有PNG文件
45
+ for filename in sorted(os.listdir(sample_dir)):
46
+ if filename.endswith('.png'):
47
+ actual_files.append(filename)
48
+
49
+ # 按照数量限制处理
50
+ for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
51
+ if i < len(actual_files):
52
+ sample_path = os.path.join(sample_dir, actual_files[i])
53
+ sample_pil = Image.open(sample_path)
54
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
55
+ samples.append(sample_np)
56
+ else:
57
+ # 如果不够,创建空白图像
58
+ sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
59
+ samples.append(sample_np)
60
+
61
+ if samples:
62
+ samples = np.stack(samples)
63
+ npz_path = f"{sample_dir}.npz"
64
+ np.savez(npz_path, arr_0=samples)
65
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
66
+ return npz_path
67
+ else:
68
+ print("No samples found to create npz file.")
69
+ return None
70
+
71
+
72
+ def get_existing_sample_count(sample_dir):
73
+ """获取已存在的样本数量和最大索引"""
74
+ if not os.path.exists(sample_dir):
75
+ return 0, -1
76
+
77
+ existing_files = []
78
+ for filename in os.listdir(sample_dir):
79
+ if filename.endswith('.png') and filename[:-4].isdigit():
80
+ try:
81
+ idx = int(filename[:-4])
82
+ existing_files.append(idx)
83
+ except ValueError:
84
+ continue
85
+
86
+ if not existing_files:
87
+ return 0, -1
88
+
89
+ existing_files.sort()
90
+ max_index = existing_files[-1]
91
+ count = len(existing_files)
92
+
93
+ # 检查是否有缺失的文件(从0到max_index应该连续)
94
+ expected_count = max_index + 1
95
+ if count < expected_count:
96
+ print(f"Warning: Found {count} files but expected {expected_count} (missing some indices)")
97
+
98
+ return count, max_index
99
+
100
+
101
+
102
+ def load_sit_weights(rectified_module, weights_path: str, rank=0):
103
+ """加载 Rectified Noise(SIT) 权重,支持 .safetensors / .bin / .pt
104
+ 支持以下目录结构:
105
+ - weights_path/pytorch_sit_weights.safetensors (直接在主目录)
106
+ - weights_path/sit_weights/pytorch_sit_weights.safetensors (在sit_weights子目录)
107
+ """
108
+ if os.path.isdir(weights_path):
109
+ # 首先尝试在主目录查找
110
+ search_paths = [
111
+ weights_path, # 主目录
112
+ os.path.join(weights_path, "sit_weights"), # sit_weights子目录
113
+ ]
114
+
115
+ for search_dir in search_paths:
116
+ if not os.path.exists(search_dir):
117
+ continue
118
+
119
+ # 优先寻找 safetensors
120
+ st_path = os.path.join(search_dir, "pytorch_sit_weights.safetensors")
121
+ if os.path.exists(st_path):
122
+ try:
123
+ from safetensors.torch import load_file
124
+ if rank == 0:
125
+ print(f"Loading rectified weights from: {st_path}")
126
+ state = load_file(st_path)
127
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
128
+ if rank == 0:
129
+ print(f" Loaded rectified weights: {len(state)} keys")
130
+ if missing_keys:
131
+ print(f" Missing keys: {len(missing_keys)}")
132
+ if unexpected_keys:
133
+ print(f" Unexpected keys: {len(unexpected_keys)}")
134
+ return True
135
+ except Exception as e:
136
+ if rank == 0:
137
+ print(f" Failed to load from {st_path}: {e}")
138
+ continue
139
+
140
+ # 其次寻找 bin/pt
141
+ for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
142
+ cand = os.path.join(search_dir, name)
143
+ if os.path.exists(cand):
144
+ try:
145
+ if rank == 0:
146
+ print(f"Loading rectified weights from: {cand}")
147
+ state = torch.load(cand, map_location="cpu")
148
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
149
+ if rank == 0:
150
+ print(f" Loaded rectified weights: {len(state)} keys")
151
+ if missing_keys:
152
+ print(f" Missing keys: {len(missing_keys)}")
153
+ if unexpected_keys:
154
+ print(f" Unexpected keys: {len(unexpected_keys)}")
155
+ return True
156
+ except Exception as e:
157
+ if rank == 0:
158
+ print(f" Failed to load from {cand}: {e}")
159
+ continue
160
+
161
+ # 兜底:目录下任意 pt/bin
162
+ try:
163
+ for fn in os.listdir(search_dir):
164
+ if fn.endswith((".pt", ".bin")):
165
+ cand = os.path.join(search_dir, fn)
166
+ try:
167
+ if rank == 0:
168
+ print(f"Loading rectified weights from: {cand}")
169
+ state = torch.load(cand, map_location="cpu")
170
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
171
+ if rank == 0:
172
+ print(f" Loaded rectified weights: {len(state)} keys")
173
+ return True
174
+ except Exception as e:
175
+ if rank == 0:
176
+ print(f" Failed to load from {cand}: {e}")
177
+ continue
178
+ except Exception:
179
+ pass
180
+
181
+ if rank == 0:
182
+ print(f" ❌ No rectified weights found in {weights_path} or {os.path.join(weights_path, 'sit_weights')}")
183
+ return False
184
+ else:
185
+ # 直接文件
186
+ try:
187
+ if rank == 0:
188
+ print(f"Loading rectified weights from file: {weights_path}")
189
+ if weights_path.endswith(".safetensors"):
190
+ from safetensors.torch import load_file
191
+ state = load_file(weights_path)
192
+ else:
193
+ state = torch.load(weights_path, map_location="cpu")
194
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
195
+ if rank == 0:
196
+ print(f" Loaded rectified weights: {len(state)} keys")
197
+ if missing_keys:
198
+ print(f" Missing keys: {len(missing_keys)}")
199
+ if unexpected_keys:
200
+ print(f" Unexpected keys: {len(unexpected_keys)}")
201
+ return True
202
+ except Exception as e:
203
+ if rank == 0:
204
+ print(f" ❌ Failed to load rectified weights from {weights_path}: {e}")
205
+ return False
206
+
207
+
208
+ def check_lora_weights_exist(lora_path):
209
+ """检查LoRA权重文件是否存在"""
210
+ if not lora_path:
211
+ return False
212
+
213
+ if os.path.isdir(lora_path):
214
+ # 检查目录中是否有pytorch_lora_weights.safetensors文件
215
+ weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
216
+ if os.path.exists(weight_file):
217
+ return True
218
+ # 检查是否有其他.safetensors文件
219
+ for file in os.listdir(lora_path):
220
+ if file.endswith(".safetensors") and "lora" in file.lower():
221
+ return True
222
+ return False
223
+ elif os.path.isfile(lora_path):
224
+ return lora_path.endswith(".safetensors")
225
+
226
+ return False
227
+
228
+
229
+ def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0, lora_rank=64):
230
+ """
231
+ 从accelerator checkpoint目录加载LoRA权重或完整模型权重
232
+ 如果checkpoint包含完整的模型权重(合并后的),直接加载
233
+ 如果只包含LoRA权重,则按LoRA方式加载
234
+ """
235
+ if rank == 0:
236
+ print(f"Loading weights from accelerator checkpoint: {checkpoint_path}")
237
+
238
+ try:
239
+ from safetensors.torch import load_file
240
+ model_file = os.path.join(checkpoint_path, "model.safetensors")
241
+ if not os.path.exists(model_file):
242
+ if rank == 0:
243
+ print(f"Model file not found: {model_file}")
244
+ return False
245
+
246
+ # 加载state dict
247
+ state_dict = load_file(model_file)
248
+ all_keys = list(state_dict.keys())
249
+
250
+ # 检测checkpoint类型:
251
+ # 1. 是否包含base_layer(PEFT格式,需要合并)
252
+ # 2. 是否包含完整的模型权重(合并���的,直接可用)
253
+ # 3. 是否只包含LoRA权重(需要添加适配器)
254
+ lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()]
255
+ base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()]
256
+ non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()]
257
+
258
+ if rank == 0:
259
+ print(f"Checkpoint analysis:")
260
+ print(f" Total keys: {len(all_keys)}")
261
+ print(f" LoRA keys: {len(lora_keys)}")
262
+ print(f" Base layer keys: {len(base_layer_keys)}")
263
+ print(f" Direct transformer weight keys (merged): {len(non_lora_transformer_keys)}")
264
+
265
+ # 如果包含base_layer,说明是PEFT格式,需要合并base_layer + lora
266
+ if len(base_layer_keys) > 0:
267
+ if rank == 0:
268
+ print(f"✓ Detected PEFT format (base_layer + LoRA), merging weights...")
269
+
270
+ # 合并base_layer和lora权重
271
+ merged_state_dict = {}
272
+
273
+ # 首先收集所有需要合并的模块
274
+ modules_to_merge = {}
275
+ # 记录所有非LoRA的transformer权重键名(用于调试)
276
+ non_lora_keys_found = []
277
+
278
+ for key in all_keys:
279
+ # 移除前缀
280
+ new_key = key
281
+ has_transformer_prefix = False
282
+
283
+ if key.startswith('base_model.model.transformer.'):
284
+ new_key = key[len('base_model.model.transformer.'):]
285
+ has_transformer_prefix = True
286
+ elif key.startswith('model.transformer.'):
287
+ new_key = key[len('model.transformer.'):]
288
+ has_transformer_prefix = True
289
+ elif key.startswith('transformer.'):
290
+ new_key = key[len('transformer.'):]
291
+ has_transformer_prefix = True
292
+ elif 'transformer' in key.lower():
293
+ # 可能没有前缀,但包含transformer(如直接是transformer_blocks.0...)
294
+ has_transformer_prefix = True
295
+
296
+ if not has_transformer_prefix:
297
+ continue
298
+
299
+ # 检查是否是base_layer或lora权重
300
+ if '.base_layer.weight' in new_key:
301
+ # 提取模块名(去掉.base_layer.weight部分)
302
+ module_key = new_key.replace('.base_layer.weight', '.weight')
303
+ if module_key not in modules_to_merge:
304
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
305
+ modules_to_merge[module_key]['base_weight'] = (key, state_dict[key])
306
+ elif '.base_layer.bias' in new_key:
307
+ module_key = new_key.replace('.base_layer.bias', '.bias')
308
+ if module_key not in modules_to_merge:
309
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
310
+ modules_to_merge[module_key]['base_bias'] = (key, state_dict[key])
311
+ elif '.lora_A.default.weight' in new_key:
312
+ module_key = new_key.replace('.lora_A.default.weight', '.weight')
313
+ if module_key not in modules_to_merge:
314
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
315
+ modules_to_merge[module_key]['lora_A'] = (key, state_dict[key])
316
+ elif '.lora_B.default.weight' in new_key:
317
+ module_key = new_key.replace('.lora_B.default.weight', '.weight')
318
+ if module_key not in modules_to_merge:
319
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
320
+ modules_to_merge[module_key]['lora_B'] = (key, state_dict[key])
321
+ elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower():
322
+ # 其他非LoRA权重(如pos_embed、time_text_embed、context_embedder等),直接使用
323
+ # 这些权重不在LoRA适配范围内,应该直接从checkpoint加载
324
+ merged_state_dict[new_key] = state_dict[key]
325
+ non_lora_keys_found.append(new_key)
326
+
327
+ if rank == 0:
328
+ print(f" Found {len(non_lora_keys_found)} non-LoRA transformer keys in checkpoint")
329
+ if non_lora_keys_found:
330
+ print(f" Sample non-LoRA keys: {non_lora_keys_found[:10]}")
331
+
332
+ # 合并权重:weight = base_weight + lora_B @ lora_A * (alpha / rank)
333
+ if rank == 0:
334
+ print(f" Merging {len(modules_to_merge)} modules...")
335
+
336
+ import torch
337
+ for module_key, weights in modules_to_merge.items():
338
+ # 处理权重(.weight)
339
+ if weights['base_weight'] is not None:
340
+ base_key, base_weight = weights['base_weight']
341
+ base_weight = base_weight.clone()
342
+
343
+ if weights['lora_A'] is not None and weights['lora_B'] is not None:
344
+ lora_A_key, lora_A = weights['lora_A']
345
+ lora_B_key, lora_B = weights['lora_B']
346
+
347
+ # 检测rank和alpha
348
+ # lora_A: [rank, in_features], lora_B: [out_features, rank]
349
+ rank_value = lora_A.shape[0]
350
+ alpha = rank_value # 通常alpha = rank
351
+
352
+ # 合并:weight = base + (lora_B @ lora_A) * (alpha / rank)
353
+ # lora_B @ lora_A 得到 [out_features, in_features]
354
+ lora_delta = torch.matmul(lora_B, lora_A)
355
+
356
+ if lora_delta.shape == base_weight.shape:
357
+ merged_weight = base_weight + lora_delta * (alpha / rank_value)
358
+ merged_state_dict[module_key] = merged_weight
359
+ if rank == 0 and len(modules_to_merge) <= 20:
360
+ print(f" ✓ Merged {module_key}: {base_weight.shape}")
361
+ else:
362
+ if rank == 0:
363
+ print(f" ⚠️ Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only")
364
+ merged_state_dict[module_key] = base_weight
365
+ else:
366
+ # 只有base权重,没有LoRA
367
+ merged_state_dict[module_key] = base_weight
368
+
369
+ # 处理bias(.bias)- bias通常不需要合并,直接使用base_bias
370
+ if '.bias' in module_key and weights['base_bias'] is not None:
371
+ bias_key, base_bias = weights['base_bias']
372
+ merged_state_dict[module_key] = base_bias.clone()
373
+
374
+ if rank == 0:
375
+ print(f" Merged {len(merged_state_dict)} weights")
376
+ print(f" Sample merged keys: {list(merged_state_dict.keys())[:5]}")
377
+
378
+ # 加载合并后的权重
379
+ try:
380
+ missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False)
381
+
382
+ if rank == 0:
383
+ print(f" Loaded merged weights:")
384
+ print(f" Missing keys: {len(missing_keys)}")
385
+ print(f" Unexpected keys: {len(unexpected_keys)}")
386
+ if missing_keys:
387
+ print(f" Missing keys: {missing_keys}")
388
+ # 检查缺失的keys是否关键
389
+ critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out']
390
+ has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys)
391
+ if has_critical:
392
+ print(f" ⚠️ WARNING: Missing critical keys! These should be loaded from pretrained model.")
393
+ print(f" The missing keys will use values from the pretrained model (not fine-tuned).")
394
+
395
+ # 如果缺失的keys太多或包含关键组件,给出警告
396
+ if len(missing_keys) > 0:
397
+ # 这些缺失的keys会使用pretrained model的默认值
398
+ # 这是正常的,因为LoRA只适配了部分层,其他层保持原样
399
+ if rank == 0:
400
+ print(f" Note: Missing keys will use pretrained model weights (not fine-tuned)")
401
+
402
+ if rank == 0:
403
+ print(f" ✓ Successfully loaded merged model weights")
404
+ return True
405
+
406
+ except Exception as e:
407
+ if rank == 0:
408
+ print(f" ❌ Error loading merged weights: {e}")
409
+ import traceback
410
+ traceback.print_exc()
411
+ return False
412
+
413
+ # 如果包含非LoRA的transformer权重(且没有base_layer),说明是合并后的完整模型
414
+ elif len(non_lora_transformer_keys) > 0:
415
+ if rank == 0:
416
+ print(f"✓ Detected merged model weights (contains full transformer weights)")
417
+ print(f" Loading full model weights directly...")
418
+
419
+ # 提取transformer相关的权重(包括LoRA和基础权重)
420
+ transformer_state_dict = {}
421
+ for key, value in state_dict.items():
422
+ # 移除可能的accelerator包装前缀
423
+ new_key = key
424
+ if key.startswith('base_model.model.transformer.'):
425
+ new_key = key[len('base_model.model.transformer.'):]
426
+ elif key.startswith('model.transformer.'):
427
+ new_key = key[len('model.transformer.'):]
428
+ elif key.startswith('transformer.'):
429
+ new_key = key[len('transformer.'):]
430
+
431
+ # 只保留transformer相关的权重(包括所有transformer子模块)
432
+ # 检查是否是transformer的权重(不包含text_encoder等)
433
+ if (new_key.startswith('transformer_blocks') or
434
+ new_key.startswith('pos_embed') or
435
+ new_key.startswith('time_text_embed') or
436
+ 'lora' in new_key.lower()): # 也包含LoRA权重(如果存在)
437
+ transformer_state_dict[new_key] = value
438
+
439
+ if rank == 0:
440
+ print(f" Extracted {len(transformer_state_dict)} transformer weight keys")
441
+ print(f" Sample keys: {list(transformer_state_dict.keys())[:5]}")
442
+
443
+ # 直接加载到transformer(不使用LoRA适配器)
444
+ try:
445
+ missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False)
446
+
447
+ if rank == 0:
448
+ print(f" Loaded full model weights:")
449
+ print(f" Missing keys: {len(missing_keys)}")
450
+ print(f" Unexpected keys: {len(unexpected_keys)}")
451
+ if missing_keys:
452
+ print(f" Sample missing keys: {missing_keys[:5]}")
453
+ if unexpected_keys:
454
+ print(f" Sample unexpected keys: {unexpected_keys[:5]}")
455
+
456
+ # 如果missing keys太多,可能有问题
457
+ if len(missing_keys) > len(transformer_state_dict) * 0.5:
458
+ if rank == 0:
459
+ print(f" ⚠️ WARNING: Too many missing keys, weights may not be fully loaded")
460
+ return False
461
+
462
+ if rank == 0:
463
+ print(f" ✓ Successfully loaded merged model weights")
464
+ return True
465
+
466
+ except Exception as e:
467
+ if rank == 0:
468
+ print(f" ❌ Error loading full model weights: {e}")
469
+ import traceback
470
+ traceback.print_exc()
471
+ return False
472
+
473
+ # 如果只包含LoRA权重,按原来的方式加载
474
+ if rank == 0:
475
+ print(f"Detected LoRA-only weights, loading as LoRA adapter...")
476
+
477
+ # 首先尝试从checkpoint中检测实际的rank
478
+ detected_rank = None
479
+ for key, value in state_dict.items():
480
+ if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
481
+ # lora_A的形状是 [rank, hidden_size]
482
+ detected_rank = value.shape[0]
483
+ if rank == 0:
484
+ print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
485
+ break
486
+
487
+ # 如果检测到rank,使用检测到的rank;否则使用传入的rank
488
+ actual_rank = detected_rank if detected_rank is not None else lora_rank
489
+ if detected_rank is not None and detected_rank != lora_rank:
490
+ if rank == 0:
491
+ print(f"⚠️ Warning: Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank")
492
+
493
+ # 检查适配器是否已存在,如果存在则先卸载
494
+ # SD3Transformer2DModel没有delete_adapter方法,需要使用unload_lora_weights
495
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
496
+ if "default" in pipeline.transformer.peft_config:
497
+ if rank == 0:
498
+ print("Removing existing 'default' adapter before adding new one...")
499
+ try:
500
+ # 使用pipeline的unload_lora_weights方法
501
+ pipeline.unload_lora_weights()
502
+ if rank == 0:
503
+ print("Successfully unloaded existing LoRA adapter")
504
+ except Exception as e:
505
+ if rank == 0:
506
+ print(f"❌ ERROR: Could not unload existing adapter: {e}")
507
+ print("Cannot proceed without cleaning up adapter")
508
+ return False
509
+
510
+ # 先配置LoRA适配器(必须在加载之前配置)
511
+ # 使用检测到的或传入的rank
512
+ transformer_lora_config = LoraConfig(
513
+ r=actual_rank,
514
+ lora_alpha=actual_rank,
515
+ init_lora_weights="gaussian",
516
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
517
+ )
518
+
519
+ # 为transformer添加LoRA适配器
520
+ pipeline.transformer.add_adapter(transformer_lora_config)
521
+
522
+ if rank == 0:
523
+ print(f"LoRA adapter configured with rank={actual_rank}")
524
+
525
+ # 继续处理LoRA权重加载(state_dict已经在上面加载了)
526
+
527
+ # 提取LoRA权重 - accelerator保存的格式
528
+ # 从accelerator checkpoint的model.safetensors中,键名格式可能是:
529
+ # - transformer_blocks.X.attn.to_q.lora_A.default.weight (PEFT格式,直接可用)
530
+ # - 或者包含其他前缀
531
+ lora_state_dict = {}
532
+ for key, value in state_dict.items():
533
+ if 'lora' in key.lower() and 'transformer' in key.lower():
534
+ # 检查键名格式
535
+ new_key = key
536
+
537
+ # 移除可能的accelerator包装前缀
538
+ # accelerator可能保存为: model.transformer.transformer_blocks...
539
+ # 或者: base_model.model.transformer.transformer_blocks...
540
+ if key.startswith('base_model.model.transformer.'):
541
+ new_key = key[len('base_model.model.transformer.'):]
542
+ elif key.startswith('model.transformer.'):
543
+ new_key = key[len('model.transformer.'):]
544
+ elif key.startswith('transformer.'):
545
+ # 如果已经是transformer_blocks开头,不需要移除transformer.前缀
546
+ # 因为transformer_blocks是transformer的子模块
547
+ if not key[len('transformer.'):].startswith('transformer_blocks'):
548
+ new_key = key[len('transformer.'):]
549
+ else:
550
+ new_key = key[len('transformer.'):]
551
+
552
+ # 只保留transformer相关的LoRA权重
553
+ if 'transformer_blocks' in new_key or 'transformer' in new_key:
554
+ lora_state_dict[new_key] = value
555
+
556
+ if not lora_state_dict:
557
+ if rank == 0:
558
+ print("No LoRA weights found in checkpoint")
559
+ # 打印所有键名用于调试
560
+ all_keys = list(state_dict.keys())
561
+ print(f"Total keys: {len(all_keys)}")
562
+ print(f"First 20 keys: {all_keys[:20]}")
563
+ # 查找包含lora的键
564
+ lora_related = [k for k in all_keys if 'lora' in k.lower()]
565
+ if lora_related:
566
+ print(f"Keys containing 'lora': {lora_related[:10]}")
567
+ return False
568
+
569
+ if rank == 0:
570
+ print(f"Found {len(lora_state_dict)} LoRA weight keys")
571
+ sample_keys = list(lora_state_dict.keys())[:5]
572
+ print(f"Sample LoRA keys: {sample_keys}")
573
+
574
+ # 加载LoRA权重到transformer
575
+ # 注意:从checkpoint提取的键名格式已经是PEFT格式(如:transformer_blocks.0.attn.to_q.lora_A.default.weight)
576
+ # 不需要使用convert_unet_state_dict_to_peft转换,直接使用即可
577
+ try:
578
+ # 检查键名格式
579
+ sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
580
+
581
+ if rank == 0:
582
+ print(f"Original key format: {sample_key}")
583
+
584
+ # 关键问题:set_peft_model_state_dict期望的键名格式
585
+ # 从back/train_dreambooth_lora.py看,需要移除.default后缀
586
+ # 格式应该是:transformer_blocks.X.attn.to_q.lora_A.weight(没有.default)
587
+ # 但accelerator保存的格式是:transformer_blocks.X.attn.to_q.lora_A.default.weight(有.default)
588
+
589
+ # 检查键名格式
590
+ sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
591
+ has_default_suffix = '.default.weight' in sample_key or '.default.bias' in sample_key
592
+
593
+ if rank == 0:
594
+ print(f"Sample key: {sample_key}")
595
+ print(f"Has .default suffix: {has_default_suffix}")
596
+
597
+ # 如果键名包含.default.weight或.default.bias,需要移除.default部分
598
+ # 因为set_peft_model_state_dict期望的格式是:lora_A.weight,而不是lora_A.default.weight
599
+ converted_dict = {}
600
+ for key, value in lora_state_dict.items():
601
+ # 移除.default后缀(如果存在)
602
+ # transformer_blocks.0.attn.to_q.lora_A.default.weight -> transformer_blocks.0.attn.to_q.lora_A.weight
603
+ new_key = key
604
+ if '.default.weight' in new_key:
605
+ new_key = new_key.replace('.default.weight', '.weight')
606
+ elif '.default.bias' in new_key:
607
+ new_key = new_key.replace('.default.bias', '.bias')
608
+ elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')):
609
+ # 处理其他可能的.default位置
610
+ new_key = new_key.replace('.default', '')
611
+
612
+ converted_dict[new_key] = value
613
+
614
+ if rank == 0:
615
+ print(f"Converted {len(converted_dict)} keys (removed .default suffix if present)")
616
+ print(f"Sample converted keys: {list(converted_dict.keys())[:5]}")
617
+
618
+ # 调用set_peft_model_state_dict并检查返回值
619
+ incompatible_keys = set_peft_model_state_dict(
620
+ pipeline.transformer,
621
+ converted_dict,
622
+ adapter_name="default"
623
+ )
624
+
625
+ # 检查加载结果
626
+ if incompatible_keys is not None:
627
+ missing_keys = getattr(incompatible_keys, "missing_keys", [])
628
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", [])
629
+
630
+ if rank == 0:
631
+ print(f"LoRA loading result:")
632
+ print(f" Missing keys: {len(missing_keys)}")
633
+ print(f" Unexpected keys: {len(unexpected_keys)}")
634
+
635
+ if len(missing_keys) > 100:
636
+ print(f" ⚠️ WARNING: Too many missing keys ({len(missing_keys)}), LoRA may not be fully loaded!")
637
+ print(f" Sample missing keys: {missing_keys[:10]}")
638
+ elif missing_keys:
639
+ print(f" Sample missing keys: {missing_keys[:10]}")
640
+
641
+ if unexpected_keys:
642
+ print(f" Unexpected keys: {unexpected_keys[:10]}")
643
+
644
+ # 如果missing keys太多,说明加载失败
645
+ if len(missing_keys) > len(converted_dict) * 0.5: # 超过50%的键缺失
646
+ if rank == 0:
647
+ print("❌ ERROR: Too many missing keys, LoRA weights not loaded correctly!")
648
+ return False
649
+ else:
650
+ if rank == 0:
651
+ print("✓ LoRA weights loaded (no incompatible keys reported)")
652
+
653
+ except RuntimeError as e:
654
+ # 检查是否是size mismatch错误
655
+ error_str = str(e)
656
+ if "size mismatch" in error_str:
657
+ if rank == 0:
658
+ print(f"❌ Size mismatch error: The checkpoint rank doesn't match the adapter rank")
659
+ print(f" This usually means the checkpoint was trained with a different rank")
660
+ # 尝试从错误信息中提取期望的rank
661
+ import re
662
+ # 错误信息格式: "copying a param with shape torch.Size([32, 1536]) from checkpoint"
663
+ match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str)
664
+ if match:
665
+ checkpoint_rank = int(match.group(1))
666
+ if rank == 0:
667
+ print(f" Detected checkpoint rank: {checkpoint_rank}")
668
+ print(f" Adapter was configured with rank: {actual_rank}")
669
+ if checkpoint_rank != actual_rank:
670
+ print(f" ⚠️ Mismatch! Need to recreate adapter with rank={checkpoint_rank}")
671
+ else:
672
+ if rank == 0:
673
+ print(f"❌ Error setting LoRA state dict: {e}")
674
+ import traceback
675
+ traceback.print_exc()
676
+ # 清理适配器以便下次尝试
677
+ try:
678
+ pipeline.unload_lora_weights()
679
+ except:
680
+ pass
681
+ return False
682
+ except Exception as e:
683
+ if rank == 0:
684
+ print(f"❌ Error setting LoRA state dict: {e}")
685
+ import traceback
686
+ traceback.print_exc()
687
+ # 清理适配器以便下次尝试
688
+ try:
689
+ pipeline.unload_lora_weights()
690
+ except:
691
+ pass
692
+ return False
693
+
694
+ # 启用LoRA适配器
695
+ pipeline.transformer.set_adapter("default")
696
+
697
+ # 验证LoRA是否已加载和应用
698
+ if hasattr(pipeline.transformer, 'peft_config'):
699
+ adapters = list(pipeline.transformer.peft_config.keys())
700
+ if rank == 0:
701
+ print(f"LoRA adapters configured: {adapters}")
702
+ # 检查适配器是否启用
703
+ if hasattr(pipeline.transformer, 'active_adapters'):
704
+ # active_adapters 是一个方法,需要调用
705
+ try:
706
+ if callable(pipeline.transformer.active_adapters):
707
+ active = pipeline.transformer.active_adapters()
708
+ else:
709
+ active = pipeline.transformer.active_adapters
710
+ if rank == 0:
711
+ print(f"Active adapters: {active}")
712
+ except:
713
+ if rank == 0:
714
+ print("Could not get active adapters, but LoRA is configured")
715
+
716
+ # 验证LoRA权���是否真的被应用
717
+ # 检查LoRA层的权重是否非零
718
+ lora_layers_found = 0
719
+ nonzero_lora_layers = 0
720
+ total_lora_weight_sum = 0.0
721
+
722
+ for name, module in pipeline.transformer.named_modules():
723
+ if 'lora_A' in name or 'lora_B' in name:
724
+ lora_layers_found += 1
725
+ if hasattr(module, 'weight') and module.weight is not None:
726
+ weight_sum = module.weight.abs().sum().item()
727
+ total_lora_weight_sum += weight_sum
728
+ if weight_sum > 1e-6: # 非零阈值
729
+ nonzero_lora_layers += 1
730
+ if rank == 0 and nonzero_lora_layers <= 3: # 只打印前3个
731
+ print(f"✓ Found non-zero LoRA weight in: {name}, sum={weight_sum:.6f}")
732
+
733
+ if rank == 0:
734
+ print(f"LoRA verification:")
735
+ print(f" Total LoRA layers found: {lora_layers_found}")
736
+ print(f" Non-zero LoRA layers: {nonzero_lora_layers}")
737
+ print(f" Total LoRA weight sum: {total_lora_weight_sum:.6f}")
738
+
739
+ if lora_layers_found == 0:
740
+ print("❌ ERROR: No LoRA layers found in transformer!")
741
+ return False
742
+ elif nonzero_lora_layers == 0:
743
+ print("❌ ERROR: All LoRA weights are zero, LoRA not loaded correctly!")
744
+ return False
745
+ elif nonzero_lora_layers < lora_layers_found * 0.5:
746
+ print(f"⚠️ WARNING: Only {nonzero_lora_layers}/{lora_layers_found} LoRA layers have non-zero weights!")
747
+ print("⚠️ LoRA may not be fully applied!")
748
+ else:
749
+ print(f"✓ LoRA weights verified: {nonzero_lora_layers}/{lora_layers_found} layers have non-zero weights")
750
+
751
+ if nonzero_lora_layers == 0:
752
+ return False
753
+
754
+ if rank == 0:
755
+ print("✓ Successfully loaded and verified LoRA weights from checkpoint")
756
+
757
+ return True
758
+
759
+ except Exception as e:
760
+ if rank == 0:
761
+ print(f"Error loading LoRA from checkpoint: {e}")
762
+ import traceback
763
+ traceback.print_exc()
764
+ return False
765
+
766
+
767
+ def load_captions_from_jsonl(jsonl_path):
768
+ captions = []
769
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
770
+ for line in f:
771
+ line = line.strip()
772
+ if not line:
773
+ continue
774
+ try:
775
+ data = json.loads(line)
776
+ cap = None
777
+ for field in ['caption', 'text', 'prompt', 'description']:
778
+ if field in data and isinstance(data[field], str):
779
+ cap = data[field].strip()
780
+ break
781
+ if cap:
782
+ captions.append(cap)
783
+ except Exception:
784
+ continue
785
+ return captions if captions else ["a beautiful high quality image"]
786
+
787
+
788
+ def main(args):
789
+ assert torch.cuda.is_available(), "需要GPU运行"
790
+ dist.init_process_group("nccl")
791
+ rank = dist.get_rank()
792
+ world_size = dist.get_world_size()
793
+ device = rank % torch.cuda.device_count()
794
+ torch.cuda.set_device(device)
795
+ seed = args.global_seed * world_size + rank
796
+ torch.manual_seed(seed)
797
+
798
+ print(f"[rank{rank}] DDP initialized, device={device}, seed={seed}, world_size={world_size}")
799
+
800
+ # 调试:打印接收到的参数
801
+ if rank == 0:
802
+ print("=" * 80)
803
+ print("参数检查:")
804
+ print(f" lora_path: {args.lora_path}")
805
+ print(f" rectified_weights: {args.rectified_weights}")
806
+ print(f" lora_path is None: {args.lora_path is None}")
807
+ print(f" lora_path is empty: {args.lora_path == '' if args.lora_path else 'N/A'}")
808
+ print(f" rectified_weights is None: {args.rectified_weights is None}")
809
+ print(f" rectified_weights is empty: {args.rectified_weights == '' if args.rectified_weights else 'N/A'}")
810
+ print("=" * 80)
811
+
812
+ # 导入训练脚本中的类
813
+ RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
814
+
815
+ # 加载 pipeline
816
+ dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
817
+ if rank == 0:
818
+ print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path} (dtype={dtype})")
819
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
820
+ args.pretrained_model_name_or_path,
821
+ revision=args.revision,
822
+ variant=args.variant,
823
+ torch_dtype=dtype,
824
+ ).to(device)
825
+
826
+ print(f"[rank{rank}] Pipeline loaded and moved to device {device}")
827
+
828
+ # 加载 LoRA(可选)
829
+ lora_loaded = False
830
+ if args.lora_path:
831
+ if rank == 0:
832
+ print(f"Attempting to load LoRA weights from: {args.lora_path}")
833
+ print(f"LoRA path exists: {os.path.exists(args.lora_path) if args.lora_path else False}")
834
+
835
+ # 首��检查是否是标准的LoRA权重文件/目录
836
+ if check_lora_weights_exist(args.lora_path):
837
+ if rank == 0:
838
+ print("Found standard LoRA weights, loading...")
839
+ try:
840
+ # 检查加载前的transformer参数(用于验证)
841
+ if rank == 0:
842
+ sample_param_before = next(iter(pipeline.transformer.parameters())).clone()
843
+ print(f"Sample transformer param before LoRA (first 5 values): {sample_param_before.flatten()[:5]}")
844
+
845
+ pipeline.load_lora_weights(args.lora_path)
846
+ lora_loaded = True
847
+
848
+ # 验证LoRA是否真的被加载
849
+ if rank == 0:
850
+ sample_param_after = next(iter(pipeline.transformer.parameters())).clone()
851
+ param_diff = (sample_param_after - sample_param_before).abs().max().item()
852
+ print(f"Sample transformer param after LoRA (first 5 values): {sample_param_after.flatten()[:5]}")
853
+ print(f"Max parameter change after LoRA loading: {param_diff}")
854
+ if param_diff < 1e-6:
855
+ print("⚠️ WARNING: LoRA weights may not have been applied (parameter change is very small)")
856
+ else:
857
+ print("✓ LoRA weights appear to have been applied")
858
+
859
+ # 检查是否有peft_config
860
+ if hasattr(pipeline.transformer, 'peft_config'):
861
+ print(f"✓ PEFT config found: {list(pipeline.transformer.peft_config.keys())}")
862
+ else:
863
+ print("⚠️ WARNING: No peft_config found after loading LoRA")
864
+
865
+ if rank == 0:
866
+ print("LoRA loaded successfully from standard format.")
867
+ except Exception as e:
868
+ if rank == 0:
869
+ print(f"Failed to load LoRA from standard format: {e}")
870
+ import traceback
871
+ traceback.print_exc()
872
+
873
+ # 如果不是标准格式,尝试从accelerator checkpoint加载
874
+ if not lora_loaded and os.path.isdir(args.lora_path):
875
+ if rank == 0:
876
+ print("Standard LoRA weights not found, trying accelerator checkpoint format...")
877
+
878
+ # 首先尝试从checkpoint的model.safetensors中检测实际的rank
879
+ # 通过检查LoRA权重的形状来推断rank
880
+ detected_rank = None
881
+ try:
882
+ from safetensors.torch import load_file
883
+ model_file = os.path.join(args.lora_path, "model.safetensors")
884
+ if os.path.exists(model_file):
885
+ state_dict = load_file(model_file)
886
+ # 查找一个LoRA权重来确定rank
887
+ for key, value in state_dict.items():
888
+ if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
889
+ # lora_A的形状是 [rank, hidden_size]
890
+ detected_rank = value.shape[0]
891
+ if rank == 0:
892
+ print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
893
+ break
894
+ except Exception as e:
895
+ if rank == 0:
896
+ print(f"Could not detect rank from checkpoint: {e}")
897
+
898
+ # 构建rank尝试列表
899
+ # 如果检测到rank,优先使用检测到的rank,只尝试一次
900
+ # 如果未检测到,尝试常见的rank值
901
+ if detected_rank is not None:
902
+ rank_list = [detected_rank]
903
+ if rank == 0:
904
+ print(f"Using detected rank: {detected_rank}")
905
+ else:
906
+ # 如果检测失败,尝试常见的rank值(按用户指定的rank优先)
907
+ rank_list = []
908
+ # 如果用户指定了rank(从args.lora_rank),优先尝试
909
+ if hasattr(args, 'lora_rank') and args.lora_rank:
910
+ rank_list.append(args.lora_rank)
911
+ # 添加其他常见的rank值
912
+ for r in [32, 64, 16, 128]:
913
+ if r not in rank_list:
914
+ rank_list.append(r)
915
+ if rank == 0:
916
+ print(f"Rank detection failed, will try ranks in order: {rank_list}")
917
+
918
+ # 尝试不同的rank值
919
+ for lora_rank in rank_list:
920
+ # 在尝试新的rank之前,先清理已存在的适配器
921
+ # 重要:每次尝试前都要清理,否则适配器会保留之前的rank配置
922
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
923
+ if "default" in pipeline.transformer.peft_config:
924
+ try:
925
+ # 使用pipeline的unload_lora_weights方法
926
+ pipeline.unload_lora_weights()
927
+ if rank == 0:
928
+ print(f"Cleaned up existing adapter before trying rank={lora_rank}")
929
+ except Exception as e:
930
+ if rank == 0:
931
+ print(f"Warning: Could not unload adapter: {e}")
932
+ # 如果卸载失败,需要重新创建pipeline
933
+ if rank == 0:
934
+ print("⚠️ WARNING: Cannot unload adapter, will recreate pipeline...")
935
+ # 重新加载pipeline(最后手段)
936
+ try:
937
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
938
+ args.pretrained_model_name_or_path,
939
+ revision=args.revision,
940
+ variant=args.variant,
941
+ torch_dtype=dtype,
942
+ ).to(device)
943
+ if rank == 0:
944
+ print("Pipeline recreated to clear adapter state")
945
+ except Exception as e2:
946
+ if rank == 0:
947
+ print(f"Failed to recreate pipeline: {e2}")
948
+
949
+ if rank == 0:
950
+ print(f"Trying to load with LoRA rank={lora_rank}...")
951
+ lora_loaded = load_lora_from_checkpoint(pipeline, args.lora_path, rank=rank, lora_rank=lora_rank)
952
+ if lora_loaded:
953
+ if rank == 0:
954
+ print(f"✓ Successfully loaded LoRA with rank={lora_rank}")
955
+ break
956
+ elif rank == 0:
957
+ print(f"✗ Failed to load with rank={lora_rank}, trying next rank...")
958
+
959
+ # 如果checkpoint目录加载失败,尝试从输出目录的根目录加载标准LoRA权重
960
+ if not lora_loaded and os.path.isdir(args.lora_path):
961
+ # 检查输出目录的根目录(checkpoint的父目录)
962
+ output_dir = os.path.dirname(args.lora_path.rstrip('/'))
963
+ if output_dir and os.path.exists(output_dir):
964
+ if rank == 0:
965
+ print(f"Trying to load standard LoRA weights from output directory: {output_dir}")
966
+ if check_lora_weights_exist(output_dir):
967
+ try:
968
+ pipeline.load_lora_weights(output_dir)
969
+ lora_loaded = True
970
+ if rank == 0:
971
+ print("LoRA loaded successfully from output directory.")
972
+ except Exception as e:
973
+ if rank == 0:
974
+ print(f"Failed to load LoRA from output directory: {e}")
975
+
976
+ if not lora_loaded:
977
+ if rank == 0:
978
+ print(f"⚠️ WARNING: Failed to load LoRA weights from {args.lora_path}, using baseline model")
979
+ else:
980
+ # 最终验证LoRA是否真的被启用
981
+ if rank == 0:
982
+ print("=" * 80)
983
+ print("LoRA 加载验证:")
984
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
985
+ print(f" ✓ PEFT config exists: {list(pipeline.transformer.peft_config.keys())}")
986
+ # 检查LoRA层的权重
987
+ lora_layers_found = 0
988
+ for name, module in pipeline.transformer.named_modules():
989
+ if 'lora_A' in name or 'lora_B' in name:
990
+ lora_layers_found += 1
991
+ if lora_layers_found <= 3: # 只打印前3个
992
+ if hasattr(module, 'weight'):
993
+ weight_sum = module.weight.abs().sum().item() if module.weight is not None else 0
994
+ print(f" ✓ Found LoRA layer: {name}, weight_sum={weight_sum:.6f}")
995
+ print(f" ✓ Total LoRA layers found: {lora_layers_found}")
996
+ if lora_layers_found == 0:
997
+ print(" ⚠️ WARNING: No LoRA layers found in transformer!")
998
+ else:
999
+ print(" ⚠️ WARNING: No PEFT config found - LoRA may not be active!")
1000
+ print("=" * 80)
1001
+
1002
+ # 构建 RectifiedNoiseModule 并加载权重(仅在提供了 rectified_weights 时)
1003
+ # 安全地检查 rectified_weights 是否有效
1004
+ use_rectified = False
1005
+ rectified_weights_path = None
1006
+ if args.rectified_weights:
1007
+ rectified_weights_str = str(args.rectified_weights).strip()
1008
+ if rectified_weights_str:
1009
+ use_rectified = True
1010
+ rectified_weights_path = rectified_weights_str
1011
+
1012
+ if rank == 0:
1013
+ print(f"use_rectified: {use_rectified}, rectified_weights_path: {rectified_weights_path}")
1014
+
1015
+ if use_rectified:
1016
+ if rank == 0:
1017
+ print(f"Using Rectified Noise module with weights from: {rectified_weights_path}")
1018
+ print(f"[rank{rank}] RectifiedNoiseModule configuration: num_sit_layers={args.num_sit_layers}")
1019
+
1020
+ # 从 transformer 配置推断必要尺寸
1021
+ tfm = pipeline.transformer
1022
+ if hasattr(tfm.config, 'joint_attention_dim') and tfm.config.joint_attention_dim is not None:
1023
+ sit_hidden_size = tfm.config.joint_attention_dim
1024
+ elif hasattr(tfm.config, 'inner_dim') and tfm.config.inner_dim is not None:
1025
+ sit_hidden_size = tfm.config.inner_dim
1026
+ elif hasattr(tfm.config, 'hidden_size') and tfm.config.hidden_size is not None:
1027
+ sit_hidden_size = tfm.config.hidden_size
1028
+ else:
1029
+ sit_hidden_size = 4096
1030
+
1031
+ transformer_hidden_size = getattr(tfm.config, 'hidden_size', 1536)
1032
+ num_attention_heads = getattr(tfm.config, 'num_attention_heads', 32)
1033
+ input_dim = getattr(tfm.config, 'in_channels', 16)
1034
+
1035
+ rectified_module = RectifiedNoiseModule(
1036
+ hidden_size=sit_hidden_size,
1037
+ num_sit_layers=args.num_sit_layers,
1038
+ num_attention_heads=num_attention_heads,
1039
+ input_dim=input_dim,
1040
+ transformer_hidden_size=transformer_hidden_size,
1041
+ )
1042
+ # 加载 SIT 权重
1043
+ ok = load_sit_weights(rectified_module, rectified_weights_path, rank=rank)
1044
+ if rank == 0:
1045
+ if not ok:
1046
+ print("⚠️ Warning: Failed to load rectified weights, will use baseline model without rectified noise")
1047
+ else:
1048
+ print("✓ Successfully loaded rectified noise weights")
1049
+
1050
+ # 组装 SD3WithRectifiedNoise
1051
+ # 关键:SD3WithRectifiedNoise 会保留 transformer 的引用
1052
+ # 但是,SD3WithRectifiedNoise 在 __init__ 中会冻结 transformer 参数
1053
+ # 这不应该影响 LoRA,因为 LoRA 是作为适配器添加的,不是原始参数
1054
+ # 我们需要确保在创建 SD3WithRectifiedNoise 之前,LoRA 适配器已经正确加载和启用
1055
+ if lora_loaded and rank == 0:
1056
+ print("Creating SD3WithRectifiedNoise with LoRA-enabled transformer...")
1057
+ elif rank == 0:
1058
+ print("Creating SD3WithRectifiedNoise...")
1059
+
1060
+ model = SD3WithRectifiedNoise(pipeline.transformer, rectified_module).to(device)
1061
+
1062
+ # 重要:SD3WithRectifiedNoise 的 __init__ 会冻结 transformer 参数
1063
+ # 但 LoRA 适配器应该仍然有效,因为它们是独立的模块
1064
+ # 我们需要确保 LoRA 适配器在包装后仍然可以访问
1065
+
1066
+ # 确保 LoRA 适配器在模型替换后仍然启用
1067
+ if lora_loaded:
1068
+ # 通过model.transformer访问,因为SD3WithRectifiedNoise包装了transformer
1069
+ if hasattr(model.transformer, 'peft_config'):
1070
+ try:
1071
+ # 确保适配器处于启用状态
1072
+ model.transformer.set_adapter("default_0")
1073
+
1074
+ # 验证LoRA权重在包装后是否仍然存在
1075
+ lora_layers_after_wrap = 0
1076
+ nonzero_after_wrap = 0
1077
+ for name, module in model.transformer.named_modules():
1078
+ if 'lora_A' in name or 'lora_B' in name:
1079
+ lora_layers_after_wrap += 1
1080
+ if hasattr(module, 'weight') and module.weight is not None:
1081
+ if module.weight.abs().sum().item() > 1e-6:
1082
+ nonzero_after_wrap += 1
1083
+
1084
+ if rank == 0:
1085
+ print(f"LoRA after SD3WithRectifiedNoise wrapping:")
1086
+ print(f" LoRA layers: {lora_layers_after_wrap}, Non-zero: {nonzero_after_wrap}")
1087
+ if nonzero_after_wrap == 0:
1088
+ print(" ❌ ERROR: All LoRA weights are zero after wrapping!")
1089
+ elif nonzero_after_wrap < lora_layers_after_wrap * 0.5:
1090
+ print(f" ⚠️ WARNING: Only {nonzero_after_wrap}/{lora_layers_after_wrap} LoRA layers have weights!")
1091
+ else:
1092
+ print(f" ✓ LoRA weights preserved after wrapping")
1093
+
1094
+ # 验证适配器是否真的启用
1095
+ if hasattr(model.transformer, 'active_adapters'):
1096
+ try:
1097
+ if callable(model.transformer.active_adapters):
1098
+ active = model.transformer.active_adapters()
1099
+ else:
1100
+ active = model.transformer.active_adapters
1101
+ if rank == 0:
1102
+ print(f" Active adapters: {active}")
1103
+ except:
1104
+ if rank == 0:
1105
+ print(" LoRA adapter re-enabled after model wrapping")
1106
+ else:
1107
+ if rank == 0:
1108
+ print(" LoRA adapter re-enabled after model wrapping")
1109
+ except Exception as e:
1110
+ if rank == 0:
1111
+ print(f"❌ ERROR: Could not re-enable LoRA adapter: {e}")
1112
+ import traceback
1113
+ traceback.print_exc()
1114
+ else:
1115
+ # LoRA权重已经合并到transformer的基础权重中(合并加载方式)
1116
+ # 这种情况下没有peft_config是正常的,因为LoRA已经合并了
1117
+ if rank == 0:
1118
+ print("LoRA loaded via merged weights (no PEFT adapter needed)")
1119
+ print(" ✓ LoRA weights are already merged into transformer base weights")
1120
+ print(" Note: This is expected when loading from merged checkpoint format")
1121
+
1122
+ # 注册到 pipeline(pipeline_stable_diffusion_3.py 已支持 external model)
1123
+ pipeline.model = model
1124
+
1125
+ # 确保模型处于评估模式(LoRA在eval模式下也应该工作)
1126
+ model.eval()
1127
+ model.transformer.eval() # 确保transformer也处于eval模式
1128
+ else:
1129
+ if rank == 0:
1130
+ print("Not using Rectified Noise module, using baseline SD3 pipeline")
1131
+ # 不使用 SD3WithRectifiedNoise,保持原始 pipeline
1132
+ # pipeline.model 保持为原始的 transformer
1133
+
1134
+ # 关键:确保LoRA适配器在推理时被使用
1135
+ # PEFT模型在eval模式下,LoRA适配器应该自动启用,但我们需要确保
1136
+ if lora_loaded:
1137
+ # 获取正确的 transformer 引用
1138
+ transformer_ref = model.transformer if use_rectified else pipeline.transformer
1139
+
1140
+ # 确保transformer的LoRA适配器处于启用状态
1141
+ if hasattr(transformer_ref, 'set_adapter'):
1142
+ try:
1143
+ transformer_ref.set_adapter("default")
1144
+ except:
1145
+ pass
1146
+
1147
+ # 验证LoRA是否真的会被使用
1148
+ if rank == 0:
1149
+ # 检查一个LoRA层的权重
1150
+ lora_found = False
1151
+ for name, module in transformer_ref.named_modules():
1152
+ if 'lora_A' in name and 'default' in name and hasattr(module, 'weight'):
1153
+ if module.weight is not None:
1154
+ weight_sum = module.weight.abs().sum().item()
1155
+ if weight_sum > 0:
1156
+ print(f"✓ Verified LoRA weight in {name}: sum={weight_sum:.6f}")
1157
+ lora_found = True
1158
+ break
1159
+
1160
+ if not lora_found:
1161
+ print("⚠ Warning: Could not verify LoRA weights in model")
1162
+ else:
1163
+ # 额外检查:验证LoRA层是否真的会被调用
1164
+ # 检查一个LoRA Linear层
1165
+ for name, module in transformer_ref.named_modules():
1166
+ if hasattr(module, '__class__') and 'lora' in module.__class__.__name__.lower():
1167
+ if hasattr(module, 'lora_enabled'):
1168
+ enabled = module.lora_enabled
1169
+ if rank == 0:
1170
+ print(f"✓ Found LoRA layer {name}, enabled: {enabled}")
1171
+ break
1172
+
1173
+ print("Model set to eval mode, LoRA should be active during inference")
1174
+
1175
+ # 启用内存优化选项
1176
+ if args.enable_attention_slicing:
1177
+ if rank == 0:
1178
+ print("Enabling attention slicing to save memory")
1179
+ pipeline.enable_attention_slicing()
1180
+
1181
+ if args.enable_vae_slicing:
1182
+ if rank == 0:
1183
+ print("Enabling VAE slicing to save memory")
1184
+ pipeline.enable_vae_slicing()
1185
+
1186
+ if args.enable_cpu_offload:
1187
+ if rank == 0:
1188
+ print("Enabling CPU offload to save memory")
1189
+ pipeline.enable_model_cpu_offload()
1190
+
1191
+ # 禁用进度条以减少输出
1192
+ pipeline.set_progress_bar_config(disable=True)
1193
+
1194
+ # 读入 captions
1195
+ captions = load_captions_from_jsonl(args.captions_jsonl)
1196
+ total_images_needed = min(len(captions) * args.images_per_caption, args.max_samples)
1197
+
1198
+ # 输出目录
1199
+ if rank == 0:
1200
+ os.makedirs(args.sample_dir, exist_ok=True)
1201
+ dist.barrier()
1202
+
1203
+ # 检查已存在的样本
1204
+ existing_count, max_existing_index = get_existing_sample_count(args.sample_dir)
1205
+ if rank == 0:
1206
+ print(f"Found {existing_count} existing samples, max index: {max_existing_index}")
1207
+
1208
+ # 调整需要生成的样本数量
1209
+ remaining_images_needed = max(0, total_images_needed - existing_count)
1210
+ if remaining_images_needed == 0:
1211
+ if rank == 0:
1212
+ print("All required samples already exist. Skipping generation.")
1213
+ print(f"Creating npz from existing samples...")
1214
+ create_npz_from_sample_folder(args.sample_dir, total_images_needed)
1215
+ return
1216
+
1217
+ if rank == 0:
1218
+ print(f"Need to generate {remaining_images_needed} more samples (total needed: {total_images_needed})")
1219
+
1220
+ n = args.per_proc_batch_size
1221
+ global_batch = n * world_size
1222
+ total_samples = int(math.ceil(remaining_images_needed / global_batch) * global_batch)
1223
+ assert total_samples % world_size == 0
1224
+ samples_per_gpu = total_samples // world_size
1225
+ assert samples_per_gpu % n == 0
1226
+ iterations = samples_per_gpu // n
1227
+
1228
+ if rank == 0:
1229
+ print(f"Sampling remaining={remaining_images_needed}, total_samples={total_samples}, per_gpu={samples_per_gpu}, iterations={iterations}")
1230
+
1231
+ pbar = tqdm(range(iterations)) if rank == 0 else range(iterations)
1232
+ saved = 0
1233
+
1234
+ autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
1235
+ for it in pbar:
1236
+ if rank == 0 and it % 10 == 0:
1237
+ print(f"[rank{rank}] Sampling iteration {it}/{iterations}")
1238
+ batch_prompts = []
1239
+ base_index = it * global_batch + rank
1240
+ for j in range(n):
1241
+ idx = it * global_batch + j * world_size + rank
1242
+ if idx < remaining_images_needed:
1243
+ cap_idx = idx // args.images_per_caption
1244
+ batch_prompts.append(captions[cap_idx])
1245
+ else:
1246
+ batch_prompts.append("a beautiful high quality image")
1247
+
1248
+ with torch.autocast(autocast_device, dtype=dtype):
1249
+ images = []
1250
+ for k, prompt in enumerate(batch_prompts):
1251
+ image_seed = seed + it * 10000 + k * 1000 + rank
1252
+ generator = torch.Generator(device=device).manual_seed(image_seed)
1253
+ img = pipeline(
1254
+ prompt=prompt,
1255
+ height=args.height,
1256
+ width=args.width,
1257
+ num_inference_steps=args.num_inference_steps,
1258
+ guidance_scale=args.guidance_scale,
1259
+ generator=generator,
1260
+ num_images_per_prompt=1,
1261
+ ).images[0]
1262
+ images.append(img)
1263
+
1264
+ # 保存
1265
+ out_dir = Path(args.sample_dir)
1266
+ if rank == 0 and it == 0:
1267
+ print(f"Saving pngs to: {out_dir}")
1268
+ for j, img in enumerate(images):
1269
+ global_index = it * global_batch + j * world_size + rank + existing_count # 加上已存在的数量
1270
+ if global_index < total_images_needed:
1271
+ filename = f"{global_index:07d}.png"
1272
+ img.save(out_dir / filename)
1273
+ saved += 1
1274
+ dist.barrier()
1275
+
1276
+ if rank == 0:
1277
+ print(f"Done. Saved {saved * world_size} images in total.")
1278
+ actual_num_samples = len([name for name in os.listdir(args.sample_dir) if name.endswith(".png")])
1279
+ print(f"Actually generated {actual_num_samples} images")
1280
+ npz_samples = min(actual_num_samples, total_images_needed)
1281
+ print(f"[rank{rank}] Creating npz from sample folder: {args.sample_dir}, npz_samples={npz_samples}")
1282
+ create_npz_from_sample_folder(args.sample_dir, npz_samples)
1283
+ print("Done creating npz.")
1284
+ print("Done.")
1285
+
1286
+
1287
+ if __name__ == "__main__":
1288
+ parser = argparse.ArgumentParser(description="SD3 LoRA + RectifiedNoise 分布式采样脚本")
1289
+ # 模型
1290
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
1291
+ parser.add_argument("--revision", type=str, default=None)
1292
+ parser.add_argument("--variant", type=str, default=None)
1293
+ # LoRA 与 Rectified
1294
+ parser.add_argument("--lora_path", type=str, default=None, help="LoRA 权重路径(文件或目录)")
1295
+ parser.add_argument("--rectified_weights", type=str, default=None, help="Rectified(SIT) 权重路径(文件或目录)")
1296
+ parser.add_argument("--num_sit_layers", type=int, default=1, help="与训练一致的 SIT 层数")
1297
+ # 采样
1298
+ parser.add_argument("--num_inference_steps", type=int, default=28)
1299
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
1300
+ parser.add_argument("--height", type=int, default=1024)
1301
+ parser.add_argument("--width", type=int, default=1024)
1302
+ parser.add_argument("--per_proc_batch_size", type=int, default=1)
1303
+ parser.add_argument("--images_per_caption", type=int, default=1)
1304
+ parser.add_argument("--max_samples", type=int, default=10000)
1305
+ parser.add_argument("--captions_jsonl", type=str, required=True)
1306
+ parser.add_argument("--sample_dir", type=str, default="sd3_rectified_samples")
1307
+ parser.add_argument("--global_seed", type=int, default=42)
1308
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
1309
+ # 内存优化选项
1310
+ parser.add_argument("--enable_attention_slicing", action="store_true", help="启用 attention slicing 以节省显存")
1311
+ parser.add_argument("--enable_vae_slicing", action="store_true", help="启用 VAE slicing 以节省显存")
1312
+ parser.add_argument("--enable_cpu_offload", action="store_true", help="启用 CPU offload 以节省显存")
1313
+
1314
+ args = parser.parse_args()
1315
+ main(args)
1316
+
sample_sd3_rectified_ddp_old.py ADDED
@@ -0,0 +1,1317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ 分布式采样脚本:支持指定 LoRA 权重与 Rectified Noise(SIT) 权重
5
+
6
+ 依据 train_rectified_noise.py 的模型结构,加载并组装 SD3WithRectifiedNoise 进行采样。
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import math
13
+ import argparse
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
+ from accelerate import Accelerator
23
+ from diffusers import StableDiffusion3Pipeline
24
+ from peft import LoraConfig, get_peft_model_state_dict
25
+ from peft.utils import set_peft_model_state_dict
26
+
27
+
28
+ def dynamic_import_training_classes(project_root: str):
29
+ """从 train_rectified_noise.py 动态导入 RectifiedNoiseModule 和 SD3WithRectifiedNoise"""
30
+ sys.path.insert(0, project_root)
31
+ try:
32
+ import train_rectified_noise as trn
33
+ return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
34
+ except Exception as e:
35
+ raise ImportError(f"无法从 train_rectified_noise.py 导入类: {e}")
36
+
37
+ def create_npz_from_sample_folder(sample_dir, num_samples):
38
+ """
39
+ 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
40
+ """
41
+ samples = []
42
+ actual_files = []
43
+
44
+ # 收集所有PNG文件
45
+ for filename in sorted(os.listdir(sample_dir)):
46
+ if filename.endswith('.png'):
47
+ actual_files.append(filename)
48
+
49
+ # 按照数量限制处理
50
+ for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
51
+ if i < len(actual_files):
52
+ sample_path = os.path.join(sample_dir, actual_files[i])
53
+ sample_pil = Image.open(sample_path)
54
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
55
+ samples.append(sample_np)
56
+ else:
57
+ # 如果不够,创建空白图像
58
+ sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
59
+ samples.append(sample_np)
60
+
61
+ if samples:
62
+ samples = np.stack(samples)
63
+ npz_path = f"{sample_dir}.npz"
64
+ np.savez(npz_path, arr_0=samples)
65
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
66
+ return npz_path
67
+ else:
68
+ print("No samples found to create npz file.")
69
+ return None
70
+
71
+
72
+ def load_sit_weights(rectified_module, weights_path: str, rank=0):
73
+ """加载 Rectified Noise(SIT) 权重,支持 .safetensors / .bin / .pt
74
+ 支持以下目录结构:
75
+ - weights_path/pytorch_sit_weights.safetensors (直接在主目录)
76
+ - weights_path/sit_weights/pytorch_sit_weights.safetensors (在sit_weights子目录)
77
+ """
78
+ if os.path.isdir(weights_path):
79
+ # 首先尝试在主目录查找
80
+ search_paths = [
81
+ weights_path, # 主目录
82
+ os.path.join(weights_path, "sit_weights"), # sit_weights子目录
83
+ ]
84
+
85
+ for search_dir in search_paths:
86
+ if not os.path.exists(search_dir):
87
+ continue
88
+
89
+ # 优先寻找 safetensors
90
+ st_path = os.path.join(search_dir, "pytorch_sit_weights.safetensors")
91
+ if os.path.exists(st_path):
92
+ try:
93
+ from safetensors.torch import load_file
94
+ if rank == 0:
95
+ print(f"Loading rectified weights from: {st_path}")
96
+ state = load_file(st_path)
97
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
98
+ if rank == 0:
99
+ print(f" Loaded rectified weights: {len(state)} keys")
100
+ if missing_keys:
101
+ print(f" Missing keys: {len(missing_keys)}")
102
+ if unexpected_keys:
103
+ print(f" Unexpected keys: {len(unexpected_keys)}")
104
+ return True
105
+ except Exception as e:
106
+ if rank == 0:
107
+ print(f" Failed to load from {st_path}: {e}")
108
+ continue
109
+
110
+ # 其次寻找 bin/pt
111
+ for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
112
+ cand = os.path.join(search_dir, name)
113
+ if os.path.exists(cand):
114
+ try:
115
+ if rank == 0:
116
+ print(f"Loading rectified weights from: {cand}")
117
+ state = torch.load(cand, map_location="cpu")
118
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
119
+ if rank == 0:
120
+ print(f" Loaded rectified weights: {len(state)} keys")
121
+ if missing_keys:
122
+ print(f" Missing keys: {len(missing_keys)}")
123
+ if unexpected_keys:
124
+ print(f" Unexpected keys: {len(unexpected_keys)}")
125
+ return True
126
+ except Exception as e:
127
+ if rank == 0:
128
+ print(f" Failed to load from {cand}: {e}")
129
+ continue
130
+
131
+ # 兜底:目录下任意 pt/bin
132
+ try:
133
+ for fn in os.listdir(search_dir):
134
+ if fn.endswith((".pt", ".bin")):
135
+ cand = os.path.join(search_dir, fn)
136
+ try:
137
+ if rank == 0:
138
+ print(f"Loading rectified weights from: {cand}")
139
+ state = torch.load(cand, map_location="cpu")
140
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
141
+ if rank == 0:
142
+ print(f" Loaded rectified weights: {len(state)} keys")
143
+ return True
144
+ except Exception as e:
145
+ if rank == 0:
146
+ print(f" Failed to load from {cand}: {e}")
147
+ continue
148
+ except Exception:
149
+ pass
150
+
151
+ if rank == 0:
152
+ print(f" ❌ No rectified weights found in {weights_path} or {os.path.join(weights_path, 'sit_weights')}")
153
+ return False
154
+ else:
155
+ # 直接文件
156
+ try:
157
+ if rank == 0:
158
+ print(f"Loading rectified weights from file: {weights_path}")
159
+ if weights_path.endswith(".safetensors"):
160
+ from safetensors.torch import load_file
161
+ state = load_file(weights_path)
162
+ else:
163
+ state = torch.load(weights_path, map_location="cpu")
164
+ missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
165
+ if rank == 0:
166
+ print(f" Loaded rectified weights: {len(state)} keys")
167
+ if missing_keys:
168
+ print(f" Missing keys: {len(missing_keys)}")
169
+ if unexpected_keys:
170
+ print(f" Unexpected keys: {len(unexpected_keys)}")
171
+ return True
172
+ except Exception as e:
173
+ if rank == 0:
174
+ print(f" ❌ Failed to load rectified weights from {weights_path}: {e}")
175
+ return False
176
+
177
+
178
+ def check_lora_weights_exist(lora_path):
179
+ """检查LoRA权重文件是否存在"""
180
+ if not lora_path:
181
+ return False
182
+
183
+ if os.path.isdir(lora_path):
184
+ # 检查目录中是否有pytorch_lora_weights.safetensors文件
185
+ weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
186
+ if os.path.exists(weight_file):
187
+ return True
188
+ # 检查是否有其他.safetensors文件
189
+ for file in os.listdir(lora_path):
190
+ if file.endswith(".safetensors") and "lora" in file.lower():
191
+ return True
192
+ return False
193
+ elif os.path.isfile(lora_path):
194
+ return lora_path.endswith(".safetensors")
195
+
196
+ return False
197
+
198
+
199
+ def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0, lora_rank=64):
200
+ """
201
+ 从accelerator checkpoint目录加载LoRA权重或完整模型权重
202
+ 如果checkpoint包含完整的模型权重(合并后的),直接加载
203
+ 如果只包含LoRA权重,则按LoRA方式加载
204
+ """
205
+ if rank == 0:
206
+ print(f"Loading weights from accelerator checkpoint: {checkpoint_path}")
207
+
208
+ try:
209
+ from safetensors.torch import load_file
210
+ model_file = os.path.join(checkpoint_path, "model.safetensors")
211
+ if not os.path.exists(model_file):
212
+ if rank == 0:
213
+ print(f"Model file not found: {model_file}")
214
+ return False
215
+
216
+ # 加载state dict
217
+ state_dict = load_file(model_file)
218
+ all_keys = list(state_dict.keys())
219
+
220
+ # 检测checkpoint类型:
221
+ # 1. 是否包含base_layer(PEFT格式,需要合并)
222
+ # 2. 是否包含完整的模型权重(合并后的,直接可用)
223
+ # 3. 是否只包含LoRA权重(需要添加适配器)
224
+ lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()]
225
+ base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()]
226
+ non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()]
227
+
228
+ if rank == 0:
229
+ print(f"Checkpoint analysis:")
230
+ print(f" Total keys: {len(all_keys)}")
231
+ print(f" LoRA keys: {len(lora_keys)}")
232
+ print(f" Base layer keys: {len(base_layer_keys)}")
233
+ print(f" Direct transformer weight keys (merged): {len(non_lora_transformer_keys)}")
234
+
235
+ # 如果包含base_layer,说明是PEFT格式,需要合并base_layer + lora
236
+ if len(base_layer_keys) > 0:
237
+ if rank == 0:
238
+ print(f"✓ Detected PEFT format (base_layer + LoRA), merging weights...")
239
+
240
+ # 合并base_layer和lora权重
241
+ merged_state_dict = {}
242
+
243
+ # 首先收集所有需要合并的模块
244
+ modules_to_merge = {}
245
+ # 记录所有非LoRA的transformer权重键名(用于调试)
246
+ non_lora_keys_found = []
247
+
248
+ for key in all_keys:
249
+ # 移除前缀
250
+ new_key = key
251
+ has_transformer_prefix = False
252
+
253
+ if key.startswith('base_model.model.transformer.'):
254
+ new_key = key[len('base_model.model.transformer.'):]
255
+ has_transformer_prefix = True
256
+ elif key.startswith('model.transformer.'):
257
+ new_key = key[len('model.transformer.'):]
258
+ has_transformer_prefix = True
259
+ elif key.startswith('transformer.'):
260
+ new_key = key[len('transformer.'):]
261
+ has_transformer_prefix = True
262
+ elif 'transformer' in key.lower():
263
+ # 可能没有前缀,但包含transformer(如直接是transformer_blocks.0...)
264
+ has_transformer_prefix = True
265
+
266
+ if not has_transformer_prefix:
267
+ continue
268
+
269
+ # 检查是否是base_layer或lora权重
270
+ if '.base_layer.weight' in new_key:
271
+ # 提取模块名(去掉.base_layer.weight部分)
272
+ module_key = new_key.replace('.base_layer.weight', '.weight')
273
+ if module_key not in modules_to_merge:
274
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
275
+ modules_to_merge[module_key]['base_weight'] = (key, state_dict[key])
276
+ elif '.base_layer.bias' in new_key:
277
+ module_key = new_key.replace('.base_layer.bias', '.bias')
278
+ if module_key not in modules_to_merge:
279
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
280
+ modules_to_merge[module_key]['base_bias'] = (key, state_dict[key])
281
+ elif '.lora_A.default.weight' in new_key:
282
+ module_key = new_key.replace('.lora_A.default.weight', '.weight')
283
+ if module_key not in modules_to_merge:
284
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
285
+ modules_to_merge[module_key]['lora_A'] = (key, state_dict[key])
286
+ elif '.lora_B.default.weight' in new_key:
287
+ module_key = new_key.replace('.lora_B.default.weight', '.weight')
288
+ if module_key not in modules_to_merge:
289
+ modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
290
+ modules_to_merge[module_key]['lora_B'] = (key, state_dict[key])
291
+ elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower():
292
+ # 其他非LoRA权重(如pos_embed、time_text_embed、context_embedder等),直接使用
293
+ # 这些权重不在LoRA适配范围内,应该直接从checkpoint加载
294
+ merged_state_dict[new_key] = state_dict[key]
295
+ non_lora_keys_found.append(new_key)
296
+
297
+ if rank == 0:
298
+ print(f" Found {len(non_lora_keys_found)} non-LoRA transformer keys in checkpoint")
299
+ if non_lora_keys_found:
300
+ print(f" Sample non-LoRA keys: {non_lora_keys_found[:10]}")
301
+
302
+ # 合并权重:weight = base_weight + lora_B @ lora_A * (alpha / rank)
303
+ if rank == 0:
304
+ print(f" Merging {len(modules_to_merge)} modules...")
305
+
306
+ import torch
307
+ for module_key, weights in modules_to_merge.items():
308
+ # 处理权重(.weight)
309
+ if weights['base_weight'] is not None:
310
+ base_key, base_weight = weights['base_weight']
311
+ base_weight = base_weight.clone()
312
+
313
+ if weights['lora_A'] is not None and weights['lora_B'] is not None:
314
+ lora_A_key, lora_A = weights['lora_A']
315
+ lora_B_key, lora_B = weights['lora_B']
316
+
317
+ # 检测rank和alpha
318
+ # lora_A: [rank, in_features], lora_B: [out_features, rank]
319
+ rank_value = lora_A.shape[0]
320
+ alpha = rank_value # 通常alpha = rank
321
+
322
+ # 合并:weight = base + (lora_B @ lora_A) * (alpha / rank)
323
+ # lora_B @ lora_A 得到 [out_features, in_features]
324
+ lora_delta = torch.matmul(lora_B, lora_A)
325
+
326
+ if lora_delta.shape == base_weight.shape:
327
+ merged_weight = base_weight + lora_delta * (alpha / rank_value)
328
+ merged_state_dict[module_key] = merged_weight
329
+ if rank == 0 and len(modules_to_merge) <= 20:
330
+ print(f" ✓ Merged {module_key}: {base_weight.shape}")
331
+ else:
332
+ if rank == 0:
333
+ print(f" ⚠️ Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only")
334
+ merged_state_dict[module_key] = base_weight
335
+ else:
336
+ # 只有base权重,没有LoRA
337
+ merged_state_dict[module_key] = base_weight
338
+
339
+ # 处理bias(.bias)- bias通常不需要合并,直接使用base_bias
340
+ if '.bias' in module_key and weights['base_bias'] is not None:
341
+ bias_key, base_bias = weights['base_bias']
342
+ merged_state_dict[module_key] = base_bias.clone()
343
+
344
+ if rank == 0:
345
+ print(f" Merged {len(merged_state_dict)} weights")
346
+ print(f" Sample merged keys: {list(merged_state_dict.keys())[:5]}")
347
+
348
+ # 加载合并后的权重
349
+ try:
350
+ missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False)
351
+
352
+ if rank == 0:
353
+ print(f" Loaded merged weights:")
354
+ print(f" Missing keys: {len(missing_keys)}")
355
+ print(f" Unexpected keys: {len(unexpected_keys)}")
356
+ if missing_keys:
357
+ print(f" Missing keys: {missing_keys}")
358
+ # 检查缺失的keys是否关键
359
+ critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out']
360
+ has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys)
361
+ if has_critical:
362
+ print(f" ⚠️ WARNING: Missing critical keys! These should be loaded from pretrained model.")
363
+ print(f" The missing keys will use values from the pretrained model (not fine-tuned).")
364
+
365
+ # 如果缺失的keys太多或包含关键组件,给出警告
366
+ if len(missing_keys) > 0:
367
+ # 这些缺失的keys会使用pretrained model的默认值
368
+ # 这是正常的,因为LoRA只适配了部分层,其他层保持原样
369
+ if rank == 0:
370
+ print(f" Note: Missing keys will use pretrained model weights (not fine-tuned)")
371
+
372
+ if rank == 0:
373
+ print(f" ✓ Successfully loaded merged model weights")
374
+ return True
375
+
376
+ except Exception as e:
377
+ if rank == 0:
378
+ print(f" ❌ Error loading merged weights: {e}")
379
+ import traceback
380
+ traceback.print_exc()
381
+ return False
382
+
383
+ # 如果包含非LoRA的transformer权重(且没有base_layer),说明是合并后的完整模型
384
+ elif len(non_lora_transformer_keys) > 0:
385
+ if rank == 0:
386
+ print(f"✓ Detected merged model weights (contains full transformer weights)")
387
+ print(f" Loading full model weights directly...")
388
+
389
+ # 提取transformer相关的权重(包括LoRA和基础权重)
390
+ transformer_state_dict = {}
391
+ for key, value in state_dict.items():
392
+ # 移除可能的accelerator包装前缀
393
+ new_key = key
394
+ if key.startswith('base_model.model.transformer.'):
395
+ new_key = key[len('base_model.model.transformer.'):]
396
+ elif key.startswith('model.transformer.'):
397
+ new_key = key[len('model.transformer.'):]
398
+ elif key.startswith('transformer.'):
399
+ new_key = key[len('transformer.'):]
400
+
401
+ # 只保留transformer相关的权重(包括所有transformer子模块)
402
+ # 检查是否是transformer的权重(不包含text_encoder等)
403
+ if (new_key.startswith('transformer_blocks') or
404
+ new_key.startswith('pos_embed') or
405
+ new_key.startswith('time_text_embed') or
406
+ 'lora' in new_key.lower()): # 也包含LoRA权重(如果存在)
407
+ transformer_state_dict[new_key] = value
408
+
409
+ if rank == 0:
410
+ print(f" Extracted {len(transformer_state_dict)} transformer weight keys")
411
+ print(f" Sample keys: {list(transformer_state_dict.keys())[:5]}")
412
+
413
+ # 直接加载到transformer(不使用LoRA适配器)
414
+ try:
415
+ missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False)
416
+
417
+ if rank == 0:
418
+ print(f" Loaded full model weights:")
419
+ print(f" Missing keys: {len(missing_keys)}")
420
+ print(f" Unexpected keys: {len(unexpected_keys)}")
421
+ if missing_keys:
422
+ print(f" Sample missing keys: {missing_keys[:5]}")
423
+ if unexpected_keys:
424
+ print(f" Sample unexpected keys: {unexpected_keys[:5]}")
425
+
426
+ # 如果missing keys太多,可能有问题
427
+ if len(missing_keys) > len(transformer_state_dict) * 0.5:
428
+ if rank == 0:
429
+ print(f" ⚠️ WARNING: Too many missing keys, weights may not be fully loaded")
430
+ return False
431
+
432
+ if rank == 0:
433
+ print(f" ✓ Successfully loaded merged model weights")
434
+ return True
435
+
436
+ except Exception as e:
437
+ if rank == 0:
438
+ print(f" ❌ Error loading full model weights: {e}")
439
+ import traceback
440
+ traceback.print_exc()
441
+ return False
442
+
443
+ # 如果只包含LoRA权重,按原来的方式加载
444
+ if rank == 0:
445
+ print(f"Detected LoRA-only weights, loading as LoRA adapter...")
446
+
447
+ # 首先尝试从checkpoint中检测实际的rank
448
+ detected_rank = None
449
+ for key, value in state_dict.items():
450
+ if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
451
+ # lora_A的形状是 [rank, hidden_size]
452
+ detected_rank = value.shape[0]
453
+ if rank == 0:
454
+ print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
455
+ break
456
+
457
+ # 如果检测到rank,使用检测到的rank;否则使用传入的rank
458
+ actual_rank = detected_rank if detected_rank is not None else lora_rank
459
+ if detected_rank is not None and detected_rank != lora_rank:
460
+ if rank == 0:
461
+ print(f"⚠️ Warning: Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank")
462
+
463
+ # 检查适配器是否已存在,如果存在则先卸载
464
+ # SD3Transformer2DModel没有delete_adapter方法,需要使用unload_lora_weights
465
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
466
+ if "default" in pipeline.transformer.peft_config:
467
+ if rank == 0:
468
+ print("Removing existing 'default' adapter before adding new one...")
469
+ try:
470
+ # 使用pipeline的unload_lora_weights方法
471
+ pipeline.unload_lora_weights()
472
+ if rank == 0:
473
+ print("Successfully unloaded existing LoRA adapter")
474
+ except Exception as e:
475
+ if rank == 0:
476
+ print(f"❌ ERROR: Could not unload existing adapter: {e}")
477
+ print("Cannot proceed without cleaning up adapter")
478
+ return False
479
+
480
+ # 先配置LoRA适配器(必须在加载之前配置)
481
+ # 使用检测到的或传入的rank
482
+ transformer_lora_config = LoraConfig(
483
+ r=actual_rank,
484
+ lora_alpha=actual_rank,
485
+ init_lora_weights="gaussian",
486
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
487
+ )
488
+
489
+ # 为transformer添加LoRA适配器
490
+ pipeline.transformer.add_adapter(transformer_lora_config)
491
+
492
+ if rank == 0:
493
+ print(f"LoRA adapter configured with rank={actual_rank}")
494
+
495
+ # 继续处理LoRA权重加载(state_dict已经在上面加载了)
496
+
497
+ # 提取LoRA权重 - accelerator保存的格式
498
+ # 从accelerator checkpoint的model.safetensors中,键名格式可能是:
499
+ # - transformer_blocks.X.attn.to_q.lora_A.default.weight (PEFT格式,直接可用)
500
+ # - 或者包含其他前缀
501
+ lora_state_dict = {}
502
+ for key, value in state_dict.items():
503
+ if 'lora' in key.lower() and 'transformer' in key.lower():
504
+ # 检查键名格式
505
+ new_key = key
506
+
507
+ # 移除可能的accelerator包装前缀
508
+ # accelerator可能保存为: model.transformer.transformer_blocks...
509
+ # 或者: base_model.model.transformer.transformer_blocks...
510
+ if key.startswith('base_model.model.transformer.'):
511
+ new_key = key[len('base_model.model.transformer.'):]
512
+ elif key.startswith('model.transformer.'):
513
+ new_key = key[len('model.transformer.'):]
514
+ elif key.startswith('transformer.'):
515
+ # 如果已经是transformer_blocks开头,不需要移除transformer.前缀
516
+ # 因为transformer_blocks是transformer的子模块
517
+ if not key[len('transformer.'):].startswith('transformer_blocks'):
518
+ new_key = key[len('transformer.'):]
519
+ else:
520
+ new_key = key[len('transformer.'):]
521
+
522
+ # 只保留transformer相关的LoRA权重
523
+ if 'transformer_blocks' in new_key or 'transformer' in new_key:
524
+ lora_state_dict[new_key] = value
525
+
526
+ if not lora_state_dict:
527
+ if rank == 0:
528
+ print("No LoRA weights found in checkpoint")
529
+ # 打印所有键名用于调试
530
+ all_keys = list(state_dict.keys())
531
+ print(f"Total keys: {len(all_keys)}")
532
+ print(f"First 20 keys: {all_keys[:20]}")
533
+ # 查找包含lora的键
534
+ lora_related = [k for k in all_keys if 'lora' in k.lower()]
535
+ if lora_related:
536
+ print(f"Keys containing 'lora': {lora_related[:10]}")
537
+ return False
538
+
539
+ if rank == 0:
540
+ print(f"Found {len(lora_state_dict)} LoRA weight keys")
541
+ sample_keys = list(lora_state_dict.keys())[:5]
542
+ print(f"Sample LoRA keys: {sample_keys}")
543
+
544
+ # 加载LoRA权重到transformer
545
+ # 注意:从checkpoint提取的键名格式已经是PEFT格式(如:transformer_blocks.0.attn.to_q.lora_A.default.weight)
546
+ # 不需要使用convert_unet_state_dict_to_peft转换,直接使用即可
547
+ try:
548
+ # 检查键名格式
549
+ sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
550
+
551
+ if rank == 0:
552
+ print(f"Original key format: {sample_key}")
553
+
554
+ # 关键问题:set_peft_model_state_dict期望的键名格式
555
+ # 从back/train_dreambooth_lora.py看,需要移除.default后缀
556
+ # 格式应该是:transformer_blocks.X.attn.to_q.lora_A.weight(没有.default)
557
+ # 但accelerator保存的格式是:transformer_blocks.X.attn.to_q.lora_A.default.weight(有.default)
558
+
559
+ # 检查键名格式
560
+ sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
561
+ has_default_suffix = '.default.weight' in sample_key or '.default.bias' in sample_key
562
+
563
+ if rank == 0:
564
+ print(f"Sample key: {sample_key}")
565
+ print(f"Has .default suffix: {has_default_suffix}")
566
+
567
+ # 如果键名包含.default.weight或.default.bias,需要移除.default部分
568
+ # 因为set_peft_model_state_dict期望的格式是:lora_A.weight,而不是lora_A.default.weight
569
+ converted_dict = {}
570
+ for key, value in lora_state_dict.items():
571
+ # 移除.default后缀(如果存在)
572
+ # transformer_blocks.0.attn.to_q.lora_A.default.weight -> transformer_blocks.0.attn.to_q.lora_A.weight
573
+ new_key = key
574
+ if '.default.weight' in new_key:
575
+ new_key = new_key.replace('.default.weight', '.weight')
576
+ elif '.default.bias' in new_key:
577
+ new_key = new_key.replace('.default.bias', '.bias')
578
+ elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')):
579
+ # 处理其他可能的.default位置
580
+ new_key = new_key.replace('.default', '')
581
+
582
+ converted_dict[new_key] = value
583
+
584
+ if rank == 0:
585
+ print(f"Converted {len(converted_dict)} keys (removed .default suffix if present)")
586
+ print(f"Sample converted keys: {list(converted_dict.keys())[:5]}")
587
+
588
+ # 调用set_peft_model_state_dict并检查返回值
589
+ incompatible_keys = set_peft_model_state_dict(
590
+ pipeline.transformer,
591
+ converted_dict,
592
+ adapter_name="default"
593
+ )
594
+
595
+ # 检查加载结果
596
+ if incompatible_keys is not None:
597
+ missing_keys = getattr(incompatible_keys, "missing_keys", [])
598
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", [])
599
+
600
+ if rank == 0:
601
+ print(f"LoRA loading result:")
602
+ print(f" Missing keys: {len(missing_keys)}")
603
+ print(f" Unexpected keys: {len(unexpected_keys)}")
604
+
605
+ if len(missing_keys) > 100:
606
+ print(f" ⚠️ WARNING: Too many missing keys ({len(missing_keys)}), LoRA may not be fully loaded!")
607
+ print(f" Sample missing keys: {missing_keys[:10]}")
608
+ elif missing_keys:
609
+ print(f" Sample missing keys: {missing_keys[:10]}")
610
+
611
+ if unexpected_keys:
612
+ print(f" Unexpected keys: {unexpected_keys[:10]}")
613
+
614
+ # 如果missing keys太多,说明加载失败
615
+ if len(missing_keys) > len(converted_dict) * 0.5: # 超过50%的键缺失
616
+ if rank == 0:
617
+ print("❌ ERROR: Too many missing keys, LoRA weights not loaded correctly!")
618
+ return False
619
+ else:
620
+ if rank == 0:
621
+ print("✓ LoRA weights loaded (no incompatible keys reported)")
622
+
623
+ except RuntimeError as e:
624
+ # 检查是否是size mismatch错误
625
+ error_str = str(e)
626
+ if "size mismatch" in error_str:
627
+ if rank == 0:
628
+ print(f"❌ Size mismatch error: The checkpoint rank doesn't match the adapter rank")
629
+ print(f" This usually means the checkpoint was trained with a different rank")
630
+ # 尝试从错误信息中提取期望的rank
631
+ import re
632
+ # 错误信息格式: "copying a param with shape torch.Size([32, 1536]) from checkpoint"
633
+ match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str)
634
+ if match:
635
+ checkpoint_rank = int(match.group(1))
636
+ if rank == 0:
637
+ print(f" Detected checkpoint rank: {checkpoint_rank}")
638
+ print(f" Adapter was configured with rank: {actual_rank}")
639
+ if checkpoint_rank != actual_rank:
640
+ print(f" ⚠️ Mismatch! Need to recreate adapter with rank={checkpoint_rank}")
641
+ else:
642
+ if rank == 0:
643
+ print(f"❌ Error setting LoRA state dict: {e}")
644
+ import traceback
645
+ traceback.print_exc()
646
+ # 清理适配器以便下次尝试
647
+ try:
648
+ pipeline.unload_lora_weights()
649
+ except:
650
+ pass
651
+ return False
652
+ except Exception as e:
653
+ if rank == 0:
654
+ print(f"❌ Error setting LoRA state dict: {e}")
655
+ import traceback
656
+ traceback.print_exc()
657
+ # 清理适配器以便下次尝试
658
+ try:
659
+ pipeline.unload_lora_weights()
660
+ except:
661
+ pass
662
+ return False
663
+
664
+ # 启用LoRA适配器
665
+ pipeline.transformer.set_adapter("default")
666
+
667
+ # 验证LoRA是否已加载和应用
668
+ if hasattr(pipeline.transformer, 'peft_config'):
669
+ adapters = list(pipeline.transformer.peft_config.keys())
670
+ if rank == 0:
671
+ print(f"LoRA adapters configured: {adapters}")
672
+ # 检查适配器是否启用
673
+ if hasattr(pipeline.transformer, 'active_adapters'):
674
+ # active_adapters 是一个方法,需要调用
675
+ try:
676
+ if callable(pipeline.transformer.active_adapters):
677
+ active = pipeline.transformer.active_adapters()
678
+ else:
679
+ active = pipeline.transformer.active_adapters
680
+ if rank == 0:
681
+ print(f"Active adapters: {active}")
682
+ except:
683
+ if rank == 0:
684
+ print("Could not get active adapters, but LoRA is configured")
685
+
686
+ # 验证LoRA权重是否真的被应用
687
+ # 检查LoRA层的权重是否非零
688
+ lora_layers_found = 0
689
+ nonzero_lora_layers = 0
690
+ total_lora_weight_sum = 0.0
691
+
692
+ for name, module in pipeline.transformer.named_modules():
693
+ if 'lora_A' in name or 'lora_B' in name:
694
+ lora_layers_found += 1
695
+ if hasattr(module, 'weight') and module.weight is not None:
696
+ weight_sum = module.weight.abs().sum().item()
697
+ total_lora_weight_sum += weight_sum
698
+ if weight_sum > 1e-6: # 非零阈值
699
+ nonzero_lora_layers += 1
700
+ if rank == 0 and nonzero_lora_layers <= 3: # 只打印前3个
701
+ print(f"✓ Found non-zero LoRA weight in: {name}, sum={weight_sum:.6f}")
702
+
703
+ if rank == 0:
704
+ print(f"LoRA verification:")
705
+ print(f" Total LoRA layers found: {lora_layers_found}")
706
+ print(f" Non-zero LoRA layers: {nonzero_lora_layers}")
707
+ print(f" Total LoRA weight sum: {total_lora_weight_sum:.6f}")
708
+
709
+ if lora_layers_found == 0:
710
+ print("❌ ERROR: No LoRA layers found in transformer!")
711
+ return False
712
+ elif nonzero_lora_layers == 0:
713
+ print("❌ ERROR: All LoRA weights are zero, LoRA not loaded correctly!")
714
+ return False
715
+ elif nonzero_lora_layers < lora_layers_found * 0.5:
716
+ print(f"⚠️ WARNING: Only {nonzero_lora_layers}/{lora_layers_found} LoRA layers have non-zero weights!")
717
+ print("⚠️ LoRA may not be fully applied!")
718
+ else:
719
+ print(f"✓ LoRA weights verified: {nonzero_lora_layers}/{lora_layers_found} layers have non-zero weights")
720
+
721
+ if nonzero_lora_layers == 0:
722
+ return False
723
+
724
+ if rank == 0:
725
+ print("✓ Successfully loaded and verified LoRA weights from checkpoint")
726
+
727
+ return True
728
+
729
+ except Exception as e:
730
+ if rank == 0:
731
+ print(f"Error loading LoRA from checkpoint: {e}")
732
+ import traceback
733
+ traceback.print_exc()
734
+ return False
735
+
736
+
737
+ def load_captions_from_jsonl(jsonl_path):
738
+ captions = []
739
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
740
+ for line in f:
741
+ line = line.strip()
742
+ if not line:
743
+ continue
744
+ try:
745
+ data = json.loads(line)
746
+ cap = None
747
+ for field in ['caption', 'text', 'prompt', 'description']:
748
+ if field in data and isinstance(data[field], str):
749
+ cap = data[field].strip()
750
+ break
751
+ if cap:
752
+ captions.append(cap)
753
+ except Exception:
754
+ continue
755
+ return captions if captions else ["a beautiful high quality image"]
756
+
757
+
758
+ def main(args):
759
+ assert torch.cuda.is_available(), "需要GPU运行"
760
+ dist.init_process_group("nccl")
761
+ rank = dist.get_rank()
762
+ world_size = dist.get_world_size()
763
+ device = rank % torch.cuda.device_count()
764
+ torch.cuda.set_device(device)
765
+ seed = args.global_seed * world_size + rank
766
+ torch.manual_seed(seed)
767
+
768
+ # 调试:打印接收到的参数
769
+ if rank == 0:
770
+ print("=" * 80)
771
+ print("参数检查:")
772
+ print(f" lora_path: {args.lora_path}")
773
+ print(f" rectified_weights: {args.rectified_weights}")
774
+ print(f" lora_path is None: {args.lora_path is None}")
775
+ print(f" lora_path is empty: {args.lora_path == '' if args.lora_path else 'N/A'}")
776
+ print(f" rectified_weights is None: {args.rectified_weights is None}")
777
+ print(f" rectified_weights is empty: {args.rectified_weights == '' if args.rectified_weights else 'N/A'}")
778
+ print("=" * 80)
779
+
780
+ lora_source = "baseline"
781
+
782
+ # 导入训练脚本中的类
783
+ RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
784
+
785
+ # 加载 pipeline
786
+ dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
787
+ if rank == 0:
788
+ print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path} (dtype={dtype})")
789
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
790
+ args.pretrained_model_name_or_path,
791
+ revision=args.revision,
792
+ variant=args.variant,
793
+ torch_dtype=dtype,
794
+ ).to(device)
795
+
796
+ # 加载 LoRA(可选)
797
+ lora_loaded = False
798
+ if args.lora_path:
799
+ if rank == 0:
800
+ print(f"Attempting to load LoRA weights from: {args.lora_path}")
801
+ print(f"LoRA path exists: {os.path.exists(args.lora_path) if args.lora_path else False}")
802
+
803
+ # 首先检查是否是标准的LoRA权重文件/目录
804
+ if check_lora_weights_exist(args.lora_path):
805
+ if rank == 0:
806
+ print("Found standard LoRA weights, loading...")
807
+ try:
808
+ # 检查加载前的transformer参数(用于验证)
809
+ if rank == 0:
810
+ sample_param_before = next(iter(pipeline.transformer.parameters())).clone()
811
+ print(f"Sample transformer param before LoRA (first 5 values): {sample_param_before.flatten()[:5]}")
812
+
813
+ pipeline.load_lora_weights(args.lora_path)
814
+ lora_loaded = True
815
+ lora_source = os.path.basename(args.lora_path.rstrip('/'))
816
+
817
+ # 验证LoRA是否真的被加载
818
+ if rank == 0:
819
+ sample_param_after = next(iter(pipeline.transformer.parameters())).clone()
820
+ param_diff = (sample_param_after - sample_param_before).abs().max().item()
821
+ print(f"Sample transformer param after LoRA (first 5 values): {sample_param_after.flatten()[:5]}")
822
+ print(f"Max parameter change after LoRA loading: {param_diff}")
823
+ if param_diff < 1e-6:
824
+ print("⚠️ WARNING: LoRA weights may not have been applied (parameter change is very small)")
825
+ else:
826
+ print("✓ LoRA weights appear to have been applied")
827
+
828
+ # 检查是否有peft_config
829
+ if hasattr(pipeline.transformer, 'peft_config'):
830
+ print(f"✓ PEFT config found: {list(pipeline.transformer.peft_config.keys())}")
831
+ else:
832
+ print("⚠️ WARNING: No peft_config found after loading LoRA")
833
+
834
+ if rank == 0:
835
+ print("LoRA loaded successfully from standard format.")
836
+ except Exception as e:
837
+ if rank == 0:
838
+ print(f"Failed to load LoRA from standard format: {e}")
839
+ import traceback
840
+ traceback.print_exc()
841
+
842
+ # 如果不是标准格式,尝试从accelerator checkpoint加载
843
+ if not lora_loaded and os.path.isdir(args.lora_path):
844
+ if rank == 0:
845
+ print("Standard LoRA weights not found, trying accelerator checkpoint format...")
846
+
847
+ # 首先尝试从checkpoint的model.safetensors中检测实际的rank
848
+ # 通过检查LoRA权重的形状来推断rank
849
+ detected_rank = None
850
+ try:
851
+ from safetensors.torch import load_file
852
+ model_file = os.path.join(args.lora_path, "model.safetensors")
853
+ if os.path.exists(model_file):
854
+ state_dict = load_file(model_file)
855
+ # 查找一个LoRA权重来确定rank
856
+ for key, value in state_dict.items():
857
+ if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
858
+ # lora_A的形状是 [rank, hidden_size]
859
+ detected_rank = value.shape[0]
860
+ if rank == 0:
861
+ print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
862
+ break
863
+ except Exception as e:
864
+ if rank == 0:
865
+ print(f"Could not detect rank from checkpoint: {e}")
866
+
867
+ # 构建rank尝试列表
868
+ # 如果检测到rank,优先使用检测到的rank,只尝试一次
869
+ # 如果未检测到,尝试常见的rank值
870
+ if detected_rank is not None:
871
+ rank_list = [detected_rank]
872
+ if rank == 0:
873
+ print(f"Using detected rank: {detected_rank}")
874
+ else:
875
+ # 如果检测失败,尝试常见的rank值(按用户指定的rank优先)
876
+ rank_list = []
877
+ # 如果用户指定了rank(从args.lora_rank),优先尝试
878
+ if hasattr(args, 'lora_rank') and args.lora_rank:
879
+ rank_list.append(args.lora_rank)
880
+ # 添加其他常见的rank值
881
+ for r in [32, 64, 16, 128]:
882
+ if r not in rank_list:
883
+ rank_list.append(r)
884
+ if rank == 0:
885
+ print(f"Rank detection failed, will try ranks in order: {rank_list}")
886
+
887
+ # 尝试不同的rank值
888
+ for lora_rank in rank_list:
889
+ # 在尝试新的rank之前,先清理已存在的适配器
890
+ # 重要:每次尝试前都要清理,否则适配器会保留之前的rank配置
891
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
892
+ if "default" in pipeline.transformer.peft_config:
893
+ try:
894
+ # 使用pipeline的unload_lora_weights方法
895
+ pipeline.unload_lora_weights()
896
+ if rank == 0:
897
+ print(f"Cleaned up existing adapter before trying rank={lora_rank}")
898
+ except Exception as e:
899
+ if rank == 0:
900
+ print(f"Warning: Could not unload adapter: {e}")
901
+ # 如果卸载失败,需要重新创建pipeline
902
+ if rank == 0:
903
+ print("⚠️ WARNING: Cannot unload adapter, will recreate pipeline...")
904
+ # 重新加载pipeline(最后手段)
905
+ try:
906
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
907
+ args.pretrained_model_name_or_path,
908
+ revision=args.revision,
909
+ variant=args.variant,
910
+ torch_dtype=dtype,
911
+ ).to(device)
912
+ if rank == 0:
913
+ print("Pipeline recreated to clear adapter state")
914
+ except Exception as e2:
915
+ if rank == 0:
916
+ print(f"Failed to recreate pipeline: {e2}")
917
+
918
+ if rank == 0:
919
+ print(f"Trying to load with LoRA rank={lora_rank}...")
920
+ lora_loaded = load_lora_from_checkpoint(pipeline, args.lora_path, rank=rank, lora_rank=lora_rank)
921
+ if lora_loaded:
922
+ if rank == 0:
923
+ print(f"✓ Successfully loaded LoRA with rank={lora_rank}")
924
+ lora_source = "checkpoint"
925
+ break
926
+ elif rank == 0:
927
+ print(f"✗ Failed to load with rank={lora_rank}, trying next rank...")
928
+
929
+ # 如果checkpoint目录加载失败,尝试从输出目录的根目录加载标准LoRA权重
930
+ if not lora_loaded and os.path.isdir(args.lora_path):
931
+ # 检查输出目录的根目录(checkpoint的父目录)
932
+ output_dir = os.path.dirname(args.lora_path.rstrip('/'))
933
+ if output_dir and os.path.exists(output_dir):
934
+ if rank == 0:
935
+ print(f"Trying to load standard LoRA weights from output directory: {output_dir}")
936
+ if check_lora_weights_exist(output_dir):
937
+ try:
938
+ pipeline.load_lora_weights(output_dir)
939
+ lora_loaded = True
940
+ if rank == 0:
941
+ print("LoRA loaded successfully from output directory.")
942
+ except Exception as e:
943
+ if rank == 0:
944
+ print(f"Failed to load LoRA from output directory: {e}")
945
+
946
+ if not lora_loaded:
947
+ if rank == 0:
948
+ print(f"⚠️ WARNING: Failed to load LoRA weights from {args.lora_path}, using baseline model")
949
+ else:
950
+ # 最终验证LoRA是否真的被启用
951
+ if rank == 0:
952
+ print("=" * 80)
953
+ print("LoRA 加载验证:")
954
+ if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
955
+ print(f" ✓ PEFT config exists: {list(pipeline.transformer.peft_config.keys())}")
956
+ # 检查LoRA层的权重
957
+ lora_layers_found = 0
958
+ for name, module in pipeline.transformer.named_modules():
959
+ if 'lora_A' in name or 'lora_B' in name:
960
+ lora_layers_found += 1
961
+ if lora_layers_found <= 3: # 只打印前3个
962
+ if hasattr(module, 'weight'):
963
+ weight_sum = module.weight.abs().sum().item() if module.weight is not None else 0
964
+ print(f" ✓ Found LoRA layer: {name}, weight_sum={weight_sum:.6f}")
965
+ print(f" ✓ Total LoRA layers found: {lora_layers_found}")
966
+ if lora_layers_found == 0:
967
+ print(" ⚠️ WARNING: No LoRA layers found in transformer!")
968
+ else:
969
+ print(" ⚠️ WARNING: No PEFT config found - LoRA may not be active!")
970
+ print("=" * 80)
971
+
972
+ # 构建 RectifiedNoiseModule 并加载权重(仅在提供了 rectified_weights 时)
973
+ # 安全地检查 rectified_weights 是否有效
974
+ use_rectified = False
975
+ rectified_weights_path = None
976
+ if args.rectified_weights:
977
+ rectified_weights_str = str(args.rectified_weights).strip()
978
+ if rectified_weights_str:
979
+ use_rectified = True
980
+ rectified_weights_path = rectified_weights_str
981
+
982
+ if rank == 0:
983
+ print(f"use_rectified: {use_rectified}, rectified_weights_path: {rectified_weights_path}")
984
+
985
+ if use_rectified:
986
+ if rank == 0:
987
+ print(f"Using Rectified Noise module with weights from: {rectified_weights_path}")
988
+
989
+ # 从 transformer 配置推断必要尺寸
990
+ tfm = pipeline.transformer
991
+ if hasattr(tfm.config, 'joint_attention_dim') and tfm.config.joint_attention_dim is not None:
992
+ sit_hidden_size = tfm.config.joint_attention_dim
993
+ elif hasattr(tfm.config, 'inner_dim') and tfm.config.inner_dim is not None:
994
+ sit_hidden_size = tfm.config.inner_dim
995
+ elif hasattr(tfm.config, 'hidden_size') and tfm.config.hidden_size is not None:
996
+ sit_hidden_size = tfm.config.hidden_size
997
+ else:
998
+ sit_hidden_size = 4096
999
+
1000
+ transformer_hidden_size = getattr(tfm.config, 'hidden_size', 1536)
1001
+ num_attention_heads = getattr(tfm.config, 'num_attention_heads', 32)
1002
+ input_dim = getattr(tfm.config, 'in_channels', 16)
1003
+
1004
+ rectified_module = RectifiedNoiseModule(
1005
+ hidden_size=sit_hidden_size,
1006
+ num_sit_layers=args.num_sit_layers,
1007
+ num_attention_heads=num_attention_heads,
1008
+ input_dim=input_dim,
1009
+ transformer_hidden_size=transformer_hidden_size,
1010
+ )
1011
+ # 加载 SIT 权重
1012
+ ok = load_sit_weights(rectified_module, rectified_weights_path, rank=rank)
1013
+ if rank == 0:
1014
+ if not ok:
1015
+ print("⚠️ Warning: Failed to load rectified weights, will use baseline model without rectified noise")
1016
+ else:
1017
+ print("✓ Successfully loaded rectified noise weights")
1018
+
1019
+ # 组装 SD3WithRectifiedNoise
1020
+ # 关键:SD3WithRectifiedNoise 会保留 transformer 的引用
1021
+ # 但是,SD3WithRectifiedNoise 在 __init__ 中会冻结 transformer 参数
1022
+ # 这不应该影响 LoRA,因为 LoRA 是作为适配器添加的,不是原始参数
1023
+ # 我们需要确保在创建 SD3WithRectifiedNoise 之前,LoRA 适配器已经正确加载和启用
1024
+ if lora_loaded and rank == 0:
1025
+ print("Creating SD3WithRectifiedNoise with LoRA-enabled transformer...")
1026
+ elif rank == 0:
1027
+ print("Creating SD3WithRectifiedNoise...")
1028
+
1029
+ model = SD3WithRectifiedNoise(pipeline.transformer, rectified_module).to(device)
1030
+
1031
+ # 重要:SD3WithRectifiedNoise 的 __init__ 会冻结 transformer 参数
1032
+ # 但 LoRA 适配器应该仍然有效,因为它们是独立的模块
1033
+ # 我们需要确保 LoRA 适配器在包装后仍然可以访问
1034
+
1035
+ # 确保 LoRA 适配器在模型替换后仍然启用
1036
+ if lora_loaded:
1037
+ # 通过model.transformer访问,因为SD3WithRectifiedNoise包装了transformer
1038
+ if hasattr(model.transformer, 'peft_config'):
1039
+ try:
1040
+ # 确保适配器处于启用状态
1041
+ model.transformer.set_adapter("default_0")
1042
+
1043
+ # 验证LoRA权重在包装后是否仍然存在
1044
+ lora_layers_after_wrap = 0
1045
+ nonzero_after_wrap = 0
1046
+ for name, module in model.transformer.named_modules():
1047
+ if 'lora_A' in name or 'lora_B' in name:
1048
+ lora_layers_after_wrap += 1
1049
+ if hasattr(module, 'weight') and module.weight is not None:
1050
+ if module.weight.abs().sum().item() > 1e-6:
1051
+ nonzero_after_wrap += 1
1052
+
1053
+ if rank == 0:
1054
+ print(f"LoRA after SD3WithRectifiedNoise wrapping:")
1055
+ print(f" LoRA layers: {lora_layers_after_wrap}, Non-zero: {nonzero_after_wrap}")
1056
+ if nonzero_after_wrap == 0:
1057
+ print(" ❌ ERROR: All LoRA weights are zero after wrapping!")
1058
+ elif nonzero_after_wrap < lora_layers_after_wrap * 0.5:
1059
+ print(f" ⚠️ WARNING: Only {nonzero_after_wrap}/{lora_layers_after_wrap} LoRA layers have weights!")
1060
+ else:
1061
+ print(f" ✓ LoRA weights preserved after wrapping")
1062
+
1063
+ # 验证适配器是否真的启用
1064
+ if hasattr(model.transformer, 'active_adapters'):
1065
+ try:
1066
+ if callable(model.transformer.active_adapters):
1067
+ active = model.transformer.active_adapters()
1068
+ else:
1069
+ active = model.transformer.active_adapters
1070
+ if rank == 0:
1071
+ print(f" Active adapters: {active}")
1072
+ except:
1073
+ if rank == 0:
1074
+ print(" LoRA adapter re-enabled after model wrapping")
1075
+ else:
1076
+ if rank == 0:
1077
+ print(" LoRA adapter re-enabled after model wrapping")
1078
+ except Exception as e:
1079
+ if rank == 0:
1080
+ print(f"❌ ERROR: Could not re-enable LoRA adapter: {e}")
1081
+ import traceback
1082
+ traceback.print_exc()
1083
+ else:
1084
+ # LoRA权重已经合并到transformer的基础权重中(合并加载方式)
1085
+ # 这种情况下没有peft_config是正常的,因为LoRA已经合并了
1086
+ if rank == 0:
1087
+ print("LoRA loaded via merged weights (no PEFT adapter needed)")
1088
+ print(" ✓ LoRA weights are already merged into transformer base weights")
1089
+ print(" Note: This is expected when loading from merged checkpoint format")
1090
+
1091
+ # 注册到 pipeline(pipeline_stable_diffusion_3.py 已支持 external model)
1092
+ pipeline.model = model
1093
+
1094
+ # 确保模型处于评估模式(LoRA在eval模式下也应该工作)
1095
+ model.eval()
1096
+ model.transformer.eval() # 确保transformer也处于eval模式
1097
+ else:
1098
+ if rank == 0:
1099
+ print("Not using Rectified Noise module, using baseline SD3 pipeline")
1100
+ # 不使用 SD3WithRectifiedNoise,保持原始 pipeline
1101
+ # pipeline.model 保持为原始的 transformer
1102
+
1103
+ # 关键:确保LoRA适配器在推理时被使用
1104
+ # PEFT模型在eval模式下,LoRA适配器应该自动启用,但我们需要确保
1105
+ if lora_loaded:
1106
+ # 获取正确的 transformer 引用
1107
+ transformer_ref = model.transformer if use_rectified else pipeline.transformer
1108
+
1109
+ # 确保transformer的LoRA适配器处于启用状态
1110
+ if hasattr(transformer_ref, 'set_adapter'):
1111
+ try:
1112
+ transformer_ref.set_adapter("default")
1113
+ except:
1114
+ pass
1115
+
1116
+ # 验证LoRA是否真的会被使用
1117
+ if rank == 0:
1118
+ # 检查一个LoRA层的权重
1119
+ lora_found = False
1120
+ for name, module in transformer_ref.named_modules():
1121
+ if 'lora_A' in name and 'default' in name and hasattr(module, 'weight'):
1122
+ if module.weight is not None:
1123
+ weight_sum = module.weight.abs().sum().item()
1124
+ if weight_sum > 0:
1125
+ print(f"✓ Verified LoRA weight in {name}: sum={weight_sum:.6f}")
1126
+ lora_found = True
1127
+ break
1128
+
1129
+ if not lora_found:
1130
+ print("⚠ Warning: Could not verify LoRA weights in model")
1131
+ else:
1132
+ # 额外检查:验证LoRA层是否真的会被调用
1133
+ # 检查一个LoRA Linear层
1134
+ for name, module in transformer_ref.named_modules():
1135
+ if hasattr(module, '__class__') and 'lora' in module.__class__.__name__.lower():
1136
+ if hasattr(module, 'lora_enabled'):
1137
+ enabled = module.lora_enabled
1138
+ if rank == 0:
1139
+ print(f"✓ Found LoRA layer {name}, enabled: {enabled}")
1140
+ break
1141
+
1142
+ print("Model set to eval mode, LoRA should be active during inference")
1143
+
1144
+ # 启用内存优化选项
1145
+ if args.enable_attention_slicing:
1146
+ enable_attention_slicing_method = getattr(pipeline, 'enable_attention_slicing', None)
1147
+ if enable_attention_slicing_method is not None and callable(enable_attention_slicing_method):
1148
+ try:
1149
+ if rank == 0:
1150
+ print("Enabling attention slicing to save memory")
1151
+ enable_attention_slicing_method()
1152
+ except Exception as e:
1153
+ if rank == 0:
1154
+ print(f"Warning: Failed to enable attention slicing: {e}")
1155
+ else:
1156
+ if rank == 0:
1157
+ print("Warning: Attention slicing not available for this pipeline")
1158
+
1159
+ if args.enable_vae_slicing:
1160
+ # 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
1161
+ enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
1162
+ if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
1163
+ try:
1164
+ if rank == 0:
1165
+ print("Enabling VAE slicing to save memory")
1166
+ enable_vae_slicing_method()
1167
+ except Exception as e:
1168
+ if rank == 0:
1169
+ print(f"Warning: Failed to enable VAE slicing: {e}")
1170
+ else:
1171
+ if rank == 0:
1172
+ print("Warning: VAE slicing not available for this pipeline (SD3 may not support this)")
1173
+
1174
+ if args.enable_cpu_offload:
1175
+ if rank == 0:
1176
+ print("Enabling CPU offload to save memory")
1177
+ pipeline.enable_model_cpu_offload()
1178
+
1179
+ # 禁用进度条以减少输出
1180
+ pipeline.set_progress_bar_config(disable=True)
1181
+
1182
+ # 读入 captions
1183
+ captions = load_captions_from_jsonl(args.captions_jsonl)
1184
+ total_images_needed = min(len(captions) * args.images_per_caption, args.max_samples)
1185
+
1186
+ # 生成caption和image的映射列表
1187
+ caption_image_pairs = []
1188
+ for i, caption in enumerate(captions):
1189
+ for j in range(args.images_per_caption):
1190
+ caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
1191
+
1192
+ # 输出目录
1193
+ folder_name = f"sd3-rectified-{lora_source}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
1194
+ sample_folder_dir = os.path.join(args.sample_dir, folder_name)
1195
+
1196
+ if rank == 0:
1197
+ os.makedirs(sample_folder_dir, exist_ok=True)
1198
+ print(f"Saving .png samples at {sample_folder_dir}")
1199
+ # 清空caption文件
1200
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
1201
+ if os.path.exists(caption_file):
1202
+ os.remove(caption_file)
1203
+ dist.barrier()
1204
+
1205
+ n = args.per_proc_batch_size
1206
+ global_batch = n * world_size
1207
+ total_samples = int(math.ceil(total_images_needed / global_batch) * global_batch)
1208
+ assert total_samples % world_size == 0
1209
+ samples_per_gpu = total_samples // world_size
1210
+ assert samples_per_gpu % n == 0
1211
+ iterations = samples_per_gpu // n
1212
+
1213
+ if rank == 0:
1214
+ print(f"Sampling total={total_samples}, per_gpu={samples_per_gpu}, iterations={iterations}")
1215
+
1216
+ pbar = tqdm(range(iterations)) if rank == 0 else range(iterations)
1217
+ saved = 0
1218
+
1219
+ autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
1220
+ for it in pbar:
1221
+ # 获取这个batch对应的caption
1222
+ batch_prompts = []
1223
+ batch_caption_info = []
1224
+
1225
+ for j in range(n):
1226
+ global_index = it * global_batch + j * world_size + rank
1227
+ if global_index < len(caption_image_pairs):
1228
+ caption, caption_idx, image_idx = caption_image_pairs[global_index]
1229
+ batch_prompts.append(caption)
1230
+ batch_caption_info.append((caption, caption_idx, image_idx))
1231
+ else:
1232
+ # 如果超出范围,使用最后一个caption
1233
+ if caption_image_pairs:
1234
+ caption, caption_idx, image_idx = caption_image_pairs[-1]
1235
+ batch_prompts.append(caption)
1236
+ batch_caption_info.append((caption, caption_idx, image_idx))
1237
+ else:
1238
+ batch_prompts.append("a beautiful high quality image")
1239
+ batch_caption_info.append(("a beautiful high quality image", 0, 0))
1240
+
1241
+ with torch.autocast(autocast_device, dtype=dtype):
1242
+ images = []
1243
+ for k, prompt in enumerate(batch_prompts):
1244
+ image_seed = seed + it * 10000 + k * 1000 + rank
1245
+ generator = torch.Generator(device=device).manual_seed(image_seed)
1246
+ img = pipeline(
1247
+ prompt=prompt,
1248
+ height=args.height,
1249
+ width=args.width,
1250
+ num_inference_steps=args.num_inference_steps,
1251
+ guidance_scale=args.guidance_scale,
1252
+ generator=generator,
1253
+ num_images_per_prompt=1,
1254
+ ).images[0]
1255
+ images.append(img)
1256
+
1257
+ # 保存
1258
+ for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
1259
+ global_index = it * global_batch + j * world_size + rank
1260
+ if global_index < len(caption_image_pairs):
1261
+ # 保存图片,文件名包含caption索引和图片索引
1262
+ filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
1263
+ image_path = os.path.join(sample_folder_dir, filename)
1264
+ image.save(image_path)
1265
+
1266
+ # 保存caption信息到文本文件(只在rank 0上操作)
1267
+ if rank == 0:
1268
+ caption_file = os.path.join(sample_folder_dir, "captions.txt")
1269
+ with open(caption_file, "a", encoding="utf-8") as f:
1270
+ f.write(f"{filename}\t{caption}\n")
1271
+
1272
+ total_generated = saved * world_size # 近似值
1273
+
1274
+ dist.barrier()
1275
+
1276
+ if rank == 0:
1277
+ print(f"Done. Saved {saved * world_size} images in total.")
1278
+ # 重新计算实际生成的图片数量
1279
+ actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
1280
+ print(f"Actually generated {actual_num_samples} images")
1281
+ # 使用实际的图片数量或用户指定的数量,取较小值
1282
+ npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
1283
+ create_npz_from_sample_folder(sample_folder_dir, npz_samples)
1284
+ print("Done.")
1285
+
1286
+ dist.barrier()
1287
+ dist.destroy_process_group()
1288
+ parser = argparse.ArgumentParser(description="SD3 LoRA + RectifiedNoise 分布式采样脚本")
1289
+ # 模型
1290
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
1291
+ parser.add_argument("--revision", type=str, default=None)
1292
+ parser.add_argument("--variant", type=str, default=None)
1293
+ # LoRA 与 Rectified
1294
+ parser.add_argument("--lora_path", type=str, default=None, help="LoRA 权重路径(文件或目录)")
1295
+ parser.add_argument("--rectified_weights", type=str, default=None, help="Rectified(SIT) 权重路径(文件或目录)")
1296
+ parser.add_argument("--num_sit_layers", type=int, default=1, help="与训练一致的 SIT 层数")
1297
+ # 采样
1298
+ parser.add_argument("--num_inference_steps", type=int, default=28)
1299
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
1300
+ parser.add_argument("--height", type=int, default=1024)
1301
+ parser.add_argument("--width", type=int, default=1024)
1302
+ parser.add_argument("--per_proc_batch_size", type=int, default=1)
1303
+ parser.add_argument("--images_per_caption", type=int, default=1)
1304
+ parser.add_argument("--max_samples", type=int, default=10000)
1305
+ parser.add_argument("--captions_jsonl", type=str, required=True)
1306
+ parser.add_argument("--sample_dir", type=str, default="sd3_rectified_samples")
1307
+ parser.add_argument("--global_seed", type=int, default=42)
1308
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
1309
+ # 内存优化选项
1310
+ parser.add_argument("--enable_attention_slicing", action="store_true", help="启用 attention slicing 以节省显存")
1311
+ parser.add_argument("--enable_vae_slicing", action="store_true", help="启用 VAE slicing 以节省显存")
1312
+ parser.add_argument("--enable_cpu_offload", action="store_true", help="启用 CPU offload 以节省显存")
1313
+
1314
+ args = parser.parse_args()
1315
+ main(args)
1316
+
1317
+
sd3_rectified_samples_batch2_2200005011.01.01.0cfg_cond_true.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Inception Score: 37.646392822265625
2
+ FID: 21.19386100577333
3
+ sFID: 71.79977998851734
4
+ Precision: 0.690407122136641
5
+ Recall: 0.358997247638176
train_lora_sd3.py ADDED
@@ -0,0 +1,1597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """SD3 LoRA fine-tuning script for text2image generation."""
17
+
18
+ import argparse
19
+ import copy
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import datasets
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+ import transformers
35
+ from accelerate import Accelerator
36
+ from accelerate.logging import get_logger
37
+ from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
38
+ from datasets import load_dataset
39
+ from huggingface_hub import create_repo, upload_folder
40
+ from packaging import version
41
+ from peft import LoraConfig, set_peft_model_state_dict
42
+ from peft.utils import get_peft_model_state_dict
43
+ from PIL import Image
44
+ from torchvision import transforms
45
+ from torchvision.transforms.functional import crop
46
+ from tqdm.auto import tqdm
47
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
48
+
49
+ import diffusers
50
+ from diffusers import (
51
+ AutoencoderKL,
52
+ FlowMatchEulerDiscreteScheduler,
53
+ SD3Transformer2DModel,
54
+ StableDiffusion3Pipeline,
55
+ )
56
+ from diffusers.optimization import get_scheduler
57
+ from diffusers.training_utils import (
58
+ _set_state_dict_into_text_encoder,
59
+ cast_training_params,
60
+ compute_density_for_timestep_sampling,
61
+ compute_loss_weighting_for_sd3,
62
+ free_memory,
63
+ )
64
+ from diffusers.utils import (
65
+ check_min_version,
66
+ convert_unet_state_dict_to_peft,
67
+ is_wandb_available,
68
+ )
69
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
70
+ from diffusers.utils.torch_utils import is_compiled_module
71
+
72
+ if is_wandb_available():
73
+ import wandb
74
+
75
+ # Check minimum diffusers version
76
+ check_min_version("0.30.0")
77
+
78
+ logger = get_logger(__name__)
79
+
80
+
81
+ def save_model_card(
82
+ repo_id: str,
83
+ images: list = None,
84
+ base_model: str = None,
85
+ dataset_name: str = None,
86
+ train_text_encoder: bool = False,
87
+ repo_folder: str = None,
88
+ vae_path: str = None,
89
+ ):
90
+ """Save model card for SD3 LoRA model."""
91
+ img_str = ""
92
+ if images is not None:
93
+ for i, image in enumerate(images):
94
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
95
+ img_str += f"![img_{i}](./image_{i}.png)\n"
96
+
97
+ model_description = f"""
98
+ # SD3 LoRA text2image fine-tuning - {repo_id}
99
+
100
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
101
+ {img_str}
102
+
103
+ LoRA for the text encoder was enabled: {train_text_encoder}.
104
+
105
+ Special VAE used for training: {vae_path}.
106
+ """
107
+ model_card = load_or_create_model_card(
108
+ repo_id_or_path=repo_id,
109
+ from_training=True,
110
+ license="other",
111
+ base_model=base_model,
112
+ model_description=model_description,
113
+ inference=True,
114
+ )
115
+
116
+ tags = [
117
+ "stable-diffusion-3",
118
+ "stable-diffusion-3-diffusers",
119
+ "text-to-image",
120
+ "diffusers",
121
+ "diffusers-training",
122
+ "lora",
123
+ "sd3",
124
+ ]
125
+ model_card = populate_model_card(model_card, tags=tags)
126
+ model_card.save(os.path.join(repo_folder, "README.md"))
127
+
128
+
129
+ def log_validation(
130
+ pipeline,
131
+ args,
132
+ accelerator,
133
+ epoch,
134
+ is_final_validation=False,
135
+ global_step=None,
136
+ ):
137
+ """Run validation and log images."""
138
+ logger.info(
139
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
140
+ f" {args.validation_prompt}."
141
+ )
142
+ pipeline = pipeline.to(accelerator.device)
143
+ pipeline.set_progress_bar_config(disable=True)
144
+
145
+ # run inference
146
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
147
+ pipeline_args = {"prompt": args.validation_prompt}
148
+
149
+ if torch.backends.mps.is_available():
150
+ autocast_ctx = nullcontext()
151
+ else:
152
+ autocast_ctx = torch.autocast(accelerator.device.type)
153
+
154
+ with autocast_ctx:
155
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
156
+
157
+ # Save images to output directory
158
+ if accelerator.is_main_process:
159
+ validation_dir = os.path.join(args.output_dir, "validation_images")
160
+ os.makedirs(validation_dir, exist_ok=True)
161
+ for i, image in enumerate(images):
162
+ # Create filename with step and epoch information
163
+ if global_step is not None:
164
+ filename = f"validation_step_{global_step}_epoch_{epoch}_img_{i}.png"
165
+ else:
166
+ filename = f"validation_epoch_{epoch}_img_{i}.png"
167
+
168
+ image_path = os.path.join(validation_dir, filename)
169
+ image.save(image_path)
170
+ logger.info(f"Saved validation image: {image_path}")
171
+
172
+ for tracker in accelerator.trackers if hasattr(accelerator, 'trackers') and accelerator.trackers else []:
173
+ phase_name = "test" if is_final_validation else "validation"
174
+ try:
175
+ if tracker.name == "tensorboard":
176
+ np_images = np.stack([np.asarray(img) for img in images])
177
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
178
+ if tracker.name == "wandb":
179
+ tracker.log(
180
+ {
181
+ phase_name: [
182
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
183
+ ]
184
+ }
185
+ )
186
+ except Exception as e:
187
+ logger.warning(f"Failed to log to {tracker.name}: {e}")
188
+
189
+ del pipeline
190
+ free_memory()
191
+ return images
192
+
193
+
194
+ def import_model_class_from_model_name_or_path(
195
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
196
+ ):
197
+ """Import the correct text encoder class."""
198
+ text_encoder_config = PretrainedConfig.from_pretrained(
199
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
200
+ )
201
+ model_class = text_encoder_config.architectures[0]
202
+
203
+ if model_class == "CLIPTextModelWithProjection":
204
+ from transformers import CLIPTextModelWithProjection
205
+ return CLIPTextModelWithProjection
206
+ elif model_class == "T5EncoderModel":
207
+ from transformers import T5EncoderModel
208
+ return T5EncoderModel
209
+ else:
210
+ raise ValueError(f"{model_class} is not supported.")
211
+
212
+
213
+ def parse_args(input_args=None):
214
+ """Parse command line arguments."""
215
+ parser = argparse.ArgumentParser(description="SD3 LoRA training script.")
216
+
217
+ # Model arguments
218
+ parser.add_argument(
219
+ "--pretrained_model_name_or_path",
220
+ type=str,
221
+ default=None,
222
+ required=True,
223
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
224
+ )
225
+ parser.add_argument(
226
+ "--revision",
227
+ type=str,
228
+ default=None,
229
+ help="Revision of pretrained model identifier from huggingface.co/models.",
230
+ )
231
+ parser.add_argument(
232
+ "--variant",
233
+ type=str,
234
+ default=None,
235
+ help="Variant of the model files, e.g. fp16",
236
+ )
237
+
238
+ # Dataset arguments
239
+ parser.add_argument(
240
+ "--dataset_name",
241
+ type=str,
242
+ default=None,
243
+ help="The name of the Dataset to train on.",
244
+ )
245
+ parser.add_argument(
246
+ "--dataset_config_name",
247
+ type=str,
248
+ default=None,
249
+ help="The config of the Dataset.",
250
+ )
251
+ parser.add_argument(
252
+ "--train_data_dir",
253
+ type=str,
254
+ default=None,
255
+ help="A folder containing the training data.",
256
+ )
257
+ parser.add_argument(
258
+ "--image_column",
259
+ type=str,
260
+ default="image",
261
+ help="The column of the dataset containing an image."
262
+ )
263
+ parser.add_argument(
264
+ "--caption_column",
265
+ type=str,
266
+ default="caption",
267
+ help="The column of the dataset containing a caption.",
268
+ )
269
+
270
+ # Training arguments
271
+ parser.add_argument(
272
+ "--max_sequence_length",
273
+ type=int,
274
+ default=77,
275
+ help="Maximum sequence length to use with the T5 text encoder",
276
+ )
277
+ parser.add_argument(
278
+ "--validation_prompt",
279
+ type=str,
280
+ default=None,
281
+ help="A prompt used during validation.",
282
+ )
283
+ parser.add_argument(
284
+ "--num_validation_images",
285
+ type=int,
286
+ default=4,
287
+ help="Number of images for validation.",
288
+ )
289
+ parser.add_argument(
290
+ "--validation_epochs",
291
+ type=int,
292
+ default=1,
293
+ help="Run validation every X epochs.",
294
+ )
295
+ parser.add_argument(
296
+ "--max_train_samples",
297
+ type=int,
298
+ default=None,
299
+ help="Truncate the number of training examples.",
300
+ )
301
+ parser.add_argument(
302
+ "--output_dir",
303
+ type=str,
304
+ default="sd3-lora-finetuned",
305
+ help="Output directory for model predictions and checkpoints.",
306
+ )
307
+ parser.add_argument(
308
+ "--cache_dir",
309
+ type=str,
310
+ default=None,
311
+ help="Directory to store downloaded models and datasets.",
312
+ )
313
+ parser.add_argument(
314
+ "--seed",
315
+ type=int,
316
+ default=None,
317
+ help="A seed for reproducible training."
318
+ )
319
+ parser.add_argument(
320
+ "--resolution",
321
+ type=int,
322
+ default=1024,
323
+ help="Image resolution for training.",
324
+ )
325
+ parser.add_argument(
326
+ "--center_crop",
327
+ default=False,
328
+ action="store_true",
329
+ help="Whether to center crop input images.",
330
+ )
331
+ parser.add_argument(
332
+ "--random_flip",
333
+ action="store_true",
334
+ help="Whether to randomly flip images horizontally.",
335
+ )
336
+ parser.add_argument(
337
+ "--train_text_encoder",
338
+ action="store_true",
339
+ help="Whether to train the text encoder.",
340
+ )
341
+ parser.add_argument(
342
+ "--train_batch_size",
343
+ type=int,
344
+ default=16,
345
+ help="Batch size for training dataloader."
346
+ )
347
+ parser.add_argument(
348
+ "--num_train_epochs",
349
+ type=int,
350
+ default=100
351
+ )
352
+ parser.add_argument(
353
+ "--max_train_steps",
354
+ type=int,
355
+ default=None,
356
+ help="Total number of training steps.",
357
+ )
358
+ parser.add_argument(
359
+ "--checkpointing_steps",
360
+ type=int,
361
+ default=500,
362
+ help="Save checkpoint every X updates.",
363
+ )
364
+ parser.add_argument(
365
+ "--checkpoints_total_limit",
366
+ type=int,
367
+ default=None,
368
+ help="Max number of checkpoints to store.",
369
+ )
370
+ parser.add_argument(
371
+ "--resume_from_checkpoint",
372
+ type=str,
373
+ default=None,
374
+ help="Path to resume training from checkpoint.",
375
+ )
376
+ parser.add_argument(
377
+ "--gradient_accumulation_steps",
378
+ type=int,
379
+ default=1,
380
+ help="Number of update steps to accumulate.",
381
+ )
382
+ parser.add_argument(
383
+ "--gradient_checkpointing",
384
+ action="store_true",
385
+ help="Use gradient checkpointing to save memory.",
386
+ )
387
+ parser.add_argument(
388
+ "--learning_rate",
389
+ type=float,
390
+ default=1e-4,
391
+ help="Initial learning rate.",
392
+ )
393
+ parser.add_argument(
394
+ "--scale_lr",
395
+ action="store_true",
396
+ default=False,
397
+ help="Scale learning rate by number of GPUs, etc.",
398
+ )
399
+ parser.add_argument(
400
+ "--lr_scheduler",
401
+ type=str,
402
+ default="constant",
403
+ help="Learning rate scheduler type.",
404
+ )
405
+ parser.add_argument(
406
+ "--lr_warmup_steps",
407
+ type=int,
408
+ default=500,
409
+ help="Number of warmup steps."
410
+ )
411
+
412
+ # SD3 specific arguments
413
+ parser.add_argument(
414
+ "--weighting_scheme",
415
+ type=str,
416
+ default="logit_normal",
417
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
418
+ help="Weighting scheme for flow matching loss.",
419
+ )
420
+ parser.add_argument(
421
+ "--logit_mean",
422
+ type=float,
423
+ default=0.0,
424
+ help="Mean for logit_normal weighting."
425
+ )
426
+ parser.add_argument(
427
+ "--logit_std",
428
+ type=float,
429
+ default=1.0,
430
+ help="Std for logit_normal weighting."
431
+ )
432
+ parser.add_argument(
433
+ "--mode_scale",
434
+ type=float,
435
+ default=1.29,
436
+ help="Scale for mode weighting scheme.",
437
+ )
438
+ parser.add_argument(
439
+ "--precondition_outputs",
440
+ type=int,
441
+ default=1,
442
+ help="Whether to precondition model outputs.",
443
+ )
444
+
445
+ # Optimization arguments
446
+ parser.add_argument(
447
+ "--allow_tf32",
448
+ action="store_true",
449
+ help="Allow TF32 on Ampere GPUs.",
450
+ )
451
+ parser.add_argument(
452
+ "--dataloader_num_workers",
453
+ type=int,
454
+ default=0,
455
+ help="Number of data loading workers.",
456
+ )
457
+ parser.add_argument(
458
+ "--use_8bit_adam",
459
+ action="store_true",
460
+ help="Use 8-bit Adam optimizer."
461
+ )
462
+ parser.add_argument(
463
+ "--adam_beta1",
464
+ type=float,
465
+ default=0.9,
466
+ help="Beta1 for Adam optimizer."
467
+ )
468
+ parser.add_argument(
469
+ "--adam_beta2",
470
+ type=float,
471
+ default=0.999,
472
+ help="Beta2 for Adam optimizer."
473
+ )
474
+ parser.add_argument(
475
+ "--adam_weight_decay",
476
+ type=float,
477
+ default=1e-2,
478
+ help="Weight decay for Adam."
479
+ )
480
+ parser.add_argument(
481
+ "--adam_epsilon",
482
+ type=float,
483
+ default=1e-08,
484
+ help="Epsilon for Adam optimizer."
485
+ )
486
+ parser.add_argument(
487
+ "--max_grad_norm",
488
+ default=1.0,
489
+ type=float,
490
+ help="Max gradient norm."
491
+ )
492
+
493
+ # Hub and logging arguments
494
+ parser.add_argument(
495
+ "--push_to_hub",
496
+ action="store_true",
497
+ help="Push model to the Hub."
498
+ )
499
+ parser.add_argument(
500
+ "--hub_token",
501
+ type=str,
502
+ default=None,
503
+ help="Token for Model Hub."
504
+ )
505
+ parser.add_argument(
506
+ "--hub_model_id",
507
+ type=str,
508
+ default=None,
509
+ help="Repository name for the Hub.",
510
+ )
511
+ parser.add_argument(
512
+ "--logging_dir",
513
+ type=str,
514
+ default="logs",
515
+ help="TensorBoard log directory.",
516
+ )
517
+ parser.add_argument(
518
+ "--report_to",
519
+ type=str,
520
+ default="tensorboard",
521
+ help="Logging integration to use.",
522
+ )
523
+ parser.add_argument(
524
+ "--mixed_precision",
525
+ type=str,
526
+ default=None,
527
+ choices=["no", "fp16", "bf16"],
528
+ help="Mixed precision type.",
529
+ )
530
+ parser.add_argument(
531
+ "--local_rank",
532
+ type=int,
533
+ default=-1,
534
+ help="Local rank for distributed training."
535
+ )
536
+
537
+ # LoRA arguments
538
+ parser.add_argument(
539
+ "--rank",
540
+ type=int,
541
+ default=64,
542
+ help="LoRA rank dimension.",
543
+ )
544
+
545
+ if input_args is not None:
546
+ args = parser.parse_args(input_args)
547
+ else:
548
+ args = parser.parse_args()
549
+
550
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
551
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
552
+ args.local_rank = env_local_rank
553
+
554
+ # Sanity checks
555
+ if args.dataset_name is None and args.train_data_dir is None:
556
+ raise ValueError("Need either a dataset name or a training folder.")
557
+
558
+ return args
559
+
560
+
561
+ DATASET_NAME_MAPPING = {
562
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
563
+ }
564
+
565
+
566
+ def tokenize_prompt(tokenizer, prompt):
567
+ """Tokenize prompt using the given tokenizer."""
568
+ text_inputs = tokenizer(
569
+ prompt,
570
+ padding="max_length",
571
+ max_length=77,
572
+ truncation=True,
573
+ return_tensors="pt",
574
+ )
575
+ return text_inputs.input_ids
576
+
577
+
578
+ def _encode_prompt_with_t5(
579
+ text_encoder,
580
+ tokenizer,
581
+ max_sequence_length,
582
+ prompt=None,
583
+ num_images_per_prompt=1,
584
+ device=None,
585
+ text_input_ids=None,
586
+ ):
587
+ """Encode prompt using T5 text encoder."""
588
+ if prompt is not None:
589
+ prompt = [prompt] if isinstance(prompt, str) else prompt
590
+ batch_size = len(prompt)
591
+ else:
592
+ # When prompt is None, we must have text_input_ids
593
+ if text_input_ids is None:
594
+ raise ValueError("Either prompt or text_input_ids must be provided")
595
+ batch_size = text_input_ids.shape[0]
596
+
597
+ if tokenizer is not None and prompt is not None:
598
+ text_inputs = tokenizer(
599
+ prompt,
600
+ padding="max_length",
601
+ max_length=max_sequence_length,
602
+ truncation=True,
603
+ add_special_tokens=True,
604
+ return_tensors="pt",
605
+ )
606
+ text_input_ids = text_inputs.input_ids
607
+ else:
608
+ if text_input_ids is None:
609
+ raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
610
+
611
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
612
+ dtype = text_encoder.dtype
613
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
614
+
615
+ _, seq_len, _ = prompt_embeds.shape
616
+ # duplicate text embeddings for each generation per prompt
617
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
618
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
619
+
620
+ return prompt_embeds
621
+
622
+
623
+ def _encode_prompt_with_clip(
624
+ text_encoder,
625
+ tokenizer,
626
+ prompt: str,
627
+ device=None,
628
+ text_input_ids=None,
629
+ num_images_per_prompt: int = 1,
630
+ ):
631
+ """Encode prompt using CLIP text encoder."""
632
+ if prompt is not None:
633
+ prompt = [prompt] if isinstance(prompt, str) else prompt
634
+ batch_size = len(prompt)
635
+ else:
636
+ # When prompt is None, we must have text_input_ids
637
+ if text_input_ids is None:
638
+ raise ValueError("Either prompt or text_input_ids must be provided")
639
+ batch_size = text_input_ids.shape[0]
640
+
641
+ if tokenizer is not None and prompt is not None:
642
+ text_inputs = tokenizer(
643
+ prompt,
644
+ padding="max_length",
645
+ max_length=77,
646
+ truncation=True,
647
+ return_tensors="pt",
648
+ )
649
+ text_input_ids = text_inputs.input_ids
650
+ else:
651
+ if text_input_ids is None:
652
+ raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
653
+
654
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
655
+ pooled_prompt_embeds = prompt_embeds[0]
656
+ prompt_embeds = prompt_embeds.hidden_states[-2]
657
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
658
+
659
+ _, seq_len, _ = prompt_embeds.shape
660
+ # duplicate text embeddings for each generation per prompt
661
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
662
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
663
+
664
+ return prompt_embeds, pooled_prompt_embeds
665
+
666
+
667
+ def encode_prompt(
668
+ text_encoders,
669
+ tokenizers,
670
+ prompt: str,
671
+ max_sequence_length,
672
+ device=None,
673
+ num_images_per_prompt: int = 1,
674
+ text_input_ids_list=None,
675
+ ):
676
+ """Encode prompt using all three text encoders (SD3 architecture)."""
677
+ if prompt is not None:
678
+ prompt = [prompt] if isinstance(prompt, str) else prompt
679
+
680
+ # Process CLIP encoders (first two)
681
+ clip_tokenizers = tokenizers[:2]
682
+ clip_text_encoders = text_encoders[:2]
683
+
684
+ clip_prompt_embeds_list = []
685
+ clip_pooled_prompt_embeds_list = []
686
+
687
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
688
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
689
+ text_encoder=text_encoder,
690
+ tokenizer=tokenizer,
691
+ prompt=prompt,
692
+ device=device if device is not None else text_encoder.device,
693
+ num_images_per_prompt=num_images_per_prompt,
694
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
695
+ )
696
+ clip_prompt_embeds_list.append(prompt_embeds)
697
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
698
+
699
+ # Concatenate CLIP embeddings
700
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
701
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
702
+
703
+ # Process T5 encoder (third encoder)
704
+ t5_prompt_embed = _encode_prompt_with_t5(
705
+ text_encoders[-1],
706
+ tokenizers[-1],
707
+ max_sequence_length,
708
+ prompt=prompt,
709
+ num_images_per_prompt=num_images_per_prompt,
710
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
711
+ device=device if device is not None else text_encoders[-1].device,
712
+ )
713
+
714
+ # Pad CLIP embeddings to match T5 embedding dimension
715
+ clip_prompt_embeds = torch.nn.functional.pad(
716
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
717
+ )
718
+ # Concatenate all embeddings
719
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
720
+
721
+ return prompt_embeds, pooled_prompt_embeds
722
+
723
+
724
+ def load_dataset_from_jsonl(metadata_path, data_dir, accelerator=None):
725
+ """
726
+ 从 metadata.jsonl 文件加载数据集,避免扫描所有文件。
727
+ 这对于大型数据集在分布式训练中非常重要。
728
+
729
+ 注意:只让主进程读取 jsonl 文件,然后创建数据集。
730
+ 其他进程会等待主进程完成后再继续。
731
+
732
+ Args:
733
+ metadata_path: metadata.jsonl 文件路径
734
+ data_dir: 数据集根目录
735
+ accelerator: Accelerator 对象,用于多进程同步
736
+
737
+ Returns:
738
+ datasets.DatasetDict
739
+ """
740
+ if accelerator is None or accelerator.is_main_process:
741
+ print(f"[INFO] Loading dataset from metadata.jsonl: {metadata_path}", flush=True)
742
+
743
+ # 读取 metadata.jsonl(只让主进程读取,避免多进程竞争)
744
+ data_list = []
745
+ if os.path.exists(metadata_path):
746
+ with open(metadata_path, 'r', encoding='utf-8') as f:
747
+ for line_num, line in enumerate(f):
748
+ try:
749
+ item = json.loads(line.strip())
750
+ file_name = item.get('file_name', '')
751
+ caption = item.get('caption', '')
752
+
753
+ # 构建完整路径
754
+ image_path = os.path.join(data_dir, file_name)
755
+
756
+ # 注意:这里不检查文件是否存在,因为:
757
+ # 1. 检查会非常慢(需要访问文件系统)
758
+ # 2. 在 DataLoader 中加载时会自然处理不存在的文件
759
+ # 3. 可以大大加快数据集加载速度
760
+ data_list.append({
761
+ 'image': image_path,
762
+ 'text': caption
763
+ })
764
+
765
+ # 每处理 100000 条记录打印一次进度(减少打印频率)
766
+ if (line_num + 1) % 100000 == 0 and (accelerator is None or accelerator.is_main_process):
767
+ print(f"[INFO] Processed {line_num + 1} entries from metadata.jsonl", flush=True)
768
+
769
+ except json.JSONDecodeError as e:
770
+ if accelerator is None or accelerator.is_main_process:
771
+ print(f"[WARNING] Skipping invalid JSON at line {line_num + 1}: {e}", flush=True)
772
+ continue
773
+
774
+ if accelerator is None or accelerator.is_main_process:
775
+ print(f"[INFO] Loaded {len(data_list)} image-caption pairs from metadata.jsonl", flush=True)
776
+ else:
777
+ raise FileNotFoundError(f"metadata.jsonl not found at: {metadata_path}")
778
+
779
+ # 创建数据集
780
+ # 注意:'image' 列存储的是路径字符串,不是 PIL Image 对象
781
+ # 图片会在预处理函数中延迟加载
782
+ dataset = datasets.Dataset.from_list(data_list)
783
+
784
+ return datasets.DatasetDict({'train': dataset})
785
+
786
+
787
+ def main(args):
788
+ """Main training function."""
789
+ if args.report_to == "wandb" and args.hub_token is not None:
790
+ raise ValueError(
791
+ "You cannot use both --report_to=wandb and --hub_token due to security risk."
792
+ )
793
+
794
+ logging_dir = Path(args.output_dir, args.logging_dir)
795
+
796
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
797
+ raise ValueError(
798
+ "Mixed precision training with bfloat16 is not supported on MPS."
799
+ )
800
+
801
+ # GPU多卡训练检查
802
+ if torch.cuda.is_available():
803
+ num_gpus = torch.cuda.device_count()
804
+ print(f"Found {num_gpus} GPUs available")
805
+ if num_gpus > 1:
806
+ print(f"Multi-GPU training enabled with {num_gpus} GPUs")
807
+ else:
808
+ print("No CUDA GPUs found, training on CPU")
809
+
810
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
811
+ # 优化多GPU训练的DDP参数
812
+ kwargs = DistributedDataParallelKwargs(
813
+ find_unused_parameters=True,
814
+ gradient_as_bucket_view=True, # 提高多GPU训练效率
815
+ static_graph=False, # 动态图支持
816
+ )
817
+ accelerator = Accelerator(
818
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
819
+ mixed_precision=args.mixed_precision,
820
+ log_with=args.report_to,
821
+ project_config=accelerator_project_config,
822
+ kwargs_handlers=[kwargs],
823
+ )
824
+
825
+ # Logging setup
826
+ logging.basicConfig(
827
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
828
+ datefmt="%m/%d/%Y %H:%M:%S",
829
+ level=logging.INFO,
830
+ )
831
+ logger.info(accelerator.state, main_process_only=False)
832
+ if accelerator.is_main_process:
833
+ print("[INFO] Accelerator initialized", flush=True)
834
+
835
+ # 记录多GPU训练信息
836
+ if accelerator.is_main_process:
837
+ logger.info(f"Number of processes: {accelerator.num_processes}")
838
+ logger.info(f"Distributed type: {accelerator.distributed_type}")
839
+ logger.info(f"Mixed precision: {accelerator.mixed_precision}")
840
+ if torch.cuda.is_available():
841
+ for i in range(torch.cuda.device_count()):
842
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
843
+ logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
844
+
845
+ if accelerator.is_local_main_process:
846
+ datasets.utils.logging.set_verbosity_warning()
847
+ transformers.utils.logging.set_verbosity_warning()
848
+ diffusers.utils.logging.set_verbosity_info()
849
+ else:
850
+ datasets.utils.logging.set_verbosity_error()
851
+ transformers.utils.logging.set_verbosity_error()
852
+ diffusers.utils.logging.set_verbosity_error()
853
+
854
+ # Set training seed
855
+ if args.seed is not None:
856
+ set_seed(args.seed)
857
+ if accelerator.is_main_process:
858
+ print(f"[INFO] Seed set to {args.seed}", flush=True)
859
+
860
+ # Create output directory
861
+ if accelerator.is_main_process:
862
+ if args.output_dir is not None:
863
+ os.makedirs(args.output_dir, exist_ok=True)
864
+
865
+ if args.push_to_hub:
866
+ repo_id = create_repo(
867
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
868
+ exist_ok=True,
869
+ token=args.hub_token
870
+ ).repo_id
871
+
872
+ if accelerator.is_main_process:
873
+ print("[INFO] Loading tokenizers...", flush=True)
874
+
875
+ # Load tokenizers (three for SD3)
876
+ tokenizer_one = CLIPTokenizer.from_pretrained(
877
+ args.pretrained_model_name_or_path,
878
+ subfolder="tokenizer",
879
+ revision=args.revision,
880
+ )
881
+ tokenizer_two = CLIPTokenizer.from_pretrained(
882
+ args.pretrained_model_name_or_path,
883
+ subfolder="tokenizer_2",
884
+ revision=args.revision,
885
+ )
886
+
887
+ if accelerator.is_main_process:
888
+ print("[INFO] Tokenizers loaded. Loading text encoders, VAE, and transformer...", flush=True)
889
+ tokenizer_three = T5TokenizerFast.from_pretrained(
890
+ args.pretrained_model_name_or_path,
891
+ subfolder="tokenizer_3",
892
+ revision=args.revision,
893
+ )
894
+
895
+ # Import text encoder classes
896
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
897
+ args.pretrained_model_name_or_path, args.revision
898
+ )
899
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
900
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
901
+ )
902
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
903
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
904
+ )
905
+
906
+ # Load models
907
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
908
+ args.pretrained_model_name_or_path, subfolder="scheduler"
909
+ )
910
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
911
+
912
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
913
+ args.pretrained_model_name_or_path,
914
+ subfolder="text_encoder",
915
+ revision=args.revision,
916
+ variant=args.variant
917
+ )
918
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
919
+ args.pretrained_model_name_or_path,
920
+ subfolder="text_encoder_2",
921
+ revision=args.revision,
922
+ variant=args.variant
923
+ )
924
+ text_encoder_three = text_encoder_cls_three.from_pretrained(
925
+ args.pretrained_model_name_or_path,
926
+ subfolder="text_encoder_3",
927
+ revision=args.revision,
928
+ variant=args.variant
929
+ )
930
+
931
+ vae = AutoencoderKL.from_pretrained(
932
+ args.pretrained_model_name_or_path,
933
+ subfolder="vae",
934
+ revision=args.revision,
935
+ variant=args.variant,
936
+ )
937
+
938
+ transformer = SD3Transformer2DModel.from_pretrained(
939
+ args.pretrained_model_name_or_path,
940
+ subfolder="transformer",
941
+ revision=args.revision,
942
+ variant=args.variant
943
+ )
944
+
945
+ if accelerator.is_main_process:
946
+ print("[INFO] Text encoders, VAE, and transformer loaded", flush=True)
947
+
948
+ # Freeze non-trainable weights
949
+ transformer.requires_grad_(False)
950
+ vae.requires_grad_(False)
951
+ text_encoder_one.requires_grad_(False)
952
+ text_encoder_two.requires_grad_(False)
953
+ text_encoder_three.requires_grad_(False)
954
+
955
+ # Set precision
956
+ weight_dtype = torch.float32
957
+ if accelerator.mixed_precision == "fp16":
958
+ weight_dtype = torch.float16
959
+ elif accelerator.mixed_precision == "bf16":
960
+ weight_dtype = torch.bfloat16
961
+
962
+ # Move models to device
963
+ vae.to(accelerator.device, dtype=torch.float32) # VAE stays in fp32
964
+ transformer.to(accelerator.device, dtype=weight_dtype)
965
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
966
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
967
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
968
+
969
+ # Enable gradient checkpointing
970
+ if args.gradient_checkpointing:
971
+ transformer.enable_gradient_checkpointing()
972
+ if args.train_text_encoder:
973
+ text_encoder_one.gradient_checkpointing_enable()
974
+ text_encoder_two.gradient_checkpointing_enable()
975
+
976
+ # Configure LoRA for transformer
977
+ transformer_lora_config = LoraConfig(
978
+ r=args.rank,
979
+ lora_alpha=args.rank,
980
+ init_lora_weights="gaussian",
981
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
982
+ )
983
+ transformer.add_adapter(transformer_lora_config)
984
+
985
+ # Configure LoRA for text encoders if enabled
986
+ if args.train_text_encoder:
987
+ text_lora_config = LoraConfig(
988
+ r=args.rank,
989
+ lora_alpha=args.rank,
990
+ init_lora_weights="gaussian",
991
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
992
+ )
993
+ text_encoder_one.add_adapter(text_lora_config)
994
+ text_encoder_two.add_adapter(text_lora_config)
995
+ # Note: T5 encoder typically doesn't use LoRA
996
+
997
+ def unwrap_model(model):
998
+ model = accelerator.unwrap_model(model)
999
+ model = model._orig_mod if is_compiled_module(model) else model
1000
+ return model
1001
+
1002
+ # Enable TF32 for faster training
1003
+ if args.allow_tf32 and torch.cuda.is_available():
1004
+ torch.backends.cuda.matmul.allow_tf32 = True
1005
+
1006
+ # Scale learning rate
1007
+ if args.scale_lr:
1008
+ args.learning_rate = (
1009
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1010
+ )
1011
+
1012
+ # Cast trainable parameters to float32
1013
+ if args.mixed_precision == "fp16":
1014
+ models = [transformer]
1015
+ if args.train_text_encoder:
1016
+ models.extend([text_encoder_one, text_encoder_two])
1017
+ cast_training_params(models, dtype=torch.float32)
1018
+
1019
+ # Setup optimizer
1020
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
1021
+ if args.train_text_encoder:
1022
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
1023
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
1024
+ params_to_optimize = (
1025
+ transformer_lora_parameters
1026
+ + text_lora_parameters_one
1027
+ + text_lora_parameters_two
1028
+ )
1029
+ else:
1030
+ params_to_optimize = transformer_lora_parameters
1031
+
1032
+ # Create optimizer
1033
+ if args.use_8bit_adam:
1034
+ try:
1035
+ import bitsandbytes as bnb
1036
+ except ImportError:
1037
+ raise ImportError("To use 8-bit Adam, install bitsandbytes: pip install bitsandbytes")
1038
+ optimizer_class = bnb.optim.AdamW8bit
1039
+ else:
1040
+ optimizer_class = torch.optim.AdamW
1041
+
1042
+ optimizer = optimizer_class(
1043
+ params_to_optimize,
1044
+ lr=args.learning_rate,
1045
+ betas=(args.adam_beta1, args.adam_beta2),
1046
+ weight_decay=args.adam_weight_decay,
1047
+ eps=args.adam_epsilon,
1048
+ )
1049
+
1050
+ if accelerator.is_main_process:
1051
+ print("[INFO] Optimizer created. Loading dataset...", flush=True)
1052
+
1053
+ # Load dataset - 使用 main_process_first 避免多进程竞争
1054
+ # 优先使用 metadata.jsonl 文件,避免扫描所有文件
1055
+ with accelerator.main_process_first():
1056
+ metadata_path = None
1057
+ if args.train_data_dir is not None:
1058
+ # 检查是否存在 metadata.jsonl
1059
+ potential_metadata = os.path.join(args.train_data_dir, "metadata.jsonl")
1060
+ if os.path.exists(potential_metadata):
1061
+ metadata_path = potential_metadata
1062
+
1063
+ if metadata_path is not None:
1064
+ # 使用 metadata.jsonl 加载数据集(更高效,避免扫描所有文件)
1065
+ if accelerator.is_main_process:
1066
+ print(f"[INFO] Found metadata.jsonl, using efficient loading method", flush=True)
1067
+ dataset = load_dataset_from_jsonl(metadata_path, args.train_data_dir, accelerator)
1068
+ elif args.dataset_name is not None:
1069
+ dataset = load_dataset(
1070
+ args.dataset_name,
1071
+ args.dataset_config_name,
1072
+ cache_dir=args.cache_dir,
1073
+ data_dir=args.train_data_dir
1074
+ )
1075
+ else:
1076
+ # 回退到 imagefolder(可能会很慢)
1077
+ if accelerator.is_main_process:
1078
+ print("[WARNING] No metadata.jsonl found, using imagefolder (may be slow for large datasets)", flush=True)
1079
+ data_files = {}
1080
+ if args.train_data_dir is not None:
1081
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
1082
+ dataset = load_dataset(
1083
+ "imagefolder",
1084
+ data_files=data_files,
1085
+ cache_dir=args.cache_dir,
1086
+ )
1087
+ if accelerator.is_main_process:
1088
+ print("[INFO] Dataset loaded successfully.", flush=True)
1089
+
1090
+ # 确保所有进程等待数据集加载完成
1091
+ accelerator.wait_for_everyone()
1092
+
1093
+ if accelerator.is_main_process:
1094
+ print("[INFO] All processes synchronized. Building transforms and DataLoader...", flush=True)
1095
+
1096
+ # Preprocessing
1097
+ column_names = dataset["train"].column_names
1098
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
1099
+
1100
+ if accelerator.is_main_process:
1101
+ print(f"[INFO] Dataset columns: {column_names}", flush=True)
1102
+
1103
+ # 智能选择 image 列:优先使用指定的列,如果不存在则自动回退
1104
+ if args.image_column is not None and args.image_column in column_names:
1105
+ # 如果指定了列名且存在,使用指定的列
1106
+ image_column = args.image_column
1107
+ else:
1108
+ # 自动选择可用的 image 列
1109
+ if 'image' in column_names:
1110
+ image_column = 'image'
1111
+ else:
1112
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
1113
+
1114
+ # 如果用户指定了列名但不存在,给出警告
1115
+ if args.image_column is not None and args.image_column != image_column:
1116
+ if accelerator.is_main_process:
1117
+ print(f"[WARNING] Specified image_column '{args.image_column}' not found. Using '{image_column}' instead.", flush=True)
1118
+
1119
+ if accelerator.is_main_process:
1120
+ print(f"[INFO] Using image column: {image_column}", flush=True)
1121
+
1122
+ # 智能选择 caption 列:优先使用指定的列,如果不存在则自动回退
1123
+ if args.caption_column is not None and args.caption_column in column_names:
1124
+ # 如果指定了列名且存在,使用指定的列
1125
+ caption_column = args.caption_column
1126
+ else:
1127
+ # 自动选择可用的 caption 列
1128
+ if 'text' in column_names:
1129
+ caption_column = 'text'
1130
+ elif 'caption' in column_names:
1131
+ caption_column = 'caption'
1132
+ else:
1133
+ caption_column = dataset_columns[1] if dataset_columns is not None else (column_names[1] if len(column_names) > 1 else column_names[0])
1134
+
1135
+ # 如果用户指定了列名但不存在,给出警告
1136
+ if args.caption_column is not None and args.caption_column != caption_column:
1137
+ if accelerator.is_main_process:
1138
+ print(f"[WARNING] Specified caption_column '{args.caption_column}' not found. Using '{caption_column}' instead.", flush=True)
1139
+
1140
+ if accelerator.is_main_process:
1141
+ print(f"[INFO] Using caption column: {caption_column}", flush=True)
1142
+
1143
+ def tokenize_captions(examples, is_train=True):
1144
+ captions = []
1145
+ for caption in examples[caption_column]:
1146
+ if isinstance(caption, str):
1147
+ captions.append(caption)
1148
+ elif isinstance(caption, (list, np.ndarray)):
1149
+ captions.append(random.choice(caption) if is_train else caption[0])
1150
+ else:
1151
+ raise ValueError(f"Caption column should contain strings or lists of strings.")
1152
+
1153
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
1154
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
1155
+ tokens_three = tokenize_prompt(tokenizer_three, captions)
1156
+ return tokens_one, tokens_two, tokens_three
1157
+
1158
+ # Image transforms
1159
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
1160
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
1161
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
1162
+ train_transforms = transforms.Compose([
1163
+ transforms.ToTensor(),
1164
+ transforms.Normalize([0.5], [0.5]),
1165
+ ])
1166
+
1167
+ def preprocess_train(examples):
1168
+ # 处理图片:如果 image_column 中是路径字符串,则加载图片;如果是 PIL Image,则直接使用
1169
+ images = []
1170
+ for img in examples[image_column]:
1171
+ if isinstance(img, str):
1172
+ # 如果是路径字符串,加载图片
1173
+ try:
1174
+ img = Image.open(img).convert("RGB")
1175
+ except Exception as e:
1176
+ # 如果加载失败,创建一个占位符
1177
+ if accelerator.is_main_process:
1178
+ print(f"[WARNING] Failed to load image {img}: {e}", flush=True)
1179
+ img = Image.new('RGB', (args.resolution, args.resolution), color='black')
1180
+ elif hasattr(img, 'convert'):
1181
+ # 如果是 PIL Image,直接使用
1182
+ img = img.convert("RGB")
1183
+ else:
1184
+ raise ValueError(f"Unexpected image type: {type(img)}")
1185
+ images.append(img)
1186
+ original_sizes = []
1187
+ all_images = []
1188
+ crop_top_lefts = []
1189
+
1190
+ for image in images:
1191
+ original_sizes.append((image.height, image.width))
1192
+ image = train_resize(image)
1193
+ if args.random_flip and random.random() < 0.5:
1194
+ image = train_flip(image)
1195
+ if args.center_crop:
1196
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1197
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
1198
+ image = train_crop(image)
1199
+ else:
1200
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
1201
+ image = crop(image, y1, x1, h, w)
1202
+ crop_top_left = (y1, x1)
1203
+ crop_top_lefts.append(crop_top_left)
1204
+ image = train_transforms(image)
1205
+ all_images.append(image)
1206
+
1207
+ examples["original_sizes"] = original_sizes
1208
+ examples["crop_top_lefts"] = crop_top_lefts
1209
+ examples["pixel_values"] = all_images
1210
+
1211
+ tokens_one, tokens_two, tokens_three = tokenize_captions(examples)
1212
+ examples["input_ids_one"] = tokens_one
1213
+ examples["input_ids_two"] = tokens_two
1214
+ examples["input_ids_three"] = tokens_three
1215
+ return examples
1216
+
1217
+ with accelerator.main_process_first():
1218
+ if args.max_train_samples is not None:
1219
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
1220
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
1221
+
1222
+ def collate_fn(examples):
1223
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
1224
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
1225
+ original_sizes = [example["original_sizes"] for example in examples]
1226
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
1227
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
1228
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
1229
+ input_ids_three = torch.stack([example["input_ids_three"] for example in examples])
1230
+
1231
+ return {
1232
+ "pixel_values": pixel_values,
1233
+ "input_ids_one": input_ids_one,
1234
+ "input_ids_two": input_ids_two,
1235
+ "input_ids_three": input_ids_three,
1236
+ "original_sizes": original_sizes,
1237
+ "crop_top_lefts": crop_top_lefts,
1238
+ }
1239
+
1240
+ # 针对多GPU训练优化dataloader设置
1241
+ if args.dataloader_num_workers == 0 and accelerator.num_processes > 1:
1242
+ # 多GPU训练时自动设置数据加载器worker数量
1243
+ args.dataloader_num_workers = min(4, os.cpu_count() // accelerator.num_processes)
1244
+ logger.info(f"Auto-setting dataloader_num_workers to {args.dataloader_num_workers} for multi-GPU training")
1245
+
1246
+ train_dataloader = torch.utils.data.DataLoader(
1247
+ train_dataset,
1248
+ shuffle=True,
1249
+ collate_fn=collate_fn,
1250
+ batch_size=args.train_batch_size,
1251
+ num_workers=args.dataloader_num_workers,
1252
+ pin_memory=True, # 提高GPU数据传输效率
1253
+ persistent_workers=args.dataloader_num_workers > 0, # 保持worker进程活跃
1254
+ )
1255
+
1256
+ if accelerator.is_main_process:
1257
+ print("[INFO] DataLoader ready. Computing training steps and scheduler...", flush=True)
1258
+
1259
+ # Scheduler and math around training steps
1260
+ overrode_max_train_steps = False
1261
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1262
+ if args.max_train_steps is None:
1263
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1264
+ overrode_max_train_steps = True
1265
+
1266
+ lr_scheduler = get_scheduler(
1267
+ args.lr_scheduler,
1268
+ optimizer=optimizer,
1269
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
1270
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
1271
+ )
1272
+
1273
+ # Prepare everything with accelerator
1274
+ if args.train_text_encoder:
1275
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1276
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1277
+ )
1278
+ else:
1279
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1280
+ transformer, optimizer, train_dataloader, lr_scheduler
1281
+ )
1282
+
1283
+ # Recalculate training steps
1284
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1285
+ if overrode_max_train_steps:
1286
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1287
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1288
+
1289
+ # Initialize trackers
1290
+ if accelerator.is_main_process:
1291
+ try:
1292
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
1293
+ except Exception as e:
1294
+ logger.warning(f"Failed to initialize trackers: {e}")
1295
+ logger.warning("Continuing without tracking. You can monitor training through console logs.")
1296
+ # Set report_to to None to avoid further tracking attempts
1297
+ args.report_to = None
1298
+
1299
+ # Train!
1300
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1301
+ logger.info("***** Running training *****")
1302
+ logger.info(f" Num examples = {len(train_dataset)}")
1303
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1304
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1305
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1306
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1307
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1308
+ logger.info(f" Number of GPU processes = {accelerator.num_processes}")
1309
+ if accelerator.num_processes > 1:
1310
+ logger.info(f" Effective batch size per GPU = {args.train_batch_size * args.gradient_accumulation_steps}")
1311
+ logger.info(f" Total effective batch size across all GPUs = {total_batch_size}")
1312
+
1313
+ global_step = 0
1314
+ first_epoch = 0
1315
+ if accelerator.is_main_process:
1316
+ print(
1317
+ f"[INFO] Training setup complete. num_examples={len(train_dataset)}, "
1318
+ f"max_train_steps={args.max_train_steps}, num_epochs={args.num_train_epochs}",
1319
+ flush=True,
1320
+ )
1321
+
1322
+ # Resume from checkpoint if specified
1323
+ if args.resume_from_checkpoint:
1324
+ if args.resume_from_checkpoint != "latest":
1325
+ path = os.path.basename(args.resume_from_checkpoint)
1326
+ else:
1327
+ dirs = os.listdir(args.output_dir)
1328
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1329
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1330
+ path = dirs[-1] if len(dirs) > 0 else None
1331
+
1332
+ if path is None:
1333
+ accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
1334
+ args.resume_from_checkpoint = None
1335
+ initial_global_step = 0
1336
+ else:
1337
+ accelerator.print(f"Resuming from checkpoint {path}")
1338
+ accelerator.load_state(os.path.join(args.output_dir, path))
1339
+ global_step = int(path.split("-")[1])
1340
+ initial_global_step = global_step
1341
+ first_epoch = global_step // num_update_steps_per_epoch
1342
+ else:
1343
+ initial_global_step = 0
1344
+
1345
+ progress_bar = tqdm(
1346
+ range(0, args.max_train_steps),
1347
+ initial=initial_global_step,
1348
+ desc="Steps",
1349
+ disable=not accelerator.is_local_main_process,
1350
+ )
1351
+
1352
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1353
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1354
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1355
+ timesteps = timesteps.to(accelerator.device)
1356
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1357
+ sigma = sigmas[step_indices].flatten()
1358
+ while len(sigma.shape) < n_dim:
1359
+ sigma = sigma.unsqueeze(-1)
1360
+ return sigma
1361
+
1362
+ # Training loop
1363
+ for epoch in range(first_epoch, args.num_train_epochs):
1364
+ transformer.train()
1365
+ if args.train_text_encoder:
1366
+ text_encoder_one.train()
1367
+ text_encoder_two.train()
1368
+
1369
+ if accelerator.is_main_process:
1370
+ print(
1371
+ f"[INFO] Starting epoch {epoch + 1}/{args.num_train_epochs}, current global_step={global_step}",
1372
+ flush=True,
1373
+ )
1374
+
1375
+ train_loss = 0.0
1376
+ for step, batch in enumerate(train_dataloader):
1377
+ with accelerator.accumulate(transformer):
1378
+ # Convert images to latent space
1379
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1380
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1381
+
1382
+ # Apply VAE scaling
1383
+ vae_config_shift_factor = vae.config.shift_factor
1384
+ vae_config_scaling_factor = vae.config.scaling_factor
1385
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1386
+ model_input = model_input.to(dtype=weight_dtype)
1387
+
1388
+ # Encode prompts
1389
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1390
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
1391
+ tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three],
1392
+ prompt=None,
1393
+ max_sequence_length=args.max_sequence_length,
1394
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"], batch["input_ids_three"]],
1395
+ )
1396
+
1397
+ # Sample noise and timesteps
1398
+ noise = torch.randn_like(model_input)
1399
+ bsz = model_input.shape[0]
1400
+
1401
+ # Flow Matching timestep sampling
1402
+ u = compute_density_for_timestep_sampling(
1403
+ weighting_scheme=args.weighting_scheme,
1404
+ batch_size=bsz,
1405
+ logit_mean=args.logit_mean,
1406
+ logit_std=args.logit_std,
1407
+ mode_scale=args.mode_scale,
1408
+ )
1409
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1410
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1411
+
1412
+ # Flow Matching interpolation
1413
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1414
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1415
+
1416
+ # Predict using SD3 Transformer
1417
+ model_pred = transformer(
1418
+ hidden_states=noisy_model_input,
1419
+ timestep=timesteps,
1420
+ encoder_hidden_states=prompt_embeds,
1421
+ pooled_projections=pooled_prompt_embeds,
1422
+ return_dict=False,
1423
+ )[0]
1424
+
1425
+ # Compute target for Flow Matching
1426
+ if args.precondition_outputs:
1427
+ model_pred = model_pred * (-sigmas) + noisy_model_input
1428
+ target = model_input
1429
+ else:
1430
+ target = noise - model_input
1431
+
1432
+ # Compute loss with weighting
1433
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1434
+ loss = torch.mean(
1435
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1436
+ 1,
1437
+ )
1438
+ loss = loss.mean()
1439
+
1440
+ # Gather loss across processes
1441
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1442
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
1443
+
1444
+ # Backpropagate
1445
+ accelerator.backward(loss)
1446
+ if accelerator.sync_gradients:
1447
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1448
+ optimizer.step()
1449
+ lr_scheduler.step()
1450
+ optimizer.zero_grad()
1451
+
1452
+ # Checks if the accelerator has performed an optimization step
1453
+ if accelerator.sync_gradients:
1454
+ progress_bar.update(1)
1455
+ global_step += 1
1456
+ if hasattr(accelerator, 'trackers') and accelerator.trackers:
1457
+ accelerator.log({"train_loss": train_loss}, step=global_step)
1458
+ train_loss = 0.0
1459
+
1460
+ if accelerator.is_main_process and global_step % 1000 == 0:
1461
+ print(
1462
+ f"[INFO] Optimization step completed at global_step={global_step}, "
1463
+ f"recent step_loss={loss.detach().item():.4f}",
1464
+ flush=True,
1465
+ )
1466
+
1467
+ # Save checkpoint
1468
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1469
+ if global_step % args.checkpointing_steps == 0:
1470
+ if args.checkpoints_total_limit is not None:
1471
+ checkpoints = os.listdir(args.output_dir)
1472
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1473
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1474
+
1475
+ if len(checkpoints) >= args.checkpoints_total_limit:
1476
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1477
+ removing_checkpoints = checkpoints[0:num_to_remove]
1478
+ logger.info(f"Removing {len(removing_checkpoints)} checkpoints")
1479
+ for removing_checkpoint in removing_checkpoints:
1480
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1481
+ shutil.rmtree(removing_checkpoint)
1482
+
1483
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1484
+ accelerator.save_state(save_path)
1485
+ logger.info(f"Saved state to {save_path}")
1486
+
1487
+ # 同时保存标准的LoRA权重格式,方便采样时直接加载
1488
+ try:
1489
+ # 获取当前模型的LoRA权重
1490
+ unwrapped_transformer = unwrap_model(transformer)
1491
+ transformer_lora_layers = get_peft_model_state_dict(unwrapped_transformer)
1492
+
1493
+ text_encoder_lora_layers = None
1494
+ text_encoder_2_lora_layers = None
1495
+ if args.train_text_encoder:
1496
+ unwrapped_text_encoder_one = unwrap_model(text_encoder_one)
1497
+ unwrapped_text_encoder_two = unwrap_model(text_encoder_two)
1498
+ text_encoder_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_one)
1499
+ text_encoder_2_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_two)
1500
+
1501
+ # 保存为标准LoRA格式到checkpoint目录
1502
+ StableDiffusion3Pipeline.save_lora_weights(
1503
+ save_directory=save_path,
1504
+ transformer_lora_layers=transformer_lora_layers,
1505
+ text_encoder_lora_layers=text_encoder_lora_layers,
1506
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1507
+ )
1508
+ logger.info(f"Saved LoRA weights in standard format to {save_path}")
1509
+ except Exception as e:
1510
+ logger.warning(f"Failed to save LoRA weights in standard format: {e}")
1511
+ logger.warning("Checkpoint saved with accelerator format only. You can extract LoRA weights later.")
1512
+
1513
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1514
+ progress_bar.set_postfix(**logs)
1515
+
1516
+ if global_step >= args.max_train_steps:
1517
+ break
1518
+
1519
+ # Validation
1520
+ if accelerator.is_main_process:
1521
+ if args.validation_prompt is not None :#and epoch % args.validation_epochs == 0:
1522
+ print(f"[INFO] Running validation for epoch {epoch + 1}, global_step={global_step}", flush=True)
1523
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1524
+ args.pretrained_model_name_or_path,
1525
+ vae=vae,
1526
+ text_encoder=unwrap_model(text_encoder_one),
1527
+ text_encoder_2=unwrap_model(text_encoder_two),
1528
+ text_encoder_3=unwrap_model(text_encoder_three),
1529
+ transformer=unwrap_model(transformer),
1530
+ revision=args.revision,
1531
+ variant=args.variant,
1532
+ torch_dtype=weight_dtype,
1533
+ )
1534
+ images = log_validation(pipeline, args, accelerator, epoch, global_step=global_step)
1535
+ del pipeline
1536
+ torch.cuda.empty_cache()
1537
+
1538
+ # Save final LoRA weights
1539
+ accelerator.wait_for_everyone()
1540
+ if accelerator.is_main_process:
1541
+ transformer = unwrap_model(transformer)
1542
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
1543
+
1544
+ if args.train_text_encoder:
1545
+ text_encoder_one = unwrap_model(text_encoder_one)
1546
+ text_encoder_two = unwrap_model(text_encoder_two)
1547
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
1548
+ text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
1549
+ else:
1550
+ text_encoder_lora_layers = None
1551
+ text_encoder_2_lora_layers = None
1552
+
1553
+ StableDiffusion3Pipeline.save_lora_weights(
1554
+ save_directory=args.output_dir,
1555
+ transformer_lora_layers=transformer_lora_layers,
1556
+ text_encoder_lora_layers=text_encoder_lora_layers,
1557
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1558
+ )
1559
+
1560
+ # Final inference
1561
+ if args.mixed_precision == "fp16":
1562
+ vae.to(weight_dtype)
1563
+
1564
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1565
+ args.pretrained_model_name_or_path,
1566
+ vae=vae,
1567
+ revision=args.revision,
1568
+ variant=args.variant,
1569
+ torch_dtype=weight_dtype,
1570
+ )
1571
+ pipeline.load_lora_weights(args.output_dir)
1572
+
1573
+ if args.validation_prompt and args.num_validation_images > 0:
1574
+ images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True, global_step=global_step)
1575
+
1576
+ if args.push_to_hub:
1577
+ save_model_card(
1578
+ repo_id,
1579
+ images=images,
1580
+ base_model=args.pretrained_model_name_or_path,
1581
+ dataset_name=args.dataset_name,
1582
+ train_text_encoder=args.train_text_encoder,
1583
+ repo_folder=args.output_dir,
1584
+ )
1585
+ upload_folder(
1586
+ repo_id=repo_id,
1587
+ folder_path=args.output_dir,
1588
+ commit_message="End of training",
1589
+ ignore_patterns=["step_*", "epoch_*"],
1590
+ )
1591
+
1592
+ accelerator.end_training()
1593
+
1594
+
1595
+ if __name__ == "__main__":
1596
+ args = parse_args()
1597
+ main(args)
train_lora_sd3_new.py ADDED
@@ -0,0 +1,1422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """SD3 LoRA fine-tuning script for text2image generation."""
17
+
18
+ import argparse
19
+ import copy
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ from contextlib import nullcontext
26
+ from pathlib import Path
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
37
+ from datasets import load_dataset
38
+ from huggingface_hub import create_repo, upload_folder
39
+ from packaging import version
40
+ from peft import LoraConfig, set_peft_model_state_dict
41
+ from peft.utils import get_peft_model_state_dict
42
+ from torchvision import transforms
43
+ from torchvision.transforms.functional import crop
44
+ from tqdm.auto import tqdm
45
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
46
+
47
+ import diffusers
48
+ from diffusers import (
49
+ AutoencoderKL,
50
+ FlowMatchEulerDiscreteScheduler,
51
+ SD3Transformer2DModel,
52
+ StableDiffusion3Pipeline,
53
+ )
54
+ from diffusers.optimization import get_scheduler
55
+ from diffusers.training_utils import (
56
+ _set_state_dict_into_text_encoder,
57
+ cast_training_params,
58
+ compute_density_for_timestep_sampling,
59
+ compute_loss_weighting_for_sd3,
60
+ free_memory,
61
+ )
62
+ from diffusers.utils import (
63
+ check_min_version,
64
+ convert_unet_state_dict_to_peft,
65
+ is_wandb_available,
66
+ )
67
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
68
+ from diffusers.utils.torch_utils import is_compiled_module
69
+
70
+ if is_wandb_available():
71
+ import wandb
72
+
73
+ # Check minimum diffusers version
74
+ check_min_version("0.30.0")
75
+
76
+ logger = get_logger(__name__)
77
+
78
+
79
+ def save_model_card(
80
+ repo_id: str,
81
+ images: list = None,
82
+ base_model: str = None,
83
+ dataset_name: str = None,
84
+ train_text_encoder: bool = False,
85
+ repo_folder: str = None,
86
+ vae_path: str = None,
87
+ ):
88
+ """Save model card for SD3 LoRA model."""
89
+ img_str = ""
90
+ if images is not None:
91
+ for i, image in enumerate(images):
92
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
93
+ img_str += f"![img_{i}](./image_{i}.png)\n"
94
+
95
+ model_description = f"""
96
+ # SD3 LoRA text2image fine-tuning - {repo_id}
97
+
98
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
99
+ {img_str}
100
+
101
+ LoRA for the text encoder was enabled: {train_text_encoder}.
102
+
103
+ Special VAE used for training: {vae_path}.
104
+ """
105
+ model_card = load_or_create_model_card(
106
+ repo_id_or_path=repo_id,
107
+ from_training=True,
108
+ license="other",
109
+ base_model=base_model,
110
+ model_description=model_description,
111
+ inference=True,
112
+ )
113
+
114
+ tags = [
115
+ "stable-diffusion-3",
116
+ "stable-diffusion-3-diffusers",
117
+ "text-to-image",
118
+ "diffusers",
119
+ "diffusers-training",
120
+ "lora",
121
+ "sd3",
122
+ ]
123
+ model_card = populate_model_card(model_card, tags=tags)
124
+ model_card.save(os.path.join(repo_folder, "README.md"))
125
+
126
+
127
+ def log_validation(
128
+ pipeline,
129
+ args,
130
+ accelerator,
131
+ epoch,
132
+ is_final_validation=False,
133
+ global_step=None,
134
+ ):
135
+ """Run validation and log images."""
136
+ logger.info(
137
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
138
+ f" {args.validation_prompt}."
139
+ )
140
+ pipeline = pipeline.to(accelerator.device)
141
+ pipeline.set_progress_bar_config(disable=True)
142
+
143
+ # run inference
144
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
145
+ pipeline_args = {"prompt": args.validation_prompt}
146
+
147
+ if torch.backends.mps.is_available():
148
+ autocast_ctx = nullcontext()
149
+ else:
150
+ autocast_ctx = torch.autocast(accelerator.device.type)
151
+
152
+ with autocast_ctx:
153
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
154
+
155
+ # Save images to output directory
156
+ if accelerator.is_main_process:
157
+ validation_dir = os.path.join(args.output_dir, "validation_images")
158
+ os.makedirs(validation_dir, exist_ok=True)
159
+ for i, image in enumerate(images):
160
+ # Create filename with step and epoch information
161
+ if global_step is not None:
162
+ filename = f"validation_step_{global_step}_epoch_{epoch}_img_{i}.png"
163
+ else:
164
+ filename = f"validation_epoch_{epoch}_img_{i}.png"
165
+
166
+ image_path = os.path.join(validation_dir, filename)
167
+ image.save(image_path)
168
+ logger.info(f"Saved validation image: {image_path}")
169
+
170
+ for tracker in accelerator.trackers if hasattr(accelerator, 'trackers') and accelerator.trackers else []:
171
+ phase_name = "test" if is_final_validation else "validation"
172
+ try:
173
+ if tracker.name == "tensorboard":
174
+ np_images = np.stack([np.asarray(img) for img in images])
175
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
176
+ if tracker.name == "wandb":
177
+ tracker.log(
178
+ {
179
+ phase_name: [
180
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
181
+ ]
182
+ }
183
+ )
184
+ except Exception as e:
185
+ logger.warning(f"Failed to log to {tracker.name}: {e}")
186
+
187
+ del pipeline
188
+ free_memory()
189
+ return images
190
+
191
+
192
+ def import_model_class_from_model_name_or_path(
193
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
194
+ ):
195
+ """Import the correct text encoder class."""
196
+ text_encoder_config = PretrainedConfig.from_pretrained(
197
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
198
+ )
199
+ model_class = text_encoder_config.architectures[0]
200
+
201
+ if model_class == "CLIPTextModelWithProjection":
202
+ from transformers import CLIPTextModelWithProjection
203
+ return CLIPTextModelWithProjection
204
+ elif model_class == "T5EncoderModel":
205
+ from transformers import T5EncoderModel
206
+ return T5EncoderModel
207
+ else:
208
+ raise ValueError(f"{model_class} is not supported.")
209
+
210
+
211
+ def parse_args(input_args=None):
212
+ """Parse command line arguments."""
213
+ parser = argparse.ArgumentParser(description="SD3 LoRA training script.")
214
+
215
+ # Model arguments
216
+ parser.add_argument(
217
+ "--pretrained_model_name_or_path",
218
+ type=str,
219
+ default=None,
220
+ required=True,
221
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
222
+ )
223
+ parser.add_argument(
224
+ "--revision",
225
+ type=str,
226
+ default=None,
227
+ help="Revision of pretrained model identifier from huggingface.co/models.",
228
+ )
229
+ parser.add_argument(
230
+ "--variant",
231
+ type=str,
232
+ default=None,
233
+ help="Variant of the model files, e.g. fp16",
234
+ )
235
+
236
+ # Dataset arguments
237
+ parser.add_argument(
238
+ "--dataset_name",
239
+ type=str,
240
+ default=None,
241
+ help="The name of the Dataset to train on.",
242
+ )
243
+ parser.add_argument(
244
+ "--dataset_config_name",
245
+ type=str,
246
+ default=None,
247
+ help="The config of the Dataset.",
248
+ )
249
+ parser.add_argument(
250
+ "--train_data_dir",
251
+ type=str,
252
+ default=None,
253
+ help="A folder containing the training data.",
254
+ )
255
+ parser.add_argument(
256
+ "--image_column",
257
+ type=str,
258
+ default="image",
259
+ help="The column of the dataset containing an image."
260
+ )
261
+ parser.add_argument(
262
+ "--caption_column",
263
+ type=str,
264
+ default="caption",
265
+ help="The column of the dataset containing a caption.",
266
+ )
267
+
268
+ # Training arguments
269
+ parser.add_argument(
270
+ "--max_sequence_length",
271
+ type=int,
272
+ default=77,
273
+ help="Maximum sequence length to use with the T5 text encoder",
274
+ )
275
+ parser.add_argument(
276
+ "--validation_prompt",
277
+ type=str,
278
+ default=None,
279
+ help="A prompt used during validation.",
280
+ )
281
+ parser.add_argument(
282
+ "--num_validation_images",
283
+ type=int,
284
+ default=4,
285
+ help="Number of images for validation.",
286
+ )
287
+ parser.add_argument(
288
+ "--validation_epochs",
289
+ type=int,
290
+ default=1,
291
+ help="Run validation every X epochs.",
292
+ )
293
+ parser.add_argument(
294
+ "--max_train_samples",
295
+ type=int,
296
+ default=None,
297
+ help="Truncate the number of training examples.",
298
+ )
299
+ parser.add_argument(
300
+ "--output_dir",
301
+ type=str,
302
+ default="sd3-lora-finetuned",
303
+ help="Output directory for model predictions and checkpoints.",
304
+ )
305
+ parser.add_argument(
306
+ "--cache_dir",
307
+ type=str,
308
+ default=None,
309
+ help="Directory to store downloaded models and datasets.",
310
+ )
311
+ parser.add_argument(
312
+ "--seed",
313
+ type=int,
314
+ default=None,
315
+ help="A seed for reproducible training."
316
+ )
317
+ parser.add_argument(
318
+ "--resolution",
319
+ type=int,
320
+ default=1024,
321
+ help="Image resolution for training.",
322
+ )
323
+ parser.add_argument(
324
+ "--center_crop",
325
+ default=False,
326
+ action="store_true",
327
+ help="Whether to center crop input images.",
328
+ )
329
+ parser.add_argument(
330
+ "--random_flip",
331
+ action="store_true",
332
+ help="Whether to randomly flip images horizontally.",
333
+ )
334
+ parser.add_argument(
335
+ "--train_text_encoder",
336
+ action="store_true",
337
+ help="Whether to train the text encoder.",
338
+ )
339
+ parser.add_argument(
340
+ "--train_batch_size",
341
+ type=int,
342
+ default=16,
343
+ help="Batch size for training dataloader."
344
+ )
345
+ parser.add_argument(
346
+ "--num_train_epochs",
347
+ type=int,
348
+ default=100
349
+ )
350
+ parser.add_argument(
351
+ "--max_train_steps",
352
+ type=int,
353
+ default=None,
354
+ help="Total number of training steps.",
355
+ )
356
+ parser.add_argument(
357
+ "--checkpointing_steps",
358
+ type=int,
359
+ default=500,
360
+ help="Save checkpoint every X updates.",
361
+ )
362
+ parser.add_argument(
363
+ "--checkpoints_total_limit",
364
+ type=int,
365
+ default=None,
366
+ help="Max number of checkpoints to store.",
367
+ )
368
+ parser.add_argument(
369
+ "--resume_from_checkpoint",
370
+ type=str,
371
+ default=None,
372
+ help="Path to resume training from checkpoint.",
373
+ )
374
+ parser.add_argument(
375
+ "--gradient_accumulation_steps",
376
+ type=int,
377
+ default=1,
378
+ help="Number of update steps to accumulate.",
379
+ )
380
+ parser.add_argument(
381
+ "--gradient_checkpointing",
382
+ action="store_true",
383
+ help="Use gradient checkpointing to save memory.",
384
+ )
385
+ parser.add_argument(
386
+ "--learning_rate",
387
+ type=float,
388
+ default=1e-4,
389
+ help="Initial learning rate.",
390
+ )
391
+ parser.add_argument(
392
+ "--scale_lr",
393
+ action="store_true",
394
+ default=False,
395
+ help="Scale learning rate by number of GPUs, etc.",
396
+ )
397
+ parser.add_argument(
398
+ "--lr_scheduler",
399
+ type=str,
400
+ default="constant",
401
+ help="Learning rate scheduler type.",
402
+ )
403
+ parser.add_argument(
404
+ "--lr_warmup_steps",
405
+ type=int,
406
+ default=500,
407
+ help="Number of warmup steps."
408
+ )
409
+
410
+ # SD3 specific arguments
411
+ parser.add_argument(
412
+ "--weighting_scheme",
413
+ type=str,
414
+ default="logit_normal",
415
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
416
+ help="Weighting scheme for flow matching loss.",
417
+ )
418
+ parser.add_argument(
419
+ "--logit_mean",
420
+ type=float,
421
+ default=0.0,
422
+ help="Mean for logit_normal weighting."
423
+ )
424
+ parser.add_argument(
425
+ "--logit_std",
426
+ type=float,
427
+ default=1.0,
428
+ help="Std for logit_normal weighting."
429
+ )
430
+ parser.add_argument(
431
+ "--mode_scale",
432
+ type=float,
433
+ default=1.29,
434
+ help="Scale for mode weighting scheme.",
435
+ )
436
+ parser.add_argument(
437
+ "--precondition_outputs",
438
+ type=int,
439
+ default=1,
440
+ help="Whether to precondition model outputs.",
441
+ )
442
+
443
+ # Optimization arguments
444
+ parser.add_argument(
445
+ "--allow_tf32",
446
+ action="store_true",
447
+ help="Allow TF32 on Ampere GPUs.",
448
+ )
449
+ parser.add_argument(
450
+ "--dataloader_num_workers",
451
+ type=int,
452
+ default=0,
453
+ help="Number of data loading workers.",
454
+ )
455
+ parser.add_argument(
456
+ "--use_8bit_adam",
457
+ action="store_true",
458
+ help="Use 8-bit Adam optimizer."
459
+ )
460
+ parser.add_argument(
461
+ "--adam_beta1",
462
+ type=float,
463
+ default=0.9,
464
+ help="Beta1 for Adam optimizer."
465
+ )
466
+ parser.add_argument(
467
+ "--adam_beta2",
468
+ type=float,
469
+ default=0.999,
470
+ help="Beta2 for Adam optimizer."
471
+ )
472
+ parser.add_argument(
473
+ "--adam_weight_decay",
474
+ type=float,
475
+ default=1e-2,
476
+ help="Weight decay for Adam."
477
+ )
478
+ parser.add_argument(
479
+ "--adam_epsilon",
480
+ type=float,
481
+ default=1e-08,
482
+ help="Epsilon for Adam optimizer."
483
+ )
484
+ parser.add_argument(
485
+ "--max_grad_norm",
486
+ default=1.0,
487
+ type=float,
488
+ help="Max gradient norm."
489
+ )
490
+
491
+ # Hub and logging arguments
492
+ parser.add_argument(
493
+ "--push_to_hub",
494
+ action="store_true",
495
+ help="Push model to the Hub."
496
+ )
497
+ parser.add_argument(
498
+ "--hub_token",
499
+ type=str,
500
+ default=None,
501
+ help="Token for Model Hub."
502
+ )
503
+ parser.add_argument(
504
+ "--hub_model_id",
505
+ type=str,
506
+ default=None,
507
+ help="Repository name for the Hub.",
508
+ )
509
+ parser.add_argument(
510
+ "--logging_dir",
511
+ type=str,
512
+ default="logs",
513
+ help="TensorBoard log directory.",
514
+ )
515
+ parser.add_argument(
516
+ "--report_to",
517
+ type=str,
518
+ default="tensorboard",
519
+ help="Logging integration to use.",
520
+ )
521
+ parser.add_argument(
522
+ "--mixed_precision",
523
+ type=str,
524
+ default=None,
525
+ choices=["no", "fp16", "bf16"],
526
+ help="Mixed precision type.",
527
+ )
528
+ parser.add_argument(
529
+ "--local_rank",
530
+ type=int,
531
+ default=-1,
532
+ help="Local rank for distributed training."
533
+ )
534
+
535
+ # LoRA arguments
536
+ parser.add_argument(
537
+ "--rank",
538
+ type=int,
539
+ default=64,
540
+ help="LoRA rank dimension.",
541
+ )
542
+
543
+ if input_args is not None:
544
+ args = parser.parse_args(input_args)
545
+ else:
546
+ args = parser.parse_args()
547
+
548
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
549
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
550
+ args.local_rank = env_local_rank
551
+
552
+ # Sanity checks
553
+ if args.dataset_name is None and args.train_data_dir is None:
554
+ raise ValueError("Need either a dataset name or a training folder.")
555
+
556
+ return args
557
+
558
+
559
+ DATASET_NAME_MAPPING = {
560
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
561
+ }
562
+
563
+
564
+ def tokenize_prompt(tokenizer, prompt):
565
+ """Tokenize prompt using the given tokenizer."""
566
+ text_inputs = tokenizer(
567
+ prompt,
568
+ padding="max_length",
569
+ max_length=77,
570
+ truncation=True,
571
+ return_tensors="pt",
572
+ )
573
+ return text_inputs.input_ids
574
+
575
+
576
+ def _encode_prompt_with_t5(
577
+ text_encoder,
578
+ tokenizer,
579
+ max_sequence_length,
580
+ prompt=None,
581
+ num_images_per_prompt=1,
582
+ device=None,
583
+ text_input_ids=None,
584
+ ):
585
+ """Encode prompt using T5 text encoder."""
586
+ if prompt is not None:
587
+ prompt = [prompt] if isinstance(prompt, str) else prompt
588
+ batch_size = len(prompt)
589
+ else:
590
+ # When prompt is None, we must have text_input_ids
591
+ if text_input_ids is None:
592
+ raise ValueError("Either prompt or text_input_ids must be provided")
593
+ batch_size = text_input_ids.shape[0]
594
+
595
+ if tokenizer is not None and prompt is not None:
596
+ text_inputs = tokenizer(
597
+ prompt,
598
+ padding="max_length",
599
+ max_length=max_sequence_length,
600
+ truncation=True,
601
+ add_special_tokens=True,
602
+ return_tensors="pt",
603
+ )
604
+ text_input_ids = text_inputs.input_ids
605
+ else:
606
+ if text_input_ids is None:
607
+ raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
608
+
609
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
610
+ dtype = text_encoder.dtype
611
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
612
+
613
+ _, seq_len, _ = prompt_embeds.shape
614
+ # duplicate text embeddings for each generation per prompt
615
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
616
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
617
+
618
+ return prompt_embeds
619
+
620
+
621
+ def _encode_prompt_with_clip(
622
+ text_encoder,
623
+ tokenizer,
624
+ prompt: str,
625
+ device=None,
626
+ text_input_ids=None,
627
+ num_images_per_prompt: int = 1,
628
+ ):
629
+ """Encode prompt using CLIP text encoder."""
630
+ if prompt is not None:
631
+ prompt = [prompt] if isinstance(prompt, str) else prompt
632
+ batch_size = len(prompt)
633
+ else:
634
+ # When prompt is None, we must have text_input_ids
635
+ if text_input_ids is None:
636
+ raise ValueError("Either prompt or text_input_ids must be provided")
637
+ batch_size = text_input_ids.shape[0]
638
+
639
+ if tokenizer is not None and prompt is not None:
640
+ text_inputs = tokenizer(
641
+ prompt,
642
+ padding="max_length",
643
+ max_length=77,
644
+ truncation=True,
645
+ return_tensors="pt",
646
+ )
647
+ text_input_ids = text_inputs.input_ids
648
+ else:
649
+ if text_input_ids is None:
650
+ raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
651
+
652
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
653
+ pooled_prompt_embeds = prompt_embeds[0]
654
+ prompt_embeds = prompt_embeds.hidden_states[-2]
655
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
656
+
657
+ _, seq_len, _ = prompt_embeds.shape
658
+ # duplicate text embeddings for each generation per prompt
659
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
660
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
661
+
662
+ return prompt_embeds, pooled_prompt_embeds
663
+
664
+
665
+ def encode_prompt(
666
+ text_encoders,
667
+ tokenizers,
668
+ prompt: str,
669
+ max_sequence_length,
670
+ device=None,
671
+ num_images_per_prompt: int = 1,
672
+ text_input_ids_list=None,
673
+ ):
674
+ """Encode prompt using all three text encoders (SD3 architecture)."""
675
+ if prompt is not None:
676
+ prompt = [prompt] if isinstance(prompt, str) else prompt
677
+
678
+ # Process CLIP encoders (first two)
679
+ clip_tokenizers = tokenizers[:2]
680
+ clip_text_encoders = text_encoders[:2]
681
+
682
+ clip_prompt_embeds_list = []
683
+ clip_pooled_prompt_embeds_list = []
684
+
685
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
686
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
687
+ text_encoder=text_encoder,
688
+ tokenizer=tokenizer,
689
+ prompt=prompt,
690
+ device=device if device is not None else text_encoder.device,
691
+ num_images_per_prompt=num_images_per_prompt,
692
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
693
+ )
694
+ clip_prompt_embeds_list.append(prompt_embeds)
695
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
696
+
697
+ # Concatenate CLIP embeddings
698
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
699
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
700
+
701
+ # Process T5 encoder (third encoder)
702
+ t5_prompt_embed = _encode_prompt_with_t5(
703
+ text_encoders[-1],
704
+ tokenizers[-1],
705
+ max_sequence_length,
706
+ prompt=prompt,
707
+ num_images_per_prompt=num_images_per_prompt,
708
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
709
+ device=device if device is not None else text_encoders[-1].device,
710
+ )
711
+
712
+ # Pad CLIP embeddings to match T5 embedding dimension
713
+ clip_prompt_embeds = torch.nn.functional.pad(
714
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
715
+ )
716
+ # Concatenate all embeddings
717
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
718
+
719
+ return prompt_embeds, pooled_prompt_embeds
720
+
721
+
722
+ def main(args):
723
+ """Main training function."""
724
+ if args.report_to == "wandb" and args.hub_token is not None:
725
+ raise ValueError(
726
+ "You cannot use both --report_to=wandb and --hub_token due to security risk."
727
+ )
728
+
729
+ logging_dir = Path(args.output_dir, args.logging_dir)
730
+
731
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
732
+ raise ValueError(
733
+ "Mixed precision training with bfloat16 is not supported on MPS."
734
+ )
735
+
736
+ # GPU多卡训练检查
737
+ if torch.cuda.is_available():
738
+ num_gpus = torch.cuda.device_count()
739
+ print(f"Found {num_gpus} GPUs available")
740
+ if num_gpus > 1:
741
+ print(f"Multi-GPU training enabled with {num_gpus} GPUs")
742
+ else:
743
+ print("No CUDA GPUs found, training on CPU")
744
+
745
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
746
+ # 优化多GPU训练的DDP参数
747
+ kwargs = DistributedDataParallelKwargs(
748
+ find_unused_parameters=True,
749
+ gradient_as_bucket_view=True, # 提高多GPU训练效率
750
+ static_graph=False, # 动态图支持
751
+ )
752
+ accelerator = Accelerator(
753
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
754
+ mixed_precision=args.mixed_precision,
755
+ log_with=args.report_to,
756
+ project_config=accelerator_project_config,
757
+ kwargs_handlers=[kwargs],
758
+ )
759
+
760
+ # Logging setup
761
+ logging.basicConfig(
762
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
763
+ datefmt="%m/%d/%Y %H:%M:%S",
764
+ level=logging.INFO,
765
+ )
766
+ logger.info(accelerator.state, main_process_only=False)
767
+
768
+ # 记录多GPU训练信息
769
+ if accelerator.is_main_process:
770
+ logger.info(f"Number of processes: {accelerator.num_processes}")
771
+ logger.info(f"Distributed type: {accelerator.distributed_type}")
772
+ logger.info(f"Mixed precision: {accelerator.mixed_precision}")
773
+ if torch.cuda.is_available():
774
+ for i in range(torch.cuda.device_count()):
775
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
776
+ logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
777
+
778
+ if accelerator.is_local_main_process:
779
+ datasets.utils.logging.set_verbosity_warning()
780
+ transformers.utils.logging.set_verbosity_warning()
781
+ diffusers.utils.logging.set_verbosity_info()
782
+ else:
783
+ datasets.utils.logging.set_verbosity_error()
784
+ transformers.utils.logging.set_verbosity_error()
785
+ diffusers.utils.logging.set_verbosity_error()
786
+
787
+ # Set training seed
788
+ if args.seed is not None:
789
+ set_seed(args.seed)
790
+
791
+ # Create output directory
792
+ if accelerator.is_main_process:
793
+ if args.output_dir is not None:
794
+ os.makedirs(args.output_dir, exist_ok=True)
795
+
796
+ if args.push_to_hub:
797
+ repo_id = create_repo(
798
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
799
+ exist_ok=True,
800
+ token=args.hub_token
801
+ ).repo_id
802
+
803
+ # Load tokenizers (three for SD3)
804
+ tokenizer_one = CLIPTokenizer.from_pretrained(
805
+ args.pretrained_model_name_or_path,
806
+ subfolder="tokenizer",
807
+ revision=args.revision,
808
+ )
809
+ tokenizer_two = CLIPTokenizer.from_pretrained(
810
+ args.pretrained_model_name_or_path,
811
+ subfolder="tokenizer_2",
812
+ revision=args.revision,
813
+ )
814
+ tokenizer_three = T5TokenizerFast.from_pretrained(
815
+ args.pretrained_model_name_or_path,
816
+ subfolder="tokenizer_3",
817
+ revision=args.revision,
818
+ )
819
+
820
+ # Import text encoder classes
821
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
822
+ args.pretrained_model_name_or_path, args.revision
823
+ )
824
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
825
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
826
+ )
827
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
828
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
829
+ )
830
+
831
+ # Load models
832
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
833
+ args.pretrained_model_name_or_path, subfolder="scheduler"
834
+ )
835
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
836
+
837
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
838
+ args.pretrained_model_name_or_path,
839
+ subfolder="text_encoder",
840
+ revision=args.revision,
841
+ variant=args.variant
842
+ )
843
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
844
+ args.pretrained_model_name_or_path,
845
+ subfolder="text_encoder_2",
846
+ revision=args.revision,
847
+ variant=args.variant
848
+ )
849
+ text_encoder_three = text_encoder_cls_three.from_pretrained(
850
+ args.pretrained_model_name_or_path,
851
+ subfolder="text_encoder_3",
852
+ revision=args.revision,
853
+ variant=args.variant
854
+ )
855
+
856
+ vae = AutoencoderKL.from_pretrained(
857
+ args.pretrained_model_name_or_path,
858
+ subfolder="vae",
859
+ revision=args.revision,
860
+ variant=args.variant,
861
+ )
862
+
863
+ transformer = SD3Transformer2DModel.from_pretrained(
864
+ args.pretrained_model_name_or_path,
865
+ subfolder="transformer",
866
+ revision=args.revision,
867
+ variant=args.variant
868
+ )
869
+
870
+ # Freeze non-trainable weights
871
+ transformer.requires_grad_(False)
872
+ vae.requires_grad_(False)
873
+ text_encoder_one.requires_grad_(False)
874
+ text_encoder_two.requires_grad_(False)
875
+ text_encoder_three.requires_grad_(False)
876
+
877
+ # Set precision
878
+ weight_dtype = torch.float32
879
+ if accelerator.mixed_precision == "fp16":
880
+ weight_dtype = torch.float16
881
+ elif accelerator.mixed_precision == "bf16":
882
+ weight_dtype = torch.bfloat16
883
+
884
+ # Move models to device
885
+ vae.to(accelerator.device, dtype=torch.float32) # VAE stays in fp32
886
+ transformer.to(accelerator.device, dtype=weight_dtype)
887
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
888
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
889
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
890
+
891
+ # Enable gradient checkpointing
892
+ if args.gradient_checkpointing:
893
+ transformer.enable_gradient_checkpointing()
894
+ if args.train_text_encoder:
895
+ text_encoder_one.gradient_checkpointing_enable()
896
+ text_encoder_two.gradient_checkpointing_enable()
897
+
898
+ # Configure LoRA for transformer
899
+ transformer_lora_config = LoraConfig(
900
+ r=args.rank,
901
+ lora_alpha=args.rank,
902
+ init_lora_weights="gaussian",
903
+ target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
904
+ )
905
+ transformer.add_adapter(transformer_lora_config)
906
+
907
+ # Configure LoRA for text encoders if enabled
908
+ if args.train_text_encoder:
909
+ text_lora_config = LoraConfig(
910
+ r=args.rank,
911
+ lora_alpha=args.rank,
912
+ init_lora_weights="gaussian",
913
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
914
+ )
915
+ text_encoder_one.add_adapter(text_lora_config)
916
+ text_encoder_two.add_adapter(text_lora_config)
917
+ # Note: T5 encoder typically doesn't use LoRA
918
+
919
+ def unwrap_model(model):
920
+ model = accelerator.unwrap_model(model)
921
+ model = model._orig_mod if is_compiled_module(model) else model
922
+ return model
923
+
924
+ # Enable TF32 for faster training
925
+ if args.allow_tf32 and torch.cuda.is_available():
926
+ torch.backends.cuda.matmul.allow_tf32 = True
927
+
928
+ # Scale learning rate
929
+ if args.scale_lr:
930
+ args.learning_rate = (
931
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
932
+ )
933
+
934
+ # Cast trainable parameters to float32
935
+ if args.mixed_precision == "fp16":
936
+ models = [transformer]
937
+ if args.train_text_encoder:
938
+ models.extend([text_encoder_one, text_encoder_two])
939
+ cast_training_params(models, dtype=torch.float32)
940
+
941
+ # Setup optimizer
942
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
943
+ if args.train_text_encoder:
944
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
945
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
946
+ params_to_optimize = (
947
+ transformer_lora_parameters
948
+ + text_lora_parameters_one
949
+ + text_lora_parameters_two
950
+ )
951
+ else:
952
+ params_to_optimize = transformer_lora_parameters
953
+
954
+ # Create optimizer
955
+ if args.use_8bit_adam:
956
+ try:
957
+ import bitsandbytes as bnb
958
+ except ImportError:
959
+ raise ImportError("To use 8-bit Adam, install bitsandbytes: pip install bitsandbytes")
960
+ optimizer_class = bnb.optim.AdamW8bit
961
+ else:
962
+ optimizer_class = torch.optim.AdamW
963
+
964
+ optimizer = optimizer_class(
965
+ params_to_optimize,
966
+ lr=args.learning_rate,
967
+ betas=(args.adam_beta1, args.adam_beta2),
968
+ weight_decay=args.adam_weight_decay,
969
+ eps=args.adam_epsilon,
970
+ )
971
+
972
+ # Load dataset
973
+ if args.dataset_name is not None:
974
+ dataset = load_dataset(
975
+ args.dataset_name,
976
+ args.dataset_config_name,
977
+ cache_dir=args.cache_dir,
978
+ data_dir=args.train_data_dir
979
+ )
980
+ else:
981
+ data_files = {}
982
+ if args.train_data_dir is not None:
983
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
984
+ dataset = load_dataset(
985
+ "imagefolder",
986
+ data_files=data_files,
987
+ cache_dir=args.cache_dir,
988
+ )
989
+
990
+ # Preprocessing
991
+ column_names = dataset["train"].column_names
992
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
993
+
994
+ if args.image_column is None:
995
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
996
+ else:
997
+ image_column = args.image_column
998
+ if image_column not in column_names:
999
+ raise ValueError(f"--image_column '{args.image_column}' not found in: {column_names}")
1000
+
1001
+ if args.caption_column is None:
1002
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
1003
+ else:
1004
+ caption_column = args.caption_column
1005
+ if caption_column not in column_names:
1006
+ raise ValueError(f"--caption_column '{args.caption_column}' not found in: {column_names}")
1007
+
1008
+ def tokenize_captions(examples, is_train=True):
1009
+ captions = []
1010
+ for caption in examples[caption_column]:
1011
+ if isinstance(caption, str):
1012
+ captions.append(caption)
1013
+ elif isinstance(caption, (list, np.ndarray)):
1014
+ captions.append(random.choice(caption) if is_train else caption[0])
1015
+ else:
1016
+ raise ValueError(f"Caption column should contain strings or lists of strings.")
1017
+
1018
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
1019
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
1020
+ tokens_three = tokenize_prompt(tokenizer_three, captions)
1021
+ return tokens_one, tokens_two, tokens_three
1022
+
1023
+ # Image transforms
1024
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
1025
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
1026
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
1027
+ train_transforms = transforms.Compose([
1028
+ transforms.ToTensor(),
1029
+ transforms.Normalize([0.5], [0.5]),
1030
+ ])
1031
+
1032
+ def preprocess_train(examples):
1033
+ images = [image.convert("RGB") for image in examples[image_column]]
1034
+ original_sizes = []
1035
+ all_images = []
1036
+ crop_top_lefts = []
1037
+
1038
+ for image in images:
1039
+ original_sizes.append((image.height, image.width))
1040
+ image = train_resize(image)
1041
+ if args.random_flip and random.random() < 0.5:
1042
+ image = train_flip(image)
1043
+ if args.center_crop:
1044
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1045
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
1046
+ image = train_crop(image)
1047
+ else:
1048
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
1049
+ image = crop(image, y1, x1, h, w)
1050
+ crop_top_left = (y1, x1)
1051
+ crop_top_lefts.append(crop_top_left)
1052
+ image = train_transforms(image)
1053
+ all_images.append(image)
1054
+
1055
+ examples["original_sizes"] = original_sizes
1056
+ examples["crop_top_lefts"] = crop_top_lefts
1057
+ examples["pixel_values"] = all_images
1058
+
1059
+ tokens_one, tokens_two, tokens_three = tokenize_captions(examples)
1060
+ examples["input_ids_one"] = tokens_one
1061
+ examples["input_ids_two"] = tokens_two
1062
+ examples["input_ids_three"] = tokens_three
1063
+ return examples
1064
+
1065
+ with accelerator.main_process_first():
1066
+ if args.max_train_samples is not None:
1067
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
1068
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
1069
+
1070
+ def collate_fn(examples):
1071
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
1072
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
1073
+ original_sizes = [example["original_sizes"] for example in examples]
1074
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
1075
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
1076
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
1077
+ input_ids_three = torch.stack([example["input_ids_three"] for example in examples])
1078
+
1079
+ return {
1080
+ "pixel_values": pixel_values,
1081
+ "input_ids_one": input_ids_one,
1082
+ "input_ids_two": input_ids_two,
1083
+ "input_ids_three": input_ids_three,
1084
+ "original_sizes": original_sizes,
1085
+ "crop_top_lefts": crop_top_lefts,
1086
+ }
1087
+
1088
+ # 针对多GPU训练优化dataloader设置
1089
+ if args.dataloader_num_workers == 0 and accelerator.num_processes > 1:
1090
+ # 多GPU训练时自动设置数据加载器worker数量
1091
+ args.dataloader_num_workers = min(4, os.cpu_count() // accelerator.num_processes)
1092
+ logger.info(f"Auto-setting dataloader_num_workers to {args.dataloader_num_workers} for multi-GPU training")
1093
+
1094
+ train_dataloader = torch.utils.data.DataLoader(
1095
+ train_dataset,
1096
+ shuffle=True,
1097
+ collate_fn=collate_fn,
1098
+ batch_size=args.train_batch_size,
1099
+ num_workers=args.dataloader_num_workers,
1100
+ pin_memory=True, # 提高GPU数据传输效率
1101
+ persistent_workers=args.dataloader_num_workers > 0, # 保持worker进程活跃
1102
+ )
1103
+
1104
+ # Scheduler and math around training steps
1105
+ overrode_max_train_steps = False
1106
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1107
+ if args.max_train_steps is None:
1108
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1109
+ overrode_max_train_steps = True
1110
+
1111
+ lr_scheduler = get_scheduler(
1112
+ args.lr_scheduler,
1113
+ optimizer=optimizer,
1114
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
1115
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
1116
+ )
1117
+
1118
+ # Prepare everything with accelerator
1119
+ if args.train_text_encoder:
1120
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1121
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1122
+ )
1123
+ else:
1124
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1125
+ transformer, optimizer, train_dataloader, lr_scheduler
1126
+ )
1127
+
1128
+ # Recalculate training steps
1129
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1130
+ if overrode_max_train_steps:
1131
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1132
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1133
+
1134
+ # Initialize trackers
1135
+ if accelerator.is_main_process:
1136
+ try:
1137
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
1138
+ except Exception as e:
1139
+ logger.warning(f"Failed to initialize trackers: {e}")
1140
+ logger.warning("Continuing without tracking. You can monitor training through console logs.")
1141
+ # Set report_to to None to avoid further tracking attempts
1142
+ args.report_to = None
1143
+
1144
+ # Train!
1145
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1146
+ logger.info("***** Running training *****")
1147
+ logger.info(f" Num examples = {len(train_dataset)}")
1148
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1149
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1150
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1151
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1152
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1153
+ logger.info(f" Number of GPU processes = {accelerator.num_processes}")
1154
+ if accelerator.num_processes > 1:
1155
+ logger.info(f" Effective batch size per GPU = {args.train_batch_size * args.gradient_accumulation_steps}")
1156
+ logger.info(f" Total effective batch size across all GPUs = {total_batch_size}")
1157
+
1158
+ global_step = 0
1159
+ first_epoch = 0
1160
+
1161
+ # Resume from checkpoint if specified
1162
+ if args.resume_from_checkpoint:
1163
+ if args.resume_from_checkpoint != "latest":
1164
+ path = os.path.basename(args.resume_from_checkpoint)
1165
+ else:
1166
+ dirs = os.listdir(args.output_dir)
1167
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1168
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1169
+ path = dirs[-1] if len(dirs) > 0 else None
1170
+
1171
+ if path is None:
1172
+ accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
1173
+ args.resume_from_checkpoint = None
1174
+ initial_global_step = 0
1175
+ else:
1176
+ accelerator.print(f"Resuming from checkpoint {path}")
1177
+ accelerator.load_state(os.path.join(args.output_dir, path))
1178
+ global_step = int(path.split("-")[1])
1179
+ initial_global_step = global_step
1180
+ first_epoch = global_step // num_update_steps_per_epoch
1181
+ else:
1182
+ initial_global_step = 0
1183
+
1184
+ progress_bar = tqdm(
1185
+ range(0, args.max_train_steps),
1186
+ initial=initial_global_step,
1187
+ desc="Steps",
1188
+ disable=not accelerator.is_local_main_process,
1189
+ )
1190
+
1191
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1192
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1193
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1194
+ timesteps = timesteps.to(accelerator.device)
1195
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1196
+ sigma = sigmas[step_indices].flatten()
1197
+ while len(sigma.shape) < n_dim:
1198
+ sigma = sigma.unsqueeze(-1)
1199
+ return sigma
1200
+
1201
+ # Training loop
1202
+ for epoch in range(first_epoch, args.num_train_epochs):
1203
+ transformer.train()
1204
+ if args.train_text_encoder:
1205
+ text_encoder_one.train()
1206
+ text_encoder_two.train()
1207
+
1208
+ train_loss = 0.0
1209
+ for step, batch in enumerate(train_dataloader):
1210
+ with accelerator.accumulate(transformer):
1211
+ # Convert images to latent space
1212
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1213
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1214
+
1215
+ # Apply VAE scaling
1216
+ vae_config_shift_factor = vae.config.shift_factor
1217
+ vae_config_scaling_factor = vae.config.scaling_factor
1218
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1219
+ model_input = model_input.to(dtype=weight_dtype)
1220
+
1221
+ # Encode prompts
1222
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1223
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
1224
+ tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three],
1225
+ prompt=None,
1226
+ max_sequence_length=args.max_sequence_length,
1227
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"], batch["input_ids_three"]],
1228
+ )
1229
+
1230
+ # Sample noise and timesteps
1231
+ noise = torch.randn_like(model_input)
1232
+ bsz = model_input.shape[0]
1233
+
1234
+ # Flow Matching timestep sampling
1235
+ u = compute_density_for_timestep_sampling(
1236
+ weighting_scheme=args.weighting_scheme,
1237
+ batch_size=bsz,
1238
+ logit_mean=args.logit_mean,
1239
+ logit_std=args.logit_std,
1240
+ mode_scale=args.mode_scale,
1241
+ )
1242
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1243
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1244
+
1245
+ # Flow Matching interpolation
1246
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1247
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1248
+
1249
+ # Predict using SD3 Transformer
1250
+ model_pred = transformer(
1251
+ hidden_states=noisy_model_input,
1252
+ timestep=timesteps,
1253
+ encoder_hidden_states=prompt_embeds,
1254
+ pooled_projections=pooled_prompt_embeds,
1255
+ return_dict=False,
1256
+ )[0]
1257
+
1258
+ # Compute target for Flow Matching
1259
+ if args.precondition_outputs:
1260
+ model_pred = model_pred * (-sigmas) + noisy_model_input
1261
+ target = model_input
1262
+ else:
1263
+ target = noise - model_input
1264
+
1265
+ # Compute loss with weighting
1266
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1267
+ loss = torch.mean(
1268
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1269
+ 1,
1270
+ )
1271
+ loss = loss.mean()
1272
+
1273
+ # Gather loss across processes
1274
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1275
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
1276
+
1277
+ # Backpropagate
1278
+ accelerator.backward(loss)
1279
+ if accelerator.sync_gradients:
1280
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1281
+ optimizer.step()
1282
+ lr_scheduler.step()
1283
+ optimizer.zero_grad()
1284
+
1285
+ # Checks if the accelerator has performed an optimization step
1286
+ if accelerator.sync_gradients:
1287
+ progress_bar.update(1)
1288
+ global_step += 1
1289
+ if hasattr(accelerator, 'trackers') and accelerator.trackers:
1290
+ accelerator.log({"train_loss": train_loss}, step=global_step)
1291
+ train_loss = 0.0
1292
+
1293
+ # Save checkpoint
1294
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1295
+ if global_step % args.checkpointing_steps == 0:
1296
+ if args.checkpoints_total_limit is not None:
1297
+ checkpoints = os.listdir(args.output_dir)
1298
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1299
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1300
+
1301
+ if len(checkpoints) >= args.checkpoints_total_limit:
1302
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1303
+ removing_checkpoints = checkpoints[0:num_to_remove]
1304
+ logger.info(f"Removing {len(removing_checkpoints)} checkpoints")
1305
+ for removing_checkpoint in removing_checkpoints:
1306
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1307
+ shutil.rmtree(removing_checkpoint)
1308
+
1309
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1310
+ accelerator.save_state(save_path)
1311
+ logger.info(f"Saved state to {save_path}")
1312
+
1313
+ # 同时保存标准的LoRA权重格式,方便采样时直接加载
1314
+ try:
1315
+ # 获取当前模型的LoRA权重
1316
+ unwrapped_transformer = unwrap_model(transformer)
1317
+ transformer_lora_layers = get_peft_model_state_dict(unwrapped_transformer)
1318
+
1319
+ text_encoder_lora_layers = None
1320
+ text_encoder_2_lora_layers = None
1321
+ if args.train_text_encoder:
1322
+ unwrapped_text_encoder_one = unwrap_model(text_encoder_one)
1323
+ unwrapped_text_encoder_two = unwrap_model(text_encoder_two)
1324
+ text_encoder_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_one)
1325
+ text_encoder_2_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_two)
1326
+
1327
+ # 保存为标准LoRA格式到checkpoint目录
1328
+ StableDiffusion3Pipeline.save_lora_weights(
1329
+ save_directory=save_path,
1330
+ transformer_lora_layers=transformer_lora_layers,
1331
+ text_encoder_lora_layers=text_encoder_lora_layers,
1332
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1333
+ )
1334
+ logger.info(f"Saved LoRA weights in standard format to {save_path}")
1335
+ except Exception as e:
1336
+ logger.warning(f"Failed to save LoRA weights in standard format: {e}")
1337
+ logger.warning("Checkpoint saved with accelerator format only. You can extract LoRA weights later.")
1338
+
1339
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1340
+ progress_bar.set_postfix(**logs)
1341
+
1342
+ if global_step >= args.max_train_steps:
1343
+ break
1344
+
1345
+ # Validation
1346
+ if accelerator.is_main_process:
1347
+ if args.validation_prompt is not None :#and epoch % args.validation_epochs == 0:
1348
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1349
+ args.pretrained_model_name_or_path,
1350
+ vae=vae,
1351
+ text_encoder=unwrap_model(text_encoder_one),
1352
+ text_encoder_2=unwrap_model(text_encoder_two),
1353
+ text_encoder_3=unwrap_model(text_encoder_three),
1354
+ transformer=unwrap_model(transformer),
1355
+ revision=args.revision,
1356
+ variant=args.variant,
1357
+ torch_dtype=weight_dtype,
1358
+ )
1359
+ images = log_validation(pipeline, args, accelerator, epoch, global_step=global_step)
1360
+ del pipeline
1361
+ torch.cuda.empty_cache()
1362
+
1363
+ # Save final LoRA weights
1364
+ accelerator.wait_for_everyone()
1365
+ if accelerator.is_main_process:
1366
+ transformer = unwrap_model(transformer)
1367
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
1368
+
1369
+ if args.train_text_encoder:
1370
+ text_encoder_one = unwrap_model(text_encoder_one)
1371
+ text_encoder_two = unwrap_model(text_encoder_two)
1372
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
1373
+ text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
1374
+ else:
1375
+ text_encoder_lora_layers = None
1376
+ text_encoder_2_lora_layers = None
1377
+
1378
+ StableDiffusion3Pipeline.save_lora_weights(
1379
+ save_directory=args.output_dir,
1380
+ transformer_lora_layers=transformer_lora_layers,
1381
+ text_encoder_lora_layers=text_encoder_lora_layers,
1382
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1383
+ )
1384
+
1385
+ # Final inference
1386
+ if args.mixed_precision == "fp16":
1387
+ vae.to(weight_dtype)
1388
+
1389
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1390
+ args.pretrained_model_name_or_path,
1391
+ vae=vae,
1392
+ revision=args.revision,
1393
+ variant=args.variant,
1394
+ torch_dtype=weight_dtype,
1395
+ )
1396
+ pipeline.load_lora_weights(args.output_dir)
1397
+
1398
+ if args.validation_prompt and args.num_validation_images > 0:
1399
+ images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True, global_step=global_step)
1400
+
1401
+ if args.push_to_hub:
1402
+ save_model_card(
1403
+ repo_id,
1404
+ images=images,
1405
+ base_model=args.pretrained_model_name_or_path,
1406
+ dataset_name=args.dataset_name,
1407
+ train_text_encoder=args.train_text_encoder,
1408
+ repo_folder=args.output_dir,
1409
+ )
1410
+ upload_folder(
1411
+ repo_id=repo_id,
1412
+ folder_path=args.output_dir,
1413
+ commit_message="End of training",
1414
+ ignore_patterns=["step_*", "epoch_*"],
1415
+ )
1416
+
1417
+ accelerator.end_training()
1418
+
1419
+
1420
+ if __name__ == "__main__":
1421
+ args = parse_args()
1422
+ main(args)
train_rectified_noise.py ADDED
The diff for this file is too large to render. See raw diff
 
train_rectified_noise.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # SD3 Rectified Noise Training Script
4
+ # 这个脚本展示了如何使用 train_rectified_noise.py 进行训练
5
+
6
+ set -e
7
+
8
+ # 激活正确的conda环境
9
+ source /root/miniconda3/etc/profile.d/conda.sh
10
+ conda activate SiT
11
+
12
+ # 基础配置
13
+ export CUDA_VISIBLE_DEVICES=0,1,2,3 # 设置使用4个GPU(0,1,2,3)
14
+ #export OMP_NUM_THREADS=1
15
+
16
+ # 内存优化设置
17
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
18
+ export TOKENIZERS_PARALLELISM=false
19
+
20
+ # 模型和数据路径
21
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
22
+ LORA_MODEL_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000" # LoRA微调后的SD3模型路径
23
+ TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 训练数据目录
24
+ OUTPUT_DIR="./rectified-noise-batch-2" # 输出目录
25
+
26
+ # 训练参数
27
+ NUM_SIT_LAYERS=1 # SIT块的层数
28
+ SIT_LEARNING_RATE=1e-5 # SIT块的学习率
29
+ KL_LOSS_WEIGHT=0.5 # KL散度损失权重
30
+ RESOLUTION=512 # 图像分辨率
31
+ BATCH_SIZE=2 # 批次大小
32
+ GRADIENT_ACCUMULATION=2 # 梯度累积步数
33
+ MAX_TRAIN_STEPS=500000 # 最大训练步数
34
+
35
+ # 验证参数
36
+ VALIDATION_PROMPT="A bicycle replica with a clock as the front wheel."
37
+ NUM_VALIDATION_IMAGES=1
38
+
39
+ echo "开始 SD3 Rectified Noise 训练..."
40
+ echo "LoRA模型路径: $LORA_MODEL_PATH"
41
+ echo "SIT层数: $NUM_SIT_LAYERS"
42
+ echo "输出目录: $OUTPUT_DIR"
43
+
44
+ # 检查LoRA模型路径是否存在
45
+ if [ ! -d "$LORA_MODEL_PATH" ]; then
46
+ echo "错误: LoRA模型路径不存在: $LORA_MODEL_PATH"
47
+ echo "请先使用 train_lora_sd3.py 训练LoRA模型"
48
+ exit 1
49
+ fi
50
+
51
+ # 使用accelerate启动训练
52
+ # 注意:移除了命令行中的mixed_precision参数,因为已经在accelerate_config.yaml中设置
53
+ accelerate launch --config_file accelerate_config.yaml train_rectified_noise.py \
54
+ --pretrained_model_name_or_path="$PRETRAINED_MODEL" \
55
+ --lora_model_path="$LORA_MODEL_PATH" \
56
+ --train_data_dir="$TRAIN_DATA_DIR" \
57
+ --num_sit_layers=$NUM_SIT_LAYERS \
58
+ --sit_learning_rate=$SIT_LEARNING_RATE \
59
+ --kl_loss_weight=$KL_LOSS_WEIGHT \
60
+ --resolution=$RESOLUTION \
61
+ --train_batch_size=$BATCH_SIZE \
62
+ --gradient_accumulation_steps=$GRADIENT_ACCUMULATION \
63
+ --gradient_checkpointing \
64
+ --learning_rate=1e-5 \
65
+ --time_weight_alpha=5.0 \
66
+ --lr_scheduler="constant" \
67
+ --lr_warmup_steps=0 \
68
+ --max_train_steps=$MAX_TRAIN_STEPS \
69
+ --output_dir="$OUTPUT_DIR" \
70
+ --validation_prompt="$VALIDATION_PROMPT" \
71
+ --num_validation_images=$NUM_VALIDATION_IMAGES \
72
+ --validation_steps=20000 \
73
+ --seed=42 \
74
+ --dataloader_num_workers=8 \
75
+ --save_sit_weights_only \
76
+ --checkpointing_steps=20000 \
77
+ --checkpoints_total_limit=10 \
78
+ --report_to="tensorboard" \
79
+ --logging_dir="./logs"
80
+
81
+ echo "训练完成!"
82
+ echo "SIT权重保存在: $OUTPUT_DIR/sit_weights/"
83
+ echo "验证图像保存在: $OUTPUT_DIR/validation_images/"
84
+
85
+ # 可选:快速测试训练命令
86
+ # cat << 'EOF'
87
+
88
+ # # 快速测试命令(少量步数):
89
+ # accelerate launch train_rectified_noise.py \
90
+ # --pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers" \
91
+ # --lora_model_path="./sd3-lora-finetuned" \
92
+ # --train_data_dir="./dataset" \
93
+ # --num_sit_layers=2 \
94
+ # --resolution=256 \
95
+ # --train_batch_size=1 \
96
+ # --gradient_accumulation_steps=4 \
97
+ # --max_train_steps=100 \
98
+ # --output_dir="./test-rectified-noise" \
99
+ # --mixed_precision="fp16" \
100
+ # --save_sit_weights_only
101
+
102
+ # EOF
103
+
104
+ # nohup bash train_rectified_noise.sh > train_rectified_noise.log 2>&1 &
train_rectified_noise2.py ADDED
The diff for this file is too large to render. See raw diff
 
train_sd3_lora.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ 检测到 4 个GPU
3
+ 每个GPU批次大小: 4
4
+ 总有效批次大小: 16
5
+ ===== SD3 LoRA 多GPU训练开始 =====
6
+ 模型: /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671
7
+ 输出目录: sd3-lora-finetuned-batch-8
8
+ 分辨率: 512
9
+ 每个GPU批次大小: 4
10
+ 梯度累积步数: 1
11
+ 总有效批次大小: 16
12
+ 学习率: 1e-5
13
+ 最大训练步数: 500000
14
+ LoRA Rank: 32
15
+ 使用GPU: 0,1,2,3
16
+ 断点重训: latest
17
+ ===========================================
18
+ 使用 accelerate 启动多GPU训练...
19
+ /root/miniconda3/envs/SiT/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
20
+ warnings.warn(
21
+ Terminated
22
+ ===========================================
23
+ 训练完成!
24
+ 模型保存在: sd3-lora-finetuned-batch-8
25
+ 日志保存在: sd3-lora-finetuned-batch-8/logs
26
+ 验证图片保存在: sd3-lora-finetuned-batch-8/validation_images
27
+ ===========================================
train_sd3_lora.sh ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # SD3 LoRA Fine-tuning Training Script
4
+ # 使用 Stable Diffusion 3 进行 LoRA 微调训练 - 多GPU优化版本
5
+
6
+ # 设置环境变量
7
+ export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU数量调整
8
+ # export PYTHONPATH=$PYTHONPATH:/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco
9
+
10
+ # 检查GPU数量
11
+ num_gpus=$(nvidia-smi --list-gpus | wc -l)
12
+ echo "检测到 $num_gpus 个GPU"
13
+
14
+ # 训练参数配置
15
+ MODEL_NAME="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
16
+ #DATASET_NAME="lambdalabs/naruto-blip-captions" # 或者使用本地数据集路径
17
+ TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 本地数据集路径0
18
+ OUTPUT_DIR="sd3-lora-finetuned-batch-8"
19
+ RESOLUTION=512
20
+ # 多GPU训练时调整批次大小 - 每个GPU的批次大小
21
+ TRAIN_BATCH_SIZE=4 # 每个GPU的批次大小,总批次大小 = TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS
22
+ GRADIENT_ACCUMULATION_STEPS=1 # 梯度累积步数
23
+ LEARNING_RATE=1e-5
24
+ MAX_TRAIN_STEPS=500000
25
+ NUM_TRAIN_EPOCHS=50
26
+ VALIDATION_PROMPT="A photo of a beautiful landscape with mountains and lake"
27
+ NUM_VALIDATION_IMAGES=2
28
+ VALIDATION_EPOCHS=1 # 每10个epoch验证一次
29
+ LORA_RANK=32
30
+ SEED=42
31
+ RESUME_FROM_CHECKPOINT="latest" # 设置为 "latest" 以自动从最新checkpoint恢复,或指定checkpoint路径,或设为 "" 以不恢复
32
+
33
+ # 计算有效批次大小
34
+ effective_batch_size=$((TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS))
35
+ echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
36
+ echo "总有效批次大小: $effective_batch_size"
37
+
38
+ # 创建输出目录
39
+ mkdir -p $OUTPUT_DIR
40
+
41
+ echo "===== SD3 LoRA 多GPU训练开始 ====="
42
+ echo "模型: $MODEL_NAME"
43
+ echo "输出目录: $OUTPUT_DIR"
44
+ echo "分辨率: $RESOLUTION"
45
+ echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
46
+ echo "梯度累积步数: $GRADIENT_ACCUMULATION_STEPS"
47
+ echo "总有效批次大小: $effective_batch_size"
48
+ echo "学习率: $LEARNING_RATE"
49
+ echo "最大训练步数: $MAX_TRAIN_STEPS"
50
+ echo "LoRA Rank: $LORA_RANK"
51
+ echo "使用GPU: $CUDA_VISIBLE_DEVICES"
52
+ echo "断点重训: $RESUME_FROM_CHECKPOINT"
53
+ echo "==========================================="
54
+
55
+ # 检查accelerate配置
56
+ if [ ! -f "accelerate_config.yaml" ]; then
57
+ echo "错误: 未找到 accelerate_config.yaml 文件"
58
+ echo "请运行: accelerate config 来配置多GPU训练"
59
+ exit 1
60
+ fi
61
+
62
+ echo "使用 accelerate 启动多GPU训练..."
63
+
64
+ # 使用 accelerate 启动训练
65
+ accelerate launch --config_file=accelerate_config.yaml train_lora_sd3.py \
66
+ --pretrained_model_name_or_path="$MODEL_NAME" \
67
+ --train_data_dir="$TRAIN_DATA_DIR" \
68
+ --output_dir="$OUTPUT_DIR" \
69
+ --mixed_precision="no" \
70
+ --resolution=$RESOLUTION \
71
+ --train_batch_size=$TRAIN_BATCH_SIZE \
72
+ --gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
73
+ --learning_rate=$LEARNING_RATE \
74
+ --scale_lr \
75
+ --lr_scheduler="cosine" \
76
+ --lr_warmup_steps=100 \
77
+ --max_train_steps=$MAX_TRAIN_STEPS \
78
+ --num_train_epochs=$NUM_TRAIN_EPOCHS \
79
+ --validation_prompt="$VALIDATION_PROMPT" \
80
+ --num_validation_images=$NUM_VALIDATION_IMAGES \
81
+ --validation_epochs=$VALIDATION_EPOCHS \
82
+ --checkpointing_steps=20000 \
83
+ --checkpoints_total_limit=10 \
84
+ --seed=$SEED \
85
+ --rank=$LORA_RANK \
86
+ --gradient_checkpointing \
87
+ --use_8bit_adam \
88
+ --dataloader_num_workers=0 \
89
+ --report_to="tensorboard" \
90
+ --logging_dir="logs" \
91
+ --adam_beta1=0.9 \
92
+ --adam_beta2=0.999 \
93
+ --adam_weight_decay=1e-2 \
94
+ --adam_epsilon=1e-8 \
95
+ --max_grad_norm=1.0 \
96
+ --allow_tf32 \
97
+ --weighting_scheme="logit_normal" \
98
+ --logit_mean=0.0 \
99
+ --logit_std=1.0 \
100
+ --precondition_outputs=1
101
+
102
+ echo "==========================================="
103
+ echo "训练完成!"
104
+ echo "模型保存在: $OUTPUT_DIR"
105
+ echo "日志保存在: $OUTPUT_DIR/logs"
106
+ echo "验证图片保存在: $OUTPUT_DIR/validation_images"
107
+ echo "==========================================="
108
+
109
+ # nohup bash train_sd3_lora.sh > train_sd3_lora.log 2>&1 &
train_sd3_lora2.log ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ 检测到 4 个GPU
3
+ 每个GPU批次大小: 8
4
+ 总有效批次大小: 32
5
+ ===== SD3 LoRA 多GPU训练开始 =====
6
+ 模型: /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671
7
+ 输出目录: sd3-lora-finetuned-batch-32
8
+ 分辨率: 512
9
+ 每个GPU批次大小: 8
10
+ 梯度累积步数: 1
11
+ 总有效批次大小: 32
12
+ 学习率: 1e-5
13
+ 最大训练步数: 500000
14
+ LoRA Rank: 32
15
+ 使用GPU: 0,1,2,3
16
+ ===========================================
17
+ 使用 accelerate 启动多GPU训练...
18
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
19
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
20
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
21
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
22
+ Found 4 GPUs available
23
+ Multi-GPU training enabled with 4 GPUs
24
+ Found 4 GPUs available
25
+ Multi-GPU training enabled with 4 GPUs
26
+ Found 4 GPUs available
27
+ Multi-GPU training enabled with 4 GPUs
28
+ Found 4 GPUs available
29
+ Multi-GPU training enabled with 4 GPUs
30
+ 03/09/2026 10:48:06 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
31
+ Num processes: 4
32
+ Process index: 2
33
+ Local process index: 2
34
+ Device: cuda:2
35
+
36
+ Mixed precision type: no
37
+
38
+ Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
39
+ 03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
40
+ Num processes: 4
41
+ Process index: 0
42
+ Local process index: 0
43
+ Device: cuda:0
44
+
45
+ Mixed precision type: no
46
+
47
+ [INFO] Accelerator initialized
48
+ 03/09/2026 10:48:07 - INFO - __main__ - Number of processes: 4
49
+ 03/09/2026 10:48:07 - INFO - __main__ - Distributed type: MULTI_GPU
50
+ 03/09/2026 10:48:07 - INFO - __main__ - Mixed precision: no
51
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 0: NVIDIA A100-PCIE-40GB
52
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 0 memory: 39.4 GB
53
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 1: NVIDIA A100-PCIE-40GB
54
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 1 memory: 39.4 GB
55
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 2: NVIDIA A100-PCIE-40GB
56
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 2 memory: 39.4 GB
57
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 3: NVIDIA A100-PCIE-40GB
58
+ 03/09/2026 10:48:07 - INFO - __main__ - GPU 3 memory: 39.4 GB
59
+ [INFO] Seed set to 42
60
+ [INFO] Loading tokenizers...
61
+ 03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
62
+ Num processes: 4
63
+ Process index: 1
64
+ Local process index: 1
65
+ Device: cuda:1
66
+
67
+ Mixed precision type: no
68
+
69
+ 03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
70
+ Num processes: 4
71
+ Process index: 3
72
+ Local process index: 3
73
+ Device: cuda:3
74
+
75
+ Mixed precision type: no
76
+
77
+ [INFO] Tokenizers loaded. Loading text encoders, VAE, and transformer...
78
+ You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
79
+ You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
80
+ You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
81
+ You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
82
+ {'max_image_seq_len', 'shift_terminal', 'use_beta_sigmas', 'time_shift_type', 'use_dynamic_shifting', 'stochastic_sampling', 'base_shift', 'invert_sigmas', 'use_exponential_sigmas', 'use_karras_sigmas', 'max_shift', 'base_image_seq_len'} was not found in config. Values will be initialized to default values.
83
+
84
+ {'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
85
+
86
+
87
+
88
+ If your task is similar to the task the model of the checkpoint was trained on, you can already use AutoencoderKL for predictions without further training.
89
+
90
+ {'qk_norm', 'dual_attention_layers'} was not found in config. Values will be initialized to default values.
91
+
92
+
93
+ All model checkpoint weights were used when initializing SD3Transformer2DModel.
94
+
95
+ All the weights of SD3Transformer2DModel were initialized from the model checkpoint at /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671.
96
+ If your task is similar to the task the model of the checkpoint was trained on, you can already use SD3Transformer2DModel for predictions without further training.
97
+ [INFO] Text encoders, VAE, and transformer loaded
98
+ [rank2]:[W309 10:48:36.300339912 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
99
+ [INFO] Optimizer created. Loading dataset...
100
+ [INFO] Found metadata.jsonl, using efficient loading method
101
+ [INFO] Loading dataset from metadata.jsonl: /gemini/space/hsd/project/dataset/cc3m-wds/train/metadata.jsonl
102
+ [rank1]:[W309 10:48:38.132724538 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
103
+ [INFO] Processed 100000 entries from metadata.jsonl
104
+ [rank3]:[W309 10:48:39.804821476 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
105
+ [INFO] Processed 200000 entries from metadata.jsonl
106
+ [INFO] Processed 300000 entries from metadata.jsonl
107
+ [INFO] Processed 400000 entries from metadata.jsonl
108
+ [INFO] Processed 500000 entries from metadata.jsonl
109
+ [INFO] Processed 600000 entries from metadata.jsonl
110
+ [INFO] Processed 700000 entries from metadata.jsonl
111
+ [INFO] Processed 800000 entries from metadata.jsonl
112
+ [INFO] Processed 900000 entries from metadata.jsonl
113
+ [INFO] Processed 1000000 entries from metadata.jsonl
114
+ [INFO] Processed 1100000 entries from metadata.jsonl
115
+ [INFO] Processed 1200000 entries from metadata.jsonl
116
+ [INFO] Processed 1300000 entries from metadata.jsonl
117
+ [INFO] Processed 1400000 entries from metadata.jsonl
118
+ [INFO] Processed 1500000 entries from metadata.jsonl
119
+ [INFO] Processed 1600000 entries from metadata.jsonl
120
+ [INFO] Processed 1700000 entries from metadata.jsonl
121
+ [INFO] Processed 1800000 entries from metadata.jsonl
122
+ [INFO] Processed 1900000 entries from metadata.jsonl
123
+ [INFO] Processed 2000000 entries from metadata.jsonl
124
+ [INFO] Processed 2100000 entries from metadata.jsonl
125
+ [INFO] Processed 2200000 entries from metadata.jsonl
126
+ [INFO] Processed 2300000 entries from metadata.jsonl
127
+ [INFO] Processed 2400000 entries from metadata.jsonl
128
+ [INFO] Processed 2500000 entries from metadata.jsonl
129
+ [INFO] Processed 2600000 entries from metadata.jsonl
130
+ [INFO] Processed 2700000 entries from metadata.jsonl
131
+ [INFO] Processed 2800000 entries from metadata.jsonl
132
+ [INFO] Processed 2900000 entries from metadata.jsonl
133
+ [INFO] Loaded 2905954 image-caption pairs from metadata.jsonl
134
+ [INFO] Dataset loaded successfully.
135
+ [rank0]:[W309 10:49:04.100729496 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
136
+ [INFO] All processes synchronized. Building transforms and DataLoader...
137
+ [INFO] Dataset columns: ['image', 'text']
138
+ [INFO] Using image column: image
139
+ [WARNING] Specified caption_column 'caption' not found. Using 'text' instead.
140
+ [INFO] Using caption column: text
141
+ 03/09/2026 10:49:34 - INFO - __main__ - Auto-setting dataloader_num_workers to 4 for multi-GPU training
142
+ [INFO] DataLoader ready. Computing training steps and scheduler...
143
+ 03/09/2026 10:49:36 - WARNING - __main__ - Failed to initialize trackers: Descriptors cannot be created directly.
144
+ If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
145
+ If you cannot immediately regenerate your protos, some other possible workarounds are:
146
+ 1. Downgrade the protobuf package to 3.20.x or lower.
147
+ 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).
148
+
149
+ More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
150
+ 03/09/2026 10:49:36 - WARNING - __main__ - Continuing without tracking. You can monitor training through console logs.
151
+ 03/09/2026 10:49:36 - INFO - __main__ - ***** Running training *****
152
+ 03/09/2026 10:49:36 - INFO - __main__ - Num examples = 2905954
153
+ 03/09/2026 10:49:36 - INFO - __main__ - Num Epochs = 6
154
+ 03/09/2026 10:49:36 - INFO - __main__ - Instantaneous batch size per device = 8
155
+ 03/09/2026 10:49:36 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 32
156
+ 03/09/2026 10:49:36 - INFO - __main__ - Gradient Accumulation steps = 1
157
+ 03/09/2026 10:49:36 - INFO - __main__ - Total optimization steps = 500000
158
+ 03/09/2026 10:49:36 - INFO - __main__ - Number of GPU processes = 4
159
+ 03/09/2026 10:49:36 - INFO - __main__ - Effective batch size per GPU = 8
160
+ 03/09/2026 10:49:36 - INFO - __main__ - Total effective batch size across all GPUs = 32
161
+ [INFO] Training setup complete. num_examples=2905954, max_train_steps=500000, num_epochs=6
162
+
163
+ [rank3]:[W309 10:49:40.986061745 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
164
+ [rank2]:[W309 10:49:40.986962214 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
165
+ [rank1]:[W309 10:49:40.988288933 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
166
+ [rank0]:[W309 10:49:40.989585595 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
167
+
168
+ [rank1]: File "/gemini/space/gzy_new/models/Sida/train_lora_sd3.py", line 1597, in <module>
169
+ [rank1]: main(args)
170
+ [rank1]: File "/gemini/space/gzy_new/models/Sida/train_lora_sd3.py", line 1410, in main
171
+ [rank1]: timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
172
+ [rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler
173
+ [rank1]: _error_if_any_worker_fails()
174
+ [rank1]: RuntimeError: DataLoader worker (pid 8746) is killed by signal: Terminated.
175
+ W0309 10:54:03.295000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7639 closing signal SIGTERM
176
+ W0309 10:54:03.296000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7640 closing signal SIGTERM
177
+ W0309 10:54:03.296000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7641 closing signal SIGTERM
178
+ E0309 10:54:03.762000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -15) local_rank: 0 (pid: 7638) of binary: /root/miniconda3/envs/SiT/bin/python3.10
179
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
180
+ Traceback (most recent call last):
181
+ File "/root/miniconda3/envs/SiT/bin/accelerate", line 6, in <module>
182
+ sys.exit(main())
183
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
184
+ args.func(args)
185
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1189, in launch_command
186
+ multi_gpu_launcher(args)
187
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/launch.py", line 815, in multi_gpu_launcher
188
+ distrib_run.run(args)
189
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
190
+ elastic_launch(
191
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
192
+ return launch_agent(self._config, self._entrypoint, list(args))
193
+ File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
194
+ raise ChildFailedError(
195
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
196
+ ==========================================================
197
+ train_lora_sd3.py FAILED
198
+ ----------------------------------------------------------
199
+ Failures:
200
+ <NO_OTHER_FAILURES>
201
+ ----------------------------------------------------------
202
+ Root Cause (first observed failure):
203
+ [0]:
204
+ time : 2026-03-09_10:54:03
205
+ host : 1406241bacf2123cb0cdd1395a94d0f5-taskrole1-0
206
+ rank : 0 (local_rank: 0)
207
+ exitcode : -15 (pid: 7638)
208
+ error_file: <N/A>
209
+ traceback : Signal 15 (SIGTERM) received by PID 7638
210
+ ==========================================================
211
+ ===========================================
212
+ 训练完成!
213
+ 模型保存在: sd3-lora-finetuned-batch-32
214
+ 日志保存在: sd3-lora-finetuned-batch-32/logs
215
+ 验证图片保存在: sd3-lora-finetuned-batch-32/validation_images
216
+ ===========================================
train_sd3_lora2.sh ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # SD3 LoRA Fine-tuning Training Script
4
+ # 使用 Stable Diffusion 3 进行 LoRA 微调训练 - 多GPU优化版本
5
+
6
+ # 设置环境变量
7
+ export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU数量调整
8
+ # export PYTHONPATH=$PYTHONPATH:/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco
9
+
10
+ # 检查GPU数量
11
+ num_gpus=$(nvidia-smi --list-gpus | wc -l)
12
+ echo "检测到 $num_gpus 个GPU"
13
+
14
+ # 训练参数配置
15
+ MODEL_NAME="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
16
+ #DATASET_NAME="lambdalabs/naruto-blip-captions" # 或者使用本地数据集路径
17
+ TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 本地数据集路径0
18
+ OUTPUT_DIR="sd3-lora-finetuned-batch-32"
19
+ RESOLUTION=512
20
+ # 多GPU训练时调整批次大小 - 每个GPU的批次大小
21
+ TRAIN_BATCH_SIZE=8 # 每个GPU的批次大小,总批次大小 = TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS
22
+ GRADIENT_ACCUMULATION_STEPS=1 # 梯度累积步数
23
+ LEARNING_RATE=1e-5
24
+ MAX_TRAIN_STEPS=500000
25
+ NUM_TRAIN_EPOCHS=50
26
+ VALIDATION_PROMPT="A photo of a beautiful landscape with mountains and lake"
27
+ NUM_VALIDATION_IMAGES=2
28
+ VALIDATION_EPOCHS=1 # 每10个epoch验证一次
29
+ LORA_RANK=32
30
+ SEED=42
31
+
32
+ # 计算有效批次大小
33
+ effective_batch_size=$((TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS))
34
+ echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
35
+ echo "总有效批次大小: $effective_batch_size"
36
+
37
+ # 创建输出目录
38
+ mkdir -p $OUTPUT_DIR
39
+
40
+ echo "===== SD3 LoRA 多GPU训练开始 ====="
41
+ echo "模型: $MODEL_NAME"
42
+ echo "输出目录: $OUTPUT_DIR"
43
+ echo "分辨率: $RESOLUTION"
44
+ echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
45
+ echo "梯度累积步数: $GRADIENT_ACCUMULATION_STEPS"
46
+ echo "总有效批次大小: $effective_batch_size"
47
+ echo "学习率: $LEARNING_RATE"
48
+ echo "最大训练步数: $MAX_TRAIN_STEPS"
49
+ echo "LoRA Rank: $LORA_RANK"
50
+ echo "使用GPU: $CUDA_VISIBLE_DEVICES"
51
+ echo "==========================================="
52
+
53
+ # 检查accelerate配置
54
+ if [ ! -f "accelerate_config.yaml" ]; then
55
+ echo "错误: 未找到 accelerate_config.yaml 文件"
56
+ echo "请运行: accelerate config 来配置多GPU训练"
57
+ exit 1
58
+ fi
59
+
60
+ echo "使用 accelerate 启动多GPU训练..."
61
+
62
+ # 使用 accelerate 启动训练
63
+ accelerate launch --config_file=accelerate_config.yaml train_lora_sd3.py \
64
+ --pretrained_model_name_or_path="$MODEL_NAME" \
65
+ --train_data_dir="$TRAIN_DATA_DIR" \
66
+ --output_dir="$OUTPUT_DIR" \
67
+ --mixed_precision="no" \
68
+ --resolution=$RESOLUTION \
69
+ --train_batch_size=$TRAIN_BATCH_SIZE \
70
+ --gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
71
+ --learning_rate=$LEARNING_RATE \
72
+ --scale_lr \
73
+ --lr_scheduler="cosine" \
74
+ --lr_warmup_steps=100 \
75
+ --max_train_steps=$MAX_TRAIN_STEPS \
76
+ --num_train_epochs=$NUM_TRAIN_EPOCHS \
77
+ --validation_prompt="$VALIDATION_PROMPT" \
78
+ --num_validation_images=$NUM_VALIDATION_IMAGES \
79
+ --validation_epochs=$VALIDATION_EPOCHS \
80
+ --checkpointing_steps=50000 \
81
+ --checkpoints_total_limit=10 \
82
+ --seed=$SEED \
83
+ --rank=$LORA_RANK \
84
+ --gradient_checkpointing \
85
+ --use_8bit_adam \
86
+ --dataloader_num_workers=0 \
87
+ --report_to="tensorboard" \
88
+ --logging_dir="logs" \
89
+ --adam_beta1=0.9 \
90
+ --adam_beta2=0.999 \
91
+ --adam_weight_decay=1e-2 \
92
+ --adam_epsilon=1e-8 \
93
+ --max_grad_norm=1.0 \
94
+ --allow_tf32 \
95
+ --weighting_scheme="logit_normal" \
96
+ --logit_mean=0.0 \
97
+ --logit_std=1.0 \
98
+ --precondition_outputs=1
99
+
100
+ echo "==========================================="
101
+ echo "训练完成!"
102
+ echo "模型保存在: $OUTPUT_DIR"
103
+ echo "日志保存在: $OUTPUT_DIR/logs"
104
+ echo "验证图片保存在: $OUTPUT_DIR/validation_images"
105
+ echo "==========================================="
106
+
107
+ # nohup bash train_sd3_lora2.sh > train_sd3_lora2.log 2>&1 &
visual.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # 分两次主模型运行,最后拼图,避免同时加载双模型导致 OOM
5
+
6
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
7
+ PYTHON_BIN="${PYTHON_BIN:-/root/miniconda3/envs/SiT/bin/python}"
8
+
9
+ PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
10
+ LOCAL_PIPELINE_PATH="/gemini/space/gzy_new/models/Sida/pipeline_stable_diffusion_3.py"
11
+ LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
12
+ RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
13
+
14
+ # 可按需修改为你想看的文本
15
+ PROMPT="young man beside a dog wearing sunglasses"
16
+
17
+ OUTPUT_DIR="/gemini/space/gzy_new/models/Sida/sd3_lora_rn_pair_samples"
18
+ OUTPUT_FILE="${OUTPUT_DIR}/lora_rn_4x8_step180.png"
19
+ LORA_NPZ="${OUTPUT_DIR}/lora_trace_4x8.npz"
20
+ RN_NPZ="${OUTPUT_DIR}/rn_trace_4x8.npz"
21
+
22
+ STEPS=180
23
+ GUIDANCE_SCALE=7.0
24
+ HEIGHT=512
25
+ WIDTH=512
26
+ SEED=42
27
+ MIXED_PRECISION="fp16" # no / fp16 / bf16
28
+ NUM_SIT_LAYERS=1 # 需与训练一致
29
+
30
+ mkdir -p "$OUTPUT_DIR"
31
+
32
+ if [ ! -x "$PYTHON_BIN" ]; then
33
+ echo "ERROR: PYTHON_BIN not executable: $PYTHON_BIN"
34
+ echo "Hint: export PYTHON_BIN=/path/to/your/python"
35
+ exit 1
36
+ fi
37
+
38
+ if [ ! -e "$PRETRAINED_MODEL" ]; then
39
+ echo "ERROR: PRETRAINED_MODEL not found: $PRETRAINED_MODEL"
40
+ exit 1
41
+ fi
42
+ if [ ! -f "$LOCAL_PIPELINE_PATH" ]; then
43
+ echo "ERROR: LOCAL_PIPELINE_PATH not found: $LOCAL_PIPELINE_PATH"
44
+ exit 1
45
+ fi
46
+ if [ ! -e "$LORA_PATH" ]; then
47
+ echo "ERROR: LORA_PATH not found: $LORA_PATH"
48
+ exit 1
49
+ fi
50
+ if [ ! -e "$RECTIFIED_WEIGHTS" ]; then
51
+ echo "ERROR: RECTIFIED_WEIGHTS not found: $RECTIFIED_WEIGHTS"
52
+ exit 1
53
+ fi
54
+
55
+ COMMON_ARGS=(
56
+ --pretrained_model_name_or_path "$PRETRAINED_MODEL"
57
+ --local_pipeline_path "$LOCAL_PIPELINE_PATH"
58
+ --lora_path "$LORA_PATH"
59
+ --rectified_weights "$RECTIFIED_WEIGHTS"
60
+ --num_sit_layers "$NUM_SIT_LAYERS"
61
+ --prompt "$PROMPT"
62
+ --output "$OUTPUT_FILE"
63
+ --lora_npz "$LORA_NPZ"
64
+ --rn_npz "$RN_NPZ"
65
+ --steps "$STEPS"
66
+ --guidance_scale "$GUIDANCE_SCALE"
67
+ --height "$HEIGHT"
68
+ --width "$WIDTH"
69
+ --seed "$SEED"
70
+ --mixed_precision "$MIXED_PRECISION"
71
+ )
72
+
73
+ "$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage lora
74
+ "$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage rn
75
+ "$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage pair
76
+
77
+ echo "Done. Saved to: $OUTPUT_FILE"
78
+ # nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
visualize_lora_rn_4x8.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ 分阶段生成 4x8 对比图(默认 180 步):
5
+ - stage=lora: 仅生成 LoRA 轨迹中间结果并保存中间 npz
6
+ - stage=rn: 仅生成 RN 轨迹中间结果并保存中间 npz
7
+ - stage=pair: 读取两阶段 npz,计算第3/4行并拼接总图
8
+ """
9
+
10
+ import argparse
11
+ import importlib.util
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+
18
+ from diffusers import AutoencoderKL
19
+ from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline
20
+
21
+
22
+ def dynamic_import_training_classes(project_root: str):
23
+ import sys
24
+ sys.path.insert(0, project_root)
25
+ import train_rectified_noise as trn
26
+ return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
27
+
28
+
29
+ def load_local_pipeline_class(local_pipeline_path: str):
30
+ module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3"
31
+ spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path)
32
+ if spec is None or spec.loader is None:
33
+ raise ImportError(f"Failed to import local pipeline: {local_pipeline_path}")
34
+ module = importlib.util.module_from_spec(spec)
35
+ spec.loader.exec_module(module)
36
+ return module.StableDiffusion3Pipeline
37
+
38
+
39
+ def load_sit_weights(rectified_module, weights_path: str):
40
+ p = Path(weights_path)
41
+ if p.is_dir():
42
+ search_dirs = [p, p / "sit_weights"]
43
+ for d in search_dirs:
44
+ if not d.exists():
45
+ continue
46
+ st = d / "pytorch_sit_weights.safetensors"
47
+ if st.exists():
48
+ from safetensors.torch import load_file
49
+ rectified_module.load_state_dict(load_file(str(st)), strict=False)
50
+ return True
51
+ for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
52
+ ck = d / name
53
+ if ck.exists():
54
+ rectified_module.load_state_dict(torch.load(str(ck), map_location="cpu"), strict=False)
55
+ return True
56
+ return False
57
+ if str(p).endswith(".safetensors"):
58
+ from safetensors.torch import load_file
59
+ state = load_file(str(p))
60
+ else:
61
+ state = torch.load(str(p), map_location="cpu")
62
+ rectified_module.load_state_dict(state, strict=False)
63
+ return True
64
+
65
+
66
+ def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device):
67
+ RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
68
+ tfm = base_pipeline.transformer
69
+ sit_hidden_size = getattr(tfm.config, "joint_attention_dim", None) or getattr(tfm.config, "inner_dim", 4096)
70
+ rectified_module = RectifiedNoiseModule(
71
+ hidden_size=sit_hidden_size,
72
+ num_sit_layers=num_sit_layers,
73
+ num_attention_heads=getattr(tfm.config, "num_attention_heads", 32),
74
+ input_dim=getattr(tfm.config, "in_channels", 16),
75
+ transformer_hidden_size=getattr(tfm.config, "hidden_size", 1536),
76
+ )
77
+ if not load_sit_weights(rectified_module, rectified_weights):
78
+ raise RuntimeError(f"Failed to load rectified weights: {rectified_weights}")
79
+ model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device)
80
+ model.eval()
81
+ return model
82
+
83
+
84
+ def set_pipeline_modules_eval(pipe):
85
+ for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "model"]:
86
+ m = getattr(pipe, name, None)
87
+ if m is not None and hasattr(m, "eval"):
88
+ m.eval()
89
+
90
+
91
+ def align_rn_branch_dtype(pipe, dtype):
92
+ model = getattr(pipe, "model", None)
93
+ if model is None:
94
+ return
95
+ rn_branch = getattr(model, "rectified_noise_module", None)
96
+ if rn_branch is not None:
97
+ rn_branch.to(device=pipe._execution_device, dtype=dtype)
98
+ if hasattr(model, "to"):
99
+ model.to(device=pipe._execution_device)
100
+
101
+
102
+ @torch.no_grad()
103
+ def decode_latent_to_uint8(vae, latents, normalize=False):
104
+ shift = getattr(vae.config, "shift_factor", 0.0) or 0.0
105
+ scaled = (latents / vae.config.scaling_factor) + shift
106
+ image = vae.decode(scaled, return_dict=False)[0]
107
+ x = image[0].float()
108
+ if normalize:
109
+ x = (x - x.min()) / (x.max() - x.min() + 1e-6)
110
+ else:
111
+ x = (x / 2.0) + 0.5
112
+ x = x.clamp(0, 1)
113
+ return (x.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
114
+
115
+
116
+ def encode_prompt_for_pipe(pipe, prompt, guidance_scale):
117
+ do_cfg = guidance_scale > 1.0
118
+ pe, npe, pp, npp = pipe.encode_prompt(
119
+ prompt=prompt,
120
+ prompt_2=prompt,
121
+ prompt_3=prompt,
122
+ do_classifier_free_guidance=do_cfg,
123
+ num_images_per_prompt=1,
124
+ device=pipe._execution_device,
125
+ )
126
+ if do_cfg:
127
+ pe = torch.cat([npe, pe], dim=0)
128
+ pp = torch.cat([npp, pp], dim=0)
129
+ return pe, pp
130
+
131
+
132
+ @torch.no_grad()
133
+ def run_single_trajectory(pipe, prompt, steps, guidance_scale, height, width, seed, use_rn_model=False, autocast_dtype=None):
134
+ device = pipe._execution_device
135
+ dtype = next(pipe.transformer.parameters()).dtype
136
+ sample_idx = np.linspace(0, steps - 1, 8, dtype=int).tolist()
137
+
138
+ latent_h = height // pipe.vae_scale_factor
139
+ latent_w = width // pipe.vae_scale_factor
140
+ g = torch.Generator(device=device).manual_seed(seed)
141
+ init_latents = torch.randn(
142
+ (1, pipe.transformer.config.in_channels, latent_h, latent_w),
143
+ device=device,
144
+ dtype=dtype,
145
+ generator=g,
146
+ )
147
+
148
+ if use_rn_model:
149
+ effective_model = getattr(pipe, "model", None)
150
+ if effective_model is not None:
151
+ pipe.model = effective_model
152
+
153
+ step_latents = {}
154
+
155
+ def _capture_callback(_pipe, i, _t, callback_kwargs):
156
+ if i in sample_idx:
157
+ step_latents[i] = callback_kwargs["latents"].detach().clone()
158
+ return callback_kwargs
159
+
160
+ # 跟参考脚本一致,主路径走 pipeline(...),通过 callback 抓中间 latent
161
+ if autocast_dtype is None:
162
+ _ = pipe(
163
+ prompt=prompt,
164
+ height=height,
165
+ width=width,
166
+ num_inference_steps=steps,
167
+ guidance_scale=guidance_scale,
168
+ latents=init_latents,
169
+ num_images_per_prompt=1,
170
+ callback_on_step_end=_capture_callback,
171
+ callback_on_step_end_tensor_inputs=["latents"],
172
+ )
173
+ else:
174
+ with torch.autocast("cuda", dtype=autocast_dtype):
175
+ _ = pipe(
176
+ prompt=prompt,
177
+ height=height,
178
+ width=width,
179
+ num_inference_steps=steps,
180
+ guidance_scale=guidance_scale,
181
+ latents=init_latents,
182
+ num_images_per_prompt=1,
183
+ callback_on_step_end=_capture_callback,
184
+ callback_on_step_end_tensor_inputs=["latents"],
185
+ )
186
+
187
+ images = []
188
+ latents = []
189
+ noises = []
190
+ prev_lat = init_latents
191
+ pe = pp = None
192
+ do_cfg = guidance_scale > 1.0
193
+ timesteps = None
194
+ if use_rn_model:
195
+ pe, pp = encode_prompt_for_pipe(pipe, prompt, guidance_scale)
196
+ pipe.scheduler.set_timesteps(steps, device=device)
197
+ timesteps = pipe.scheduler.timesteps
198
+
199
+ for i in sample_idx:
200
+ cur_lat = step_latents[i]
201
+ images.append(decode_latent_to_uint8(pipe.vae, cur_lat, normalize=False))
202
+ latents.append(cur_lat.squeeze(0).float().cpu().numpy())
203
+ if use_rn_model:
204
+ # 第3行严格使用 RN 噪声分支(速度场)输出
205
+ model_in = torch.cat([cur_lat] * 2) if do_cfg else cur_lat
206
+ ts = timesteps[i].expand(model_in.shape[0])
207
+ if autocast_dtype is not None and device == "cuda":
208
+ with torch.autocast("cuda", dtype=autocast_dtype):
209
+ rn_out = pipe.model(
210
+ hidden_states=model_in,
211
+ timestep=ts,
212
+ encoder_hidden_states=pe,
213
+ pooled_projections=pp,
214
+ return_dict=False,
215
+ )
216
+ else:
217
+ rn_out = pipe.model(
218
+ hidden_states=model_in,
219
+ timestep=ts,
220
+ encoder_hidden_states=pe,
221
+ pooled_projections=pp,
222
+ return_dict=False,
223
+ )
224
+ # SD3WithRectifiedNoise: (final_output, mean_out, var_out)
225
+ rn_branch = rn_out[1] if isinstance(rn_out, tuple) and len(rn_out) > 1 else rn_out[0]
226
+ if do_cfg:
227
+ ru, rt = rn_branch.chunk(2)
228
+ rn_branch = ru + guidance_scale * (rt - ru)
229
+ noises.append(rn_branch.squeeze(0).float().cpu().numpy())
230
+ else:
231
+ # lora 阶段保留占位,pair 阶段不会用到
232
+ delta = (cur_lat - prev_lat)
233
+ noises.append(delta.squeeze(0).float().cpu().numpy())
234
+ prev_lat = cur_lat
235
+
236
+ return {
237
+ "images": np.stack(images, axis=0),
238
+ "latents": np.stack(latents, axis=0),
239
+ "noises": np.stack(noises, axis=0),
240
+ "sample_idx": np.array(sample_idx, dtype=np.int32),
241
+ }
242
+
243
+
244
+ def save_grid_4x8(rows, sample_idx, out_path, cell_w=512, cell_h=512):
245
+ cols = 8
246
+ grid = Image.new("RGB", (cols * cell_w, 4 * cell_h), color=(245, 245, 245))
247
+ for r in range(4):
248
+ for c in range(cols):
249
+ img = Image.fromarray(rows[r][c]).resize((cell_w, cell_h), Image.BILINEAR)
250
+ x = c * cell_w
251
+ y = r * cell_h
252
+ grid.paste(img, (x, y))
253
+ grid.save(out_path)
254
+
255
+
256
+ def stage_lora(args, dtype, device):
257
+ pipe = DiffusersStableDiffusion3Pipeline.from_pretrained(
258
+ args.pretrained_model_name_or_path,
259
+ revision=args.revision,
260
+ variant=args.variant,
261
+ torch_dtype=dtype,
262
+ ).to(device)
263
+ pipe.load_lora_weights(args.lora_path)
264
+ set_pipeline_modules_eval(pipe)
265
+ pipe.set_progress_bar_config(disable=True)
266
+ data = run_single_trajectory(
267
+ pipe=pipe,
268
+ prompt=args.prompt,
269
+ steps=args.steps,
270
+ guidance_scale=args.guidance_scale,
271
+ height=args.height,
272
+ width=args.width,
273
+ seed=args.seed,
274
+ use_rn_model=False,
275
+ autocast_dtype=dtype if args.mixed_precision != "no" and device == "cuda" else None,
276
+ )
277
+ np.savez_compressed(args.lora_npz, **data)
278
+ print(f"[lora] saved: {args.lora_npz}")
279
+
280
+
281
+ def stage_rn(args, dtype, device):
282
+ LocalPipe = load_local_pipeline_class(args.local_pipeline_path)
283
+ pipe = LocalPipe.from_pretrained(
284
+ args.pretrained_model_name_or_path,
285
+ revision=args.revision,
286
+ variant=args.variant,
287
+ torch_dtype=dtype,
288
+ ).to(device)
289
+ pipe.load_lora_weights(args.lora_path)
290
+ pipe.model = build_rn_model(pipe, args.rectified_weights, args.num_sit_layers, device)
291
+ # 避免 RN 速度场额外前向时出现 Half/Float 冲突
292
+ align_rn_branch_dtype(pipe, dtype)
293
+ set_pipeline_modules_eval(pipe)
294
+ pipe.set_progress_bar_config(disable=True)
295
+ data = run_single_trajectory(
296
+ pipe=pipe,
297
+ prompt=args.prompt,
298
+ steps=args.steps,
299
+ guidance_scale=args.guidance_scale,
300
+ height=args.height,
301
+ width=args.width,
302
+ seed=args.seed,
303
+ use_rn_model=True,
304
+ autocast_dtype=dtype if args.mixed_precision != "no" and device == "cuda" else None,
305
+ )
306
+ np.savez_compressed(args.rn_npz, **data)
307
+ print(f"[rn] saved: {args.rn_npz}")
308
+
309
+
310
+ def stage_pair(args, dtype, device):
311
+ lora = np.load(args.lora_npz)
312
+ rn = np.load(args.rn_npz)
313
+ lora_images = lora["images"]
314
+ rn_images = rn["images"]
315
+ sample_idx = lora["sample_idx"]
316
+ rn_noises = rn["noises"] # RN 速度场,shape: [8, C, H, W]
317
+
318
+ def _velocity_to_sparse_points(vel_chw, out_h, out_w, q=99.6, point_color=(245, 245, 245)):
319
+ # 通道聚合成强度图,再转成黑底稀疏点图
320
+ mag = np.sqrt(np.sum(np.square(vel_chw.astype(np.float32)), axis=0)) # [H, W]
321
+ thr = np.percentile(mag, q)
322
+ mask = mag >= thr
323
+ # 最近邻放大到输出分辨率
324
+ sy = max(1, out_h // mask.shape[0])
325
+ sx = max(1, out_w // mask.shape[1])
326
+ up = np.repeat(np.repeat(mask, sy, axis=0), sx, axis=1)
327
+ up = up[:out_h, :out_w]
328
+ canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
329
+ canvas[up] = np.array(point_color, dtype=np.uint8)
330
+ return canvas, mag
331
+
332
+ # 第3行:黑底 + 稀疏速度点(接近你最初图风格)
333
+ step_noise_imgs = []
334
+ # 第4行:速度场累积后再做稀疏点图
335
+ sum_noise_imgs = []
336
+ running_mag = None
337
+ for i in range(8):
338
+ step_vis, mag = _velocity_to_sparse_points(rn_noises[i], args.height, args.width, q=99.6)
339
+ if running_mag is None:
340
+ running_mag = mag
341
+ else:
342
+ running_mag = running_mag + mag
343
+ thr_sum = np.percentile(running_mag, 99.2)
344
+ mask_sum = running_mag >= thr_sum
345
+ sy = max(1, args.height // mask_sum.shape[0])
346
+ sx = max(1, args.width // mask_sum.shape[1])
347
+ up_sum = np.repeat(np.repeat(mask_sum, sy, axis=0), sx, axis=1)[:args.height, :args.width]
348
+ sum_vis = np.zeros((args.height, args.width, 3), dtype=np.uint8)
349
+ sum_vis[up_sum] = np.array([245, 245, 245], dtype=np.uint8)
350
+
351
+ step_noise_imgs.append(step_vis)
352
+ sum_noise_imgs.append(sum_vis)
353
+
354
+ rows = [
355
+ [lora_images[i] for i in range(8)],
356
+ [rn_images[i] for i in range(8)],
357
+ step_noise_imgs,
358
+ sum_noise_imgs,
359
+ ]
360
+ save_grid_4x8(rows, sample_idx, args.output, cell_w=args.width, cell_h=args.height)
361
+ print(f"[pair] saved: {args.output}")
362
+
363
+
364
+ def main():
365
+ parser = argparse.ArgumentParser()
366
+ parser.add_argument("--stage", type=str, default="all", choices=["all", "lora", "rn", "pair"])
367
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
368
+ parser.add_argument("--revision", type=str, default=None)
369
+ parser.add_argument("--variant", type=str, default=None)
370
+ parser.add_argument("--local_pipeline_path", type=str, required=True)
371
+ parser.add_argument("--lora_path", type=str, required=True)
372
+ parser.add_argument("--rectified_weights", type=str, required=True)
373
+ parser.add_argument("--num_sit_layers", type=int, default=1)
374
+ parser.add_argument("--prompt", type=str, required=True)
375
+ parser.add_argument("--output", type=str, default="lora_rn_4x8.png")
376
+ parser.add_argument("--lora_npz", type=str, default="lora_trace_4x8.npz")
377
+ parser.add_argument("--rn_npz", type=str, default="rn_trace_4x8.npz")
378
+ parser.add_argument("--steps", type=int, default=180)
379
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
380
+ parser.add_argument("--height", type=int, default=512)
381
+ parser.add_argument("--width", type=int, default=512)
382
+ parser.add_argument("--seed", type=int, default=42)
383
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
384
+ args = parser.parse_args()
385
+
386
+ dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
387
+ device = "cuda" if torch.cuda.is_available() else "cpu"
388
+
389
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
390
+ Path(args.lora_npz).parent.mkdir(parents=True, exist_ok=True)
391
+ Path(args.rn_npz).parent.mkdir(parents=True, exist_ok=True)
392
+
393
+ if args.stage in ("all", "lora"):
394
+ stage_lora(args, dtype, device)
395
+ if device == "cuda":
396
+ torch.cuda.empty_cache()
397
+ if args.stage in ("all", "rn"):
398
+ stage_rn(args, dtype, device)
399
+ if device == "cuda":
400
+ torch.cuda.empty_cache()
401
+ if args.stage in ("all", "pair"):
402
+ stage_pair(args, dtype, device)
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()
visualize_sitf2_noise_evolution.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import imageio
4
+ from tqdm import tqdm
5
+ import torch.distributed as dist
6
+ from models import SiTF1, SiTF2, SiT, CombinedModel
7
+ from download import find_model
8
+ from diffusers.models import AutoencoderKL
9
+
10
+ def tensor_to_img(tensor):
11
+ arr = tensor.detach().cpu().numpy()
12
+ if arr.ndim == 3:
13
+ arr = arr[0]
14
+ arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255
15
+ return arr.astype(np.uint8)
16
+
17
+ def main(
18
+ sit_ckpt, sitf2_ckpt,
19
+ image_size=256,
20
+ patch_size=2,
21
+ hidden_size=1152,
22
+ out_channels=8,
23
+ steps=50,
24
+ gif_path='sitf2_noise_evolution.gif',
25
+ device='cuda'
26
+ ):
27
+ dist.init_process_group("nccl")
28
+ rank = dist.get_rank()
29
+ device = rank % torch.cuda.device_count()
30
+ latent_size = image_size // 8
31
+ sitf1 = SiTF1(
32
+ input_size=latent_size,
33
+ patch_size=2,
34
+ in_channels=4,
35
+ hidden_size=768,
36
+ depth=12,
37
+ num_heads=12,
38
+ mlp_ratio=4.0,
39
+ class_dropout_prob=0.1,
40
+ num_classes=3,
41
+ learn_sigma=False
42
+ ).to(device)
43
+ sitf1_state = find_model(sit_ckpt)
44
+ try:
45
+ sitf1.load_state_dict(sitf1_state["model"], strict=False)
46
+ except Exception:
47
+ sitf1.load_state_dict(sitf1_state, strict=False)
48
+ sitf1.eval()
49
+
50
+ sitf2 = SiTF2(
51
+ hidden_size=768,
52
+ out_channels=8,
53
+ patch_size=2,
54
+ num_heads=12,
55
+ mlp_ratio=4.0,
56
+ depth=2,
57
+ learn_sigma=False,
58
+ learn_mu=True
59
+ ).to(device)
60
+ from torch.nn.parallel import DistributedDataParallel as DDP
61
+ sitf2 = DDP(sitf2, device_ids=[device])
62
+ sitf2_state = find_model(args.sitf2_ckpt)
63
+ try:
64
+ sitf2.load_state_dict(sitf2_state["ema"])
65
+ except Exception:
66
+ sitf2.load_state_dict(sitf2_state)
67
+ sitf2.eval()
68
+
69
+ batch = 1
70
+ x = torch.randn(batch, 4, latent_size, latent_size, device=device)
71
+ x0= x
72
+ y = torch.zeros(batch, dtype=torch.long, device=device)
73
+ t = torch.ones(batch, device=device)
74
+
75
+ imgs = []
76
+ imgs1 = []
77
+ imgs2 = []
78
+ img_original=[]
79
+ sit = SiT(
80
+ input_size=latent_size,
81
+ patch_size=2,
82
+ in_channels=4,
83
+ hidden_size=768,
84
+ depth=12,
85
+ num_heads=12,
86
+ mlp_ratio=4.0,
87
+ class_dropout_prob=0.1,
88
+ num_classes=3,
89
+ learn_sigma=False
90
+ ).to(device)
91
+ try:
92
+ sit.load_state_dict(sitf1_state["model"])
93
+ except Exception:
94
+ sit.load_state_dict(sitf1_state)
95
+ sit.eval()
96
+ combined_model = CombinedModel(sitf1, sitf2).to(device)
97
+ combined_model.eval()
98
+
99
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)
100
+
101
+ for step in tqdm(range(steps)):
102
+ t = torch.full((batch,), step / (steps - 1), device=device)
103
+ with torch.no_grad():
104
+ patch_tokens = sitf1(x, t, y)
105
+ t_emb = sitf1.t_embedder(t)
106
+ y_emb = sitf1.y_embedder(y, False)
107
+ c = t_emb + y_emb
108
+ std = sitf2.module.forward_noise(patch_tokens, c)
109
+ x1=x
110
+ drift = sit(x1, t, y)
111
+ delta_t = 1.0 / steps
112
+ x1 = x1 + drift * delta_t
113
+ x_dec = vae.decode(x1 / 0.18215).sample
114
+ img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
115
+ img1 = img[0]
116
+ drift = sit(x, t, y)
117
+ delta_t = 1.0 / steps
118
+ noise = torch.randn_like(x)
119
+ x = x + drift * delta_t
120
+ x_dec = vae.decode(x / 0.18215).sample
121
+ img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
122
+ img2 = img[0]
123
+ imgs.append(img2-img1)
124
+ imgs1.append(img1)
125
+ imgs2.append(img2)
126
+ x=x0
127
+ for step in tqdm(range(steps)):
128
+ t = torch.full((batch,), step / (steps - 1), device=device)
129
+ with torch.no_grad():
130
+ patch_tokens = sitf1(x, t, y)
131
+ t_emb = sitf1.t_embedder(t)
132
+ y_emb = sitf1.y_embedder(y, False)
133
+ c = t_emb + y_emb
134
+ std = sitf2.module.forward_noise(patch_tokens, c)
135
+ drift = sit(x, t, y)
136
+ delta_t = 1.0 / steps
137
+ x = x + drift * delta_t
138
+ x_dec = vae.decode(x / 0.18215).sample
139
+ img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
140
+ img1 = img[0]
141
+ img_original.append(img1)
142
+
143
+ imageio.mimsave(gif_path, imgs, duration=0.1)
144
+ print(f"Saved gif to {gif_path}")
145
+ imageio.mimsave('noise.gif', imgs1, duration=0.1)
146
+ print(f"Saved gif to {gif_path}")
147
+ imageio.mimsave('std.gif', imgs2, duration=0.1)
148
+ imageio.mimsave('nothing.gif', img_original, duration=0.1)
149
+ print(f"Saved gif to {gif_path}")
150
+ if __name__ == '__main__':
151
+ import argparse
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument('--ckpt', type=str, default='/gemini/space/gzy/w_w_last/Celeba/w_w_sit_1/0200000.pt')
154
+ parser.add_argument('--sitf2-ckpt', type=str,default='/gemini/space/gzy/w_w_last/Celeba/w_w_sit_1/results/depth-mu-2-014-SiT-XL-2-Linear-velocity-None/checkpoints/0010000.pt')
155
+ parser.add_argument('--steps', type=int, default=100)
156
+ parser.add_argument('--gif-path', type=str, default='sitf2_noise_evolution.gif')
157
+ parser.add_argument('--gif-path2', type=str, default='noise.gif')
158
+ parser.add_argument('--gif-path1', type=str, default='std.gif')
159
+ parser.add_argument('--image-size', type=int, default=256)
160
+ parser.add_argument('--device', type=str, default='cuda')
161
+ args = parser.parse_args()
162
+ main(
163
+ sit_ckpt=args.ckpt,
164
+ sitf2_ckpt=args.sitf2_ckpt,
165
+ image_size=args.image_size,
166
+ steps=args.steps,
167
+ gif_path=args.gif_path,
168
+ device=args.device
169
+ )