Add files using upload-large-folder tool
Browse files- __pycache__/pic_npz.cpython-311.pyc +0 -0
- __pycache__/pipeline_stable_diffusion_3.cpython-310.pyc +0 -0
- __pycache__/sample_sd3_lora_rn_pair_ddp.cpython-311.pyc +0 -0
- __pycache__/train_rectified_noise.cpython-310.pyc +0 -0
- __pycache__/visualize_lora_rn_4x8.cpython-311.pyc +0 -0
- accelerate_config.yaml +16 -0
- cc3m_render.log +0 -0
- cc3m_render.py +155 -0
- download.log +0 -0
- download_sd3_models.py +71 -0
- eval_baseline.log +24 -0
- eval_rectified_noise_new_batch_2.log +24 -0
- evaluate.sh +11 -0
- evaluator_base copy.py +680 -0
- evaluator_base.log +5 -0
- evaluator_base.py +685 -0
- evaluator_rf.py +685 -0
- evaluator_rf_iter22.log +25 -0
- pic_npz copy.py +259 -0
- pic_npz.py +157 -0
- pipeline_stable_diffusion_3.py +1378 -0
- rectified-noise-batch-2/checkpoint-100000/sit_weights/sit_config.json +10 -0
- rectified-noise-batch-2/checkpoint-120000/sit_weights/sit_config.json +10 -0
- rectified-noise-batch-2/checkpoint-140000/sit_weights/sit_config.json +10 -0
- rectified-noise-batch-2/checkpoint-160000/sit_weights/sit_config.json +10 -0
- rectified-noise-batch-2/checkpoint-180000/sit_weights/sit_config.json +10 -0
- rectified-noise-batch-2/checkpoint-200000/sit_weights/sit_config.json +10 -0
- run_sd3_lora_rn_pair_sampling.sh +50 -0
- run_sd3_lora_sampling.log +0 -0
- run_sd3_lora_sampling.sh +94 -0
- run_sd3_rectified_sampling.sh +55 -0
- run_sd3_rectified_sampling_old.sh +72 -0
- sample_sd3_lora_checkpoint_ddp.py +818 -0
- sample_sd3_lora_ddp.py +675 -0
- sample_sd3_lora_rn_pair_ddp.py +417 -0
- sample_sd3_rectified_ddp.py +1316 -0
- sample_sd3_rectified_ddp_old.py +1317 -0
- sd3_rectified_samples_batch2_2200005011.01.01.0cfg_cond_true.txt +5 -0
- train_lora_sd3.py +1597 -0
- train_lora_sd3_new.py +1422 -0
- train_rectified_noise.py +0 -0
- train_rectified_noise.sh +104 -0
- train_rectified_noise2.py +0 -0
- train_sd3_lora.log +27 -0
- train_sd3_lora.sh +109 -0
- train_sd3_lora2.log +216 -0
- train_sd3_lora2.sh +107 -0
- visual.sh +78 -0
- visualize_lora_rn_4x8.py +406 -0
- visualize_sitf2_noise_evolution.py +169 -0
__pycache__/pic_npz.cpython-311.pyc
ADDED
|
Binary file (7.6 kB). View file
|
|
|
__pycache__/pipeline_stable_diffusion_3.cpython-310.pyc
ADDED
|
Binary file (39.5 kB). View file
|
|
|
__pycache__/sample_sd3_lora_rn_pair_ddp.cpython-311.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
__pycache__/train_rectified_noise.cpython-310.pyc
ADDED
|
Binary file (58.7 kB). View file
|
|
|
__pycache__/visualize_lora_rn_4x8.cpython-311.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
accelerate_config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: 'no'
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 4
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
cc3m_render.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cc3m_render.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
在 `data_root/` 下已经有 `train/` 和 `validation/` 两个文件夹时:
|
| 6 |
+
分别在这两个文件夹内生成对应的 `metadata.jsonl`,不复制任何图片。
|
| 7 |
+
|
| 8 |
+
`metadata.jsonl` 每行格式:
|
| 9 |
+
{"file_name": "subdir/000026831.jpg", "caption": "..."}
|
| 10 |
+
|
| 11 |
+
其中 `file_name` 是相对当前 split 目录(train/ 或 validation/)的路径。
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 18 |
+
from itertools import islice
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args() -> argparse.Namespace:
|
| 26 |
+
parser = argparse.ArgumentParser(description="Generate per-split metadata.jsonl for imagefolder (no copy)")
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--data_root",
|
| 29 |
+
type=str,
|
| 30 |
+
default="/gemini/space/hsd/project/dataset/cc3m-wds",
|
| 31 |
+
help="数据根目录(必须包含 train/ 和 validation/)",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--jsonl_name",
|
| 35 |
+
type=str,
|
| 36 |
+
default="metadata.jsonl",
|
| 37 |
+
help="每个 split 下生成的 jsonl 文件名(默认 metadata.jsonl)",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--use_txt_caption",
|
| 41 |
+
action="store_true",
|
| 42 |
+
default=True,
|
| 43 |
+
help="优先读取同名 .txt 作为 caption(默认开启),否则回落到 .json",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--num_workers",
|
| 47 |
+
type=int,
|
| 48 |
+
default=32,
|
| 49 |
+
help="线程数(I/O 密集型建议 8~64 之间按机器调整)",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--max_images",
|
| 53 |
+
type=int,
|
| 54 |
+
default=None,
|
| 55 |
+
help="每个 split 最多处理多少张图片(None 表示全部,调试可用)",
|
| 56 |
+
)
|
| 57 |
+
return parser.parse_args()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def read_caption_from_txt(txt_path: Path) -> Optional[str]:
|
| 61 |
+
if not txt_path.exists():
|
| 62 |
+
return None
|
| 63 |
+
try:
|
| 64 |
+
with txt_path.open("r", encoding="utf-8") as f:
|
| 65 |
+
caption = f.read().strip()
|
| 66 |
+
return caption or None
|
| 67 |
+
except Exception:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def read_caption_from_json(json_path: Path) -> Optional[str]:
|
| 72 |
+
if not json_path.exists():
|
| 73 |
+
return None
|
| 74 |
+
try:
|
| 75 |
+
with json_path.open("r", encoding="utf-8") as f:
|
| 76 |
+
data = json.load(f)
|
| 77 |
+
for key in ["caption", "text", "description"]:
|
| 78 |
+
if key in data and isinstance(data[key], str) and data[key].strip():
|
| 79 |
+
return data[key].strip()
|
| 80 |
+
except Exception:
|
| 81 |
+
return None
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main() -> None:
|
| 86 |
+
args = parse_args()
|
| 87 |
+
|
| 88 |
+
data_root = Path(args.data_root).resolve()
|
| 89 |
+
if not data_root.exists():
|
| 90 |
+
raise FileNotFoundError(f"数据根目录不存在:{data_root}")
|
| 91 |
+
|
| 92 |
+
splits = [("train", data_root / "train"), ("validation", data_root / "validation")]
|
| 93 |
+
for split_name, split_dir in splits:
|
| 94 |
+
if not split_dir.exists():
|
| 95 |
+
raise FileNotFoundError(f"缺少目录:{split_dir}(需要 train/ 和 validation/)")
|
| 96 |
+
|
| 97 |
+
def iter_images(split_dir: Path):
|
| 98 |
+
for root, _dirs, files in os.walk(split_dir):
|
| 99 |
+
for name in files:
|
| 100 |
+
if name.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 101 |
+
yield Path(root) / name
|
| 102 |
+
|
| 103 |
+
def process_one(img_path: Path, split_dir: Path) -> Optional[Tuple[str, str]]:
|
| 104 |
+
txt_path = img_path.with_suffix(".txt")
|
| 105 |
+
json_path = img_path.with_suffix(".json")
|
| 106 |
+
|
| 107 |
+
caption = None
|
| 108 |
+
if args.use_txt_caption:
|
| 109 |
+
caption = read_caption_from_txt(txt_path)
|
| 110 |
+
if caption is None:
|
| 111 |
+
caption = read_caption_from_json(json_path)
|
| 112 |
+
else:
|
| 113 |
+
caption = read_caption_from_json(json_path)
|
| 114 |
+
if caption is None:
|
| 115 |
+
caption = read_caption_from_txt(txt_path)
|
| 116 |
+
|
| 117 |
+
if caption is None:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
rel = img_path.relative_to(split_dir)
|
| 121 |
+
return str(rel).replace(os.sep, "/"), caption
|
| 122 |
+
|
| 123 |
+
for split_name, split_dir in splits:
|
| 124 |
+
jsonl_path = split_dir / args.jsonl_name
|
| 125 |
+
|
| 126 |
+
img_iter = iter_images(split_dir)
|
| 127 |
+
if args.max_images is not None:
|
| 128 |
+
img_iter = islice(img_iter, args.max_images)
|
| 129 |
+
|
| 130 |
+
# tqdm 需要可迭代对象,这里不预先收集列表以节省内存
|
| 131 |
+
# 进度条显示 processed 数量(total 可能未知)
|
| 132 |
+
def _task_iter():
|
| 133 |
+
for p in img_iter:
|
| 134 |
+
yield p
|
| 135 |
+
|
| 136 |
+
written = 0
|
| 137 |
+
with jsonl_path.open("w", encoding="utf-8") as f, ThreadPoolExecutor(max_workers=args.num_workers) as ex:
|
| 138 |
+
# executor.map 保持输入顺序;tqdm 显示处理进度
|
| 139 |
+
for result in tqdm(
|
| 140 |
+
ex.map(lambda p: process_one(p, split_dir), _task_iter()),
|
| 141 |
+
desc=f"[{split_name}] Processing",
|
| 142 |
+
):
|
| 143 |
+
if result is None:
|
| 144 |
+
continue
|
| 145 |
+
file_name, caption = result
|
| 146 |
+
f.write(json.dumps({"file_name": file_name, "caption": caption}, ensure_ascii=False) + "\n")
|
| 147 |
+
written += 1
|
| 148 |
+
|
| 149 |
+
print(f"{split_name}: 写入 {written} 条 -> {jsonl_path}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
main()
|
| 154 |
+
|
| 155 |
+
# nohup python cc3m_render.py > cc3m_render.log 2>&1 &
|
download.log
ADDED
|
File without changes
|
download_sd3_models.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
只负责“按 train_lora_sd3.py 相同的方式”下载 SD3 及相关组件到默认 HF cache:
|
| 6 |
+
通过依次调用 `from_pretrained(..., subfolder=...)` 来触发下载。
|
| 7 |
+
|
| 8 |
+
会下载的子目录(与 train_lora_sd3.py 一致):
|
| 9 |
+
tokenizer, tokenizer_2, tokenizer_3,
|
| 10 |
+
text_encoder, text_encoder_2, text_encoder_3,
|
| 11 |
+
scheduler, vae, transformer
|
| 12 |
+
|
| 13 |
+
用法:
|
| 14 |
+
python download_sd3_models.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers
|
| 15 |
+
|
| 16 |
+
下载完成后训练可离线:
|
| 17 |
+
export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
| 18 |
+
python train_lora_sd3.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers ...
|
| 19 |
+
或直接指向你已经下载好的本地 repo 目录。
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import gc
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from diffusers import StableDiffusion3Pipeline
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_args() -> argparse.Namespace:
|
| 31 |
+
p = argparse.ArgumentParser(description="Download SD3 via a single from_pretrained call (cache warmup only).")
|
| 32 |
+
p.add_argument(
|
| 33 |
+
"--pretrained_model_name_or_path",
|
| 34 |
+
type=str,
|
| 35 |
+
default="stabilityai/stable-diffusion-3-medium-diffusers",
|
| 36 |
+
help="模型 repo id 或本地路径(与 train_lora_sd3.py 参数一致",
|
| 37 |
+
)
|
| 38 |
+
p.add_argument("--revision", type=str, default=None, help="可选:下载特定 revision/branch/tag")
|
| 39 |
+
p.add_argument("--variant", type=str, default=None, help="可选:如 fp16 等 variant(若仓库提供)")
|
| 40 |
+
p.add_argument("--cache_dir", type=str, default=None, help="可选:自定义 HF cache_dir;默认用系统/用户默认")
|
| 41 |
+
return p.parse_args()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main() -> None:
|
| 45 |
+
args = parse_args()
|
| 46 |
+
|
| 47 |
+
model = args.pretrained_model_name_or_path
|
| 48 |
+
|
| 49 |
+
# 最简单:直接加载整条 pipeline,触发把其依赖的全部组件下载进默认 HF cache
|
| 50 |
+
# 注意:from_pretrained 下载的是 pipeline 会用到的文件;本脚本目的就是让训练时不再联网。
|
| 51 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(
|
| 52 |
+
model,
|
| 53 |
+
revision=args.revision,
|
| 54 |
+
variant=args.variant,
|
| 55 |
+
cache_dir=args.cache_dir,
|
| 56 |
+
low_cpu_mem_usage=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# 释放内存(下载已完成)
|
| 60 |
+
del pipe
|
| 61 |
+
gc.collect()
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
|
| 65 |
+
print("下载/缓存预热完成。后续训练可设置离线环境变量避免联网:")
|
| 66 |
+
print(" export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
| 71 |
+
|
eval_baseline.log
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-21 21:11:45.919794: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
|
|
|
|
|
|
|
|
|
| 1 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 2 |
0%| | 1/211 [00:02<07:04, 2.02s/it]
|
| 3 |
1%| | 2/211 [00:03<06:56, 2.00s/it]
|
| 4 |
1%|▏ | 3/211 [00:05<06:50, 1.98s/it]
|
| 5 |
2%|▏ | 4/211 [00:07<06:42, 1.95s/it]
|
| 6 |
2%|▏ | 5/211 [00:09<06:25, 1.87s/it]
|
| 7 |
3%|▎ | 6/211 [00:11<06:20, 1.85s/it]
|
| 8 |
3%|▎ | 7/211 [00:13<06:46, 1.99s/it]
|
| 9 |
4%|▍ | 8/211 [00:15<06:56, 2.05s/it]
|
| 10 |
4%|▍ | 9/211 [00:30<20:28, 6.08s/it]
|
| 11 |
5%|▍ | 10/211 [00:32<15:51, 4.73s/it]
|
| 12 |
5%|▌ | 11/211 [00:34<12:42, 3.81s/it]
|
| 13 |
6%|▌ | 12/211 [00:35<10:26, 3.15s/it]
|
| 14 |
6%|▌ | 13/211 [00:37<08:55, 2.71s/it]
|
| 15 |
7%|▋ | 14/211 [00:39<08:09, 2.49s/it]
|
| 16 |
7%|▋ | 15/211 [00:41<07:32, 2.31s/it]
|
| 17 |
8%|▊ | 16/211 [00:43<06:52, 2.12s/it]
|
| 18 |
8%|▊ | 17/211 [00:44<06:35, 2.04s/it]
|
| 19 |
9%|▊ | 18/211 [00:46<06:18, 1.96s/it]
|
| 20 |
9%|▉ | 19/211 [00:48<06:16, 1.96s/it]
|
| 21 |
9%|▉ | 20/211 [00:50<06:20, 1.99s/it]
|
| 22 |
10%|▉ | 21/211 [00:52<06:01, 1.91s/it]
|
| 23 |
10%|█ | 22/211 [00:54<05:51, 1.86s/it]
|
| 24 |
11%|█ | 23/211 [00:56<05:48, 1.85s/it]
|
| 25 |
11%|█▏ | 24/211 [00:57<05:44, 1.84s/it]
|
| 26 |
12%|█▏ | 25/211 [00:59<05:44, 1.85s/it]
|
| 27 |
12%|█▏ | 26/211 [01:01<05:38, 1.83s/it]
|
| 28 |
13%|█▎ | 27/211 [01:03<05:33, 1.81s/it]
|
| 29 |
13%|█▎ | 28/211 [01:05<06:02, 1.98s/it]
|
| 30 |
14%|█▎ | 29/211 [01:07<05:49, 1.92s/it]
|
| 31 |
14%|█▍ | 30/211 [01:09<05:33, 1.84s/it]
|
| 32 |
15%|█▍ | 31/211 [01:11<05:38, 1.88s/it]
|
| 33 |
15%|█▌ | 32/211 [01:12<05:32, 1.86s/it]
|
| 34 |
16%|█▌ | 33/211 [01:14<05:26, 1.83s/it]
|
| 35 |
16%|█▌ | 34/211 [01:16<05:20, 1.81s/it]
|
| 36 |
17%|█▋ | 35/211 [01:18<05:16, 1.80s/it]
|
| 37 |
17%|█▋ | 36/211 [01:19<05:11, 1.78s/it]
|
| 38 |
18%|█▊ | 37/211 [01:21<05:06, 1.76s/it]
|
| 39 |
18%|█▊ | 38/211 [01:23<05:22, 1.86s/it]
|
| 40 |
18%|█▊ | 39/211 [01:25<05:12, 1.82s/it]
|
| 41 |
19%|█▉ | 40/211 [01:35<11:51, 4.16s/it]
|
| 42 |
19%|█▉ | 41/211 [01:36<09:45, 3.44s/it]
|
| 43 |
20%|█▉ | 42/211 [01:38<08:27, 3.00s/it]
|
| 44 |
20%|██ | 43/211 [01:40<07:25, 2.65s/it]
|
| 45 |
21%|██ | 44/211 [01:42<06:46, 2.43s/it]
|
| 46 |
21%|██▏ | 45/211 [01:44<06:28, 2.34s/it]
|
| 47 |
22%|██▏ | 46/211 [01:46<05:56, 2.16s/it]
|
| 48 |
22%|██▏ | 47/211 [01:48<05:55, 2.17s/it]
|
| 49 |
23%|██▎ | 48/211 [01:50<05:30, 2.03s/it]
|
| 50 |
23%|██▎ | 49/211 [01:52<05:22, 1.99s/it]
|
| 51 |
24%|██▎ | 50/211 [01:54<05:11, 1.94s/it]
|
| 52 |
24%|██▍ | 51/211 [01:56<05:18, 1.99s/it]
|
| 53 |
25%|██▍ | 52/211 [01:57<05:02, 1.90s/it]
|
| 54 |
25%|██▌ | 53/211 [01:59<04:52, 1.85s/it]
|
| 55 |
26%|██▌ | 54/211 [02:01<04:42, 1.80s/it]
|
| 56 |
26%|██▌ | 55/211 [02:03<04:40, 1.80s/it]
|
| 57 |
27%|██▋ | 56/211 [02:05<04:47, 1.85s/it]
|
| 58 |
27%|██▋ | 57/211 [02:06<04:38, 1.81s/it]
|
| 59 |
27%|██▋ | 58/211 [02:08<04:40, 1.83s/it]
|
| 60 |
28%|██▊ | 59/211 [02:10<04:34, 1.81s/it]
|
| 61 |
28%|██▊ | 60/211 [02:12<04:31, 1.80s/it]
|
| 62 |
29%|██▉ | 61/211 [02:13<04:24, 1.76s/it]
|
| 63 |
29%|██▉ | 62/211 [02:15<04:19, 1.74s/it]
|
| 64 |
30%|██▉ | 63/211 [02:17<04:25, 1.79s/it]
|
| 65 |
30%|███ | 64/211 [02:19<04:22, 1.78s/it]
|
| 66 |
31%|███ | 65/211 [02:21<04:21, 1.79s/it]
|
| 67 |
31%|███▏ | 66/211 [02:22<04:16, 1.77s/it]
|
| 68 |
32%|███▏ | 67/211 [02:24<04:27, 1.85s/it]
|
| 69 |
32%|███▏ | 68/211 [02:26<04:23, 1.84s/it]
|
| 70 |
33%|███▎ | 69/211 [02:28<04:21, 1.84s/it]
|
| 71 |
33%|███▎ | 70/211 [02:30<04:15, 1.81s/it]
|
| 72 |
34%|███▎ | 71/211 [02:32<04:13, 1.81s/it]
|
| 73 |
34%|███▍ | 72/211 [02:33<04:09, 1.79s/it]
|
| 74 |
35%|███▍ | 73/211 [02:35<04:04, 1.77s/it]
|
| 75 |
35%|███▌ | 74/211 [02:53<15:24, 6.75s/it]
|
| 76 |
36%|███▌ | 75/211 [02:55<11:58, 5.28s/it]
|
| 77 |
36%|███▌ | 76/211 [02:57<09:31, 4.24s/it]
|
| 78 |
36%|███▋ | 77/211 [02:59<07:46, 3.48s/it]
|
| 79 |
37%|███▋ | 78/211 [03:00<06:33, 2.96s/it]
|
| 80 |
37%|███▋ | 79/211 [03:02<05:47, 2.64s/it]
|
| 81 |
38%|███▊ | 80/211 [03:04<05:19, 2.44s/it]
|
| 82 |
38%|███▊ | 81/211 [03:06<04:47, 2.22s/it]
|
| 83 |
39%|███▉ | 82/211 [03:08<04:27, 2.07s/it]
|
| 84 |
39%|███▉ | 83/211 [03:09<04:12, 1.97s/it]
|
| 85 |
40%|███▉ | 84/211 [03:11<04:02, 1.91s/it]
|
| 86 |
40%|████ | 85/211 [03:13<03:53, 1.85s/it]
|
| 87 |
41%|████ | 86/211 [03:15<03:48, 1.83s/it]
|
| 88 |
41%|████ | 87/211 [03:16<03:42, 1.79s/it]
|
| 89 |
42%|████▏ | 88/211 [03:18<03:40, 1.79s/it]
|
| 90 |
42%|████▏ | 89/211 [03:20<03:35, 1.77s/it]
|
| 91 |
43%|████▎ | 90/211 [03:22<03:33, 1.77s/it]
|
| 92 |
43%|████▎ | 91/211 [03:23<03:30, 1.76s/it]
|
| 93 |
44%|████▎ | 92/211 [03:25<03:35, 1.81s/it]
|
| 94 |
44%|████▍ | 93/211 [03:27<03:28, 1.76s/it]
|
| 95 |
45%|████▍ | 94/211 [03:29<03:25, 1.75s/it]
|
| 96 |
45%|████▌ | 95/211 [03:31<03:22, 1.74s/it]
|
| 97 |
45%|████▌ | 96/211 [03:32<03:24, 1.77s/it]
|
| 98 |
46%|████▌ | 97/211 [03:34<03:21, 1.77s/it]
|
| 99 |
46%|████▋ | 98/211 [03:36<03:19, 1.77s/it]
|
| 100 |
47%|████▋ | 99/211 [03:38<03:18, 1.77s/it]
|
| 101 |
47%|████▋ | 100/211 [03:39<03:14, 1.75s/it]
|
| 102 |
48%|████▊ | 101/211 [03:41<03:13, 1.75s/it]
|
| 103 |
48%|████▊ | 102/211 [03:43<03:13, 1.77s/it]
|
| 104 |
49%|████▉ | 103/211 [03:45<03:08, 1.75s/it]
|
| 105 |
49%|████▉ | 104/211 [03:47<03:12, 1.79s/it]
|
| 106 |
50%|████▉ | 105/211 [03:48<03:07, 1.77s/it]
|
| 107 |
50%|█████ | 106/211 [04:10<13:22, 7.64s/it]
|
| 108 |
51%|█████ | 107/211 [04:12<10:18, 5.94s/it]
|
| 109 |
51%|█████ | 108/211 [04:13<08:02, 4.68s/it]
|
| 110 |
52%|█████▏ | 109/211 [04:15<06:26, 3.79s/it]
|
| 111 |
52%|█████▏ | 110/211 [04:17<05:18, 3.16s/it]
|
| 112 |
53%|█████▎ | 111/211 [04:18<04:33, 2.74s/it]
|
| 113 |
53%|█████▎ | 112/211 [04:20<04:09, 2.52s/it]
|
| 114 |
54%|█████▎ | 113/211 [04:22<03:50, 2.35s/it]
|
| 115 |
54%|█████▍ | 114/211 [04:24<03:30, 2.17s/it]
|
| 116 |
55%|█████▍ | 115/211 [04:26<03:22, 2.11s/it]
|
| 117 |
55%|█████▍ | 116/211 [04:46<11:49, 7.47s/it]
|
| 118 |
55%|█████▌ | 117/211 [04:48<08:58, 5.73s/it]
|
| 119 |
56%|█████▌ | 118/211 [04:50<07:08, 4.61s/it]
|
| 120 |
56%|█████▋ | 119/211 [04:51<05:41, 3.71s/it]
|
| 121 |
57%|█████▋ | 120/211 [04:53<04:42, 3.10s/it]
|
| 122 |
57%|█████▋ | 121/211 [04:55<04:01, 2.69s/it]
|
| 123 |
58%|█████▊ | 122/211 [04:57<03:33, 2.40s/it]
|
| 124 |
58%|█████▊ | 123/211 [04:58<03:13, 2.20s/it]
|
| 125 |
59%|█████▉ | 124/211 [05:00<03:01, 2.08s/it]
|
| 126 |
59%|█████▉ | 125/211 [05:02<02:50, 1.98s/it]
|
| 127 |
60%|█████▉ | 126/211 [05:03<02:38, 1.87s/it]
|
| 128 |
60%|██████ | 127/211 [05:05<02:31, 1.80s/it]
|
| 129 |
61%|██████ | 128/211 [05:07<02:25, 1.75s/it]
|
| 130 |
61%|██████ | 129/211 [05:09<02:29, 1.82s/it]
|
| 131 |
62%|██████▏ | 130/211 [05:10<02:24, 1.79s/it]
|
| 132 |
62%|██████▏ | 131/211 [05:12<02:20, 1.76s/it]
|
| 133 |
63%|██████▎ | 132/211 [05:14<02:17, 1.74s/it]
|
| 134 |
63%|██████▎ | 133/211 [05:15<02:14, 1.73s/it]
|
| 135 |
64%|██████▎ | 134/211 [05:18<02:23, 1.87s/it]
|
| 136 |
64%|██████▍ | 135/211 [05:20<02:22, 1.87s/it]
|
| 137 |
64%|██████▍ | 136/211 [05:21<02:15, 1.80s/it]
|
| 138 |
65%|██████▍ | 137/211 [05:23<02:13, 1.81s/it]
|
| 139 |
65%|██████▌ | 138/211 [05:25<02:10, 1.79s/it]
|
| 140 |
66%|██████▌ | 139/211 [05:26<02:06, 1.75s/it]
|
| 141 |
66%|██████▋ | 140/211 [05:28<02:03, 1.73s/it]
|
| 142 |
67%|██████▋ | 141/211 [05:30<02:04, 1.78s/it]
|
| 143 |
67%|██████▋ | 142/211 [05:32<02:05, 1.82s/it]
|
| 144 |
68%|██████▊ | 143/211 [05:34<02:04, 1.83s/it]
|
| 145 |
68%|██��███▊ | 144/211 [05:35<01:59, 1.78s/it]
|
| 146 |
69%|██████▊ | 145/211 [05:37<01:57, 1.78s/it]
|
| 147 |
69%|██████▉ | 146/211 [05:55<07:10, 6.63s/it]
|
| 148 |
70%|██████▉ | 147/211 [05:57<05:30, 5.17s/it]
|
| 149 |
70%|███████ | 148/211 [05:59<04:23, 4.17s/it]
|
| 150 |
71%|███████ | 149/211 [06:00<03:31, 3.42s/it]
|
| 151 |
71%|███████ | 150/211 [06:02<02:58, 2.92s/it]
|
| 152 |
72%|███████▏ | 151/211 [06:04<02:33, 2.56s/it]
|
| 153 |
72%|███████▏ | 152/211 [06:06<02:15, 2.29s/it]
|
| 154 |
73%|███████▎ | 153/211 [06:07<02:02, 2.12s/it]
|
| 155 |
73%|███████▎ | 154/211 [06:10<02:15, 2.38s/it]
|
| 156 |
73%|███████▎ | 155/211 [06:12<02:00, 2.15s/it]
|
| 157 |
74%|███████▍ | 156/211 [06:14<01:50, 2.00s/it]
|
| 158 |
74%|███████▍ | 157/211 [06:15<01:43, 1.91s/it]
|
| 159 |
75%|███████▍ | 158/211 [06:17<01:38, 1.86s/it]
|
| 160 |
75%|███████▌ | 159/211 [06:19<01:37, 1.88s/it]
|
| 161 |
76%|███████▌ | 160/211 [06:21<01:32, 1.82s/it]
|
| 162 |
76%|███████▋ | 161/211 [06:22<01:29, 1.79s/it]
|
| 163 |
77%|███████▋ | 162/211 [06:24<01:31, 1.87s/it]
|
| 164 |
77%|███████▋ | 163/211 [06:26<01:27, 1.83s/it]
|
| 165 |
78%|███████▊ | 164/211 [06:29<01:36, 2.05s/it]
|
| 166 |
78%|███████▊ | 165/211 [06:31<01:32, 2.01s/it]
|
| 167 |
79%|███████▊ | 166/211 [06:32<01:27, 1.94s/it]
|
| 168 |
79%|███████▉ | 167/211 [06:34<01:21, 1.85s/it]
|
| 169 |
80%|███████▉ | 168/211 [06:36<01:17, 1.79s/it]
|
| 170 |
80%|████████ | 169/211 [06:38<01:19, 1.89s/it]
|
| 171 |
81%|████████ | 170/211 [06:40<01:17, 1.88s/it]
|
| 172 |
81%|████████ | 171/211 [06:41<01:14, 1.85s/it]
|
| 173 |
82%|████████▏ | 172/211 [06:43<01:11, 1.84s/it]
|
| 174 |
82%|████████▏ | 173/211 [06:45<01:08, 1.81s/it]
|
| 175 |
82%|████████▏ | 174/211 [06:47<01:07, 1.84s/it]
|
| 176 |
83%|████████▎ | 175/211 [06:49<01:04, 1.78s/it]
|
| 177 |
83%|████████▎ | 176/211 [06:50<01:02, 1.79s/it]
|
| 178 |
84%|████████▍ | 177/211 [06:52<01:00, 1.77s/it]
|
| 179 |
84%|████████▍ | 178/211 [06:54<01:03, 1.92s/it]
|
| 180 |
85%|████████▍ | 179/211 [06:56<00:59, 1.87s/it]
|
| 181 |
85%|████████▌ | 180/211 [06:58<00:56, 1.81s/it]
|
| 182 |
86%|████████▌ | 181/211 [06:59<00:53, 1.77s/it]
|
| 183 |
86%|████████▋ | 182/211 [07:01<00:53, 1.84s/it]
|
| 184 |
87%|████████▋ | 183/211 [07:03<00:50, 1.81s/it]
|
| 185 |
87%|████████▋ | 184/211 [07:05<00:49, 1.83s/it]
|
| 186 |
88%|████████▊ | 185/211 [07:07<00:48, 1.86s/it]
|
| 187 |
88%|████████▊ | 186/211 [07:09<00:45, 1.83s/it]
|
| 188 |
89%|████████▊ | 187/211 [07:11<00:43, 1.82s/it]
|
| 189 |
89%|████████▉ | 188/211 [07:12<00:42, 1.83s/it]
|
| 190 |
90%|████████▉ | 189/211 [07:14<00:39, 1.77s/it]
|
| 191 |
90%|█████████ | 190/211 [07:16<00:37, 1.77s/it]
|
| 192 |
91%|█████████ | 191/211 [07:17<00:34, 1.73s/it]
|
| 193 |
91%|█████████ | 192/211 [07:19<00:32, 1.71s/it]
|
| 194 |
91%|█████████▏| 193/211 [07:21<00:30, 1.72s/it]
|
| 195 |
92%|█████████▏| 194/211 [07:23<00:29, 1.72s/it]
|
| 196 |
92%|█████████▏| 195/211 [07:24<00:27, 1.72s/it]
|
| 197 |
93%|█████████▎| 196/211 [07:26<00:25, 1.70s/it]
|
| 198 |
93%|█████████▎| 197/211 [07:28<00:24, 1.75s/it]
|
| 199 |
94%|█████████▍| 198/211 [07:30<00:23, 1.78s/it]
|
| 200 |
94%|█████████▍| 199/211 [07:31<00:21, 1.76s/it]
|
| 201 |
95%|█████████▍| 200/211 [07:33<00:19, 1.78s/it]
|
| 202 |
95%|█████████▌| 201/211 [07:35<00:17, 1.76s/it]
|
| 203 |
96%|█████████▌| 202/211 [07:37<00:16, 1.80s/it]
|
| 204 |
96%|█████████▌| 203/211 [07:39<00:14, 1.79s/it]
|
| 205 |
97%|█████████▋| 204/211 [07:40<00:12, 1.76s/it]
|
| 206 |
97%|█████████▋| 205/211 [07:42<00:10, 1.74s/it]
|
| 207 |
98%|█████████▊| 206/211 [07:44<00:08, 1.75s/it]
|
| 208 |
98%|█████████▊| 207/211 [07:45<00:06, 1.72s/it]
|
| 209 |
99%|█████████▊| 208/211 [07:47<00:05, 1.70s/it]
|
| 210 |
99%|█████████▉| 209/211 [07:49<00:03, 1.68s/it]
|
|
|
|
|
|
|
|
|
|
| 211 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 212 |
0%| | 1/469 [00:02<18:09, 2.33s/it]
|
| 213 |
0%| | 2/469 [00:04<15:10, 1.95s/it]
|
| 214 |
1%| | 3/469 [00:05<14:14, 1.83s/it]
|
| 215 |
1%| | 4/469 [00:07<13:52, 1.79s/it]
|
| 216 |
1%| | 5/469 [00:09<13:23, 1.73s/it]
|
| 217 |
1%|▏ | 6/469 [00:12<17:01, 2.21s/it]
|
| 218 |
1%|▏ | 7/469 [00:14<16:26, 2.14s/it]
|
| 219 |
2%|▏ | 8/469 [00:15<15:33, 2.02s/it]
|
| 220 |
2%|▏ | 9/469 [00:17<15:18, 2.00s/it]
|
| 221 |
2%|▏ | 10/469 [00:19<14:43, 1.92s/it]
|
| 222 |
2%|▏ | 11/469 [00:21<14:04, 1.84s/it]
|
| 223 |
3%|▎ | 12/469 [00:23<14:21, 1.89s/it]
|
| 224 |
3%|▎ | 13/469 [00:25<14:16, 1.88s/it]
|
| 225 |
3%|▎ | 14/469 [00:27<14:15, 1.88s/it]
|
| 226 |
3%|▎ | 15/469 [00:28<14:05, 1.86s/it]
|
| 227 |
3%|▎ | 16/469 [00:30<14:24, 1.91s/it]
|
| 228 |
4%|▎ | 17/469 [00:32<14:05, 1.87s/it]
|
| 229 |
4%|▍ | 18/469 [00:34<14:13, 1.89s/it]
|
| 230 |
4%|▍ | 19/469 [00:36<13:44, 1.83s/it]
|
| 231 |
4%|▍ | 20/469 [00:37<13:14, 1.77s/it]
|
| 232 |
4%|▍ | 21/469 [00:39<13:07, 1.76s/it]
|
| 233 |
5%|▍ | 22/469 [00:41<12:51, 1.72s/it]
|
| 234 |
5%|▍ | 23/469 [00:43<12:56, 1.74s/it]
|
| 235 |
5%|▌ | 24/469 [00:44<12:53, 1.74s/it]
|
| 236 |
5%|▌ | 25/469 [00:46<12:43, 1.72s/it]
|
| 237 |
6%|▌ | 26/469 [00:48<12:47, 1.73s/it]
|
| 238 |
6%|▌ | 27/469 [00:49<12:34, 1.71s/it]
|
| 239 |
6%|▌ | 28/469 [00:51<12:53, 1.75s/it]
|
| 240 |
6%|▌ | 29/469 [00:53<12:38, 1.72s/it]
|
| 241 |
6%|▋ | 30/469 [00:55<12:33, 1.72s/it]
|
| 242 |
7%|▋ | 31/469 [00:56<12:17, 1.68s/it]
|
| 243 |
7%|▋ | 32/469 [00:58<12:27, 1.71s/it]
|
| 244 |
7%|▋ | 33/469 [01:00<12:22, 1.70s/it]
|
| 245 |
7%|▋ | 34/469 [01:03<16:09, 2.23s/it]
|
| 246 |
7%|▋ | 35/469 [01:05<15:00, 2.08s/it]
|
| 247 |
8%|▊ | 36/469 [01:07<14:10, 1.96s/it]
|
| 248 |
8%|▊ | 37/469 [01:08<13:48, 1.92s/it]
|
| 249 |
8%|▊ | 38/469 [01:10<13:13, 1.84s/it]
|
| 250 |
8%|▊ | 39/469 [01:12<13:28, 1.88s/it]
|
| 251 |
9%|▊ | 40/469 [01:14<13:04, 1.83s/it]
|
| 252 |
9%|▊ | 41/469 [01:16<13:04, 1.83s/it]
|
| 253 |
9%|▉ | 42/469 [01:19<15:53, 2.23s/it]
|
| 254 |
9%|▉ | 43/469 [01:21<15:36, 2.20s/it]
|
| 255 |
9%|▉ | 44/469 [01:23<14:48, 2.09s/it]
|
| 256 |
10%|▉ | 45/469 [01:24<13:55, 1.97s/it]
|
| 257 |
10%|▉ | 46/469 [01:26<13:20, 1.89s/it]
|
| 258 |
10%|█ | 47/469 [01:28<13:24, 1.91s/it]
|
| 259 |
10%|█ | 48/469 [01:30<13:15, 1.89s/it]
|
| 260 |
10%|█ | 49/469 [01:31<12:39, 1.81s/it]
|
| 261 |
11%|█ | 50/469 [01:33<12:26, 1.78s/it]
|
| 262 |
11%|█ | 51/469 [01:35<12:47, 1.84s/it]
|
| 263 |
11%|█ | 52/469 [01:37<12:28, 1.79s/it]
|
| 264 |
11%|█▏ | 53/469 [01:39<12:09, 1.75s/it]
|
| 265 |
12%|█▏ | 54/469 [01:40<12:13, 1.77s/it]
|
| 266 |
12%|█▏ | 55/469 [01:42<12:03, 1.75s/it]
|
| 267 |
12%|█▏ | 56/469 [01:44<12:13, 1.78s/it]
|
| 268 |
12%|█▏ | 57/469 [01:46<12:27, 1.81s/it]
|
| 269 |
12%|█▏ | 58/469 [01:48<12:53, 1.88s/it]
|
| 270 |
13%|█▎ | 59/469 [01:50<12:46, 1.87s/it]
|
| 271 |
13%|█▎ | 60/469 [01:51<12:25, 1.82s/it]
|
| 272 |
13%|█▎ | 61/469 [01:53<12:01, 1.77s/it]
|
| 273 |
13%|█▎ | 62/469 [01:55<11:49, 1.74s/it]
|
| 274 |
13%|█▎ | 63/469 [01:56<11:43, 1.73s/it]
|
| 275 |
14%|█▎ | 64/469 [01:58<11:37, 1.72s/it]
|
| 276 |
14%|█▍ | 65/469 [02:00<11:38, 1.73s/it]
|
| 277 |
14%|█▍ | 66/469 [02:02<11:28, 1.71s/it]
|
| 278 |
14%|█▍ | 67/469 [02:03<11:30, 1.72s/it]
|
| 279 |
14%|█▍ | 68/469 [02:05<11:30, 1.72s/it]
|
| 280 |
15%|█▍ | 69/469 [02:07<12:03, 1.81s/it]
|
| 281 |
15%|█▍ | 70/469 [02:09<11:46, 1.77s/it]
|
| 282 |
15%|█▌ | 71/469 [02:10<11:36, 1.75s/it]
|
| 283 |
15%|█▌ | 72/469 [02:12<11:34, 1.75s/it]
|
| 284 |
16%|█▌ | 73/469 [02:14<12:02, 1.82s/it]
|
| 285 |
16%|█▌ | 74/469 [02:16<11:41, 1.78s/it]
|
| 286 |
16%|█▌ | 75/469 [02:18<12:10, 1.85s/it]
|
| 287 |
16%|█▌ | 76/469 [02:20<12:04, 1.84s/it]
|
| 288 |
16%|█▋ | 77/469 [02:21<11:44, 1.80s/it]
|
| 289 |
17%|█▋ | 78/469 [02:23<11:44, 1.80s/it]
|
| 290 |
17%|█▋ | 79/469 [02:25<11:44, 1.81s/it]
|
| 291 |
17%|█▋ | 80/469 [02:27<11:43, 1.81s/it]
|
| 292 |
17%|█▋ | 81/469 [02:28<11:22, 1.76s/it]
|
| 293 |
17%|█▋ | 82/469 [02:30<11:10, 1.73s/it]
|
| 294 |
18%|█▊ | 83/469 [02:32<11:12, 1.74s/it]
|
| 295 |
18%|█▊ | 84/469 [02:34<11:11, 1.75s/it]
|
| 296 |
18%|█▊ | 85/469 [02:36<11:32, 1.80s/it]
|
| 297 |
18%|█▊ | 86/469 [02:37<11:32, 1.81s/it]
|
| 298 |
19%|█▊ | 87/469 [02:39<11:16, 1.77s/it]
|
| 299 |
19%|█▉ | 88/469 [02:41<11:07, 1.75s/it]
|
| 300 |
19%|█▉ | 89/469 [02:42<11:00, 1.74s/it]
|
| 301 |
19%|█▉ | 90/469 [02:44<11:09, 1.77s/it]
|
| 302 |
19%|█▉ | 91/469 [02:46<11:15, 1.79s/it]
|
| 303 |
20%|█▉ | 92/469 [02:48<11:09, 1.78s/it]
|
| 304 |
20%|█▉ | 93/469 [02:50<10:57, 1.75s/it]
|
| 305 |
20%|██ | 94/469 [02:51<10:55, 1.75s/it]
|
| 306 |
20%|██ | 95/469 [02:53<10:50, 1.74s/it]
|
| 307 |
20%|██ | 96/469 [02:55<10:42, 1.72s/it]
|
| 308 |
21%|██ | 97/469 [02:56<10:37, 1.71s/it]
|
| 309 |
21%|██ | 98/469 [02:58<10:51, 1.76s/it]
|
| 310 |
21%|██ | 99/469 [03:00<10:52, 1.76s/it]
|
| 311 |
21%|██▏ | 100/469 [03:02<11:08, 1.81s/it]
|
| 312 |
22%|██▏ | 101/469 [03:04<10:52, 1.77s/it]
|
| 313 |
22%|██▏ | 102/469 [03:05<10:48, 1.77s/it]
|
| 314 |
22%|██▏ | 103/469 [03:07<10:35, 1.74s/it]
|
| 315 |
22%|██▏ | 104/469 [03:09<10:28, 1.72s/it]
|
| 316 |
22%|██▏ | 105/469 [03:11<10:38, 1.75s/it]
|
| 317 |
23%|██▎ | 106/469 [03:12<10:29, 1.73s/it]
|
| 318 |
23%|██▎ | 107/469 [03:14<10:22, 1.72s/it]
|
| 319 |
23%|██▎ | 108/469 [03:16<10:41, 1.78s/it]
|
| 320 |
23%|██▎ | 109/469 [03:18<10:29, 1.75s/it]
|
| 321 |
23%|██▎ | 110/469 [03:19<10:32, 1.76s/it]
|
| 322 |
24%|██▎ | 111/469 [03:21<10:25, 1.75s/it]
|
| 323 |
24%|██▍ | 112/469 [03:23<10:24, 1.75s/it]
|
| 324 |
24%|██▍ | 113/469 [03:25<10:37, 1.79s/it]
|
| 325 |
24%|██▍ | 114/469 [03:26<10:24, 1.76s/it]
|
| 326 |
25%|██▍ | 115/469 [03:28<10:27, 1.77s/it]
|
| 327 |
25%|██▍ | 116/469 [03:30<10:22, 1.76s/it]
|
| 328 |
25%|██▍ | 117/469 [03:32<10:28, 1.79s/it]
|
| 329 |
25%|██▌ | 118/469 [03:34<10:52, 1.86s/it]
|
| 330 |
25%|██▌ | 119/469 [03:36<11:00, 1.89s/it]
|
| 331 |
26%|██▌ | 120/469 [03:37<10:42, 1.84s/it]
|
| 332 |
26%|██▌ | 121/469 [03:39<10:37, 1.83s/it]
|
| 333 |
26%|██▌ | 122/469 [03:41<10:17, 1.78s/it]
|
| 334 |
26%|██▌ | 123/469 [03:43<10:21, 1.80s/it]
|
| 335 |
26%|██▋ | 124/469 [03:45<10:14, 1.78s/it]
|
| 336 |
27%|██▋ | 125/469 [03:46<10:15, 1.79s/it]
|
| 337 |
27%|██▋ | 126/469 [03:48<10:12, 1.79s/it]
|
| 338 |
27%|██▋ | 127/469 [03:50<09:59, 1.75s/it]
|
| 339 |
27%|██▋ | 128/469 [03:51<09:49, 1.73s/it]
|
| 340 |
28%|██▊ | 129/469 [03:53<09:53, 1.75s/it]
|
| 341 |
28%|██▊ | 130/469 [03:55<09:42, 1.72s/it]
|
| 342 |
28%|██▊ | 131/469 [03:57<09:37, 1.71s/it]
|
| 343 |
28%|██▊ | 132/469 [03:58<09:49, 1.75s/it]
|
| 344 |
28%|██▊ | 133/469 [04:00<09:59, 1.78s/it]
|
| 345 |
29%|██▊ | 134/469 [04:02<10:22, 1.86s/it]
|
| 346 |
29%|██▉ | 135/469 [04:04<10:20, 1.86s/it]
|
| 347 |
29%|██▉ | 136/469 [04:06<10:26, 1.88s/it]
|
| 348 |
29%|██▉ | 137/469 [04:08<10:02, 1.81s/it]
|
| 349 |
29%|██▉ | 138/469 [04:10<10:16, 1.86s/it]
|
| 350 |
30%|██▉ | 139/469 [04:11<09:58, 1.81s/it]
|
| 351 |
30%|██▉ | 140/469 [04:13<09:59, 1.82s/it]
|
| 352 |
30%|███ | 141/469 [04:15<09:40, 1.77s/it]
|
| 353 |
30%|███ | 142/469 [04:17<09:28, 1.74s/it]
|
| 354 |
30%|███ | 143/469 [04:18<09:16, 1.71s/it]
|
| 355 |
31%|███ | 144/469 [04:20<09:13, 1.70s/it]
|
| 356 |
31%|███ | 145/469 [04:22<09:28, 1.75s/it]
|
| 357 |
31%|███ | 146/469 [04:24<09:18, 1.73s/it]
|
| 358 |
31%|███▏ | 147/469 [04:25<09:35, 1.79s/it]
|
| 359 |
32%|███▏ | 148/469 [04:27<09:28, 1.77s/it]
|
| 360 |
32%|███▏ | 149/469 [04:29<09:12, 1.73s/it]
|
| 361 |
32%|███▏ | 150/469 [04:31<09:11, 1.73s/it]
|
| 362 |
32%|███▏ | 151/469 [04:32<09:18, 1.76s/it]
|
| 363 |
32%|███▏ | 152/469 [04:34<09:23, 1.78s/it]
|
| 364 |
33%|███▎ | 153/469 [04:36<09:21, 1.78s/it]
|
| 365 |
33%|███▎ | 154/469 [04:38<09:11, 1.75s/it]
|
| 366 |
33%|███▎ | 155/469 [04:39<09:09, 1.75s/it]
|
| 367 |
33%|███▎ | 156/469 [04:42<10:07, 1.94s/it]
|
| 368 |
33%|███▎ | 157/469 [04:44<10:39, 2.05s/it]
|
| 369 |
34%|███▎ | 158/469 [04:46<10:12, 1.97s/it]
|
| 370 |
34%|███▍ | 159/469 [04:48<09:42, 1.88s/it]
|
| 371 |
34%|███▍ | 160/469 [04:49<09:42, 1.88s/it]
|
| 372 |
34%|███▍ | 161/469 [04:51<09:24, 1.83s/it]
|
| 373 |
35%|███▍ | 162/469 [04:53<09:08, 1.79s/it]
|
| 374 |
35%|███▍ | 163/469 [04:55<09:01, 1.77s/it]
|
| 375 |
35%|███▍ | 164/469 [04:56<09:08, 1.80s/it]
|
| 376 |
35%|███▌ | 165/469 [04:58<09:07, 1.80s/it]
|
| 377 |
35%|███▌ | 166/469 [05:00<09:06, 1.80s/it]
|
| 378 |
36%|███▌ | 167/469 [05:02<09:00, 1.79s/it]
|
| 379 |
36%|███▌ | 168/469 [05:03<08:47, 1.75s/it]
|
| 380 |
36%|███▌ | 169/469 [05:05<08:51, 1.77s/it]
|
| 381 |
36%|███▌ | 170/469 [05:07<08:35, 1.73s/it]
|
| 382 |
36%|███▋ | 171/469 [05:09<08:25, 1.70s/it]
|
| 383 |
37%|███▋ | 172/469 [05:10<08:24, 1.70s/it]
|
| 384 |
37%|███▋ | 173/469 [05:12<08:20, 1.69s/it]
|
| 385 |
37%|███▋ | 174/469 [05:14<08:20, 1.70s/it]
|
| 386 |
37%|███▋ | 175/469 [05:16<08:42, 1.78s/it]
|
| 387 |
38%|███▊ | 176/469 [05:17<08:34, 1.76s/it]
|
| 388 |
38%|███▊ | 177/469 [05:19<08:31, 1.75s/it]
|
| 389 |
38%|███▊ | 178/469 [05:21<08:34, 1.77s/it]
|
| 390 |
38%|███▊ | 179/469 [05:22<08:22, 1.73s/it]
|
| 391 |
38%|███▊ | 180/469 [05:24<08:22, 1.74s/it]
|
| 392 |
39%|███▊ | 181/469 [05:26<08:14, 1.72s/it]
|
| 393 |
39%|███▉ | 182/469 [05:28<08:19, 1.74s/it]
|
| 394 |
39%|███▉ | 183/469 [05:29<08:17, 1.74s/it]
|
| 395 |
39%|███▉ | 184/469 [05:31<08:29, 1.79s/it]
|
| 396 |
39%|███▉ | 185/469 [05:33<08:36, 1.82s/it]
|
| 397 |
40%|███▉ | 186/469 [05:35<08:36, 1.82s/it]
|
| 398 |
40%|███▉ | 187/469 [05:37<08:29, 1.81s/it]
|
| 399 |
40%|████ | 188/469 [05:39<08:45, 1.87s/it]
|
| 400 |
40%|████ | 189/469 [05:41<08:29, 1.82s/it]
|
| 401 |
41%|████ | 190/469 [05:42<08:28, 1.82s/it]
|
| 402 |
41%|████ | 191/469 [05:44<08:18, 1.79s/it]
|
| 403 |
41%|████ | 192/469 [05:46<08:10, 1.77s/it]
|
| 404 |
41%|████ | 193/469 [05:47<07:58, 1.73s/it]
|
| 405 |
41%|████▏ | 194/469 [05:49<07:52, 1.72s/it]
|
| 406 |
42%|████▏ | 195/469 [05:51<08:08, 1.78s/it]
|
| 407 |
42%|████▏ | 196/469 [05:53<08:19, 1.83s/it]
|
| 408 |
42%|████▏ | 197/469 [05:55<08:35, 1.90s/it]
|
| 409 |
42%|████▏ | 198/469 [05:57<08:28, 1.88s/it]
|
| 410 |
42%|████▏ | 199/469 [05:59<08:17, 1.84s/it]
|
| 411 |
43%|████▎ | 200/469 [06:01<08:22, 1.87s/it]
|
| 412 |
43%|████▎ | 201/469 [06:02<08:09, 1.83s/it]
|
| 413 |
43%|████▎ | 202/469 [06:04<08:08, 1.83s/it]
|
| 414 |
43%|████▎ | 203/469 [06:06<08:21, 1.88s/it]
|
| 415 |
43%|████▎ | 204/469 [06:08<07:59, 1.81s/it]
|
| 416 |
44%|████▎ | 205/469 [06:09<07:46, 1.77s/it]
|
| 417 |
44%|████▍ | 206/469 [06:11<07:38, 1.74s/it]
|
| 418 |
44%|████▍ | 207/469 [06:13<07:27, 1.71s/it]
|
| 419 |
44%|████▍ | 208/469 [06:15<07:40, 1.76s/it]
|
| 420 |
45%|████▍ | 209/469 [06:16<07:35, 1.75s/it]
|
| 421 |
45%|████▍ | 210/469 [06:18<07:30, 1.74s/it]
|
| 422 |
45%|████▍ | 211/469 [06:20<07:35, 1.77s/it]
|
| 423 |
45%|████▌ | 212/469 [06:22<07:32, 1.76s/it]
|
| 424 |
45%|████▌ | 213/469 [06:23<07:29, 1.76s/it]
|
| 425 |
46%|████▌ | 214/469 [06:25<07:38, 1.80s/it]
|
| 426 |
46%|████▌ | 215/469 [06:27<07:34, 1.79s/it]
|
| 427 |
46%|████▌ | 216/469 [06:29<07:28, 1.77s/it]
|
| 428 |
46%|████▋ | 217/469 [06:30<07:20, 1.75s/it]
|
| 429 |
46%|████▋ | 218/469 [06:32<07:08, 1.71s/it]
|
| 430 |
47%|████▋ | 219/469 [06:34<07:04, 1.70s/it]
|
| 431 |
47%|████▋ | 220/469 [06:35<07:02, 1.70s/it]
|
| 432 |
47%|████▋ | 221/469 [06:38<07:54, 1.91s/it]
|
| 433 |
47%|████▋ | 222/469 [06:40<07:49, 1.90s/it]
|
| 434 |
48%|████▊ | 223/469 [06:42<07:37, 1.86s/it]
|
| 435 |
48%|████▊ | 224/469 [06:43<07:18, 1.79s/it]
|
| 436 |
48%|████▊ | 225/469 [06:45<07:42, 1.90s/it]
|
| 437 |
48%|████▊ | 226/469 [06:48<09:05, 2.25s/it]
|
| 438 |
48%|████▊ | 227/469 [06:50<08:33, 2.12s/it]
|
| 439 |
49%|████▊ | 228/469 [06:52<07:55, 1.97s/it]
|
| 440 |
49%|████▉ | 229/469 [06:54<07:40, 1.92s/it]
|
| 441 |
49%|████▉ | 230/469 [06:55<07:33, 1.90s/it]
|
| 442 |
49%|████▉ | 231/469 [06:57<07:17, 1.84s/it]
|
| 443 |
49%|████▉ | 232/469 [06:59<07:05, 1.79s/it]
|
| 444 |
50%|████▉ | 233/469 [07:01<07:03, 1.80s/it]
|
| 445 |
50%|████▉ | 234/469 [07:02<06:59, 1.78s/it]
|
| 446 |
50%|█████ | 235/469 [07:05<08:13, 2.11s/it]
|
| 447 |
50%|█████ | 236/469 [07:07<07:39, 1.97s/it]
|
| 448 |
51%|█████ | 237/469 [07:09<07:15, 1.88s/it]
|
| 449 |
51%|█████ | 238/469 [07:10<07:11, 1.87s/it]
|
| 450 |
51%|█████ | 239/469 [07:12<06:55, 1.81s/it]
|
| 451 |
51%|█████ | 240/469 [07:14<07:03, 1.85s/it]
|
| 452 |
51%|█████▏ | 241/469 [07:16<06:54, 1.82s/it]
|
| 453 |
52%|█████▏ | 242/469 [07:17<06:41, 1.77s/it]
|
| 454 |
52%|█████▏ | 243/469 [07:19<06:33, 1.74s/it]
|
| 455 |
52%|█████▏ | 244/469 [07:21<06:28, 1.73s/it]
|
| 456 |
52%|█████▏ | 245/469 [07:24<07:39, 2.05s/it]
|
| 457 |
52%|█████▏ | 246/469 [07:25<07:19, 1.97s/it]
|
| 458 |
53%|█████▎ | 247/469 [07:27<07:08, 1.93s/it]
|
| 459 |
53%|█████▎ | 248/469 [07:29<06:57, 1.89s/it]
|
| 460 |
53%|█████▎ | 249/469 [07:31<06:40, 1.82s/it]
|
| 461 |
53%|█████▎ | 250/469 [07:33<06:38, 1.82s/it]
|
| 462 |
54%|█████▎ | 251/469 [07:34<06:25, 1.77s/it]
|
| 463 |
54%|█████▎ | 252/469 [07:36<06:15, 1.73s/it]
|
| 464 |
54%|█████▍ | 253/469 [07:37<06:09, 1.71s/it]
|
| 465 |
54%|█████▍ | 254/469 [07:39<06:05, 1.70s/it]
|
| 466 |
54%|█████▍ | 255/469 [07:41<06:02, 1.70s/it]
|
| 467 |
55%|█████▍ | 256/469 [07:42<05:58, 1.68s/it]
|
| 468 |
55%|█████▍ | 257/469 [07:44<06:01, 1.71s/it]
|
| 469 |
55%|█████▌ | 258/469 [07:47<06:54, 1.96s/it]
|
| 470 |
55%|█████▌ | 259/469 [07:49<06:35, 1.88s/it]
|
| 471 |
55%|█████▌ | 260/469 [07:52<08:26, 2.42s/it]
|
| 472 |
56%|█████▌ | 261/469 [07:54<07:35, 2.19s/it]
|
| 473 |
56%|█████▌ | 262/469 [07:56<07:07, 2.07s/it]
|
| 474 |
56%|█████▌ | 263/469 [07:57<06:47, 1.98s/it]
|
| 475 |
56%|█████▋ | 264/469 [07:59<06:25, 1.88s/it]
|
| 476 |
57%|█████▋ | 265/469 [08:01<06:21, 1.87s/it]
|
| 477 |
57%|█████▋ | 266/469 [08:03<06:20, 1.87s/it]
|
| 478 |
57%|█████▋ | 267/469 [08:04<06:08, 1.82s/it]
|
| 479 |
57%|█████▋ | 268/469 [08:06<06:07, 1.83s/it]
|
| 480 |
57%|█████▋ | 269/469 [08:08<05:54, 1.77s/it]
|
| 481 |
58%|█████▊ | 270/469 [08:10<05:59, 1.80s/it]
|
| 482 |
58%|█████▊ | 271/469 [08:12<05:51, 1.77s/it]
|
| 483 |
58%|█████▊ | 272/469 [08:13<05:58, 1.82s/it]
|
| 484 |
58%|█████▊ | 273/469 [08:15<06:06, 1.87s/it]
|
| 485 |
58%|█████▊ | 274/469 [08:17<05:52, 1.81s/it]
|
| 486 |
59%|█████▊ | 275/469 [08:19<05:50, 1.81s/it]
|
| 487 |
59%|█████▉ | 276/469 [08:21<05:42, 1.77s/it]
|
| 488 |
59%|█████▉ | 277/469 [08:22<05:37, 1.76s/it]
|
| 489 |
59%|█████▉ | 278/469 [08:24<05:34, 1.75s/it]
|
| 490 |
59%|█████▉ | 279/469 [08:26<05:38, 1.78s/it]
|
| 491 |
60%|█████▉ | 280/469 [08:28<05:38, 1.79s/it]
|
| 492 |
60%|█████▉ | 281/469 [08:30<05:39, 1.81s/it]
|
| 493 |
60%|██████ | 282/469 [08:31<05:38, 1.81s/it]
|
| 494 |
60%|██████ | 283/469 [08:33<05:30, 1.78s/it]
|
| 495 |
61%|██████ | 284/469 [08:35<05:26, 1.77s/it]
|
| 496 |
61%|██████ | 285/469 [08:36<05:17, 1.73s/it]
|
| 497 |
61%|██████ | 286/469 [08:38<05:14, 1.72s/it]
|
| 498 |
61%|██████ | 287/469 [08:40<05:10, 1.71s/it]
|
| 499 |
61%|██████▏ | 288/469 [08:41<05:05, 1.69s/it]
|
| 500 |
62%|██████▏ | 289/469 [08:43<05:04, 1.69s/it]
|
| 501 |
62%|██████▏ | 290/469 [08:45<05:11, 1.74s/it]
|
| 502 |
62%|██████▏ | 291/469 [08:47<05:11, 1.75s/it]
|
| 503 |
62%|██████▏ | 292/469 [08:48<05:04, 1.72s/it]
|
| 504 |
62%|██████▏ | 293/469 [08:50<05:04, 1.73s/it]
|
| 505 |
63%|██████▎ | 294/469 [08:52<05:19, 1.83s/it]
|
| 506 |
63%|██████▎ | 295/469 [08:54<05:14, 1.81s/it]
|
| 507 |
63%|██████▎ | 296/469 [08:56<05:04, 1.76s/it]
|
| 508 |
63%|██████▎ | 297/469 [08:57<04:58, 1.74s/it]
|
| 509 |
64%|██████▎ | 298/469 [08:59<05:01, 1.76s/it]
|
| 510 |
64%|██████▍ | 299/469 [09:01<04:56, 1.74s/it]
|
| 511 |
64%|██████▍ | 300/469 [09:03<04:52, 1.73s/it]
|
| 512 |
64%|██████▍ | 301/469 [09:04<04:48, 1.72s/it]
|
| 513 |
64%|██████▍ | 302/469 [09:06<04:43, 1.70s/it]
|
| 514 |
65%|██████▍ | 303/469 [09:08<04:50, 1.75s/it]
|
| 515 |
65%|██████▍ | 304/469 [09:10<04:46, 1.74s/it]
|
| 516 |
65%|██████▌ | 305/469 [09:11<04:47, 1.75s/it]
|
| 517 |
65%|██████▌ | 306/469 [09:13<04:42, 1.73s/it]
|
| 518 |
65%|██████▌ | 307/469 [09:15<04:58, 1.84s/it]
|
| 519 |
66%|██████▌ | 308/469 [09:17<04:50, 1.81s/it]
|
| 520 |
66%|██████▌ | 309/469 [09:19<04:46, 1.79s/it]
|
| 521 |
66%|██████▌ | 310/469 [09:20<04:39, 1.76s/it]
|
| 522 |
66%|██████▋ | 311/469 [09:22<04:34, 1.74s/it]
|
| 523 |
67%|██████▋ | 312/469 [09:24<04:36, 1.76s/it]
|
| 524 |
67%|██████▋ | 313/469 [09:26<04:52, 1.88s/it]
|
| 525 |
67%|██████▋ | 314/469 [09:28<04:55, 1.91s/it]
|
| 526 |
67%|██████▋ | 315/469 [09:30<04:43, 1.84s/it]
|
| 527 |
67%|██████▋ | 316/469 [09:31<04:36, 1.81s/it]
|
| 528 |
68%|██████▊ | 317/469 [09:33<04:42, 1.86s/it]
|
| 529 |
68%|██████▊ | 318/469 [09:36<05:38, 2.24s/it]
|
| 530 |
68%|██████▊ | 319/469 [09:38<05:18, 2.13s/it]
|
| 531 |
68%|██████▊ | 320/469 [09:40<04:59, 2.01s/it]
|
| 532 |
68%|██████▊ | 321/469 [09:42<04:49, 1.95s/it]
|
| 533 |
69%|██████▊ | 322/469 [09:43<04:32, 1.86s/it]
|
| 534 |
69%|██████▉ | 323/469 [09:47<05:51, 2.41s/it]
|
| 535 |
69%|██████▉ | 324/469 [09:49<05:17, 2.19s/it]
|
| 536 |
69%|██████▉ | 325/469 [09:51<04:58, 2.08s/it]
|
| 537 |
70%|██████▉ | 326/469 [09:52<04:42, 1.97s/it]
|
| 538 |
70%|██████▉ | 327/469 [09:54<04:29, 1.89s/it]
|
| 539 |
70%|██████▉ | 328/469 [09:56<04:17, 1.82s/it]
|
| 540 |
70%|███████ | 329/469 [09:58<04:12, 1.80s/it]
|
| 541 |
70%|███████ | 330/469 [09:59<04:07, 1.78s/it]
|
| 542 |
71%|███████ | 331/469 [10:01<04:12, 1.83s/it]
|
| 543 |
71%|███████ | 332/469 [10:03<04:06, 1.80s/it]
|
| 544 |
71%|███████ | 333/469 [10:05<03:57, 1.75s/it]
|
| 545 |
71%|███████ | 334/469 [10:06<03:51, 1.71s/it]
|
| 546 |
71%|███████▏ | 335/469 [10:08<03:47, 1.70s/it]
|
| 547 |
72%|███████▏ | 336/469 [10:09<03:44, 1.69s/it]
|
| 548 |
72%|███████▏ | 337/469 [10:11<03:42, 1.68s/it]
|
| 549 |
72%|███████▏ | 338/469 [10:13<03:45, 1.72s/it]
|
| 550 |
72%|███████▏ | 339/469 [10:15<03:41, 1.70s/it]
|
| 551 |
72%|███████▏ | 340/469 [10:16<03:38, 1.70s/it]
|
| 552 |
73%|███████▎ | 341/469 [10:18<03:37, 1.70s/it]
|
| 553 |
73%|███████▎ | 342/469 [10:20<03:47, 1.79s/it]
|
| 554 |
73%|███████▎ | 343/469 [10:22<03:42, 1.77s/it]
|
| 555 |
73%|███████▎ | 344/469 [10:24<03:40, 1.77s/it]
|
| 556 |
74%|███████▎ | 345/469 [10:25<03:39, 1.77s/it]
|
| 557 |
74%|███████▍ | 346/469 [10:27<03:36, 1.76s/it]
|
| 558 |
74%|███████▍ | 347/469 [10:29<03:31, 1.73s/it]
|
| 559 |
74%|███████▍ | 348/469 [10:31<03:35, 1.78s/it]
|
| 560 |
74%|███████▍ | 349/469 [10:32<03:34, 1.78s/it]
|
| 561 |
75%|███████▍ | 350/469 [10:34<03:30, 1.76s/it]
|
| 562 |
75%|███████▍ | 351/469 [10:36<03:23, 1.72s/it]
|
| 563 |
75%|███████▌ | 352/469 [10:38<03:33, 1.82s/it]
|
| 564 |
75%|███████▌ | 353/469 [10:40<03:29, 1.81s/it]
|
| 565 |
75%|███████▌ | 354/469 [10:41<03:27, 1.80s/it]
|
| 566 |
76%|███████▌ | 355/469 [10:43<03:21, 1.77s/it]
|
| 567 |
76%|███████▌ | 356/469 [10:45<03:27, 1.83s/it]
|
| 568 |
76%|███████▌ | 357/469 [10:47<03:18, 1.77s/it]
|
| 569 |
76%|███████▋ | 358/469 [10:48<03:14, 1.75s/it]
|
| 570 |
77%|██��████▋ | 359/469 [10:50<03:09, 1.73s/it]
|
| 571 |
77%|███████▋ | 360/469 [10:52<03:07, 1.72s/it]
|
| 572 |
77%|███████▋ | 361/469 [10:53<03:06, 1.72s/it]
|
| 573 |
77%|███████▋ | 362/469 [10:55<03:07, 1.75s/it]
|
| 574 |
77%|███████▋ | 363/469 [10:57<03:10, 1.79s/it]
|
| 575 |
78%|███████▊ | 364/469 [10:59<03:06, 1.77s/it]
|
| 576 |
78%|███████▊ | 365/469 [11:01<03:00, 1.73s/it]
|
| 577 |
78%|███████▊ | 366/469 [11:02<02:55, 1.71s/it]
|
| 578 |
78%|███████▊ | 367/469 [11:04<02:53, 1.70s/it]
|
| 579 |
78%|███████▊ | 368/469 [11:06<02:53, 1.72s/it]
|
| 580 |
79%|███████▊ | 369/469 [11:07<02:51, 1.71s/it]
|
| 581 |
79%|███████▉ | 370/469 [11:09<02:48, 1.70s/it]
|
| 582 |
79%|███████▉ | 371/469 [11:11<02:48, 1.71s/it]
|
| 583 |
79%|███████▉ | 372/469 [11:13<02:51, 1.77s/it]
|
| 584 |
80%|███████▉ | 373/469 [11:14<02:46, 1.73s/it]
|
| 585 |
80%|███████▉ | 374/469 [11:16<02:45, 1.75s/it]
|
| 586 |
80%|███████▉ | 375/469 [11:18<02:43, 1.74s/it]
|
| 587 |
80%|████████ | 376/469 [11:20<02:42, 1.74s/it]
|
| 588 |
80%|████████ | 377/469 [11:21<02:39, 1.74s/it]
|
| 589 |
81%|████████ | 378/469 [11:23<02:36, 1.72s/it]
|
| 590 |
81%|████████ | 379/469 [11:25<02:35, 1.72s/it]
|
| 591 |
81%|████████ | 380/469 [11:26<02:32, 1.72s/it]
|
| 592 |
81%|████████ | 381/469 [11:28<02:34, 1.76s/it]
|
| 593 |
81%|████████▏ | 382/469 [11:30<02:43, 1.88s/it]
|
| 594 |
82%|████████▏ | 383/469 [11:32<02:35, 1.81s/it]
|
| 595 |
82%|████████▏ | 384/469 [11:34<02:29, 1.76s/it]
|
| 596 |
82%|████████▏ | 385/469 [11:36<02:30, 1.80s/it]
|
| 597 |
82%|████████▏ | 386/469 [11:37<02:26, 1.77s/it]
|
| 598 |
83%|████████▎ | 387/469 [11:39<02:26, 1.78s/it]
|
| 599 |
83%|████████▎ | 388/469 [11:41<02:23, 1.77s/it]
|
| 600 |
83%|████████▎ | 389/469 [11:42<02:18, 1.73s/it]
|
| 601 |
83%|████████▎ | 390/469 [11:44<02:18, 1.75s/it]
|
| 602 |
83%|████████▎ | 391/469 [11:46<02:16, 1.76s/it]
|
| 603 |
84%|████████▎ | 392/469 [11:49<02:33, 1.99s/it]
|
| 604 |
84%|████████▍ | 393/469 [11:50<02:24, 1.90s/it]
|
| 605 |
84%|████████▍ | 394/469 [11:52<02:18, 1.85s/it]
|
| 606 |
84%|████████▍ | 395/469 [11:54<02:12, 1.79s/it]
|
| 607 |
84%|████████▍ | 396/469 [11:58<03:02, 2.50s/it]
|
| 608 |
85%|████████▍ | 397/469 [12:00<02:44, 2.29s/it]
|
| 609 |
85%|████████▍ | 398/469 [12:01<02:31, 2.13s/it]
|
| 610 |
85%|████████▌ | 399/469 [12:03<02:21, 2.02s/it]
|
| 611 |
85%|████████▌ | 400/469 [12:05<02:15, 1.96s/it]
|
| 612 |
86%|████████▌ | 401/469 [12:07<02:10, 1.92s/it]
|
| 613 |
86%|████████▌ | 402/469 [12:08<02:02, 1.83s/it]
|
| 614 |
86%|████████▌ | 403/469 [12:10<01:57, 1.78s/it]
|
| 615 |
86%|████████▌ | 404/469 [12:12<01:53, 1.74s/it]
|
| 616 |
86%|████████▋ | 405/469 [12:13<01:49, 1.72s/it]
|
| 617 |
87%|████████▋ | 406/469 [12:15<01:50, 1.76s/it]
|
| 618 |
87%|████████▋ | 407/469 [12:17<01:48, 1.74s/it]
|
| 619 |
87%|████████▋ | 408/469 [12:19<01:44, 1.72s/it]
|
| 620 |
87%|████████▋ | 409/469 [12:20<01:41, 1.70s/it]
|
| 621 |
87%|████████▋ | 410/469 [12:22<01:45, 1.79s/it]
|
| 622 |
88%|████████▊ | 411/469 [12:24<01:40, 1.74s/it]
|
| 623 |
88%|████████▊ | 412/469 [12:26<01:41, 1.78s/it]
|
| 624 |
88%|████████▊ | 413/469 [12:28<01:39, 1.78s/it]
|
| 625 |
88%|████████▊ | 414/469 [12:29<01:35, 1.74s/it]
|
| 626 |
88%|████████▊ | 415/469 [12:31<01:38, 1.83s/it]
|
| 627 |
89%|████████▊ | 416/469 [12:33<01:36, 1.82s/it]
|
| 628 |
89%|████████▉ | 417/469 [12:35<01:31, 1.76s/it]
|
| 629 |
89%|████████▉ | 418/469 [12:36<01:31, 1.79s/it]
|
| 630 |
89%|████████▉ | 419/469 [12:38<01:26, 1.74s/it]
|
| 631 |
90%|████████▉ | 420/469 [12:40<01:25, 1.75s/it]
|
| 632 |
90%|████████▉ | 421/469 [12:42<01:23, 1.75s/it]
|
| 633 |
90%|████████▉ | 422/469 [12:43<01:20, 1.71s/it]
|
| 634 |
90%|█████████ | 423/469 [12:45<01:20, 1.74s/it]
|
| 635 |
90%|█████████ | 424/469 [12:47<01:18, 1.74s/it]
|
| 636 |
91%|█████████ | 425/469 [12:49<01:17, 1.75s/it]
|
| 637 |
91%|█████████ | 426/469 [12:50<01:15, 1.75s/it]
|
| 638 |
91%|█████████ | 427/469 [12:52<01:14, 1.78s/it]
|
| 639 |
91%|█████████▏| 428/469 [12:54<01:12, 1.77s/it]
|
| 640 |
91%|█████████▏| 429/469 [12:56<01:09, 1.75s/it]
|
| 641 |
92%|█████████▏| 430/469 [12:57<01:07, 1.72s/it]
|
| 642 |
92%|█████████▏| 431/469 [12:59<01:04, 1.70s/it]
|
| 643 |
92%|█████████▏| 432/469 [13:01<01:04, 1.74s/it]
|
| 644 |
92%|█████████▏| 433/469 [13:03<01:07, 1.88s/it]
|
| 645 |
93%|█████████▎| 434/469 [13:05<01:04, 1.86s/it]
|
| 646 |
93%|█████████▎| 435/469 [13:07<01:03, 1.87s/it]
|
| 647 |
93%|█████████▎| 436/469 [13:08<00:59, 1.81s/it]
|
| 648 |
93%|█████████▎| 437/469 [13:10<00:57, 1.81s/it]
|
| 649 |
93%|█████████▎| 438/469 [13:12<00:57, 1.86s/it]
|
| 650 |
94%|█████████▎| 439/469 [13:14<00:54, 1.82s/it]
|
| 651 |
94%|█████████▍| 440/469 [13:16<00:52, 1.82s/it]
|
| 652 |
94%|█████████▍| 441/469 [13:18<00:51, 1.84s/it]
|
| 653 |
94%|█████████▍| 442/469 [13:19<00:48, 1.80s/it]
|
| 654 |
94%|█████████▍| 443/469 [13:21<00:48, 1.86s/it]
|
| 655 |
95%|█████████▍| 444/469 [13:23<00:45, 1.80s/it]
|
| 656 |
95%|█████████▍| 445/469 [13:25<00:42, 1.79s/it]
|
| 657 |
95%|█████████▌| 446/469 [13:27<00:42, 1.85s/it]
|
| 658 |
95%|█████████▌| 447/469 [13:29<00:41, 1.86s/it]
|
| 659 |
96%|█████████▌| 448/469 [13:30<00:39, 1.88s/it]
|
| 660 |
96%|█████████▌| 449/469 [13:32<00:37, 1.86s/it]
|
| 661 |
96%|█████████▌| 450/469 [13:34<00:33, 1.79s/it]
|
| 662 |
96%|█████████▌| 451/469 [13:36<00:31, 1.74s/it]
|
| 663 |
96%|█████████▋| 452/469 [13:37<00:29, 1.73s/it]
|
| 664 |
97%|█████████▋| 453/469 [13:39<00:28, 1.76s/it]
|
| 665 |
97%|█████████▋| 454/469 [13:41<00:26, 1.74s/it]
|
| 666 |
97%|█████████▋| 455/469 [13:42<00:23, 1.70s/it]
|
| 667 |
97%|█████████▋| 456/469 [13:44<00:22, 1.71s/it]
|
| 668 |
97%|█████████▋| 457/469 [13:46<00:20, 1.71s/it]
|
| 669 |
98%|█████████▊| 458/469 [13:48<00:18, 1.71s/it]
|
| 670 |
98%|█████████▊| 459/469 [13:49<00:16, 1.69s/it]
|
| 671 |
98%|█████████▊| 460/469 [13:51<00:15, 1.68s/it]
|
| 672 |
98%|█████████▊| 461/469 [13:53<00:14, 1.76s/it]
|
| 673 |
99%|█████████▊| 462/469 [13:55<00:12, 1.78s/it]
|
| 674 |
99%|█████████▊| 463/469 [13:57<00:10, 1.83s/it]
|
| 675 |
99%|█████████▉| 464/469 [13:58<00:09, 1.86s/it]
|
| 676 |
99%|█████████▉| 465/469 [14:00<00:07, 1.81s/it]
|
| 677 |
99%|█████████▉| 466/469 [14:02<00:05, 1.78s/it]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-21 21:11:32.431141: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
| 2 |
+
2026-03-21 21:11:44.041066: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
|
| 3 |
+
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
| 4 |
+
2026-03-21 21:11:44.095212: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
|
| 5 |
+
2026-03-21 21:11:44.095324: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: d700126ec97dc07d69688b0430c49a6a-taskrole1-0
|
| 6 |
+
2026-03-21 21:11:44.095377: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: d700126ec97dc07d69688b0430c49a6a-taskrole1-0
|
| 7 |
+
2026-03-21 21:11:44.095515: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
|
| 8 |
+
2026-03-21 21:11:44.095581: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 550.127.8
|
| 9 |
+
2026-03-21 21:11:45.185895: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
| 10 |
+
warming up TensorFlow...
|
| 11 |
+
|
| 12 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-21 21:11:45.919794: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
| 13 |
+
|
| 14 |
+
computing reference batch activations...
|
| 15 |
+
|
| 16 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 17 |
0%| | 1/211 [00:02<07:04, 2.02s/it]
|
| 18 |
1%| | 2/211 [00:03<06:56, 2.00s/it]
|
| 19 |
1%|▏ | 3/211 [00:05<06:50, 1.98s/it]
|
| 20 |
2%|▏ | 4/211 [00:07<06:42, 1.95s/it]
|
| 21 |
2%|▏ | 5/211 [00:09<06:25, 1.87s/it]
|
| 22 |
3%|▎ | 6/211 [00:11<06:20, 1.85s/it]
|
| 23 |
3%|▎ | 7/211 [00:13<06:46, 1.99s/it]
|
| 24 |
4%|▍ | 8/211 [00:15<06:56, 2.05s/it]
|
| 25 |
4%|▍ | 9/211 [00:30<20:28, 6.08s/it]
|
| 26 |
5%|▍ | 10/211 [00:32<15:51, 4.73s/it]
|
| 27 |
5%|▌ | 11/211 [00:34<12:42, 3.81s/it]
|
| 28 |
6%|▌ | 12/211 [00:35<10:26, 3.15s/it]
|
| 29 |
6%|▌ | 13/211 [00:37<08:55, 2.71s/it]
|
| 30 |
7%|▋ | 14/211 [00:39<08:09, 2.49s/it]
|
| 31 |
7%|▋ | 15/211 [00:41<07:32, 2.31s/it]
|
| 32 |
8%|▊ | 16/211 [00:43<06:52, 2.12s/it]
|
| 33 |
8%|▊ | 17/211 [00:44<06:35, 2.04s/it]
|
| 34 |
9%|▊ | 18/211 [00:46<06:18, 1.96s/it]
|
| 35 |
9%|▉ | 19/211 [00:48<06:16, 1.96s/it]
|
| 36 |
9%|▉ | 20/211 [00:50<06:20, 1.99s/it]
|
| 37 |
10%|▉ | 21/211 [00:52<06:01, 1.91s/it]
|
| 38 |
10%|█ | 22/211 [00:54<05:51, 1.86s/it]
|
| 39 |
11%|█ | 23/211 [00:56<05:48, 1.85s/it]
|
| 40 |
11%|█▏ | 24/211 [00:57<05:44, 1.84s/it]
|
| 41 |
12%|█▏ | 25/211 [00:59<05:44, 1.85s/it]
|
| 42 |
12%|█▏ | 26/211 [01:01<05:38, 1.83s/it]
|
| 43 |
13%|█▎ | 27/211 [01:03<05:33, 1.81s/it]
|
| 44 |
13%|█▎ | 28/211 [01:05<06:02, 1.98s/it]
|
| 45 |
14%|█▎ | 29/211 [01:07<05:49, 1.92s/it]
|
| 46 |
14%|█▍ | 30/211 [01:09<05:33, 1.84s/it]
|
| 47 |
15%|█▍ | 31/211 [01:11<05:38, 1.88s/it]
|
| 48 |
15%|█▌ | 32/211 [01:12<05:32, 1.86s/it]
|
| 49 |
16%|█▌ | 33/211 [01:14<05:26, 1.83s/it]
|
| 50 |
16%|█▌ | 34/211 [01:16<05:20, 1.81s/it]
|
| 51 |
17%|█▋ | 35/211 [01:18<05:16, 1.80s/it]
|
| 52 |
17%|█▋ | 36/211 [01:19<05:11, 1.78s/it]
|
| 53 |
18%|█▊ | 37/211 [01:21<05:06, 1.76s/it]
|
| 54 |
18%|█▊ | 38/211 [01:23<05:22, 1.86s/it]
|
| 55 |
18%|█▊ | 39/211 [01:25<05:12, 1.82s/it]
|
| 56 |
19%|█▉ | 40/211 [01:35<11:51, 4.16s/it]
|
| 57 |
19%|█▉ | 41/211 [01:36<09:45, 3.44s/it]
|
| 58 |
20%|█▉ | 42/211 [01:38<08:27, 3.00s/it]
|
| 59 |
20%|██ | 43/211 [01:40<07:25, 2.65s/it]
|
| 60 |
21%|██ | 44/211 [01:42<06:46, 2.43s/it]
|
| 61 |
21%|██▏ | 45/211 [01:44<06:28, 2.34s/it]
|
| 62 |
22%|██▏ | 46/211 [01:46<05:56, 2.16s/it]
|
| 63 |
22%|██▏ | 47/211 [01:48<05:55, 2.17s/it]
|
| 64 |
23%|██▎ | 48/211 [01:50<05:30, 2.03s/it]
|
| 65 |
23%|██▎ | 49/211 [01:52<05:22, 1.99s/it]
|
| 66 |
24%|██▎ | 50/211 [01:54<05:11, 1.94s/it]
|
| 67 |
24%|██▍ | 51/211 [01:56<05:18, 1.99s/it]
|
| 68 |
25%|██▍ | 52/211 [01:57<05:02, 1.90s/it]
|
| 69 |
25%|██▌ | 53/211 [01:59<04:52, 1.85s/it]
|
| 70 |
26%|██▌ | 54/211 [02:01<04:42, 1.80s/it]
|
| 71 |
26%|██▌ | 55/211 [02:03<04:40, 1.80s/it]
|
| 72 |
27%|██▋ | 56/211 [02:05<04:47, 1.85s/it]
|
| 73 |
27%|██▋ | 57/211 [02:06<04:38, 1.81s/it]
|
| 74 |
27%|██▋ | 58/211 [02:08<04:40, 1.83s/it]
|
| 75 |
28%|██▊ | 59/211 [02:10<04:34, 1.81s/it]
|
| 76 |
28%|██▊ | 60/211 [02:12<04:31, 1.80s/it]
|
| 77 |
29%|██▉ | 61/211 [02:13<04:24, 1.76s/it]
|
| 78 |
29%|██▉ | 62/211 [02:15<04:19, 1.74s/it]
|
| 79 |
30%|██▉ | 63/211 [02:17<04:25, 1.79s/it]
|
| 80 |
30%|███ | 64/211 [02:19<04:22, 1.78s/it]
|
| 81 |
31%|███ | 65/211 [02:21<04:21, 1.79s/it]
|
| 82 |
31%|███▏ | 66/211 [02:22<04:16, 1.77s/it]
|
| 83 |
32%|███▏ | 67/211 [02:24<04:27, 1.85s/it]
|
| 84 |
32%|███▏ | 68/211 [02:26<04:23, 1.84s/it]
|
| 85 |
33%|███▎ | 69/211 [02:28<04:21, 1.84s/it]
|
| 86 |
33%|███▎ | 70/211 [02:30<04:15, 1.81s/it]
|
| 87 |
34%|███▎ | 71/211 [02:32<04:13, 1.81s/it]
|
| 88 |
34%|███▍ | 72/211 [02:33<04:09, 1.79s/it]
|
| 89 |
35%|███▍ | 73/211 [02:35<04:04, 1.77s/it]
|
| 90 |
35%|███▌ | 74/211 [02:53<15:24, 6.75s/it]
|
| 91 |
36%|███▌ | 75/211 [02:55<11:58, 5.28s/it]
|
| 92 |
36%|███▌ | 76/211 [02:57<09:31, 4.24s/it]
|
| 93 |
36%|███▋ | 77/211 [02:59<07:46, 3.48s/it]
|
| 94 |
37%|███▋ | 78/211 [03:00<06:33, 2.96s/it]
|
| 95 |
37%|███▋ | 79/211 [03:02<05:47, 2.64s/it]
|
| 96 |
38%|███▊ | 80/211 [03:04<05:19, 2.44s/it]
|
| 97 |
38%|███▊ | 81/211 [03:06<04:47, 2.22s/it]
|
| 98 |
39%|███▉ | 82/211 [03:08<04:27, 2.07s/it]
|
| 99 |
39%|███▉ | 83/211 [03:09<04:12, 1.97s/it]
|
| 100 |
40%|███▉ | 84/211 [03:11<04:02, 1.91s/it]
|
| 101 |
40%|████ | 85/211 [03:13<03:53, 1.85s/it]
|
| 102 |
41%|████ | 86/211 [03:15<03:48, 1.83s/it]
|
| 103 |
41%|████ | 87/211 [03:16<03:42, 1.79s/it]
|
| 104 |
42%|████▏ | 88/211 [03:18<03:40, 1.79s/it]
|
| 105 |
42%|████▏ | 89/211 [03:20<03:35, 1.77s/it]
|
| 106 |
43%|████▎ | 90/211 [03:22<03:33, 1.77s/it]
|
| 107 |
43%|████▎ | 91/211 [03:23<03:30, 1.76s/it]
|
| 108 |
44%|████▎ | 92/211 [03:25<03:35, 1.81s/it]
|
| 109 |
44%|████▍ | 93/211 [03:27<03:28, 1.76s/it]
|
| 110 |
45%|████▍ | 94/211 [03:29<03:25, 1.75s/it]
|
| 111 |
45%|████▌ | 95/211 [03:31<03:22, 1.74s/it]
|
| 112 |
45%|████▌ | 96/211 [03:32<03:24, 1.77s/it]
|
| 113 |
46%|████▌ | 97/211 [03:34<03:21, 1.77s/it]
|
| 114 |
46%|████▋ | 98/211 [03:36<03:19, 1.77s/it]
|
| 115 |
47%|████▋ | 99/211 [03:38<03:18, 1.77s/it]
|
| 116 |
47%|████▋ | 100/211 [03:39<03:14, 1.75s/it]
|
| 117 |
48%|████▊ | 101/211 [03:41<03:13, 1.75s/it]
|
| 118 |
48%|████▊ | 102/211 [03:43<03:13, 1.77s/it]
|
| 119 |
49%|████▉ | 103/211 [03:45<03:08, 1.75s/it]
|
| 120 |
49%|████▉ | 104/211 [03:47<03:12, 1.79s/it]
|
| 121 |
50%|████▉ | 105/211 [03:48<03:07, 1.77s/it]
|
| 122 |
50%|█████ | 106/211 [04:10<13:22, 7.64s/it]
|
| 123 |
51%|█████ | 107/211 [04:12<10:18, 5.94s/it]
|
| 124 |
51%|█████ | 108/211 [04:13<08:02, 4.68s/it]
|
| 125 |
52%|█████▏ | 109/211 [04:15<06:26, 3.79s/it]
|
| 126 |
52%|█████▏ | 110/211 [04:17<05:18, 3.16s/it]
|
| 127 |
53%|█████▎ | 111/211 [04:18<04:33, 2.74s/it]
|
| 128 |
53%|█████▎ | 112/211 [04:20<04:09, 2.52s/it]
|
| 129 |
54%|█████▎ | 113/211 [04:22<03:50, 2.35s/it]
|
| 130 |
54%|█████▍ | 114/211 [04:24<03:30, 2.17s/it]
|
| 131 |
55%|█████▍ | 115/211 [04:26<03:22, 2.11s/it]
|
| 132 |
55%|█████▍ | 116/211 [04:46<11:49, 7.47s/it]
|
| 133 |
55%|█████▌ | 117/211 [04:48<08:58, 5.73s/it]
|
| 134 |
56%|█████▌ | 118/211 [04:50<07:08, 4.61s/it]
|
| 135 |
56%|█████▋ | 119/211 [04:51<05:41, 3.71s/it]
|
| 136 |
57%|█████▋ | 120/211 [04:53<04:42, 3.10s/it]
|
| 137 |
57%|█████▋ | 121/211 [04:55<04:01, 2.69s/it]
|
| 138 |
58%|█████▊ | 122/211 [04:57<03:33, 2.40s/it]
|
| 139 |
58%|█████▊ | 123/211 [04:58<03:13, 2.20s/it]
|
| 140 |
59%|█████▉ | 124/211 [05:00<03:01, 2.08s/it]
|
| 141 |
59%|█████▉ | 125/211 [05:02<02:50, 1.98s/it]
|
| 142 |
60%|█████▉ | 126/211 [05:03<02:38, 1.87s/it]
|
| 143 |
60%|██████ | 127/211 [05:05<02:31, 1.80s/it]
|
| 144 |
61%|██████ | 128/211 [05:07<02:25, 1.75s/it]
|
| 145 |
61%|██████ | 129/211 [05:09<02:29, 1.82s/it]
|
| 146 |
62%|██████▏ | 130/211 [05:10<02:24, 1.79s/it]
|
| 147 |
62%|██████▏ | 131/211 [05:12<02:20, 1.76s/it]
|
| 148 |
63%|██████▎ | 132/211 [05:14<02:17, 1.74s/it]
|
| 149 |
63%|██████▎ | 133/211 [05:15<02:14, 1.73s/it]
|
| 150 |
64%|██████▎ | 134/211 [05:18<02:23, 1.87s/it]
|
| 151 |
64%|██████▍ | 135/211 [05:20<02:22, 1.87s/it]
|
| 152 |
64%|██████▍ | 136/211 [05:21<02:15, 1.80s/it]
|
| 153 |
65%|██████▍ | 137/211 [05:23<02:13, 1.81s/it]
|
| 154 |
65%|██████▌ | 138/211 [05:25<02:10, 1.79s/it]
|
| 155 |
66%|██████▌ | 139/211 [05:26<02:06, 1.75s/it]
|
| 156 |
66%|██████▋ | 140/211 [05:28<02:03, 1.73s/it]
|
| 157 |
67%|██████▋ | 141/211 [05:30<02:04, 1.78s/it]
|
| 158 |
67%|██████▋ | 142/211 [05:32<02:05, 1.82s/it]
|
| 159 |
68%|██████▊ | 143/211 [05:34<02:04, 1.83s/it]
|
| 160 |
68%|██��███▊ | 144/211 [05:35<01:59, 1.78s/it]
|
| 161 |
69%|██████▊ | 145/211 [05:37<01:57, 1.78s/it]
|
| 162 |
69%|██████▉ | 146/211 [05:55<07:10, 6.63s/it]
|
| 163 |
70%|██████▉ | 147/211 [05:57<05:30, 5.17s/it]
|
| 164 |
70%|███████ | 148/211 [05:59<04:23, 4.17s/it]
|
| 165 |
71%|███████ | 149/211 [06:00<03:31, 3.42s/it]
|
| 166 |
71%|███████ | 150/211 [06:02<02:58, 2.92s/it]
|
| 167 |
72%|███████▏ | 151/211 [06:04<02:33, 2.56s/it]
|
| 168 |
72%|███████▏ | 152/211 [06:06<02:15, 2.29s/it]
|
| 169 |
73%|███████▎ | 153/211 [06:07<02:02, 2.12s/it]
|
| 170 |
73%|███████▎ | 154/211 [06:10<02:15, 2.38s/it]
|
| 171 |
73%|███████▎ | 155/211 [06:12<02:00, 2.15s/it]
|
| 172 |
74%|███████▍ | 156/211 [06:14<01:50, 2.00s/it]
|
| 173 |
74%|███████▍ | 157/211 [06:15<01:43, 1.91s/it]
|
| 174 |
75%|███████▍ | 158/211 [06:17<01:38, 1.86s/it]
|
| 175 |
75%|███████▌ | 159/211 [06:19<01:37, 1.88s/it]
|
| 176 |
76%|███████▌ | 160/211 [06:21<01:32, 1.82s/it]
|
| 177 |
76%|███████▋ | 161/211 [06:22<01:29, 1.79s/it]
|
| 178 |
77%|███████▋ | 162/211 [06:24<01:31, 1.87s/it]
|
| 179 |
77%|███████▋ | 163/211 [06:26<01:27, 1.83s/it]
|
| 180 |
78%|███████▊ | 164/211 [06:29<01:36, 2.05s/it]
|
| 181 |
78%|███████▊ | 165/211 [06:31<01:32, 2.01s/it]
|
| 182 |
79%|███████▊ | 166/211 [06:32<01:27, 1.94s/it]
|
| 183 |
79%|███████▉ | 167/211 [06:34<01:21, 1.85s/it]
|
| 184 |
80%|███████▉ | 168/211 [06:36<01:17, 1.79s/it]
|
| 185 |
80%|████████ | 169/211 [06:38<01:19, 1.89s/it]
|
| 186 |
81%|████████ | 170/211 [06:40<01:17, 1.88s/it]
|
| 187 |
81%|████████ | 171/211 [06:41<01:14, 1.85s/it]
|
| 188 |
82%|████████▏ | 172/211 [06:43<01:11, 1.84s/it]
|
| 189 |
82%|████████▏ | 173/211 [06:45<01:08, 1.81s/it]
|
| 190 |
82%|████████▏ | 174/211 [06:47<01:07, 1.84s/it]
|
| 191 |
83%|████████▎ | 175/211 [06:49<01:04, 1.78s/it]
|
| 192 |
83%|████████▎ | 176/211 [06:50<01:02, 1.79s/it]
|
| 193 |
84%|████████▍ | 177/211 [06:52<01:00, 1.77s/it]
|
| 194 |
84%|████████▍ | 178/211 [06:54<01:03, 1.92s/it]
|
| 195 |
85%|████████▍ | 179/211 [06:56<00:59, 1.87s/it]
|
| 196 |
85%|████████▌ | 180/211 [06:58<00:56, 1.81s/it]
|
| 197 |
86%|████████▌ | 181/211 [06:59<00:53, 1.77s/it]
|
| 198 |
86%|████████▋ | 182/211 [07:01<00:53, 1.84s/it]
|
| 199 |
87%|████████▋ | 183/211 [07:03<00:50, 1.81s/it]
|
| 200 |
87%|████████▋ | 184/211 [07:05<00:49, 1.83s/it]
|
| 201 |
88%|████████▊ | 185/211 [07:07<00:48, 1.86s/it]
|
| 202 |
88%|████████▊ | 186/211 [07:09<00:45, 1.83s/it]
|
| 203 |
89%|████████▊ | 187/211 [07:11<00:43, 1.82s/it]
|
| 204 |
89%|████████▉ | 188/211 [07:12<00:42, 1.83s/it]
|
| 205 |
90%|████████▉ | 189/211 [07:14<00:39, 1.77s/it]
|
| 206 |
90%|█████████ | 190/211 [07:16<00:37, 1.77s/it]
|
| 207 |
91%|█████████ | 191/211 [07:17<00:34, 1.73s/it]
|
| 208 |
91%|█████████ | 192/211 [07:19<00:32, 1.71s/it]
|
| 209 |
91%|█████████▏| 193/211 [07:21<00:30, 1.72s/it]
|
| 210 |
92%|█████████▏| 194/211 [07:23<00:29, 1.72s/it]
|
| 211 |
92%|█████████▏| 195/211 [07:24<00:27, 1.72s/it]
|
| 212 |
93%|█████████▎| 196/211 [07:26<00:25, 1.70s/it]
|
| 213 |
93%|█████████▎| 197/211 [07:28<00:24, 1.75s/it]
|
| 214 |
94%|█████████▍| 198/211 [07:30<00:23, 1.78s/it]
|
| 215 |
94%|█████████▍| 199/211 [07:31<00:21, 1.76s/it]
|
| 216 |
95%|█████████▍| 200/211 [07:33<00:19, 1.78s/it]
|
| 217 |
95%|█████████▌| 201/211 [07:35<00:17, 1.76s/it]
|
| 218 |
96%|█████████▌| 202/211 [07:37<00:16, 1.80s/it]
|
| 219 |
96%|█████████▌| 203/211 [07:39<00:14, 1.79s/it]
|
| 220 |
97%|█████████▋| 204/211 [07:40<00:12, 1.76s/it]
|
| 221 |
97%|█████████▋| 205/211 [07:42<00:10, 1.74s/it]
|
| 222 |
98%|█████████▊| 206/211 [07:44<00:08, 1.75s/it]
|
| 223 |
98%|█████████▊| 207/211 [07:45<00:06, 1.72s/it]
|
| 224 |
99%|█████████▊| 208/211 [07:47<00:05, 1.70s/it]
|
| 225 |
99%|█████████▉| 209/211 [07:49<00:03, 1.68s/it]
|
| 226 |
+
computing/reading reference batch statistics...
|
| 227 |
+
computing sample batch activations...
|
| 228 |
+
|
| 229 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 230 |
0%| | 1/469 [00:02<18:09, 2.33s/it]
|
| 231 |
0%| | 2/469 [00:04<15:10, 1.95s/it]
|
| 232 |
1%| | 3/469 [00:05<14:14, 1.83s/it]
|
| 233 |
1%| | 4/469 [00:07<13:52, 1.79s/it]
|
| 234 |
1%| | 5/469 [00:09<13:23, 1.73s/it]
|
| 235 |
1%|▏ | 6/469 [00:12<17:01, 2.21s/it]
|
| 236 |
1%|▏ | 7/469 [00:14<16:26, 2.14s/it]
|
| 237 |
2%|▏ | 8/469 [00:15<15:33, 2.02s/it]
|
| 238 |
2%|▏ | 9/469 [00:17<15:18, 2.00s/it]
|
| 239 |
2%|▏ | 10/469 [00:19<14:43, 1.92s/it]
|
| 240 |
2%|▏ | 11/469 [00:21<14:04, 1.84s/it]
|
| 241 |
3%|▎ | 12/469 [00:23<14:21, 1.89s/it]
|
| 242 |
3%|▎ | 13/469 [00:25<14:16, 1.88s/it]
|
| 243 |
3%|▎ | 14/469 [00:27<14:15, 1.88s/it]
|
| 244 |
3%|▎ | 15/469 [00:28<14:05, 1.86s/it]
|
| 245 |
3%|▎ | 16/469 [00:30<14:24, 1.91s/it]
|
| 246 |
4%|▎ | 17/469 [00:32<14:05, 1.87s/it]
|
| 247 |
4%|▍ | 18/469 [00:34<14:13, 1.89s/it]
|
| 248 |
4%|▍ | 19/469 [00:36<13:44, 1.83s/it]
|
| 249 |
4%|▍ | 20/469 [00:37<13:14, 1.77s/it]
|
| 250 |
4%|▍ | 21/469 [00:39<13:07, 1.76s/it]
|
| 251 |
5%|▍ | 22/469 [00:41<12:51, 1.72s/it]
|
| 252 |
5%|▍ | 23/469 [00:43<12:56, 1.74s/it]
|
| 253 |
5%|▌ | 24/469 [00:44<12:53, 1.74s/it]
|
| 254 |
5%|▌ | 25/469 [00:46<12:43, 1.72s/it]
|
| 255 |
6%|▌ | 26/469 [00:48<12:47, 1.73s/it]
|
| 256 |
6%|▌ | 27/469 [00:49<12:34, 1.71s/it]
|
| 257 |
6%|▌ | 28/469 [00:51<12:53, 1.75s/it]
|
| 258 |
6%|▌ | 29/469 [00:53<12:38, 1.72s/it]
|
| 259 |
6%|▋ | 30/469 [00:55<12:33, 1.72s/it]
|
| 260 |
7%|▋ | 31/469 [00:56<12:17, 1.68s/it]
|
| 261 |
7%|▋ | 32/469 [00:58<12:27, 1.71s/it]
|
| 262 |
7%|▋ | 33/469 [01:00<12:22, 1.70s/it]
|
| 263 |
7%|▋ | 34/469 [01:03<16:09, 2.23s/it]
|
| 264 |
7%|▋ | 35/469 [01:05<15:00, 2.08s/it]
|
| 265 |
8%|▊ | 36/469 [01:07<14:10, 1.96s/it]
|
| 266 |
8%|▊ | 37/469 [01:08<13:48, 1.92s/it]
|
| 267 |
8%|▊ | 38/469 [01:10<13:13, 1.84s/it]
|
| 268 |
8%|▊ | 39/469 [01:12<13:28, 1.88s/it]
|
| 269 |
9%|▊ | 40/469 [01:14<13:04, 1.83s/it]
|
| 270 |
9%|▊ | 41/469 [01:16<13:04, 1.83s/it]
|
| 271 |
9%|▉ | 42/469 [01:19<15:53, 2.23s/it]
|
| 272 |
9%|▉ | 43/469 [01:21<15:36, 2.20s/it]
|
| 273 |
9%|▉ | 44/469 [01:23<14:48, 2.09s/it]
|
| 274 |
10%|▉ | 45/469 [01:24<13:55, 1.97s/it]
|
| 275 |
10%|▉ | 46/469 [01:26<13:20, 1.89s/it]
|
| 276 |
10%|█ | 47/469 [01:28<13:24, 1.91s/it]
|
| 277 |
10%|█ | 48/469 [01:30<13:15, 1.89s/it]
|
| 278 |
10%|█ | 49/469 [01:31<12:39, 1.81s/it]
|
| 279 |
11%|█ | 50/469 [01:33<12:26, 1.78s/it]
|
| 280 |
11%|█ | 51/469 [01:35<12:47, 1.84s/it]
|
| 281 |
11%|█ | 52/469 [01:37<12:28, 1.79s/it]
|
| 282 |
11%|█▏ | 53/469 [01:39<12:09, 1.75s/it]
|
| 283 |
12%|█▏ | 54/469 [01:40<12:13, 1.77s/it]
|
| 284 |
12%|█▏ | 55/469 [01:42<12:03, 1.75s/it]
|
| 285 |
12%|█▏ | 56/469 [01:44<12:13, 1.78s/it]
|
| 286 |
12%|█▏ | 57/469 [01:46<12:27, 1.81s/it]
|
| 287 |
12%|█▏ | 58/469 [01:48<12:53, 1.88s/it]
|
| 288 |
13%|█▎ | 59/469 [01:50<12:46, 1.87s/it]
|
| 289 |
13%|█▎ | 60/469 [01:51<12:25, 1.82s/it]
|
| 290 |
13%|█▎ | 61/469 [01:53<12:01, 1.77s/it]
|
| 291 |
13%|█▎ | 62/469 [01:55<11:49, 1.74s/it]
|
| 292 |
13%|█▎ | 63/469 [01:56<11:43, 1.73s/it]
|
| 293 |
14%|█▎ | 64/469 [01:58<11:37, 1.72s/it]
|
| 294 |
14%|█▍ | 65/469 [02:00<11:38, 1.73s/it]
|
| 295 |
14%|█▍ | 66/469 [02:02<11:28, 1.71s/it]
|
| 296 |
14%|█▍ | 67/469 [02:03<11:30, 1.72s/it]
|
| 297 |
14%|█▍ | 68/469 [02:05<11:30, 1.72s/it]
|
| 298 |
15%|█▍ | 69/469 [02:07<12:03, 1.81s/it]
|
| 299 |
15%|█▍ | 70/469 [02:09<11:46, 1.77s/it]
|
| 300 |
15%|█▌ | 71/469 [02:10<11:36, 1.75s/it]
|
| 301 |
15%|█▌ | 72/469 [02:12<11:34, 1.75s/it]
|
| 302 |
16%|█▌ | 73/469 [02:14<12:02, 1.82s/it]
|
| 303 |
16%|█▌ | 74/469 [02:16<11:41, 1.78s/it]
|
| 304 |
16%|█▌ | 75/469 [02:18<12:10, 1.85s/it]
|
| 305 |
16%|█▌ | 76/469 [02:20<12:04, 1.84s/it]
|
| 306 |
16%|█▋ | 77/469 [02:21<11:44, 1.80s/it]
|
| 307 |
17%|█▋ | 78/469 [02:23<11:44, 1.80s/it]
|
| 308 |
17%|█▋ | 79/469 [02:25<11:44, 1.81s/it]
|
| 309 |
17%|█▋ | 80/469 [02:27<11:43, 1.81s/it]
|
| 310 |
17%|█▋ | 81/469 [02:28<11:22, 1.76s/it]
|
| 311 |
17%|█▋ | 82/469 [02:30<11:10, 1.73s/it]
|
| 312 |
18%|█▊ | 83/469 [02:32<11:12, 1.74s/it]
|
| 313 |
18%|█▊ | 84/469 [02:34<11:11, 1.75s/it]
|
| 314 |
18%|█▊ | 85/469 [02:36<11:32, 1.80s/it]
|
| 315 |
18%|█▊ | 86/469 [02:37<11:32, 1.81s/it]
|
| 316 |
19%|█▊ | 87/469 [02:39<11:16, 1.77s/it]
|
| 317 |
19%|█▉ | 88/469 [02:41<11:07, 1.75s/it]
|
| 318 |
19%|█▉ | 89/469 [02:42<11:00, 1.74s/it]
|
| 319 |
19%|█▉ | 90/469 [02:44<11:09, 1.77s/it]
|
| 320 |
19%|█▉ | 91/469 [02:46<11:15, 1.79s/it]
|
| 321 |
20%|█▉ | 92/469 [02:48<11:09, 1.78s/it]
|
| 322 |
20%|█▉ | 93/469 [02:50<10:57, 1.75s/it]
|
| 323 |
20%|██ | 94/469 [02:51<10:55, 1.75s/it]
|
| 324 |
20%|██ | 95/469 [02:53<10:50, 1.74s/it]
|
| 325 |
20%|██ | 96/469 [02:55<10:42, 1.72s/it]
|
| 326 |
21%|██ | 97/469 [02:56<10:37, 1.71s/it]
|
| 327 |
21%|██ | 98/469 [02:58<10:51, 1.76s/it]
|
| 328 |
21%|██ | 99/469 [03:00<10:52, 1.76s/it]
|
| 329 |
21%|██▏ | 100/469 [03:02<11:08, 1.81s/it]
|
| 330 |
22%|██▏ | 101/469 [03:04<10:52, 1.77s/it]
|
| 331 |
22%|██▏ | 102/469 [03:05<10:48, 1.77s/it]
|
| 332 |
22%|██▏ | 103/469 [03:07<10:35, 1.74s/it]
|
| 333 |
22%|██▏ | 104/469 [03:09<10:28, 1.72s/it]
|
| 334 |
22%|██▏ | 105/469 [03:11<10:38, 1.75s/it]
|
| 335 |
23%|██▎ | 106/469 [03:12<10:29, 1.73s/it]
|
| 336 |
23%|██▎ | 107/469 [03:14<10:22, 1.72s/it]
|
| 337 |
23%|██▎ | 108/469 [03:16<10:41, 1.78s/it]
|
| 338 |
23%|██▎ | 109/469 [03:18<10:29, 1.75s/it]
|
| 339 |
23%|██▎ | 110/469 [03:19<10:32, 1.76s/it]
|
| 340 |
24%|██▎ | 111/469 [03:21<10:25, 1.75s/it]
|
| 341 |
24%|██▍ | 112/469 [03:23<10:24, 1.75s/it]
|
| 342 |
24%|██▍ | 113/469 [03:25<10:37, 1.79s/it]
|
| 343 |
24%|██▍ | 114/469 [03:26<10:24, 1.76s/it]
|
| 344 |
25%|██▍ | 115/469 [03:28<10:27, 1.77s/it]
|
| 345 |
25%|██▍ | 116/469 [03:30<10:22, 1.76s/it]
|
| 346 |
25%|██▍ | 117/469 [03:32<10:28, 1.79s/it]
|
| 347 |
25%|██▌ | 118/469 [03:34<10:52, 1.86s/it]
|
| 348 |
25%|██▌ | 119/469 [03:36<11:00, 1.89s/it]
|
| 349 |
26%|██▌ | 120/469 [03:37<10:42, 1.84s/it]
|
| 350 |
26%|██▌ | 121/469 [03:39<10:37, 1.83s/it]
|
| 351 |
26%|██▌ | 122/469 [03:41<10:17, 1.78s/it]
|
| 352 |
26%|██▌ | 123/469 [03:43<10:21, 1.80s/it]
|
| 353 |
26%|██▋ | 124/469 [03:45<10:14, 1.78s/it]
|
| 354 |
27%|██▋ | 125/469 [03:46<10:15, 1.79s/it]
|
| 355 |
27%|██▋ | 126/469 [03:48<10:12, 1.79s/it]
|
| 356 |
27%|██▋ | 127/469 [03:50<09:59, 1.75s/it]
|
| 357 |
27%|██▋ | 128/469 [03:51<09:49, 1.73s/it]
|
| 358 |
28%|██▊ | 129/469 [03:53<09:53, 1.75s/it]
|
| 359 |
28%|██▊ | 130/469 [03:55<09:42, 1.72s/it]
|
| 360 |
28%|██▊ | 131/469 [03:57<09:37, 1.71s/it]
|
| 361 |
28%|██▊ | 132/469 [03:58<09:49, 1.75s/it]
|
| 362 |
28%|██▊ | 133/469 [04:00<09:59, 1.78s/it]
|
| 363 |
29%|██▊ | 134/469 [04:02<10:22, 1.86s/it]
|
| 364 |
29%|██▉ | 135/469 [04:04<10:20, 1.86s/it]
|
| 365 |
29%|██▉ | 136/469 [04:06<10:26, 1.88s/it]
|
| 366 |
29%|██▉ | 137/469 [04:08<10:02, 1.81s/it]
|
| 367 |
29%|██▉ | 138/469 [04:10<10:16, 1.86s/it]
|
| 368 |
30%|██▉ | 139/469 [04:11<09:58, 1.81s/it]
|
| 369 |
30%|██▉ | 140/469 [04:13<09:59, 1.82s/it]
|
| 370 |
30%|███ | 141/469 [04:15<09:40, 1.77s/it]
|
| 371 |
30%|███ | 142/469 [04:17<09:28, 1.74s/it]
|
| 372 |
30%|███ | 143/469 [04:18<09:16, 1.71s/it]
|
| 373 |
31%|███ | 144/469 [04:20<09:13, 1.70s/it]
|
| 374 |
31%|███ | 145/469 [04:22<09:28, 1.75s/it]
|
| 375 |
31%|███ | 146/469 [04:24<09:18, 1.73s/it]
|
| 376 |
31%|███▏ | 147/469 [04:25<09:35, 1.79s/it]
|
| 377 |
32%|███▏ | 148/469 [04:27<09:28, 1.77s/it]
|
| 378 |
32%|███▏ | 149/469 [04:29<09:12, 1.73s/it]
|
| 379 |
32%|███▏ | 150/469 [04:31<09:11, 1.73s/it]
|
| 380 |
32%|███▏ | 151/469 [04:32<09:18, 1.76s/it]
|
| 381 |
32%|███▏ | 152/469 [04:34<09:23, 1.78s/it]
|
| 382 |
33%|███▎ | 153/469 [04:36<09:21, 1.78s/it]
|
| 383 |
33%|███▎ | 154/469 [04:38<09:11, 1.75s/it]
|
| 384 |
33%|███▎ | 155/469 [04:39<09:09, 1.75s/it]
|
| 385 |
33%|███▎ | 156/469 [04:42<10:07, 1.94s/it]
|
| 386 |
33%|███▎ | 157/469 [04:44<10:39, 2.05s/it]
|
| 387 |
34%|███▎ | 158/469 [04:46<10:12, 1.97s/it]
|
| 388 |
34%|███▍ | 159/469 [04:48<09:42, 1.88s/it]
|
| 389 |
34%|███▍ | 160/469 [04:49<09:42, 1.88s/it]
|
| 390 |
34%|███▍ | 161/469 [04:51<09:24, 1.83s/it]
|
| 391 |
35%|███▍ | 162/469 [04:53<09:08, 1.79s/it]
|
| 392 |
35%|███▍ | 163/469 [04:55<09:01, 1.77s/it]
|
| 393 |
35%|███▍ | 164/469 [04:56<09:08, 1.80s/it]
|
| 394 |
35%|███▌ | 165/469 [04:58<09:07, 1.80s/it]
|
| 395 |
35%|███▌ | 166/469 [05:00<09:06, 1.80s/it]
|
| 396 |
36%|███▌ | 167/469 [05:02<09:00, 1.79s/it]
|
| 397 |
36%|███▌ | 168/469 [05:03<08:47, 1.75s/it]
|
| 398 |
36%|███▌ | 169/469 [05:05<08:51, 1.77s/it]
|
| 399 |
36%|███▌ | 170/469 [05:07<08:35, 1.73s/it]
|
| 400 |
36%|███▋ | 171/469 [05:09<08:25, 1.70s/it]
|
| 401 |
37%|███▋ | 172/469 [05:10<08:24, 1.70s/it]
|
| 402 |
37%|███▋ | 173/469 [05:12<08:20, 1.69s/it]
|
| 403 |
37%|███▋ | 174/469 [05:14<08:20, 1.70s/it]
|
| 404 |
37%|███▋ | 175/469 [05:16<08:42, 1.78s/it]
|
| 405 |
38%|███▊ | 176/469 [05:17<08:34, 1.76s/it]
|
| 406 |
38%|███▊ | 177/469 [05:19<08:31, 1.75s/it]
|
| 407 |
38%|███▊ | 178/469 [05:21<08:34, 1.77s/it]
|
| 408 |
38%|███▊ | 179/469 [05:22<08:22, 1.73s/it]
|
| 409 |
38%|███▊ | 180/469 [05:24<08:22, 1.74s/it]
|
| 410 |
39%|███▊ | 181/469 [05:26<08:14, 1.72s/it]
|
| 411 |
39%|███▉ | 182/469 [05:28<08:19, 1.74s/it]
|
| 412 |
39%|███▉ | 183/469 [05:29<08:17, 1.74s/it]
|
| 413 |
39%|███▉ | 184/469 [05:31<08:29, 1.79s/it]
|
| 414 |
39%|███▉ | 185/469 [05:33<08:36, 1.82s/it]
|
| 415 |
40%|███▉ | 186/469 [05:35<08:36, 1.82s/it]
|
| 416 |
40%|███▉ | 187/469 [05:37<08:29, 1.81s/it]
|
| 417 |
40%|████ | 188/469 [05:39<08:45, 1.87s/it]
|
| 418 |
40%|████ | 189/469 [05:41<08:29, 1.82s/it]
|
| 419 |
41%|████ | 190/469 [05:42<08:28, 1.82s/it]
|
| 420 |
41%|████ | 191/469 [05:44<08:18, 1.79s/it]
|
| 421 |
41%|████ | 192/469 [05:46<08:10, 1.77s/it]
|
| 422 |
41%|████ | 193/469 [05:47<07:58, 1.73s/it]
|
| 423 |
41%|████▏ | 194/469 [05:49<07:52, 1.72s/it]
|
| 424 |
42%|████▏ | 195/469 [05:51<08:08, 1.78s/it]
|
| 425 |
42%|████▏ | 196/469 [05:53<08:19, 1.83s/it]
|
| 426 |
42%|████▏ | 197/469 [05:55<08:35, 1.90s/it]
|
| 427 |
42%|████▏ | 198/469 [05:57<08:28, 1.88s/it]
|
| 428 |
42%|████▏ | 199/469 [05:59<08:17, 1.84s/it]
|
| 429 |
43%|████▎ | 200/469 [06:01<08:22, 1.87s/it]
|
| 430 |
43%|████▎ | 201/469 [06:02<08:09, 1.83s/it]
|
| 431 |
43%|████▎ | 202/469 [06:04<08:08, 1.83s/it]
|
| 432 |
43%|████▎ | 203/469 [06:06<08:21, 1.88s/it]
|
| 433 |
43%|████▎ | 204/469 [06:08<07:59, 1.81s/it]
|
| 434 |
44%|████▎ | 205/469 [06:09<07:46, 1.77s/it]
|
| 435 |
44%|████▍ | 206/469 [06:11<07:38, 1.74s/it]
|
| 436 |
44%|████▍ | 207/469 [06:13<07:27, 1.71s/it]
|
| 437 |
44%|████▍ | 208/469 [06:15<07:40, 1.76s/it]
|
| 438 |
45%|████▍ | 209/469 [06:16<07:35, 1.75s/it]
|
| 439 |
45%|████▍ | 210/469 [06:18<07:30, 1.74s/it]
|
| 440 |
45%|████▍ | 211/469 [06:20<07:35, 1.77s/it]
|
| 441 |
45%|████▌ | 212/469 [06:22<07:32, 1.76s/it]
|
| 442 |
45%|████▌ | 213/469 [06:23<07:29, 1.76s/it]
|
| 443 |
46%|████▌ | 214/469 [06:25<07:38, 1.80s/it]
|
| 444 |
46%|████▌ | 215/469 [06:27<07:34, 1.79s/it]
|
| 445 |
46%|████▌ | 216/469 [06:29<07:28, 1.77s/it]
|
| 446 |
46%|████▋ | 217/469 [06:30<07:20, 1.75s/it]
|
| 447 |
46%|████▋ | 218/469 [06:32<07:08, 1.71s/it]
|
| 448 |
47%|████▋ | 219/469 [06:34<07:04, 1.70s/it]
|
| 449 |
47%|████▋ | 220/469 [06:35<07:02, 1.70s/it]
|
| 450 |
47%|████▋ | 221/469 [06:38<07:54, 1.91s/it]
|
| 451 |
47%|████▋ | 222/469 [06:40<07:49, 1.90s/it]
|
| 452 |
48%|████▊ | 223/469 [06:42<07:37, 1.86s/it]
|
| 453 |
48%|████▊ | 224/469 [06:43<07:18, 1.79s/it]
|
| 454 |
48%|████▊ | 225/469 [06:45<07:42, 1.90s/it]
|
| 455 |
48%|████▊ | 226/469 [06:48<09:05, 2.25s/it]
|
| 456 |
48%|████▊ | 227/469 [06:50<08:33, 2.12s/it]
|
| 457 |
49%|████▊ | 228/469 [06:52<07:55, 1.97s/it]
|
| 458 |
49%|████▉ | 229/469 [06:54<07:40, 1.92s/it]
|
| 459 |
49%|████▉ | 230/469 [06:55<07:33, 1.90s/it]
|
| 460 |
49%|████▉ | 231/469 [06:57<07:17, 1.84s/it]
|
| 461 |
49%|████▉ | 232/469 [06:59<07:05, 1.79s/it]
|
| 462 |
50%|████▉ | 233/469 [07:01<07:03, 1.80s/it]
|
| 463 |
50%|████▉ | 234/469 [07:02<06:59, 1.78s/it]
|
| 464 |
50%|█████ | 235/469 [07:05<08:13, 2.11s/it]
|
| 465 |
50%|█████ | 236/469 [07:07<07:39, 1.97s/it]
|
| 466 |
51%|█████ | 237/469 [07:09<07:15, 1.88s/it]
|
| 467 |
51%|█████ | 238/469 [07:10<07:11, 1.87s/it]
|
| 468 |
51%|█████ | 239/469 [07:12<06:55, 1.81s/it]
|
| 469 |
51%|█████ | 240/469 [07:14<07:03, 1.85s/it]
|
| 470 |
51%|█████▏ | 241/469 [07:16<06:54, 1.82s/it]
|
| 471 |
52%|█████▏ | 242/469 [07:17<06:41, 1.77s/it]
|
| 472 |
52%|█████▏ | 243/469 [07:19<06:33, 1.74s/it]
|
| 473 |
52%|█████▏ | 244/469 [07:21<06:28, 1.73s/it]
|
| 474 |
52%|█████▏ | 245/469 [07:24<07:39, 2.05s/it]
|
| 475 |
52%|█████▏ | 246/469 [07:25<07:19, 1.97s/it]
|
| 476 |
53%|█████▎ | 247/469 [07:27<07:08, 1.93s/it]
|
| 477 |
53%|█████▎ | 248/469 [07:29<06:57, 1.89s/it]
|
| 478 |
53%|█████▎ | 249/469 [07:31<06:40, 1.82s/it]
|
| 479 |
53%|█████▎ | 250/469 [07:33<06:38, 1.82s/it]
|
| 480 |
54%|█████▎ | 251/469 [07:34<06:25, 1.77s/it]
|
| 481 |
54%|█████▎ | 252/469 [07:36<06:15, 1.73s/it]
|
| 482 |
54%|█████▍ | 253/469 [07:37<06:09, 1.71s/it]
|
| 483 |
54%|█████▍ | 254/469 [07:39<06:05, 1.70s/it]
|
| 484 |
54%|█████▍ | 255/469 [07:41<06:02, 1.70s/it]
|
| 485 |
55%|█████▍ | 256/469 [07:42<05:58, 1.68s/it]
|
| 486 |
55%|█████▍ | 257/469 [07:44<06:01, 1.71s/it]
|
| 487 |
55%|█████▌ | 258/469 [07:47<06:54, 1.96s/it]
|
| 488 |
55%|█████▌ | 259/469 [07:49<06:35, 1.88s/it]
|
| 489 |
55%|█████▌ | 260/469 [07:52<08:26, 2.42s/it]
|
| 490 |
56%|█████▌ | 261/469 [07:54<07:35, 2.19s/it]
|
| 491 |
56%|█████▌ | 262/469 [07:56<07:07, 2.07s/it]
|
| 492 |
56%|█████▌ | 263/469 [07:57<06:47, 1.98s/it]
|
| 493 |
56%|█████▋ | 264/469 [07:59<06:25, 1.88s/it]
|
| 494 |
57%|█████▋ | 265/469 [08:01<06:21, 1.87s/it]
|
| 495 |
57%|█████▋ | 266/469 [08:03<06:20, 1.87s/it]
|
| 496 |
57%|█████▋ | 267/469 [08:04<06:08, 1.82s/it]
|
| 497 |
57%|█████▋ | 268/469 [08:06<06:07, 1.83s/it]
|
| 498 |
57%|█████▋ | 269/469 [08:08<05:54, 1.77s/it]
|
| 499 |
58%|█████▊ | 270/469 [08:10<05:59, 1.80s/it]
|
| 500 |
58%|█████▊ | 271/469 [08:12<05:51, 1.77s/it]
|
| 501 |
58%|█████▊ | 272/469 [08:13<05:58, 1.82s/it]
|
| 502 |
58%|█████▊ | 273/469 [08:15<06:06, 1.87s/it]
|
| 503 |
58%|█████▊ | 274/469 [08:17<05:52, 1.81s/it]
|
| 504 |
59%|█████▊ | 275/469 [08:19<05:50, 1.81s/it]
|
| 505 |
59%|█████▉ | 276/469 [08:21<05:42, 1.77s/it]
|
| 506 |
59%|█████▉ | 277/469 [08:22<05:37, 1.76s/it]
|
| 507 |
59%|█████▉ | 278/469 [08:24<05:34, 1.75s/it]
|
| 508 |
59%|█████▉ | 279/469 [08:26<05:38, 1.78s/it]
|
| 509 |
60%|█████▉ | 280/469 [08:28<05:38, 1.79s/it]
|
| 510 |
60%|█████▉ | 281/469 [08:30<05:39, 1.81s/it]
|
| 511 |
60%|██████ | 282/469 [08:31<05:38, 1.81s/it]
|
| 512 |
60%|██████ | 283/469 [08:33<05:30, 1.78s/it]
|
| 513 |
61%|██████ | 284/469 [08:35<05:26, 1.77s/it]
|
| 514 |
61%|██████ | 285/469 [08:36<05:17, 1.73s/it]
|
| 515 |
61%|██████ | 286/469 [08:38<05:14, 1.72s/it]
|
| 516 |
61%|██████ | 287/469 [08:40<05:10, 1.71s/it]
|
| 517 |
61%|██████▏ | 288/469 [08:41<05:05, 1.69s/it]
|
| 518 |
62%|██████▏ | 289/469 [08:43<05:04, 1.69s/it]
|
| 519 |
62%|██████▏ | 290/469 [08:45<05:11, 1.74s/it]
|
| 520 |
62%|██████▏ | 291/469 [08:47<05:11, 1.75s/it]
|
| 521 |
62%|██████▏ | 292/469 [08:48<05:04, 1.72s/it]
|
| 522 |
62%|██████▏ | 293/469 [08:50<05:04, 1.73s/it]
|
| 523 |
63%|██████▎ | 294/469 [08:52<05:19, 1.83s/it]
|
| 524 |
63%|██████▎ | 295/469 [08:54<05:14, 1.81s/it]
|
| 525 |
63%|██████▎ | 296/469 [08:56<05:04, 1.76s/it]
|
| 526 |
63%|██████▎ | 297/469 [08:57<04:58, 1.74s/it]
|
| 527 |
64%|██████▎ | 298/469 [08:59<05:01, 1.76s/it]
|
| 528 |
64%|██████▍ | 299/469 [09:01<04:56, 1.74s/it]
|
| 529 |
64%|██████▍ | 300/469 [09:03<04:52, 1.73s/it]
|
| 530 |
64%|██████▍ | 301/469 [09:04<04:48, 1.72s/it]
|
| 531 |
64%|██████▍ | 302/469 [09:06<04:43, 1.70s/it]
|
| 532 |
65%|██████▍ | 303/469 [09:08<04:50, 1.75s/it]
|
| 533 |
65%|██████▍ | 304/469 [09:10<04:46, 1.74s/it]
|
| 534 |
65%|██████▌ | 305/469 [09:11<04:47, 1.75s/it]
|
| 535 |
65%|██████▌ | 306/469 [09:13<04:42, 1.73s/it]
|
| 536 |
65%|██████▌ | 307/469 [09:15<04:58, 1.84s/it]
|
| 537 |
66%|██████▌ | 308/469 [09:17<04:50, 1.81s/it]
|
| 538 |
66%|██████▌ | 309/469 [09:19<04:46, 1.79s/it]
|
| 539 |
66%|██████▌ | 310/469 [09:20<04:39, 1.76s/it]
|
| 540 |
66%|██████▋ | 311/469 [09:22<04:34, 1.74s/it]
|
| 541 |
67%|██████▋ | 312/469 [09:24<04:36, 1.76s/it]
|
| 542 |
67%|██████▋ | 313/469 [09:26<04:52, 1.88s/it]
|
| 543 |
67%|██████▋ | 314/469 [09:28<04:55, 1.91s/it]
|
| 544 |
67%|██████▋ | 315/469 [09:30<04:43, 1.84s/it]
|
| 545 |
67%|██████▋ | 316/469 [09:31<04:36, 1.81s/it]
|
| 546 |
68%|██████▊ | 317/469 [09:33<04:42, 1.86s/it]
|
| 547 |
68%|██████▊ | 318/469 [09:36<05:38, 2.24s/it]
|
| 548 |
68%|██████▊ | 319/469 [09:38<05:18, 2.13s/it]
|
| 549 |
68%|██████▊ | 320/469 [09:40<04:59, 2.01s/it]
|
| 550 |
68%|██████▊ | 321/469 [09:42<04:49, 1.95s/it]
|
| 551 |
69%|██████▊ | 322/469 [09:43<04:32, 1.86s/it]
|
| 552 |
69%|██████▉ | 323/469 [09:47<05:51, 2.41s/it]
|
| 553 |
69%|██████▉ | 324/469 [09:49<05:17, 2.19s/it]
|
| 554 |
69%|██████▉ | 325/469 [09:51<04:58, 2.08s/it]
|
| 555 |
70%|██████▉ | 326/469 [09:52<04:42, 1.97s/it]
|
| 556 |
70%|██████▉ | 327/469 [09:54<04:29, 1.89s/it]
|
| 557 |
70%|██████▉ | 328/469 [09:56<04:17, 1.82s/it]
|
| 558 |
70%|███████ | 329/469 [09:58<04:12, 1.80s/it]
|
| 559 |
70%|███████ | 330/469 [09:59<04:07, 1.78s/it]
|
| 560 |
71%|███████ | 331/469 [10:01<04:12, 1.83s/it]
|
| 561 |
71%|███████ | 332/469 [10:03<04:06, 1.80s/it]
|
| 562 |
71%|███████ | 333/469 [10:05<03:57, 1.75s/it]
|
| 563 |
71%|███████ | 334/469 [10:06<03:51, 1.71s/it]
|
| 564 |
71%|███████▏ | 335/469 [10:08<03:47, 1.70s/it]
|
| 565 |
72%|███████▏ | 336/469 [10:09<03:44, 1.69s/it]
|
| 566 |
72%|███████▏ | 337/469 [10:11<03:42, 1.68s/it]
|
| 567 |
72%|███████▏ | 338/469 [10:13<03:45, 1.72s/it]
|
| 568 |
72%|███████▏ | 339/469 [10:15<03:41, 1.70s/it]
|
| 569 |
72%|███████▏ | 340/469 [10:16<03:38, 1.70s/it]
|
| 570 |
73%|███████▎ | 341/469 [10:18<03:37, 1.70s/it]
|
| 571 |
73%|███████▎ | 342/469 [10:20<03:47, 1.79s/it]
|
| 572 |
73%|███████▎ | 343/469 [10:22<03:42, 1.77s/it]
|
| 573 |
73%|███████▎ | 344/469 [10:24<03:40, 1.77s/it]
|
| 574 |
74%|███████▎ | 345/469 [10:25<03:39, 1.77s/it]
|
| 575 |
74%|███████▍ | 346/469 [10:27<03:36, 1.76s/it]
|
| 576 |
74%|███████▍ | 347/469 [10:29<03:31, 1.73s/it]
|
| 577 |
74%|███████▍ | 348/469 [10:31<03:35, 1.78s/it]
|
| 578 |
74%|███████▍ | 349/469 [10:32<03:34, 1.78s/it]
|
| 579 |
75%|███████▍ | 350/469 [10:34<03:30, 1.76s/it]
|
| 580 |
75%|███████▍ | 351/469 [10:36<03:23, 1.72s/it]
|
| 581 |
75%|███████▌ | 352/469 [10:38<03:33, 1.82s/it]
|
| 582 |
75%|███████▌ | 353/469 [10:40<03:29, 1.81s/it]
|
| 583 |
75%|███████▌ | 354/469 [10:41<03:27, 1.80s/it]
|
| 584 |
76%|███████▌ | 355/469 [10:43<03:21, 1.77s/it]
|
| 585 |
76%|███████▌ | 356/469 [10:45<03:27, 1.83s/it]
|
| 586 |
76%|███████▌ | 357/469 [10:47<03:18, 1.77s/it]
|
| 587 |
76%|███████▋ | 358/469 [10:48<03:14, 1.75s/it]
|
| 588 |
77%|██��████▋ | 359/469 [10:50<03:09, 1.73s/it]
|
| 589 |
77%|███████▋ | 360/469 [10:52<03:07, 1.72s/it]
|
| 590 |
77%|███████▋ | 361/469 [10:53<03:06, 1.72s/it]
|
| 591 |
77%|███████▋ | 362/469 [10:55<03:07, 1.75s/it]
|
| 592 |
77%|███████▋ | 363/469 [10:57<03:10, 1.79s/it]
|
| 593 |
78%|███████▊ | 364/469 [10:59<03:06, 1.77s/it]
|
| 594 |
78%|███████▊ | 365/469 [11:01<03:00, 1.73s/it]
|
| 595 |
78%|███████▊ | 366/469 [11:02<02:55, 1.71s/it]
|
| 596 |
78%|███████▊ | 367/469 [11:04<02:53, 1.70s/it]
|
| 597 |
78%|███████▊ | 368/469 [11:06<02:53, 1.72s/it]
|
| 598 |
79%|███████▊ | 369/469 [11:07<02:51, 1.71s/it]
|
| 599 |
79%|███████▉ | 370/469 [11:09<02:48, 1.70s/it]
|
| 600 |
79%|███████▉ | 371/469 [11:11<02:48, 1.71s/it]
|
| 601 |
79%|███████▉ | 372/469 [11:13<02:51, 1.77s/it]
|
| 602 |
80%|███████▉ | 373/469 [11:14<02:46, 1.73s/it]
|
| 603 |
80%|███████▉ | 374/469 [11:16<02:45, 1.75s/it]
|
| 604 |
80%|███████▉ | 375/469 [11:18<02:43, 1.74s/it]
|
| 605 |
80%|████████ | 376/469 [11:20<02:42, 1.74s/it]
|
| 606 |
80%|████████ | 377/469 [11:21<02:39, 1.74s/it]
|
| 607 |
81%|████████ | 378/469 [11:23<02:36, 1.72s/it]
|
| 608 |
81%|████████ | 379/469 [11:25<02:35, 1.72s/it]
|
| 609 |
81%|████████ | 380/469 [11:26<02:32, 1.72s/it]
|
| 610 |
81%|████████ | 381/469 [11:28<02:34, 1.76s/it]
|
| 611 |
81%|████████▏ | 382/469 [11:30<02:43, 1.88s/it]
|
| 612 |
82%|████████▏ | 383/469 [11:32<02:35, 1.81s/it]
|
| 613 |
82%|████████▏ | 384/469 [11:34<02:29, 1.76s/it]
|
| 614 |
82%|████████▏ | 385/469 [11:36<02:30, 1.80s/it]
|
| 615 |
82%|████████▏ | 386/469 [11:37<02:26, 1.77s/it]
|
| 616 |
83%|████████▎ | 387/469 [11:39<02:26, 1.78s/it]
|
| 617 |
83%|████████▎ | 388/469 [11:41<02:23, 1.77s/it]
|
| 618 |
83%|████████▎ | 389/469 [11:42<02:18, 1.73s/it]
|
| 619 |
83%|████████▎ | 390/469 [11:44<02:18, 1.75s/it]
|
| 620 |
83%|████████▎ | 391/469 [11:46<02:16, 1.76s/it]
|
| 621 |
84%|████████▎ | 392/469 [11:49<02:33, 1.99s/it]
|
| 622 |
84%|████████▍ | 393/469 [11:50<02:24, 1.90s/it]
|
| 623 |
84%|████████▍ | 394/469 [11:52<02:18, 1.85s/it]
|
| 624 |
84%|████████▍ | 395/469 [11:54<02:12, 1.79s/it]
|
| 625 |
84%|████████▍ | 396/469 [11:58<03:02, 2.50s/it]
|
| 626 |
85%|████████▍ | 397/469 [12:00<02:44, 2.29s/it]
|
| 627 |
85%|████████▍ | 398/469 [12:01<02:31, 2.13s/it]
|
| 628 |
85%|████████▌ | 399/469 [12:03<02:21, 2.02s/it]
|
| 629 |
85%|████████▌ | 400/469 [12:05<02:15, 1.96s/it]
|
| 630 |
86%|████████▌ | 401/469 [12:07<02:10, 1.92s/it]
|
| 631 |
86%|████████▌ | 402/469 [12:08<02:02, 1.83s/it]
|
| 632 |
86%|████████▌ | 403/469 [12:10<01:57, 1.78s/it]
|
| 633 |
86%|████████▌ | 404/469 [12:12<01:53, 1.74s/it]
|
| 634 |
86%|████████▋ | 405/469 [12:13<01:49, 1.72s/it]
|
| 635 |
87%|████████▋ | 406/469 [12:15<01:50, 1.76s/it]
|
| 636 |
87%|████████▋ | 407/469 [12:17<01:48, 1.74s/it]
|
| 637 |
87%|████████▋ | 408/469 [12:19<01:44, 1.72s/it]
|
| 638 |
87%|████████▋ | 409/469 [12:20<01:41, 1.70s/it]
|
| 639 |
87%|████████▋ | 410/469 [12:22<01:45, 1.79s/it]
|
| 640 |
88%|████████▊ | 411/469 [12:24<01:40, 1.74s/it]
|
| 641 |
88%|████████▊ | 412/469 [12:26<01:41, 1.78s/it]
|
| 642 |
88%|████████▊ | 413/469 [12:28<01:39, 1.78s/it]
|
| 643 |
88%|████████▊ | 414/469 [12:29<01:35, 1.74s/it]
|
| 644 |
88%|████████▊ | 415/469 [12:31<01:38, 1.83s/it]
|
| 645 |
89%|████████▊ | 416/469 [12:33<01:36, 1.82s/it]
|
| 646 |
89%|████████▉ | 417/469 [12:35<01:31, 1.76s/it]
|
| 647 |
89%|████████▉ | 418/469 [12:36<01:31, 1.79s/it]
|
| 648 |
89%|████████▉ | 419/469 [12:38<01:26, 1.74s/it]
|
| 649 |
90%|████████▉ | 420/469 [12:40<01:25, 1.75s/it]
|
| 650 |
90%|████████▉ | 421/469 [12:42<01:23, 1.75s/it]
|
| 651 |
90%|████████▉ | 422/469 [12:43<01:20, 1.71s/it]
|
| 652 |
90%|█████████ | 423/469 [12:45<01:20, 1.74s/it]
|
| 653 |
90%|█████████ | 424/469 [12:47<01:18, 1.74s/it]
|
| 654 |
91%|█████████ | 425/469 [12:49<01:17, 1.75s/it]
|
| 655 |
91%|█████████ | 426/469 [12:50<01:15, 1.75s/it]
|
| 656 |
91%|█████████ | 427/469 [12:52<01:14, 1.78s/it]
|
| 657 |
91%|█████████▏| 428/469 [12:54<01:12, 1.77s/it]
|
| 658 |
91%|█████████▏| 429/469 [12:56<01:09, 1.75s/it]
|
| 659 |
92%|█████████▏| 430/469 [12:57<01:07, 1.72s/it]
|
| 660 |
92%|█████████▏| 431/469 [12:59<01:04, 1.70s/it]
|
| 661 |
92%|█████████▏| 432/469 [13:01<01:04, 1.74s/it]
|
| 662 |
92%|█████████▏| 433/469 [13:03<01:07, 1.88s/it]
|
| 663 |
93%|█████████▎| 434/469 [13:05<01:04, 1.86s/it]
|
| 664 |
93%|█████████▎| 435/469 [13:07<01:03, 1.87s/it]
|
| 665 |
93%|█████████▎| 436/469 [13:08<00:59, 1.81s/it]
|
| 666 |
93%|█████████▎| 437/469 [13:10<00:57, 1.81s/it]
|
| 667 |
93%|█████████▎| 438/469 [13:12<00:57, 1.86s/it]
|
| 668 |
94%|█████████▎| 439/469 [13:14<00:54, 1.82s/it]
|
| 669 |
94%|█████████▍| 440/469 [13:16<00:52, 1.82s/it]
|
| 670 |
94%|█████████▍| 441/469 [13:18<00:51, 1.84s/it]
|
| 671 |
94%|█████████▍| 442/469 [13:19<00:48, 1.80s/it]
|
| 672 |
94%|█████████▍| 443/469 [13:21<00:48, 1.86s/it]
|
| 673 |
95%|█████████▍| 444/469 [13:23<00:45, 1.80s/it]
|
| 674 |
95%|█████████▍| 445/469 [13:25<00:42, 1.79s/it]
|
| 675 |
95%|█████████▌| 446/469 [13:27<00:42, 1.85s/it]
|
| 676 |
95%|█████████▌| 447/469 [13:29<00:41, 1.86s/it]
|
| 677 |
96%|█████████▌| 448/469 [13:30<00:39, 1.88s/it]
|
| 678 |
96%|█████████▌| 449/469 [13:32<00:37, 1.86s/it]
|
| 679 |
96%|█████████▌| 450/469 [13:34<00:33, 1.79s/it]
|
| 680 |
96%|█████████▌| 451/469 [13:36<00:31, 1.74s/it]
|
| 681 |
96%|█████████▋| 452/469 [13:37<00:29, 1.73s/it]
|
| 682 |
97%|█████████▋| 453/469 [13:39<00:28, 1.76s/it]
|
| 683 |
97%|█████████▋| 454/469 [13:41<00:26, 1.74s/it]
|
| 684 |
97%|█████████▋| 455/469 [13:42<00:23, 1.70s/it]
|
| 685 |
97%|█████████▋| 456/469 [13:44<00:22, 1.71s/it]
|
| 686 |
97%|█████████▋| 457/469 [13:46<00:20, 1.71s/it]
|
| 687 |
98%|█████████▊| 458/469 [13:48<00:18, 1.71s/it]
|
| 688 |
98%|█████████▊| 459/469 [13:49<00:16, 1.69s/it]
|
| 689 |
98%|█████████▊| 460/469 [13:51<00:15, 1.68s/it]
|
| 690 |
98%|█████████▊| 461/469 [13:53<00:14, 1.76s/it]
|
| 691 |
99%|█████████▊| 462/469 [13:55<00:12, 1.78s/it]
|
| 692 |
99%|█████████▊| 463/469 [13:57<00:10, 1.83s/it]
|
| 693 |
99%|█████████▉| 464/469 [13:58<00:09, 1.86s/it]
|
| 694 |
99%|█████████▉| 465/469 [14:00<00:07, 1.81s/it]
|
| 695 |
99%|█████████▉| 466/469 [14:02<00:05, 1.78s/it]
|
| 696 |
+
computing/reading sample batch statistics...
|
| 697 |
+
Computing evaluations...
|
| 698 |
+
Inception Score: 38.328826904296875
|
| 699 |
+
FID: 21.82574123258769
|
| 700 |
+
sFID: 70.92829349483634
|
| 701 |
+
Precision: 0.6937
|
| 702 |
+
Recall: 0.3517815963698579
|
eval_rectified_noise_new_batch_2.log
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-23 16:28:27.563125: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
|
|
|
|
|
|
|
|
|
| 1 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 2 |
0%| | 1/211 [00:02<10:16, 2.94s/it]
|
| 3 |
1%| | 2/211 [00:04<08:15, 2.37s/it]
|
| 4 |
1%|▏ | 3/211 [00:06<07:38, 2.21s/it]
|
| 5 |
2%|▏ | 4/211 [00:08<07:20, 2.13s/it]
|
| 6 |
2%|▏ | 5/211 [00:11<07:25, 2.16s/it]
|
| 7 |
3%|▎ | 6/211 [00:12<06:59, 2.04s/it]
|
| 8 |
3%|▎ | 7/211 [00:15<07:15, 2.13s/it]
|
| 9 |
4%|▍ | 8/211 [00:17<06:55, 2.05s/it]
|
| 10 |
4%|▍ | 9/211 [00:19<06:53, 2.05s/it]
|
| 11 |
5%|▍ | 10/211 [00:21<06:37, 1.98s/it]
|
| 12 |
5%|▌ | 11/211 [00:22<06:18, 1.89s/it]
|
| 13 |
6%|▌ | 12/211 [00:24<06:13, 1.88s/it]
|
| 14 |
6%|▌ | 13/211 [00:26<06:08, 1.86s/it]
|
| 15 |
7%|▋ | 14/211 [00:28<06:02, 1.84s/it]
|
| 16 |
7%|▋ | 15/211 [00:29<05:57, 1.82s/it]
|
| 17 |
8%|▊ | 16/211 [00:31<05:55, 1.82s/it]
|
| 18 |
8%|▊ | 17/211 [00:33<06:00, 1.86s/it]
|
| 19 |
9%|▊ | 18/211 [00:35<06:01, 1.87s/it]
|
| 20 |
9%|▉ | 19/211 [00:37<05:59, 1.87s/it]
|
| 21 |
9%|▉ | 20/211 [00:39<05:51, 1.84s/it]
|
| 22 |
10%|▉ | 21/211 [00:41<05:45, 1.82s/it]
|
| 23 |
10%|█ | 22/211 [00:42<05:43, 1.82s/it]
|
| 24 |
11%|█ | 23/211 [00:44<05:41, 1.82s/it]
|
| 25 |
11%|█▏ | 24/211 [00:47<06:23, 2.05s/it]
|
| 26 |
12%|█▏ | 25/211 [00:50<07:07, 2.30s/it]
|
| 27 |
12%|█▏ | 26/211 [00:51<06:37, 2.15s/it]
|
| 28 |
13%|█▎ | 27/211 [00:53<06:14, 2.04s/it]
|
| 29 |
13%|█▎ | 28/211 [00:55<05:56, 1.95s/it]
|
| 30 |
14%|█▎ | 29/211 [00:57<05:56, 1.96s/it]
|
| 31 |
14%|█▍ | 30/211 [00:59<05:49, 1.93s/it]
|
| 32 |
15%|█▍ | 31/211 [01:01<05:49, 1.94s/it]
|
| 33 |
15%|█▌ | 32/211 [01:03<05:41, 1.91s/it]
|
| 34 |
16%|█▌ | 33/211 [01:05<05:42, 1.93s/it]
|
| 35 |
16%|█▌ | 34/211 [01:07<05:49, 1.97s/it]
|
| 36 |
17%|█▋ | 35/211 [01:09<06:11, 2.11s/it]
|
| 37 |
17%|█▋ | 36/211 [01:11<06:00, 2.06s/it]
|
| 38 |
18%|█▊ | 37/211 [01:13<05:46, 1.99s/it]
|
| 39 |
18%|█▊ | 38/211 [01:15<05:57, 2.07s/it]
|
| 40 |
18%|█▊ | 39/211 [01:17<05:44, 2.01s/it]
|
| 41 |
19%|█▉ | 40/211 [01:19<05:31, 1.94s/it]
|
| 42 |
19%|█▉ | 41/211 [01:21<05:24, 1.91s/it]
|
| 43 |
20%|█▉ | 42/211 [01:22<05:19, 1.89s/it]
|
| 44 |
20%|██ | 43/211 [01:24<05:20, 1.91s/it]
|
| 45 |
21%|██ | 44/211 [01:26<05:09, 1.86s/it]
|
| 46 |
21%|██▏ | 45/211 [01:29<05:50, 2.11s/it]
|
| 47 |
22%|██▏ | 46/211 [01:31<05:46, 2.10s/it]
|
| 48 |
22%|██▏ | 47/211 [01:33<05:33, 2.03s/it]
|
| 49 |
23%|██▎ | 48/211 [01:35<05:17, 1.95s/it]
|
| 50 |
23%|██▎ | 49/211 [01:36<05:14, 1.94s/it]
|
| 51 |
24%|██▎ | 50/211 [01:38<05:03, 1.88s/it]
|
| 52 |
24%|██▍ | 51/211 [01:40<05:19, 2.00s/it]
|
| 53 |
25%|██▍ | 52/211 [01:42<05:13, 1.97s/it]
|
| 54 |
25%|██▌ | 53/211 [01:45<05:19, 2.02s/it]
|
| 55 |
26%|██▌ | 54/211 [01:46<05:05, 1.95s/it]
|
| 56 |
26%|██▌ | 55/211 [01:48<04:55, 1.89s/it]
|
| 57 |
27%|██▋ | 56/211 [01:50<04:56, 1.92s/it]
|
| 58 |
27%|██▋ | 57/211 [01:52<04:48, 1.88s/it]
|
| 59 |
27%|██▋ | 58/211 [01:54<04:44, 1.86s/it]
|
| 60 |
28%|██▊ | 59/211 [01:56<04:47, 1.89s/it]
|
| 61 |
28%|██▊ | 60/211 [01:58<04:57, 1.97s/it]
|
| 62 |
29%|██▉ | 61/211 [02:00<04:46, 1.91s/it]
|
| 63 |
29%|██▉ | 62/211 [02:01<04:38, 1.87s/it]
|
| 64 |
30%|██▉ | 63/211 [02:04<05:05, 2.07s/it]
|
| 65 |
30%|███ | 64/211 [02:06<04:59, 2.04s/it]
|
| 66 |
31%|███ | 65/211 [02:08<05:05, 2.09s/it]
|
| 67 |
31%|███▏ | 66/211 [02:10<04:48, 1.99s/it]
|
| 68 |
32%|███▏ | 67/211 [02:12<04:37, 1.93s/it]
|
| 69 |
32%|███▏ | 68/211 [02:13<04:27, 1.87s/it]
|
| 70 |
33%|███▎ | 69/211 [02:15<04:25, 1.87s/it]
|
| 71 |
33%|███▎ | 70/211 [02:17<04:18, 1.83s/it]
|
| 72 |
34%|███▎ | 71/211 [02:19<04:15, 1.83s/it]
|
| 73 |
34%|███▍ | 72/211 [02:20<04:10, 1.80s/it]
|
| 74 |
35%|███▍ | 73/211 [02:22<04:08, 1.80s/it]
|
| 75 |
35%|███▌ | 74/211 [02:24<04:12, 1.84s/it]
|
| 76 |
36%|███▌ | 75/211 [02:26<04:10, 1.84s/it]
|
| 77 |
36%|███▌ | 76/211 [02:28<04:06, 1.83s/it]
|
| 78 |
36%|███▋ | 77/211 [02:30<04:11, 1.88s/it]
|
| 79 |
37%|███▋ | 78/211 [02:32<04:07, 1.86s/it]
|
| 80 |
37%|███▋ | 79/211 [02:34<04:08, 1.88s/it]
|
| 81 |
38%|███▊ | 80/211 [02:35<04:06, 1.88s/it]
|
| 82 |
38%|███▊ | 81/211 [02:37<04:11, 1.93s/it]
|
| 83 |
39%|███▉ | 82/211 [02:39<04:04, 1.89s/it]
|
| 84 |
39%|███▉ | 83/211 [02:41<04:06, 1.92s/it]
|
| 85 |
40%|███▉ | 84/211 [02:43<04:03, 1.92s/it]
|
| 86 |
40%|████ | 85/211 [02:45<03:55, 1.87s/it]
|
| 87 |
41%|████ | 86/211 [02:47<03:50, 1.84s/it]
|
| 88 |
41%|████ | 87/211 [02:48<03:45, 1.82s/it]
|
| 89 |
42%|████▏ | 88/211 [02:50<03:41, 1.80s/it]
|
| 90 |
42%|████▏ | 89/211 [02:52<03:36, 1.78s/it]
|
| 91 |
43%|████▎ | 90/211 [02:54<03:51, 1.91s/it]
|
| 92 |
43%|████▎ | 91/211 [02:56<03:45, 1.88s/it]
|
| 93 |
44%|████▎ | 92/211 [02:58<03:38, 1.84s/it]
|
| 94 |
44%|████▍ | 93/211 [03:00<03:37, 1.85s/it]
|
| 95 |
45%|████▍ | 94/211 [03:02<03:46, 1.94s/it]
|
| 96 |
45%|████▌ | 95/211 [03:03<03:36, 1.87s/it]
|
| 97 |
45%|████▌ | 96/211 [03:05<03:34, 1.87s/it]
|
| 98 |
46%|████▌ | 97/211 [03:08<03:58, 2.09s/it]
|
| 99 |
46%|████▋ | 98/211 [03:10<03:45, 1.99s/it]
|
| 100 |
47%|████▋ | 99/211 [03:11<03:34, 1.91s/it]
|
| 101 |
47%|████▋ | 100/211 [03:13<03:31, 1.91s/it]
|
| 102 |
48%|████▊ | 101/211 [03:15<03:28, 1.89s/it]
|
| 103 |
48%|████▊ | 102/211 [03:21<05:46, 3.18s/it]
|
| 104 |
49%|████▉ | 103/211 [03:23<04:58, 2.76s/it]
|
| 105 |
49%|████▉ | 104/211 [03:25<04:22, 2.45s/it]
|
| 106 |
50%|████▉ | 105/211 [03:27<04:03, 2.30s/it]
|
| 107 |
50%|█████ | 106/211 [03:29<03:45, 2.15s/it]
|
| 108 |
51%|█████ | 107/211 [03:31<03:35, 2.07s/it]
|
| 109 |
51%|█████ | 108/211 [03:32<03:29, 2.03s/it]
|
| 110 |
52%|█████▏ | 109/211 [03:34<03:21, 1.97s/it]
|
| 111 |
52%|█████▏ | 110/211 [03:37<03:36, 2.15s/it]
|
| 112 |
53%|█████▎ | 111/211 [03:39<03:23, 2.03s/it]
|
| 113 |
53%|█████▎ | 112/211 [03:41<03:17, 1.99s/it]
|
| 114 |
54%|█████▎ | 113/211 [03:42<03:10, 1.95s/it]
|
| 115 |
54%|█████▍ | 114/211 [03:44<03:03, 1.90s/it]
|
| 116 |
55%|█████▍ | 115/211 [03:46<03:02, 1.91s/it]
|
| 117 |
55%|█████▍ | 116/211 [03:48<02:59, 1.89s/it]
|
| 118 |
55%|█████▌ | 117/211 [03:50<03:03, 1.96s/it]
|
| 119 |
56%|█████▌ | 118/211 [03:52<02:56, 1.90s/it]
|
| 120 |
56%|█████▋ | 119/211 [03:54<02:51, 1.86s/it]
|
| 121 |
57%|█████▋ | 120/211 [03:55<02:49, 1.86s/it]
|
| 122 |
57%|█████▋ | 121/211 [03:57<02:52, 1.91s/it]
|
| 123 |
58%|█████▊ | 122/211 [03:59<02:47, 1.89s/it]
|
| 124 |
58%|█████▊ | 123/211 [04:01<02:51, 1.94s/it]
|
| 125 |
59%|█████▉ | 124/211 [04:03<02:47, 1.93s/it]
|
| 126 |
59%|█████▉ | 125/211 [04:05<02:40, 1.87s/it]
|
| 127 |
60%|█████▉ | 126/211 [04:07<02:46, 1.96s/it]
|
| 128 |
60%|██████ | 127/211 [04:09<02:43, 1.95s/it]
|
| 129 |
61%|██████ | 128/211 [04:11<02:42, 1.95s/it]
|
| 130 |
61%|██████ | 129/211 [04:13<02:47, 2.04s/it]
|
| 131 |
62%|██████▏ | 130/211 [04:15<02:41, 2.00s/it]
|
| 132 |
62%|██████▏ | 131/211 [04:17<02:39, 2.00s/it]
|
| 133 |
63%|██████▎ | 132/211 [04:19<02:36, 1.98s/it]
|
| 134 |
63%|██████▎ | 133/211 [04:21<02:31, 1.94s/it]
|
| 135 |
64%|██████▎ | 134/211 [04:23<02:26, 1.91s/it]
|
| 136 |
64%|██████▍ | 135/211 [04:25<02:21, 1.86s/it]
|
| 137 |
64%|██████▍ | 136/211 [04:26<02:20, 1.87s/it]
|
| 138 |
65%|██████▍ | 137/211 [04:29<02:23, 1.94s/it]
|
| 139 |
65%|██████▌ | 138/211 [04:30<02:18, 1.90s/it]
|
| 140 |
66%|██████▌ | 139/211 [04:32<02:13, 1.86s/it]
|
| 141 |
66%|██████▋ | 140/211 [04:34<02:10, 1.84s/it]
|
| 142 |
67%|██████▋ | 141/211 [04:36<02:07, 1.82s/it]
|
| 143 |
67%|██████▋ | 142/211 [04:38<02:06, 1.83s/it]
|
| 144 |
68%|██████▊ | 143/211 [04:39<02:06, 1.86s/it]
|
| 145 |
68%|██��███▊ | 144/211 [04:41<02:02, 1.83s/it]
|
| 146 |
69%|██████▊ | 145/211 [04:43<02:00, 1.82s/it]
|
| 147 |
69%|██████▉ | 146/211 [04:45<01:58, 1.82s/it]
|
| 148 |
70%|██████▉ | 147/211 [04:47<02:06, 1.98s/it]
|
| 149 |
70%|███████ | 148/211 [04:49<02:07, 2.02s/it]
|
| 150 |
71%|███████ | 149/211 [04:51<02:02, 1.97s/it]
|
| 151 |
71%|███████ | 150/211 [04:53<01:59, 1.96s/it]
|
| 152 |
72%|███████▏ | 151/211 [04:55<01:56, 1.95s/it]
|
| 153 |
72%|███████▏ | 152/211 [04:57<01:54, 1.94s/it]
|
| 154 |
73%|███████▎ | 153/211 [04:59<01:49, 1.88s/it]
|
| 155 |
73%|███████▎ | 154/211 [05:01<01:56, 2.04s/it]
|
| 156 |
73%|███████▎ | 155/211 [05:03<01:50, 1.97s/it]
|
| 157 |
74%|███████▍ | 156/211 [05:05<01:48, 1.97s/it]
|
| 158 |
74%|███████▍ | 157/211 [05:07<01:46, 1.96s/it]
|
| 159 |
75%|███████▍ | 158/211 [05:09<01:41, 1.92s/it]
|
| 160 |
75%|███████▌ | 159/211 [05:10<01:38, 1.90s/it]
|
| 161 |
76%|███████▌ | 160/211 [05:12<01:35, 1.88s/it]
|
| 162 |
76%|███████▋ | 161/211 [05:14<01:32, 1.85s/it]
|
| 163 |
77%|███████▋ | 162/211 [05:16<01:32, 1.89s/it]
|
| 164 |
77%|███████▋ | 163/211 [05:18<01:31, 1.90s/it]
|
| 165 |
78%|███████▊ | 164/211 [05:20<01:28, 1.88s/it]
|
| 166 |
78%|███████▊ | 165/211 [05:22<01:26, 1.88s/it]
|
| 167 |
79%|███████▊ | 166/211 [05:24<01:25, 1.89s/it]
|
| 168 |
79%|███████▉ | 167/211 [05:25<01:22, 1.88s/it]
|
| 169 |
80%|███████▉ | 168/211 [05:27<01:22, 1.92s/it]
|
| 170 |
80%|████████ | 169/211 [05:29<01:21, 1.94s/it]
|
| 171 |
81%|████████ | 170/211 [05:31<01:18, 1.92s/it]
|
| 172 |
81%|████████ | 171/211 [05:34<01:26, 2.15s/it]
|
| 173 |
82%|████████▏ | 172/211 [05:36<01:23, 2.14s/it]
|
| 174 |
82%|████████▏ | 173/211 [05:38<01:17, 2.05s/it]
|
| 175 |
82%|████████▏ | 174/211 [05:40<01:13, 1.99s/it]
|
| 176 |
83%|████████▎ | 175/211 [05:42<01:09, 1.93s/it]
|
| 177 |
83%|████████▎ | 176/211 [05:44<01:10, 2.02s/it]
|
| 178 |
84%|████████▍ | 177/211 [05:46<01:10, 2.06s/it]
|
| 179 |
84%|████████▍ | 178/211 [05:49<01:16, 2.33s/it]
|
| 180 |
85%|████████▍ | 179/211 [05:51<01:09, 2.16s/it]
|
| 181 |
85%|████████▌ | 180/211 [05:53<01:03, 2.05s/it]
|
| 182 |
86%|████████▌ | 181/211 [05:54<00:58, 1.96s/it]
|
| 183 |
86%|████████▋ | 182/211 [05:56<00:56, 1.95s/it]
|
| 184 |
87%|████████▋ | 183/211 [05:59<00:57, 2.05s/it]
|
| 185 |
87%|████████▋ | 184/211 [06:00<00:54, 2.01s/it]
|
| 186 |
88%|████████▊ | 185/211 [06:02<00:50, 1.95s/it]
|
| 187 |
88%|████████▊ | 186/211 [06:04<00:47, 1.89s/it]
|
| 188 |
89%|████████▊ | 187/211 [06:06<00:45, 1.88s/it]
|
| 189 |
89%|████████▉ | 188/211 [06:08<00:42, 1.86s/it]
|
| 190 |
90%|████████▉ | 189/211 [06:09<00:40, 1.83s/it]
|
| 191 |
90%|█████████ | 190/211 [06:11<00:38, 1.85s/it]
|
| 192 |
91%|█████████ | 191/211 [06:13<00:36, 1.85s/it]
|
| 193 |
91%|█████████ | 192/211 [06:15<00:34, 1.84s/it]
|
| 194 |
91%|█████████▏| 193/211 [06:17<00:34, 1.92s/it]
|
| 195 |
92%|█████████▏| 194/211 [06:19<00:32, 1.91s/it]
|
| 196 |
92%|█████████▏| 195/211 [06:21<00:31, 1.96s/it]
|
| 197 |
93%|█████████▎| 196/211 [06:23<00:28, 1.92s/it]
|
| 198 |
93%|█████████▎| 197/211 [06:25<00:27, 1.96s/it]
|
| 199 |
94%|█████████▍| 198/211 [06:27<00:25, 1.93s/it]
|
| 200 |
94%|█████████▍| 199/211 [06:29<00:22, 1.88s/it]
|
| 201 |
95%|█████████▍| 200/211 [06:30<00:20, 1.89s/it]
|
| 202 |
95%|█████████▌| 201/211 [06:32<00:18, 1.88s/it]
|
| 203 |
96%|█████████▌| 202/211 [06:34<00:16, 1.86s/it]
|
| 204 |
96%|█████████▌| 203/211 [06:36<00:15, 1.88s/it]
|
| 205 |
97%|█████████▋| 204/211 [06:38<00:13, 1.89s/it]
|
| 206 |
97%|█████████▋| 205/211 [06:40<00:11, 1.91s/it]
|
| 207 |
98%|█████████▊| 206/211 [06:42<00:10, 2.10s/it]
|
| 208 |
98%|█████████▊| 207/211 [06:44<00:08, 2.07s/it]
|
| 209 |
99%|█████████▊| 208/211 [06:46<00:05, 1.98s/it]
|
| 210 |
99%|█████████▉| 209/211 [06:49<00:04, 2.12s/it]
|
|
|
|
|
|
|
|
|
|
| 211 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 212 |
0%| | 1/469 [00:01<14:40, 1.88s/it]
|
| 213 |
0%| | 2/469 [00:04<16:02, 2.06s/it]
|
| 214 |
1%| | 3/469 [00:05<15:06, 1.95s/it]
|
| 215 |
1%| | 4/469 [00:07<14:44, 1.90s/it]
|
| 216 |
1%| | 5/469 [00:09<14:34, 1.88s/it]
|
| 217 |
1%|▏ | 6/469 [00:11<14:24, 1.87s/it]
|
| 218 |
1%|▏ | 7/469 [00:13<14:05, 1.83s/it]
|
| 219 |
2%|▏ | 8/469 [00:15<14:19, 1.87s/it]
|
| 220 |
2%|▏ | 9/469 [00:16<14:13, 1.85s/it]
|
| 221 |
2%|▏ | 10/469 [00:18<14:09, 1.85s/it]
|
| 222 |
2%|▏ | 11/469 [00:20<14:04, 1.84s/it]
|
| 223 |
3%|▎ | 12/469 [00:22<13:59, 1.84s/it]
|
| 224 |
3%|▎ | 13/469 [00:24<14:09, 1.86s/it]
|
| 225 |
3%|▎ | 14/469 [00:26<14:18, 1.89s/it]
|
| 226 |
3%|▎ | 15/469 [00:28<14:01, 1.85s/it]
|
| 227 |
3%|▎ | 16/469 [00:30<14:36, 1.94s/it]
|
| 228 |
4%|▎ | 17/469 [00:31<14:12, 1.89s/it]
|
| 229 |
4%|▍ | 18/469 [00:33<13:58, 1.86s/it]
|
| 230 |
4%|▍ | 19/469 [00:36<14:55, 1.99s/it]
|
| 231 |
4%|▍ | 20/469 [00:38<14:53, 1.99s/it]
|
| 232 |
4%|▍ | 21/469 [00:40<14:50, 1.99s/it]
|
| 233 |
5%|▍ | 22/469 [00:41<14:43, 1.98s/it]
|
| 234 |
5%|▍ | 23/469 [00:44<15:04, 2.03s/it]
|
| 235 |
5%|▌ | 24/469 [00:45<14:34, 1.97s/it]
|
| 236 |
5%|▌ | 25/469 [00:47<14:03, 1.90s/it]
|
| 237 |
6%|▌ | 26/469 [00:49<13:42, 1.86s/it]
|
| 238 |
6%|▌ | 27/469 [00:51<13:31, 1.84s/it]
|
| 239 |
6%|▌ | 28/469 [00:53<13:35, 1.85s/it]
|
| 240 |
6%|▌ | 29/469 [00:54<13:24, 1.83s/it]
|
| 241 |
6%|▋ | 30/469 [00:56<13:30, 1.85s/it]
|
| 242 |
7%|▋ | 31/469 [00:58<13:10, 1.80s/it]
|
| 243 |
7%|▋ | 32/469 [01:00<13:43, 1.88s/it]
|
| 244 |
7%|▋ | 33/469 [01:02<13:23, 1.84s/it]
|
| 245 |
7%|▋ | 34/469 [01:05<15:15, 2.10s/it]
|
| 246 |
7%|▋ | 35/469 [01:06<14:40, 2.03s/it]
|
| 247 |
8%|▊ | 36/469 [01:08<14:23, 1.99s/it]
|
| 248 |
8%|▊ | 37/469 [01:10<13:51, 1.93s/it]
|
| 249 |
8%|▊ | 38/469 [01:12<13:26, 1.87s/it]
|
| 250 |
8%|▊ | 39/469 [01:14<13:13, 1.85s/it]
|
| 251 |
9%|▊ | 40/469 [01:15<13:15, 1.85s/it]
|
| 252 |
9%|▊ | 41/469 [01:17<13:14, 1.86s/it]
|
| 253 |
9%|▉ | 42/469 [01:19<12:58, 1.82s/it]
|
| 254 |
9%|▉ | 43/469 [01:21<13:04, 1.84s/it]
|
| 255 |
9%|▉ | 44/469 [01:23<12:55, 1.82s/it]
|
| 256 |
10%|▉ | 45/469 [01:25<12:56, 1.83s/it]
|
| 257 |
10%|▉ | 46/469 [01:26<12:50, 1.82s/it]
|
| 258 |
10%|█ | 47/469 [01:28<12:47, 1.82s/it]
|
| 259 |
10%|█ | 48/469 [01:30<12:41, 1.81s/it]
|
| 260 |
10%|█ | 49/469 [01:32<12:52, 1.84s/it]
|
| 261 |
11%|█ | 50/469 [01:34<13:00, 1.86s/it]
|
| 262 |
11%|█ | 51/469 [01:36<12:46, 1.83s/it]
|
| 263 |
11%|█ | 52/469 [01:37<12:41, 1.83s/it]
|
| 264 |
11%|█▏ | 53/469 [01:39<12:40, 1.83s/it]
|
| 265 |
12%|█▏ | 54/469 [01:41<12:46, 1.85s/it]
|
| 266 |
12%|█▏ | 55/469 [01:44<15:33, 2.25s/it]
|
| 267 |
12%|█▏ | 56/469 [01:46<14:24, 2.09s/it]
|
| 268 |
12%|█▏ | 57/469 [01:48<13:41, 1.99s/it]
|
| 269 |
12%|█▏ | 58/469 [01:50<13:22, 1.95s/it]
|
| 270 |
13%|█▎ | 59/469 [01:51<12:54, 1.89s/it]
|
| 271 |
13%|█▎ | 60/469 [01:53<12:33, 1.84s/it]
|
| 272 |
13%|█▎ | 61/469 [01:55<12:23, 1.82s/it]
|
| 273 |
13%|█▎ | 62/469 [01:57<12:29, 1.84s/it]
|
| 274 |
13%|█▎ | 63/469 [01:59<12:42, 1.88s/it]
|
| 275 |
14%|█▎ | 64/469 [02:01<12:43, 1.89s/it]
|
| 276 |
14%|█▍ | 65/469 [02:02<12:29, 1.86s/it]
|
| 277 |
14%|█▍ | 66/469 [02:04<12:49, 1.91s/it]
|
| 278 |
14%|█▍ | 67/469 [02:06<12:50, 1.92s/it]
|
| 279 |
14%|█▍ | 68/469 [02:08<12:34, 1.88s/it]
|
| 280 |
15%|█▍ | 69/469 [02:10<12:22, 1.86s/it]
|
| 281 |
15%|█▍ | 70/469 [02:12<12:20, 1.86s/it]
|
| 282 |
15%|█▌ | 71/469 [02:14<12:10, 1.84s/it]
|
| 283 |
15%|█▌ | 72/469 [02:15<12:06, 1.83s/it]
|
| 284 |
16%|█▌ | 73/469 [02:17<12:00, 1.82s/it]
|
| 285 |
16%|█▌ | 74/469 [02:19<12:09, 1.85s/it]
|
| 286 |
16%|█▌ | 75/469 [02:21<12:18, 1.87s/it]
|
| 287 |
16%|█▌ | 76/469 [02:23<12:13, 1.87s/it]
|
| 288 |
16%|█▋ | 77/469 [02:25<12:16, 1.88s/it]
|
| 289 |
17%|█▋ | 78/469 [02:27<12:45, 1.96s/it]
|
| 290 |
17%|█▋ | 79/469 [02:29<12:17, 1.89s/it]
|
| 291 |
17%|█▋ | 80/469 [02:31<12:07, 1.87s/it]
|
| 292 |
17%|█▋ | 81/469 [02:32<11:53, 1.84s/it]
|
| 293 |
17%|█▋ | 82/469 [02:34<12:09, 1.89s/it]
|
| 294 |
18%|█▊ | 83/469 [02:36<12:01, 1.87s/it]
|
| 295 |
18%|█▊ | 84/469 [02:38<11:45, 1.83s/it]
|
| 296 |
18%|█▊ | 85/469 [02:40<12:13, 1.91s/it]
|
| 297 |
18%|█▊ | 86/469 [02:42<12:12, 1.91s/it]
|
| 298 |
19%|█▊ | 87/469 [02:44<11:53, 1.87s/it]
|
| 299 |
19%|█▉ | 88/469 [02:45<11:32, 1.82s/it]
|
| 300 |
19%|█▉ | 89/469 [02:47<11:28, 1.81s/it]
|
| 301 |
19%|█▉ | 90/469 [02:49<11:52, 1.88s/it]
|
| 302 |
19%|█▉ | 91/469 [02:51<11:34, 1.84s/it]
|
| 303 |
20%|█▉ | 92/469 [02:53<11:36, 1.85s/it]
|
| 304 |
20%|█▉ | 93/469 [02:55<11:36, 1.85s/it]
|
| 305 |
20%|██ | 94/469 [02:56<11:27, 1.83s/it]
|
| 306 |
20%|██ | 95/469 [02:58<11:24, 1.83s/it]
|
| 307 |
20%|██ | 96/469 [03:00<11:19, 1.82s/it]
|
| 308 |
21%|██ | 97/469 [03:02<11:25, 1.84s/it]
|
| 309 |
21%|██ | 98/469 [03:04<11:45, 1.90s/it]
|
| 310 |
21%|██ | 99/469 [03:06<11:45, 1.91s/it]
|
| 311 |
21%|██▏ | 100/469 [03:08<11:30, 1.87s/it]
|
| 312 |
22%|██▏ | 101/469 [03:10<12:35, 2.05s/it]
|
| 313 |
22%|██▏ | 102/469 [03:12<12:40, 2.07s/it]
|
| 314 |
22%|██▏ | 103/469 [03:14<12:10, 2.00s/it]
|
| 315 |
22%|██▏ | 104/469 [03:16<11:49, 1.94s/it]
|
| 316 |
22%|██▏ | 105/469 [03:18<11:38, 1.92s/it]
|
| 317 |
23%|██▎ | 106/469 [03:20<11:50, 1.96s/it]
|
| 318 |
23%|██▎ | 107/469 [03:22<11:29, 1.91s/it]
|
| 319 |
23%|██▎ | 108/469 [03:24<12:19, 2.05s/it]
|
| 320 |
23%|██▎ | 109/469 [03:26<11:50, 1.97s/it]
|
| 321 |
23%|██▎ | 110/469 [03:28<12:21, 2.06s/it]
|
| 322 |
24%|██▎ | 111/469 [03:30<12:00, 2.01s/it]
|
| 323 |
24%|██▍ | 112/469 [03:32<11:37, 1.95s/it]
|
| 324 |
24%|██▍ | 113/469 [03:34<11:13, 1.89s/it]
|
| 325 |
24%|██▍ | 114/469 [03:36<11:26, 1.93s/it]
|
| 326 |
25%|██▍ | 115/469 [03:37<11:06, 1.88s/it]
|
| 327 |
25%|██▍ | 116/469 [03:39<11:27, 1.95s/it]
|
| 328 |
25%|██▍ | 117/469 [03:42<12:08, 2.07s/it]
|
| 329 |
25%|██▌ | 118/469 [03:44<11:37, 1.99s/it]
|
| 330 |
25%|██▌ | 119/469 [03:45<11:14, 1.93s/it]
|
| 331 |
26%|██▌ | 120/469 [03:47<10:53, 1.87s/it]
|
| 332 |
26%|██▌ | 121/469 [03:49<10:39, 1.84s/it]
|
| 333 |
26%|██▌ | 122/469 [03:51<10:28, 1.81s/it]
|
| 334 |
26%|██▌ | 123/469 [03:52<10:23, 1.80s/it]
|
| 335 |
26%|██▋ | 124/469 [03:54<10:16, 1.79s/it]
|
| 336 |
27%|██▋ | 125/469 [03:56<10:26, 1.82s/it]
|
| 337 |
27%|██▋ | 126/469 [03:58<10:19, 1.81s/it]
|
| 338 |
27%|██▋ | 127/469 [04:00<10:05, 1.77s/it]
|
| 339 |
27%|██▋ | 128/469 [04:01<10:13, 1.80s/it]
|
| 340 |
28%|██▊ | 129/469 [04:04<10:47, 1.91s/it]
|
| 341 |
28%|██▊ | 130/469 [04:05<10:30, 1.86s/it]
|
| 342 |
28%|██▊ | 131/469 [04:08<11:02, 1.96s/it]
|
| 343 |
28%|██▊ | 132/469 [04:10<11:07, 1.98s/it]
|
| 344 |
28%|██▊ | 133/469 [04:11<11:00, 1.97s/it]
|
| 345 |
29%|██▊ | 134/469 [04:13<10:36, 1.90s/it]
|
| 346 |
29%|██▉ | 135/469 [04:15<10:33, 1.90s/it]
|
| 347 |
29%|██▉ | 136/469 [04:17<10:28, 1.89s/it]
|
| 348 |
29%|██▉ | 137/469 [04:23<16:46, 3.03s/it]
|
| 349 |
29%|██▉ | 138/469 [04:24<14:39, 2.66s/it]
|
| 350 |
30%|██▉ | 139/469 [04:27<13:43, 2.49s/it]
|
| 351 |
30%|██▉ | 140/469 [04:29<13:38, 2.49s/it]
|
| 352 |
30%|███ | 141/469 [04:31<12:21, 2.26s/it]
|
| 353 |
30%|███ | 142/469 [04:33<11:27, 2.10s/it]
|
| 354 |
30%|███ | 143/469 [04:34<10:47, 1.98s/it]
|
| 355 |
31%|███ | 144/469 [04:36<10:23, 1.92s/it]
|
| 356 |
31%|███ | 145/469 [04:38<10:18, 1.91s/it]
|
| 357 |
31%|███ | 146/469 [04:40<10:36, 1.97s/it]
|
| 358 |
31%|███▏ | 147/469 [04:42<10:39, 1.99s/it]
|
| 359 |
32%|███▏ | 148/469 [04:44<10:50, 2.03s/it]
|
| 360 |
32%|███▏ | 149/469 [04:46<10:30, 1.97s/it]
|
| 361 |
32%|███▏ | 150/469 [04:48<10:12, 1.92s/it]
|
| 362 |
32%|███▏ | 151/469 [04:50<09:55, 1.87s/it]
|
| 363 |
32%|███▏ | 152/469 [04:51<10:00, 1.89s/it]
|
| 364 |
33%|███▎ | 153/469 [04:53<10:00, 1.90s/it]
|
| 365 |
33%|███▎ | 154/469 [04:55<09:56, 1.89s/it]
|
| 366 |
33%|███▎ | 155/469 [04:57<09:55, 1.90s/it]
|
| 367 |
33%|███▎ | 156/469 [04:59<09:38, 1.85s/it]
|
| 368 |
33%|███▎ | 157/469 [05:01<09:29, 1.83s/it]
|
| 369 |
34%|███▎ | 158/469 [05:03<09:30, 1.84s/it]
|
| 370 |
34%|███▍ | 159/469 [05:04<09:18, 1.80s/it]
|
| 371 |
34%|███▍ | 160/469 [05:07<09:59, 1.94s/it]
|
| 372 |
34%|███▍ | 161/469 [05:08<09:52, 1.93s/it]
|
| 373 |
35%|███▍ | 162/469 [05:10<09:53, 1.93s/it]
|
| 374 |
35%|███▍ | 163/469 [05:12<09:41, 1.90s/it]
|
| 375 |
35%|███▍ | 164/469 [05:14<09:34, 1.88s/it]
|
| 376 |
35%|███▌ | 165/469 [05:16<09:32, 1.88s/it]
|
| 377 |
35%|███▌ | 166/469 [05:18<09:23, 1.86s/it]
|
| 378 |
36%|███▌ | 167/469 [05:20<09:39, 1.92s/it]
|
| 379 |
36%|███▌ | 168/469 [05:22<09:30, 1.90s/it]
|
| 380 |
36%|███▌ | 169/469 [05:23<09:21, 1.87s/it]
|
| 381 |
36%|███▌ | 170/469 [05:25<09:10, 1.84s/it]
|
| 382 |
36%|███▋ | 171/469 [05:27<09:20, 1.88s/it]
|
| 383 |
37%|███▋ | 172/469 [05:29<09:07, 1.84s/it]
|
| 384 |
37%|███▋ | 173/469 [05:31<09:01, 1.83s/it]
|
| 385 |
37%|███▋ | 174/469 [05:33<09:00, 1.83s/it]
|
| 386 |
37%|███▋ | 175/469 [05:34<08:49, 1.80s/it]
|
| 387 |
38%|███▊ | 176/469 [05:36<09:09, 1.88s/it]
|
| 388 |
38%|███▊ | 177/469 [05:38<09:08, 1.88s/it]
|
| 389 |
38%|███▊ | 178/469 [05:40<09:00, 1.86s/it]
|
| 390 |
38%|███▊ | 179/469 [05:42<08:51, 1.83s/it]
|
| 391 |
38%|███▊ | 180/469 [05:44<08:51, 1.84s/it]
|
| 392 |
39%|███▊ | 181/469 [05:46<10:03, 2.10s/it]
|
| 393 |
39%|███▉ | 182/469 [05:48<09:44, 2.04s/it]
|
| 394 |
39%|███▉ | 183/469 [05:50<09:21, 1.96s/it]
|
| 395 |
39%|███▉ | 184/469 [05:52<09:23, 1.98s/it]
|
| 396 |
39%|███▉ | 185/469 [05:54<09:03, 1.92s/it]
|
| 397 |
40%|███▉ | 186/469 [05:56<08:53, 1.88s/it]
|
| 398 |
40%|███▉ | 187/469 [05:57<08:40, 1.85s/it]
|
| 399 |
40%|████ | 188/469 [05:59<08:58, 1.92s/it]
|
| 400 |
40%|████ | 189/469 [06:01<08:50, 1.89s/it]
|
| 401 |
41%|████ | 190/469 [06:03<08:55, 1.92s/it]
|
| 402 |
41%|████ | 191/469 [06:05<08:49, 1.91s/it]
|
| 403 |
41%|████ | 192/469 [06:07<08:48, 1.91s/it]
|
| 404 |
41%|████ | 193/469 [06:09<08:43, 1.90s/it]
|
| 405 |
41%|████▏ | 194/469 [06:11<08:31, 1.86s/it]
|
| 406 |
42%|████▏ | 195/469 [06:12<08:18, 1.82s/it]
|
| 407 |
42%|████▏ | 196/469 [06:14<08:06, 1.78s/it]
|
| 408 |
42%|████▏ | 197/469 [06:16<08:05, 1.79s/it]
|
| 409 |
42%|████▏ | 198/469 [06:18<08:14, 1.83s/it]
|
| 410 |
42%|████▏ | 199/469 [06:20<08:11, 1.82s/it]
|
| 411 |
43%|████▎ | 200/469 [06:22<08:36, 1.92s/it]
|
| 412 |
43%|████▎ | 201/469 [06:24<08:25, 1.89s/it]
|
| 413 |
43%|████▎ | 202/469 [06:26<08:28, 1.91s/it]
|
| 414 |
43%|████▎ | 203/469 [06:27<08:24, 1.90s/it]
|
| 415 |
43%|████▎ | 204/469 [06:29<08:31, 1.93s/it]
|
| 416 |
44%|████▎ | 205/469 [06:31<08:20, 1.89s/it]
|
| 417 |
44%|████▍ | 206/469 [06:33<08:09, 1.86s/it]
|
| 418 |
44%|████▍ | 207/469 [06:35<08:13, 1.89s/it]
|
| 419 |
44%|████▍ | 208/469 [06:37<08:21, 1.92s/it]
|
| 420 |
45%|████▍ | 209/469 [06:39<08:55, 2.06s/it]
|
| 421 |
45%|████▍ | 210/469 [06:41<08:32, 1.98s/it]
|
| 422 |
45%|████▍ | 211/469 [06:43<08:21, 1.94s/it]
|
| 423 |
45%|████▌ | 212/469 [06:45<08:19, 1.94s/it]
|
| 424 |
45%|████▌ | 213/469 [06:47<08:13, 1.93s/it]
|
| 425 |
46%|████▌ | 214/469 [06:49<08:23, 1.98s/it]
|
| 426 |
46%|████▌ | 215/469 [06:51<08:15, 1.95s/it]
|
| 427 |
46%|████▌ | 216/469 [06:53<08:02, 1.91s/it]
|
| 428 |
46%|████▋ | 217/469 [06:55<07:57, 1.90s/it]
|
| 429 |
46%|████▋ | 218/469 [06:56<07:51, 1.88s/it]
|
| 430 |
47%|████▋ | 219/469 [06:58<07:40, 1.84s/it]
|
| 431 |
47%|████▋ | 220/469 [07:00<07:52, 1.90s/it]
|
| 432 |
47%|████▋ | 221/469 [07:02<07:44, 1.87s/it]
|
| 433 |
47%|████▋ | 222/469 [07:04<08:09, 1.98s/it]
|
| 434 |
48%|████▊ | 223/469 [07:07<09:19, 2.27s/it]
|
| 435 |
48%|████▊ | 224/469 [07:09<08:42, 2.13s/it]
|
| 436 |
48%|████▊ | 225/469 [07:11<08:22, 2.06s/it]
|
| 437 |
48%|████▊ | 226/469 [07:13<08:01, 1.98s/it]
|
| 438 |
48%|████▊ | 227/469 [07:15<07:51, 1.95s/it]
|
| 439 |
49%|████▊ | 228/469 [07:16<07:43, 1.92s/it]
|
| 440 |
49%|████▉ | 229/469 [07:18<07:35, 1.90s/it]
|
| 441 |
49%|████▉ | 230/469 [07:20<07:28, 1.88s/it]
|
| 442 |
49%|████▉ | 231/469 [07:22<07:44, 1.95s/it]
|
| 443 |
49%|████▉ | 232/469 [07:24<07:32, 1.91s/it]
|
| 444 |
50%|████▉ | 233/469 [07:26<07:24, 1.88s/it]
|
| 445 |
50%|████▉ | 234/469 [07:28<07:26, 1.90s/it]
|
| 446 |
50%|█████ | 235/469 [07:30<07:18, 1.88s/it]
|
| 447 |
50%|█████ | 236/469 [07:32<07:21, 1.89s/it]
|
| 448 |
51%|█████ | 237/469 [07:33<07:15, 1.88s/it]
|
| 449 |
51%|█████ | 238/469 [07:35<07:31, 1.95s/it]
|
| 450 |
51%|█████ | 239/469 [07:37<07:25, 1.94s/it]
|
| 451 |
51%|█████ | 240/469 [07:39<07:26, 1.95s/it]
|
| 452 |
51%|█████▏ | 241/469 [07:42<08:39, 2.28s/it]
|
| 453 |
52%|█████▏ | 242/469 [07:44<08:17, 2.19s/it]
|
| 454 |
52%|█████▏ | 243/469 [07:47<08:19, 2.21s/it]
|
| 455 |
52%|█████▏ | 244/469 [07:49<07:57, 2.12s/it]
|
| 456 |
52%|█████▏ | 245/469 [07:50<07:35, 2.03s/it]
|
| 457 |
52%|█████▏ | 246/469 [07:52<07:26, 2.00s/it]
|
| 458 |
53%|█████▎ | 247/469 [07:54<07:19, 1.98s/it]
|
| 459 |
53%|█████▎ | 248/469 [07:56<07:11, 1.95s/it]
|
| 460 |
53%|█████▎ | 249/469 [07:58<07:15, 1.98s/it]
|
| 461 |
53%|█████▎ | 250/469 [08:00<06:58, 1.91s/it]
|
| 462 |
54%|█████▎ | 251/469 [08:02<06:51, 1.89s/it]
|
| 463 |
54%|█████▎ | 252/469 [08:04<06:42, 1.85s/it]
|
| 464 |
54%|█████▍ | 253/469 [08:05<06:33, 1.82s/it]
|
| 465 |
54%|█████▍ | 254/469 [08:07<06:29, 1.81s/it]
|
| 466 |
54%|█████▍ | 255/469 [08:09<06:32, 1.83s/it]
|
| 467 |
55%|█████▍ | 256/469 [08:11<06:24, 1.80s/it]
|
| 468 |
55%|█████▍ | 257/469 [08:12<06:18, 1.79s/it]
|
| 469 |
55%|█████▌ | 258/469 [08:14<06:25, 1.83s/it]
|
| 470 |
55%|█████▌ | 259/469 [08:16<06:25, 1.83s/it]
|
| 471 |
55%|█████▌ | 260/469 [08:18<06:19, 1.82s/it]
|
| 472 |
56%|█████▌ | 261/469 [08:20<06:42, 1.94s/it]
|
| 473 |
56%|█████▌ | 262/469 [08:22<06:32, 1.90s/it]
|
| 474 |
56%|█████▌ | 263/469 [08:24<06:25, 1.87s/it]
|
| 475 |
56%|█████▋ | 264/469 [08:26<06:17, 1.84s/it]
|
| 476 |
57%|█████▋ | 265/469 [08:28<06:20, 1.87s/it]
|
| 477 |
57%|█████▋ | 266/469 [08:29<06:17, 1.86s/it]
|
| 478 |
57%|█████▋ | 267/469 [08:31<06:16, 1.86s/it]
|
| 479 |
57%|█████▋ | 268/469 [08:33<06:12, 1.86s/it]
|
| 480 |
57%|█████▋ | 269/469 [08:35<06:41, 2.01s/it]
|
| 481 |
58%|█████▊ | 270/469 [08:37<06:24, 1.93s/it]
|
| 482 |
58%|█████▊ | 271/469 [08:39<06:10, 1.87s/it]
|
| 483 |
58%|█████▊ | 272/469 [08:41<06:09, 1.87s/it]
|
| 484 |
58%|█████▊ | 273/469 [08:43<06:24, 1.96s/it]
|
| 485 |
58%|█████▊ | 274/469 [08:45<06:11, 1.90s/it]
|
| 486 |
59%|█████▊ | 275/469 [08:47<06:10, 1.91s/it]
|
| 487 |
59%|█████▉ | 276/469 [08:48<06:03, 1.88s/it]
|
| 488 |
59%|█████▉ | 277/469 [08:50<06:01, 1.88s/it]
|
| 489 |
59%|█████▉ | 278/469 [08:52<06:08, 1.93s/it]
|
| 490 |
59%|█████▉ | 279/469 [08:54<05:57, 1.88s/it]
|
| 491 |
60%|█████▉ | 280/469 [08:56<06:11, 1.97s/it]
|
| 492 |
60%|█████▉ | 281/469 [08:58<06:07, 1.95s/it]
|
| 493 |
60%|██████ | 282/469 [09:00<05:57, 1.91s/it]
|
| 494 |
60%|██████ | 283/469 [09:02<05:48, 1.87s/it]
|
| 495 |
61%|██████ | 284/469 [09:04<05:44, 1.86s/it]
|
| 496 |
61%|██████ | 285/469 [09:06<05:45, 1.88s/it]
|
| 497 |
61%|██████ | 286/469 [09:08<06:11, 2.03s/it]
|
| 498 |
61%|██████ | 287/469 [09:10<06:16, 2.07s/it]
|
| 499 |
61%|██████▏ | 288/469 [09:12<05:59, 1.99s/it]
|
| 500 |
62%|██████▏ | 289/469 [09:14<05:50, 1.95s/it]
|
| 501 |
62%|██████▏ | 290/469 [09:16<05:39, 1.90s/it]
|
| 502 |
62%|██████▏ | 291/469 [09:17<05:28, 1.85s/it]
|
| 503 |
62%|██████▏ | 292/469 [09:19<05:25, 1.84s/it]
|
| 504 |
62%|██████▏ | 293/469 [09:21<05:23, 1.84s/it]
|
| 505 |
63%|██████▎ | 294/469 [09:23<05:31, 1.89s/it]
|
| 506 |
63%|██████▎ | 295/469 [09:25<05:20, 1.84s/it]
|
| 507 |
63%|██████▎ | 296/469 [09:26<05:13, 1.81s/it]
|
| 508 |
63%|██████▎ | 297/469 [09:28<05:08, 1.80s/it]
|
| 509 |
64%|██████▎ | 298/469 [09:30<05:04, 1.78s/it]
|
| 510 |
64%|██████▍ | 299/469 [09:32<05:05, 1.80s/it]
|
| 511 |
64%|██████▍ | 300/469 [09:34<05:08, 1.82s/it]
|
| 512 |
64%|██████▍ | 301/469 [09:36<05:07, 1.83s/it]
|
| 513 |
64%|██████▍ | 302/469 [09:37<05:04, 1.82s/it]
|
| 514 |
65%|██████▍ | 303/469 [09:39<05:06, 1.84s/it]
|
| 515 |
65%|██████▍ | 304/469 [09:41<05:06, 1.86s/it]
|
| 516 |
65%|██████▌ | 305/469 [09:43<05:02, 1.84s/it]
|
| 517 |
65%|██████▌ | 306/469 [09:45<05:01, 1.85s/it]
|
| 518 |
65%|██████▌ | 307/469 [09:47<05:00, 1.85s/it]
|
| 519 |
66%|██████▌ | 308/469 [09:49<05:02, 1.88s/it]
|
| 520 |
66%|██████▌ | 309/469 [09:50<05:00, 1.88s/it]
|
| 521 |
66%|██████▌ | 310/469 [09:52<04:53, 1.85s/it]
|
| 522 |
66%|██████▋ | 311/469 [09:54<04:54, 1.86s/it]
|
| 523 |
67%|██████▋ | 312/469 [09:56<04:58, 1.90s/it]
|
| 524 |
67%|██████▋ | 313/469 [09:58<04:50, 1.86s/it]
|
| 525 |
67%|██████▋ | 314/469 [10:00<04:51, 1.88s/it]
|
| 526 |
67%|██████▋ | 315/469 [10:02<04:52, 1.90s/it]
|
| 527 |
67%|██████▋ | 316/469 [10:04<04:53, 1.92s/it]
|
| 528 |
68%|██████▊ | 317/469 [10:06<04:45, 1.88s/it]
|
| 529 |
68%|██████▊ | 318/469 [10:08<04:57, 1.97s/it]
|
| 530 |
68%|██████▊ | 319/469 [10:09<04:46, 1.91s/it]
|
| 531 |
68%|██████▊ | 320/469 [10:11<04:46, 1.92s/it]
|
| 532 |
68%|██████▊ | 321/469 [10:14<05:08, 2.09s/it]
|
| 533 |
69%|██████▊ | 322/469 [10:16<05:13, 2.14s/it]
|
| 534 |
69%|██████▉ | 323/469 [10:18<05:02, 2.07s/it]
|
| 535 |
69%|██████▉ | 324/469 [10:20<04:52, 2.02s/it]
|
| 536 |
69%|██████▉ | 325/469 [10:22<04:42, 1.96s/it]
|
| 537 |
70%|██████▉ | 326/469 [10:24<04:34, 1.92s/it]
|
| 538 |
70%|██████▉ | 327/469 [10:26<04:41, 1.98s/it]
|
| 539 |
70%|██████▉ | 328/469 [10:28<04:31, 1.93s/it]
|
| 540 |
70%|███████ | 329/469 [10:29<04:24, 1.89s/it]
|
| 541 |
70%|███████ | 330/469 [10:31<04:19, 1.87s/it]
|
| 542 |
71%|███████ | 331/469 [10:33<04:13, 1.84s/it]
|
| 543 |
71%|███████ | 332/469 [10:35<04:24, 1.93s/it]
|
| 544 |
71%|███████ | 333/469 [10:37<04:16, 1.89s/it]
|
| 545 |
71%|███████ | 334/469 [10:42<06:23, 2.84s/it]
|
| 546 |
71%|███████▏ | 335/469 [10:44<05:35, 2.51s/it]
|
| 547 |
72%|███████▏ | 336/469 [10:45<05:04, 2.29s/it]
|
| 548 |
72%|███████▏ | 337/469 [10:47<04:42, 2.14s/it]
|
| 549 |
72%|███████▏ | 338/469 [10:49<04:41, 2.15s/it]
|
| 550 |
72%|███████▏ | 339/469 [10:51<04:33, 2.11s/it]
|
| 551 |
72%|███████▏ | 340/469 [10:53<04:21, 2.03s/it]
|
| 552 |
73%|███████▎ | 341/469 [10:55<04:17, 2.01s/it]
|
| 553 |
73%|███████▎ | 342/469 [10:57<04:19, 2.04s/it]
|
| 554 |
73%|███████▎ | 343/469 [10:59<04:18, 2.05s/it]
|
| 555 |
73%|███████▎ | 344/469 [11:01<04:04, 1.96s/it]
|
| 556 |
74%|███████▎ | 345/469 [11:03<04:00, 1.94s/it]
|
| 557 |
74%|███████▍ | 346/469 [11:05<03:52, 1.89s/it]
|
| 558 |
74%|███████▍ | 347/469 [11:07<03:45, 1.85s/it]
|
| 559 |
74%|███████▍ | 348/469 [11:08<03:39, 1.82s/it]
|
| 560 |
74%|███████▍ | 349/469 [11:10<03:38, 1.82s/it]
|
| 561 |
75%|███████▍ | 350/469 [11:12<03:42, 1.87s/it]
|
| 562 |
75%|███████▍ | 351/469 [11:14<03:44, 1.90s/it]
|
| 563 |
75%|███████▌ | 352/469 [11:16<03:40, 1.88s/it]
|
| 564 |
75%|███████▌ | 353/469 [11:25<07:38, 3.95s/it]
|
| 565 |
75%|███████▌ | 354/469 [11:27<06:26, 3.36s/it]
|
| 566 |
76%|███████▌ | 355/469 [11:29<05:30, 2.90s/it]
|
| 567 |
76%|███████▌ | 356/469 [11:31<05:08, 2.73s/it]
|
| 568 |
76%|███████▌ | 357/469 [11:33<04:45, 2.55s/it]
|
| 569 |
76%|███████▋ | 358/469 [11:35<04:23, 2.37s/it]
|
| 570 |
77%|██��████▋ | 359/469 [11:37<04:10, 2.28s/it]
|
| 571 |
77%|███████▋ | 360/469 [11:39<03:51, 2.12s/it]
|
| 572 |
77%|███████▋ | 361/469 [11:41<03:37, 2.02s/it]
|
| 573 |
77%|███████▋ | 362/469 [12:00<13:03, 7.32s/it]
|
| 574 |
77%|███████▋ | 363/469 [12:02<10:03, 5.69s/it]
|
| 575 |
78%|███████▊ | 364/469 [12:04<07:55, 4.53s/it]
|
| 576 |
78%|███████▊ | 365/469 [12:06<06:27, 3.73s/it]
|
| 577 |
78%|███████▊ | 366/469 [12:08<05:26, 3.17s/it]
|
| 578 |
78%|███████▊ | 367/469 [12:09<04:40, 2.75s/it]
|
| 579 |
78%|███████▊ | 368/469 [12:12<04:21, 2.59s/it]
|
| 580 |
79%|███████▊ | 369/469 [12:13<03:55, 2.36s/it]
|
| 581 |
79%|███████▉ | 370/469 [12:15<03:43, 2.25s/it]
|
| 582 |
79%|███████▉ | 371/469 [12:17<03:28, 2.13s/it]
|
| 583 |
79%|███████▉ | 372/469 [12:19<03:21, 2.07s/it]
|
| 584 |
80%|███████▉ | 373/469 [12:21<03:10, 1.99s/it]
|
| 585 |
80%|███████▉ | 374/469 [12:23<03:02, 1.92s/it]
|
| 586 |
80%|███████▉ | 375/469 [12:25<02:55, 1.86s/it]
|
| 587 |
80%|████████ | 376/469 [12:26<02:51, 1.84s/it]
|
| 588 |
80%|████████ | 377/469 [12:28<02:51, 1.86s/it]
|
| 589 |
81%|████████ | 378/469 [12:30<02:47, 1.84s/it]
|
| 590 |
81%|████████ | 379/469 [12:32<02:43, 1.81s/it]
|
| 591 |
81%|████████ | 380/469 [12:34<02:48, 1.89s/it]
|
| 592 |
81%|████████ | 381/469 [12:36<02:45, 1.88s/it]
|
| 593 |
81%|████████▏ | 382/469 [12:38<02:43, 1.88s/it]
|
| 594 |
82%|████████▏ | 383/469 [12:39<02:38, 1.85s/it]
|
| 595 |
82%|████████▏ | 384/469 [12:41<02:36, 1.84s/it]
|
| 596 |
82%|████████▏ | 385/469 [12:43<02:37, 1.88s/it]
|
| 597 |
82%|████████▏ | 386/469 [12:45<02:35, 1.87s/it]
|
| 598 |
83%|████████▎ | 387/469 [12:47<02:45, 2.02s/it]
|
| 599 |
83%|████████▎ | 388/469 [12:49<02:37, 1.94s/it]
|
| 600 |
83%|████████▎ | 389/469 [12:51<02:30, 1.88s/it]
|
| 601 |
83%|████████▎ | 390/469 [12:53<02:32, 1.93s/it]
|
| 602 |
83%|████████▎ | 391/469 [12:55<02:29, 1.92s/it]
|
| 603 |
84%|████████▎ | 392/469 [12:57<02:23, 1.87s/it]
|
| 604 |
84%|████████▍ | 393/469 [13:00<03:00, 2.38s/it]
|
| 605 |
84%|████████▍ | 394/469 [13:02<02:47, 2.23s/it]
|
| 606 |
84%|████████▍ | 395/469 [13:04<02:36, 2.12s/it]
|
| 607 |
84%|████████▍ | 396/469 [13:06<02:27, 2.02s/it]
|
| 608 |
85%|████████▍ | 397/469 [13:08<02:24, 2.00s/it]
|
| 609 |
85%|████████▍ | 398/469 [13:09<02:19, 1.96s/it]
|
| 610 |
85%|████████▌ | 399/469 [13:11<02:14, 1.92s/it]
|
| 611 |
85%|████████▌ | 400/469 [13:13<02:08, 1.86s/it]
|
| 612 |
86%|████████▌ | 401/469 [13:15<02:05, 1.85s/it]
|
| 613 |
86%|████████▌ | 402/469 [13:17<02:02, 1.84s/it]
|
| 614 |
86%|████████▌ | 403/469 [13:18<01:59, 1.81s/it]
|
| 615 |
86%|████████▌ | 404/469 [13:20<02:02, 1.89s/it]
|
| 616 |
86%|████████▋ | 405/469 [13:22<01:59, 1.87s/it]
|
| 617 |
87%|████████▋ | 406/469 [13:24<01:56, 1.85s/it]
|
| 618 |
87%|████████▋ | 407/469 [13:26<01:52, 1.81s/it]
|
| 619 |
87%|████████▋ | 408/469 [13:28<01:50, 1.81s/it]
|
| 620 |
87%|████████▋ | 409/469 [13:30<01:50, 1.84s/it]
|
| 621 |
87%|████████▋ | 410/469 [13:32<01:55, 1.95s/it]
|
| 622 |
88%|████████▊ | 411/469 [13:33<01:49, 1.89s/it]
|
| 623 |
88%|████████▊ | 412/469 [13:35<01:48, 1.90s/it]
|
| 624 |
88%|████████▊ | 413/469 [13:37<01:45, 1.89s/it]
|
| 625 |
88%|████████▊ | 414/469 [13:39<01:42, 1.86s/it]
|
| 626 |
88%|████████▊ | 415/469 [13:41<01:38, 1.82s/it]
|
| 627 |
89%|████████▊ | 416/469 [13:43<01:39, 1.87s/it]
|
| 628 |
89%|████████▉ | 417/469 [13:45<01:38, 1.90s/it]
|
| 629 |
89%|████████▉ | 418/469 [13:50<02:30, 2.96s/it]
|
| 630 |
89%|████████▉ | 419/469 [13:52<02:12, 2.65s/it]
|
| 631 |
90%|████████▉ | 420/469 [13:54<02:01, 2.47s/it]
|
| 632 |
90%|████████▉ | 421/469 [13:56<01:48, 2.26s/it]
|
| 633 |
90%|████████▉ | 422/469 [13:58<01:39, 2.12s/it]
|
| 634 |
90%|█████████ | 423/469 [14:00<01:33, 2.02s/it]
|
| 635 |
90%|█████████ | 424/469 [14:01<01:29, 2.00s/it]
|
| 636 |
91%|█████████ | 425/469 [14:03<01:24, 1.93s/it]
|
| 637 |
91%|█████████ | 426/469 [14:05<01:22, 1.93s/it]
|
| 638 |
91%|█████████ | 427/469 [14:07<01:19, 1.90s/it]
|
| 639 |
91%|█████████▏| 428/469 [14:10<01:30, 2.21s/it]
|
| 640 |
91%|█████████▏| 429/469 [14:12<01:23, 2.09s/it]
|
| 641 |
92%|█████████▏| 430/469 [14:13<01:17, 2.00s/it]
|
| 642 |
92%|█████████▏| 431/469 [14:16<01:17, 2.05s/it]
|
| 643 |
92%|█████████▏| 432/469 [14:17<01:12, 1.96s/it]
|
| 644 |
92%|█████████▏| 433/469 [14:19<01:09, 1.93s/it]
|
| 645 |
93%|█████████▎| 434/469 [14:21<01:06, 1.89s/it]
|
| 646 |
93%|█████████▎| 435/469 [14:23<01:03, 1.88s/it]
|
| 647 |
93%|█████████▎| 436/469 [14:25<01:06, 2.01s/it]
|
| 648 |
93%|█████████▎| 437/469 [14:27<01:02, 1.95s/it]
|
| 649 |
93%|█████████▎| 438/469 [14:29<00:59, 1.93s/it]
|
| 650 |
94%|█████████▎| 439/469 [14:31<00:57, 1.91s/it]
|
| 651 |
94%|█████████▍| 440/469 [14:33<00:59, 2.04s/it]
|
| 652 |
94%|█████████▍| 441/469 [14:35<00:55, 1.99s/it]
|
| 653 |
94%|█████████▍| 442/469 [14:37<00:53, 1.97s/it]
|
| 654 |
94%|█████████▍| 443/469 [14:39<00:49, 1.92s/it]
|
| 655 |
95%|█████████▍| 444/469 [14:41<00:47, 1.89s/it]
|
| 656 |
95%|█████████▍| 445/469 [14:42<00:44, 1.87s/it]
|
| 657 |
95%|█████████▌| 446/469 [14:44<00:44, 1.94s/it]
|
| 658 |
95%|█████████▌| 447/469 [14:46<00:41, 1.88s/it]
|
| 659 |
96%|█████████▌| 448/469 [14:48<00:38, 1.83s/it]
|
| 660 |
96%|█████████▌| 449/469 [14:50<00:37, 1.89s/it]
|
| 661 |
96%|█████████▌| 450/469 [14:52<00:34, 1.84s/it]
|
| 662 |
96%|█████████▌| 451/469 [14:53<00:32, 1.81s/it]
|
| 663 |
96%|█████████▋| 452/469 [14:55<00:30, 1.80s/it]
|
| 664 |
97%|█████████▋| 453/469 [14:57<00:28, 1.77s/it]
|
| 665 |
97%|█████████▋| 454/469 [14:59<00:27, 1.85s/it]
|
| 666 |
97%|█████████▋| 455/469 [15:01<00:25, 1.84s/it]
|
| 667 |
97%|█████████▋| 456/469 [15:03<00:24, 1.92s/it]
|
| 668 |
97%|█████████▋| 457/469 [15:05<00:22, 1.88s/it]
|
| 669 |
98%|█████████▊| 458/469 [15:07<00:20, 1.88s/it]
|
| 670 |
98%|█████████▊| 459/469 [15:09<00:19, 1.92s/it]
|
| 671 |
98%|█████████▊| 460/469 [15:10<00:17, 1.92s/it]
|
| 672 |
98%|█████████▊| 461/469 [15:12<00:15, 1.90s/it]
|
| 673 |
99%|█████████▊| 462/469 [15:15<00:14, 2.11s/it]
|
| 674 |
99%|█████████▊| 463/469 [15:17<00:13, 2.18s/it]
|
| 675 |
99%|█████████▉| 464/469 [15:19<00:10, 2.09s/it]
|
| 676 |
99%|█████████▉| 465/469 [15:21<00:08, 2.07s/it]
|
| 677 |
99%|█████████▉| 466/469 [15:23<00:06, 2.02s/it]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-23 16:28:17.115582: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
| 2 |
+
2026-03-23 16:28:26.392554: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
|
| 3 |
+
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
| 4 |
+
2026-03-23 16:28:26.438747: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
|
| 5 |
+
2026-03-23 16:28:26.438805: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: d9a147cbff28e342e3a570e4cd1afa4e-taskrole1-0
|
| 6 |
+
2026-03-23 16:28:26.438830: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: d9a147cbff28e342e3a570e4cd1afa4e-taskrole1-0
|
| 7 |
+
2026-03-23 16:28:26.438915: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
|
| 8 |
+
2026-03-23 16:28:26.438959: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 535.154.5
|
| 9 |
+
2026-03-23 16:28:26.776048: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
| 10 |
+
warming up TensorFlow...
|
| 11 |
+
|
| 12 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-23 16:28:27.563125: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
| 13 |
+
|
| 14 |
+
computing reference batch activations...
|
| 15 |
+
|
| 16 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 17 |
0%| | 1/211 [00:02<10:16, 2.94s/it]
|
| 18 |
1%| | 2/211 [00:04<08:15, 2.37s/it]
|
| 19 |
1%|▏ | 3/211 [00:06<07:38, 2.21s/it]
|
| 20 |
2%|▏ | 4/211 [00:08<07:20, 2.13s/it]
|
| 21 |
2%|▏ | 5/211 [00:11<07:25, 2.16s/it]
|
| 22 |
3%|▎ | 6/211 [00:12<06:59, 2.04s/it]
|
| 23 |
3%|▎ | 7/211 [00:15<07:15, 2.13s/it]
|
| 24 |
4%|▍ | 8/211 [00:17<06:55, 2.05s/it]
|
| 25 |
4%|▍ | 9/211 [00:19<06:53, 2.05s/it]
|
| 26 |
5%|▍ | 10/211 [00:21<06:37, 1.98s/it]
|
| 27 |
5%|▌ | 11/211 [00:22<06:18, 1.89s/it]
|
| 28 |
6%|▌ | 12/211 [00:24<06:13, 1.88s/it]
|
| 29 |
6%|▌ | 13/211 [00:26<06:08, 1.86s/it]
|
| 30 |
7%|▋ | 14/211 [00:28<06:02, 1.84s/it]
|
| 31 |
7%|▋ | 15/211 [00:29<05:57, 1.82s/it]
|
| 32 |
8%|▊ | 16/211 [00:31<05:55, 1.82s/it]
|
| 33 |
8%|▊ | 17/211 [00:33<06:00, 1.86s/it]
|
| 34 |
9%|▊ | 18/211 [00:35<06:01, 1.87s/it]
|
| 35 |
9%|▉ | 19/211 [00:37<05:59, 1.87s/it]
|
| 36 |
9%|▉ | 20/211 [00:39<05:51, 1.84s/it]
|
| 37 |
10%|▉ | 21/211 [00:41<05:45, 1.82s/it]
|
| 38 |
10%|█ | 22/211 [00:42<05:43, 1.82s/it]
|
| 39 |
11%|█ | 23/211 [00:44<05:41, 1.82s/it]
|
| 40 |
11%|█▏ | 24/211 [00:47<06:23, 2.05s/it]
|
| 41 |
12%|█▏ | 25/211 [00:50<07:07, 2.30s/it]
|
| 42 |
12%|█▏ | 26/211 [00:51<06:37, 2.15s/it]
|
| 43 |
13%|█▎ | 27/211 [00:53<06:14, 2.04s/it]
|
| 44 |
13%|█▎ | 28/211 [00:55<05:56, 1.95s/it]
|
| 45 |
14%|█▎ | 29/211 [00:57<05:56, 1.96s/it]
|
| 46 |
14%|█▍ | 30/211 [00:59<05:49, 1.93s/it]
|
| 47 |
15%|█▍ | 31/211 [01:01<05:49, 1.94s/it]
|
| 48 |
15%|█▌ | 32/211 [01:03<05:41, 1.91s/it]
|
| 49 |
16%|█▌ | 33/211 [01:05<05:42, 1.93s/it]
|
| 50 |
16%|█▌ | 34/211 [01:07<05:49, 1.97s/it]
|
| 51 |
17%|█▋ | 35/211 [01:09<06:11, 2.11s/it]
|
| 52 |
17%|█▋ | 36/211 [01:11<06:00, 2.06s/it]
|
| 53 |
18%|█▊ | 37/211 [01:13<05:46, 1.99s/it]
|
| 54 |
18%|█▊ | 38/211 [01:15<05:57, 2.07s/it]
|
| 55 |
18%|█▊ | 39/211 [01:17<05:44, 2.01s/it]
|
| 56 |
19%|█▉ | 40/211 [01:19<05:31, 1.94s/it]
|
| 57 |
19%|█▉ | 41/211 [01:21<05:24, 1.91s/it]
|
| 58 |
20%|█▉ | 42/211 [01:22<05:19, 1.89s/it]
|
| 59 |
20%|██ | 43/211 [01:24<05:20, 1.91s/it]
|
| 60 |
21%|██ | 44/211 [01:26<05:09, 1.86s/it]
|
| 61 |
21%|██▏ | 45/211 [01:29<05:50, 2.11s/it]
|
| 62 |
22%|██▏ | 46/211 [01:31<05:46, 2.10s/it]
|
| 63 |
22%|██▏ | 47/211 [01:33<05:33, 2.03s/it]
|
| 64 |
23%|██▎ | 48/211 [01:35<05:17, 1.95s/it]
|
| 65 |
23%|██▎ | 49/211 [01:36<05:14, 1.94s/it]
|
| 66 |
24%|██▎ | 50/211 [01:38<05:03, 1.88s/it]
|
| 67 |
24%|██▍ | 51/211 [01:40<05:19, 2.00s/it]
|
| 68 |
25%|██▍ | 52/211 [01:42<05:13, 1.97s/it]
|
| 69 |
25%|██▌ | 53/211 [01:45<05:19, 2.02s/it]
|
| 70 |
26%|██▌ | 54/211 [01:46<05:05, 1.95s/it]
|
| 71 |
26%|██▌ | 55/211 [01:48<04:55, 1.89s/it]
|
| 72 |
27%|██▋ | 56/211 [01:50<04:56, 1.92s/it]
|
| 73 |
27%|██▋ | 57/211 [01:52<04:48, 1.88s/it]
|
| 74 |
27%|██▋ | 58/211 [01:54<04:44, 1.86s/it]
|
| 75 |
28%|██▊ | 59/211 [01:56<04:47, 1.89s/it]
|
| 76 |
28%|██▊ | 60/211 [01:58<04:57, 1.97s/it]
|
| 77 |
29%|██▉ | 61/211 [02:00<04:46, 1.91s/it]
|
| 78 |
29%|██▉ | 62/211 [02:01<04:38, 1.87s/it]
|
| 79 |
30%|██▉ | 63/211 [02:04<05:05, 2.07s/it]
|
| 80 |
30%|███ | 64/211 [02:06<04:59, 2.04s/it]
|
| 81 |
31%|███ | 65/211 [02:08<05:05, 2.09s/it]
|
| 82 |
31%|███▏ | 66/211 [02:10<04:48, 1.99s/it]
|
| 83 |
32%|███▏ | 67/211 [02:12<04:37, 1.93s/it]
|
| 84 |
32%|███▏ | 68/211 [02:13<04:27, 1.87s/it]
|
| 85 |
33%|███▎ | 69/211 [02:15<04:25, 1.87s/it]
|
| 86 |
33%|███▎ | 70/211 [02:17<04:18, 1.83s/it]
|
| 87 |
34%|███▎ | 71/211 [02:19<04:15, 1.83s/it]
|
| 88 |
34%|███▍ | 72/211 [02:20<04:10, 1.80s/it]
|
| 89 |
35%|███▍ | 73/211 [02:22<04:08, 1.80s/it]
|
| 90 |
35%|███▌ | 74/211 [02:24<04:12, 1.84s/it]
|
| 91 |
36%|███▌ | 75/211 [02:26<04:10, 1.84s/it]
|
| 92 |
36%|███▌ | 76/211 [02:28<04:06, 1.83s/it]
|
| 93 |
36%|███▋ | 77/211 [02:30<04:11, 1.88s/it]
|
| 94 |
37%|███▋ | 78/211 [02:32<04:07, 1.86s/it]
|
| 95 |
37%|███▋ | 79/211 [02:34<04:08, 1.88s/it]
|
| 96 |
38%|███▊ | 80/211 [02:35<04:06, 1.88s/it]
|
| 97 |
38%|███▊ | 81/211 [02:37<04:11, 1.93s/it]
|
| 98 |
39%|███▉ | 82/211 [02:39<04:04, 1.89s/it]
|
| 99 |
39%|███▉ | 83/211 [02:41<04:06, 1.92s/it]
|
| 100 |
40%|███▉ | 84/211 [02:43<04:03, 1.92s/it]
|
| 101 |
40%|████ | 85/211 [02:45<03:55, 1.87s/it]
|
| 102 |
41%|████ | 86/211 [02:47<03:50, 1.84s/it]
|
| 103 |
41%|████ | 87/211 [02:48<03:45, 1.82s/it]
|
| 104 |
42%|████▏ | 88/211 [02:50<03:41, 1.80s/it]
|
| 105 |
42%|████▏ | 89/211 [02:52<03:36, 1.78s/it]
|
| 106 |
43%|████▎ | 90/211 [02:54<03:51, 1.91s/it]
|
| 107 |
43%|████▎ | 91/211 [02:56<03:45, 1.88s/it]
|
| 108 |
44%|████▎ | 92/211 [02:58<03:38, 1.84s/it]
|
| 109 |
44%|████▍ | 93/211 [03:00<03:37, 1.85s/it]
|
| 110 |
45%|████▍ | 94/211 [03:02<03:46, 1.94s/it]
|
| 111 |
45%|████▌ | 95/211 [03:03<03:36, 1.87s/it]
|
| 112 |
45%|████▌ | 96/211 [03:05<03:34, 1.87s/it]
|
| 113 |
46%|████▌ | 97/211 [03:08<03:58, 2.09s/it]
|
| 114 |
46%|████▋ | 98/211 [03:10<03:45, 1.99s/it]
|
| 115 |
47%|████▋ | 99/211 [03:11<03:34, 1.91s/it]
|
| 116 |
47%|████▋ | 100/211 [03:13<03:31, 1.91s/it]
|
| 117 |
48%|████▊ | 101/211 [03:15<03:28, 1.89s/it]
|
| 118 |
48%|████▊ | 102/211 [03:21<05:46, 3.18s/it]
|
| 119 |
49%|████▉ | 103/211 [03:23<04:58, 2.76s/it]
|
| 120 |
49%|████▉ | 104/211 [03:25<04:22, 2.45s/it]
|
| 121 |
50%|████▉ | 105/211 [03:27<04:03, 2.30s/it]
|
| 122 |
50%|█████ | 106/211 [03:29<03:45, 2.15s/it]
|
| 123 |
51%|█████ | 107/211 [03:31<03:35, 2.07s/it]
|
| 124 |
51%|█████ | 108/211 [03:32<03:29, 2.03s/it]
|
| 125 |
52%|█████▏ | 109/211 [03:34<03:21, 1.97s/it]
|
| 126 |
52%|█████▏ | 110/211 [03:37<03:36, 2.15s/it]
|
| 127 |
53%|█████▎ | 111/211 [03:39<03:23, 2.03s/it]
|
| 128 |
53%|█████▎ | 112/211 [03:41<03:17, 1.99s/it]
|
| 129 |
54%|█████▎ | 113/211 [03:42<03:10, 1.95s/it]
|
| 130 |
54%|█████▍ | 114/211 [03:44<03:03, 1.90s/it]
|
| 131 |
55%|█████▍ | 115/211 [03:46<03:02, 1.91s/it]
|
| 132 |
55%|█████▍ | 116/211 [03:48<02:59, 1.89s/it]
|
| 133 |
55%|█████▌ | 117/211 [03:50<03:03, 1.96s/it]
|
| 134 |
56%|█████▌ | 118/211 [03:52<02:56, 1.90s/it]
|
| 135 |
56%|█████▋ | 119/211 [03:54<02:51, 1.86s/it]
|
| 136 |
57%|█████▋ | 120/211 [03:55<02:49, 1.86s/it]
|
| 137 |
57%|█████▋ | 121/211 [03:57<02:52, 1.91s/it]
|
| 138 |
58%|█████▊ | 122/211 [03:59<02:47, 1.89s/it]
|
| 139 |
58%|█████▊ | 123/211 [04:01<02:51, 1.94s/it]
|
| 140 |
59%|█████▉ | 124/211 [04:03<02:47, 1.93s/it]
|
| 141 |
59%|█████▉ | 125/211 [04:05<02:40, 1.87s/it]
|
| 142 |
60%|█████▉ | 126/211 [04:07<02:46, 1.96s/it]
|
| 143 |
60%|██████ | 127/211 [04:09<02:43, 1.95s/it]
|
| 144 |
61%|██████ | 128/211 [04:11<02:42, 1.95s/it]
|
| 145 |
61%|██████ | 129/211 [04:13<02:47, 2.04s/it]
|
| 146 |
62%|██████▏ | 130/211 [04:15<02:41, 2.00s/it]
|
| 147 |
62%|██████▏ | 131/211 [04:17<02:39, 2.00s/it]
|
| 148 |
63%|██████▎ | 132/211 [04:19<02:36, 1.98s/it]
|
| 149 |
63%|██████▎ | 133/211 [04:21<02:31, 1.94s/it]
|
| 150 |
64%|██████▎ | 134/211 [04:23<02:26, 1.91s/it]
|
| 151 |
64%|██████▍ | 135/211 [04:25<02:21, 1.86s/it]
|
| 152 |
64%|██████▍ | 136/211 [04:26<02:20, 1.87s/it]
|
| 153 |
65%|██████▍ | 137/211 [04:29<02:23, 1.94s/it]
|
| 154 |
65%|██████▌ | 138/211 [04:30<02:18, 1.90s/it]
|
| 155 |
66%|██████▌ | 139/211 [04:32<02:13, 1.86s/it]
|
| 156 |
66%|██████▋ | 140/211 [04:34<02:10, 1.84s/it]
|
| 157 |
67%|██████▋ | 141/211 [04:36<02:07, 1.82s/it]
|
| 158 |
67%|██████▋ | 142/211 [04:38<02:06, 1.83s/it]
|
| 159 |
68%|██████▊ | 143/211 [04:39<02:06, 1.86s/it]
|
| 160 |
68%|██��███▊ | 144/211 [04:41<02:02, 1.83s/it]
|
| 161 |
69%|██████▊ | 145/211 [04:43<02:00, 1.82s/it]
|
| 162 |
69%|██████▉ | 146/211 [04:45<01:58, 1.82s/it]
|
| 163 |
70%|██████▉ | 147/211 [04:47<02:06, 1.98s/it]
|
| 164 |
70%|███████ | 148/211 [04:49<02:07, 2.02s/it]
|
| 165 |
71%|███████ | 149/211 [04:51<02:02, 1.97s/it]
|
| 166 |
71%|███████ | 150/211 [04:53<01:59, 1.96s/it]
|
| 167 |
72%|███████▏ | 151/211 [04:55<01:56, 1.95s/it]
|
| 168 |
72%|███████▏ | 152/211 [04:57<01:54, 1.94s/it]
|
| 169 |
73%|███████▎ | 153/211 [04:59<01:49, 1.88s/it]
|
| 170 |
73%|███████▎ | 154/211 [05:01<01:56, 2.04s/it]
|
| 171 |
73%|███████▎ | 155/211 [05:03<01:50, 1.97s/it]
|
| 172 |
74%|███████▍ | 156/211 [05:05<01:48, 1.97s/it]
|
| 173 |
74%|███████▍ | 157/211 [05:07<01:46, 1.96s/it]
|
| 174 |
75%|███████▍ | 158/211 [05:09<01:41, 1.92s/it]
|
| 175 |
75%|███████▌ | 159/211 [05:10<01:38, 1.90s/it]
|
| 176 |
76%|███████▌ | 160/211 [05:12<01:35, 1.88s/it]
|
| 177 |
76%|███████▋ | 161/211 [05:14<01:32, 1.85s/it]
|
| 178 |
77%|███████▋ | 162/211 [05:16<01:32, 1.89s/it]
|
| 179 |
77%|███████▋ | 163/211 [05:18<01:31, 1.90s/it]
|
| 180 |
78%|███████▊ | 164/211 [05:20<01:28, 1.88s/it]
|
| 181 |
78%|███████▊ | 165/211 [05:22<01:26, 1.88s/it]
|
| 182 |
79%|███████▊ | 166/211 [05:24<01:25, 1.89s/it]
|
| 183 |
79%|███████▉ | 167/211 [05:25<01:22, 1.88s/it]
|
| 184 |
80%|███████▉ | 168/211 [05:27<01:22, 1.92s/it]
|
| 185 |
80%|████████ | 169/211 [05:29<01:21, 1.94s/it]
|
| 186 |
81%|████████ | 170/211 [05:31<01:18, 1.92s/it]
|
| 187 |
81%|████████ | 171/211 [05:34<01:26, 2.15s/it]
|
| 188 |
82%|████████▏ | 172/211 [05:36<01:23, 2.14s/it]
|
| 189 |
82%|████████▏ | 173/211 [05:38<01:17, 2.05s/it]
|
| 190 |
82%|████████▏ | 174/211 [05:40<01:13, 1.99s/it]
|
| 191 |
83%|████████▎ | 175/211 [05:42<01:09, 1.93s/it]
|
| 192 |
83%|████████▎ | 176/211 [05:44<01:10, 2.02s/it]
|
| 193 |
84%|████████▍ | 177/211 [05:46<01:10, 2.06s/it]
|
| 194 |
84%|████████▍ | 178/211 [05:49<01:16, 2.33s/it]
|
| 195 |
85%|████████▍ | 179/211 [05:51<01:09, 2.16s/it]
|
| 196 |
85%|████████▌ | 180/211 [05:53<01:03, 2.05s/it]
|
| 197 |
86%|████████▌ | 181/211 [05:54<00:58, 1.96s/it]
|
| 198 |
86%|████████▋ | 182/211 [05:56<00:56, 1.95s/it]
|
| 199 |
87%|████████▋ | 183/211 [05:59<00:57, 2.05s/it]
|
| 200 |
87%|████████▋ | 184/211 [06:00<00:54, 2.01s/it]
|
| 201 |
88%|████████▊ | 185/211 [06:02<00:50, 1.95s/it]
|
| 202 |
88%|████████▊ | 186/211 [06:04<00:47, 1.89s/it]
|
| 203 |
89%|████████▊ | 187/211 [06:06<00:45, 1.88s/it]
|
| 204 |
89%|████████▉ | 188/211 [06:08<00:42, 1.86s/it]
|
| 205 |
90%|████████▉ | 189/211 [06:09<00:40, 1.83s/it]
|
| 206 |
90%|█████████ | 190/211 [06:11<00:38, 1.85s/it]
|
| 207 |
91%|█████████ | 191/211 [06:13<00:36, 1.85s/it]
|
| 208 |
91%|█████████ | 192/211 [06:15<00:34, 1.84s/it]
|
| 209 |
91%|█████████▏| 193/211 [06:17<00:34, 1.92s/it]
|
| 210 |
92%|█████████▏| 194/211 [06:19<00:32, 1.91s/it]
|
| 211 |
92%|█████████▏| 195/211 [06:21<00:31, 1.96s/it]
|
| 212 |
93%|█████████▎| 196/211 [06:23<00:28, 1.92s/it]
|
| 213 |
93%|█████████▎| 197/211 [06:25<00:27, 1.96s/it]
|
| 214 |
94%|█████████▍| 198/211 [06:27<00:25, 1.93s/it]
|
| 215 |
94%|█████████▍| 199/211 [06:29<00:22, 1.88s/it]
|
| 216 |
95%|█████████▍| 200/211 [06:30<00:20, 1.89s/it]
|
| 217 |
95%|█████████▌| 201/211 [06:32<00:18, 1.88s/it]
|
| 218 |
96%|█████████▌| 202/211 [06:34<00:16, 1.86s/it]
|
| 219 |
96%|█████████▌| 203/211 [06:36<00:15, 1.88s/it]
|
| 220 |
97%|█████████▋| 204/211 [06:38<00:13, 1.89s/it]
|
| 221 |
97%|█████████▋| 205/211 [06:40<00:11, 1.91s/it]
|
| 222 |
98%|█████████▊| 206/211 [06:42<00:10, 2.10s/it]
|
| 223 |
98%|█████████▊| 207/211 [06:44<00:08, 2.07s/it]
|
| 224 |
99%|█████████▊| 208/211 [06:46<00:05, 1.98s/it]
|
| 225 |
99%|█████████▉| 209/211 [06:49<00:04, 2.12s/it]
|
| 226 |
+
computing/reading reference batch statistics...
|
| 227 |
+
computing sample batch activations...
|
| 228 |
+
|
| 229 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 230 |
0%| | 1/469 [00:01<14:40, 1.88s/it]
|
| 231 |
0%| | 2/469 [00:04<16:02, 2.06s/it]
|
| 232 |
1%| | 3/469 [00:05<15:06, 1.95s/it]
|
| 233 |
1%| | 4/469 [00:07<14:44, 1.90s/it]
|
| 234 |
1%| | 5/469 [00:09<14:34, 1.88s/it]
|
| 235 |
1%|▏ | 6/469 [00:11<14:24, 1.87s/it]
|
| 236 |
1%|▏ | 7/469 [00:13<14:05, 1.83s/it]
|
| 237 |
2%|▏ | 8/469 [00:15<14:19, 1.87s/it]
|
| 238 |
2%|▏ | 9/469 [00:16<14:13, 1.85s/it]
|
| 239 |
2%|▏ | 10/469 [00:18<14:09, 1.85s/it]
|
| 240 |
2%|▏ | 11/469 [00:20<14:04, 1.84s/it]
|
| 241 |
3%|▎ | 12/469 [00:22<13:59, 1.84s/it]
|
| 242 |
3%|▎ | 13/469 [00:24<14:09, 1.86s/it]
|
| 243 |
3%|▎ | 14/469 [00:26<14:18, 1.89s/it]
|
| 244 |
3%|▎ | 15/469 [00:28<14:01, 1.85s/it]
|
| 245 |
3%|▎ | 16/469 [00:30<14:36, 1.94s/it]
|
| 246 |
4%|▎ | 17/469 [00:31<14:12, 1.89s/it]
|
| 247 |
4%|▍ | 18/469 [00:33<13:58, 1.86s/it]
|
| 248 |
4%|▍ | 19/469 [00:36<14:55, 1.99s/it]
|
| 249 |
4%|▍ | 20/469 [00:38<14:53, 1.99s/it]
|
| 250 |
4%|▍ | 21/469 [00:40<14:50, 1.99s/it]
|
| 251 |
5%|▍ | 22/469 [00:41<14:43, 1.98s/it]
|
| 252 |
5%|▍ | 23/469 [00:44<15:04, 2.03s/it]
|
| 253 |
5%|▌ | 24/469 [00:45<14:34, 1.97s/it]
|
| 254 |
5%|▌ | 25/469 [00:47<14:03, 1.90s/it]
|
| 255 |
6%|▌ | 26/469 [00:49<13:42, 1.86s/it]
|
| 256 |
6%|▌ | 27/469 [00:51<13:31, 1.84s/it]
|
| 257 |
6%|▌ | 28/469 [00:53<13:35, 1.85s/it]
|
| 258 |
6%|▌ | 29/469 [00:54<13:24, 1.83s/it]
|
| 259 |
6%|▋ | 30/469 [00:56<13:30, 1.85s/it]
|
| 260 |
7%|▋ | 31/469 [00:58<13:10, 1.80s/it]
|
| 261 |
7%|▋ | 32/469 [01:00<13:43, 1.88s/it]
|
| 262 |
7%|▋ | 33/469 [01:02<13:23, 1.84s/it]
|
| 263 |
7%|▋ | 34/469 [01:05<15:15, 2.10s/it]
|
| 264 |
7%|▋ | 35/469 [01:06<14:40, 2.03s/it]
|
| 265 |
8%|▊ | 36/469 [01:08<14:23, 1.99s/it]
|
| 266 |
8%|▊ | 37/469 [01:10<13:51, 1.93s/it]
|
| 267 |
8%|▊ | 38/469 [01:12<13:26, 1.87s/it]
|
| 268 |
8%|▊ | 39/469 [01:14<13:13, 1.85s/it]
|
| 269 |
9%|▊ | 40/469 [01:15<13:15, 1.85s/it]
|
| 270 |
9%|▊ | 41/469 [01:17<13:14, 1.86s/it]
|
| 271 |
9%|▉ | 42/469 [01:19<12:58, 1.82s/it]
|
| 272 |
9%|▉ | 43/469 [01:21<13:04, 1.84s/it]
|
| 273 |
9%|▉ | 44/469 [01:23<12:55, 1.82s/it]
|
| 274 |
10%|▉ | 45/469 [01:25<12:56, 1.83s/it]
|
| 275 |
10%|▉ | 46/469 [01:26<12:50, 1.82s/it]
|
| 276 |
10%|█ | 47/469 [01:28<12:47, 1.82s/it]
|
| 277 |
10%|█ | 48/469 [01:30<12:41, 1.81s/it]
|
| 278 |
10%|█ | 49/469 [01:32<12:52, 1.84s/it]
|
| 279 |
11%|█ | 50/469 [01:34<13:00, 1.86s/it]
|
| 280 |
11%|█ | 51/469 [01:36<12:46, 1.83s/it]
|
| 281 |
11%|█ | 52/469 [01:37<12:41, 1.83s/it]
|
| 282 |
11%|█▏ | 53/469 [01:39<12:40, 1.83s/it]
|
| 283 |
12%|█▏ | 54/469 [01:41<12:46, 1.85s/it]
|
| 284 |
12%|█▏ | 55/469 [01:44<15:33, 2.25s/it]
|
| 285 |
12%|█▏ | 56/469 [01:46<14:24, 2.09s/it]
|
| 286 |
12%|█▏ | 57/469 [01:48<13:41, 1.99s/it]
|
| 287 |
12%|█▏ | 58/469 [01:50<13:22, 1.95s/it]
|
| 288 |
13%|█▎ | 59/469 [01:51<12:54, 1.89s/it]
|
| 289 |
13%|█▎ | 60/469 [01:53<12:33, 1.84s/it]
|
| 290 |
13%|█▎ | 61/469 [01:55<12:23, 1.82s/it]
|
| 291 |
13%|█▎ | 62/469 [01:57<12:29, 1.84s/it]
|
| 292 |
13%|█▎ | 63/469 [01:59<12:42, 1.88s/it]
|
| 293 |
14%|█▎ | 64/469 [02:01<12:43, 1.89s/it]
|
| 294 |
14%|█▍ | 65/469 [02:02<12:29, 1.86s/it]
|
| 295 |
14%|█▍ | 66/469 [02:04<12:49, 1.91s/it]
|
| 296 |
14%|█▍ | 67/469 [02:06<12:50, 1.92s/it]
|
| 297 |
14%|█▍ | 68/469 [02:08<12:34, 1.88s/it]
|
| 298 |
15%|█▍ | 69/469 [02:10<12:22, 1.86s/it]
|
| 299 |
15%|█▍ | 70/469 [02:12<12:20, 1.86s/it]
|
| 300 |
15%|█▌ | 71/469 [02:14<12:10, 1.84s/it]
|
| 301 |
15%|█▌ | 72/469 [02:15<12:06, 1.83s/it]
|
| 302 |
16%|█▌ | 73/469 [02:17<12:00, 1.82s/it]
|
| 303 |
16%|█▌ | 74/469 [02:19<12:09, 1.85s/it]
|
| 304 |
16%|█▌ | 75/469 [02:21<12:18, 1.87s/it]
|
| 305 |
16%|█▌ | 76/469 [02:23<12:13, 1.87s/it]
|
| 306 |
16%|█▋ | 77/469 [02:25<12:16, 1.88s/it]
|
| 307 |
17%|█▋ | 78/469 [02:27<12:45, 1.96s/it]
|
| 308 |
17%|█▋ | 79/469 [02:29<12:17, 1.89s/it]
|
| 309 |
17%|█▋ | 80/469 [02:31<12:07, 1.87s/it]
|
| 310 |
17%|█▋ | 81/469 [02:32<11:53, 1.84s/it]
|
| 311 |
17%|█▋ | 82/469 [02:34<12:09, 1.89s/it]
|
| 312 |
18%|█▊ | 83/469 [02:36<12:01, 1.87s/it]
|
| 313 |
18%|█▊ | 84/469 [02:38<11:45, 1.83s/it]
|
| 314 |
18%|█▊ | 85/469 [02:40<12:13, 1.91s/it]
|
| 315 |
18%|█▊ | 86/469 [02:42<12:12, 1.91s/it]
|
| 316 |
19%|█▊ | 87/469 [02:44<11:53, 1.87s/it]
|
| 317 |
19%|█▉ | 88/469 [02:45<11:32, 1.82s/it]
|
| 318 |
19%|█▉ | 89/469 [02:47<11:28, 1.81s/it]
|
| 319 |
19%|█▉ | 90/469 [02:49<11:52, 1.88s/it]
|
| 320 |
19%|█▉ | 91/469 [02:51<11:34, 1.84s/it]
|
| 321 |
20%|█▉ | 92/469 [02:53<11:36, 1.85s/it]
|
| 322 |
20%|█▉ | 93/469 [02:55<11:36, 1.85s/it]
|
| 323 |
20%|██ | 94/469 [02:56<11:27, 1.83s/it]
|
| 324 |
20%|██ | 95/469 [02:58<11:24, 1.83s/it]
|
| 325 |
20%|██ | 96/469 [03:00<11:19, 1.82s/it]
|
| 326 |
21%|██ | 97/469 [03:02<11:25, 1.84s/it]
|
| 327 |
21%|██ | 98/469 [03:04<11:45, 1.90s/it]
|
| 328 |
21%|██ | 99/469 [03:06<11:45, 1.91s/it]
|
| 329 |
21%|██▏ | 100/469 [03:08<11:30, 1.87s/it]
|
| 330 |
22%|██▏ | 101/469 [03:10<12:35, 2.05s/it]
|
| 331 |
22%|██▏ | 102/469 [03:12<12:40, 2.07s/it]
|
| 332 |
22%|██▏ | 103/469 [03:14<12:10, 2.00s/it]
|
| 333 |
22%|██▏ | 104/469 [03:16<11:49, 1.94s/it]
|
| 334 |
22%|██▏ | 105/469 [03:18<11:38, 1.92s/it]
|
| 335 |
23%|██▎ | 106/469 [03:20<11:50, 1.96s/it]
|
| 336 |
23%|██▎ | 107/469 [03:22<11:29, 1.91s/it]
|
| 337 |
23%|██▎ | 108/469 [03:24<12:19, 2.05s/it]
|
| 338 |
23%|██▎ | 109/469 [03:26<11:50, 1.97s/it]
|
| 339 |
23%|██▎ | 110/469 [03:28<12:21, 2.06s/it]
|
| 340 |
24%|██▎ | 111/469 [03:30<12:00, 2.01s/it]
|
| 341 |
24%|██▍ | 112/469 [03:32<11:37, 1.95s/it]
|
| 342 |
24%|██▍ | 113/469 [03:34<11:13, 1.89s/it]
|
| 343 |
24%|██▍ | 114/469 [03:36<11:26, 1.93s/it]
|
| 344 |
25%|██▍ | 115/469 [03:37<11:06, 1.88s/it]
|
| 345 |
25%|██▍ | 116/469 [03:39<11:27, 1.95s/it]
|
| 346 |
25%|██▍ | 117/469 [03:42<12:08, 2.07s/it]
|
| 347 |
25%|██▌ | 118/469 [03:44<11:37, 1.99s/it]
|
| 348 |
25%|██▌ | 119/469 [03:45<11:14, 1.93s/it]
|
| 349 |
26%|██▌ | 120/469 [03:47<10:53, 1.87s/it]
|
| 350 |
26%|██▌ | 121/469 [03:49<10:39, 1.84s/it]
|
| 351 |
26%|██▌ | 122/469 [03:51<10:28, 1.81s/it]
|
| 352 |
26%|██▌ | 123/469 [03:52<10:23, 1.80s/it]
|
| 353 |
26%|██▋ | 124/469 [03:54<10:16, 1.79s/it]
|
| 354 |
27%|██▋ | 125/469 [03:56<10:26, 1.82s/it]
|
| 355 |
27%|██▋ | 126/469 [03:58<10:19, 1.81s/it]
|
| 356 |
27%|██▋ | 127/469 [04:00<10:05, 1.77s/it]
|
| 357 |
27%|██▋ | 128/469 [04:01<10:13, 1.80s/it]
|
| 358 |
28%|██▊ | 129/469 [04:04<10:47, 1.91s/it]
|
| 359 |
28%|██▊ | 130/469 [04:05<10:30, 1.86s/it]
|
| 360 |
28%|██▊ | 131/469 [04:08<11:02, 1.96s/it]
|
| 361 |
28%|██▊ | 132/469 [04:10<11:07, 1.98s/it]
|
| 362 |
28%|██▊ | 133/469 [04:11<11:00, 1.97s/it]
|
| 363 |
29%|██▊ | 134/469 [04:13<10:36, 1.90s/it]
|
| 364 |
29%|██▉ | 135/469 [04:15<10:33, 1.90s/it]
|
| 365 |
29%|██▉ | 136/469 [04:17<10:28, 1.89s/it]
|
| 366 |
29%|██▉ | 137/469 [04:23<16:46, 3.03s/it]
|
| 367 |
29%|██▉ | 138/469 [04:24<14:39, 2.66s/it]
|
| 368 |
30%|██▉ | 139/469 [04:27<13:43, 2.49s/it]
|
| 369 |
30%|██▉ | 140/469 [04:29<13:38, 2.49s/it]
|
| 370 |
30%|███ | 141/469 [04:31<12:21, 2.26s/it]
|
| 371 |
30%|███ | 142/469 [04:33<11:27, 2.10s/it]
|
| 372 |
30%|███ | 143/469 [04:34<10:47, 1.98s/it]
|
| 373 |
31%|███ | 144/469 [04:36<10:23, 1.92s/it]
|
| 374 |
31%|███ | 145/469 [04:38<10:18, 1.91s/it]
|
| 375 |
31%|███ | 146/469 [04:40<10:36, 1.97s/it]
|
| 376 |
31%|███▏ | 147/469 [04:42<10:39, 1.99s/it]
|
| 377 |
32%|███▏ | 148/469 [04:44<10:50, 2.03s/it]
|
| 378 |
32%|███▏ | 149/469 [04:46<10:30, 1.97s/it]
|
| 379 |
32%|███▏ | 150/469 [04:48<10:12, 1.92s/it]
|
| 380 |
32%|███▏ | 151/469 [04:50<09:55, 1.87s/it]
|
| 381 |
32%|███▏ | 152/469 [04:51<10:00, 1.89s/it]
|
| 382 |
33%|███▎ | 153/469 [04:53<10:00, 1.90s/it]
|
| 383 |
33%|███▎ | 154/469 [04:55<09:56, 1.89s/it]
|
| 384 |
33%|███▎ | 155/469 [04:57<09:55, 1.90s/it]
|
| 385 |
33%|███▎ | 156/469 [04:59<09:38, 1.85s/it]
|
| 386 |
33%|███▎ | 157/469 [05:01<09:29, 1.83s/it]
|
| 387 |
34%|███▎ | 158/469 [05:03<09:30, 1.84s/it]
|
| 388 |
34%|███▍ | 159/469 [05:04<09:18, 1.80s/it]
|
| 389 |
34%|███▍ | 160/469 [05:07<09:59, 1.94s/it]
|
| 390 |
34%|███▍ | 161/469 [05:08<09:52, 1.93s/it]
|
| 391 |
35%|███▍ | 162/469 [05:10<09:53, 1.93s/it]
|
| 392 |
35%|███▍ | 163/469 [05:12<09:41, 1.90s/it]
|
| 393 |
35%|███▍ | 164/469 [05:14<09:34, 1.88s/it]
|
| 394 |
35%|███▌ | 165/469 [05:16<09:32, 1.88s/it]
|
| 395 |
35%|███▌ | 166/469 [05:18<09:23, 1.86s/it]
|
| 396 |
36%|███▌ | 167/469 [05:20<09:39, 1.92s/it]
|
| 397 |
36%|███▌ | 168/469 [05:22<09:30, 1.90s/it]
|
| 398 |
36%|███▌ | 169/469 [05:23<09:21, 1.87s/it]
|
| 399 |
36%|███▌ | 170/469 [05:25<09:10, 1.84s/it]
|
| 400 |
36%|███▋ | 171/469 [05:27<09:20, 1.88s/it]
|
| 401 |
37%|███▋ | 172/469 [05:29<09:07, 1.84s/it]
|
| 402 |
37%|███▋ | 173/469 [05:31<09:01, 1.83s/it]
|
| 403 |
37%|███▋ | 174/469 [05:33<09:00, 1.83s/it]
|
| 404 |
37%|███▋ | 175/469 [05:34<08:49, 1.80s/it]
|
| 405 |
38%|███▊ | 176/469 [05:36<09:09, 1.88s/it]
|
| 406 |
38%|███▊ | 177/469 [05:38<09:08, 1.88s/it]
|
| 407 |
38%|███▊ | 178/469 [05:40<09:00, 1.86s/it]
|
| 408 |
38%|███▊ | 179/469 [05:42<08:51, 1.83s/it]
|
| 409 |
38%|███▊ | 180/469 [05:44<08:51, 1.84s/it]
|
| 410 |
39%|███▊ | 181/469 [05:46<10:03, 2.10s/it]
|
| 411 |
39%|███▉ | 182/469 [05:48<09:44, 2.04s/it]
|
| 412 |
39%|███▉ | 183/469 [05:50<09:21, 1.96s/it]
|
| 413 |
39%|███▉ | 184/469 [05:52<09:23, 1.98s/it]
|
| 414 |
39%|███▉ | 185/469 [05:54<09:03, 1.92s/it]
|
| 415 |
40%|███▉ | 186/469 [05:56<08:53, 1.88s/it]
|
| 416 |
40%|███▉ | 187/469 [05:57<08:40, 1.85s/it]
|
| 417 |
40%|████ | 188/469 [05:59<08:58, 1.92s/it]
|
| 418 |
40%|████ | 189/469 [06:01<08:50, 1.89s/it]
|
| 419 |
41%|████ | 190/469 [06:03<08:55, 1.92s/it]
|
| 420 |
41%|████ | 191/469 [06:05<08:49, 1.91s/it]
|
| 421 |
41%|████ | 192/469 [06:07<08:48, 1.91s/it]
|
| 422 |
41%|████ | 193/469 [06:09<08:43, 1.90s/it]
|
| 423 |
41%|████▏ | 194/469 [06:11<08:31, 1.86s/it]
|
| 424 |
42%|████▏ | 195/469 [06:12<08:18, 1.82s/it]
|
| 425 |
42%|████▏ | 196/469 [06:14<08:06, 1.78s/it]
|
| 426 |
42%|████▏ | 197/469 [06:16<08:05, 1.79s/it]
|
| 427 |
42%|████▏ | 198/469 [06:18<08:14, 1.83s/it]
|
| 428 |
42%|████▏ | 199/469 [06:20<08:11, 1.82s/it]
|
| 429 |
43%|████▎ | 200/469 [06:22<08:36, 1.92s/it]
|
| 430 |
43%|████▎ | 201/469 [06:24<08:25, 1.89s/it]
|
| 431 |
43%|████▎ | 202/469 [06:26<08:28, 1.91s/it]
|
| 432 |
43%|████▎ | 203/469 [06:27<08:24, 1.90s/it]
|
| 433 |
43%|████▎ | 204/469 [06:29<08:31, 1.93s/it]
|
| 434 |
44%|████▎ | 205/469 [06:31<08:20, 1.89s/it]
|
| 435 |
44%|████▍ | 206/469 [06:33<08:09, 1.86s/it]
|
| 436 |
44%|████▍ | 207/469 [06:35<08:13, 1.89s/it]
|
| 437 |
44%|████▍ | 208/469 [06:37<08:21, 1.92s/it]
|
| 438 |
45%|████▍ | 209/469 [06:39<08:55, 2.06s/it]
|
| 439 |
45%|████▍ | 210/469 [06:41<08:32, 1.98s/it]
|
| 440 |
45%|████▍ | 211/469 [06:43<08:21, 1.94s/it]
|
| 441 |
45%|████▌ | 212/469 [06:45<08:19, 1.94s/it]
|
| 442 |
45%|████▌ | 213/469 [06:47<08:13, 1.93s/it]
|
| 443 |
46%|████▌ | 214/469 [06:49<08:23, 1.98s/it]
|
| 444 |
46%|████▌ | 215/469 [06:51<08:15, 1.95s/it]
|
| 445 |
46%|████▌ | 216/469 [06:53<08:02, 1.91s/it]
|
| 446 |
46%|████▋ | 217/469 [06:55<07:57, 1.90s/it]
|
| 447 |
46%|████▋ | 218/469 [06:56<07:51, 1.88s/it]
|
| 448 |
47%|████▋ | 219/469 [06:58<07:40, 1.84s/it]
|
| 449 |
47%|████▋ | 220/469 [07:00<07:52, 1.90s/it]
|
| 450 |
47%|████▋ | 221/469 [07:02<07:44, 1.87s/it]
|
| 451 |
47%|████▋ | 222/469 [07:04<08:09, 1.98s/it]
|
| 452 |
48%|████▊ | 223/469 [07:07<09:19, 2.27s/it]
|
| 453 |
48%|████▊ | 224/469 [07:09<08:42, 2.13s/it]
|
| 454 |
48%|████▊ | 225/469 [07:11<08:22, 2.06s/it]
|
| 455 |
48%|████▊ | 226/469 [07:13<08:01, 1.98s/it]
|
| 456 |
48%|████▊ | 227/469 [07:15<07:51, 1.95s/it]
|
| 457 |
49%|████▊ | 228/469 [07:16<07:43, 1.92s/it]
|
| 458 |
49%|████▉ | 229/469 [07:18<07:35, 1.90s/it]
|
| 459 |
49%|████▉ | 230/469 [07:20<07:28, 1.88s/it]
|
| 460 |
49%|████▉ | 231/469 [07:22<07:44, 1.95s/it]
|
| 461 |
49%|████▉ | 232/469 [07:24<07:32, 1.91s/it]
|
| 462 |
50%|████▉ | 233/469 [07:26<07:24, 1.88s/it]
|
| 463 |
50%|████▉ | 234/469 [07:28<07:26, 1.90s/it]
|
| 464 |
50%|█████ | 235/469 [07:30<07:18, 1.88s/it]
|
| 465 |
50%|█████ | 236/469 [07:32<07:21, 1.89s/it]
|
| 466 |
51%|█████ | 237/469 [07:33<07:15, 1.88s/it]
|
| 467 |
51%|█████ | 238/469 [07:35<07:31, 1.95s/it]
|
| 468 |
51%|█████ | 239/469 [07:37<07:25, 1.94s/it]
|
| 469 |
51%|█████ | 240/469 [07:39<07:26, 1.95s/it]
|
| 470 |
51%|█████▏ | 241/469 [07:42<08:39, 2.28s/it]
|
| 471 |
52%|█████▏ | 242/469 [07:44<08:17, 2.19s/it]
|
| 472 |
52%|█████▏ | 243/469 [07:47<08:19, 2.21s/it]
|
| 473 |
52%|█████▏ | 244/469 [07:49<07:57, 2.12s/it]
|
| 474 |
52%|█████▏ | 245/469 [07:50<07:35, 2.03s/it]
|
| 475 |
52%|█████▏ | 246/469 [07:52<07:26, 2.00s/it]
|
| 476 |
53%|█████▎ | 247/469 [07:54<07:19, 1.98s/it]
|
| 477 |
53%|█████▎ | 248/469 [07:56<07:11, 1.95s/it]
|
| 478 |
53%|█████▎ | 249/469 [07:58<07:15, 1.98s/it]
|
| 479 |
53%|█████▎ | 250/469 [08:00<06:58, 1.91s/it]
|
| 480 |
54%|█████▎ | 251/469 [08:02<06:51, 1.89s/it]
|
| 481 |
54%|█████▎ | 252/469 [08:04<06:42, 1.85s/it]
|
| 482 |
54%|█████▍ | 253/469 [08:05<06:33, 1.82s/it]
|
| 483 |
54%|█████▍ | 254/469 [08:07<06:29, 1.81s/it]
|
| 484 |
54%|█████▍ | 255/469 [08:09<06:32, 1.83s/it]
|
| 485 |
55%|█████▍ | 256/469 [08:11<06:24, 1.80s/it]
|
| 486 |
55%|█████▍ | 257/469 [08:12<06:18, 1.79s/it]
|
| 487 |
55%|█████▌ | 258/469 [08:14<06:25, 1.83s/it]
|
| 488 |
55%|█████▌ | 259/469 [08:16<06:25, 1.83s/it]
|
| 489 |
55%|█████▌ | 260/469 [08:18<06:19, 1.82s/it]
|
| 490 |
56%|█████▌ | 261/469 [08:20<06:42, 1.94s/it]
|
| 491 |
56%|█████▌ | 262/469 [08:22<06:32, 1.90s/it]
|
| 492 |
56%|█████▌ | 263/469 [08:24<06:25, 1.87s/it]
|
| 493 |
56%|█████▋ | 264/469 [08:26<06:17, 1.84s/it]
|
| 494 |
57%|█████▋ | 265/469 [08:28<06:20, 1.87s/it]
|
| 495 |
57%|█████▋ | 266/469 [08:29<06:17, 1.86s/it]
|
| 496 |
57%|█████▋ | 267/469 [08:31<06:16, 1.86s/it]
|
| 497 |
57%|█████▋ | 268/469 [08:33<06:12, 1.86s/it]
|
| 498 |
57%|█████▋ | 269/469 [08:35<06:41, 2.01s/it]
|
| 499 |
58%|█████▊ | 270/469 [08:37<06:24, 1.93s/it]
|
| 500 |
58%|█████▊ | 271/469 [08:39<06:10, 1.87s/it]
|
| 501 |
58%|█████▊ | 272/469 [08:41<06:09, 1.87s/it]
|
| 502 |
58%|█████▊ | 273/469 [08:43<06:24, 1.96s/it]
|
| 503 |
58%|█████▊ | 274/469 [08:45<06:11, 1.90s/it]
|
| 504 |
59%|█████▊ | 275/469 [08:47<06:10, 1.91s/it]
|
| 505 |
59%|█████▉ | 276/469 [08:48<06:03, 1.88s/it]
|
| 506 |
59%|█████▉ | 277/469 [08:50<06:01, 1.88s/it]
|
| 507 |
59%|█████▉ | 278/469 [08:52<06:08, 1.93s/it]
|
| 508 |
59%|█████▉ | 279/469 [08:54<05:57, 1.88s/it]
|
| 509 |
60%|█████▉ | 280/469 [08:56<06:11, 1.97s/it]
|
| 510 |
60%|█████▉ | 281/469 [08:58<06:07, 1.95s/it]
|
| 511 |
60%|██████ | 282/469 [09:00<05:57, 1.91s/it]
|
| 512 |
60%|██████ | 283/469 [09:02<05:48, 1.87s/it]
|
| 513 |
61%|██████ | 284/469 [09:04<05:44, 1.86s/it]
|
| 514 |
61%|██████ | 285/469 [09:06<05:45, 1.88s/it]
|
| 515 |
61%|██████ | 286/469 [09:08<06:11, 2.03s/it]
|
| 516 |
61%|██████ | 287/469 [09:10<06:16, 2.07s/it]
|
| 517 |
61%|██████▏ | 288/469 [09:12<05:59, 1.99s/it]
|
| 518 |
62%|██████▏ | 289/469 [09:14<05:50, 1.95s/it]
|
| 519 |
62%|██████▏ | 290/469 [09:16<05:39, 1.90s/it]
|
| 520 |
62%|██████▏ | 291/469 [09:17<05:28, 1.85s/it]
|
| 521 |
62%|██████▏ | 292/469 [09:19<05:25, 1.84s/it]
|
| 522 |
62%|██████▏ | 293/469 [09:21<05:23, 1.84s/it]
|
| 523 |
63%|██████▎ | 294/469 [09:23<05:31, 1.89s/it]
|
| 524 |
63%|██████▎ | 295/469 [09:25<05:20, 1.84s/it]
|
| 525 |
63%|██████▎ | 296/469 [09:26<05:13, 1.81s/it]
|
| 526 |
63%|██████▎ | 297/469 [09:28<05:08, 1.80s/it]
|
| 527 |
64%|██████▎ | 298/469 [09:30<05:04, 1.78s/it]
|
| 528 |
64%|██████▍ | 299/469 [09:32<05:05, 1.80s/it]
|
| 529 |
64%|██████▍ | 300/469 [09:34<05:08, 1.82s/it]
|
| 530 |
64%|██████▍ | 301/469 [09:36<05:07, 1.83s/it]
|
| 531 |
64%|██████▍ | 302/469 [09:37<05:04, 1.82s/it]
|
| 532 |
65%|██████▍ | 303/469 [09:39<05:06, 1.84s/it]
|
| 533 |
65%|██████▍ | 304/469 [09:41<05:06, 1.86s/it]
|
| 534 |
65%|██████▌ | 305/469 [09:43<05:02, 1.84s/it]
|
| 535 |
65%|██████▌ | 306/469 [09:45<05:01, 1.85s/it]
|
| 536 |
65%|██████▌ | 307/469 [09:47<05:00, 1.85s/it]
|
| 537 |
66%|██████▌ | 308/469 [09:49<05:02, 1.88s/it]
|
| 538 |
66%|██████▌ | 309/469 [09:50<05:00, 1.88s/it]
|
| 539 |
66%|██████▌ | 310/469 [09:52<04:53, 1.85s/it]
|
| 540 |
66%|██████▋ | 311/469 [09:54<04:54, 1.86s/it]
|
| 541 |
67%|██████▋ | 312/469 [09:56<04:58, 1.90s/it]
|
| 542 |
67%|██████▋ | 313/469 [09:58<04:50, 1.86s/it]
|
| 543 |
67%|██████▋ | 314/469 [10:00<04:51, 1.88s/it]
|
| 544 |
67%|██████▋ | 315/469 [10:02<04:52, 1.90s/it]
|
| 545 |
67%|██████▋ | 316/469 [10:04<04:53, 1.92s/it]
|
| 546 |
68%|██████▊ | 317/469 [10:06<04:45, 1.88s/it]
|
| 547 |
68%|██████▊ | 318/469 [10:08<04:57, 1.97s/it]
|
| 548 |
68%|██████▊ | 319/469 [10:09<04:46, 1.91s/it]
|
| 549 |
68%|██████▊ | 320/469 [10:11<04:46, 1.92s/it]
|
| 550 |
68%|██████▊ | 321/469 [10:14<05:08, 2.09s/it]
|
| 551 |
69%|██████▊ | 322/469 [10:16<05:13, 2.14s/it]
|
| 552 |
69%|██████▉ | 323/469 [10:18<05:02, 2.07s/it]
|
| 553 |
69%|██████▉ | 324/469 [10:20<04:52, 2.02s/it]
|
| 554 |
69%|██████▉ | 325/469 [10:22<04:42, 1.96s/it]
|
| 555 |
70%|██████▉ | 326/469 [10:24<04:34, 1.92s/it]
|
| 556 |
70%|██████▉ | 327/469 [10:26<04:41, 1.98s/it]
|
| 557 |
70%|██████▉ | 328/469 [10:28<04:31, 1.93s/it]
|
| 558 |
70%|███████ | 329/469 [10:29<04:24, 1.89s/it]
|
| 559 |
70%|███████ | 330/469 [10:31<04:19, 1.87s/it]
|
| 560 |
71%|███████ | 331/469 [10:33<04:13, 1.84s/it]
|
| 561 |
71%|███████ | 332/469 [10:35<04:24, 1.93s/it]
|
| 562 |
71%|███████ | 333/469 [10:37<04:16, 1.89s/it]
|
| 563 |
71%|███████ | 334/469 [10:42<06:23, 2.84s/it]
|
| 564 |
71%|███████▏ | 335/469 [10:44<05:35, 2.51s/it]
|
| 565 |
72%|███████▏ | 336/469 [10:45<05:04, 2.29s/it]
|
| 566 |
72%|███████▏ | 337/469 [10:47<04:42, 2.14s/it]
|
| 567 |
72%|███████▏ | 338/469 [10:49<04:41, 2.15s/it]
|
| 568 |
72%|███████▏ | 339/469 [10:51<04:33, 2.11s/it]
|
| 569 |
72%|███████▏ | 340/469 [10:53<04:21, 2.03s/it]
|
| 570 |
73%|███████▎ | 341/469 [10:55<04:17, 2.01s/it]
|
| 571 |
73%|███████▎ | 342/469 [10:57<04:19, 2.04s/it]
|
| 572 |
73%|███████▎ | 343/469 [10:59<04:18, 2.05s/it]
|
| 573 |
73%|███████▎ | 344/469 [11:01<04:04, 1.96s/it]
|
| 574 |
74%|███████▎ | 345/469 [11:03<04:00, 1.94s/it]
|
| 575 |
74%|███████▍ | 346/469 [11:05<03:52, 1.89s/it]
|
| 576 |
74%|███████▍ | 347/469 [11:07<03:45, 1.85s/it]
|
| 577 |
74%|███████▍ | 348/469 [11:08<03:39, 1.82s/it]
|
| 578 |
74%|███████▍ | 349/469 [11:10<03:38, 1.82s/it]
|
| 579 |
75%|███████▍ | 350/469 [11:12<03:42, 1.87s/it]
|
| 580 |
75%|███████▍ | 351/469 [11:14<03:44, 1.90s/it]
|
| 581 |
75%|███████▌ | 352/469 [11:16<03:40, 1.88s/it]
|
| 582 |
75%|███████▌ | 353/469 [11:25<07:38, 3.95s/it]
|
| 583 |
75%|███████▌ | 354/469 [11:27<06:26, 3.36s/it]
|
| 584 |
76%|███████▌ | 355/469 [11:29<05:30, 2.90s/it]
|
| 585 |
76%|███████▌ | 356/469 [11:31<05:08, 2.73s/it]
|
| 586 |
76%|███████▌ | 357/469 [11:33<04:45, 2.55s/it]
|
| 587 |
76%|███████▋ | 358/469 [11:35<04:23, 2.37s/it]
|
| 588 |
77%|██��████▋ | 359/469 [11:37<04:10, 2.28s/it]
|
| 589 |
77%|███████▋ | 360/469 [11:39<03:51, 2.12s/it]
|
| 590 |
77%|███████▋ | 361/469 [11:41<03:37, 2.02s/it]
|
| 591 |
77%|███████▋ | 362/469 [12:00<13:03, 7.32s/it]
|
| 592 |
77%|███████▋ | 363/469 [12:02<10:03, 5.69s/it]
|
| 593 |
78%|███████▊ | 364/469 [12:04<07:55, 4.53s/it]
|
| 594 |
78%|███████▊ | 365/469 [12:06<06:27, 3.73s/it]
|
| 595 |
78%|███████▊ | 366/469 [12:08<05:26, 3.17s/it]
|
| 596 |
78%|███████▊ | 367/469 [12:09<04:40, 2.75s/it]
|
| 597 |
78%|███████▊ | 368/469 [12:12<04:21, 2.59s/it]
|
| 598 |
79%|███████▊ | 369/469 [12:13<03:55, 2.36s/it]
|
| 599 |
79%|███████▉ | 370/469 [12:15<03:43, 2.25s/it]
|
| 600 |
79%|███████▉ | 371/469 [12:17<03:28, 2.13s/it]
|
| 601 |
79%|███████▉ | 372/469 [12:19<03:21, 2.07s/it]
|
| 602 |
80%|███████▉ | 373/469 [12:21<03:10, 1.99s/it]
|
| 603 |
80%|███████▉ | 374/469 [12:23<03:02, 1.92s/it]
|
| 604 |
80%|███████▉ | 375/469 [12:25<02:55, 1.86s/it]
|
| 605 |
80%|████████ | 376/469 [12:26<02:51, 1.84s/it]
|
| 606 |
80%|████████ | 377/469 [12:28<02:51, 1.86s/it]
|
| 607 |
81%|████████ | 378/469 [12:30<02:47, 1.84s/it]
|
| 608 |
81%|████████ | 379/469 [12:32<02:43, 1.81s/it]
|
| 609 |
81%|████████ | 380/469 [12:34<02:48, 1.89s/it]
|
| 610 |
81%|████████ | 381/469 [12:36<02:45, 1.88s/it]
|
| 611 |
81%|████████▏ | 382/469 [12:38<02:43, 1.88s/it]
|
| 612 |
82%|████████▏ | 383/469 [12:39<02:38, 1.85s/it]
|
| 613 |
82%|████████▏ | 384/469 [12:41<02:36, 1.84s/it]
|
| 614 |
82%|████████▏ | 385/469 [12:43<02:37, 1.88s/it]
|
| 615 |
82%|████████▏ | 386/469 [12:45<02:35, 1.87s/it]
|
| 616 |
83%|████████▎ | 387/469 [12:47<02:45, 2.02s/it]
|
| 617 |
83%|████████▎ | 388/469 [12:49<02:37, 1.94s/it]
|
| 618 |
83%|████████▎ | 389/469 [12:51<02:30, 1.88s/it]
|
| 619 |
83%|████████▎ | 390/469 [12:53<02:32, 1.93s/it]
|
| 620 |
83%|████████▎ | 391/469 [12:55<02:29, 1.92s/it]
|
| 621 |
84%|████████▎ | 392/469 [12:57<02:23, 1.87s/it]
|
| 622 |
84%|████████▍ | 393/469 [13:00<03:00, 2.38s/it]
|
| 623 |
84%|████████▍ | 394/469 [13:02<02:47, 2.23s/it]
|
| 624 |
84%|████████▍ | 395/469 [13:04<02:36, 2.12s/it]
|
| 625 |
84%|████████▍ | 396/469 [13:06<02:27, 2.02s/it]
|
| 626 |
85%|████████▍ | 397/469 [13:08<02:24, 2.00s/it]
|
| 627 |
85%|████████▍ | 398/469 [13:09<02:19, 1.96s/it]
|
| 628 |
85%|████████▌ | 399/469 [13:11<02:14, 1.92s/it]
|
| 629 |
85%|████████▌ | 400/469 [13:13<02:08, 1.86s/it]
|
| 630 |
86%|████████▌ | 401/469 [13:15<02:05, 1.85s/it]
|
| 631 |
86%|████████▌ | 402/469 [13:17<02:02, 1.84s/it]
|
| 632 |
86%|████████▌ | 403/469 [13:18<01:59, 1.81s/it]
|
| 633 |
86%|████████▌ | 404/469 [13:20<02:02, 1.89s/it]
|
| 634 |
86%|████████▋ | 405/469 [13:22<01:59, 1.87s/it]
|
| 635 |
87%|████████▋ | 406/469 [13:24<01:56, 1.85s/it]
|
| 636 |
87%|████████▋ | 407/469 [13:26<01:52, 1.81s/it]
|
| 637 |
87%|████████▋ | 408/469 [13:28<01:50, 1.81s/it]
|
| 638 |
87%|████████▋ | 409/469 [13:30<01:50, 1.84s/it]
|
| 639 |
87%|████████▋ | 410/469 [13:32<01:55, 1.95s/it]
|
| 640 |
88%|████████▊ | 411/469 [13:33<01:49, 1.89s/it]
|
| 641 |
88%|████████▊ | 412/469 [13:35<01:48, 1.90s/it]
|
| 642 |
88%|████████▊ | 413/469 [13:37<01:45, 1.89s/it]
|
| 643 |
88%|████████▊ | 414/469 [13:39<01:42, 1.86s/it]
|
| 644 |
88%|████████▊ | 415/469 [13:41<01:38, 1.82s/it]
|
| 645 |
89%|████████▊ | 416/469 [13:43<01:39, 1.87s/it]
|
| 646 |
89%|████████▉ | 417/469 [13:45<01:38, 1.90s/it]
|
| 647 |
89%|████████▉ | 418/469 [13:50<02:30, 2.96s/it]
|
| 648 |
89%|████████▉ | 419/469 [13:52<02:12, 2.65s/it]
|
| 649 |
90%|████████▉ | 420/469 [13:54<02:01, 2.47s/it]
|
| 650 |
90%|████████▉ | 421/469 [13:56<01:48, 2.26s/it]
|
| 651 |
90%|████████▉ | 422/469 [13:58<01:39, 2.12s/it]
|
| 652 |
90%|█████████ | 423/469 [14:00<01:33, 2.02s/it]
|
| 653 |
90%|█████████ | 424/469 [14:01<01:29, 2.00s/it]
|
| 654 |
91%|█████████ | 425/469 [14:03<01:24, 1.93s/it]
|
| 655 |
91%|█████████ | 426/469 [14:05<01:22, 1.93s/it]
|
| 656 |
91%|█████████ | 427/469 [14:07<01:19, 1.90s/it]
|
| 657 |
91%|█████████▏| 428/469 [14:10<01:30, 2.21s/it]
|
| 658 |
91%|█████████▏| 429/469 [14:12<01:23, 2.09s/it]
|
| 659 |
92%|█████████▏| 430/469 [14:13<01:17, 2.00s/it]
|
| 660 |
92%|█████████▏| 431/469 [14:16<01:17, 2.05s/it]
|
| 661 |
92%|█████████▏| 432/469 [14:17<01:12, 1.96s/it]
|
| 662 |
92%|█████████▏| 433/469 [14:19<01:09, 1.93s/it]
|
| 663 |
93%|█████████▎| 434/469 [14:21<01:06, 1.89s/it]
|
| 664 |
93%|█████████▎| 435/469 [14:23<01:03, 1.88s/it]
|
| 665 |
93%|█████████▎| 436/469 [14:25<01:06, 2.01s/it]
|
| 666 |
93%|█████████▎| 437/469 [14:27<01:02, 1.95s/it]
|
| 667 |
93%|█████████▎| 438/469 [14:29<00:59, 1.93s/it]
|
| 668 |
94%|█████████▎| 439/469 [14:31<00:57, 1.91s/it]
|
| 669 |
94%|█████████▍| 440/469 [14:33<00:59, 2.04s/it]
|
| 670 |
94%|█████████▍| 441/469 [14:35<00:55, 1.99s/it]
|
| 671 |
94%|█████████▍| 442/469 [14:37<00:53, 1.97s/it]
|
| 672 |
94%|█████████▍| 443/469 [14:39<00:49, 1.92s/it]
|
| 673 |
95%|█████████▍| 444/469 [14:41<00:47, 1.89s/it]
|
| 674 |
95%|█████████▍| 445/469 [14:42<00:44, 1.87s/it]
|
| 675 |
95%|█████████▌| 446/469 [14:44<00:44, 1.94s/it]
|
| 676 |
95%|█████████▌| 447/469 [14:46<00:41, 1.88s/it]
|
| 677 |
96%|█████████▌| 448/469 [14:48<00:38, 1.83s/it]
|
| 678 |
96%|█████████▌| 449/469 [14:50<00:37, 1.89s/it]
|
| 679 |
96%|█████████▌| 450/469 [14:52<00:34, 1.84s/it]
|
| 680 |
96%|█████████▌| 451/469 [14:53<00:32, 1.81s/it]
|
| 681 |
96%|█████████▋| 452/469 [14:55<00:30, 1.80s/it]
|
| 682 |
97%|█████████▋| 453/469 [14:57<00:28, 1.77s/it]
|
| 683 |
97%|█████████▋| 454/469 [14:59<00:27, 1.85s/it]
|
| 684 |
97%|█████████▋| 455/469 [15:01<00:25, 1.84s/it]
|
| 685 |
97%|█████████▋| 456/469 [15:03<00:24, 1.92s/it]
|
| 686 |
97%|█████████▋| 457/469 [15:05<00:22, 1.88s/it]
|
| 687 |
98%|█████████▊| 458/469 [15:07<00:20, 1.88s/it]
|
| 688 |
98%|█████████▊| 459/469 [15:09<00:19, 1.92s/it]
|
| 689 |
98%|█████████▊| 460/469 [15:10<00:17, 1.92s/it]
|
| 690 |
98%|█████████▊| 461/469 [15:12<00:15, 1.90s/it]
|
| 691 |
99%|█████████▊| 462/469 [15:15<00:14, 2.11s/it]
|
| 692 |
99%|█████████▊| 463/469 [15:17<00:13, 2.18s/it]
|
| 693 |
99%|█████████▉| 464/469 [15:19<00:10, 2.09s/it]
|
| 694 |
99%|█████████▉| 465/469 [15:21<00:08, 2.07s/it]
|
| 695 |
99%|█████████▉| 466/469 [15:23<00:06, 2.02s/it]
|
| 696 |
+
computing/reading sample batch statistics...
|
| 697 |
+
Computing evaluations...
|
| 698 |
+
Inception Score: 37.95753860473633
|
| 699 |
+
FID: 21.04736987276152
|
| 700 |
+
sFID: 71.53442942455422
|
| 701 |
+
Precision: 0.6907333333333333
|
| 702 |
+
Recall: 0.35639366212898904
|
evaluate.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
REF_BATCH="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz"
|
| 4 |
+
CUDA_VISIBLE_DEVICES=1 nohup python evaluator_rf.py \
|
| 5 |
+
--ref_batch ${REF_BATCH} \
|
| 6 |
+
--sample_batch /gemini/space/gzy_new/models/Sida/sd3_rectified_samples_new_batch_2.npz \
|
| 7 |
+
> eval_rectified_noise_new_batch_2.log 2>&1 &
|
| 8 |
+
# CUDA_VISIBLE_DEVICES=0 nohup python evaluator_rf.py \
|
| 9 |
+
# --ref_batch ${REF_BATCH} \
|
| 10 |
+
# --sample_batch "/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512.npz" \
|
| 11 |
+
# > eval_baseline.log 2>&1 &
|
evaluator_base copy.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
#/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
|
| 30 |
+
parser.add_argument("--ref_batch", default='/gemini/space/dataset/coco/coco_train_3w.npz',help="path to reference batch npz file")
|
| 31 |
+
parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_lora_samples/batch32-rank64-last-sd3-sd3-lora-finetuned-batch-32-guidance-7.0-steps-20-size-512x512.npz', help="path to sample batch npz file")
|
| 32 |
+
parser.add_argument("--save_path", default='/gemini/space/gzy/w_w_last/w_w_sit_last1/temp/',help="path to sample batch npz file")
|
| 33 |
+
parser.add_argument("--cfg_cond", default=1, type=int)
|
| 34 |
+
parser.add_argument("--step", default=1, type=int)
|
| 35 |
+
parser.add_argument("--cfg", default=1.0, type=float)
|
| 36 |
+
parser.add_argument("--cls_cfg", default=1.0, type=float)
|
| 37 |
+
parser.add_argument("--gh", default=1.0, type=float)
|
| 38 |
+
parser.add_argument("--num_steps", default=250, type=int)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
if not os.path.exists(args.save_path):
|
| 42 |
+
os.mkdir(args.save_path)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
config = tf.ConfigProto(
|
| 46 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 47 |
+
)
|
| 48 |
+
config.gpu_options.allow_growth = True
|
| 49 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 50 |
+
|
| 51 |
+
print("warming up TensorFlow...")
|
| 52 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 53 |
+
# than after the next print(), to help prevent confusion.
|
| 54 |
+
evaluator.warmup()
|
| 55 |
+
|
| 56 |
+
print("computing reference batch activations...")
|
| 57 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 58 |
+
print("computing/reading reference batch statistics...")
|
| 59 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 60 |
+
|
| 61 |
+
print("computing sample batch activations...")
|
| 62 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 63 |
+
print("computing/reading sample batch statistics...")
|
| 64 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 65 |
+
|
| 66 |
+
print("Computing evaluations...")
|
| 67 |
+
Inception_Score = evaluator.compute_inception_score(sample_acts[0])
|
| 68 |
+
FID = sample_stats.frechet_distance(ref_stats)
|
| 69 |
+
sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
| 70 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 71 |
+
|
| 72 |
+
print("Inception Score:", Inception_Score)
|
| 73 |
+
print("FID:", FID)
|
| 74 |
+
print("sFID:", sFID)
|
| 75 |
+
print("Precision:", prec)
|
| 76 |
+
print("Recall:", recall)
|
| 77 |
+
|
| 78 |
+
if args.cfg_cond:
|
| 79 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
|
| 80 |
+
else:
|
| 81 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
|
| 82 |
+
with open(file_path, "w") as file:
|
| 83 |
+
file.write("Inception Score: {}\n".format(Inception_Score))
|
| 84 |
+
file.write("FID: {}\n".format(FID))
|
| 85 |
+
file.write("sFID: {}\n".format(sFID))
|
| 86 |
+
file.write("Precision: {}\n".format(prec))
|
| 87 |
+
file.write("Recall: {}\n".format(recall))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InvalidFIDException(Exception):
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class FIDStatistics:
|
| 95 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 96 |
+
self.mu = mu
|
| 97 |
+
self.sigma = sigma
|
| 98 |
+
|
| 99 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 100 |
+
"""
|
| 101 |
+
Compute the Frechet distance between two sets of statistics.
|
| 102 |
+
"""
|
| 103 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 104 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 105 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 106 |
+
|
| 107 |
+
mu1 = np.atleast_1d(mu1)
|
| 108 |
+
mu2 = np.atleast_1d(mu2)
|
| 109 |
+
|
| 110 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 111 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 112 |
+
|
| 113 |
+
assert (
|
| 114 |
+
mu1.shape == mu2.shape
|
| 115 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 116 |
+
assert (
|
| 117 |
+
sigma1.shape == sigma2.shape
|
| 118 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 119 |
+
|
| 120 |
+
diff = mu1 - mu2
|
| 121 |
+
|
| 122 |
+
# product might be almost singular
|
| 123 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 124 |
+
if not np.isfinite(covmean).all():
|
| 125 |
+
msg = (
|
| 126 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 127 |
+
% eps
|
| 128 |
+
)
|
| 129 |
+
warnings.warn(msg)
|
| 130 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 131 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 132 |
+
|
| 133 |
+
# numerical error might give slight imaginary component
|
| 134 |
+
if np.iscomplexobj(covmean):
|
| 135 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 136 |
+
m = np.max(np.abs(covmean.imag))
|
| 137 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 138 |
+
covmean = covmean.real
|
| 139 |
+
|
| 140 |
+
tr_covmean = np.trace(covmean)
|
| 141 |
+
|
| 142 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Evaluator:
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
session,
|
| 149 |
+
batch_size=64,
|
| 150 |
+
softmax_batch_size=512,
|
| 151 |
+
):
|
| 152 |
+
self.sess = session
|
| 153 |
+
self.batch_size = batch_size
|
| 154 |
+
self.softmax_batch_size = softmax_batch_size
|
| 155 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 156 |
+
with self.sess.graph.as_default():
|
| 157 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 158 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 159 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 160 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 161 |
+
|
| 162 |
+
def warmup(self):
|
| 163 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 164 |
+
|
| 165 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 166 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 167 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 168 |
+
|
| 169 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 170 |
+
"""
|
| 171 |
+
Compute image features for downstream evals.
|
| 172 |
+
|
| 173 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 174 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 175 |
+
dimension. The tuple is (pool_3, spatial).
|
| 176 |
+
"""
|
| 177 |
+
preds = []
|
| 178 |
+
spatial_preds = []
|
| 179 |
+
for batch in tqdm(batches):
|
| 180 |
+
batch = batch.astype(np.float32)
|
| 181 |
+
pred, spatial_pred = self.sess.run(
|
| 182 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 183 |
+
)
|
| 184 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 185 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 186 |
+
return (
|
| 187 |
+
np.concatenate(preds, axis=0),
|
| 188 |
+
np.concatenate(spatial_preds, axis=0),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def read_statistics(
|
| 192 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 193 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 194 |
+
obj = np.load(npz_path)
|
| 195 |
+
if "mu" in list(obj.keys()):
|
| 196 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 197 |
+
obj["mu_s"], obj["sigma_s"]
|
| 198 |
+
)
|
| 199 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 200 |
+
|
| 201 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 202 |
+
mu = np.mean(activations, axis=0)
|
| 203 |
+
sigma = np.cov(activations, rowvar=False)
|
| 204 |
+
return FIDStatistics(mu, sigma)
|
| 205 |
+
|
| 206 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 207 |
+
softmax_out = []
|
| 208 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 209 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 210 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 211 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 212 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 213 |
+
scores = []
|
| 214 |
+
for i in range(0, len(preds), split_size):
|
| 215 |
+
part = preds[i : i + split_size]
|
| 216 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 217 |
+
kl = np.mean(np.sum(kl, 1))
|
| 218 |
+
scores.append(np.exp(kl))
|
| 219 |
+
return float(np.mean(scores))
|
| 220 |
+
|
| 221 |
+
def compute_prec_recall(
|
| 222 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 223 |
+
) -> Tuple[float, float]:
|
| 224 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 225 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 226 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 227 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 228 |
+
)
|
| 229 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class ManifoldEstimator:
|
| 233 |
+
"""
|
| 234 |
+
A helper for comparing manifolds of feature vectors.
|
| 235 |
+
|
| 236 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
session,
|
| 242 |
+
row_batch_size=10000,
|
| 243 |
+
col_batch_size=10000,
|
| 244 |
+
nhood_sizes=(3,),
|
| 245 |
+
clamp_to_percentile=None,
|
| 246 |
+
eps=1e-5,
|
| 247 |
+
):
|
| 248 |
+
"""
|
| 249 |
+
Estimate the manifold of given feature vectors.
|
| 250 |
+
|
| 251 |
+
:param session: the TensorFlow session.
|
| 252 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 253 |
+
(parameter to trade-off between memory usage and performance).
|
| 254 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 255 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 256 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 257 |
+
the given percentile.
|
| 258 |
+
:param eps: small number for numerical stability.
|
| 259 |
+
"""
|
| 260 |
+
self.distance_block = DistanceBlock(session)
|
| 261 |
+
self.row_batch_size = row_batch_size
|
| 262 |
+
self.col_batch_size = col_batch_size
|
| 263 |
+
self.nhood_sizes = nhood_sizes
|
| 264 |
+
self.num_nhoods = len(nhood_sizes)
|
| 265 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 266 |
+
self.eps = eps
|
| 267 |
+
|
| 268 |
+
def warmup(self):
|
| 269 |
+
feats, radii = (
|
| 270 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 271 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 272 |
+
)
|
| 273 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 274 |
+
|
| 275 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 276 |
+
num_images = len(features)
|
| 277 |
+
|
| 278 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 279 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 280 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 281 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 282 |
+
|
| 283 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 284 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 285 |
+
row_batch = features[begin1:end1]
|
| 286 |
+
|
| 287 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 288 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 289 |
+
col_batch = features[begin2:end2]
|
| 290 |
+
|
| 291 |
+
# Compute distances between batches.
|
| 292 |
+
distance_batch[
|
| 293 |
+
0 : end1 - begin1, begin2:end2
|
| 294 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 295 |
+
|
| 296 |
+
# Find the k-nearest neighbor from the current batch.
|
| 297 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 298 |
+
[
|
| 299 |
+
x[:, self.nhood_sizes]
|
| 300 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 301 |
+
],
|
| 302 |
+
axis=0,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if self.clamp_to_percentile is not None:
|
| 306 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 307 |
+
radii[radii > max_distances] = 0
|
| 308 |
+
return radii
|
| 309 |
+
|
| 310 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 311 |
+
"""
|
| 312 |
+
Evaluate if new feature vectors are at the manifold.
|
| 313 |
+
"""
|
| 314 |
+
num_eval_images = eval_features.shape[0]
|
| 315 |
+
num_ref_images = radii.shape[0]
|
| 316 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 317 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 318 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 319 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 320 |
+
|
| 321 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 322 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 323 |
+
feature_batch = eval_features[begin1:end1]
|
| 324 |
+
|
| 325 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 326 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 327 |
+
ref_batch = features[begin2:end2]
|
| 328 |
+
|
| 329 |
+
distance_batch[
|
| 330 |
+
0 : end1 - begin1, begin2:end2
|
| 331 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 332 |
+
|
| 333 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 334 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 335 |
+
# the new sample lies at the estimated manifold.
|
| 336 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 337 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 338 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 339 |
+
|
| 340 |
+
max_realism_score[begin1:end1] = np.max(
|
| 341 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 342 |
+
)
|
| 343 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 344 |
+
|
| 345 |
+
return {
|
| 346 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 347 |
+
"batch_predictions": batch_predictions,
|
| 348 |
+
"max_realisim_score": max_realism_score,
|
| 349 |
+
"nearest_indices": nearest_indices,
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
def evaluate_pr(
|
| 353 |
+
self,
|
| 354 |
+
features_1: np.ndarray,
|
| 355 |
+
radii_1: np.ndarray,
|
| 356 |
+
features_2: np.ndarray,
|
| 357 |
+
radii_2: np.ndarray,
|
| 358 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 359 |
+
"""
|
| 360 |
+
Evaluate precision and recall efficiently.
|
| 361 |
+
|
| 362 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 363 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 364 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 365 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 366 |
+
:return: a tuple of arrays for (precision, recall):
|
| 367 |
+
- precision: an np.ndarray of length K1
|
| 368 |
+
- recall: an np.ndarray of length K2
|
| 369 |
+
"""
|
| 370 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
|
| 371 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
|
| 372 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 373 |
+
end_1 = begin_1 + self.row_batch_size
|
| 374 |
+
batch_1 = features_1[begin_1:end_1]
|
| 375 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 376 |
+
end_2 = begin_2 + self.col_batch_size
|
| 377 |
+
batch_2 = features_2[begin_2:end_2]
|
| 378 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 379 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 380 |
+
)
|
| 381 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 382 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 383 |
+
return (
|
| 384 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 385 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class DistanceBlock:
|
| 390 |
+
"""
|
| 391 |
+
Calculate pairwise distances between vectors.
|
| 392 |
+
|
| 393 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def __init__(self, session):
|
| 397 |
+
self.session = session
|
| 398 |
+
|
| 399 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 400 |
+
with session.graph.as_default():
|
| 401 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 402 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 403 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 404 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 405 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 406 |
+
)
|
| 407 |
+
self.distance_block = tf.cond(
|
| 408 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 409 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 410 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Extra logic for less thans.
|
| 414 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 415 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 416 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 417 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 418 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 419 |
+
|
| 420 |
+
def pairwise_distances(self, U, V):
|
| 421 |
+
"""
|
| 422 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 423 |
+
"""
|
| 424 |
+
return self.session.run(
|
| 425 |
+
self.distance_block,
|
| 426 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 430 |
+
return self.session.run(
|
| 431 |
+
[self._batch_1_in, self._batch_2_in],
|
| 432 |
+
feed_dict={
|
| 433 |
+
self._features_batch1: batch_1,
|
| 434 |
+
self._features_batch2: batch_2,
|
| 435 |
+
self._radii1: radii_1,
|
| 436 |
+
self._radii2: radii_2,
|
| 437 |
+
},
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def _batch_pairwise_distances(U, V):
|
| 442 |
+
"""
|
| 443 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 444 |
+
"""
|
| 445 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 446 |
+
# Squared norms of each row in U and V.
|
| 447 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 448 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 449 |
+
|
| 450 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 451 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 452 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 453 |
+
|
| 454 |
+
# Pairwise squared Euclidean distances.
|
| 455 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 456 |
+
|
| 457 |
+
return D
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class NpzArrayReader(ABC):
|
| 461 |
+
@abstractmethod
|
| 462 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 463 |
+
pass
|
| 464 |
+
|
| 465 |
+
@abstractmethod
|
| 466 |
+
def remaining(self) -> int:
|
| 467 |
+
pass
|
| 468 |
+
|
| 469 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 470 |
+
def gen_fn():
|
| 471 |
+
while True:
|
| 472 |
+
batch = self.read_batch(batch_size)
|
| 473 |
+
if batch is None:
|
| 474 |
+
break
|
| 475 |
+
yield batch
|
| 476 |
+
|
| 477 |
+
rem = self.remaining()
|
| 478 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 479 |
+
return BatchIterator(gen_fn, num_batches)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class BatchIterator:
|
| 483 |
+
def __init__(self, gen_fn, length):
|
| 484 |
+
self.gen_fn = gen_fn
|
| 485 |
+
self.length = length
|
| 486 |
+
|
| 487 |
+
def __len__(self):
|
| 488 |
+
return self.length
|
| 489 |
+
|
| 490 |
+
def __iter__(self):
|
| 491 |
+
return self.gen_fn()
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 495 |
+
def __init__(self, arr_f, shape, dtype):
|
| 496 |
+
self.arr_f = arr_f
|
| 497 |
+
self.shape = shape
|
| 498 |
+
self.dtype = dtype
|
| 499 |
+
self.idx = 0
|
| 500 |
+
|
| 501 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 502 |
+
if self.idx >= self.shape[0]:
|
| 503 |
+
return None
|
| 504 |
+
|
| 505 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 506 |
+
self.idx += bs
|
| 507 |
+
|
| 508 |
+
if self.dtype.itemsize == 0:
|
| 509 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 510 |
+
|
| 511 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 512 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 513 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 514 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 515 |
+
|
| 516 |
+
def remaining(self) -> int:
|
| 517 |
+
return max(0, self.shape[0] - self.idx)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 521 |
+
def __init__(self, arr):
|
| 522 |
+
self.arr = arr
|
| 523 |
+
self.idx = 0
|
| 524 |
+
|
| 525 |
+
@classmethod
|
| 526 |
+
def load(cls, path: str, arr_name: str):
|
| 527 |
+
with open(path, "rb") as f:
|
| 528 |
+
arr = np.load(f)[arr_name]
|
| 529 |
+
return cls(arr)
|
| 530 |
+
|
| 531 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 532 |
+
if self.idx >= self.arr.shape[0]:
|
| 533 |
+
return None
|
| 534 |
+
|
| 535 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 536 |
+
self.idx += batch_size
|
| 537 |
+
return res
|
| 538 |
+
|
| 539 |
+
def remaining(self) -> int:
|
| 540 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@contextmanager
|
| 544 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 545 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 546 |
+
version = np.lib.format.read_magic(arr_f)
|
| 547 |
+
if version == (1, 0):
|
| 548 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 549 |
+
elif version == (2, 0):
|
| 550 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 551 |
+
else:
|
| 552 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 553 |
+
return
|
| 554 |
+
shape, fortran, dtype = header
|
| 555 |
+
if fortran or dtype.hasobject:
|
| 556 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 557 |
+
else:
|
| 558 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 562 |
+
"""
|
| 563 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 564 |
+
|
| 565 |
+
Read from file-like object until size bytes are read.
|
| 566 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 567 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 568 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 569 |
+
requested.
|
| 570 |
+
"""
|
| 571 |
+
data = bytes()
|
| 572 |
+
while True:
|
| 573 |
+
# io files (default in python3) return None or raise on
|
| 574 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 575 |
+
# done about that. note that regular files can't be non-blocking
|
| 576 |
+
try:
|
| 577 |
+
r = fp.read(size - len(data))
|
| 578 |
+
data += r
|
| 579 |
+
if len(r) == 0 or len(data) == size:
|
| 580 |
+
break
|
| 581 |
+
except io.BlockingIOError:
|
| 582 |
+
pass
|
| 583 |
+
if len(data) != size:
|
| 584 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 585 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 586 |
+
else:
|
| 587 |
+
return data
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
@contextmanager
|
| 591 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 592 |
+
with open(path, "rb") as f:
|
| 593 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 594 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 595 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 596 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 597 |
+
yield arr_f
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _download_inception_model():
|
| 601 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 602 |
+
return
|
| 603 |
+
print("downloading InceptionV3 model...")
|
| 604 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 605 |
+
r.raise_for_status()
|
| 606 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 607 |
+
with open(tmp_path, "wb") as f:
|
| 608 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 609 |
+
f.write(chunk)
|
| 610 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def _create_feature_graph(input_batch):
|
| 614 |
+
_download_inception_model()
|
| 615 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 616 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 617 |
+
graph_def = tf.GraphDef()
|
| 618 |
+
graph_def.ParseFromString(f.read())
|
| 619 |
+
pool3, spatial = tf.import_graph_def(
|
| 620 |
+
graph_def,
|
| 621 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 622 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 623 |
+
name=prefix,
|
| 624 |
+
)
|
| 625 |
+
_update_shapes(pool3)
|
| 626 |
+
spatial = spatial[..., :7]
|
| 627 |
+
return pool3, spatial
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def _create_softmax_graph(input_batch):
|
| 631 |
+
_download_inception_model()
|
| 632 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 633 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 634 |
+
graph_def = tf.GraphDef()
|
| 635 |
+
graph_def.ParseFromString(f.read())
|
| 636 |
+
(matmul,) = tf.import_graph_def(
|
| 637 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 638 |
+
)
|
| 639 |
+
w = matmul.inputs[1]
|
| 640 |
+
logits = tf.matmul(input_batch, w)
|
| 641 |
+
return tf.nn.softmax(logits)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def _update_shapes(pool3):
|
| 645 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 646 |
+
ops = pool3.graph.get_operations()
|
| 647 |
+
for op in ops:
|
| 648 |
+
for o in op.outputs:
|
| 649 |
+
shape = o.get_shape()
|
| 650 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 651 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 652 |
+
shape = [s for s in shape] # TF 2.x
|
| 653 |
+
new_shape = []
|
| 654 |
+
for j, s in enumerate(shape):
|
| 655 |
+
if s == 1 and j == 0:
|
| 656 |
+
new_shape.append(None)
|
| 657 |
+
else:
|
| 658 |
+
new_shape.append(s)
|
| 659 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 660 |
+
return pool3
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 664 |
+
num_workers = min(cpu_count(), len(arr))
|
| 665 |
+
chunk_size = len(arr) // num_workers
|
| 666 |
+
extra = len(arr) % num_workers
|
| 667 |
+
|
| 668 |
+
start_idx = 0
|
| 669 |
+
batches = []
|
| 670 |
+
for i in range(num_workers):
|
| 671 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 672 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 673 |
+
start_idx += size
|
| 674 |
+
|
| 675 |
+
with ThreadPool(num_workers) as pool:
|
| 676 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
if __name__ == "__main__":
|
| 680 |
+
main()
|
evaluator_base.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
Traceback (most recent call last):
|
| 3 |
+
File "/gemini/space/gzy_new/models/Sida/evaluator_base.py", line 16, in <module>
|
| 4 |
+
import tensorflow.compat.v1 as tf
|
| 5 |
+
ModuleNotFoundError: No module named 'tensorflow'
|
evaluator_base.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
#/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
|
| 30 |
+
parser.add_argument("--ref_batch", default='/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz',help="path to reference batch npz file")
|
| 31 |
+
parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512.npz', help="path to sample batch npz file")
|
| 32 |
+
parser.add_argument("--save_path", default='/gemini/space/gzy_new/models/Sida/sd3_lora_samples_3w/checkpoint-checkpoint-500000-rank32-guidance-7.0-steps-40-size-512x512/result',help="path to sample batch npz file")
|
| 33 |
+
parser.add_argument("--cfg_cond", default=1, type=int)
|
| 34 |
+
parser.add_argument("--step", default=1, type=int)
|
| 35 |
+
parser.add_argument("--cfg", default=1.0, type=float)
|
| 36 |
+
parser.add_argument("--cls_cfg", default=1.0, type=float)
|
| 37 |
+
parser.add_argument("--gh", default=1.0, type=float)
|
| 38 |
+
parser.add_argument("--num_steps", default=50, type=int)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
if not os.path.exists(args.save_path):
|
| 42 |
+
os.mkdir(args.save_path)
|
| 43 |
+
|
| 44 |
+
# NOTE: 当前环境中 TensorFlow 与 CUDA/cuDNN 可能版本不匹配(例如报 "No DNN in stream executor"),
|
| 45 |
+
# 这会导致 GPU 计算失败。这里强制使用 CPU 进行评估(会慢一些,但能保证运行)。
|
| 46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 47 |
+
|
| 48 |
+
config = tf.ConfigProto(
|
| 49 |
+
allow_soft_placement=True, # allows DecodeJpeg to run on CPU in Inception graph
|
| 50 |
+
device_count={"GPU": 0},
|
| 51 |
+
)
|
| 52 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 53 |
+
|
| 54 |
+
print("warming up TensorFlow...")
|
| 55 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 56 |
+
# than after the next print(), to help prevent confusion.
|
| 57 |
+
evaluator.warmup()
|
| 58 |
+
|
| 59 |
+
print("computing reference batch activations...")
|
| 60 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 61 |
+
print("computing/reading reference batch statistics...")
|
| 62 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 63 |
+
|
| 64 |
+
print("computing sample batch activations...")
|
| 65 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 66 |
+
print("computing/reading sample batch statistics...")
|
| 67 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 68 |
+
|
| 69 |
+
print("Computing evaluations...")
|
| 70 |
+
Inception_Score = evaluator.compute_inception_score(sample_acts[0])
|
| 71 |
+
FID = sample_stats.frechet_distance(ref_stats)
|
| 72 |
+
sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
| 73 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 74 |
+
|
| 75 |
+
print("Inception Score:", Inception_Score)
|
| 76 |
+
print("FID:", FID)
|
| 77 |
+
print("sFID:", sFID)
|
| 78 |
+
print("Precision:", prec)
|
| 79 |
+
print("Recall:", recall)
|
| 80 |
+
|
| 81 |
+
if args.cfg_cond:
|
| 82 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
|
| 83 |
+
else:
|
| 84 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
|
| 85 |
+
with open(file_path, "w") as file:
|
| 86 |
+
file.write("Inception Score: {}\n".format(Inception_Score))
|
| 87 |
+
file.write("FID: {}\n".format(FID))
|
| 88 |
+
file.write("sFID: {}\n".format(sFID))
|
| 89 |
+
file.write("Precision: {}\n".format(prec))
|
| 90 |
+
file.write("Recall: {}\n".format(recall))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class InvalidFIDException(Exception):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FIDStatistics:
|
| 98 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 99 |
+
self.mu = mu
|
| 100 |
+
self.sigma = sigma
|
| 101 |
+
|
| 102 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 103 |
+
"""
|
| 104 |
+
Compute the Frechet distance between two sets of statistics.
|
| 105 |
+
"""
|
| 106 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 107 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 108 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 109 |
+
|
| 110 |
+
mu1 = np.atleast_1d(mu1)
|
| 111 |
+
mu2 = np.atleast_1d(mu2)
|
| 112 |
+
|
| 113 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 114 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 115 |
+
|
| 116 |
+
assert (
|
| 117 |
+
mu1.shape == mu2.shape
|
| 118 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 119 |
+
assert (
|
| 120 |
+
sigma1.shape == sigma2.shape
|
| 121 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 122 |
+
|
| 123 |
+
diff = mu1 - mu2
|
| 124 |
+
|
| 125 |
+
# product might be almost singular
|
| 126 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 127 |
+
if not np.isfinite(covmean).all():
|
| 128 |
+
msg = (
|
| 129 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 130 |
+
% eps
|
| 131 |
+
)
|
| 132 |
+
warnings.warn(msg)
|
| 133 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 134 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 135 |
+
|
| 136 |
+
# numerical error might give slight imaginary component
|
| 137 |
+
if np.iscomplexobj(covmean):
|
| 138 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 139 |
+
m = np.max(np.abs(covmean.imag))
|
| 140 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 141 |
+
covmean = covmean.real
|
| 142 |
+
|
| 143 |
+
tr_covmean = np.trace(covmean)
|
| 144 |
+
|
| 145 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Evaluator:
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
session,
|
| 152 |
+
batch_size=64,
|
| 153 |
+
softmax_batch_size=512,
|
| 154 |
+
):
|
| 155 |
+
self.sess = session
|
| 156 |
+
self.batch_size = batch_size
|
| 157 |
+
self.softmax_batch_size = softmax_batch_size
|
| 158 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 159 |
+
with self.sess.graph.as_default():
|
| 160 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 161 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 162 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 163 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 164 |
+
|
| 165 |
+
def warmup(self):
|
| 166 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 167 |
+
|
| 168 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 169 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 170 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 171 |
+
|
| 172 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 173 |
+
"""
|
| 174 |
+
Compute image features for downstream evals.
|
| 175 |
+
|
| 176 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 177 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 178 |
+
dimension. The tuple is (pool_3, spatial).
|
| 179 |
+
"""
|
| 180 |
+
preds = []
|
| 181 |
+
spatial_preds = []
|
| 182 |
+
for batch in tqdm(batches):
|
| 183 |
+
batch = batch.astype(np.float32)
|
| 184 |
+
pred, spatial_pred = self.sess.run(
|
| 185 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 186 |
+
)
|
| 187 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 188 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 189 |
+
return (
|
| 190 |
+
np.concatenate(preds, axis=0),
|
| 191 |
+
np.concatenate(spatial_preds, axis=0),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def read_statistics(
|
| 195 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 196 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 197 |
+
obj = np.load(npz_path)
|
| 198 |
+
if "mu" in list(obj.keys()):
|
| 199 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 200 |
+
obj["mu_s"], obj["sigma_s"]
|
| 201 |
+
)
|
| 202 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 203 |
+
|
| 204 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 205 |
+
mu = np.mean(activations, axis=0)
|
| 206 |
+
sigma = np.cov(activations, rowvar=False)
|
| 207 |
+
return FIDStatistics(mu, sigma)
|
| 208 |
+
|
| 209 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 210 |
+
softmax_out = []
|
| 211 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 212 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 213 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 214 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 215 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 216 |
+
scores = []
|
| 217 |
+
for i in range(0, len(preds), split_size):
|
| 218 |
+
part = preds[i : i + split_size]
|
| 219 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 220 |
+
kl = np.mean(np.sum(kl, 1))
|
| 221 |
+
scores.append(np.exp(kl))
|
| 222 |
+
return float(np.mean(scores))
|
| 223 |
+
|
| 224 |
+
def compute_prec_recall(
|
| 225 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 226 |
+
) -> Tuple[float, float]:
|
| 227 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 228 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 229 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 230 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 231 |
+
)
|
| 232 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ManifoldEstimator:
|
| 236 |
+
"""
|
| 237 |
+
A helper for comparing manifolds of feature vectors.
|
| 238 |
+
|
| 239 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
session,
|
| 245 |
+
row_batch_size=10000,
|
| 246 |
+
col_batch_size=10000,
|
| 247 |
+
nhood_sizes=(3,),
|
| 248 |
+
clamp_to_percentile=None,
|
| 249 |
+
eps=1e-5,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Estimate the manifold of given feature vectors.
|
| 253 |
+
|
| 254 |
+
:param session: the TensorFlow session.
|
| 255 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 256 |
+
(parameter to trade-off between memory usage and performance).
|
| 257 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 258 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 259 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 260 |
+
the given percentile.
|
| 261 |
+
:param eps: small number for numerical stability.
|
| 262 |
+
"""
|
| 263 |
+
self.distance_block = DistanceBlock(session)
|
| 264 |
+
self.row_batch_size = row_batch_size
|
| 265 |
+
self.col_batch_size = col_batch_size
|
| 266 |
+
self.nhood_sizes = nhood_sizes
|
| 267 |
+
self.num_nhoods = len(nhood_sizes)
|
| 268 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 269 |
+
self.eps = eps
|
| 270 |
+
|
| 271 |
+
def warmup(self):
|
| 272 |
+
feats, radii = (
|
| 273 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 274 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 275 |
+
)
|
| 276 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 277 |
+
|
| 278 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 279 |
+
num_images = len(features)
|
| 280 |
+
|
| 281 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 282 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 283 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 284 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 285 |
+
|
| 286 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 287 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 288 |
+
row_batch = features[begin1:end1]
|
| 289 |
+
|
| 290 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 291 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 292 |
+
col_batch = features[begin2:end2]
|
| 293 |
+
|
| 294 |
+
# Compute distances between batches.
|
| 295 |
+
distance_batch[
|
| 296 |
+
0 : end1 - begin1, begin2:end2
|
| 297 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 298 |
+
|
| 299 |
+
# Find the k-nearest neighbor from the current batch.
|
| 300 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 301 |
+
[
|
| 302 |
+
x[:, self.nhood_sizes]
|
| 303 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 304 |
+
],
|
| 305 |
+
axis=0,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if self.clamp_to_percentile is not None:
|
| 309 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 310 |
+
radii[radii > max_distances] = 0
|
| 311 |
+
return radii
|
| 312 |
+
|
| 313 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 314 |
+
"""
|
| 315 |
+
Evaluate if new feature vectors are at the manifold.
|
| 316 |
+
"""
|
| 317 |
+
num_eval_images = eval_features.shape[0]
|
| 318 |
+
num_ref_images = radii.shape[0]
|
| 319 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 320 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 321 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 322 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 323 |
+
|
| 324 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 325 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 326 |
+
feature_batch = eval_features[begin1:end1]
|
| 327 |
+
|
| 328 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 329 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 330 |
+
ref_batch = features[begin2:end2]
|
| 331 |
+
|
| 332 |
+
distance_batch[
|
| 333 |
+
0 : end1 - begin1, begin2:end2
|
| 334 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 335 |
+
|
| 336 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 337 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 338 |
+
# the new sample lies at the estimated manifold.
|
| 339 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 340 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 341 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 342 |
+
|
| 343 |
+
max_realism_score[begin1:end1] = np.max(
|
| 344 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 345 |
+
)
|
| 346 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 347 |
+
|
| 348 |
+
return {
|
| 349 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 350 |
+
"batch_predictions": batch_predictions,
|
| 351 |
+
"max_realisim_score": max_realism_score,
|
| 352 |
+
"nearest_indices": nearest_indices,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
def evaluate_pr(
|
| 356 |
+
self,
|
| 357 |
+
features_1: np.ndarray,
|
| 358 |
+
radii_1: np.ndarray,
|
| 359 |
+
features_2: np.ndarray,
|
| 360 |
+
radii_2: np.ndarray,
|
| 361 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 362 |
+
"""
|
| 363 |
+
Evaluate precision and recall efficiently.
|
| 364 |
+
|
| 365 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 366 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 367 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 368 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 369 |
+
:return: a tuple of arrays for (precision, recall):
|
| 370 |
+
- precision: an np.ndarray of length K1
|
| 371 |
+
- recall: an np.ndarray of length K2
|
| 372 |
+
"""
|
| 373 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
|
| 374 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
|
| 375 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 376 |
+
end_1 = begin_1 + self.row_batch_size
|
| 377 |
+
batch_1 = features_1[begin_1:end_1]
|
| 378 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 379 |
+
end_2 = begin_2 + self.col_batch_size
|
| 380 |
+
batch_2 = features_2[begin_2:end_2]
|
| 381 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 382 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 383 |
+
)
|
| 384 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 385 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 386 |
+
return (
|
| 387 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 388 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class DistanceBlock:
|
| 393 |
+
"""
|
| 394 |
+
Calculate pairwise distances between vectors.
|
| 395 |
+
|
| 396 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(self, session):
|
| 400 |
+
self.session = session
|
| 401 |
+
|
| 402 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 403 |
+
with session.graph.as_default():
|
| 404 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 405 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 406 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 407 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 408 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 409 |
+
)
|
| 410 |
+
self.distance_block = tf.cond(
|
| 411 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 412 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 413 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Extra logic for less thans.
|
| 417 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 418 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 419 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 420 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 421 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 422 |
+
|
| 423 |
+
def pairwise_distances(self, U, V):
|
| 424 |
+
"""
|
| 425 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 426 |
+
"""
|
| 427 |
+
return self.session.run(
|
| 428 |
+
self.distance_block,
|
| 429 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 433 |
+
return self.session.run(
|
| 434 |
+
[self._batch_1_in, self._batch_2_in],
|
| 435 |
+
feed_dict={
|
| 436 |
+
self._features_batch1: batch_1,
|
| 437 |
+
self._features_batch2: batch_2,
|
| 438 |
+
self._radii1: radii_1,
|
| 439 |
+
self._radii2: radii_2,
|
| 440 |
+
},
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _batch_pairwise_distances(U, V):
|
| 445 |
+
"""
|
| 446 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 447 |
+
"""
|
| 448 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 449 |
+
# Squared norms of each row in U and V.
|
| 450 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 451 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 452 |
+
|
| 453 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 454 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 455 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 456 |
+
|
| 457 |
+
# Pairwise squared Euclidean distances.
|
| 458 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 459 |
+
|
| 460 |
+
return D
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class NpzArrayReader(ABC):
|
| 464 |
+
@abstractmethod
|
| 465 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
@abstractmethod
|
| 469 |
+
def remaining(self) -> int:
|
| 470 |
+
pass
|
| 471 |
+
|
| 472 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 473 |
+
def gen_fn():
|
| 474 |
+
while True:
|
| 475 |
+
batch = self.read_batch(batch_size)
|
| 476 |
+
if batch is None:
|
| 477 |
+
break
|
| 478 |
+
yield batch
|
| 479 |
+
|
| 480 |
+
rem = self.remaining()
|
| 481 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 482 |
+
return BatchIterator(gen_fn, num_batches)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class BatchIterator:
|
| 486 |
+
def __init__(self, gen_fn, length):
|
| 487 |
+
self.gen_fn = gen_fn
|
| 488 |
+
self.length = length
|
| 489 |
+
|
| 490 |
+
def __len__(self):
|
| 491 |
+
return self.length
|
| 492 |
+
|
| 493 |
+
def __iter__(self):
|
| 494 |
+
return self.gen_fn()
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 498 |
+
def __init__(self, arr_f, shape, dtype):
|
| 499 |
+
self.arr_f = arr_f
|
| 500 |
+
self.shape = shape
|
| 501 |
+
self.dtype = dtype
|
| 502 |
+
self.idx = 0
|
| 503 |
+
|
| 504 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 505 |
+
if self.idx >= self.shape[0]:
|
| 506 |
+
return None
|
| 507 |
+
|
| 508 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 509 |
+
self.idx += bs
|
| 510 |
+
|
| 511 |
+
if self.dtype.itemsize == 0:
|
| 512 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 513 |
+
|
| 514 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 515 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 516 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 517 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 518 |
+
|
| 519 |
+
def remaining(self) -> int:
|
| 520 |
+
return max(0, self.shape[0] - self.idx)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 524 |
+
def __init__(self, arr):
|
| 525 |
+
self.arr = arr
|
| 526 |
+
self.idx = 0
|
| 527 |
+
|
| 528 |
+
@classmethod
|
| 529 |
+
def load(cls, path: str, arr_name: str):
|
| 530 |
+
with open(path, "rb") as f:
|
| 531 |
+
arr = np.load(f)[arr_name]
|
| 532 |
+
return cls(arr)
|
| 533 |
+
|
| 534 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 535 |
+
if self.idx >= self.arr.shape[0]:
|
| 536 |
+
return None
|
| 537 |
+
|
| 538 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 539 |
+
self.idx += batch_size
|
| 540 |
+
return res
|
| 541 |
+
|
| 542 |
+
def remaining(self) -> int:
|
| 543 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@contextmanager
|
| 547 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 548 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 549 |
+
version = np.lib.format.read_magic(arr_f)
|
| 550 |
+
if version == (1, 0):
|
| 551 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 552 |
+
elif version == (2, 0):
|
| 553 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 554 |
+
else:
|
| 555 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 556 |
+
return
|
| 557 |
+
shape, fortran, dtype = header
|
| 558 |
+
if fortran or dtype.hasobject:
|
| 559 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 560 |
+
else:
|
| 561 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 565 |
+
"""
|
| 566 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 567 |
+
|
| 568 |
+
Read from file-like object until size bytes are read.
|
| 569 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 570 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 571 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 572 |
+
requested.
|
| 573 |
+
"""
|
| 574 |
+
data = bytes()
|
| 575 |
+
while True:
|
| 576 |
+
# io files (default in python3) return None or raise on
|
| 577 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 578 |
+
# done about that. note that regular files can't be non-blocking
|
| 579 |
+
try:
|
| 580 |
+
r = fp.read(size - len(data))
|
| 581 |
+
data += r
|
| 582 |
+
if len(r) == 0 or len(data) == size:
|
| 583 |
+
break
|
| 584 |
+
except io.BlockingIOError:
|
| 585 |
+
pass
|
| 586 |
+
if len(data) != size:
|
| 587 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 588 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 589 |
+
else:
|
| 590 |
+
return data
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
@contextmanager
|
| 594 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 595 |
+
with open(path, "rb") as f:
|
| 596 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 597 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 598 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 599 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 600 |
+
yield arr_f
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _download_inception_model():
|
| 604 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 605 |
+
return
|
| 606 |
+
print("downloading InceptionV3 model...")
|
| 607 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 608 |
+
r.raise_for_status()
|
| 609 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 610 |
+
with open(tmp_path, "wb") as f:
|
| 611 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 612 |
+
f.write(chunk)
|
| 613 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def _create_feature_graph(input_batch):
|
| 617 |
+
_download_inception_model()
|
| 618 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 619 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 620 |
+
graph_def = tf.GraphDef()
|
| 621 |
+
graph_def.ParseFromString(f.read())
|
| 622 |
+
pool3, spatial = tf.import_graph_def(
|
| 623 |
+
graph_def,
|
| 624 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 625 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 626 |
+
name=prefix,
|
| 627 |
+
)
|
| 628 |
+
_update_shapes(pool3)
|
| 629 |
+
spatial = spatial[..., :7]
|
| 630 |
+
return pool3, spatial
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def _create_softmax_graph(input_batch):
|
| 634 |
+
_download_inception_model()
|
| 635 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 636 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 637 |
+
graph_def = tf.GraphDef()
|
| 638 |
+
graph_def.ParseFromString(f.read())
|
| 639 |
+
(matmul,) = tf.import_graph_def(
|
| 640 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 641 |
+
)
|
| 642 |
+
w = matmul.inputs[1]
|
| 643 |
+
logits = tf.matmul(input_batch, w)
|
| 644 |
+
return tf.nn.softmax(logits)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def _update_shapes(pool3):
|
| 648 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 649 |
+
ops = pool3.graph.get_operations()
|
| 650 |
+
for op in ops:
|
| 651 |
+
for o in op.outputs:
|
| 652 |
+
shape = o.get_shape()
|
| 653 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 654 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 655 |
+
shape = [s for s in shape] # TF 2.x
|
| 656 |
+
new_shape = []
|
| 657 |
+
for j, s in enumerate(shape):
|
| 658 |
+
if s == 1 and j == 0:
|
| 659 |
+
new_shape.append(None)
|
| 660 |
+
else:
|
| 661 |
+
new_shape.append(s)
|
| 662 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 663 |
+
return pool3
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 667 |
+
num_workers = min(cpu_count(), len(arr))
|
| 668 |
+
chunk_size = len(arr) // num_workers
|
| 669 |
+
extra = len(arr) % num_workers
|
| 670 |
+
|
| 671 |
+
start_idx = 0
|
| 672 |
+
batches = []
|
| 673 |
+
for i in range(num_workers):
|
| 674 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 675 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 676 |
+
start_idx += size
|
| 677 |
+
|
| 678 |
+
with ThreadPool(num_workers) as pool:
|
| 679 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
main()
|
| 684 |
+
|
| 685 |
+
# nohup python evaluator_base.py > evaluator_base.log 2>&1 &
|
evaluator_rf.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
#/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco/sd3_rectified_samples.npz
|
| 30 |
+
parser.add_argument("--ref_batch", default='/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.npz',help="path to reference batch npz file")
|
| 31 |
+
parser.add_argument("--sample_batch", default='/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_batch2_220000.npz', help="path to sample batch npz file")
|
| 32 |
+
parser.add_argument("--save_path", default='/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_batch2_220000',help="path to sample batch npz file")
|
| 33 |
+
parser.add_argument("--cfg_cond", default=1, type=int)
|
| 34 |
+
parser.add_argument("--step", default=1, type=int)
|
| 35 |
+
parser.add_argument("--cfg", default=1.0, type=float)
|
| 36 |
+
parser.add_argument("--cls_cfg", default=1.0, type=float)
|
| 37 |
+
parser.add_argument("--gh", default=1.0, type=float)
|
| 38 |
+
parser.add_argument("--num_steps", default=50, type=int)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
if not os.path.exists(args.save_path):
|
| 42 |
+
os.mkdir(args.save_path)
|
| 43 |
+
|
| 44 |
+
# NOTE: 当前环境中 TensorFlow 与 CUDA/cuDNN 可能版本不匹配(例如报 "No DNN in stream executor"),
|
| 45 |
+
# 这会导致 GPU 计算失败。这里强制使用 CPU 进行评估(会慢一些,但能保证运行)。
|
| 46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 47 |
+
|
| 48 |
+
config = tf.ConfigProto(
|
| 49 |
+
allow_soft_placement=True, # allows DecodeJpeg to run on CPU in Inception graph
|
| 50 |
+
device_count={"GPU": 0},
|
| 51 |
+
)
|
| 52 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 53 |
+
|
| 54 |
+
print("warming up TensorFlow...")
|
| 55 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 56 |
+
# than after the next print(), to help prevent confusion.
|
| 57 |
+
evaluator.warmup()
|
| 58 |
+
|
| 59 |
+
print("computing reference batch activations...")
|
| 60 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 61 |
+
print("computing/reading reference batch statistics...")
|
| 62 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 63 |
+
|
| 64 |
+
print("computing sample batch activations...")
|
| 65 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 66 |
+
print("computing/reading sample batch statistics...")
|
| 67 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 68 |
+
|
| 69 |
+
print("Computing evaluations...")
|
| 70 |
+
Inception_Score = evaluator.compute_inception_score(sample_acts[0])
|
| 71 |
+
FID = sample_stats.frechet_distance(ref_stats)
|
| 72 |
+
sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
| 73 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 74 |
+
|
| 75 |
+
print("Inception Score:", Inception_Score)
|
| 76 |
+
print("FID:", FID)
|
| 77 |
+
print("sFID:", sFID)
|
| 78 |
+
print("Precision:", prec)
|
| 79 |
+
print("Recall:", recall)
|
| 80 |
+
|
| 81 |
+
if args.cfg_cond:
|
| 82 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
|
| 83 |
+
else:
|
| 84 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
|
| 85 |
+
with open(file_path, "w") as file:
|
| 86 |
+
file.write("Inception Score: {}\n".format(Inception_Score))
|
| 87 |
+
file.write("FID: {}\n".format(FID))
|
| 88 |
+
file.write("sFID: {}\n".format(sFID))
|
| 89 |
+
file.write("Precision: {}\n".format(prec))
|
| 90 |
+
file.write("Recall: {}\n".format(recall))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class InvalidFIDException(Exception):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FIDStatistics:
|
| 98 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 99 |
+
self.mu = mu
|
| 100 |
+
self.sigma = sigma
|
| 101 |
+
|
| 102 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 103 |
+
"""
|
| 104 |
+
Compute the Frechet distance between two sets of statistics.
|
| 105 |
+
"""
|
| 106 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 107 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 108 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 109 |
+
|
| 110 |
+
mu1 = np.atleast_1d(mu1)
|
| 111 |
+
mu2 = np.atleast_1d(mu2)
|
| 112 |
+
|
| 113 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 114 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 115 |
+
|
| 116 |
+
assert (
|
| 117 |
+
mu1.shape == mu2.shape
|
| 118 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 119 |
+
assert (
|
| 120 |
+
sigma1.shape == sigma2.shape
|
| 121 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 122 |
+
|
| 123 |
+
diff = mu1 - mu2
|
| 124 |
+
|
| 125 |
+
# product might be almost singular
|
| 126 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 127 |
+
if not np.isfinite(covmean).all():
|
| 128 |
+
msg = (
|
| 129 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 130 |
+
% eps
|
| 131 |
+
)
|
| 132 |
+
warnings.warn(msg)
|
| 133 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 134 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 135 |
+
|
| 136 |
+
# numerical error might give slight imaginary component
|
| 137 |
+
if np.iscomplexobj(covmean):
|
| 138 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 139 |
+
m = np.max(np.abs(covmean.imag))
|
| 140 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 141 |
+
covmean = covmean.real
|
| 142 |
+
|
| 143 |
+
tr_covmean = np.trace(covmean)
|
| 144 |
+
|
| 145 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Evaluator:
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
session,
|
| 152 |
+
batch_size=64,
|
| 153 |
+
softmax_batch_size=512,
|
| 154 |
+
):
|
| 155 |
+
self.sess = session
|
| 156 |
+
self.batch_size = batch_size
|
| 157 |
+
self.softmax_batch_size = softmax_batch_size
|
| 158 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 159 |
+
with self.sess.graph.as_default():
|
| 160 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 161 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 162 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 163 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 164 |
+
|
| 165 |
+
def warmup(self):
|
| 166 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 167 |
+
|
| 168 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 169 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 170 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 171 |
+
|
| 172 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 173 |
+
"""
|
| 174 |
+
Compute image features for downstream evals.
|
| 175 |
+
|
| 176 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 177 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 178 |
+
dimension. The tuple is (pool_3, spatial).
|
| 179 |
+
"""
|
| 180 |
+
preds = []
|
| 181 |
+
spatial_preds = []
|
| 182 |
+
for batch in tqdm(batches):
|
| 183 |
+
batch = batch.astype(np.float32)
|
| 184 |
+
pred, spatial_pred = self.sess.run(
|
| 185 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 186 |
+
)
|
| 187 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 188 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 189 |
+
return (
|
| 190 |
+
np.concatenate(preds, axis=0),
|
| 191 |
+
np.concatenate(spatial_preds, axis=0),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def read_statistics(
|
| 195 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 196 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 197 |
+
obj = np.load(npz_path)
|
| 198 |
+
if "mu" in list(obj.keys()):
|
| 199 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 200 |
+
obj["mu_s"], obj["sigma_s"]
|
| 201 |
+
)
|
| 202 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 203 |
+
|
| 204 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 205 |
+
mu = np.mean(activations, axis=0)
|
| 206 |
+
sigma = np.cov(activations, rowvar=False)
|
| 207 |
+
return FIDStatistics(mu, sigma)
|
| 208 |
+
|
| 209 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 210 |
+
softmax_out = []
|
| 211 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 212 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 213 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 214 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 215 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 216 |
+
scores = []
|
| 217 |
+
for i in range(0, len(preds), split_size):
|
| 218 |
+
part = preds[i : i + split_size]
|
| 219 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 220 |
+
kl = np.mean(np.sum(kl, 1))
|
| 221 |
+
scores.append(np.exp(kl))
|
| 222 |
+
return float(np.mean(scores))
|
| 223 |
+
|
| 224 |
+
def compute_prec_recall(
|
| 225 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 226 |
+
) -> Tuple[float, float]:
|
| 227 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 228 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 229 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 230 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 231 |
+
)
|
| 232 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ManifoldEstimator:
|
| 236 |
+
"""
|
| 237 |
+
A helper for comparing manifolds of feature vectors.
|
| 238 |
+
|
| 239 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
session,
|
| 245 |
+
row_batch_size=10000,
|
| 246 |
+
col_batch_size=10000,
|
| 247 |
+
nhood_sizes=(3,),
|
| 248 |
+
clamp_to_percentile=None,
|
| 249 |
+
eps=1e-5,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Estimate the manifold of given feature vectors.
|
| 253 |
+
|
| 254 |
+
:param session: the TensorFlow session.
|
| 255 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 256 |
+
(parameter to trade-off between memory usage and performance).
|
| 257 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 258 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 259 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 260 |
+
the given percentile.
|
| 261 |
+
:param eps: small number for numerical stability.
|
| 262 |
+
"""
|
| 263 |
+
self.distance_block = DistanceBlock(session)
|
| 264 |
+
self.row_batch_size = row_batch_size
|
| 265 |
+
self.col_batch_size = col_batch_size
|
| 266 |
+
self.nhood_sizes = nhood_sizes
|
| 267 |
+
self.num_nhoods = len(nhood_sizes)
|
| 268 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 269 |
+
self.eps = eps
|
| 270 |
+
|
| 271 |
+
def warmup(self):
|
| 272 |
+
feats, radii = (
|
| 273 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 274 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 275 |
+
)
|
| 276 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 277 |
+
|
| 278 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 279 |
+
num_images = len(features)
|
| 280 |
+
|
| 281 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 282 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 283 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 284 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 285 |
+
|
| 286 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 287 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 288 |
+
row_batch = features[begin1:end1]
|
| 289 |
+
|
| 290 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 291 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 292 |
+
col_batch = features[begin2:end2]
|
| 293 |
+
|
| 294 |
+
# Compute distances between batches.
|
| 295 |
+
distance_batch[
|
| 296 |
+
0 : end1 - begin1, begin2:end2
|
| 297 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 298 |
+
|
| 299 |
+
# Find the k-nearest neighbor from the current batch.
|
| 300 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 301 |
+
[
|
| 302 |
+
x[:, self.nhood_sizes]
|
| 303 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 304 |
+
],
|
| 305 |
+
axis=0,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if self.clamp_to_percentile is not None:
|
| 309 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 310 |
+
radii[radii > max_distances] = 0
|
| 311 |
+
return radii
|
| 312 |
+
|
| 313 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 314 |
+
"""
|
| 315 |
+
Evaluate if new feature vectors are at the manifold.
|
| 316 |
+
"""
|
| 317 |
+
num_eval_images = eval_features.shape[0]
|
| 318 |
+
num_ref_images = radii.shape[0]
|
| 319 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 320 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 321 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 322 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 323 |
+
|
| 324 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 325 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 326 |
+
feature_batch = eval_features[begin1:end1]
|
| 327 |
+
|
| 328 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 329 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 330 |
+
ref_batch = features[begin2:end2]
|
| 331 |
+
|
| 332 |
+
distance_batch[
|
| 333 |
+
0 : end1 - begin1, begin2:end2
|
| 334 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 335 |
+
|
| 336 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 337 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 338 |
+
# the new sample lies at the estimated manifold.
|
| 339 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 340 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 341 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 342 |
+
|
| 343 |
+
max_realism_score[begin1:end1] = np.max(
|
| 344 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 345 |
+
)
|
| 346 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 347 |
+
|
| 348 |
+
return {
|
| 349 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 350 |
+
"batch_predictions": batch_predictions,
|
| 351 |
+
"max_realisim_score": max_realism_score,
|
| 352 |
+
"nearest_indices": nearest_indices,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
def evaluate_pr(
|
| 356 |
+
self,
|
| 357 |
+
features_1: np.ndarray,
|
| 358 |
+
radii_1: np.ndarray,
|
| 359 |
+
features_2: np.ndarray,
|
| 360 |
+
radii_2: np.ndarray,
|
| 361 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 362 |
+
"""
|
| 363 |
+
Evaluate precision and recall efficiently.
|
| 364 |
+
|
| 365 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 366 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 367 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 368 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 369 |
+
:return: a tuple of arrays for (precision, recall):
|
| 370 |
+
- precision: an np.ndarray of length K1
|
| 371 |
+
- recall: an np.ndarray of length K2
|
| 372 |
+
"""
|
| 373 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
|
| 374 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
|
| 375 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 376 |
+
end_1 = begin_1 + self.row_batch_size
|
| 377 |
+
batch_1 = features_1[begin_1:end_1]
|
| 378 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 379 |
+
end_2 = begin_2 + self.col_batch_size
|
| 380 |
+
batch_2 = features_2[begin_2:end_2]
|
| 381 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 382 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 383 |
+
)
|
| 384 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 385 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 386 |
+
return (
|
| 387 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 388 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class DistanceBlock:
|
| 393 |
+
"""
|
| 394 |
+
Calculate pairwise distances between vectors.
|
| 395 |
+
|
| 396 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(self, session):
|
| 400 |
+
self.session = session
|
| 401 |
+
|
| 402 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 403 |
+
with session.graph.as_default():
|
| 404 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 405 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 406 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 407 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 408 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 409 |
+
)
|
| 410 |
+
self.distance_block = tf.cond(
|
| 411 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 412 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 413 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Extra logic for less thans.
|
| 417 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 418 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 419 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 420 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 421 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 422 |
+
|
| 423 |
+
def pairwise_distances(self, U, V):
|
| 424 |
+
"""
|
| 425 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 426 |
+
"""
|
| 427 |
+
return self.session.run(
|
| 428 |
+
self.distance_block,
|
| 429 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 433 |
+
return self.session.run(
|
| 434 |
+
[self._batch_1_in, self._batch_2_in],
|
| 435 |
+
feed_dict={
|
| 436 |
+
self._features_batch1: batch_1,
|
| 437 |
+
self._features_batch2: batch_2,
|
| 438 |
+
self._radii1: radii_1,
|
| 439 |
+
self._radii2: radii_2,
|
| 440 |
+
},
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _batch_pairwise_distances(U, V):
|
| 445 |
+
"""
|
| 446 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 447 |
+
"""
|
| 448 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 449 |
+
# Squared norms of each row in U and V.
|
| 450 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 451 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 452 |
+
|
| 453 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 454 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 455 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 456 |
+
|
| 457 |
+
# Pairwise squared Euclidean distances.
|
| 458 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 459 |
+
|
| 460 |
+
return D
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class NpzArrayReader(ABC):
|
| 464 |
+
@abstractmethod
|
| 465 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
@abstractmethod
|
| 469 |
+
def remaining(self) -> int:
|
| 470 |
+
pass
|
| 471 |
+
|
| 472 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 473 |
+
def gen_fn():
|
| 474 |
+
while True:
|
| 475 |
+
batch = self.read_batch(batch_size)
|
| 476 |
+
if batch is None:
|
| 477 |
+
break
|
| 478 |
+
yield batch
|
| 479 |
+
|
| 480 |
+
rem = self.remaining()
|
| 481 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 482 |
+
return BatchIterator(gen_fn, num_batches)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class BatchIterator:
|
| 486 |
+
def __init__(self, gen_fn, length):
|
| 487 |
+
self.gen_fn = gen_fn
|
| 488 |
+
self.length = length
|
| 489 |
+
|
| 490 |
+
def __len__(self):
|
| 491 |
+
return self.length
|
| 492 |
+
|
| 493 |
+
def __iter__(self):
|
| 494 |
+
return self.gen_fn()
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 498 |
+
def __init__(self, arr_f, shape, dtype):
|
| 499 |
+
self.arr_f = arr_f
|
| 500 |
+
self.shape = shape
|
| 501 |
+
self.dtype = dtype
|
| 502 |
+
self.idx = 0
|
| 503 |
+
|
| 504 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 505 |
+
if self.idx >= self.shape[0]:
|
| 506 |
+
return None
|
| 507 |
+
|
| 508 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 509 |
+
self.idx += bs
|
| 510 |
+
|
| 511 |
+
if self.dtype.itemsize == 0:
|
| 512 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 513 |
+
|
| 514 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 515 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 516 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 517 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 518 |
+
|
| 519 |
+
def remaining(self) -> int:
|
| 520 |
+
return max(0, self.shape[0] - self.idx)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 524 |
+
def __init__(self, arr):
|
| 525 |
+
self.arr = arr
|
| 526 |
+
self.idx = 0
|
| 527 |
+
|
| 528 |
+
@classmethod
|
| 529 |
+
def load(cls, path: str, arr_name: str):
|
| 530 |
+
with open(path, "rb") as f:
|
| 531 |
+
arr = np.load(f)[arr_name]
|
| 532 |
+
return cls(arr)
|
| 533 |
+
|
| 534 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 535 |
+
if self.idx >= self.arr.shape[0]:
|
| 536 |
+
return None
|
| 537 |
+
|
| 538 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 539 |
+
self.idx += batch_size
|
| 540 |
+
return res
|
| 541 |
+
|
| 542 |
+
def remaining(self) -> int:
|
| 543 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@contextmanager
|
| 547 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 548 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 549 |
+
version = np.lib.format.read_magic(arr_f)
|
| 550 |
+
if version == (1, 0):
|
| 551 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 552 |
+
elif version == (2, 0):
|
| 553 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 554 |
+
else:
|
| 555 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 556 |
+
return
|
| 557 |
+
shape, fortran, dtype = header
|
| 558 |
+
if fortran or dtype.hasobject:
|
| 559 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 560 |
+
else:
|
| 561 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 565 |
+
"""
|
| 566 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 567 |
+
|
| 568 |
+
Read from file-like object until size bytes are read.
|
| 569 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 570 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 571 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 572 |
+
requested.
|
| 573 |
+
"""
|
| 574 |
+
data = bytes()
|
| 575 |
+
while True:
|
| 576 |
+
# io files (default in python3) return None or raise on
|
| 577 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 578 |
+
# done about that. note that regular files can't be non-blocking
|
| 579 |
+
try:
|
| 580 |
+
r = fp.read(size - len(data))
|
| 581 |
+
data += r
|
| 582 |
+
if len(r) == 0 or len(data) == size:
|
| 583 |
+
break
|
| 584 |
+
except io.BlockingIOError:
|
| 585 |
+
pass
|
| 586 |
+
if len(data) != size:
|
| 587 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 588 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 589 |
+
else:
|
| 590 |
+
return data
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
@contextmanager
|
| 594 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 595 |
+
with open(path, "rb") as f:
|
| 596 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 597 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 598 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 599 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 600 |
+
yield arr_f
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _download_inception_model():
|
| 604 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 605 |
+
return
|
| 606 |
+
print("downloading InceptionV3 model...")
|
| 607 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 608 |
+
r.raise_for_status()
|
| 609 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 610 |
+
with open(tmp_path, "wb") as f:
|
| 611 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 612 |
+
f.write(chunk)
|
| 613 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def _create_feature_graph(input_batch):
|
| 617 |
+
_download_inception_model()
|
| 618 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 619 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 620 |
+
graph_def = tf.GraphDef()
|
| 621 |
+
graph_def.ParseFromString(f.read())
|
| 622 |
+
pool3, spatial = tf.import_graph_def(
|
| 623 |
+
graph_def,
|
| 624 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 625 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 626 |
+
name=prefix,
|
| 627 |
+
)
|
| 628 |
+
_update_shapes(pool3)
|
| 629 |
+
spatial = spatial[..., :7]
|
| 630 |
+
return pool3, spatial
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def _create_softmax_graph(input_batch):
|
| 634 |
+
_download_inception_model()
|
| 635 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 636 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 637 |
+
graph_def = tf.GraphDef()
|
| 638 |
+
graph_def.ParseFromString(f.read())
|
| 639 |
+
(matmul,) = tf.import_graph_def(
|
| 640 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 641 |
+
)
|
| 642 |
+
w = matmul.inputs[1]
|
| 643 |
+
logits = tf.matmul(input_batch, w)
|
| 644 |
+
return tf.nn.softmax(logits)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def _update_shapes(pool3):
|
| 648 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 649 |
+
ops = pool3.graph.get_operations()
|
| 650 |
+
for op in ops:
|
| 651 |
+
for o in op.outputs:
|
| 652 |
+
shape = o.get_shape()
|
| 653 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 654 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 655 |
+
shape = [s for s in shape] # TF 2.x
|
| 656 |
+
new_shape = []
|
| 657 |
+
for j, s in enumerate(shape):
|
| 658 |
+
if s == 1 and j == 0:
|
| 659 |
+
new_shape.append(None)
|
| 660 |
+
else:
|
| 661 |
+
new_shape.append(s)
|
| 662 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 663 |
+
return pool3
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 667 |
+
num_workers = min(cpu_count(), len(arr))
|
| 668 |
+
chunk_size = len(arr) // num_workers
|
| 669 |
+
extra = len(arr) % num_workers
|
| 670 |
+
|
| 671 |
+
start_idx = 0
|
| 672 |
+
batches = []
|
| 673 |
+
for i in range(num_workers):
|
| 674 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 675 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 676 |
+
start_idx += size
|
| 677 |
+
|
| 678 |
+
with ThreadPool(num_workers) as pool:
|
| 679 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
main()
|
| 684 |
+
|
| 685 |
+
# nohup python evaluator_rf.py > evaluator_rf_iter22.log 2>&1 &
|
evaluator_rf_iter22.log
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-25 14:55:37.841840: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
|
|
|
|
|
|
|
|
|
| 1 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 2 |
0%| | 1/211 [00:02<07:28, 2.13s/it]
|
| 3 |
1%| | 2/211 [00:03<06:40, 1.92s/it]
|
| 4 |
1%|▏ | 3/211 [00:06<08:10, 2.36s/it]
|
| 5 |
2%|▏ | 4/211 [00:08<07:14, 2.10s/it]
|
| 6 |
2%|▏ | 5/211 [00:10<06:43, 1.96s/it]
|
| 7 |
3%|▎ | 6/211 [00:11<06:27, 1.89s/it]
|
| 8 |
3%|▎ | 7/211 [00:13<06:15, 1.84s/it]
|
| 9 |
4%|▍ | 8/211 [00:15<06:05, 1.80s/it]
|
| 10 |
4%|▍ | 9/211 [00:17<06:00, 1.79s/it]
|
| 11 |
5%|▍ | 10/211 [00:18<05:53, 1.76s/it]
|
| 12 |
5%|▌ | 11/211 [00:20<05:47, 1.74s/it]
|
| 13 |
6%|▌ | 12/211 [00:22<05:40, 1.71s/it]
|
| 14 |
6%|▌ | 13/211 [00:23<05:34, 1.69s/it]
|
| 15 |
7%|▋ | 14/211 [00:25<05:32, 1.69s/it]
|
| 16 |
7%|▋ | 15/211 [00:27<05:39, 1.73s/it]
|
| 17 |
8%|▊ | 16/211 [00:29<05:34, 1.71s/it]
|
| 18 |
8%|▊ | 17/211 [00:30<05:34, 1.72s/it]
|
| 19 |
9%|▊ | 18/211 [00:32<06:00, 1.87s/it]
|
| 20 |
9%|▉ | 19/211 [00:34<05:48, 1.82s/it]
|
| 21 |
9%|▉ | 20/211 [00:36<05:44, 1.80s/it]
|
| 22 |
10%|▉ | 21/211 [00:38<05:39, 1.78s/it]
|
| 23 |
10%|█ | 22/211 [00:40<05:40, 1.80s/it]
|
| 24 |
11%|█ | 23/211 [00:41<05:45, 1.84s/it]
|
| 25 |
11%|█▏ | 24/211 [00:43<05:43, 1.84s/it]
|
| 26 |
12%|█▏ | 25/211 [00:45<05:35, 1.80s/it]
|
| 27 |
12%|█▏ | 26/211 [00:47<05:41, 1.84s/it]
|
| 28 |
13%|█▎ | 27/211 [00:49<05:42, 1.86s/it]
|
| 29 |
13%|█▎ | 28/211 [00:51<05:36, 1.84s/it]
|
| 30 |
14%|█▎ | 29/211 [00:52<05:27, 1.80s/it]
|
| 31 |
14%|█▍ | 30/211 [00:54<05:18, 1.76s/it]
|
| 32 |
15%|█▍ | 31/211 [00:56<05:28, 1.83s/it]
|
| 33 |
15%|█▌ | 32/211 [00:58<05:24, 1.81s/it]
|
| 34 |
16%|█▌ | 33/211 [01:00<05:28, 1.85s/it]
|
| 35 |
16%|█▌ | 34/211 [01:01<05:20, 1.81s/it]
|
| 36 |
17%|█▋ | 35/211 [01:03<05:19, 1.82s/it]
|
| 37 |
17%|█▋ | 36/211 [01:05<05:14, 1.80s/it]
|
| 38 |
18%|█▊ | 37/211 [01:07<05:07, 1.77s/it]
|
| 39 |
18%|█▊ | 38/211 [01:08<05:04, 1.76s/it]
|
| 40 |
18%|█▊ | 39/211 [01:10<05:17, 1.85s/it]
|
| 41 |
19%|█▉ | 40/211 [01:12<05:10, 1.82s/it]
|
| 42 |
19%|█▉ | 41/211 [01:14<05:23, 1.90s/it]
|
| 43 |
20%|█▉ | 42/211 [01:16<05:15, 1.87s/it]
|
| 44 |
20%|██ | 43/211 [01:18<05:06, 1.82s/it]
|
| 45 |
21%|██ | 44/211 [01:20<05:01, 1.81s/it]
|
| 46 |
21%|██▏ | 45/211 [01:21<04:55, 1.78s/it]
|
| 47 |
22%|██▏ | 46/211 [01:23<04:54, 1.78s/it]
|
| 48 |
22%|██▏ | 47/211 [01:25<04:49, 1.77s/it]
|
| 49 |
23%|██▎ | 48/211 [01:27<04:54, 1.81s/it]
|
| 50 |
23%|██▎ | 49/211 [01:28<04:48, 1.78s/it]
|
| 51 |
24%|██▎ | 50/211 [01:30<04:42, 1.75s/it]
|
| 52 |
24%|██▍ | 51/211 [01:32<04:36, 1.73s/it]
|
| 53 |
25%|██▍ | 52/211 [01:34<04:44, 1.79s/it]
|
| 54 |
25%|██▌ | 53/211 [01:36<05:01, 1.91s/it]
|
| 55 |
26%|██▌ | 54/211 [01:38<04:49, 1.84s/it]
|
| 56 |
26%|██▌ | 55/211 [01:40<05:11, 2.00s/it]
|
| 57 |
27%|██▋ | 56/211 [01:42<04:57, 1.92s/it]
|
| 58 |
27%|██▋ | 57/211 [01:43<04:43, 1.84s/it]
|
| 59 |
27%|█��▋ | 58/211 [01:45<04:48, 1.88s/it]
|
| 60 |
28%|██▊ | 59/211 [01:47<04:34, 1.81s/it]
|
| 61 |
28%|██▊ | 60/211 [01:49<04:36, 1.83s/it]
|
| 62 |
29%|██▉ | 61/211 [01:51<04:25, 1.77s/it]
|
| 63 |
29%|██▉ | 62/211 [01:52<04:17, 1.73s/it]
|
| 64 |
30%|██▉ | 63/211 [01:54<04:13, 1.71s/it]
|
| 65 |
30%|███ | 64/211 [01:55<04:08, 1.69s/it]
|
| 66 |
31%|███ | 65/211 [01:57<04:09, 1.71s/it]
|
| 67 |
31%|███▏ | 66/211 [01:59<04:11, 1.73s/it]
|
| 68 |
32%|███▏ | 67/211 [02:01<04:11, 1.75s/it]
|
| 69 |
32%|███▏ | 68/211 [02:03<04:16, 1.79s/it]
|
| 70 |
33%|███▎ | 69/211 [02:04<04:11, 1.77s/it]
|
| 71 |
33%|███▎ | 70/211 [02:07<04:30, 1.92s/it]
|
| 72 |
34%|███▎ | 71/211 [02:08<04:20, 1.86s/it]
|
| 73 |
34%|███▍ | 72/211 [02:10<04:12, 1.82s/it]
|
| 74 |
35%|███▍ | 73/211 [02:12<04:07, 1.79s/it]
|
| 75 |
35%|███▌ | 74/211 [02:14<04:12, 1.84s/it]
|
| 76 |
36%|███▌ | 75/211 [02:16<04:19, 1.91s/it]
|
| 77 |
36%|███▌ | 76/211 [02:18<04:06, 1.83s/it]
|
| 78 |
36%|███▋ | 77/211 [02:20<04:17, 1.92s/it]
|
| 79 |
37%|███▋ | 78/211 [02:21<04:06, 1.86s/it]
|
| 80 |
37%|███▋ | 79/211 [02:23<04:03, 1.84s/it]
|
| 81 |
38%|███▊ | 80/211 [02:26<04:34, 2.10s/it]
|
| 82 |
38%|███▊ | 81/211 [02:28<04:16, 1.97s/it]
|
| 83 |
39%|███▉ | 82/211 [02:29<04:01, 1.88s/it]
|
| 84 |
39%|███▉ | 83/211 [02:31<03:54, 1.84s/it]
|
| 85 |
40%|███▉ | 84/211 [02:33<03:48, 1.80s/it]
|
| 86 |
40%|████ | 85/211 [02:34<03:42, 1.77s/it]
|
| 87 |
41%|████ | 86/211 [02:36<03:39, 1.75s/it]
|
| 88 |
41%|████ | 87/211 [02:38<03:37, 1.76s/it]
|
| 89 |
42%|████▏ | 88/211 [02:39<03:33, 1.73s/it]
|
| 90 |
42%|████▏ | 89/211 [02:41<03:27, 1.70s/it]
|
| 91 |
43%|████▎ | 90/211 [02:43<03:27, 1.71s/it]
|
| 92 |
43%|████▎ | 91/211 [02:45<03:27, 1.73s/it]
|
| 93 |
44%|████▎ | 92/211 [02:46<03:23, 1.71s/it]
|
| 94 |
44%|████▍ | 93/211 [02:48<03:22, 1.71s/it]
|
| 95 |
45%|████▍ | 94/211 [02:57<07:25, 3.81s/it]
|
| 96 |
45%|████▌ | 95/211 [02:58<06:11, 3.20s/it]
|
| 97 |
45%|████▌ | 96/211 [03:00<05:17, 2.76s/it]
|
| 98 |
46%|████▌ | 97/211 [03:02<04:47, 2.52s/it]
|
| 99 |
46%|████▋ | 98/211 [03:04<04:15, 2.27s/it]
|
| 100 |
47%|████▋ | 99/211 [03:06<03:56, 2.11s/it]
|
| 101 |
47%|████▋ | 100/211 [03:07<03:42, 2.00s/it]
|
| 102 |
48%|████▊ | 101/211 [03:09<03:28, 1.90s/it]
|
| 103 |
48%|████▊ | 102/211 [03:11<03:17, 1.81s/it]
|
| 104 |
49%|████▉ | 103/211 [03:12<03:12, 1.78s/it]
|
| 105 |
49%|████▉ | 104/211 [03:14<03:05, 1.74s/it]
|
| 106 |
50%|████▉ | 105/211 [03:16<03:01, 1.71s/it]
|
| 107 |
50%|█████ | 106/211 [03:17<03:05, 1.76s/it]
|
| 108 |
51%|█████ | 107/211 [03:19<03:03, 1.76s/it]
|
| 109 |
51%|█████ | 108/211 [03:21<02:58, 1.74s/it]
|
| 110 |
52%|█████▏ | 109/211 [03:23<02:58, 1.75s/it]
|
| 111 |
52%|█████▏ | 110/211 [03:25<02:58, 1.77s/it]
|
| 112 |
53%|█████▎ | 111/211 [03:26<02:51, 1.72s/it]
|
| 113 |
53%|█████▎ | 112/211 [03:28<02:49, 1.71s/it]
|
| 114 |
54%|█████▎ | 113/211 [03:29<02:46, 1.70s/it]
|
| 115 |
54%|█████▍ | 114/211 [03:31<02:44, 1.69s/it]
|
| 116 |
55%|█████▍ | 115/211 [03:33<02:41, 1.68s/it]
|
| 117 |
55%|█████▍ | 116/211 [03:35<02:42, 1.71s/it]
|
| 118 |
55%|█████▌ | 117/211 [03:36<02:38, 1.69s/it]
|
| 119 |
56%|█████▌ | 118/211 [03:38<02:38, 1.70s/it]
|
| 120 |
56%|█████▋ | 119/211 [03:40<02:35, 1.69s/it]
|
| 121 |
57%|█████▋ | 120/211 [03:41<02:36, 1.72s/it]
|
| 122 |
57%|█████▋ | 121/211 [03:43<02:36, 1.74s/it]
|
| 123 |
58%|█████▊ | 122/211 [03:45<02:31, 1.70s/it]
|
| 124 |
58%|█████▊ | 123/211 [03:47<02:34, 1.76s/it]
|
| 125 |
59%|█████▉ | 124/211 [03:48<02:30, 1.73s/it]
|
| 126 |
59%|█████▉ | 125/211 [03:50<02:25, 1.70s/it]
|
| 127 |
60%|█████▉ | 126/211 [03:52<02:22, 1.68s/it]
|
| 128 |
60%|██████ | 127/211 [03:53<02:21, 1.69s/it]
|
| 129 |
61%|██████ | 128/211 [03:55<02:19, 1.68s/it]
|
| 130 |
61%|██████ | 129/211 [03:57<02:17, 1.68s/it]
|
| 131 |
62%|██████▏ | 130/211 [03:59<02:21, 1.75s/it]
|
| 132 |
62%|██████▏ | 131/211 [04:00<02:20, 1.75s/it]
|
| 133 |
63%|██████▎ | 132/211 [04:02<02:18, 1.75s/it]
|
| 134 |
63%|██████▎ | 133/211 [04:04<02:21, 1.81s/it]
|
| 135 |
64%|██████▎ | 134/211 [04:06<02:15, 1.76s/it]
|
| 136 |
64%|██████▍ | 135/211 [04:08<02:17, 1.81s/it]
|
| 137 |
64%|██████▍ | 136/211 [04:10<02:20, 1.88s/it]
|
| 138 |
65%|██████▍ | 137/211 [04:11<02:13, 1.80s/it]
|
| 139 |
65%|██████▌ | 138/211 [04:13<02:07, 1.74s/it]
|
| 140 |
66%|██████▌ | 139/211 [04:16<02:38, 2.20s/it]
|
| 141 |
66%|██████▋ | 140/211 [04:18<02:27, 2.08s/it]
|
| 142 |
67%|██████▋ | 141/211 [04:20<02:16, 1.95s/it]
|
| 143 |
67%|██████▋ | 142/211 [04:21<02:11, 1.91s/it]
|
| 144 |
68%|██████▊ | 143/211 [04:23<02:06, 1.86s/it]
|
| 145 |
68%|██████▊ | 144/211 [04:25<02:02, 1.82s/it]
|
| 146 |
69%|██████▊ | 145/211 [04:27<01:58, 1.80s/it]
|
| 147 |
69%|██████▉ | 146/211 [04:28<01:53, 1.75s/it]
|
| 148 |
70%|██████▉ | 147/211 [04:30<01:54, 1.78s/it]
|
| 149 |
70%|███████ | 148/211 [04:32<01:53, 1.80s/it]
|
| 150 |
71%|███████ | 149/211 [04:34<01:54, 1.85s/it]
|
| 151 |
71%|███████ | 150/211 [04:36<01:49, 1.80s/it]
|
| 152 |
72%|███████▏ | 151/211 [04:37<01:46, 1.77s/it]
|
| 153 |
72%|███████▏ | 152/211 [04:39<01:45, 1.79s/it]
|
| 154 |
73%|███████▎ | 153/211 [04:41<01:43, 1.78s/it]
|
| 155 |
73%|███████▎ | 154/211 [04:43<01:41, 1.78s/it]
|
| 156 |
73%|███████▎ | 155/211 [04:44<01:37, 1.73s/it]
|
| 157 |
74%|███████▍ | 156/211 [04:46<01:35, 1.73s/it]
|
| 158 |
74%|███████▍ | 157/211 [04:48<01:33, 1.73s/it]
|
| 159 |
75%|███████▍ | 158/211 [04:50<01:33, 1.77s/it]
|
| 160 |
75%|███████▌ | 159/211 [04:51<01:30, 1.74s/it]
|
| 161 |
76%|███████▌ | 160/211 [04:53<01:27, 1.72s/it]
|
| 162 |
76%|███████▋ | 161/211 [04:55<01:28, 1.77s/it]
|
| 163 |
77%|███████▋ | 162/211 [04:57<01:28, 1.80s/it]
|
| 164 |
77%|███████▋ | 163/211 [04:58<01:24, 1.75s/it]
|
| 165 |
78%|███████▊ | 164/211 [05:00<01:26, 1.84s/it]
|
| 166 |
78%|███████▊ | 165/211 [05:02<01:21, 1.77s/it]
|
| 167 |
79%|███████▊ | 166/211 [05:04<01:17, 1.72s/it]
|
| 168 |
79%|███████▉ | 167/211 [05:05<01:15, 1.72s/it]
|
| 169 |
80%|███████▉ | 168/211 [05:07<01:16, 1.79s/it]
|
| 170 |
80%|████████ | 169/211 [05:09<01:17, 1.84s/it]
|
| 171 |
81%|████████ | 170/211 [05:11<01:15, 1.83s/it]
|
| 172 |
81%|████████ | 171/211 [05:13<01:12, 1.81s/it]
|
| 173 |
82%|████████▏ | 172/211 [05:15<01:09, 1.79s/it]
|
| 174 |
82%|████████▏ | 173/211 [05:16<01:06, 1.74s/it]
|
| 175 |
82%|████████▏ | 174/211 [05:18<01:04, 1.75s/it]
|
| 176 |
83%|████████▎ | 175/211 [05:20<01:02, 1.74s/it]
|
| 177 |
83%|████████▎ | 176/211 [05:41<04:21, 7.47s/it]
|
| 178 |
84%|████████▍ | 177/211 [05:42<03:17, 5.81s/it]
|
| 179 |
84%|████████▍ | 178/211 [05:44<02:32, 4.63s/it]
|
| 180 |
85%|████████▍ | 179/211 [05:46<02:00, 3.75s/it]
|
| 181 |
85%|████████▌ | 180/211 [05:48<01:36, 3.13s/it]
|
| 182 |
86%|████████▌ | 181/211 [05:49<01:21, 2.72s/it]
|
| 183 |
86%|████████▋ | 182/211 [05:51<01:10, 2.42s/it]
|
| 184 |
87%|████████▋ | 183/211 [05:53<01:01, 2.19s/it]
|
| 185 |
87%|████████▋ | 184/211 [05:55<00:54, 2.03s/it]
|
| 186 |
88%|████████▊ | 185/211 [05:56<00:50, 1.94s/it]
|
| 187 |
88%|████████▊ | 186/211 [05:58<00:46, 1.88s/it]
|
| 188 |
89%|████████▊ | 187/211 [06:00<00:43, 1.82s/it]
|
| 189 |
89%|████████▉ | 188/211 [06:01<00:41, 1.79s/it]
|
| 190 |
90%|████████▉ | 189/211 [06:03<00:40, 1.86s/it]
|
| 191 |
90%|█████████ | 190/211 [06:06<00:45, 2.14s/it]
|
| 192 |
91%|█████████ | 191/211 [06:08<00:40, 2.02s/it]
|
| 193 |
91%|█████████ | 192/211 [06:10<00:36, 1.91s/it]
|
| 194 |
91%|█████████▏| 193/211 [06:11<00:33, 1.85s/it]
|
| 195 |
92%|█████████▏| 194/211 [06:13<00:30, 1.82s/it]
|
| 196 |
92%|█████████▏| 195/211 [06:15<00:28, 1.78s/it]
|
| 197 |
93%|█████████▎| 196/211 [06:17<00:28, 1.92s/it]
|
| 198 |
93%|█████████▎| 197/211 [06:19<00:25, 1.84s/it]
|
| 199 |
94%|█████████▍| 198/211 [06:20<00:23, 1.80s/it]
|
| 200 |
94%|█████████▍| 199/211 [06:22<00:22, 1.84s/it]
|
| 201 |
95%|█████████▍| 200/211 [06:24<00:19, 1.80s/it]
|
| 202 |
95%|█████████▌| 201/211 [06:26<00:17, 1.77s/it]
|
| 203 |
96%|█████████▌| 202/211 [06:28<00:16, 1.82s/it]
|
| 204 |
96%|█████████▌| 203/211 [06:29<00:14, 1.77s/it]
|
| 205 |
97%|█████████▋| 204/211 [06:31<00:12, 1.73s/it]
|
| 206 |
97%|█████████▋| 205/211 [06:33<00:10, 1.72s/it]
|
| 207 |
98%|█████████▊| 206/211 [06:34<00:08, 1.73s/it]
|
| 208 |
98%|█████████▊| 207/211 [06:36<00:06, 1.71s/it]
|
| 209 |
99%|█████████▊| 208/211 [06:39<00:05, 1.97s/it]
|
| 210 |
99%|█████████▉| 209/211 [06:40<00:03, 1.93s/it]
|
|
|
|
|
|
|
|
|
|
| 211 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 212 |
0%| | 1/469 [00:02<15:45, 2.02s/it]
|
| 213 |
0%| | 2/469 [00:03<14:20, 1.84s/it]
|
| 214 |
1%| | 3/469 [00:05<13:51, 1.78s/it]
|
| 215 |
1%| | 4/469 [00:07<13:37, 1.76s/it]
|
| 216 |
1%| | 5/469 [00:08<13:23, 1.73s/it]
|
| 217 |
1%|▏ | 6/469 [00:10<13:11, 1.71s/it]
|
| 218 |
1%|▏ | 7/469 [00:12<13:03, 1.70s/it]
|
| 219 |
2%|▏ | 8/469 [00:13<12:57, 1.69s/it]
|
| 220 |
2%|▏ | 9/469 [00:15<12:53, 1.68s/it]
|
| 221 |
2%|▏ | 10/469 [00:17<13:08, 1.72s/it]
|
| 222 |
2%|▏ | 11/469 [00:26<31:14, 4.09s/it]
|
| 223 |
3%|▎ | 12/469 [00:28<25:29, 3.35s/it]
|
| 224 |
3%|▎ | 13/469 [00:30<21:32, 2.84s/it]
|
| 225 |
3%|▎ | 14/469 [00:31<18:50, 2.49s/it]
|
| 226 |
3%|▎ | 15/469 [00:33<16:50, 2.23s/it]
|
| 227 |
3%|▎ | 16/469 [00:35<15:32, 2.06s/it]
|
| 228 |
4%|▎ | 17/469 [00:36<14:57, 1.99s/it]
|
| 229 |
4%|▍ | 18/469 [00:38<14:12, 1.89s/it]
|
| 230 |
4%|▍ | 19/469 [00:40<13:45, 1.83s/it]
|
| 231 |
4%|▍ | 20/469 [00:41<13:17, 1.78s/it]
|
| 232 |
4%|▍ | 21/469 [00:43<13:08, 1.76s/it]
|
| 233 |
5%|▍ | 22/469 [00:45<12:51, 1.73s/it]
|
| 234 |
5%|▍ | 23/469 [00:47<12:51, 1.73s/it]
|
| 235 |
5%|▌ | 24/469 [00:48<12:38, 1.70s/it]
|
| 236 |
5%|▌ | 25/469 [00:50<12:33, 1.70s/it]
|
| 237 |
6%|▌ | 26/469 [00:51<12:26, 1.68s/it]
|
| 238 |
6%|▌ | 27/469 [00:53<12:26, 1.69s/it]
|
| 239 |
6%|▌ | 28/469 [00:55<12:17, 1.67s/it]
|
| 240 |
6%|▌ | 29/469 [00:57<12:18, 1.68s/it]
|
| 241 |
6%|▋ | 30/469 [00:58<12:12, 1.67s/it]
|
| 242 |
7%|▋ | 31/469 [01:00<12:15, 1.68s/it]
|
| 243 |
7%|▋ | 32/469 [01:02<12:10, 1.67s/it]
|
| 244 |
7%|▋ | 33/469 [01:04<12:55, 1.78s/it]
|
| 245 |
7%|▋ | 34/469 [01:05<12:41, 1.75s/it]
|
| 246 |
7%|▋ | 35/469 [01:07<12:24, 1.72s/it]
|
| 247 |
8%|▊ | 36/469 [01:09<12:34, 1.74s/it]
|
| 248 |
8%|▊ | 37/469 [01:10<12:19, 1.71s/it]
|
| 249 |
8%|▊ | 38/469 [01:12<12:27, 1.73s/it]
|
| 250 |
8%|▊ | 39/469 [01:14<12:14, 1.71s/it]
|
| 251 |
9%|▊ | 40/469 [01:15<12:10, 1.70s/it]
|
| 252 |
9%|▊ | 41/469 [01:17<12:07, 1.70s/it]
|
| 253 |
9%|▉ | 42/469 [01:19<12:09, 1.71s/it]
|
| 254 |
9%|▉ | 43/469 [01:21<12:18, 1.73s/it]
|
| 255 |
9%|▉ | 44/469 [01:22<12:13, 1.72s/it]
|
| 256 |
10%|▉ | 45/469 [01:24<12:06, 1.71s/it]
|
| 257 |
10%|▉ | 46/469 [01:26<12:22, 1.75s/it]
|
| 258 |
10%|█ | 47/469 [01:28<12:35, 1.79s/it]
|
| 259 |
10%|█ | 48/469 [01:29<12:12, 1.74s/it]
|
| 260 |
10%|█ | 49/469 [01:31<12:01, 1.72s/it]
|
| 261 |
11%|█ | 50/469 [01:33<12:24, 1.78s/it]
|
| 262 |
11%|█ | 51/469 [01:35<12:32, 1.80s/it]
|
| 263 |
11%|█ | 52/469 [01:37<12:29, 1.80s/it]
|
| 264 |
11%|█▏ | 53/469 [01:39<13:30, 1.95s/it]
|
| 265 |
12%|█▏ | 54/469 [01:41<12:50, 1.86s/it]
|
| 266 |
12%|█▏ | 55/469 [01:42<12:27, 1.81s/it]
|
| 267 |
12%|█▏ | 56/469 [01:44<12:05, 1.76s/it]
|
| 268 |
12%|█▏ | 57/469 [01:46<12:17, 1.79s/it]
|
| 269 |
12%|█▏ | 58/469 [01:47<12:08, 1.77s/it]
|
| 270 |
13%|█▎ | 59/469 [01:49<12:09, 1.78s/it]
|
| 271 |
13%|█▎ | 60/469 [01:51<12:17, 1.80s/it]
|
| 272 |
13%|█▎ | 61/469 [01:53<12:07, 1.78s/it]
|
| 273 |
13%|█▎ | 62/469 [01:55<12:24, 1.83s/it]
|
| 274 |
13%|█▎ | 63/469 [01:56<12:00, 1.77s/it]
|
| 275 |
14%|█▎ | 64/469 [01:58<11:42, 1.73s/it]
|
| 276 |
14%|█▍ | 65/469 [02:00<11:35, 1.72s/it]
|
| 277 |
14%|█▍ | 66/469 [02:02<11:34, 1.72s/it]
|
| 278 |
14%|█▍ | 67/469 [02:04<12:11, 1.82s/it]
|
| 279 |
14%|█▍ | 68/469 [02:05<11:45, 1.76s/it]
|
| 280 |
15%|█▍ | 69/469 [02:07<11:31, 1.73s/it]
|
| 281 |
15%|█▍ | 70/469 [02:09<11:27, 1.72s/it]
|
| 282 |
15%|█▌ | 71/469 [02:10<11:25, 1.72s/it]
|
| 283 |
15%|█▌ | 72/469 [02:12<11:16, 1.70s/it]
|
| 284 |
16%|█▌ | 73/469 [02:14<11:09, 1.69s/it]
|
| 285 |
16%|█▌ | 74/469 [02:15<11:32, 1.75s/it]
|
| 286 |
16%|█▌ | 75/469 [02:17<11:47, 1.80s/it]
|
| 287 |
16%|█▌ | 76/469 [02:19<11:43, 1.79s/it]
|
| 288 |
16%|█▋ | 77/469 [02:21<12:10, 1.86s/it]
|
| 289 |
17%|█▋ | 78/469 [02:23<11:45, 1.80s/it]
|
| 290 |
17%|█▋ | 79/469 [02:25<12:10, 1.87s/it]
|
| 291 |
17%|█▋ | 80/469 [02:27<11:48, 1.82s/it]
|
| 292 |
17%|█▋ | 81/469 [02:28<11:30, 1.78s/it]
|
| 293 |
17%|█▋ | 82/469 [02:30<11:36, 1.80s/it]
|
| 294 |
18%|█▊ | 83/469 [02:32<11:18, 1.76s/it]
|
| 295 |
18%|█▊ | 84/469 [02:34<11:12, 1.75s/it]
|
| 296 |
18%|█▊ | 85/469 [02:35<11:03, 1.73s/it]
|
| 297 |
18%|█▊ | 86/469 [02:37<11:06, 1.74s/it]
|
| 298 |
19%|█▊ | 87/469 [02:39<11:04, 1.74s/it]
|
| 299 |
19%|█▉ | 88/469 [02:40<10:56, 1.72s/it]
|
| 300 |
19%|█▉ | 89/469 [02:42<11:07, 1.76s/it]
|
| 301 |
19%|█▉ | 90/469 [02:45<12:56, 2.05s/it]
|
| 302 |
19%|█▉ | 91/469 [02:47<12:10, 1.93s/it]
|
| 303 |
20%|█▉ | 92/469 [02:48<11:35, 1.84s/it]
|
| 304 |
20%|█▉ | 93/469 [02:50<11:19, 1.81s/it]
|
| 305 |
20%|██ | 94/469 [02:52<11:04, 1.77s/it]
|
| 306 |
20%|██ | 95/469 [02:53<10:48, 1.73s/it]
|
| 307 |
20%|██ | 96/469 [02:55<10:38, 1.71s/it]
|
| 308 |
21%|██ | 97/469 [02:57<11:01, 1.78s/it]
|
| 309 |
21%|██ | 98/469 [02:59<10:50, 1.75s/it]
|
| 310 |
21%|██ | 99/469 [03:00<10:33, 1.71s/it]
|
| 311 |
21%|██▏ | 100/469 [03:02<10:35, 1.72s/it]
|
| 312 |
22%|██▏ | 101/469 [03:04<11:42, 1.91s/it]
|
| 313 |
22%|██▏ | 102/469 [03:06<11:19, 1.85s/it]
|
| 314 |
22%|██▏ | 103/469 [03:08<11:48, 1.94s/it]
|
| 315 |
22%|██▏ | 104/469 [03:10<11:23, 1.87s/it]
|
| 316 |
22%|██▏ | 105/469 [03:11<10:53, 1.80s/it]
|
| 317 |
23%|██▎ | 106/469 [03:13<10:34, 1.75s/it]
|
| 318 |
23%|██▎ | 107/469 [03:15<10:32, 1.75s/it]
|
| 319 |
23%|██▎ | 108/469 [03:17<10:19, 1.71s/it]
|
| 320 |
23%|██▎ | 109/469 [03:18<10:08, 1.69s/it]
|
| 321 |
23%|██▎ | 110/469 [03:20<11:07, 1.86s/it]
|
| 322 |
24%|██▎ | 111/469 [03:22<10:49, 1.81s/it]
|
| 323 |
24%|██▍ | 112/469 [03:24<10:35, 1.78s/it]
|
| 324 |
24%|██▍ | 113/469 [03:25<10:22, 1.75s/it]
|
| 325 |
24%|██▍ | 114/469 [03:27<10:13, 1.73s/it]
|
| 326 |
25%|██▍ | 115/469 [03:29<10:01, 1.70s/it]
|
| 327 |
25%|██▍ | 116/469 [03:30<09:56, 1.69s/it]
|
| 328 |
25%|██▍ | 117/469 [03:32<09:56, 1.69s/it]
|
| 329 |
25%|██▌ | 118/469 [03:34<09:45, 1.67s/it]
|
| 330 |
25%|██▌ | 119/469 [03:35<09:45, 1.67s/it]
|
| 331 |
26%|██▌ | 120/469 [03:37<09:45, 1.68s/it]
|
| 332 |
26%|██▌ | 121/469 [03:40<11:56, 2.06s/it]
|
| 333 |
26%|██▌ | 122/469 [03:42<11:18, 1.95s/it]
|
| 334 |
26%|██▌ | 123/469 [03:44<11:00, 1.91s/it]
|
| 335 |
26%|██▋ | 124/469 [03:45<10:31, 1.83s/it]
|
| 336 |
27%|██▋ | 125/469 [03:47<10:06, 1.76s/it]
|
| 337 |
27%|██▋ | 126/469 [03:48<09:48, 1.72s/it]
|
| 338 |
27%|██▋ | 127/469 [03:50<09:57, 1.75s/it]
|
| 339 |
27%|██▋ | 128/469 [03:52<09:52, 1.74s/it]
|
| 340 |
28%|██▊ | 129/469 [03:54<09:46, 1.72s/it]
|
| 341 |
28%|██▊ | 130/469 [03:56<09:59, 1.77s/it]
|
| 342 |
28%|██▊ | 131/469 [03:57<09:56, 1.76s/it]
|
| 343 |
28%|██▊ | 132/469 [03:59<09:40, 1.72s/it]
|
| 344 |
28%|██▊ | 133/469 [04:01<10:56, 1.95s/it]
|
| 345 |
29%|██▊ | 134/469 [04:03<10:35, 1.90s/it]
|
| 346 |
29%|██▉ | 135/469 [04:05<10:05, 1.81s/it]
|
| 347 |
29%|██▉ | 136/469 [04:06<09:45, 1.76s/it]
|
| 348 |
29%|██▉ | 137/469 [04:08<09:30, 1.72s/it]
|
| 349 |
29%|██▉ | 138/469 [04:10<09:35, 1.74s/it]
|
| 350 |
30%|██▉ | 139/469 [04:11<09:19, 1.69s/it]
|
| 351 |
30%|██▉ | 140/469 [04:13<09:06, 1.66s/it]
|
| 352 |
30%|███ | 141/469 [04:15<09:00, 1.65s/it]
|
| 353 |
30%|███ | 142/469 [04:16<08:55, 1.64s/it]
|
| 354 |
30%|███ | 143/469 [04:18<08:55, 1.64s/it]
|
| 355 |
31%|███ | 144/469 [04:19<08:47, 1.62s/it]
|
| 356 |
31%|███ | 145/469 [04:21<09:06, 1.69s/it]
|
| 357 |
31%|███ | 146/469 [04:23<09:06, 1.69s/it]
|
| 358 |
31%|███▏ | 147/469 [04:25<08:56, 1.67s/it]
|
| 359 |
32%|███▏ | 148/469 [04:26<08:56, 1.67s/it]
|
| 360 |
32%|███▏ | 149/469 [04:28<09:09, 1.72s/it]
|
| 361 |
32%|███▏ | 150/469 [04:30<08:57, 1.69s/it]
|
| 362 |
32%|███▏ | 151/469 [04:31<08:47, 1.66s/it]
|
| 363 |
32%|███▏ | 152/469 [04:33<08:41, 1.64s/it]
|
| 364 |
33%|███▎ | 153/469 [04:35<08:34, 1.63s/it]
|
| 365 |
33%|███▎ | 154/469 [04:36<08:32, 1.63s/it]
|
| 366 |
33%|███▎ | 155/469 [04:38<08:25, 1.61s/it]
|
| 367 |
33%|███▎ | 156/469 [04:39<08:23, 1.61s/it]
|
| 368 |
33%|███▎ | 157/469 [04:41<08:22, 1.61s/it]
|
| 369 |
34%|███▎ | 158/469 [04:47<15:58, 3.08s/it]
|
| 370 |
34%|███▍ | 159/469 [04:49<13:44, 2.66s/it]
|
| 371 |
34%|███▍ | 160/469 [04:51<12:27, 2.42s/it]
|
| 372 |
34%|███▍ | 161/469 [04:53<11:14, 2.19s/it]
|
| 373 |
35%|███▍ | 162/469 [04:54<10:22, 2.03s/it]
|
| 374 |
35%|███▍ | 163/469 [04:56<09:58, 1.96s/it]
|
| 375 |
35%|███▍ | 164/469 [04:58<09:35, 1.89s/it]
|
| 376 |
35%|███▌ | 165/469 [04:59<09:06, 1.80s/it]
|
| 377 |
35%|███▌ | 166/469 [05:01<08:47, 1.74s/it]
|
| 378 |
36%|███▌ | 167/469 [05:03<08:31, 1.69s/it]
|
| 379 |
36%|███▌ | 168/469 [05:05<08:55, 1.78s/it]
|
| 380 |
36%|███▌ | 169/469 [05:07<09:07, 1.82s/it]
|
| 381 |
36%|███▌ | 170/469 [05:08<08:54, 1.79s/it]
|
| 382 |
36%|███▋ | 171/469 [05:10<08:46, 1.77s/it]
|
| 383 |
37%|███▋ | 172/469 [05:12<08:31, 1.72s/it]
|
| 384 |
37%|███▋ | 173/469 [05:13<08:23, 1.70s/it]
|
| 385 |
37%|███▋ | 174/469 [05:15<08:16, 1.68s/it]
|
| 386 |
37%|███▋ | 175/469 [05:17<08:12, 1.67s/it]
|
| 387 |
38%|███▊ | 176/469 [05:18<08:14, 1.69s/it]
|
| 388 |
38%|███▊ | 177/469 [05:20<08:13, 1.69s/it]
|
| 389 |
38%|███▊ | 178/469 [05:22<08:07, 1.68s/it]
|
| 390 |
38%|███▊ | 179/469 [05:23<08:00, 1.66s/it]
|
| 391 |
38%|███▊ | 180/469 [05:25<07:58, 1.66s/it]
|
| 392 |
39%|███▊ | 181/469 [05:26<07:53, 1.64s/it]
|
| 393 |
39%|███▉ | 182/469 [05:28<07:47, 1.63s/it]
|
| 394 |
39%|███▉ | 183/469 [05:30<07:48, 1.64s/it]
|
| 395 |
39%|███▉ | 184/469 [05:31<07:50, 1.65s/it]
|
| 396 |
39%|███▉ | 185/469 [05:33<07:44, 1.64s/it]
|
| 397 |
40%|███▉ | 186/469 [05:35<07:50, 1.66s/it]
|
| 398 |
40%|███▉ | 187/469 [05:37<08:00, 1.70s/it]
|
| 399 |
40%|████ | 188/469 [05:38<07:55, 1.69s/it]
|
| 400 |
40%|████ | 189/469 [05:40<07:48, 1.67s/it]
|
| 401 |
41%|████ | 190/469 [05:42<07:53, 1.70s/it]
|
| 402 |
41%|████ | 191/469 [05:43<07:44, 1.67s/it]
|
| 403 |
41%|████ | 192/469 [05:45<07:38, 1.65s/it]
|
| 404 |
41%|████ | 193/469 [05:46<07:32, 1.64s/it]
|
| 405 |
41%|████▏ | 194/469 [05:48<07:54, 1.73s/it]
|
| 406 |
42%|███��▏ | 195/469 [05:50<07:45, 1.70s/it]
|
| 407 |
42%|████▏ | 196/469 [05:52<07:44, 1.70s/it]
|
| 408 |
42%|████▏ | 197/469 [05:53<07:48, 1.72s/it]
|
| 409 |
42%|████▏ | 198/469 [05:55<07:40, 1.70s/it]
|
| 410 |
42%|████▏ | 199/469 [05:57<08:02, 1.79s/it]
|
| 411 |
43%|████▎ | 200/469 [05:59<07:49, 1.74s/it]
|
| 412 |
43%|████▎ | 201/469 [06:00<07:41, 1.72s/it]
|
| 413 |
43%|████▎ | 202/469 [06:02<07:38, 1.72s/it]
|
| 414 |
43%|████▎ | 203/469 [06:04<07:37, 1.72s/it]
|
| 415 |
43%|████▎ | 204/469 [06:05<07:29, 1.70s/it]
|
| 416 |
44%|████▎ | 205/469 [06:07<07:31, 1.71s/it]
|
| 417 |
44%|████▍ | 206/469 [06:09<07:21, 1.68s/it]
|
| 418 |
44%|████▍ | 207/469 [06:10<07:16, 1.67s/it]
|
| 419 |
44%|████▍ | 208/469 [06:12<07:09, 1.64s/it]
|
| 420 |
45%|████▍ | 209/469 [06:14<07:07, 1.64s/it]
|
| 421 |
45%|████▍ | 210/469 [06:15<07:11, 1.67s/it]
|
| 422 |
45%|████▍ | 211/469 [06:17<07:21, 1.71s/it]
|
| 423 |
45%|████▌ | 212/469 [06:19<07:27, 1.74s/it]
|
| 424 |
45%|████▌ | 213/469 [06:21<07:20, 1.72s/it]
|
| 425 |
46%|████▌ | 214/469 [06:22<07:09, 1.69s/it]
|
| 426 |
46%|████▌ | 215/469 [06:24<07:07, 1.68s/it]
|
| 427 |
46%|████▌ | 216/469 [06:26<07:00, 1.66s/it]
|
| 428 |
46%|████▋ | 217/469 [06:27<06:56, 1.65s/it]
|
| 429 |
46%|████▋ | 218/469 [06:29<06:48, 1.63s/it]
|
| 430 |
47%|████▋ | 219/469 [06:31<07:08, 1.72s/it]
|
| 431 |
47%|████▋ | 220/469 [06:32<06:57, 1.68s/it]
|
| 432 |
47%|████▋ | 221/469 [06:34<06:56, 1.68s/it]
|
| 433 |
47%|████▋ | 222/469 [06:36<06:54, 1.68s/it]
|
| 434 |
48%|████▊ | 223/469 [06:37<06:47, 1.66s/it]
|
| 435 |
48%|████▊ | 224/469 [06:39<06:44, 1.65s/it]
|
| 436 |
48%|████▊ | 225/469 [06:41<06:42, 1.65s/it]
|
| 437 |
48%|████▊ | 226/469 [06:42<06:35, 1.63s/it]
|
| 438 |
48%|████▊ | 227/469 [06:44<06:31, 1.62s/it]
|
| 439 |
49%|████▊ | 228/469 [06:46<06:57, 1.73s/it]
|
| 440 |
49%|████▉ | 229/469 [06:47<06:47, 1.70s/it]
|
| 441 |
49%|████▉ | 230/469 [06:49<06:46, 1.70s/it]
|
| 442 |
49%|████▉ | 231/469 [06:51<06:40, 1.68s/it]
|
| 443 |
49%|████▉ | 232/469 [06:52<06:33, 1.66s/it]
|
| 444 |
50%|████▉ | 233/469 [06:54<06:46, 1.72s/it]
|
| 445 |
50%|████▉ | 234/469 [06:56<07:17, 1.86s/it]
|
| 446 |
50%|█████ | 235/469 [06:58<07:00, 1.80s/it]
|
| 447 |
50%|█████ | 236/469 [07:00<06:47, 1.75s/it]
|
| 448 |
51%|█████ | 237/469 [07:01<06:40, 1.72s/it]
|
| 449 |
51%|█████ | 238/469 [07:03<06:31, 1.69s/it]
|
| 450 |
51%|█████ | 239/469 [07:05<06:23, 1.67s/it]
|
| 451 |
51%|█████ | 240/469 [07:06<06:18, 1.65s/it]
|
| 452 |
51%|█████▏ | 241/469 [07:08<06:14, 1.64s/it]
|
| 453 |
52%|█████▏ | 242/469 [07:09<06:10, 1.63s/it]
|
| 454 |
52%|█████▏ | 243/469 [07:11<06:41, 1.77s/it]
|
| 455 |
52%|█████▏ | 244/469 [07:13<06:27, 1.72s/it]
|
| 456 |
52%|█████▏ | 245/469 [07:15<06:49, 1.83s/it]
|
| 457 |
52%|█████▏ | 246/469 [07:17<06:33, 1.76s/it]
|
| 458 |
53%|█████▎ | 247/469 [07:18<06:22, 1.72s/it]
|
| 459 |
53%|█████▎ | 248/469 [07:20<06:19, 1.72s/it]
|
| 460 |
53%|█████▎ | 249/469 [07:22<06:12, 1.69s/it]
|
| 461 |
53%|█████▎ | 250/469 [07:24<06:16, 1.72s/it]
|
| 462 |
54%|█████▎ | 251/469 [07:25<06:10, 1.70s/it]
|
| 463 |
54%|█████▎ | 252/469 [07:27<06:11, 1.71s/it]
|
| 464 |
54%|█████▍ | 253/469 [07:29<06:03, 1.68s/it]
|
| 465 |
54%|█████▍ | 254/469 [07:30<06:04, 1.70s/it]
|
| 466 |
54%|█████▍ | 255/469 [07:32<05:56, 1.67s/it]
|
| 467 |
55%|█████▍ | 256/469 [07:34<06:02, 1.70s/it]
|
| 468 |
55%|█████▍ | 257/469 [07:35<05:56, 1.68s/it]
|
| 469 |
55%|█████▌ | 258/469 [07:37<06:02, 1.72s/it]
|
| 470 |
55%|█████▌ | 259/469 [07:39<05:57, 1.70s/it]
|
| 471 |
55%|█████▌ | 260/469 [07:40<05:50, 1.68s/it]
|
| 472 |
56%|█████▌ | 261/469 [07:42<05:50, 1.69s/it]
|
| 473 |
56%|█████▌ | 262/469 [07:44<05:52, 1.70s/it]
|
| 474 |
56%|█████▌ | 263/469 [07:46<05:50, 1.70s/it]
|
| 475 |
56%|█████▋ | 264/469 [07:47<05:48, 1.70s/it]
|
| 476 |
57%|█████▋ | 265/469 [07:49<05:45, 1.69s/it]
|
| 477 |
57%|█████▋ | 266/469 [07:50<05:37, 1.66s/it]
|
| 478 |
57%|█████▋ | 267/469 [07:52<05:33, 1.65s/it]
|
| 479 |
57%|█████▋ | 268/469 [07:58<09:27, 2.82s/it]
|
| 480 |
57%|█████▋ | 269/469 [07:59<08:23, 2.52s/it]
|
| 481 |
58%|█████▊ | 270/469 [08:01<07:26, 2.24s/it]
|
| 482 |
58%|█████▊ | 271/469 [08:03<06:53, 2.09s/it]
|
| 483 |
58%|█████▊ | 272/469 [08:04<06:26, 1.96s/it]
|
| 484 |
58%|█████▊ | 273/469 [08:06<06:03, 1.86s/it]
|
| 485 |
58%|█████▊ | 274/469 [08:08<05:44, 1.77s/it]
|
| 486 |
59%|█████▊ | 275/469 [08:09<05:33, 1.72s/it]
|
| 487 |
59%|█████▉ | 276/469 [08:11<05:25, 1.69s/it]
|
| 488 |
59%|█████▉ | 277/469 [08:12<05:21, 1.68s/it]
|
| 489 |
59%|█████▉ | 278/469 [08:14<05:21, 1.68s/it]
|
| 490 |
59%|█████▉ | 279/469 [08:16<05:26, 1.72s/it]
|
| 491 |
60%|█████▉ | 280/469 [08:18<05:16, 1.67s/it]
|
| 492 |
60%|█████▉ | 281/469 [08:20<05:45, 1.84s/it]
|
| 493 |
60%|██████ | 282/469 [08:21<05:31, 1.77s/it]
|
| 494 |
60%|██████ | 283/469 [08:23<05:26, 1.75s/it]
|
| 495 |
61%|██████ | 284/469 [08:25<05:17, 1.72s/it]
|
| 496 |
61%|██████ | 285/469 [08:26<05:10, 1.69s/it]
|
| 497 |
61%|██████ | 286/469 [08:28<05:21, 1.76s/it]
|
| 498 |
61%|██████ | 287/469 [08:30<05:16, 1.74s/it]
|
| 499 |
61%|██████▏ | 288/469 [08:32<05:13, 1.73s/it]
|
| 500 |
62%|██████▏ | 289/469 [08:33<05:03, 1.69s/it]
|
| 501 |
62%|██████▏ | 290/469 [08:35<05:10, 1.74s/it]
|
| 502 |
62%|██████▏ | 291/469 [08:37<05:17, 1.78s/it]
|
| 503 |
62%|██████▏ | 292/469 [08:39<05:17, 1.79s/it]
|
| 504 |
62%|██████▏ | 293/469 [08:41<05:19, 1.81s/it]
|
| 505 |
63%|██████▎ | 294/469 [08:42<05:06, 1.75s/it]
|
| 506 |
63%|██████▎ | 295/469 [08:44<04:55, 1.70s/it]
|
| 507 |
63%|██████▎ | 296/469 [08:46<04:53, 1.70s/it]
|
| 508 |
63%|██████▎ | 297/469 [08:47<04:46, 1.67s/it]
|
| 509 |
64%|██████▎ | 298/469 [08:49<04:40, 1.64s/it]
|
| 510 |
64%|██████▍ | 299/469 [08:50<04:39, 1.65s/it]
|
| 511 |
64%|██████▍ | 300/469 [08:52<04:37, 1.64s/it]
|
| 512 |
64%|██████▍ | 301/469 [08:54<04:39, 1.66s/it]
|
| 513 |
64%|██████▍ | 302/469 [08:55<04:35, 1.65s/it]
|
| 514 |
65%|██████▍ | 303/469 [08:57<04:36, 1.66s/it]
|
| 515 |
65%|██████▍ | 304/469 [08:59<04:36, 1.67s/it]
|
| 516 |
65%|██████▌ | 305/469 [09:00<04:34, 1.68s/it]
|
| 517 |
65%|██████▌ | 306/469 [09:02<04:32, 1.67s/it]
|
| 518 |
65%|██████▌ | 307/469 [09:04<04:30, 1.67s/it]
|
| 519 |
66%|██████▌ | 308/469 [09:05<04:25, 1.65s/it]
|
| 520 |
66%|██████▌ | 309/469 [09:07<04:23, 1.65s/it]
|
| 521 |
66%|██████▌ | 310/469 [09:09<04:20, 1.64s/it]
|
| 522 |
66%|██████▋ | 311/469 [09:10<04:18, 1.64s/it]
|
| 523 |
67%|██████▋ | 312/469 [09:12<04:19, 1.65s/it]
|
| 524 |
67%|██████▋ | 313/469 [09:14<04:15, 1.64s/it]
|
| 525 |
67%|██████▋ | 314/469 [09:15<04:10, 1.62s/it]
|
| 526 |
67%|██████▋ | 315/469 [09:17<04:08, 1.61s/it]
|
| 527 |
67%|██████▋ | 316/469 [09:19<04:21, 1.71s/it]
|
| 528 |
68%|██████▊ | 317/469 [09:20<04:23, 1.73s/it]
|
| 529 |
68%|██████▊ | 318/469 [09:22<04:18, 1.71s/it]
|
| 530 |
68%|██████▊ | 319/469 [09:24<04:20, 1.74s/it]
|
| 531 |
68%|██████▊ | 320/469 [09:26<04:15, 1.71s/it]
|
| 532 |
68%|██████▊ | 321/469 [09:27<04:21, 1.77s/it]
|
| 533 |
69%|██████▊ | 322/469 [09:29<04:12, 1.72s/it]
|
| 534 |
69%|██████▉ | 323/469 [09:31<04:04, 1.67s/it]
|
| 535 |
69%|██████▉ | 324/469 [09:32<04:01, 1.67s/it]
|
| 536 |
69%|██████▉ | 325/469 [09:34<03:55, 1.64s/it]
|
| 537 |
70%|██████▉ | 326/469 [09:35<03:52, 1.62s/it]
|
| 538 |
70%|██████▉ | 327/469 [09:37<03:49, 1.61s/it]
|
| 539 |
70%|██████▉ | 328/469 [09:39<03:52, 1.65s/it]
|
| 540 |
70%|███████ | 329/469 [09:41<04:01, 1.72s/it]
|
| 541 |
70%|███████ | 330/469 [09:43<04:06, 1.77s/it]
|
| 542 |
71%|███████ | 331/469 [09:44<04:03, 1.76s/it]
|
| 543 |
71%|███████ | 332/469 [09:46<03:55, 1.72s/it]
|
| 544 |
71%|███████ | 333/469 [09:47<03:48, 1.68s/it]
|
| 545 |
71%|███████ | 334/469 [09:49<03:46, 1.68s/it]
|
| 546 |
71%|███████▏ | 335/469 [09:51<03:42, 1.66s/it]
|
| 547 |
72%|███████▏ | 336/469 [09:52<03:37, 1.63s/it]
|
| 548 |
72%|███████▏ | 337/469 [09:54<03:33, 1.62s/it]
|
| 549 |
72%|███████▏ | 338/469 [09:56<03:33, 1.63s/it]
|
| 550 |
72%|███████▏ | 339/469 [09:57<03:31, 1.62s/it]
|
| 551 |
72%|███████▏ | 340/469 [09:59<03:27, 1.61s/it]
|
| 552 |
73%|███████▎ | 341/469 [10:01<03:32, 1.66s/it]
|
| 553 |
73%|███████▎ | 342/469 [10:03<03:41, 1.75s/it]
|
| 554 |
73%|███████▎ | 343/469 [10:04<03:35, 1.71s/it]
|
| 555 |
73%|███████▎ | 344/469 [10:06<03:31, 1.70s/it]
|
| 556 |
74%|███████▎ | 345/469 [10:27<15:27, 7.48s/it]
|
| 557 |
74%|███████▍ | 346/469 [10:29<11:48, 5.76s/it]
|
| 558 |
74%|███████▍ | 347/469 [10:30<09:19, 4.59s/it]
|
| 559 |
74%|███████▍ | 348/469 [10:32<07:29, 3.71s/it]
|
| 560 |
74%|███████▍ | 349/469 [10:34<06:11, 3.10s/it]
|
| 561 |
75%|███████▍ | 350/469 [10:36<05:49, 2.93s/it]
|
| 562 |
75%|███████▍ | 351/469 [10:38<05:09, 2.63s/it]
|
| 563 |
75%|███████▌ | 352/469 [10:40<04:39, 2.39s/it]
|
| 564 |
75%|███████▌ | 353/469 [10:42<04:16, 2.21s/it]
|
| 565 |
75%|███████▌ | 354/469 [10:43<03:53, 2.03s/it]
|
| 566 |
76%|███████▌ | 355/469 [10:45<03:35, 1.89s/it]
|
| 567 |
76%|███████▌ | 356/469 [10:47<03:30, 1.86s/it]
|
| 568 |
76%|███████▌ | 357/469 [10:48<03:23, 1.82s/it]
|
| 569 |
76%|███████▋ | 358/469 [10:50<03:19, 1.80s/it]
|
| 570 |
77%|███████▋ | 359/469 [10:52<03:20, 1.82s/it]
|
| 571 |
77%|███████▋ | 360/469 [10:54<03:23, 1.87s/it]
|
| 572 |
77%|███████▋ | 361/469 [10:56<03:16, 1.82s/it]
|
| 573 |
77%|███████▋ | 362/469 [10:58<03:24, 1.91s/it]
|
| 574 |
77%|███████▋ | 363/469 [11:00<03:12, 1.82s/it]
|
| 575 |
78%|███████▊ | 364/469 [11:01<03:05, 1.77s/it]
|
| 576 |
78%|███████▊ | 365/469 [11:03<02:58, 1.72s/it]
|
| 577 |
78%|███████▊ | 366/469 [11:04<02:54, 1.69s/it]
|
| 578 |
78%|███████▊ | 367/469 [11:06<02:52, 1.69s/it]
|
| 579 |
78%|███████▊ | 368/469 [11:08<02:50, 1.68s/it]
|
| 580 |
79%|███████▊ | 369/469 [11:10<02:50, 1.71s/it]
|
| 581 |
79%|███████▉ | 370/469 [11:11<02:49, 1.71s/it]
|
| 582 |
79%|███████▉ | 371/469 [11:13<02:46, 1.70s/it]
|
| 583 |
79%|███████▉ | 372/469 [11:15<02:49, 1.74s/it]
|
| 584 |
80%|███████▉ | 373/469 [11:16<02:43, 1.70s/it]
|
| 585 |
80%|███████▉ | 374/469 [11:18<02:42, 1.71s/it]
|
| 586 |
80%|███████▉ | 375/469 [11:20<02:44, 1.75s/it]
|
| 587 |
80%|████████ | 376/469 [11:22<02:47, 1.80s/it]
|
| 588 |
80%|████████ | 377/469 [11:24<02:41, 1.76s/it]
|
| 589 |
81%|████████ | 378/469 [11:25<02:34, 1.70s/it]
|
| 590 |
81%|████████ | 379/469 [11:27<02:30, 1.67s/it]
|
| 591 |
81%|████████ | 380/469 [11:29<02:33, 1.72s/it]
|
| 592 |
81%|████████ | 381/469 [11:30<02:28, 1.69s/it]
|
| 593 |
81%|████████▏ | 382/469 [11:32<02:24, 1.66s/it]
|
| 594 |
82%|████████▏ | 383/469 [11:34<02:29, 1.73s/it]
|
| 595 |
82%|████████▏ | 384/469 [11:36<02:33, 1.80s/it]
|
| 596 |
82%|████████▏ | 385/469 [11:38<02:34, 1.84s/it]
|
| 597 |
82%|████████▏ | 386/469 [11:39<02:27, 1.77s/it]
|
| 598 |
83%|████████▎ | 387/469 [11:41<02:26, 1.79s/it]
|
| 599 |
83%|████████▎ | 388/469 [11:43<02:21, 1.74s/it]
|
| 600 |
83%|████████▎ | 389/469 [11:45<02:23, 1.79s/it]
|
| 601 |
83%|████████▎ | 390/469 [11:46<02:16, 1.73s/it]
|
| 602 |
83%|████████▎ | 391/469 [11:48<02:12, 1.70s/it]
|
| 603 |
84%|████████▎ | 392/469 [11:49<02:09, 1.69s/it]
|
| 604 |
84%|████████▍ | 393/469 [11:51<02:05, 1.66s/it]
|
| 605 |
84%|████████▍ | 394/469 [11:53<02:03, 1.65s/it]
|
| 606 |
84%|████████▍ | 395/469 [11:54<02:01, 1.65s/it]
|
| 607 |
84%|████████▍ | 396/469 [11:56<01:59, 1.63s/it]
|
| 608 |
85%|████████▍ | 397/469 [11:58<01:58, 1.64s/it]
|
| 609 |
85%|████████▍ | 398/469 [11:59<02:00, 1.70s/it]
|
| 610 |
85%|████████▌ | 399/469 [12:01<01:57, 1.68s/it]
|
| 611 |
85%|████████▌ | 400/469 [12:03<01:54, 1.66s/it]
|
| 612 |
86%|████████▌ | 401/469 [12:04<01:51, 1.64s/it]
|
| 613 |
86%|████████▌ | 402/469 [12:25<08:21, 7.49s/it]
|
| 614 |
86%|████████▌ | 403/469 [12:27<06:17, 5.72s/it]
|
| 615 |
86%|████████▌ | 404/469 [12:29<04:51, 4.49s/it]
|
| 616 |
86%|████████▋ | 405/469 [12:30<03:53, 3.65s/it]
|
| 617 |
87%|████████▋ | 406/469 [12:32<03:12, 3.06s/it]
|
| 618 |
87%|████████▋ | 407/469 [12:34<02:42, 2.63s/it]
|
| 619 |
87%|████████▋ | 408/469 [12:35<02:21, 2.32s/it]
|
| 620 |
87%|████████▋ | 409/469 [12:37<02:06, 2.10s/it]
|
| 621 |
87%|████████▋ | 410/469 [12:38<01:54, 1.94s/it]
|
| 622 |
88%|████████▊ | 411/469 [12:40<01:46, 1.83s/it]
|
| 623 |
88%|████████▊ | 412/469 [12:41<01:39, 1.75s/it]
|
| 624 |
88%|████████▊ | 413/469 [12:43<01:36, 1.73s/it]
|
| 625 |
88%|████████▊ | 414/469 [12:45<01:33, 1.69s/it]
|
| 626 |
88%|████████▊ | 415/469 [12:46<01:29, 1.65s/it]
|
| 627 |
89%|████████▊ | 416/469 [12:48<01:28, 1.68s/it]
|
| 628 |
89%|████████▉ | 417/469 [12:50<01:26, 1.67s/it]
|
| 629 |
89%|████████▉ | 418/469 [12:51<01:24, 1.65s/it]
|
| 630 |
89%|████████▉ | 419/469 [12:53<01:21, 1.64s/it]
|
| 631 |
90%|████████▉ | 420/469 [12:55<01:21, 1.66s/it]
|
| 632 |
90%|████████▉ | 421/469 [12:56<01:19, 1.66s/it]
|
| 633 |
90%|████████▉ | 422/469 [12:58<01:18, 1.67s/it]
|
| 634 |
90%|█████████ | 423/469 [13:00<01:18, 1.71s/it]
|
| 635 |
90%|█████████ | 424/469 [13:02<01:18, 1.73s/it]
|
| 636 |
91%|█████████ | 425/469 [13:03<01:14, 1.69s/it]
|
| 637 |
91%|█████████ | 426/469 [13:05<01:11, 1.66s/it]
|
| 638 |
91%|█████████ | 427/469 [13:07<01:13, 1.74s/it]
|
| 639 |
91%|█████████▏| 428/469 [13:08<01:09, 1.70s/it]
|
| 640 |
91%|█████████▏| 429/469 [13:10<01:07, 1.68s/it]
|
| 641 |
92%|█████████▏| 430/469 [13:12<01:05, 1.68s/it]
|
| 642 |
92%|█████████▏| 431/469 [13:13<01:05, 1.71s/it]
|
| 643 |
92%|█████████▏| 432/469 [13:15<01:02, 1.68s/it]
|
| 644 |
92%|█████████▏| 433/469 [13:17<01:01, 1.71s/it]
|
| 645 |
93%|█████████▎| 434/469 [13:18<00:58, 1.69s/it]
|
| 646 |
93%|█████████▎| 435/469 [13:20<00:57, 1.69s/it]
|
| 647 |
93%|█████████▎| 436/469 [13:22<00:55, 1.68s/it]
|
| 648 |
93%|█████████▎| 437/469 [13:23<00:53, 1.67s/it]
|
| 649 |
93%|█████████▎| 438/469 [13:25<00:51, 1.66s/it]
|
| 650 |
94%|█████████▎| 439/469 [13:27<00:49, 1.64s/it]
|
| 651 |
94%|█████████▍| 440/469 [13:28<00:47, 1.65s/it]
|
| 652 |
94%|█████████▍| 441/469 [13:30<00:46, 1.65s/it]
|
| 653 |
94%|█████████▍| 442/469 [13:32<00:44, 1.63s/it]
|
| 654 |
94%|█████████▍| 443/469 [13:33<00:43, 1.67s/it]
|
| 655 |
95%|█████████▍| 444/469 [13:35<00:41, 1.65s/it]
|
| 656 |
95%|█████████▍| 445/469 [13:36<00:39, 1.64s/it]
|
| 657 |
95%|█████████▌| 446/469 [13:39<00:40, 1.77s/it]
|
| 658 |
95%|█████████▌| 447/469 [13:40<00:37, 1.72s/it]
|
| 659 |
96%|█████████▌| 448/469 [13:42<00:35, 1.68s/it]
|
| 660 |
96%|█████████▌| 449/469 [13:44<00:36, 1.84s/it]
|
| 661 |
96%|█████████▌| 450/469 [13:46<00:33, 1.78s/it]
|
| 662 |
96%|█████████▌| 451/469 [13:47<00:30, 1.72s/it]
|
| 663 |
96%|█████████▋| 452/469 [13:49<00:30, 1.78s/it]
|
| 664 |
97%|█████████▋| 453/469 [13:51<00:27, 1.73s/it]
|
| 665 |
97%|█████████▋| 454/469 [13:52<00:25, 1.70s/it]
|
| 666 |
97%|█████████▋| 455/469 [13:54<00:23, 1.70s/it]
|
| 667 |
97%|█████████▋| 456/469 [13:56<00:21, 1.67s/it]
|
| 668 |
97%|█████████▋| 457/469 [13:57<00:19, 1.66s/it]
|
| 669 |
98%|█████████▊| 458/469 [13:59<00:18, 1.65s/it]
|
| 670 |
98%|█████████▊| 459/469 [14:01<00:16, 1.64s/it]
|
| 671 |
98%|█████████▊| 460/469 [14:02<00:14, 1.65s/it]
|
| 672 |
98%|█████████▊| 461/469 [14:04<00:13, 1.65s/it]
|
| 673 |
99%|█████████▊| 462/469 [14:06<00:11, 1.67s/it]
|
| 674 |
99%|█████████▊| 463/469 [14:07<00:10, 1.67s/it]
|
| 675 |
99%|█████████▉| 464/469 [14:09<00:09, 1.83s/it]
|
| 676 |
99%|█████████▉| 465/469 [14:11<00:07, 1.84s/it]
|
| 677 |
99%|█████████▉| 466/469 [14:13<00:05, 1.80s/it]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
2026-03-25 14:55:31.459313: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
| 3 |
+
2026-03-25 14:55:36.111147: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
|
| 4 |
+
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
| 5 |
+
2026-03-25 14:55:36.137507: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
|
| 6 |
+
2026-03-25 14:55:36.137574: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: 66d2d54653616c6252364513da490658-taskrole1-0
|
| 7 |
+
2026-03-25 14:55:36.137623: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: 66d2d54653616c6252364513da490658-taskrole1-0
|
| 8 |
+
2026-03-25 14:55:36.137710: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
|
| 9 |
+
2026-03-25 14:55:36.137756: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 535.154.5
|
| 10 |
+
2026-03-25 14:55:37.175397: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
| 11 |
+
warming up TensorFlow...
|
| 12 |
+
|
| 13 |
0%| | 0/1 [00:00<?, ?it/s]2026-03-25 14:55:37.841840: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
|
| 14 |
+
|
| 15 |
+
computing reference batch activations...
|
| 16 |
+
|
| 17 |
0%| | 0/211 [00:00<?, ?it/s]
|
| 18 |
0%| | 1/211 [00:02<07:28, 2.13s/it]
|
| 19 |
1%| | 2/211 [00:03<06:40, 1.92s/it]
|
| 20 |
1%|▏ | 3/211 [00:06<08:10, 2.36s/it]
|
| 21 |
2%|▏ | 4/211 [00:08<07:14, 2.10s/it]
|
| 22 |
2%|▏ | 5/211 [00:10<06:43, 1.96s/it]
|
| 23 |
3%|▎ | 6/211 [00:11<06:27, 1.89s/it]
|
| 24 |
3%|▎ | 7/211 [00:13<06:15, 1.84s/it]
|
| 25 |
4%|▍ | 8/211 [00:15<06:05, 1.80s/it]
|
| 26 |
4%|▍ | 9/211 [00:17<06:00, 1.79s/it]
|
| 27 |
5%|▍ | 10/211 [00:18<05:53, 1.76s/it]
|
| 28 |
5%|▌ | 11/211 [00:20<05:47, 1.74s/it]
|
| 29 |
6%|▌ | 12/211 [00:22<05:40, 1.71s/it]
|
| 30 |
6%|▌ | 13/211 [00:23<05:34, 1.69s/it]
|
| 31 |
7%|▋ | 14/211 [00:25<05:32, 1.69s/it]
|
| 32 |
7%|▋ | 15/211 [00:27<05:39, 1.73s/it]
|
| 33 |
8%|▊ | 16/211 [00:29<05:34, 1.71s/it]
|
| 34 |
8%|▊ | 17/211 [00:30<05:34, 1.72s/it]
|
| 35 |
9%|▊ | 18/211 [00:32<06:00, 1.87s/it]
|
| 36 |
9%|▉ | 19/211 [00:34<05:48, 1.82s/it]
|
| 37 |
9%|▉ | 20/211 [00:36<05:44, 1.80s/it]
|
| 38 |
10%|▉ | 21/211 [00:38<05:39, 1.78s/it]
|
| 39 |
10%|█ | 22/211 [00:40<05:40, 1.80s/it]
|
| 40 |
11%|█ | 23/211 [00:41<05:45, 1.84s/it]
|
| 41 |
11%|█▏ | 24/211 [00:43<05:43, 1.84s/it]
|
| 42 |
12%|█▏ | 25/211 [00:45<05:35, 1.80s/it]
|
| 43 |
12%|█▏ | 26/211 [00:47<05:41, 1.84s/it]
|
| 44 |
13%|█▎ | 27/211 [00:49<05:42, 1.86s/it]
|
| 45 |
13%|█▎ | 28/211 [00:51<05:36, 1.84s/it]
|
| 46 |
14%|█▎ | 29/211 [00:52<05:27, 1.80s/it]
|
| 47 |
14%|█▍ | 30/211 [00:54<05:18, 1.76s/it]
|
| 48 |
15%|█▍ | 31/211 [00:56<05:28, 1.83s/it]
|
| 49 |
15%|█▌ | 32/211 [00:58<05:24, 1.81s/it]
|
| 50 |
16%|█▌ | 33/211 [01:00<05:28, 1.85s/it]
|
| 51 |
16%|█▌ | 34/211 [01:01<05:20, 1.81s/it]
|
| 52 |
17%|█▋ | 35/211 [01:03<05:19, 1.82s/it]
|
| 53 |
17%|█▋ | 36/211 [01:05<05:14, 1.80s/it]
|
| 54 |
18%|█▊ | 37/211 [01:07<05:07, 1.77s/it]
|
| 55 |
18%|█▊ | 38/211 [01:08<05:04, 1.76s/it]
|
| 56 |
18%|█▊ | 39/211 [01:10<05:17, 1.85s/it]
|
| 57 |
19%|█▉ | 40/211 [01:12<05:10, 1.82s/it]
|
| 58 |
19%|█▉ | 41/211 [01:14<05:23, 1.90s/it]
|
| 59 |
20%|█▉ | 42/211 [01:16<05:15, 1.87s/it]
|
| 60 |
20%|██ | 43/211 [01:18<05:06, 1.82s/it]
|
| 61 |
21%|██ | 44/211 [01:20<05:01, 1.81s/it]
|
| 62 |
21%|██▏ | 45/211 [01:21<04:55, 1.78s/it]
|
| 63 |
22%|██▏ | 46/211 [01:23<04:54, 1.78s/it]
|
| 64 |
22%|██▏ | 47/211 [01:25<04:49, 1.77s/it]
|
| 65 |
23%|██▎ | 48/211 [01:27<04:54, 1.81s/it]
|
| 66 |
23%|██▎ | 49/211 [01:28<04:48, 1.78s/it]
|
| 67 |
24%|██▎ | 50/211 [01:30<04:42, 1.75s/it]
|
| 68 |
24%|██▍ | 51/211 [01:32<04:36, 1.73s/it]
|
| 69 |
25%|██▍ | 52/211 [01:34<04:44, 1.79s/it]
|
| 70 |
25%|██▌ | 53/211 [01:36<05:01, 1.91s/it]
|
| 71 |
26%|██▌ | 54/211 [01:38<04:49, 1.84s/it]
|
| 72 |
26%|██▌ | 55/211 [01:40<05:11, 2.00s/it]
|
| 73 |
27%|██▋ | 56/211 [01:42<04:57, 1.92s/it]
|
| 74 |
27%|██▋ | 57/211 [01:43<04:43, 1.84s/it]
|
| 75 |
27%|█��▋ | 58/211 [01:45<04:48, 1.88s/it]
|
| 76 |
28%|██▊ | 59/211 [01:47<04:34, 1.81s/it]
|
| 77 |
28%|██▊ | 60/211 [01:49<04:36, 1.83s/it]
|
| 78 |
29%|██▉ | 61/211 [01:51<04:25, 1.77s/it]
|
| 79 |
29%|██▉ | 62/211 [01:52<04:17, 1.73s/it]
|
| 80 |
30%|██▉ | 63/211 [01:54<04:13, 1.71s/it]
|
| 81 |
30%|███ | 64/211 [01:55<04:08, 1.69s/it]
|
| 82 |
31%|███ | 65/211 [01:57<04:09, 1.71s/it]
|
| 83 |
31%|███▏ | 66/211 [01:59<04:11, 1.73s/it]
|
| 84 |
32%|███▏ | 67/211 [02:01<04:11, 1.75s/it]
|
| 85 |
32%|███▏ | 68/211 [02:03<04:16, 1.79s/it]
|
| 86 |
33%|███▎ | 69/211 [02:04<04:11, 1.77s/it]
|
| 87 |
33%|███▎ | 70/211 [02:07<04:30, 1.92s/it]
|
| 88 |
34%|███▎ | 71/211 [02:08<04:20, 1.86s/it]
|
| 89 |
34%|███▍ | 72/211 [02:10<04:12, 1.82s/it]
|
| 90 |
35%|███▍ | 73/211 [02:12<04:07, 1.79s/it]
|
| 91 |
35%|███▌ | 74/211 [02:14<04:12, 1.84s/it]
|
| 92 |
36%|███▌ | 75/211 [02:16<04:19, 1.91s/it]
|
| 93 |
36%|███▌ | 76/211 [02:18<04:06, 1.83s/it]
|
| 94 |
36%|███▋ | 77/211 [02:20<04:17, 1.92s/it]
|
| 95 |
37%|███▋ | 78/211 [02:21<04:06, 1.86s/it]
|
| 96 |
37%|███▋ | 79/211 [02:23<04:03, 1.84s/it]
|
| 97 |
38%|███▊ | 80/211 [02:26<04:34, 2.10s/it]
|
| 98 |
38%|███▊ | 81/211 [02:28<04:16, 1.97s/it]
|
| 99 |
39%|███▉ | 82/211 [02:29<04:01, 1.88s/it]
|
| 100 |
39%|███▉ | 83/211 [02:31<03:54, 1.84s/it]
|
| 101 |
40%|███▉ | 84/211 [02:33<03:48, 1.80s/it]
|
| 102 |
40%|████ | 85/211 [02:34<03:42, 1.77s/it]
|
| 103 |
41%|████ | 86/211 [02:36<03:39, 1.75s/it]
|
| 104 |
41%|████ | 87/211 [02:38<03:37, 1.76s/it]
|
| 105 |
42%|████▏ | 88/211 [02:39<03:33, 1.73s/it]
|
| 106 |
42%|████▏ | 89/211 [02:41<03:27, 1.70s/it]
|
| 107 |
43%|████▎ | 90/211 [02:43<03:27, 1.71s/it]
|
| 108 |
43%|████▎ | 91/211 [02:45<03:27, 1.73s/it]
|
| 109 |
44%|████▎ | 92/211 [02:46<03:23, 1.71s/it]
|
| 110 |
44%|████▍ | 93/211 [02:48<03:22, 1.71s/it]
|
| 111 |
45%|████▍ | 94/211 [02:57<07:25, 3.81s/it]
|
| 112 |
45%|████▌ | 95/211 [02:58<06:11, 3.20s/it]
|
| 113 |
45%|████▌ | 96/211 [03:00<05:17, 2.76s/it]
|
| 114 |
46%|████▌ | 97/211 [03:02<04:47, 2.52s/it]
|
| 115 |
46%|████▋ | 98/211 [03:04<04:15, 2.27s/it]
|
| 116 |
47%|████▋ | 99/211 [03:06<03:56, 2.11s/it]
|
| 117 |
47%|████▋ | 100/211 [03:07<03:42, 2.00s/it]
|
| 118 |
48%|████▊ | 101/211 [03:09<03:28, 1.90s/it]
|
| 119 |
48%|████▊ | 102/211 [03:11<03:17, 1.81s/it]
|
| 120 |
49%|████▉ | 103/211 [03:12<03:12, 1.78s/it]
|
| 121 |
49%|████▉ | 104/211 [03:14<03:05, 1.74s/it]
|
| 122 |
50%|████▉ | 105/211 [03:16<03:01, 1.71s/it]
|
| 123 |
50%|█████ | 106/211 [03:17<03:05, 1.76s/it]
|
| 124 |
51%|█████ | 107/211 [03:19<03:03, 1.76s/it]
|
| 125 |
51%|█████ | 108/211 [03:21<02:58, 1.74s/it]
|
| 126 |
52%|█████▏ | 109/211 [03:23<02:58, 1.75s/it]
|
| 127 |
52%|█████▏ | 110/211 [03:25<02:58, 1.77s/it]
|
| 128 |
53%|█████▎ | 111/211 [03:26<02:51, 1.72s/it]
|
| 129 |
53%|█████▎ | 112/211 [03:28<02:49, 1.71s/it]
|
| 130 |
54%|█████▎ | 113/211 [03:29<02:46, 1.70s/it]
|
| 131 |
54%|█████▍ | 114/211 [03:31<02:44, 1.69s/it]
|
| 132 |
55%|█████▍ | 115/211 [03:33<02:41, 1.68s/it]
|
| 133 |
55%|█████▍ | 116/211 [03:35<02:42, 1.71s/it]
|
| 134 |
55%|█████▌ | 117/211 [03:36<02:38, 1.69s/it]
|
| 135 |
56%|█████▌ | 118/211 [03:38<02:38, 1.70s/it]
|
| 136 |
56%|█████▋ | 119/211 [03:40<02:35, 1.69s/it]
|
| 137 |
57%|█████▋ | 120/211 [03:41<02:36, 1.72s/it]
|
| 138 |
57%|█████▋ | 121/211 [03:43<02:36, 1.74s/it]
|
| 139 |
58%|█████▊ | 122/211 [03:45<02:31, 1.70s/it]
|
| 140 |
58%|█████▊ | 123/211 [03:47<02:34, 1.76s/it]
|
| 141 |
59%|█████▉ | 124/211 [03:48<02:30, 1.73s/it]
|
| 142 |
59%|█████▉ | 125/211 [03:50<02:25, 1.70s/it]
|
| 143 |
60%|█████▉ | 126/211 [03:52<02:22, 1.68s/it]
|
| 144 |
60%|██████ | 127/211 [03:53<02:21, 1.69s/it]
|
| 145 |
61%|██████ | 128/211 [03:55<02:19, 1.68s/it]
|
| 146 |
61%|██████ | 129/211 [03:57<02:17, 1.68s/it]
|
| 147 |
62%|██████▏ | 130/211 [03:59<02:21, 1.75s/it]
|
| 148 |
62%|██████▏ | 131/211 [04:00<02:20, 1.75s/it]
|
| 149 |
63%|██████▎ | 132/211 [04:02<02:18, 1.75s/it]
|
| 150 |
63%|██████▎ | 133/211 [04:04<02:21, 1.81s/it]
|
| 151 |
64%|██████▎ | 134/211 [04:06<02:15, 1.76s/it]
|
| 152 |
64%|██████▍ | 135/211 [04:08<02:17, 1.81s/it]
|
| 153 |
64%|██████▍ | 136/211 [04:10<02:20, 1.88s/it]
|
| 154 |
65%|██████▍ | 137/211 [04:11<02:13, 1.80s/it]
|
| 155 |
65%|██████▌ | 138/211 [04:13<02:07, 1.74s/it]
|
| 156 |
66%|██████▌ | 139/211 [04:16<02:38, 2.20s/it]
|
| 157 |
66%|██████▋ | 140/211 [04:18<02:27, 2.08s/it]
|
| 158 |
67%|██████▋ | 141/211 [04:20<02:16, 1.95s/it]
|
| 159 |
67%|██████▋ | 142/211 [04:21<02:11, 1.91s/it]
|
| 160 |
68%|██████▊ | 143/211 [04:23<02:06, 1.86s/it]
|
| 161 |
68%|██████▊ | 144/211 [04:25<02:02, 1.82s/it]
|
| 162 |
69%|██████▊ | 145/211 [04:27<01:58, 1.80s/it]
|
| 163 |
69%|██████▉ | 146/211 [04:28<01:53, 1.75s/it]
|
| 164 |
70%|██████▉ | 147/211 [04:30<01:54, 1.78s/it]
|
| 165 |
70%|███████ | 148/211 [04:32<01:53, 1.80s/it]
|
| 166 |
71%|███████ | 149/211 [04:34<01:54, 1.85s/it]
|
| 167 |
71%|███████ | 150/211 [04:36<01:49, 1.80s/it]
|
| 168 |
72%|███████▏ | 151/211 [04:37<01:46, 1.77s/it]
|
| 169 |
72%|███████▏ | 152/211 [04:39<01:45, 1.79s/it]
|
| 170 |
73%|███████▎ | 153/211 [04:41<01:43, 1.78s/it]
|
| 171 |
73%|███████▎ | 154/211 [04:43<01:41, 1.78s/it]
|
| 172 |
73%|███████▎ | 155/211 [04:44<01:37, 1.73s/it]
|
| 173 |
74%|███████▍ | 156/211 [04:46<01:35, 1.73s/it]
|
| 174 |
74%|███████▍ | 157/211 [04:48<01:33, 1.73s/it]
|
| 175 |
75%|███████▍ | 158/211 [04:50<01:33, 1.77s/it]
|
| 176 |
75%|███████▌ | 159/211 [04:51<01:30, 1.74s/it]
|
| 177 |
76%|███████▌ | 160/211 [04:53<01:27, 1.72s/it]
|
| 178 |
76%|███████▋ | 161/211 [04:55<01:28, 1.77s/it]
|
| 179 |
77%|███████▋ | 162/211 [04:57<01:28, 1.80s/it]
|
| 180 |
77%|███████▋ | 163/211 [04:58<01:24, 1.75s/it]
|
| 181 |
78%|███████▊ | 164/211 [05:00<01:26, 1.84s/it]
|
| 182 |
78%|███████▊ | 165/211 [05:02<01:21, 1.77s/it]
|
| 183 |
79%|███████▊ | 166/211 [05:04<01:17, 1.72s/it]
|
| 184 |
79%|███████▉ | 167/211 [05:05<01:15, 1.72s/it]
|
| 185 |
80%|███████▉ | 168/211 [05:07<01:16, 1.79s/it]
|
| 186 |
80%|████████ | 169/211 [05:09<01:17, 1.84s/it]
|
| 187 |
81%|████████ | 170/211 [05:11<01:15, 1.83s/it]
|
| 188 |
81%|████████ | 171/211 [05:13<01:12, 1.81s/it]
|
| 189 |
82%|████████▏ | 172/211 [05:15<01:09, 1.79s/it]
|
| 190 |
82%|████████▏ | 173/211 [05:16<01:06, 1.74s/it]
|
| 191 |
82%|████████▏ | 174/211 [05:18<01:04, 1.75s/it]
|
| 192 |
83%|████████▎ | 175/211 [05:20<01:02, 1.74s/it]
|
| 193 |
83%|████████▎ | 176/211 [05:41<04:21, 7.47s/it]
|
| 194 |
84%|████████▍ | 177/211 [05:42<03:17, 5.81s/it]
|
| 195 |
84%|████████▍ | 178/211 [05:44<02:32, 4.63s/it]
|
| 196 |
85%|████████▍ | 179/211 [05:46<02:00, 3.75s/it]
|
| 197 |
85%|████████▌ | 180/211 [05:48<01:36, 3.13s/it]
|
| 198 |
86%|████████▌ | 181/211 [05:49<01:21, 2.72s/it]
|
| 199 |
86%|████████▋ | 182/211 [05:51<01:10, 2.42s/it]
|
| 200 |
87%|████████▋ | 183/211 [05:53<01:01, 2.19s/it]
|
| 201 |
87%|████████▋ | 184/211 [05:55<00:54, 2.03s/it]
|
| 202 |
88%|████████▊ | 185/211 [05:56<00:50, 1.94s/it]
|
| 203 |
88%|████████▊ | 186/211 [05:58<00:46, 1.88s/it]
|
| 204 |
89%|████████▊ | 187/211 [06:00<00:43, 1.82s/it]
|
| 205 |
89%|████████▉ | 188/211 [06:01<00:41, 1.79s/it]
|
| 206 |
90%|████████▉ | 189/211 [06:03<00:40, 1.86s/it]
|
| 207 |
90%|█████████ | 190/211 [06:06<00:45, 2.14s/it]
|
| 208 |
91%|█████████ | 191/211 [06:08<00:40, 2.02s/it]
|
| 209 |
91%|█████████ | 192/211 [06:10<00:36, 1.91s/it]
|
| 210 |
91%|█████████▏| 193/211 [06:11<00:33, 1.85s/it]
|
| 211 |
92%|█████████▏| 194/211 [06:13<00:30, 1.82s/it]
|
| 212 |
92%|█████████▏| 195/211 [06:15<00:28, 1.78s/it]
|
| 213 |
93%|█████████▎| 196/211 [06:17<00:28, 1.92s/it]
|
| 214 |
93%|█████████▎| 197/211 [06:19<00:25, 1.84s/it]
|
| 215 |
94%|█████████▍| 198/211 [06:20<00:23, 1.80s/it]
|
| 216 |
94%|█████████▍| 199/211 [06:22<00:22, 1.84s/it]
|
| 217 |
95%|█████████▍| 200/211 [06:24<00:19, 1.80s/it]
|
| 218 |
95%|█████████▌| 201/211 [06:26<00:17, 1.77s/it]
|
| 219 |
96%|█████████▌| 202/211 [06:28<00:16, 1.82s/it]
|
| 220 |
96%|█████████▌| 203/211 [06:29<00:14, 1.77s/it]
|
| 221 |
97%|█████████▋| 204/211 [06:31<00:12, 1.73s/it]
|
| 222 |
97%|█████████▋| 205/211 [06:33<00:10, 1.72s/it]
|
| 223 |
98%|█████████▊| 206/211 [06:34<00:08, 1.73s/it]
|
| 224 |
98%|█████████▊| 207/211 [06:36<00:06, 1.71s/it]
|
| 225 |
99%|█████████▊| 208/211 [06:39<00:05, 1.97s/it]
|
| 226 |
99%|█████████▉| 209/211 [06:40<00:03, 1.93s/it]
|
| 227 |
+
computing/reading reference batch statistics...
|
| 228 |
+
computing sample batch activations...
|
| 229 |
+
|
| 230 |
0%| | 0/469 [00:00<?, ?it/s]
|
| 231 |
0%| | 1/469 [00:02<15:45, 2.02s/it]
|
| 232 |
0%| | 2/469 [00:03<14:20, 1.84s/it]
|
| 233 |
1%| | 3/469 [00:05<13:51, 1.78s/it]
|
| 234 |
1%| | 4/469 [00:07<13:37, 1.76s/it]
|
| 235 |
1%| | 5/469 [00:08<13:23, 1.73s/it]
|
| 236 |
1%|▏ | 6/469 [00:10<13:11, 1.71s/it]
|
| 237 |
1%|▏ | 7/469 [00:12<13:03, 1.70s/it]
|
| 238 |
2%|▏ | 8/469 [00:13<12:57, 1.69s/it]
|
| 239 |
2%|▏ | 9/469 [00:15<12:53, 1.68s/it]
|
| 240 |
2%|▏ | 10/469 [00:17<13:08, 1.72s/it]
|
| 241 |
2%|▏ | 11/469 [00:26<31:14, 4.09s/it]
|
| 242 |
3%|▎ | 12/469 [00:28<25:29, 3.35s/it]
|
| 243 |
3%|▎ | 13/469 [00:30<21:32, 2.84s/it]
|
| 244 |
3%|▎ | 14/469 [00:31<18:50, 2.49s/it]
|
| 245 |
3%|▎ | 15/469 [00:33<16:50, 2.23s/it]
|
| 246 |
3%|▎ | 16/469 [00:35<15:32, 2.06s/it]
|
| 247 |
4%|▎ | 17/469 [00:36<14:57, 1.99s/it]
|
| 248 |
4%|▍ | 18/469 [00:38<14:12, 1.89s/it]
|
| 249 |
4%|▍ | 19/469 [00:40<13:45, 1.83s/it]
|
| 250 |
4%|▍ | 20/469 [00:41<13:17, 1.78s/it]
|
| 251 |
4%|▍ | 21/469 [00:43<13:08, 1.76s/it]
|
| 252 |
5%|▍ | 22/469 [00:45<12:51, 1.73s/it]
|
| 253 |
5%|▍ | 23/469 [00:47<12:51, 1.73s/it]
|
| 254 |
5%|▌ | 24/469 [00:48<12:38, 1.70s/it]
|
| 255 |
5%|▌ | 25/469 [00:50<12:33, 1.70s/it]
|
| 256 |
6%|▌ | 26/469 [00:51<12:26, 1.68s/it]
|
| 257 |
6%|▌ | 27/469 [00:53<12:26, 1.69s/it]
|
| 258 |
6%|▌ | 28/469 [00:55<12:17, 1.67s/it]
|
| 259 |
6%|▌ | 29/469 [00:57<12:18, 1.68s/it]
|
| 260 |
6%|▋ | 30/469 [00:58<12:12, 1.67s/it]
|
| 261 |
7%|▋ | 31/469 [01:00<12:15, 1.68s/it]
|
| 262 |
7%|▋ | 32/469 [01:02<12:10, 1.67s/it]
|
| 263 |
7%|▋ | 33/469 [01:04<12:55, 1.78s/it]
|
| 264 |
7%|▋ | 34/469 [01:05<12:41, 1.75s/it]
|
| 265 |
7%|▋ | 35/469 [01:07<12:24, 1.72s/it]
|
| 266 |
8%|▊ | 36/469 [01:09<12:34, 1.74s/it]
|
| 267 |
8%|▊ | 37/469 [01:10<12:19, 1.71s/it]
|
| 268 |
8%|▊ | 38/469 [01:12<12:27, 1.73s/it]
|
| 269 |
8%|▊ | 39/469 [01:14<12:14, 1.71s/it]
|
| 270 |
9%|▊ | 40/469 [01:15<12:10, 1.70s/it]
|
| 271 |
9%|▊ | 41/469 [01:17<12:07, 1.70s/it]
|
| 272 |
9%|▉ | 42/469 [01:19<12:09, 1.71s/it]
|
| 273 |
9%|▉ | 43/469 [01:21<12:18, 1.73s/it]
|
| 274 |
9%|▉ | 44/469 [01:22<12:13, 1.72s/it]
|
| 275 |
10%|▉ | 45/469 [01:24<12:06, 1.71s/it]
|
| 276 |
10%|▉ | 46/469 [01:26<12:22, 1.75s/it]
|
| 277 |
10%|█ | 47/469 [01:28<12:35, 1.79s/it]
|
| 278 |
10%|█ | 48/469 [01:29<12:12, 1.74s/it]
|
| 279 |
10%|█ | 49/469 [01:31<12:01, 1.72s/it]
|
| 280 |
11%|█ | 50/469 [01:33<12:24, 1.78s/it]
|
| 281 |
11%|█ | 51/469 [01:35<12:32, 1.80s/it]
|
| 282 |
11%|█ | 52/469 [01:37<12:29, 1.80s/it]
|
| 283 |
11%|█▏ | 53/469 [01:39<13:30, 1.95s/it]
|
| 284 |
12%|█▏ | 54/469 [01:41<12:50, 1.86s/it]
|
| 285 |
12%|█▏ | 55/469 [01:42<12:27, 1.81s/it]
|
| 286 |
12%|█▏ | 56/469 [01:44<12:05, 1.76s/it]
|
| 287 |
12%|█▏ | 57/469 [01:46<12:17, 1.79s/it]
|
| 288 |
12%|█▏ | 58/469 [01:47<12:08, 1.77s/it]
|
| 289 |
13%|█▎ | 59/469 [01:49<12:09, 1.78s/it]
|
| 290 |
13%|█▎ | 60/469 [01:51<12:17, 1.80s/it]
|
| 291 |
13%|█▎ | 61/469 [01:53<12:07, 1.78s/it]
|
| 292 |
13%|█▎ | 62/469 [01:55<12:24, 1.83s/it]
|
| 293 |
13%|█▎ | 63/469 [01:56<12:00, 1.77s/it]
|
| 294 |
14%|█▎ | 64/469 [01:58<11:42, 1.73s/it]
|
| 295 |
14%|█▍ | 65/469 [02:00<11:35, 1.72s/it]
|
| 296 |
14%|█▍ | 66/469 [02:02<11:34, 1.72s/it]
|
| 297 |
14%|█▍ | 67/469 [02:04<12:11, 1.82s/it]
|
| 298 |
14%|█▍ | 68/469 [02:05<11:45, 1.76s/it]
|
| 299 |
15%|█▍ | 69/469 [02:07<11:31, 1.73s/it]
|
| 300 |
15%|█▍ | 70/469 [02:09<11:27, 1.72s/it]
|
| 301 |
15%|█▌ | 71/469 [02:10<11:25, 1.72s/it]
|
| 302 |
15%|█▌ | 72/469 [02:12<11:16, 1.70s/it]
|
| 303 |
16%|█▌ | 73/469 [02:14<11:09, 1.69s/it]
|
| 304 |
16%|█▌ | 74/469 [02:15<11:32, 1.75s/it]
|
| 305 |
16%|█▌ | 75/469 [02:17<11:47, 1.80s/it]
|
| 306 |
16%|█▌ | 76/469 [02:19<11:43, 1.79s/it]
|
| 307 |
16%|█▋ | 77/469 [02:21<12:10, 1.86s/it]
|
| 308 |
17%|█▋ | 78/469 [02:23<11:45, 1.80s/it]
|
| 309 |
17%|█▋ | 79/469 [02:25<12:10, 1.87s/it]
|
| 310 |
17%|█▋ | 80/469 [02:27<11:48, 1.82s/it]
|
| 311 |
17%|█▋ | 81/469 [02:28<11:30, 1.78s/it]
|
| 312 |
17%|█▋ | 82/469 [02:30<11:36, 1.80s/it]
|
| 313 |
18%|█▊ | 83/469 [02:32<11:18, 1.76s/it]
|
| 314 |
18%|█▊ | 84/469 [02:34<11:12, 1.75s/it]
|
| 315 |
18%|█▊ | 85/469 [02:35<11:03, 1.73s/it]
|
| 316 |
18%|█▊ | 86/469 [02:37<11:06, 1.74s/it]
|
| 317 |
19%|█▊ | 87/469 [02:39<11:04, 1.74s/it]
|
| 318 |
19%|█▉ | 88/469 [02:40<10:56, 1.72s/it]
|
| 319 |
19%|█▉ | 89/469 [02:42<11:07, 1.76s/it]
|
| 320 |
19%|█▉ | 90/469 [02:45<12:56, 2.05s/it]
|
| 321 |
19%|█▉ | 91/469 [02:47<12:10, 1.93s/it]
|
| 322 |
20%|█▉ | 92/469 [02:48<11:35, 1.84s/it]
|
| 323 |
20%|█▉ | 93/469 [02:50<11:19, 1.81s/it]
|
| 324 |
20%|██ | 94/469 [02:52<11:04, 1.77s/it]
|
| 325 |
20%|██ | 95/469 [02:53<10:48, 1.73s/it]
|
| 326 |
20%|██ | 96/469 [02:55<10:38, 1.71s/it]
|
| 327 |
21%|██ | 97/469 [02:57<11:01, 1.78s/it]
|
| 328 |
21%|██ | 98/469 [02:59<10:50, 1.75s/it]
|
| 329 |
21%|██ | 99/469 [03:00<10:33, 1.71s/it]
|
| 330 |
21%|██▏ | 100/469 [03:02<10:35, 1.72s/it]
|
| 331 |
22%|██▏ | 101/469 [03:04<11:42, 1.91s/it]
|
| 332 |
22%|██▏ | 102/469 [03:06<11:19, 1.85s/it]
|
| 333 |
22%|██▏ | 103/469 [03:08<11:48, 1.94s/it]
|
| 334 |
22%|██▏ | 104/469 [03:10<11:23, 1.87s/it]
|
| 335 |
22%|██▏ | 105/469 [03:11<10:53, 1.80s/it]
|
| 336 |
23%|██▎ | 106/469 [03:13<10:34, 1.75s/it]
|
| 337 |
23%|██▎ | 107/469 [03:15<10:32, 1.75s/it]
|
| 338 |
23%|██▎ | 108/469 [03:17<10:19, 1.71s/it]
|
| 339 |
23%|██▎ | 109/469 [03:18<10:08, 1.69s/it]
|
| 340 |
23%|██▎ | 110/469 [03:20<11:07, 1.86s/it]
|
| 341 |
24%|██▎ | 111/469 [03:22<10:49, 1.81s/it]
|
| 342 |
24%|██▍ | 112/469 [03:24<10:35, 1.78s/it]
|
| 343 |
24%|██▍ | 113/469 [03:25<10:22, 1.75s/it]
|
| 344 |
24%|██▍ | 114/469 [03:27<10:13, 1.73s/it]
|
| 345 |
25%|██▍ | 115/469 [03:29<10:01, 1.70s/it]
|
| 346 |
25%|██▍ | 116/469 [03:30<09:56, 1.69s/it]
|
| 347 |
25%|██▍ | 117/469 [03:32<09:56, 1.69s/it]
|
| 348 |
25%|██▌ | 118/469 [03:34<09:45, 1.67s/it]
|
| 349 |
25%|██▌ | 119/469 [03:35<09:45, 1.67s/it]
|
| 350 |
26%|██▌ | 120/469 [03:37<09:45, 1.68s/it]
|
| 351 |
26%|██▌ | 121/469 [03:40<11:56, 2.06s/it]
|
| 352 |
26%|██▌ | 122/469 [03:42<11:18, 1.95s/it]
|
| 353 |
26%|██▌ | 123/469 [03:44<11:00, 1.91s/it]
|
| 354 |
26%|██▋ | 124/469 [03:45<10:31, 1.83s/it]
|
| 355 |
27%|██▋ | 125/469 [03:47<10:06, 1.76s/it]
|
| 356 |
27%|██▋ | 126/469 [03:48<09:48, 1.72s/it]
|
| 357 |
27%|██▋ | 127/469 [03:50<09:57, 1.75s/it]
|
| 358 |
27%|██▋ | 128/469 [03:52<09:52, 1.74s/it]
|
| 359 |
28%|██▊ | 129/469 [03:54<09:46, 1.72s/it]
|
| 360 |
28%|██▊ | 130/469 [03:56<09:59, 1.77s/it]
|
| 361 |
28%|██▊ | 131/469 [03:57<09:56, 1.76s/it]
|
| 362 |
28%|██▊ | 132/469 [03:59<09:40, 1.72s/it]
|
| 363 |
28%|██▊ | 133/469 [04:01<10:56, 1.95s/it]
|
| 364 |
29%|██▊ | 134/469 [04:03<10:35, 1.90s/it]
|
| 365 |
29%|██▉ | 135/469 [04:05<10:05, 1.81s/it]
|
| 366 |
29%|██▉ | 136/469 [04:06<09:45, 1.76s/it]
|
| 367 |
29%|██▉ | 137/469 [04:08<09:30, 1.72s/it]
|
| 368 |
29%|██▉ | 138/469 [04:10<09:35, 1.74s/it]
|
| 369 |
30%|██▉ | 139/469 [04:11<09:19, 1.69s/it]
|
| 370 |
30%|██▉ | 140/469 [04:13<09:06, 1.66s/it]
|
| 371 |
30%|███ | 141/469 [04:15<09:00, 1.65s/it]
|
| 372 |
30%|███ | 142/469 [04:16<08:55, 1.64s/it]
|
| 373 |
30%|███ | 143/469 [04:18<08:55, 1.64s/it]
|
| 374 |
31%|███ | 144/469 [04:19<08:47, 1.62s/it]
|
| 375 |
31%|███ | 145/469 [04:21<09:06, 1.69s/it]
|
| 376 |
31%|███ | 146/469 [04:23<09:06, 1.69s/it]
|
| 377 |
31%|███▏ | 147/469 [04:25<08:56, 1.67s/it]
|
| 378 |
32%|███▏ | 148/469 [04:26<08:56, 1.67s/it]
|
| 379 |
32%|███▏ | 149/469 [04:28<09:09, 1.72s/it]
|
| 380 |
32%|███▏ | 150/469 [04:30<08:57, 1.69s/it]
|
| 381 |
32%|███▏ | 151/469 [04:31<08:47, 1.66s/it]
|
| 382 |
32%|███▏ | 152/469 [04:33<08:41, 1.64s/it]
|
| 383 |
33%|███▎ | 153/469 [04:35<08:34, 1.63s/it]
|
| 384 |
33%|███▎ | 154/469 [04:36<08:32, 1.63s/it]
|
| 385 |
33%|███▎ | 155/469 [04:38<08:25, 1.61s/it]
|
| 386 |
33%|███▎ | 156/469 [04:39<08:23, 1.61s/it]
|
| 387 |
33%|███▎ | 157/469 [04:41<08:22, 1.61s/it]
|
| 388 |
34%|███▎ | 158/469 [04:47<15:58, 3.08s/it]
|
| 389 |
34%|███▍ | 159/469 [04:49<13:44, 2.66s/it]
|
| 390 |
34%|███▍ | 160/469 [04:51<12:27, 2.42s/it]
|
| 391 |
34%|███▍ | 161/469 [04:53<11:14, 2.19s/it]
|
| 392 |
35%|███▍ | 162/469 [04:54<10:22, 2.03s/it]
|
| 393 |
35%|███▍ | 163/469 [04:56<09:58, 1.96s/it]
|
| 394 |
35%|███▍ | 164/469 [04:58<09:35, 1.89s/it]
|
| 395 |
35%|███▌ | 165/469 [04:59<09:06, 1.80s/it]
|
| 396 |
35%|███▌ | 166/469 [05:01<08:47, 1.74s/it]
|
| 397 |
36%|███▌ | 167/469 [05:03<08:31, 1.69s/it]
|
| 398 |
36%|███▌ | 168/469 [05:05<08:55, 1.78s/it]
|
| 399 |
36%|███▌ | 169/469 [05:07<09:07, 1.82s/it]
|
| 400 |
36%|███▌ | 170/469 [05:08<08:54, 1.79s/it]
|
| 401 |
36%|███▋ | 171/469 [05:10<08:46, 1.77s/it]
|
| 402 |
37%|███▋ | 172/469 [05:12<08:31, 1.72s/it]
|
| 403 |
37%|███▋ | 173/469 [05:13<08:23, 1.70s/it]
|
| 404 |
37%|███▋ | 174/469 [05:15<08:16, 1.68s/it]
|
| 405 |
37%|███▋ | 175/469 [05:17<08:12, 1.67s/it]
|
| 406 |
38%|███▊ | 176/469 [05:18<08:14, 1.69s/it]
|
| 407 |
38%|███▊ | 177/469 [05:20<08:13, 1.69s/it]
|
| 408 |
38%|███▊ | 178/469 [05:22<08:07, 1.68s/it]
|
| 409 |
38%|███▊ | 179/469 [05:23<08:00, 1.66s/it]
|
| 410 |
38%|███▊ | 180/469 [05:25<07:58, 1.66s/it]
|
| 411 |
39%|███▊ | 181/469 [05:26<07:53, 1.64s/it]
|
| 412 |
39%|███▉ | 182/469 [05:28<07:47, 1.63s/it]
|
| 413 |
39%|███▉ | 183/469 [05:30<07:48, 1.64s/it]
|
| 414 |
39%|███▉ | 184/469 [05:31<07:50, 1.65s/it]
|
| 415 |
39%|███▉ | 185/469 [05:33<07:44, 1.64s/it]
|
| 416 |
40%|███▉ | 186/469 [05:35<07:50, 1.66s/it]
|
| 417 |
40%|███▉ | 187/469 [05:37<08:00, 1.70s/it]
|
| 418 |
40%|████ | 188/469 [05:38<07:55, 1.69s/it]
|
| 419 |
40%|████ | 189/469 [05:40<07:48, 1.67s/it]
|
| 420 |
41%|████ | 190/469 [05:42<07:53, 1.70s/it]
|
| 421 |
41%|████ | 191/469 [05:43<07:44, 1.67s/it]
|
| 422 |
41%|████ | 192/469 [05:45<07:38, 1.65s/it]
|
| 423 |
41%|████ | 193/469 [05:46<07:32, 1.64s/it]
|
| 424 |
41%|████▏ | 194/469 [05:48<07:54, 1.73s/it]
|
| 425 |
42%|███��▏ | 195/469 [05:50<07:45, 1.70s/it]
|
| 426 |
42%|████▏ | 196/469 [05:52<07:44, 1.70s/it]
|
| 427 |
42%|████▏ | 197/469 [05:53<07:48, 1.72s/it]
|
| 428 |
42%|████▏ | 198/469 [05:55<07:40, 1.70s/it]
|
| 429 |
42%|████▏ | 199/469 [05:57<08:02, 1.79s/it]
|
| 430 |
43%|████▎ | 200/469 [05:59<07:49, 1.74s/it]
|
| 431 |
43%|████▎ | 201/469 [06:00<07:41, 1.72s/it]
|
| 432 |
43%|████▎ | 202/469 [06:02<07:38, 1.72s/it]
|
| 433 |
43%|████▎ | 203/469 [06:04<07:37, 1.72s/it]
|
| 434 |
43%|████▎ | 204/469 [06:05<07:29, 1.70s/it]
|
| 435 |
44%|████▎ | 205/469 [06:07<07:31, 1.71s/it]
|
| 436 |
44%|████▍ | 206/469 [06:09<07:21, 1.68s/it]
|
| 437 |
44%|████▍ | 207/469 [06:10<07:16, 1.67s/it]
|
| 438 |
44%|████▍ | 208/469 [06:12<07:09, 1.64s/it]
|
| 439 |
45%|████▍ | 209/469 [06:14<07:07, 1.64s/it]
|
| 440 |
45%|████▍ | 210/469 [06:15<07:11, 1.67s/it]
|
| 441 |
45%|████▍ | 211/469 [06:17<07:21, 1.71s/it]
|
| 442 |
45%|████▌ | 212/469 [06:19<07:27, 1.74s/it]
|
| 443 |
45%|████▌ | 213/469 [06:21<07:20, 1.72s/it]
|
| 444 |
46%|████▌ | 214/469 [06:22<07:09, 1.69s/it]
|
| 445 |
46%|████▌ | 215/469 [06:24<07:07, 1.68s/it]
|
| 446 |
46%|████▌ | 216/469 [06:26<07:00, 1.66s/it]
|
| 447 |
46%|████▋ | 217/469 [06:27<06:56, 1.65s/it]
|
| 448 |
46%|████▋ | 218/469 [06:29<06:48, 1.63s/it]
|
| 449 |
47%|████▋ | 219/469 [06:31<07:08, 1.72s/it]
|
| 450 |
47%|████▋ | 220/469 [06:32<06:57, 1.68s/it]
|
| 451 |
47%|████▋ | 221/469 [06:34<06:56, 1.68s/it]
|
| 452 |
47%|████▋ | 222/469 [06:36<06:54, 1.68s/it]
|
| 453 |
48%|████▊ | 223/469 [06:37<06:47, 1.66s/it]
|
| 454 |
48%|████▊ | 224/469 [06:39<06:44, 1.65s/it]
|
| 455 |
48%|████▊ | 225/469 [06:41<06:42, 1.65s/it]
|
| 456 |
48%|████▊ | 226/469 [06:42<06:35, 1.63s/it]
|
| 457 |
48%|████▊ | 227/469 [06:44<06:31, 1.62s/it]
|
| 458 |
49%|████▊ | 228/469 [06:46<06:57, 1.73s/it]
|
| 459 |
49%|████▉ | 229/469 [06:47<06:47, 1.70s/it]
|
| 460 |
49%|████▉ | 230/469 [06:49<06:46, 1.70s/it]
|
| 461 |
49%|████▉ | 231/469 [06:51<06:40, 1.68s/it]
|
| 462 |
49%|████▉ | 232/469 [06:52<06:33, 1.66s/it]
|
| 463 |
50%|████▉ | 233/469 [06:54<06:46, 1.72s/it]
|
| 464 |
50%|████▉ | 234/469 [06:56<07:17, 1.86s/it]
|
| 465 |
50%|█████ | 235/469 [06:58<07:00, 1.80s/it]
|
| 466 |
50%|█████ | 236/469 [07:00<06:47, 1.75s/it]
|
| 467 |
51%|█████ | 237/469 [07:01<06:40, 1.72s/it]
|
| 468 |
51%|█████ | 238/469 [07:03<06:31, 1.69s/it]
|
| 469 |
51%|█████ | 239/469 [07:05<06:23, 1.67s/it]
|
| 470 |
51%|█████ | 240/469 [07:06<06:18, 1.65s/it]
|
| 471 |
51%|█████▏ | 241/469 [07:08<06:14, 1.64s/it]
|
| 472 |
52%|█████▏ | 242/469 [07:09<06:10, 1.63s/it]
|
| 473 |
52%|█████▏ | 243/469 [07:11<06:41, 1.77s/it]
|
| 474 |
52%|█████▏ | 244/469 [07:13<06:27, 1.72s/it]
|
| 475 |
52%|█████▏ | 245/469 [07:15<06:49, 1.83s/it]
|
| 476 |
52%|█████▏ | 246/469 [07:17<06:33, 1.76s/it]
|
| 477 |
53%|█████▎ | 247/469 [07:18<06:22, 1.72s/it]
|
| 478 |
53%|█████▎ | 248/469 [07:20<06:19, 1.72s/it]
|
| 479 |
53%|█████▎ | 249/469 [07:22<06:12, 1.69s/it]
|
| 480 |
53%|█████▎ | 250/469 [07:24<06:16, 1.72s/it]
|
| 481 |
54%|█████▎ | 251/469 [07:25<06:10, 1.70s/it]
|
| 482 |
54%|█████▎ | 252/469 [07:27<06:11, 1.71s/it]
|
| 483 |
54%|█████▍ | 253/469 [07:29<06:03, 1.68s/it]
|
| 484 |
54%|█████▍ | 254/469 [07:30<06:04, 1.70s/it]
|
| 485 |
54%|█████▍ | 255/469 [07:32<05:56, 1.67s/it]
|
| 486 |
55%|█████▍ | 256/469 [07:34<06:02, 1.70s/it]
|
| 487 |
55%|█████▍ | 257/469 [07:35<05:56, 1.68s/it]
|
| 488 |
55%|█████▌ | 258/469 [07:37<06:02, 1.72s/it]
|
| 489 |
55%|█████▌ | 259/469 [07:39<05:57, 1.70s/it]
|
| 490 |
55%|█████▌ | 260/469 [07:40<05:50, 1.68s/it]
|
| 491 |
56%|█████▌ | 261/469 [07:42<05:50, 1.69s/it]
|
| 492 |
56%|█████▌ | 262/469 [07:44<05:52, 1.70s/it]
|
| 493 |
56%|█████▌ | 263/469 [07:46<05:50, 1.70s/it]
|
| 494 |
56%|█████▋ | 264/469 [07:47<05:48, 1.70s/it]
|
| 495 |
57%|█████▋ | 265/469 [07:49<05:45, 1.69s/it]
|
| 496 |
57%|█████▋ | 266/469 [07:50<05:37, 1.66s/it]
|
| 497 |
57%|█████▋ | 267/469 [07:52<05:33, 1.65s/it]
|
| 498 |
57%|█████▋ | 268/469 [07:58<09:27, 2.82s/it]
|
| 499 |
57%|█████▋ | 269/469 [07:59<08:23, 2.52s/it]
|
| 500 |
58%|█████▊ | 270/469 [08:01<07:26, 2.24s/it]
|
| 501 |
58%|█████▊ | 271/469 [08:03<06:53, 2.09s/it]
|
| 502 |
58%|█████▊ | 272/469 [08:04<06:26, 1.96s/it]
|
| 503 |
58%|█████▊ | 273/469 [08:06<06:03, 1.86s/it]
|
| 504 |
58%|█████▊ | 274/469 [08:08<05:44, 1.77s/it]
|
| 505 |
59%|█████▊ | 275/469 [08:09<05:33, 1.72s/it]
|
| 506 |
59%|█████▉ | 276/469 [08:11<05:25, 1.69s/it]
|
| 507 |
59%|█████▉ | 277/469 [08:12<05:21, 1.68s/it]
|
| 508 |
59%|█████▉ | 278/469 [08:14<05:21, 1.68s/it]
|
| 509 |
59%|█████▉ | 279/469 [08:16<05:26, 1.72s/it]
|
| 510 |
60%|█████▉ | 280/469 [08:18<05:16, 1.67s/it]
|
| 511 |
60%|█████▉ | 281/469 [08:20<05:45, 1.84s/it]
|
| 512 |
60%|██████ | 282/469 [08:21<05:31, 1.77s/it]
|
| 513 |
60%|██████ | 283/469 [08:23<05:26, 1.75s/it]
|
| 514 |
61%|██████ | 284/469 [08:25<05:17, 1.72s/it]
|
| 515 |
61%|██████ | 285/469 [08:26<05:10, 1.69s/it]
|
| 516 |
61%|██████ | 286/469 [08:28<05:21, 1.76s/it]
|
| 517 |
61%|██████ | 287/469 [08:30<05:16, 1.74s/it]
|
| 518 |
61%|██████▏ | 288/469 [08:32<05:13, 1.73s/it]
|
| 519 |
62%|██████▏ | 289/469 [08:33<05:03, 1.69s/it]
|
| 520 |
62%|██████▏ | 290/469 [08:35<05:10, 1.74s/it]
|
| 521 |
62%|██████▏ | 291/469 [08:37<05:17, 1.78s/it]
|
| 522 |
62%|██████▏ | 292/469 [08:39<05:17, 1.79s/it]
|
| 523 |
62%|██████▏ | 293/469 [08:41<05:19, 1.81s/it]
|
| 524 |
63%|██████▎ | 294/469 [08:42<05:06, 1.75s/it]
|
| 525 |
63%|██████▎ | 295/469 [08:44<04:55, 1.70s/it]
|
| 526 |
63%|██████▎ | 296/469 [08:46<04:53, 1.70s/it]
|
| 527 |
63%|██████▎ | 297/469 [08:47<04:46, 1.67s/it]
|
| 528 |
64%|██████▎ | 298/469 [08:49<04:40, 1.64s/it]
|
| 529 |
64%|██████▍ | 299/469 [08:50<04:39, 1.65s/it]
|
| 530 |
64%|██████▍ | 300/469 [08:52<04:37, 1.64s/it]
|
| 531 |
64%|██████▍ | 301/469 [08:54<04:39, 1.66s/it]
|
| 532 |
64%|██████▍ | 302/469 [08:55<04:35, 1.65s/it]
|
| 533 |
65%|██████▍ | 303/469 [08:57<04:36, 1.66s/it]
|
| 534 |
65%|██████▍ | 304/469 [08:59<04:36, 1.67s/it]
|
| 535 |
65%|██████▌ | 305/469 [09:00<04:34, 1.68s/it]
|
| 536 |
65%|██████▌ | 306/469 [09:02<04:32, 1.67s/it]
|
| 537 |
65%|██████▌ | 307/469 [09:04<04:30, 1.67s/it]
|
| 538 |
66%|██████▌ | 308/469 [09:05<04:25, 1.65s/it]
|
| 539 |
66%|██████▌ | 309/469 [09:07<04:23, 1.65s/it]
|
| 540 |
66%|██████▌ | 310/469 [09:09<04:20, 1.64s/it]
|
| 541 |
66%|██████▋ | 311/469 [09:10<04:18, 1.64s/it]
|
| 542 |
67%|██████▋ | 312/469 [09:12<04:19, 1.65s/it]
|
| 543 |
67%|██████▋ | 313/469 [09:14<04:15, 1.64s/it]
|
| 544 |
67%|██████▋ | 314/469 [09:15<04:10, 1.62s/it]
|
| 545 |
67%|██████▋ | 315/469 [09:17<04:08, 1.61s/it]
|
| 546 |
67%|██████▋ | 316/469 [09:19<04:21, 1.71s/it]
|
| 547 |
68%|██████▊ | 317/469 [09:20<04:23, 1.73s/it]
|
| 548 |
68%|██████▊ | 318/469 [09:22<04:18, 1.71s/it]
|
| 549 |
68%|██████▊ | 319/469 [09:24<04:20, 1.74s/it]
|
| 550 |
68%|██████▊ | 320/469 [09:26<04:15, 1.71s/it]
|
| 551 |
68%|██████▊ | 321/469 [09:27<04:21, 1.77s/it]
|
| 552 |
69%|██████▊ | 322/469 [09:29<04:12, 1.72s/it]
|
| 553 |
69%|██████▉ | 323/469 [09:31<04:04, 1.67s/it]
|
| 554 |
69%|██████▉ | 324/469 [09:32<04:01, 1.67s/it]
|
| 555 |
69%|██████▉ | 325/469 [09:34<03:55, 1.64s/it]
|
| 556 |
70%|██████▉ | 326/469 [09:35<03:52, 1.62s/it]
|
| 557 |
70%|██████▉ | 327/469 [09:37<03:49, 1.61s/it]
|
| 558 |
70%|██████▉ | 328/469 [09:39<03:52, 1.65s/it]
|
| 559 |
70%|███████ | 329/469 [09:41<04:01, 1.72s/it]
|
| 560 |
70%|███████ | 330/469 [09:43<04:06, 1.77s/it]
|
| 561 |
71%|███████ | 331/469 [09:44<04:03, 1.76s/it]
|
| 562 |
71%|███████ | 332/469 [09:46<03:55, 1.72s/it]
|
| 563 |
71%|███████ | 333/469 [09:47<03:48, 1.68s/it]
|
| 564 |
71%|███████ | 334/469 [09:49<03:46, 1.68s/it]
|
| 565 |
71%|███████▏ | 335/469 [09:51<03:42, 1.66s/it]
|
| 566 |
72%|███████▏ | 336/469 [09:52<03:37, 1.63s/it]
|
| 567 |
72%|███████▏ | 337/469 [09:54<03:33, 1.62s/it]
|
| 568 |
72%|███████▏ | 338/469 [09:56<03:33, 1.63s/it]
|
| 569 |
72%|███████▏ | 339/469 [09:57<03:31, 1.62s/it]
|
| 570 |
72%|███████▏ | 340/469 [09:59<03:27, 1.61s/it]
|
| 571 |
73%|███████▎ | 341/469 [10:01<03:32, 1.66s/it]
|
| 572 |
73%|███████▎ | 342/469 [10:03<03:41, 1.75s/it]
|
| 573 |
73%|███████▎ | 343/469 [10:04<03:35, 1.71s/it]
|
| 574 |
73%|███████▎ | 344/469 [10:06<03:31, 1.70s/it]
|
| 575 |
74%|███████▎ | 345/469 [10:27<15:27, 7.48s/it]
|
| 576 |
74%|███████▍ | 346/469 [10:29<11:48, 5.76s/it]
|
| 577 |
74%|███████▍ | 347/469 [10:30<09:19, 4.59s/it]
|
| 578 |
74%|███████▍ | 348/469 [10:32<07:29, 3.71s/it]
|
| 579 |
74%|███████▍ | 349/469 [10:34<06:11, 3.10s/it]
|
| 580 |
75%|███████▍ | 350/469 [10:36<05:49, 2.93s/it]
|
| 581 |
75%|███████▍ | 351/469 [10:38<05:09, 2.63s/it]
|
| 582 |
75%|███████▌ | 352/469 [10:40<04:39, 2.39s/it]
|
| 583 |
75%|███████▌ | 353/469 [10:42<04:16, 2.21s/it]
|
| 584 |
75%|███████▌ | 354/469 [10:43<03:53, 2.03s/it]
|
| 585 |
76%|███████▌ | 355/469 [10:45<03:35, 1.89s/it]
|
| 586 |
76%|███████▌ | 356/469 [10:47<03:30, 1.86s/it]
|
| 587 |
76%|███████▌ | 357/469 [10:48<03:23, 1.82s/it]
|
| 588 |
76%|███████▋ | 358/469 [10:50<03:19, 1.80s/it]
|
| 589 |
77%|███████▋ | 359/469 [10:52<03:20, 1.82s/it]
|
| 590 |
77%|███████▋ | 360/469 [10:54<03:23, 1.87s/it]
|
| 591 |
77%|███████▋ | 361/469 [10:56<03:16, 1.82s/it]
|
| 592 |
77%|███████▋ | 362/469 [10:58<03:24, 1.91s/it]
|
| 593 |
77%|███████▋ | 363/469 [11:00<03:12, 1.82s/it]
|
| 594 |
78%|███████▊ | 364/469 [11:01<03:05, 1.77s/it]
|
| 595 |
78%|███████▊ | 365/469 [11:03<02:58, 1.72s/it]
|
| 596 |
78%|███████▊ | 366/469 [11:04<02:54, 1.69s/it]
|
| 597 |
78%|███████▊ | 367/469 [11:06<02:52, 1.69s/it]
|
| 598 |
78%|███████▊ | 368/469 [11:08<02:50, 1.68s/it]
|
| 599 |
79%|███████▊ | 369/469 [11:10<02:50, 1.71s/it]
|
| 600 |
79%|███████▉ | 370/469 [11:11<02:49, 1.71s/it]
|
| 601 |
79%|███████▉ | 371/469 [11:13<02:46, 1.70s/it]
|
| 602 |
79%|███████▉ | 372/469 [11:15<02:49, 1.74s/it]
|
| 603 |
80%|███████▉ | 373/469 [11:16<02:43, 1.70s/it]
|
| 604 |
80%|███████▉ | 374/469 [11:18<02:42, 1.71s/it]
|
| 605 |
80%|███████▉ | 375/469 [11:20<02:44, 1.75s/it]
|
| 606 |
80%|████████ | 376/469 [11:22<02:47, 1.80s/it]
|
| 607 |
80%|████████ | 377/469 [11:24<02:41, 1.76s/it]
|
| 608 |
81%|████████ | 378/469 [11:25<02:34, 1.70s/it]
|
| 609 |
81%|████████ | 379/469 [11:27<02:30, 1.67s/it]
|
| 610 |
81%|████████ | 380/469 [11:29<02:33, 1.72s/it]
|
| 611 |
81%|████████ | 381/469 [11:30<02:28, 1.69s/it]
|
| 612 |
81%|████████▏ | 382/469 [11:32<02:24, 1.66s/it]
|
| 613 |
82%|████████▏ | 383/469 [11:34<02:29, 1.73s/it]
|
| 614 |
82%|████████▏ | 384/469 [11:36<02:33, 1.80s/it]
|
| 615 |
82%|████████▏ | 385/469 [11:38<02:34, 1.84s/it]
|
| 616 |
82%|████████▏ | 386/469 [11:39<02:27, 1.77s/it]
|
| 617 |
83%|████████▎ | 387/469 [11:41<02:26, 1.79s/it]
|
| 618 |
83%|████████▎ | 388/469 [11:43<02:21, 1.74s/it]
|
| 619 |
83%|████████▎ | 389/469 [11:45<02:23, 1.79s/it]
|
| 620 |
83%|████████▎ | 390/469 [11:46<02:16, 1.73s/it]
|
| 621 |
83%|████████▎ | 391/469 [11:48<02:12, 1.70s/it]
|
| 622 |
84%|████████▎ | 392/469 [11:49<02:09, 1.69s/it]
|
| 623 |
84%|████████▍ | 393/469 [11:51<02:05, 1.66s/it]
|
| 624 |
84%|████████▍ | 394/469 [11:53<02:03, 1.65s/it]
|
| 625 |
84%|████████▍ | 395/469 [11:54<02:01, 1.65s/it]
|
| 626 |
84%|████████▍ | 396/469 [11:56<01:59, 1.63s/it]
|
| 627 |
85%|████████▍ | 397/469 [11:58<01:58, 1.64s/it]
|
| 628 |
85%|████████▍ | 398/469 [11:59<02:00, 1.70s/it]
|
| 629 |
85%|████████▌ | 399/469 [12:01<01:57, 1.68s/it]
|
| 630 |
85%|████████▌ | 400/469 [12:03<01:54, 1.66s/it]
|
| 631 |
86%|████████▌ | 401/469 [12:04<01:51, 1.64s/it]
|
| 632 |
86%|████████▌ | 402/469 [12:25<08:21, 7.49s/it]
|
| 633 |
86%|████████▌ | 403/469 [12:27<06:17, 5.72s/it]
|
| 634 |
86%|████████▌ | 404/469 [12:29<04:51, 4.49s/it]
|
| 635 |
86%|████████▋ | 405/469 [12:30<03:53, 3.65s/it]
|
| 636 |
87%|████████▋ | 406/469 [12:32<03:12, 3.06s/it]
|
| 637 |
87%|████████▋ | 407/469 [12:34<02:42, 2.63s/it]
|
| 638 |
87%|████████▋ | 408/469 [12:35<02:21, 2.32s/it]
|
| 639 |
87%|████████▋ | 409/469 [12:37<02:06, 2.10s/it]
|
| 640 |
87%|████████▋ | 410/469 [12:38<01:54, 1.94s/it]
|
| 641 |
88%|████████▊ | 411/469 [12:40<01:46, 1.83s/it]
|
| 642 |
88%|████████▊ | 412/469 [12:41<01:39, 1.75s/it]
|
| 643 |
88%|████████▊ | 413/469 [12:43<01:36, 1.73s/it]
|
| 644 |
88%|████████▊ | 414/469 [12:45<01:33, 1.69s/it]
|
| 645 |
88%|████████▊ | 415/469 [12:46<01:29, 1.65s/it]
|
| 646 |
89%|████████▊ | 416/469 [12:48<01:28, 1.68s/it]
|
| 647 |
89%|████████▉ | 417/469 [12:50<01:26, 1.67s/it]
|
| 648 |
89%|████████▉ | 418/469 [12:51<01:24, 1.65s/it]
|
| 649 |
89%|████████▉ | 419/469 [12:53<01:21, 1.64s/it]
|
| 650 |
90%|████████▉ | 420/469 [12:55<01:21, 1.66s/it]
|
| 651 |
90%|████████▉ | 421/469 [12:56<01:19, 1.66s/it]
|
| 652 |
90%|████████▉ | 422/469 [12:58<01:18, 1.67s/it]
|
| 653 |
90%|█████████ | 423/469 [13:00<01:18, 1.71s/it]
|
| 654 |
90%|█████████ | 424/469 [13:02<01:18, 1.73s/it]
|
| 655 |
91%|█████████ | 425/469 [13:03<01:14, 1.69s/it]
|
| 656 |
91%|█████████ | 426/469 [13:05<01:11, 1.66s/it]
|
| 657 |
91%|█████████ | 427/469 [13:07<01:13, 1.74s/it]
|
| 658 |
91%|█████████▏| 428/469 [13:08<01:09, 1.70s/it]
|
| 659 |
91%|█████████▏| 429/469 [13:10<01:07, 1.68s/it]
|
| 660 |
92%|█████████▏| 430/469 [13:12<01:05, 1.68s/it]
|
| 661 |
92%|█████████▏| 431/469 [13:13<01:05, 1.71s/it]
|
| 662 |
92%|█████████▏| 432/469 [13:15<01:02, 1.68s/it]
|
| 663 |
92%|█████████▏| 433/469 [13:17<01:01, 1.71s/it]
|
| 664 |
93%|█████████▎| 434/469 [13:18<00:58, 1.69s/it]
|
| 665 |
93%|█████████▎| 435/469 [13:20<00:57, 1.69s/it]
|
| 666 |
93%|█████████▎| 436/469 [13:22<00:55, 1.68s/it]
|
| 667 |
93%|█████████▎| 437/469 [13:23<00:53, 1.67s/it]
|
| 668 |
93%|█████████▎| 438/469 [13:25<00:51, 1.66s/it]
|
| 669 |
94%|█████████▎| 439/469 [13:27<00:49, 1.64s/it]
|
| 670 |
94%|█████████▍| 440/469 [13:28<00:47, 1.65s/it]
|
| 671 |
94%|█████████▍| 441/469 [13:30<00:46, 1.65s/it]
|
| 672 |
94%|█████████▍| 442/469 [13:32<00:44, 1.63s/it]
|
| 673 |
94%|█████████▍| 443/469 [13:33<00:43, 1.67s/it]
|
| 674 |
95%|█████████▍| 444/469 [13:35<00:41, 1.65s/it]
|
| 675 |
95%|█████████▍| 445/469 [13:36<00:39, 1.64s/it]
|
| 676 |
95%|█████████▌| 446/469 [13:39<00:40, 1.77s/it]
|
| 677 |
95%|█████████▌| 447/469 [13:40<00:37, 1.72s/it]
|
| 678 |
96%|█████████▌| 448/469 [13:42<00:35, 1.68s/it]
|
| 679 |
96%|█████████▌| 449/469 [13:44<00:36, 1.84s/it]
|
| 680 |
96%|█████████▌| 450/469 [13:46<00:33, 1.78s/it]
|
| 681 |
96%|█████████▌| 451/469 [13:47<00:30, 1.72s/it]
|
| 682 |
96%|█████████▋| 452/469 [13:49<00:30, 1.78s/it]
|
| 683 |
97%|█████████▋| 453/469 [13:51<00:27, 1.73s/it]
|
| 684 |
97%|█████████▋| 454/469 [13:52<00:25, 1.70s/it]
|
| 685 |
97%|█████████▋| 455/469 [13:54<00:23, 1.70s/it]
|
| 686 |
97%|█████████▋| 456/469 [13:56<00:21, 1.67s/it]
|
| 687 |
97%|█████████▋| 457/469 [13:57<00:19, 1.66s/it]
|
| 688 |
98%|█████████▊| 458/469 [13:59<00:18, 1.65s/it]
|
| 689 |
98%|█████████▊| 459/469 [14:01<00:16, 1.64s/it]
|
| 690 |
98%|█████████▊| 460/469 [14:02<00:14, 1.65s/it]
|
| 691 |
98%|█████████▊| 461/469 [14:04<00:13, 1.65s/it]
|
| 692 |
99%|█████████▊| 462/469 [14:06<00:11, 1.67s/it]
|
| 693 |
99%|█████████▊| 463/469 [14:07<00:10, 1.67s/it]
|
| 694 |
99%|█████████▉| 464/469 [14:09<00:09, 1.83s/it]
|
| 695 |
99%|█████████▉| 465/469 [14:11<00:07, 1.84s/it]
|
| 696 |
99%|█████████▉| 466/469 [14:13<00:05, 1.80s/it]
|
| 697 |
+
computing/reading sample batch statistics...
|
| 698 |
+
Computing evaluations...
|
| 699 |
+
Inception Score: 37.646392822265625
|
| 700 |
+
FID: 21.19386100577333
|
| 701 |
+
sFID: 71.79977998851734
|
| 702 |
+
Precision: 0.690407122136641
|
| 703 |
+
Recall: 0.358997247638176
|
pic_npz copy.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
|
| 4 |
+
基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
|
| 5 |
+
支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
|
| 6 |
+
支持从 metadata.jsonl 文件读取图片路径
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import glob
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_npz_from_metadata(metadata_jsonl_path, output_path=None):
|
| 19 |
+
"""
|
| 20 |
+
从 metadata.jsonl 文件读取图片路径并构建 .npz 文件
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
metadata_jsonl_path (str): metadata.jsonl 文件路径
|
| 24 |
+
output_path (str, optional): 输出 npz 文件路径,默认在 metadata.jsonl 同目录下生成
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
str: 生成的 npz 文件路径
|
| 28 |
+
"""
|
| 29 |
+
# 确保 metadata.jsonl 存在
|
| 30 |
+
if not os.path.exists(metadata_jsonl_path):
|
| 31 |
+
raise ValueError(f"metadata.jsonl 文件不存在: {metadata_jsonl_path}")
|
| 32 |
+
|
| 33 |
+
# 获取基础目录
|
| 34 |
+
base_dir = os.path.dirname(metadata_jsonl_path)
|
| 35 |
+
|
| 36 |
+
# 读取 metadata.jsonl
|
| 37 |
+
image_files = []
|
| 38 |
+
with open(metadata_jsonl_path, 'r', encoding='utf-8') as f:
|
| 39 |
+
for line in f:
|
| 40 |
+
line = line.strip()
|
| 41 |
+
if line:
|
| 42 |
+
try:
|
| 43 |
+
data = json.loads(line)
|
| 44 |
+
file_name = data.get('file_name')
|
| 45 |
+
if file_name:
|
| 46 |
+
full_path = os.path.join(base_dir, file_name)
|
| 47 |
+
image_files.append(full_path)
|
| 48 |
+
except json.JSONDecodeError as e:
|
| 49 |
+
print(f"警告: 跳过无效的 JSON 行: {e}")
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
if len(image_files) == 0:
|
| 53 |
+
raise ValueError(f"在 {metadata_jsonl_path} 中未找到任何有效的图片路径")
|
| 54 |
+
|
| 55 |
+
print(f"从 metadata.jsonl 读取到 {len(image_files)} 张图片路径")
|
| 56 |
+
|
| 57 |
+
# 读取所有图片
|
| 58 |
+
samples = []
|
| 59 |
+
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
|
| 60 |
+
try:
|
| 61 |
+
# 打开图片并转换为RGB格式(确保一致性)
|
| 62 |
+
with Image.open(img_path) as img:
|
| 63 |
+
# 转换为RGB,确保所有图片都是3通道
|
| 64 |
+
if img.mode != 'RGB':
|
| 65 |
+
img = img.convert('RGB')
|
| 66 |
+
|
| 67 |
+
# 将图片resize到512x512
|
| 68 |
+
img = img.resize((512, 512), Image.LANCZOS)
|
| 69 |
+
|
| 70 |
+
sample_np = np.asarray(img).astype(np.uint8)
|
| 71 |
+
|
| 72 |
+
# 确保图片是3通道
|
| 73 |
+
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
|
| 74 |
+
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
samples.append(sample_np)
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"警告: 无法读取图片 {img_path}: {e}")
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
if len(samples) == 0:
|
| 84 |
+
raise ValueError("没有成功读取任何有效的图片文件")
|
| 85 |
+
|
| 86 |
+
# 转换为numpy数组
|
| 87 |
+
samples = np.stack(samples)
|
| 88 |
+
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
|
| 89 |
+
|
| 90 |
+
# 验证数据形状
|
| 91 |
+
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
|
| 92 |
+
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
|
| 93 |
+
|
| 94 |
+
# 生成输出路径
|
| 95 |
+
if output_path is None:
|
| 96 |
+
base_name = os.path.splitext(os.path.basename(metadata_jsonl_path))[0]
|
| 97 |
+
output_path = os.path.join(base_dir, f"{base_name}.npz")
|
| 98 |
+
|
| 99 |
+
# 保存为npz文件
|
| 100 |
+
np.savez(output_path, arr_0=samples)
|
| 101 |
+
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
|
| 102 |
+
|
| 103 |
+
return output_path
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
"""
|
| 108 |
+
主函数:解析命令行参数并执行图片到npz的转换
|
| 109 |
+
"""
|
| 110 |
+
parser = argparse.ArgumentParser(
|
| 111 |
+
description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
|
| 112 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 113 |
+
epilog="""
|
| 114 |
+
使用示例:
|
| 115 |
+
python pic_npz.py /path/to/image/folder
|
| 116 |
+
python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
|
| 117 |
+
"""
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--image_folder",
|
| 122 |
+
type=str,
|
| 123 |
+
default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples",
|
| 124 |
+
help="包含PNG或JPG图片文件的文件夹路径"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# parser.add_argument(
|
| 128 |
+
# "--metadata_jsonl",
|
| 129 |
+
# type=str,
|
| 130 |
+
# default="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl",
|
| 131 |
+
# help="metadata.jsonl 文件路径,用于从 JSONL 文件读取图片路径"
|
| 132 |
+
# )
|
| 133 |
+
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--output-dir",
|
| 136 |
+
type=str,
|
| 137 |
+
default=None,
|
| 138 |
+
help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
args = parser.parse_args()
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
if args.metadata_jsonl and os.path.exists(args.metadata_jsonl):
|
| 145 |
+
# 使用 metadata.jsonl
|
| 146 |
+
metadata_path = os.path.abspath(args.metadata_jsonl)
|
| 147 |
+
base_dir = os.path.dirname(metadata_path)
|
| 148 |
+
base_name = os.path.splitext(os.path.basename(metadata_path))[0]
|
| 149 |
+
|
| 150 |
+
if args.output_dir:
|
| 151 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 152 |
+
output_path = os.path.join(args.output_dir, f"{base_name}.npz")
|
| 153 |
+
else:
|
| 154 |
+
output_path = os.path.join(base_dir, f"{base_name}.npz")
|
| 155 |
+
|
| 156 |
+
npz_path = create_npz_from_metadata(metadata_path, output_path)
|
| 157 |
+
else:
|
| 158 |
+
# 使用图片文件夹
|
| 159 |
+
image_folder_path = os.path.abspath(args.image_folder)
|
| 160 |
+
|
| 161 |
+
if args.output_dir:
|
| 162 |
+
# 如果指定了输出目录,修改生成逻辑
|
| 163 |
+
folder_name = os.path.basename(image_folder_path.rstrip('/'))
|
| 164 |
+
custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
|
| 165 |
+
|
| 166 |
+
# 创建输出目录(如果不存在)
|
| 167 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 168 |
+
|
| 169 |
+
# 临时修改函数以支持自定义输出路径
|
| 170 |
+
npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
|
| 171 |
+
else:
|
| 172 |
+
npz_path = create_npz_from_image_folder(image_folder_path)
|
| 173 |
+
|
| 174 |
+
print(f"转换完成!NPZ文件已保存至: {npz_path}")
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(f"错误: {e}")
|
| 178 |
+
return 1
|
| 179 |
+
|
| 180 |
+
return 0
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def create_npz_from_image_folder_custom(image_folder_path, output_path):
|
| 184 |
+
"""
|
| 185 |
+
从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
image_folder_path (str): 包含图片文件的文件夹路径
|
| 189 |
+
output_path (str): 输出npz文件的完整路径
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
str: 生成的 npz 文件路径
|
| 193 |
+
"""
|
| 194 |
+
# 确保路径存在
|
| 195 |
+
if not os.path.exists(image_folder_path):
|
| 196 |
+
raise ValueError(f"文件夹路径不存在: {image_folder_path}")
|
| 197 |
+
|
| 198 |
+
# 获取所有支持的图片文件
|
| 199 |
+
supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
|
| 200 |
+
image_files = []
|
| 201 |
+
|
| 202 |
+
for extension in supported_extensions:
|
| 203 |
+
pattern = os.path.join(image_folder_path, extension)
|
| 204 |
+
image_files.extend(glob.glob(pattern))
|
| 205 |
+
|
| 206 |
+
# 按文件名排序确保一致性
|
| 207 |
+
image_files.sort()
|
| 208 |
+
|
| 209 |
+
if len(image_files) == 0:
|
| 210 |
+
raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
|
| 211 |
+
|
| 212 |
+
print(f"找到 {len(image_files)} 张图片文件")
|
| 213 |
+
|
| 214 |
+
# 读取所有图片
|
| 215 |
+
samples = []
|
| 216 |
+
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
|
| 217 |
+
try:
|
| 218 |
+
# 打开图片并转换为RGB格式(确保一致性)
|
| 219 |
+
with Image.open(img_path) as img:
|
| 220 |
+
# 转换为RGB,确保所有图片都是3通道
|
| 221 |
+
if img.mode != 'RGB':
|
| 222 |
+
img = img.convert('RGB')
|
| 223 |
+
|
| 224 |
+
# 将图片resize到512x512
|
| 225 |
+
img = img.resize((512, 512), Image.LANCZOS)
|
| 226 |
+
|
| 227 |
+
sample_np = np.asarray(img).astype(np.uint8)
|
| 228 |
+
|
| 229 |
+
# 确保图片是3通道
|
| 230 |
+
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
|
| 231 |
+
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
samples.append(sample_np)
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"警告: 无法读取图片 {img_path}: {e}")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
if len(samples) == 0:
|
| 241 |
+
raise ValueError("没有成功读取任何有效的图片文件")
|
| 242 |
+
|
| 243 |
+
# 转换为numpy数组
|
| 244 |
+
samples = np.stack(samples)
|
| 245 |
+
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
|
| 246 |
+
|
| 247 |
+
# 验证数据形状
|
| 248 |
+
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
|
| 249 |
+
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
|
| 250 |
+
|
| 251 |
+
# 保存为npz文件
|
| 252 |
+
np.savez(output_path, arr_0=samples)
|
| 253 |
+
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
|
| 254 |
+
|
| 255 |
+
return output_path
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
exit(main())
|
pic_npz.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
|
| 4 |
+
基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
|
| 5 |
+
支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import glob
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
"""
|
| 18 |
+
主函数:解析命令行参数并执行图片到npz的转换
|
| 19 |
+
"""
|
| 20 |
+
parser = argparse.ArgumentParser(
|
| 21 |
+
description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
|
| 22 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 23 |
+
epilog="""
|
| 24 |
+
使用示例:
|
| 25 |
+
python pic_npz.py /path/to/image/folder
|
| 26 |
+
python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
|
| 27 |
+
"""
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--image_folder",
|
| 32 |
+
type=str,
|
| 33 |
+
default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples_new_batch_2",
|
| 34 |
+
help="包含PNG或JPG图片文件的文件夹路径"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--output-dir",
|
| 39 |
+
type=str,
|
| 40 |
+
default=None,
|
| 41 |
+
help="自定义输出目录(默认为输入文件夹的父级目录)"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# 仅支持从图片文件夹生成 npz
|
| 48 |
+
image_folder_path = os.path.abspath(args.image_folder)
|
| 49 |
+
|
| 50 |
+
if args.output_dir:
|
| 51 |
+
# 如果指定了输出目录,修改生成逻辑
|
| 52 |
+
folder_name = os.path.basename(image_folder_path.rstrip('/'))
|
| 53 |
+
custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
|
| 54 |
+
|
| 55 |
+
# 创建输出目录(如果不存在)
|
| 56 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
|
| 59 |
+
else:
|
| 60 |
+
npz_path = create_npz_from_image_folder(image_folder_path)
|
| 61 |
+
|
| 62 |
+
print(f"转换完成!NPZ文件已保存至: {npz_path}")
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"错误: {e}")
|
| 66 |
+
return 1
|
| 67 |
+
|
| 68 |
+
return 0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def create_npz_from_image_folder_custom(image_folder_path, output_path):
|
| 72 |
+
"""
|
| 73 |
+
从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
image_folder_path (str): 包含图片文件的文件夹路径
|
| 77 |
+
output_path (str): 输出npz文件的完整路径
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
str: 生成的 npz 文件路径
|
| 81 |
+
"""
|
| 82 |
+
# 确保路径存在
|
| 83 |
+
if not os.path.exists(image_folder_path):
|
| 84 |
+
raise ValueError(f"文件夹路径不存在: {image_folder_path}")
|
| 85 |
+
|
| 86 |
+
# 获取所有支持的图片文件
|
| 87 |
+
supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
|
| 88 |
+
image_files = []
|
| 89 |
+
|
| 90 |
+
for extension in supported_extensions:
|
| 91 |
+
pattern = os.path.join(image_folder_path, extension)
|
| 92 |
+
image_files.extend(glob.glob(pattern))
|
| 93 |
+
|
| 94 |
+
# 按文件名排序确保一致性
|
| 95 |
+
image_files.sort()
|
| 96 |
+
|
| 97 |
+
if len(image_files) == 0:
|
| 98 |
+
raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
|
| 99 |
+
|
| 100 |
+
print(f"找到 {len(image_files)} 张图片文件")
|
| 101 |
+
|
| 102 |
+
# 读取所有图片
|
| 103 |
+
samples = []
|
| 104 |
+
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
|
| 105 |
+
try:
|
| 106 |
+
# 打开图片并转换为RGB格式(确保一致性)
|
| 107 |
+
with Image.open(img_path) as img:
|
| 108 |
+
# 转换为RGB,确保所有图片都是3通道
|
| 109 |
+
if img.mode != 'RGB':
|
| 110 |
+
img = img.convert('RGB')
|
| 111 |
+
|
| 112 |
+
# 将图片resize到512x512
|
| 113 |
+
img = img.resize((512, 512), Image.LANCZOS)
|
| 114 |
+
|
| 115 |
+
sample_np = np.asarray(img).astype(np.uint8)
|
| 116 |
+
|
| 117 |
+
# 确保图片是3通道
|
| 118 |
+
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
|
| 119 |
+
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
samples.append(sample_np)
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"警告: 无法读取图片 {img_path}: {e}")
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
if len(samples) == 0:
|
| 129 |
+
raise ValueError("没有成功读取任何有效的图片文件")
|
| 130 |
+
|
| 131 |
+
# 转换为numpy数组
|
| 132 |
+
samples = np.stack(samples)
|
| 133 |
+
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
|
| 134 |
+
|
| 135 |
+
# 验证数据形状
|
| 136 |
+
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
|
| 137 |
+
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
|
| 138 |
+
|
| 139 |
+
# 保存为npz文件
|
| 140 |
+
np.savez(output_path, arr_0=samples)
|
| 141 |
+
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
|
| 142 |
+
|
| 143 |
+
return output_path
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def create_npz_from_image_folder(image_folder_path):
|
| 147 |
+
"""
|
| 148 |
+
从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz
|
| 149 |
+
"""
|
| 150 |
+
parent_dir = os.path.dirname(os.path.abspath(image_folder_path))
|
| 151 |
+
folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/"))
|
| 152 |
+
output_path = os.path.join(parent_dir, f"{folder_name}.npz")
|
| 153 |
+
return create_npz_from_image_folder_custom(image_folder_path, output_path)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
exit(main())
|
pipeline_stable_diffusion_3.py
ADDED
|
@@ -0,0 +1,1378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import (
|
| 20 |
+
CLIPTextModelWithProjection,
|
| 21 |
+
CLIPTokenizer,
|
| 22 |
+
SiglipImageProcessor,
|
| 23 |
+
SiglipVisionModel,
|
| 24 |
+
T5EncoderModel,
|
| 25 |
+
T5TokenizerFast,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 29 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 30 |
+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
| 31 |
+
from ...models.autoencoders import AutoencoderKL
|
| 32 |
+
from ...models.transformers import SD3Transformer2DModel
|
| 33 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 34 |
+
from ...utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from ...utils.torch_utils import randn_tensor
|
| 43 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 44 |
+
from .pipeline_output import StableDiffusion3PipelineOutput
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_torch_xla_available():
|
| 48 |
+
import torch_xla.core.xla_model as xm
|
| 49 |
+
|
| 50 |
+
XLA_AVAILABLE = True
|
| 51 |
+
else:
|
| 52 |
+
XLA_AVAILABLE = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
EXAMPLE_DOC_STRING = """
|
| 58 |
+
Examples:
|
| 59 |
+
```py
|
| 60 |
+
>>> import torch
|
| 61 |
+
>>> from diffusers import StableDiffusion3Pipeline
|
| 62 |
+
|
| 63 |
+
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
| 64 |
+
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
| 65 |
+
... )
|
| 66 |
+
>>> pipe.to("cuda")
|
| 67 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 68 |
+
>>> image = pipe(prompt).images[0]
|
| 69 |
+
>>> image.save("sd3.png")
|
| 70 |
+
```
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 74 |
+
def calculate_shift(
|
| 75 |
+
image_seq_len,
|
| 76 |
+
base_seq_len: int = 256,
|
| 77 |
+
max_seq_len: int = 4096,
|
| 78 |
+
base_shift: float = 0.5,
|
| 79 |
+
max_shift: float = 1.15,
|
| 80 |
+
):
|
| 81 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 82 |
+
b = base_shift - m * base_seq_len
|
| 83 |
+
mu = image_seq_len * m + b
|
| 84 |
+
return mu
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 88 |
+
def retrieve_timesteps(
|
| 89 |
+
scheduler,
|
| 90 |
+
num_inference_steps: Optional[int] = None,
|
| 91 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 92 |
+
timesteps: Optional[List[int]] = None,
|
| 93 |
+
sigmas: Optional[List[float]] = None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
r"""
|
| 97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
scheduler (`SchedulerMixin`):
|
| 102 |
+
The scheduler to get timesteps from.
|
| 103 |
+
num_inference_steps (`int`):
|
| 104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 105 |
+
must be `None`.
|
| 106 |
+
device (`str` or `torch.device`, *optional*):
|
| 107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 108 |
+
timesteps (`List[int]`, *optional*):
|
| 109 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 110 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 111 |
+
sigmas (`List[float]`, *optional*):
|
| 112 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 113 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 117 |
+
second element is the number of inference steps.
|
| 118 |
+
"""
|
| 119 |
+
if timesteps is not None and sigmas is not None:
|
| 120 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 121 |
+
if timesteps is not None:
|
| 122 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 123 |
+
if not accepts_timesteps:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 126 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
elif sigmas is not None:
|
| 132 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accept_sigmas:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
else:
|
| 142 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 143 |
+
timesteps = scheduler.timesteps
|
| 144 |
+
return timesteps, num_inference_steps
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
|
| 148 |
+
r"""
|
| 149 |
+
Args:
|
| 150 |
+
transformer ([`SD3Transformer2DModel`]):
|
| 151 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 152 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 153 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 154 |
+
vae ([`AutoencoderKL`]):
|
| 155 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 156 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
| 157 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 158 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
| 159 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
| 160 |
+
as its dimension.
|
| 161 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
| 162 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 163 |
+
specifically the
|
| 164 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
| 165 |
+
variant.
|
| 166 |
+
text_encoder_3 ([`T5EncoderModel`]):
|
| 167 |
+
Frozen text-encoder. Stable Diffusion 3 uses
|
| 168 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 169 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 170 |
+
tokenizer (`CLIPTokenizer`):
|
| 171 |
+
Tokenizer of class
|
| 172 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 173 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 174 |
+
Second Tokenizer of class
|
| 175 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 176 |
+
tokenizer_3 (`T5TokenizerFast`):
|
| 177 |
+
Tokenizer of class
|
| 178 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 179 |
+
image_encoder (`SiglipVisionModel`, *optional*):
|
| 180 |
+
Pre-trained Vision Model for IP Adapter.
|
| 181 |
+
feature_extractor (`SiglipImageProcessor`, *optional*):
|
| 182 |
+
Image processor for IP Adapter.
|
| 183 |
+
model (`SD3WithRectifiedNoise`, *optional*):
|
| 184 |
+
Optional SD3WithRectifiedNoise model for enhanced noise prediction. If provided, will be used instead of
|
| 185 |
+
the default transformer for denoising.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
| 189 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 190 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
transformer: SD3Transformer2DModel,
|
| 195 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 196 |
+
vae: AutoencoderKL,
|
| 197 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 198 |
+
tokenizer: CLIPTokenizer,
|
| 199 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 200 |
+
tokenizer_2: CLIPTokenizer,
|
| 201 |
+
text_encoder_3: T5EncoderModel,
|
| 202 |
+
tokenizer_3: T5TokenizerFast,
|
| 203 |
+
image_encoder: SiglipVisionModel = None,
|
| 204 |
+
feature_extractor: SiglipImageProcessor = None,
|
| 205 |
+
model = None, # 添加 model 参数
|
| 206 |
+
):
|
| 207 |
+
super().__init__()
|
| 208 |
+
|
| 209 |
+
self.register_modules(
|
| 210 |
+
vae=vae,
|
| 211 |
+
text_encoder=text_encoder,
|
| 212 |
+
text_encoder_2=text_encoder_2,
|
| 213 |
+
text_encoder_3=text_encoder_3,
|
| 214 |
+
tokenizer=tokenizer,
|
| 215 |
+
tokenizer_2=tokenizer_2,
|
| 216 |
+
tokenizer_3=tokenizer_3,
|
| 217 |
+
transformer=transformer,
|
| 218 |
+
scheduler=scheduler,
|
| 219 |
+
image_encoder=image_encoder,
|
| 220 |
+
feature_extractor=feature_extractor,
|
| 221 |
+
model=model, # 添加 model 参数到 register_modules
|
| 222 |
+
)
|
| 223 |
+
#print(f"VAE is None: {getattr(self, 'vae', None) is None}")
|
| 224 |
+
#if getattr(self, 'vae', None) is not None:
|
| 225 |
+
# print(f"VAE config block_out_channels: {self.vae.config.block_out_channels}")
|
| 226 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 227 |
+
#print(f"VAE scale factor: {self.vae_scale_factor}")
|
| 228 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 229 |
+
self.tokenizer_max_length = (
|
| 230 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 231 |
+
)
|
| 232 |
+
self.default_sample_size = (
|
| 233 |
+
self.transformer.config.sample_size
|
| 234 |
+
if hasattr(self, "transformer") and self.transformer is not None
|
| 235 |
+
else 64
|
| 236 |
+
)
|
| 237 |
+
self.patch_size = (
|
| 238 |
+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
| 239 |
+
)
|
| 240 |
+
# 添加对 SD3WithRectifiedNoise 模型的支持
|
| 241 |
+
self.model = model
|
| 242 |
+
|
| 243 |
+
def _get_t5_prompt_embeds(
|
| 244 |
+
self,
|
| 245 |
+
prompt: Union[str, List[str]] = None,
|
| 246 |
+
num_images_per_prompt: int = 1,
|
| 247 |
+
max_sequence_length: int = 256,
|
| 248 |
+
device: Optional[torch.device] = None,
|
| 249 |
+
dtype: Optional[torch.dtype] = None,
|
| 250 |
+
):
|
| 251 |
+
device = device or self._execution_device
|
| 252 |
+
dtype = dtype or self.text_encoder_3.dtype
|
| 253 |
+
#max_sequence_length=77
|
| 254 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 255 |
+
batch_size = len(prompt)
|
| 256 |
+
|
| 257 |
+
# print(f"T5处理 - 输入提示: {prompt}")
|
| 258 |
+
# print(f"T5处理 - batch_size: {batch_size}")
|
| 259 |
+
# print(f"T5处理 - num_images_per_prompt: {num_images_per_prompt}")
|
| 260 |
+
# print(f"T5处理 - max_sequence_length: {max_sequence_length}")
|
| 261 |
+
# print(f"T5处理 - device: {device}")
|
| 262 |
+
# print(f"T5处理 - dtype: {dtype}")
|
| 263 |
+
|
| 264 |
+
if self.text_encoder_3 is None:
|
| 265 |
+
#print("T5处理 - text_encoder_3为None,返回零张量")
|
| 266 |
+
return torch.zeros(
|
| 267 |
+
(
|
| 268 |
+
batch_size * num_images_per_prompt,
|
| 269 |
+
self.tokenizer_max_length,
|
| 270 |
+
self.transformer.config.joint_attention_dim,
|
| 271 |
+
),
|
| 272 |
+
device=device,
|
| 273 |
+
dtype=dtype,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
text_inputs = self.tokenizer_3(
|
| 277 |
+
prompt,
|
| 278 |
+
padding="max_length",
|
| 279 |
+
max_length=max_sequence_length,
|
| 280 |
+
truncation=True,
|
| 281 |
+
# add_special_tokens=True,
|
| 282 |
+
return_tensors="pt",
|
| 283 |
+
)
|
| 284 |
+
text_input_ids = text_inputs.input_ids
|
| 285 |
+
# if torch.isnan(text_input_ids).any():
|
| 286 |
+
# print("T5处理 - text_input_ids输入提示包含NaN值")
|
| 287 |
+
# else:
|
| 288 |
+
# print("T5处理 - text_input_ids",text_input_ids)
|
| 289 |
+
# print(f"T5处理 - text_input_ids形状: {text_input_ids.shape}")
|
| 290 |
+
# print(f"T5处理 - text_input_ids范围: [{text_input_ids.min().item()}, {text_input_ids.max().item()}]")
|
| 291 |
+
# print(f"T5处理 - tokenizer_3.vocab_size: {self.tokenizer_3.vocab_size}")
|
| 292 |
+
|
| 293 |
+
# # 检查输入token IDs是否包含非法值
|
| 294 |
+
# if torch.any(text_input_ids < 0):
|
| 295 |
+
# print(f"警告:发现负数token ID,最小值: {text_input_ids.min().item()}")
|
| 296 |
+
# if torch.any(text_input_ids >= self.tokenizer_3.vocab_size):
|
| 297 |
+
# print(f"警告:发现超过词汇表大小的token ID,最大值: {text_input_ids.max().item()}, vocab_size: {self.tokenizer_3.vocab_size}")
|
| 298 |
+
|
| 299 |
+
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
| 300 |
+
|
| 301 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 302 |
+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 303 |
+
logger.warning(
|
| 304 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 305 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# 将输入移动到设备并确保数据类型正确
|
| 309 |
+
text_input_ids = text_input_ids.to(device)#, dtype=torch.long)
|
| 310 |
+
# print(f"T5处理 - text_input_ids形状: ",text_input_ids)
|
| 311 |
+
# print(f"T5处理 - text_input_ids设备: {text_input_ids.device}, dtype: {text_input_ids.dtype}")
|
| 312 |
+
|
| 313 |
+
# 检查text_encoder_3的状态
|
| 314 |
+
# print(f"T5处理 - text_encoder_3设备: {next(self.text_encoder_3.parameters()).device}")
|
| 315 |
+
# print(f"T5处理 - text_encoder_3.dtype: {self.text_encoder_3.dtype}")
|
| 316 |
+
with torch.autocast(device.type if isinstance(device, torch.device) else "cuda", enabled=False):
|
| 317 |
+
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
| 318 |
+
#prompt_embeds = self.text_encoder_3(text_input_ids)[0]
|
| 319 |
+
# print(f"T5处理 - T5编码器输出形状: {prompt_embeds.shape}")
|
| 320 |
+
# print(f"T5处理 - T5编码器输出设备: {prompt_embeds.device}")
|
| 321 |
+
# print(f"T5处理 - T5编码器输出dtype: {prompt_embeds.dtype}")
|
| 322 |
+
|
| 323 |
+
# # 检查T5编码器输出是否包含NaN或inf
|
| 324 |
+
# has_nan = torch.isnan(prompt_embeds).any()
|
| 325 |
+
# has_inf = torch.isinf(prompt_embeds).any()
|
| 326 |
+
# if has_nan or has_inf:
|
| 327 |
+
# print(f"警告:T5编码器输出包含NaN: {has_nan} 或inf: {has_inf}")
|
| 328 |
+
# print(f"T5编码器输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
|
| 329 |
+
# if has_nan:
|
| 330 |
+
# nan_locations = torch.where(torch.isnan(prompt_embeds))
|
| 331 |
+
# print(f"NaN位置 - 前10个: {[(nan_locations[i][:10].tolist()) for i in range(len(nan_locations))]}")
|
| 332 |
+
|
| 333 |
+
dtype = self.text_encoder_3.dtype
|
| 334 |
+
# 强制在无 autocast 的上下文中运行 T5 以避免 fp16 溢出为 NaN
|
| 335 |
+
|
| 336 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 337 |
+
#print(f"T5处理 - 转换后prompt_embeds dtype: {prompt_embeds.dtype}, device: {prompt_embeds.device}")
|
| 338 |
+
|
| 339 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 340 |
+
|
| 341 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 342 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 343 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 344 |
+
# print(f"T5处理 - 最终输出形状: {prompt_embeds.shape}")
|
| 345 |
+
|
| 346 |
+
# # 检查最终输出是否包含NaN
|
| 347 |
+
# if torch.isnan(prompt_embeds).any():
|
| 348 |
+
# print(f"警告:T5最终输出包含NaN,位置: {torch.where(torch.isnan(prompt_embeds))}")
|
| 349 |
+
# print(f"T5最终输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
|
| 350 |
+
# print("最终输出",prompt_embeds)
|
| 351 |
+
return prompt_embeds
|
| 352 |
+
|
| 353 |
+
def _get_clip_prompt_embeds(
|
| 354 |
+
self,
|
| 355 |
+
prompt: Union[str, List[str]],
|
| 356 |
+
num_images_per_prompt: int = 1,
|
| 357 |
+
device: Optional[torch.device] = None,
|
| 358 |
+
clip_skip: Optional[int] = None,
|
| 359 |
+
clip_model_index: int = 0,
|
| 360 |
+
):
|
| 361 |
+
device = device or self._execution_device
|
| 362 |
+
|
| 363 |
+
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 364 |
+
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 365 |
+
|
| 366 |
+
tokenizer = clip_tokenizers[clip_model_index]
|
| 367 |
+
text_encoder = clip_text_encoders[clip_model_index]
|
| 368 |
+
|
| 369 |
+
# print(f"CLIP处理 - clip_model_index: {clip_model_index}")
|
| 370 |
+
# print(f"CLIP处理 - 使用的tokenizer: {type(tokenizer)}")
|
| 371 |
+
# print(f"CLIP处理 - 使用的text_encoder: {type(text_encoder)}")
|
| 372 |
+
|
| 373 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 374 |
+
batch_size = len(prompt)
|
| 375 |
+
# print(f"CLIP处理 - 输入提示: {prompt}")
|
| 376 |
+
# print(f"CLIP处理 - batch_size: {batch_size}")
|
| 377 |
+
|
| 378 |
+
text_inputs = tokenizer(
|
| 379 |
+
prompt,
|
| 380 |
+
padding="max_length",
|
| 381 |
+
max_length=self.tokenizer_max_length,
|
| 382 |
+
truncation=True,
|
| 383 |
+
return_tensors="pt",
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
text_input_ids = text_inputs.input_ids
|
| 387 |
+
# print(f"CLIP处理 - text_input_ids形状: {text_input_ids.shape}")
|
| 388 |
+
# print(f"CLIP处理 - text_input_ids范围: [{text_input_ids.min().item()}, {text_input_ids.max().item()}]")
|
| 389 |
+
|
| 390 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 391 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 392 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 393 |
+
logger.warning(
|
| 394 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 395 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 396 |
+
)
|
| 397 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 398 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 399 |
+
|
| 400 |
+
# print(f"CLIP处理 - pooled_prompt_embeds形状: {pooled_prompt_embeds.shape}")
|
| 401 |
+
# print(f"CLIP处理 - pooled_prompt_embeds设备: {pooled_prompt_embeds.device}")
|
| 402 |
+
# print(f"CLIP处理 - pooled_prompt_embeds dtype: {pooled_prompt_embeds.dtype}")
|
| 403 |
+
|
| 404 |
+
# 检查CLIP编码器输出是否包含NaN或inf
|
| 405 |
+
# has_nan = torch.isnan(pooled_prompt_embeds).any()
|
| 406 |
+
# has_inf = torch.isinf(pooled_prompt_embeds).any()
|
| 407 |
+
# if has_nan or has_inf:
|
| 408 |
+
# print(f"警告:CLIP编码器pooled输出包含NaN: {has_nan} 或inf: {has_inf}")
|
| 409 |
+
# print(f"CLIP编码器pooled输出统计 - min: {pooled_prompt_embeds.min().item()}, max: {pooled_prompt_embeds.max().item()}, mean: {pooled_prompt_embeds.mean().item()}")
|
| 410 |
+
|
| 411 |
+
if clip_skip is None:
|
| 412 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 413 |
+
else:
|
| 414 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
| 415 |
+
|
| 416 |
+
# 检查CLIP编码器embeds输出是否包含NaN或inf
|
| 417 |
+
# has_nan = torch.isnan(prompt_embeds).any()
|
| 418 |
+
# has_inf = torch.isinf(prompt_embeds).any()
|
| 419 |
+
# if has_nan or has_inf:
|
| 420 |
+
# print(f"警告:CLIP编码器embeds输出包含NaN: {has_nan} 或inf: {has_inf}")
|
| 421 |
+
# print(f"CLIP编码器embeds输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
|
| 422 |
+
|
| 423 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 424 |
+
#print(f"CLIP处理 - 转换后prompt_embeds dtype: {prompt_embeds.dtype}, device: {prompt_embeds.device}")
|
| 425 |
+
|
| 426 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 427 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 428 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 429 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 430 |
+
# print(f"CLIP处理 - 最终prompt_embeds形状: {prompt_embeds.shape}")
|
| 431 |
+
|
| 432 |
+
# # 检查CLIP编码器embeds最终输出是否包含NaN
|
| 433 |
+
# if torch.isnan(prompt_embeds).any():
|
| 434 |
+
# print(f"警告:CLIP编码器embeds最终输出包含NaN,位置: {torch.where(torch.isnan(prompt_embeds))}")
|
| 435 |
+
# print(f"CLIP编码器embeds最终输出统计 - min: {prompt_embeds.min().item()}, max: {prompt_embeds.max().item()}, mean: {prompt_embeds.mean().item()}")
|
| 436 |
+
|
| 437 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 438 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 439 |
+
# print(f"CLIP处理 - 最终pooled_prompt_embeds形状: {pooled_prompt_embeds.shape}")
|
| 440 |
+
|
| 441 |
+
# # 检查CLIP编码器pooled最终输出是否包含NaN
|
| 442 |
+
# if torch.isnan(pooled_prompt_embeds).any():
|
| 443 |
+
# print(f"警告:CLIP编码器pooled最终输出包含NaN,位置: {torch.where(torch.isnan(pooled_prompt_embeds))}")
|
| 444 |
+
# print(f"CLIP编码器pooled最终输出统计 - min: {pooled_prompt_embeds.min().item()}, max: {pooled_prompt_embeds.max().item()}, mean: {pooled_prompt_embeds.mean().item()}")
|
| 445 |
+
|
| 446 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 447 |
+
|
| 448 |
+
def encode_prompt(
|
| 449 |
+
self,
|
| 450 |
+
prompt: Union[str, List[str]],
|
| 451 |
+
prompt_2: Union[str, List[str]],
|
| 452 |
+
prompt_3: Union[str, List[str]],
|
| 453 |
+
device: Optional[torch.device] = None,
|
| 454 |
+
num_images_per_prompt: int = 1,
|
| 455 |
+
do_classifier_free_guidance: bool = True,
|
| 456 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 457 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 458 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 459 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 460 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 461 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 462 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 463 |
+
clip_skip: Optional[int] = None,
|
| 464 |
+
max_sequence_length: int = 256,
|
| 465 |
+
lora_scale: Optional[float] = None,
|
| 466 |
+
):
|
| 467 |
+
r"""
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 471 |
+
prompt to be encoded
|
| 472 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 473 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 474 |
+
used in all text-encoders
|
| 475 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 476 |
+
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 477 |
+
used in all text-encoders
|
| 478 |
+
device: (`torch.device`):
|
| 479 |
+
torch device
|
| 480 |
+
num_images_per_prompt (`int`):
|
| 481 |
+
number of images that should be generated per prompt
|
| 482 |
+
do_classifier_free_guidance (`bool`):
|
| 483 |
+
whether to use classifier free guidance or not
|
| 484 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 485 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 486 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 487 |
+
less than `1`).
|
| 488 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 489 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 490 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 491 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 492 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 493 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 494 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 495 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 496 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 497 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 498 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 499 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 500 |
+
argument.
|
| 501 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 502 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 503 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 504 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 505 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 506 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 507 |
+
input argument.
|
| 508 |
+
clip_skip (`int`, *optional*):
|
| 509 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 510 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 511 |
+
lora_scale (`float`, *optional*):
|
| 512 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 513 |
+
"""
|
| 514 |
+
device = device or self._execution_device
|
| 515 |
+
# print(f"encode_prompt - 开始处理提示编码")
|
| 516 |
+
# print(f"encode_prompt - device: {device}")
|
| 517 |
+
# print(f"encode_prompt - num_images_per_prompt: {num_images_per_prompt}")
|
| 518 |
+
# print(f"encode_prompt - do_classifier_free_guidance: {do_classifier_free_guidance}")
|
| 519 |
+
# print(f"encode_prompt - max_sequence_length: {max_sequence_length}")
|
| 520 |
+
|
| 521 |
+
# set lora scale so that monkey patched LoRA
|
| 522 |
+
# function of text encoder can correctly access it
|
| 523 |
+
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
| 524 |
+
self._lora_scale = lora_scale
|
| 525 |
+
|
| 526 |
+
# dynamically adjust the LoRA scale
|
| 527 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 528 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 529 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 530 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 531 |
+
|
| 532 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 533 |
+
if prompt is not None:
|
| 534 |
+
batch_size = len(prompt)
|
| 535 |
+
else:
|
| 536 |
+
batch_size = prompt_embeds.shape[0]
|
| 537 |
+
# print(f"encode_prompt - batch_size: {batch_size}")
|
| 538 |
+
|
| 539 |
+
if prompt_embeds is None:
|
| 540 |
+
prompt_2 = prompt_2 or prompt
|
| 541 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 542 |
+
|
| 543 |
+
prompt_3 = prompt_3 or prompt
|
| 544 |
+
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
| 545 |
+
|
| 546 |
+
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 547 |
+
prompt=prompt,
|
| 548 |
+
device=device,
|
| 549 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 550 |
+
clip_skip=clip_skip,
|
| 551 |
+
clip_model_index=0,
|
| 552 |
+
)
|
| 553 |
+
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 554 |
+
prompt=prompt_2,
|
| 555 |
+
device=device,
|
| 556 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 557 |
+
clip_skip=clip_skip,
|
| 558 |
+
clip_model_index=1,
|
| 559 |
+
)
|
| 560 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
| 561 |
+
|
| 562 |
+
t5_prompt_embed = self._get_t5_prompt_embeds(
|
| 563 |
+
prompt=prompt_3,
|
| 564 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 565 |
+
max_sequence_length=max_sequence_length,
|
| 566 |
+
device=device,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 570 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 571 |
+
)
|
| 572 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 573 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
| 574 |
+
|
| 575 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 576 |
+
negative_prompt = negative_prompt or ""
|
| 577 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 578 |
+
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
| 579 |
+
|
| 580 |
+
# normalize str to list
|
| 581 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 582 |
+
negative_prompt_2 = (
|
| 583 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
| 584 |
+
)
|
| 585 |
+
negative_prompt_3 = (
|
| 586 |
+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 590 |
+
raise TypeError(
|
| 591 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 592 |
+
f" {type(prompt)}."
|
| 593 |
+
)
|
| 594 |
+
elif batch_size != len(negative_prompt):
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 597 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 598 |
+
" the batch size of `prompt`."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 602 |
+
negative_prompt,
|
| 603 |
+
device=device,
|
| 604 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 605 |
+
clip_skip=None,
|
| 606 |
+
clip_model_index=0,
|
| 607 |
+
)
|
| 608 |
+
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 609 |
+
negative_prompt_2,
|
| 610 |
+
device=device,
|
| 611 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 612 |
+
clip_skip=None,
|
| 613 |
+
clip_model_index=1,
|
| 614 |
+
)
|
| 615 |
+
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
| 616 |
+
|
| 617 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
| 618 |
+
prompt=negative_prompt_3,
|
| 619 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 620 |
+
max_sequence_length=max_sequence_length,
|
| 621 |
+
device=device,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
| 625 |
+
negative_clip_prompt_embeds,
|
| 626 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
| 630 |
+
negative_pooled_prompt_embeds = torch.cat(
|
| 631 |
+
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if self.text_encoder is not None:
|
| 635 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 636 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 637 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 638 |
+
|
| 639 |
+
if self.text_encoder_2 is not None:
|
| 640 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 641 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 642 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 643 |
+
|
| 644 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 645 |
+
|
| 646 |
+
def check_inputs(
|
| 647 |
+
self,
|
| 648 |
+
prompt,
|
| 649 |
+
prompt_2,
|
| 650 |
+
prompt_3,
|
| 651 |
+
height,
|
| 652 |
+
width,
|
| 653 |
+
negative_prompt=None,
|
| 654 |
+
negative_prompt_2=None,
|
| 655 |
+
negative_prompt_3=None,
|
| 656 |
+
prompt_embeds=None,
|
| 657 |
+
negative_prompt_embeds=None,
|
| 658 |
+
pooled_prompt_embeds=None,
|
| 659 |
+
negative_pooled_prompt_embeds=None,
|
| 660 |
+
callback_on_step_end_tensor_inputs=None,
|
| 661 |
+
max_sequence_length=None,
|
| 662 |
+
):
|
| 663 |
+
if (
|
| 664 |
+
height % (self.vae_scale_factor * self.patch_size) != 0
|
| 665 |
+
or width % (self.vae_scale_factor * self.patch_size) != 0
|
| 666 |
+
):
|
| 667 |
+
raise ValueError(
|
| 668 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
| 669 |
+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 673 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 674 |
+
):
|
| 675 |
+
raise ValueError(
|
| 676 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if prompt is not None and prompt_embeds is not None:
|
| 680 |
+
raise ValueError(
|
| 681 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 682 |
+
" only forward one of the two."
|
| 683 |
+
)
|
| 684 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 685 |
+
raise ValueError(
|
| 686 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 687 |
+
" only forward one of the two."
|
| 688 |
+
)
|
| 689 |
+
elif prompt_3 is not None and prompt_embeds is not None:
|
| 690 |
+
raise ValueError(
|
| 691 |
+
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 692 |
+
" only forward one of the two."
|
| 693 |
+
)
|
| 694 |
+
elif prompt is None and prompt_embeds is None:
|
| 695 |
+
raise ValueError(
|
| 696 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 697 |
+
)
|
| 698 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 699 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 700 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 701 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 702 |
+
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
| 703 |
+
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
| 704 |
+
|
| 705 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 706 |
+
raise ValueError(
|
| 707 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 708 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 709 |
+
)
|
| 710 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 711 |
+
raise ValueError(
|
| 712 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 713 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 714 |
+
)
|
| 715 |
+
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
| 716 |
+
raise ValueError(
|
| 717 |
+
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
| 718 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 722 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 723 |
+
raise ValueError(
|
| 724 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 725 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 726 |
+
f" {negative_prompt_embeds.shape}."
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 730 |
+
raise ValueError(
|
| 731 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 735 |
+
raise ValueError(
|
| 736 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 740 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 741 |
+
|
| 742 |
+
def prepare_latents(
|
| 743 |
+
self,
|
| 744 |
+
batch_size,
|
| 745 |
+
num_channels_latents,
|
| 746 |
+
height,
|
| 747 |
+
width,
|
| 748 |
+
dtype,
|
| 749 |
+
device,
|
| 750 |
+
generator,
|
| 751 |
+
latents=None,
|
| 752 |
+
):
|
| 753 |
+
if latents is not None:
|
| 754 |
+
return latents.to(device=device, dtype=dtype)
|
| 755 |
+
|
| 756 |
+
shape = (
|
| 757 |
+
batch_size,
|
| 758 |
+
num_channels_latents,
|
| 759 |
+
int(height) // self.vae_scale_factor,
|
| 760 |
+
int(width) // self.vae_scale_factor,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 764 |
+
raise ValueError(
|
| 765 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 766 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 770 |
+
|
| 771 |
+
return latents
|
| 772 |
+
|
| 773 |
+
@property
|
| 774 |
+
def guidance_scale(self):
|
| 775 |
+
return self._guidance_scale
|
| 776 |
+
|
| 777 |
+
@property
|
| 778 |
+
def skip_guidance_layers(self):
|
| 779 |
+
return self._skip_guidance_layers
|
| 780 |
+
|
| 781 |
+
@property
|
| 782 |
+
def clip_skip(self):
|
| 783 |
+
return self._clip_skip
|
| 784 |
+
|
| 785 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 786 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 787 |
+
# corresponds to doing no classifier free guidance.
|
| 788 |
+
@property
|
| 789 |
+
def do_classifier_free_guidance(self):
|
| 790 |
+
return self._guidance_scale > 1
|
| 791 |
+
|
| 792 |
+
@property
|
| 793 |
+
def joint_attention_kwargs(self):
|
| 794 |
+
return self._joint_attention_kwargs
|
| 795 |
+
|
| 796 |
+
@property
|
| 797 |
+
def num_timesteps(self):
|
| 798 |
+
return self._num_timesteps
|
| 799 |
+
|
| 800 |
+
@property
|
| 801 |
+
def interrupt(self):
|
| 802 |
+
return self._interrupt
|
| 803 |
+
|
| 804 |
+
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
|
| 805 |
+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
|
| 806 |
+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
|
| 807 |
+
|
| 808 |
+
Args:
|
| 809 |
+
image (`PipelineImageInput`):
|
| 810 |
+
Input image to be encoded.
|
| 811 |
+
device: (`torch.device`):
|
| 812 |
+
Torch device.
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
`torch.Tensor`: The encoded image feature representation.
|
| 816 |
+
"""
|
| 817 |
+
if not isinstance(image, torch.Tensor):
|
| 818 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 819 |
+
|
| 820 |
+
image = image.to(device=device, dtype=self.dtype)
|
| 821 |
+
|
| 822 |
+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 823 |
+
|
| 824 |
+
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
|
| 825 |
+
def prepare_ip_adapter_image_embeds(
|
| 826 |
+
self,
|
| 827 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 828 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 829 |
+
device: Optional[torch.device] = None,
|
| 830 |
+
num_images_per_prompt: int = 1,
|
| 831 |
+
do_classifier_free_guidance: bool = True,
|
| 832 |
+
) -> torch.Tensor:
|
| 833 |
+
"""Prepares image embeddings for use in the IP-Adapter.
|
| 834 |
+
|
| 835 |
+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 839 |
+
The input image to extract features from for IP-Adapter.
|
| 840 |
+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
| 841 |
+
Precomputed image embeddings.
|
| 842 |
+
device: (`torch.device`, *optional*):
|
| 843 |
+
Torch device.
|
| 844 |
+
num_images_per_prompt (`int`, defaults to 1):
|
| 845 |
+
Number of images that should be generated per prompt.
|
| 846 |
+
do_classifier_free_guidance (`bool`, defaults to True):
|
| 847 |
+
Whether to use classifier free guidance or not.
|
| 848 |
+
"""
|
| 849 |
+
device = device or self._execution_device
|
| 850 |
+
|
| 851 |
+
if ip_adapter_image_embeds is not None:
|
| 852 |
+
if do_classifier_free_guidance:
|
| 853 |
+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
|
| 854 |
+
else:
|
| 855 |
+
single_image_embeds = ip_adapter_image_embeds
|
| 856 |
+
elif ip_adapter_image is not None:
|
| 857 |
+
single_image_embeds = self.encode_image(ip_adapter_image, device)
|
| 858 |
+
if do_classifier_free_guidance:
|
| 859 |
+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
|
| 860 |
+
else:
|
| 861 |
+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
|
| 862 |
+
|
| 863 |
+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 864 |
+
|
| 865 |
+
if do_classifier_free_guidance:
|
| 866 |
+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
| 867 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
| 868 |
+
|
| 869 |
+
return image_embeds.to(device=device)
|
| 870 |
+
|
| 871 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
| 872 |
+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
|
| 873 |
+
logger.warning(
|
| 874 |
+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
|
| 875 |
+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
|
| 876 |
+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
| 880 |
+
|
| 881 |
+
@torch.no_grad()
|
| 882 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 883 |
+
def __call__(
|
| 884 |
+
self,
|
| 885 |
+
prompt: Union[str, List[str]] = None,
|
| 886 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 887 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 888 |
+
height: Optional[int] = None,
|
| 889 |
+
width: Optional[int] = None,
|
| 890 |
+
num_inference_steps: int = 28,
|
| 891 |
+
sigmas: Optional[List[float]] = None,
|
| 892 |
+
guidance_scale: float = 7.0,
|
| 893 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 894 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 895 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 896 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 897 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 898 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 899 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 900 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 901 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 902 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 903 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 904 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 905 |
+
output_type: Optional[str] = "pil",
|
| 906 |
+
return_dict: bool = True,
|
| 907 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 908 |
+
clip_skip: Optional[int] = None,
|
| 909 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 910 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 911 |
+
max_sequence_length: int = 256,
|
| 912 |
+
skip_guidance_layers: List[int] = None,
|
| 913 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 914 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 915 |
+
skip_layer_guidance_start: float = 0.01,
|
| 916 |
+
mu: Optional[float] = None,
|
| 917 |
+
model: Optional[Any] = None, # 添加 model 参数,默认为 None
|
| 918 |
+
):
|
| 919 |
+
r"""
|
| 920 |
+
Function invoked when calling the pipeline for generation.
|
| 921 |
+
|
| 922 |
+
Args:
|
| 923 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 924 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 925 |
+
instead.
|
| 926 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 927 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 928 |
+
will be used instead
|
| 929 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 930 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 931 |
+
will be used instead
|
| 932 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 933 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 934 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 935 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 936 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 937 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 938 |
+
expense of slower inference.
|
| 939 |
+
sigmas (`List[float]`, *optional*):
|
| 940 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 941 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 942 |
+
will be used.
|
| 943 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 944 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 945 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 946 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 947 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 948 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 949 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 950 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 951 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 952 |
+
less than `1`).
|
| 953 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 954 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 955 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
| 956 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 957 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 958 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
| 959 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 960 |
+
The number of images to generate per prompt.
|
| 961 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 962 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 963 |
+
to make generation deterministic.
|
| 964 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 965 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 966 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 967 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 968 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 969 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 970 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 971 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 972 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 973 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 974 |
+
argument.
|
| 975 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 976 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 977 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 978 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 979 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 980 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 981 |
+
input argument.
|
| 982 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 983 |
+
Optional image input to work with IP Adapters.
|
| 984 |
+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
| 985 |
+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
|
| 986 |
+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
|
| 987 |
+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 988 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 989 |
+
The output format of the generate image. Choose between
|
| 990 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 991 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 992 |
+
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
| 993 |
+
a plain tuple.
|
| 994 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 995 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 996 |
+
`self.processor` in
|
| 997 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 998 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 999 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 1000 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 1001 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 1002 |
+
`callback_on_step_end_tensor_inputs`.
|
| 1003 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1004 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1005 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1006 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1007 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
| 1008 |
+
skip_guidance_layers (`List[int]`, *optional*):
|
| 1009 |
+
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
|
| 1010 |
+
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
|
| 1011 |
+
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
|
| 1012 |
+
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
|
| 1013 |
+
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
|
| 1014 |
+
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
|
| 1015 |
+
with a scale of `1`.
|
| 1016 |
+
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 1017 |
+
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
|
| 1018 |
+
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
|
| 1019 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
|
| 1020 |
+
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 1021 |
+
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
|
| 1022 |
+
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
|
| 1023 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
|
| 1024 |
+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
|
| 1025 |
+
model (`SD3WithRectifiedNoise`, *optional*):
|
| 1026 |
+
Optional SD3WithRectifiedNoise model for enhanced noise prediction. If provided, will be used instead of
|
| 1027 |
+
the default transformer for denoising. The model should be an instance of SD3WithRectifiedNoise class.
|
| 1028 |
+
|
| 1029 |
+
Examples:
|
| 1030 |
+
|
| 1031 |
+
Returns:
|
| 1032 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
| 1033 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
| 1034 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 1035 |
+
"""
|
| 1036 |
+
|
| 1037 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 1038 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 1039 |
+
#height=512
|
| 1040 |
+
#width=512
|
| 1041 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1042 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1043 |
+
|
| 1044 |
+
# 1. Check inputs. Raise error if not correct
|
| 1045 |
+
self.check_inputs(
|
| 1046 |
+
prompt,
|
| 1047 |
+
prompt_2,
|
| 1048 |
+
prompt_3,
|
| 1049 |
+
height,
|
| 1050 |
+
width,
|
| 1051 |
+
negative_prompt=negative_prompt,
|
| 1052 |
+
negative_prompt_2=negative_prompt_2,
|
| 1053 |
+
negative_prompt_3=negative_prompt_3,
|
| 1054 |
+
prompt_embeds=prompt_embeds,
|
| 1055 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1056 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1057 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1058 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1059 |
+
max_sequence_length=max_sequence_length,
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
self._guidance_scale = guidance_scale
|
| 1063 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 1064 |
+
self._clip_skip = clip_skip
|
| 1065 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 1066 |
+
self._interrupt = False
|
| 1067 |
+
|
| 1068 |
+
# 2. Define call parameters
|
| 1069 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1070 |
+
batch_size = 1
|
| 1071 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1072 |
+
batch_size = len(prompt)
|
| 1073 |
+
else:
|
| 1074 |
+
batch_size = prompt_embeds.shape[0]
|
| 1075 |
+
|
| 1076 |
+
device = self._execution_device
|
| 1077 |
+
|
| 1078 |
+
lora_scale = (
|
| 1079 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 1080 |
+
)
|
| 1081 |
+
(
|
| 1082 |
+
prompt_embeds,
|
| 1083 |
+
negative_prompt_embeds,
|
| 1084 |
+
pooled_prompt_embeds,
|
| 1085 |
+
negative_pooled_prompt_embeds,
|
| 1086 |
+
) = self.encode_prompt(
|
| 1087 |
+
prompt=prompt,
|
| 1088 |
+
prompt_2=prompt_2,
|
| 1089 |
+
prompt_3=prompt_3,
|
| 1090 |
+
negative_prompt=negative_prompt,
|
| 1091 |
+
negative_prompt_2=negative_prompt_2,
|
| 1092 |
+
negative_prompt_3=negative_prompt_3,
|
| 1093 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1094 |
+
prompt_embeds=prompt_embeds,
|
| 1095 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1096 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1097 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1098 |
+
device=device,
|
| 1099 |
+
clip_skip=self.clip_skip,
|
| 1100 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1101 |
+
max_sequence_length=max_sequence_length,
|
| 1102 |
+
lora_scale=lora_scale,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
if self.do_classifier_free_guidance:
|
| 1106 |
+
if skip_guidance_layers is not None:
|
| 1107 |
+
original_prompt_embeds = prompt_embeds
|
| 1108 |
+
original_pooled_prompt_embeds = pooled_prompt_embeds
|
| 1109 |
+
# print("检测negative_prompt_embeds",negative_prompt_embeds)
|
| 1110 |
+
#print("检测pooled_prompt_embeds",prompt_embeds)
|
| 1111 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1112 |
+
|
| 1113 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1114 |
+
|
| 1115 |
+
# 4. Prepare latent variables
|
| 1116 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 1117 |
+
latents = self.prepare_latents(
|
| 1118 |
+
batch_size * num_images_per_prompt,
|
| 1119 |
+
num_channels_latents,
|
| 1120 |
+
height,
|
| 1121 |
+
width,
|
| 1122 |
+
prompt_embeds.dtype,
|
| 1123 |
+
device,
|
| 1124 |
+
generator,
|
| 1125 |
+
latents,
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
# 5. Prepare timesteps
|
| 1129 |
+
scheduler_kwargs = {}
|
| 1130 |
+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
|
| 1131 |
+
_, _, height, width = latents.shape
|
| 1132 |
+
image_seq_len = (height // self.transformer.config.patch_size) * (
|
| 1133 |
+
width // self.transformer.config.patch_size
|
| 1134 |
+
)
|
| 1135 |
+
mu = calculate_shift(
|
| 1136 |
+
image_seq_len,
|
| 1137 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1138 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1139 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1140 |
+
self.scheduler.config.get("max_shift", 1.16),
|
| 1141 |
+
)
|
| 1142 |
+
scheduler_kwargs["mu"] = mu
|
| 1143 |
+
elif mu is not None:
|
| 1144 |
+
scheduler_kwargs["mu"] = mu
|
| 1145 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1146 |
+
self.scheduler,
|
| 1147 |
+
num_inference_steps,
|
| 1148 |
+
device,
|
| 1149 |
+
sigmas=sigmas,
|
| 1150 |
+
**scheduler_kwargs,
|
| 1151 |
+
)
|
| 1152 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1153 |
+
self._num_timesteps = len(timesteps)
|
| 1154 |
+
|
| 1155 |
+
# 6. Prepare image embeddings
|
| 1156 |
+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
| 1157 |
+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1158 |
+
ip_adapter_image,
|
| 1159 |
+
ip_adapter_image_embeds,
|
| 1160 |
+
device,
|
| 1161 |
+
batch_size * num_images_per_prompt,
|
| 1162 |
+
self.do_classifier_free_guidance,
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
if self.joint_attention_kwargs is None:
|
| 1166 |
+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
| 1167 |
+
else:
|
| 1168 |
+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
| 1169 |
+
|
| 1170 |
+
# 7. Denoising loop
|
| 1171 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1172 |
+
for i, t in enumerate(timesteps):
|
| 1173 |
+
if self.interrupt:
|
| 1174 |
+
continue
|
| 1175 |
+
|
| 1176 |
+
# expand the latents if we are doing classifier free guidance
|
| 1177 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1178 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1179 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1180 |
+
|
| 1181 |
+
# Check for NaN in latents before transformer
|
| 1182 |
+
if torch.isnan(latents).any():
|
| 1183 |
+
# print(f"NaN detected in latents at step {i}")
|
| 1184 |
+
# print(f"NaN locations: {torch.where(torch.isnan(latents))}")
|
| 1185 |
+
break
|
| 1186 |
+
# 优先使用传入的 model 参数,其次使用类属性中的 model,最后回退到默认 transformer
|
| 1187 |
+
effective_model = model or getattr(self, 'model', None) or self.transformer
|
| 1188 |
+
|
| 1189 |
+
if hasattr(effective_model, '__call__') and callable(effective_model):
|
| 1190 |
+
# 使用有效的模型进行预测
|
| 1191 |
+
# 检查模型是否支持 skip_layers 参数
|
| 1192 |
+
if hasattr(effective_model, 'forward') and 'skip_layers' in inspect.signature(effective_model.forward).parameters:
|
| 1193 |
+
noise_pred_output = effective_model(
|
| 1194 |
+
hidden_states=latent_model_input,
|
| 1195 |
+
timestep=timestep,
|
| 1196 |
+
encoder_hidden_states=prompt_embeds,
|
| 1197 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1198 |
+
#joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1199 |
+
return_dict=False,
|
| 1200 |
+
#skip_layers=skip_guidance_layers if skip_guidance_layers is not None else None,
|
| 1201 |
+
)
|
| 1202 |
+
#print(f"effective_model type: {type(effective_model)}")
|
| 1203 |
+
#print(f"noise_pred_output: {noise_pred_output}")
|
| 1204 |
+
else:
|
| 1205 |
+
# SD3WithRectifiedNoise 不支持 skip_layers 参数
|
| 1206 |
+
noise_pred_output = effective_model(
|
| 1207 |
+
hidden_states=latent_model_input,
|
| 1208 |
+
timestep=timestep,
|
| 1209 |
+
encoder_hidden_states=prompt_embeds,
|
| 1210 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1211 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1212 |
+
return_dict=False,
|
| 1213 |
+
)
|
| 1214 |
+
# 正确处理 SD3WithRectifiedNoise 模型的输出
|
| 1215 |
+
# SD3WithRectifiedNoise 返回 (final_output, mean_out, var_out) 元组,我们只需要第一个元素
|
| 1216 |
+
# 如果返回的是字典,则使用 "sample" 键
|
| 1217 |
+
if isinstance(noise_pred_output, dict):
|
| 1218 |
+
noise_pred = noise_pred_output["sample"]
|
| 1219 |
+
elif isinstance(noise_pred_output, tuple):
|
| 1220 |
+
# 对于 SD3WithRectifiedNoise,取第一个输出作为主要预测结果
|
| 1221 |
+
noise_pred = noise_pred_output[0]
|
| 1222 |
+
else:
|
| 1223 |
+
noise_pred = noise_pred_output
|
| 1224 |
+
else:
|
| 1225 |
+
# 使用默认的 transformer 进行预测
|
| 1226 |
+
noise_pred = self.transformer(
|
| 1227 |
+
hidden_states=latent_model_input,
|
| 1228 |
+
timestep=timestep,
|
| 1229 |
+
encoder_hidden_states=prompt_embeds,
|
| 1230 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1231 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1232 |
+
return_dict=False,
|
| 1233 |
+
)[0]
|
| 1234 |
+
|
| 1235 |
+
# Check for NaN in noise prediction
|
| 1236 |
+
if torch.isnan(noise_pred).any():
|
| 1237 |
+
# print(f"NaN detected in noise_pred at step {i}")
|
| 1238 |
+
# print(f"NaN locations: {torch.where(torch.isnan(noise_pred))}")
|
| 1239 |
+
# print(f"noise_pred stats - min: {noise_pred.min().item()}, max: {noise_pred.max().item()}, mean: {noise_pred.mean().item()}")
|
| 1240 |
+
break
|
| 1241 |
+
|
| 1242 |
+
# perform guidance
|
| 1243 |
+
if self.do_classifier_free_guidance:
|
| 1244 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1245 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1246 |
+
|
| 1247 |
+
# Check for NaN after guidance
|
| 1248 |
+
if torch.isnan(noise_pred).any():
|
| 1249 |
+
# print(f"NaN detected in noise_pred after guidance at step {i}")
|
| 1250 |
+
# print(f"noise_pred_uncond stats - min: {noise_pred_uncond.min().item()}, max: {noise_pred_uncond.max().item()}, mean: {noise_pred_uncond.mean().item()}")
|
| 1251 |
+
# print(f"noise_pred_text stats - min: {noise_pred_text.min().item()}, max: {noise_pred_text.max().item()}, mean: {noise_pred_text.mean().item()}")
|
| 1252 |
+
# print(f"guidance_scale: {self.guidance_scale}")
|
| 1253 |
+
break
|
| 1254 |
+
|
| 1255 |
+
should_skip_layers = (
|
| 1256 |
+
True
|
| 1257 |
+
if i > num_inference_steps * skip_layer_guidance_start
|
| 1258 |
+
and i < num_inference_steps * skip_layer_guidance_stop
|
| 1259 |
+
else False
|
| 1260 |
+
)
|
| 1261 |
+
if skip_guidance_layers is not None and should_skip_layers:
|
| 1262 |
+
timestep = t.expand(latents.shape[0])
|
| 1263 |
+
latent_model_input = latents
|
| 1264 |
+
# 修改 skip_guidance_layers 部分的 transformer 调用逻辑以支持 SD3WithRectifiedNoise 模型
|
| 1265 |
+
# 优先使用传入的 model 参数,其次使用类属性中的 model,最后回退到默认 transformer
|
| 1266 |
+
effective_model = model or getattr(self, 'model', None) or self.transformer
|
| 1267 |
+
|
| 1268 |
+
if hasattr(effective_model, '__call__') and callable(effective_model):
|
| 1269 |
+
# 使用有效的模型进行预测
|
| 1270 |
+
# 检查模型是否支持 skip_layers 参数
|
| 1271 |
+
if hasattr(effective_model, 'forward') and 'skip_layers' in inspect.signature(effective_model.forward).parameters:
|
| 1272 |
+
noise_pred_skip_output = effective_model(
|
| 1273 |
+
hidden_states=latent_model_input,
|
| 1274 |
+
timestep=timestep,
|
| 1275 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 1276 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 1277 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1278 |
+
return_dict=False,
|
| 1279 |
+
skip_layers=skip_guidance_layers,
|
| 1280 |
+
)
|
| 1281 |
+
else:
|
| 1282 |
+
# SD3WithRectifiedNoise 不支持 skip_layers 参数
|
| 1283 |
+
noise_pred_skip_output = effective_model(
|
| 1284 |
+
hidden_states=latent_model_input,
|
| 1285 |
+
timestep=timestep,
|
| 1286 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 1287 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 1288 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1289 |
+
return_dict=False,
|
| 1290 |
+
)
|
| 1291 |
+
# 正确处理 SD3WithRectifiedNoise 模型的输出
|
| 1292 |
+
# SD3WithRectifiedNoise 返回 (final_output, mean_out, var_out) 元组,我们只需要第一个元素
|
| 1293 |
+
# 如果返回的是字典,则使用 "sample" 键
|
| 1294 |
+
if isinstance(noise_pred_skip_output, dict):
|
| 1295 |
+
noise_pred_skip_layers = noise_pred_skip_output["sample"]
|
| 1296 |
+
elif isinstance(noise_pred_skip_output, tuple):
|
| 1297 |
+
# 对于 SD3WithRectifiedNoise,取第一个输出作为主要预测结果
|
| 1298 |
+
noise_pred_skip_layers = noise_pred_skip_output[0]
|
| 1299 |
+
else:
|
| 1300 |
+
noise_pred_skip_layers = noise_pred_skip_output
|
| 1301 |
+
else:
|
| 1302 |
+
# 使用默认的 transformer 进行预测
|
| 1303 |
+
noise_pred_skip_layers = self.transformer(
|
| 1304 |
+
hidden_states=latent_model_input,
|
| 1305 |
+
timestep=timestep,
|
| 1306 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 1307 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 1308 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1309 |
+
return_dict=False,
|
| 1310 |
+
skip_layers=skip_guidance_layers,
|
| 1311 |
+
)[0]
|
| 1312 |
+
|
| 1313 |
+
# Check for NaN in skip layers noise prediction
|
| 1314 |
+
if torch.isnan(noise_pred_skip_layers).any():
|
| 1315 |
+
# print(f"NaN detected in noise_pred_skip_layers at step {i}")
|
| 1316 |
+
# print(f"NaN locations: {torch.where(torch.isnan(noise_pred_skip_layers))}")
|
| 1317 |
+
break
|
| 1318 |
+
|
| 1319 |
+
noise_pred = (
|
| 1320 |
+
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
# Check for NaN after skip layer guidance
|
| 1324 |
+
if torch.isnan(noise_pred).any():
|
| 1325 |
+
# print(f"NaN detected in noise_pred after skip layer guidance at step {i}")
|
| 1326 |
+
break
|
| 1327 |
+
|
| 1328 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1329 |
+
latents_dtype = latents.dtype
|
| 1330 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1331 |
+
|
| 1332 |
+
# Check for NaN in latents after scheduler step
|
| 1333 |
+
if torch.isnan(latents).any():
|
| 1334 |
+
# print(f"NaN detected in latents after scheduler step at step {i}")
|
| 1335 |
+
# print(f"noise_pred stats - min: {noise_pred.min().item()}, max: {noise_pred.max().item()}, mean: {noise_pred.mean().item()}")
|
| 1336 |
+
break
|
| 1337 |
+
|
| 1338 |
+
# Print intermediate results
|
| 1339 |
+
# print(f"Step {i+1}/{num_inference_steps}, Timestep: {t.item():.2f}, Latents mean: {latents.mean().item():.6f}, Latents std: {latents.std().item():.6f}")
|
| 1340 |
+
|
| 1341 |
+
if latents.dtype != latents_dtype:
|
| 1342 |
+
if torch.backends.mps.is_available():
|
| 1343 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1344 |
+
latents = latents.to(latents_dtype)
|
| 1345 |
+
|
| 1346 |
+
if callback_on_step_end is not None:
|
| 1347 |
+
callback_kwargs = {}
|
| 1348 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1349 |
+
callback_kwargs[k] = locals()[k]
|
| 1350 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1351 |
+
|
| 1352 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1353 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1354 |
+
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
|
| 1355 |
+
|
| 1356 |
+
# call the callback, if provided
|
| 1357 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1358 |
+
progress_bar.update()
|
| 1359 |
+
|
| 1360 |
+
if XLA_AVAILABLE:
|
| 1361 |
+
xm.mark_step()
|
| 1362 |
+
|
| 1363 |
+
if output_type == "latent":
|
| 1364 |
+
image = latents
|
| 1365 |
+
|
| 1366 |
+
else:
|
| 1367 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1368 |
+
|
| 1369 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1370 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1371 |
+
|
| 1372 |
+
# Offload all models
|
| 1373 |
+
self.maybe_free_model_hooks()
|
| 1374 |
+
|
| 1375 |
+
if not return_dict:
|
| 1376 |
+
return (image,)
|
| 1377 |
+
|
| 1378 |
+
return StableDiffusion3PipelineOutput(images=image)
|
rectified-noise-batch-2/checkpoint-100000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
rectified-noise-batch-2/checkpoint-120000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
rectified-noise-batch-2/checkpoint-140000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
rectified-noise-batch-2/checkpoint-160000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
rectified-noise-batch-2/checkpoint-180000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
rectified-noise-batch-2/checkpoint-200000/sit_weights/sit_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_sit_layers": 1,
|
| 3 |
+
"hidden_size": 4096,
|
| 4 |
+
"input_dim": 16,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"intermediate_size": 16384,
|
| 7 |
+
"model_type": "rectified_noise",
|
| 8 |
+
"architecture": "SIT",
|
| 9 |
+
"version": "1.0"
|
| 10 |
+
}
|
run_sd3_lora_rn_pair_sampling.sh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
| 4 |
+
|
| 5 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 6 |
+
LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
|
| 7 |
+
RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
|
| 8 |
+
|
| 9 |
+
CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
|
| 10 |
+
SAMPLE_DIR="./sd3_lora_rn_pair_samples"
|
| 11 |
+
|
| 12 |
+
NUM_INFERENCE_STEPS=40
|
| 13 |
+
GUIDANCE_SCALE=7.0
|
| 14 |
+
HEIGHT=512
|
| 15 |
+
WIDTH=512
|
| 16 |
+
PER_PROC_BATCH_SIZE=1
|
| 17 |
+
IMAGES_PER_CAPTION=1
|
| 18 |
+
MAX_SAMPLES=500
|
| 19 |
+
GLOBAL_SEED=42
|
| 20 |
+
MIXED_PRECISION="fp16"
|
| 21 |
+
NUM_SIT_LAYERS=1
|
| 22 |
+
|
| 23 |
+
ARGS=(
|
| 24 |
+
--pretrained_model_name_or_path "$PRETRAINED_MODEL"
|
| 25 |
+
--captions_jsonl "$CAPTIONS_JSONL"
|
| 26 |
+
--sample_dir "$SAMPLE_DIR"
|
| 27 |
+
--num_inference_steps $NUM_INFERENCE_STEPS
|
| 28 |
+
--guidance_scale $GUIDANCE_SCALE
|
| 29 |
+
--height $HEIGHT
|
| 30 |
+
--width $WIDTH
|
| 31 |
+
--per_proc_batch_size $PER_PROC_BATCH_SIZE
|
| 32 |
+
--images_per_caption $IMAGES_PER_CAPTION
|
| 33 |
+
--max_samples $MAX_SAMPLES
|
| 34 |
+
--global_seed $GLOBAL_SEED
|
| 35 |
+
--num_sit_layers $NUM_SIT_LAYERS
|
| 36 |
+
--mixed_precision $MIXED_PRECISION
|
| 37 |
+
--rectified_weights "$RECTIFIED_WEIGHTS"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if [ -n "$LORA_PATH" ]; then
|
| 41 |
+
ARGS+=(--lora_path "$LORA_PATH")
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
torchrun --nproc_per_node=4 --master_port=25923 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage lora
|
| 45 |
+
torchrun --nproc_per_node=4 --master_port=25924 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage rn
|
| 46 |
+
torchrun --nproc_per_node=4 --master_port=25925 sample_sd3_lora_rn_pair_ddp.py "${ARGS[@]}" --stage pair
|
| 47 |
+
|
| 48 |
+
echo "Sampling done. Output at: $SAMPLE_DIR"
|
| 49 |
+
# nohup bash run_sd3_lora_rn_pair_sampling.sh > run_sd3_lora_rn_pair_sampling.log 2>&1 &
|
| 50 |
+
|
run_sd3_lora_sampling.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
run_sd3_lora_sampling.sh
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# SD3 LoRA模型采样脚本
|
| 4 |
+
# 使用JSONL文件进行采样的示例脚本
|
| 5 |
+
# 使用方法: ./run_sd3_lora_sampling.sh
|
| 6 |
+
|
| 7 |
+
# 设置GPU设备
|
| 8 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3" # 使用4个GPU(0,1,2,3)
|
| 9 |
+
|
| 10 |
+
# 内存优化设置
|
| 11 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 12 |
+
|
| 13 |
+
# 模型和LoRA路径配置
|
| 14 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 15 |
+
# LoRA checkpoint路径 - 使用accelerator checkpoint目录
|
| 16 |
+
LORA_CHECKPOINT_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
|
| 17 |
+
# LoRA rank(必须与训练时一致)
|
| 18 |
+
LORA_RANK=32
|
| 19 |
+
|
| 20 |
+
# 采样参数配置
|
| 21 |
+
NUM_INFERENCE_STEPS=40
|
| 22 |
+
GUIDANCE_SCALE=7.0
|
| 23 |
+
HEIGHT=512
|
| 24 |
+
WIDTH=512
|
| 25 |
+
PER_PROC_BATCH_SIZE=1 # 每个GPU的批大小,建议从1开始(SD3模型很大,保持为1以避免内存溢出)
|
| 26 |
+
MAX_SAMPLES=30000 # 最大采样数量限制
|
| 27 |
+
|
| 28 |
+
# 提示词配置
|
| 29 |
+
#NEGATIVE_PROMPT="blurry, low quality, distorted, ugly, bad anatomy"
|
| 30 |
+
|
| 31 |
+
# Caption文件配置
|
| 32 |
+
CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl" # JSONL文件路径
|
| 33 |
+
IMAGES_PER_CAPTION=3 # 每个caption生成几张图片
|
| 34 |
+
|
| 35 |
+
# 输出配置
|
| 36 |
+
SAMPLE_DIR="./sd3_lora_samples_3w"
|
| 37 |
+
GLOBAL_SEED=42
|
| 38 |
+
|
| 39 |
+
echo "开始SD3 LoRA采样(从checkpoint加载)..."
|
| 40 |
+
echo "模型: $PRETRAINED_MODEL"
|
| 41 |
+
echo "LoRA Checkpoint路径: $LORA_CHECKPOINT_PATH"
|
| 42 |
+
echo "LoRA Rank: $LORA_RANK"
|
| 43 |
+
echo "Caption文件: $CAPTIONS_JSONL"
|
| 44 |
+
echo "每个caption生成图片数: $IMAGES_PER_CAPTION"
|
| 45 |
+
echo "图像尺寸: ${HEIGHT}x${WIDTH}"
|
| 46 |
+
echo "引导尺度: $GUIDANCE_SCALE"
|
| 47 |
+
echo "推理步数: $NUM_INFERENCE_STEPS"
|
| 48 |
+
|
| 49 |
+
# 检查必要文件
|
| 50 |
+
if [ ! -f "$CAPTIONS_JSONL" ]; then
|
| 51 |
+
echo "错误: Caption文件 $CAPTIONS_JSONL 不存在"
|
| 52 |
+
exit 1
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
if [ ! -d "$LORA_CHECKPOINT_PATH" ]; then
|
| 56 |
+
echo "错误: LoRA checkpoint目录 $LORA_CHECKPOINT_PATH 不存在"
|
| 57 |
+
exit 1
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
# 构建命令参数数组
|
| 61 |
+
CMD_ARGS=(
|
| 62 |
+
"--pretrained_model_name_or_path=$PRETRAINED_MODEL"
|
| 63 |
+
"--lora_checkpoint_path=$LORA_CHECKPOINT_PATH"
|
| 64 |
+
"--lora_rank=$LORA_RANK"
|
| 65 |
+
"--num_inference_steps=$NUM_INFERENCE_STEPS"
|
| 66 |
+
"--guidance_scale=$GUIDANCE_SCALE"
|
| 67 |
+
"--height=$HEIGHT"
|
| 68 |
+
"--width=$WIDTH"
|
| 69 |
+
"--per_proc_batch_size=$PER_PROC_BATCH_SIZE"
|
| 70 |
+
"--captions_jsonl=$CAPTIONS_JSONL"
|
| 71 |
+
"--images_per_caption=$IMAGES_PER_CAPTION"
|
| 72 |
+
"--sample_dir=$SAMPLE_DIR"
|
| 73 |
+
"--global_seed=$GLOBAL_SEED"
|
| 74 |
+
#"--max_samples=$MAX_SAMPLES"
|
| 75 |
+
"--mixed_precision=fp16" # 使用 fp16 以减少内存占用
|
| 76 |
+
# 注意:在多GPU环境下,CPU offload会被代码自动禁用(不支持分布式)
|
| 77 |
+
# 代码会自动检测world_size > 1并禁用CPU offload
|
| 78 |
+
"--enable_cpu_offload"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# # 添加负面提示词参数(如果存在)
|
| 82 |
+
# if [ ! -z "$NEGATIVE_PROMPT" ]; then
|
| 83 |
+
# CMD_ARGS+=("--negative_prompt" "$NEGATIVE_PROMPT")
|
| 84 |
+
# fi
|
| 85 |
+
|
| 86 |
+
# 运行分布式采样
|
| 87 |
+
torchrun --nproc_per_node=4 --master_port=25900 sample_sd3_lora_checkpoint_ddp.py "${CMD_ARGS[@]}"
|
| 88 |
+
|
| 89 |
+
echo "采样完成!"
|
| 90 |
+
echo "结果保存在: $SAMPLE_DIR"
|
| 91 |
+
echo "Caption信息保存在: $SAMPLE_DIR/*/captions.txt"
|
| 92 |
+
echo "NPZ文件已生成用于FID评估"
|
| 93 |
+
|
| 94 |
+
# nohup bash run_sd3_lora_sampling.sh > run_sd3_lora_sampling.log 2>&1 &
|
run_sd3_rectified_sampling.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# 分布式采样:指定 LoRA 与 Rectified(SIT) 权重
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
| 6 |
+
|
| 7 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 8 |
+
LOCAL_PIPELINE_PATH="/gemini/space/gzy_new/models/Sida/pipeline_stable_diffusion_3.py"
|
| 9 |
+
LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
|
| 10 |
+
RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
|
| 11 |
+
|
| 12 |
+
CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
|
| 13 |
+
SAMPLE_DIR="./sd3_rectified_samples_batch2_220000"
|
| 14 |
+
|
| 15 |
+
NUM_INFERENCE_STEPS=40
|
| 16 |
+
GUIDANCE_SCALE=7.0
|
| 17 |
+
HEIGHT=512
|
| 18 |
+
WIDTH=512
|
| 19 |
+
PER_PROC_BATCH_SIZE=32
|
| 20 |
+
IMAGES_PER_CAPTION=3
|
| 21 |
+
MAX_SAMPLES=30000
|
| 22 |
+
GLOBAL_SEED=42
|
| 23 |
+
MIXED_PRECISION="fp16" # no / fp16 / bf16
|
| 24 |
+
NUM_SIT_LAYERS=1 # 需与训练一致
|
| 25 |
+
|
| 26 |
+
ARGS=(
|
| 27 |
+
--pretrained_model_name_or_path "$PRETRAINED_MODEL"
|
| 28 |
+
--captions_jsonl "$CAPTIONS_JSONL"
|
| 29 |
+
--sample_dir "$SAMPLE_DIR"
|
| 30 |
+
--num_inference_steps $NUM_INFERENCE_STEPS
|
| 31 |
+
--guidance_scale $GUIDANCE_SCALE
|
| 32 |
+
--height $HEIGHT
|
| 33 |
+
--width $WIDTH
|
| 34 |
+
--per_proc_batch_size $PER_PROC_BATCH_SIZE
|
| 35 |
+
--images_per_caption $IMAGES_PER_CAPTION
|
| 36 |
+
--max_samples $MAX_SAMPLES
|
| 37 |
+
--global_seed $GLOBAL_SEED
|
| 38 |
+
--num_sit_layers $NUM_SIT_LAYERS
|
| 39 |
+
--mixed_precision $MIXED_PRECISION
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if [ -n "$LORA_PATH" ]; then
|
| 43 |
+
ARGS+=(--lora_path "$LORA_PATH")
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
if [ -n "$RECTIFIED_WEIGHTS" ]; then
|
| 47 |
+
ARGS+=(--rectified_weights "$RECTIFIED_WEIGHTS")
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
torchrun --nproc_per_node=4 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"
|
| 51 |
+
|
| 52 |
+
echo "Sampling done. Output at: $SAMPLE_DIR"
|
| 53 |
+
# nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
|
| 54 |
+
|
| 55 |
+
|
run_sd3_rectified_sampling_old.sh
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# 分布式采样:指定 LoRA 与 Rectified(SIT) 权重
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
| 7 |
+
export NCCL_DEBUG=INFO
|
| 8 |
+
export NCCL_DEBUG_SUBSYS=ALL
|
| 9 |
+
export NCCL_IB_DISABLE=1
|
| 10 |
+
export NCCL_P2P_LEVEL=SYS
|
| 11 |
+
|
| 12 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 13 |
+
#"/gemini/space/zhaozy/zhy/hsd/project/pretrained_model/models--stabilityai--stable-diffusion-3-medium-diffusers"
|
| 14 |
+
LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000" # 可为空
|
| 15 |
+
RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-120000" # 可为空(若不用 Rectified)
|
| 16 |
+
|
| 17 |
+
CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl"
|
| 18 |
+
SAMPLE_DIR="./sd3_rectified_samples_new_batch_2"
|
| 19 |
+
|
| 20 |
+
NUM_INFERENCE_STEPS=40
|
| 21 |
+
GUIDANCE_SCALE=7.0
|
| 22 |
+
HEIGHT=512
|
| 23 |
+
WIDTH=512
|
| 24 |
+
PER_PROC_BATCH_SIZE=32
|
| 25 |
+
IMAGES_PER_CAPTION=3
|
| 26 |
+
MAX_SAMPLES=30000
|
| 27 |
+
GLOBAL_SEED=42
|
| 28 |
+
MIXED_PRECISION="fp16" # no / fp16 / bf16
|
| 29 |
+
NUM_SIT_LAYERS=1 # 需与训练一致
|
| 30 |
+
|
| 31 |
+
ARGS=(
|
| 32 |
+
--pretrained_model_name_or_path "$PRETRAINED_MODEL"
|
| 33 |
+
--captions_jsonl "$CAPTIONS_JSONL"
|
| 34 |
+
--sample_dir "$SAMPLE_DIR"
|
| 35 |
+
--num_inference_steps $NUM_INFERENCE_STEPS
|
| 36 |
+
--guidance_scale $GUIDANCE_SCALE
|
| 37 |
+
--height $HEIGHT
|
| 38 |
+
--width $WIDTH
|
| 39 |
+
--per_proc_batch_size $PER_PROC_BATCH_SIZE
|
| 40 |
+
--images_per_caption $IMAGES_PER_CAPTION
|
| 41 |
+
--max_samples $MAX_SAMPLES
|
| 42 |
+
--global_seed $GLOBAL_SEED
|
| 43 |
+
--num_sit_layers $NUM_SIT_LAYERS
|
| 44 |
+
--mixed_precision $MIXED_PRECISION
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if [ -n "$LORA_PATH" ]; then
|
| 48 |
+
ARGS+=(--lora_path "$LORA_PATH")
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
if [ -n "$RECTIFIED_WEIGHTS" ]; then
|
| 52 |
+
ARGS+=(--rectified_weights "$RECTIFIED_WEIGHTS")
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
echo "[run_sd3_rectified_sampling.sh] start torchrun: $(date)"
|
| 56 |
+
|
| 57 |
+
# 先尝试 4 卡模式,如果失败则退到单卡模式
|
| 58 |
+
if ! torchrun --nproc_per_node=4 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"; then
|
| 59 |
+
ret=$?
|
| 60 |
+
echo "[run_sd3_rectified_sampling.sh] 4卡运行失败(退出码 ${ret}),尝试单卡模式"
|
| 61 |
+
if ! torchrun --nproc_per_node=1 --master_port=25913 sample_sd3_rectified_ddp.py "${ARGS[@]}"; then
|
| 62 |
+
ret2=$?
|
| 63 |
+
echo "[run_sd3_rectified_sampling.sh] 单卡运行也失败(退出码 ${ret2}),请查看具体错误信息。"
|
| 64 |
+
exit $ret2
|
| 65 |
+
fi
|
| 66 |
+
echo "[run_sd3_rectified_sampling.sh] 单卡运行成功,建议降低 per_proc_batch_size 或使用单卡配置继续。"
|
| 67 |
+
fi
|
| 68 |
+
|
| 69 |
+
wait
|
| 70 |
+
|
| 71 |
+
echo "Sampling done. Output at: $SAMPLE_DIR"
|
| 72 |
+
# nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
|
sample_sd3_lora_checkpoint_ddp.py
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
SD3 LoRA分布式采样脚本 - 从accelerator checkpoint加载LoRA权重
|
| 5 |
+
使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import os
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import math
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
import json
|
| 18 |
+
import random
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
from diffusers import (
|
| 22 |
+
StableDiffusion3Pipeline,
|
| 23 |
+
AutoencoderKL,
|
| 24 |
+
FlowMatchEulerDiscreteScheduler,
|
| 25 |
+
SD3Transformer2DModel,
|
| 26 |
+
)
|
| 27 |
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
| 28 |
+
from accelerate import Accelerator
|
| 29 |
+
from peft import LoraConfig, PeftModel
|
| 30 |
+
from peft.utils import get_peft_model_state_dict
|
| 31 |
+
from safetensors.torch import load_file, save_file
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_npz_from_sample_folder(sample_dir, num_samples):
|
| 35 |
+
"""
|
| 36 |
+
从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
|
| 37 |
+
"""
|
| 38 |
+
samples = []
|
| 39 |
+
actual_files = []
|
| 40 |
+
|
| 41 |
+
# 收集所有PNG文件
|
| 42 |
+
for filename in sorted(os.listdir(sample_dir)):
|
| 43 |
+
if filename.endswith('.png'):
|
| 44 |
+
actual_files.append(filename)
|
| 45 |
+
|
| 46 |
+
# 按照数量限制处理
|
| 47 |
+
for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
|
| 48 |
+
if i < len(actual_files):
|
| 49 |
+
sample_path = os.path.join(sample_dir, actual_files[i])
|
| 50 |
+
sample_pil = Image.open(sample_path)
|
| 51 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 52 |
+
samples.append(sample_np)
|
| 53 |
+
else:
|
| 54 |
+
# 如果不够,创建空白图像
|
| 55 |
+
sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 56 |
+
samples.append(sample_np)
|
| 57 |
+
|
| 58 |
+
if samples:
|
| 59 |
+
samples = np.stack(samples)
|
| 60 |
+
npz_path = f"{sample_dir}.npz"
|
| 61 |
+
np.savez(npz_path, arr_0=samples)
|
| 62 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 63 |
+
return npz_path
|
| 64 |
+
else:
|
| 65 |
+
print("No samples found to create npz file.")
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def extract_lora_from_checkpoint(checkpoint_path, output_lora_path, rank=64, rank0_only=True):
|
| 70 |
+
"""
|
| 71 |
+
从accelerator checkpoint中提取LoRA权重并保存为标准格式
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
checkpoint_path: checkpoint目录路径
|
| 75 |
+
output_lora_path: 输出LoRA权重保存路径
|
| 76 |
+
rank: LoRA rank
|
| 77 |
+
rank0_only: 是否只在rank 0上执行
|
| 78 |
+
"""
|
| 79 |
+
model_file = os.path.join(checkpoint_path, "model.safetensors")
|
| 80 |
+
if not os.path.exists(model_file):
|
| 81 |
+
if rank0_only:
|
| 82 |
+
print(f"Model file not found: {model_file}")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# 加载checkpoint state dict
|
| 87 |
+
state_dict = load_file(model_file)
|
| 88 |
+
|
| 89 |
+
if rank0_only:
|
| 90 |
+
print(f"Loaded checkpoint with {len(state_dict)} keys")
|
| 91 |
+
|
| 92 |
+
# 提取LoRA权重
|
| 93 |
+
# Accelerator保存的格式可能是: "transformer.lora_A.weight" 或 "model.transformer.lora_A.weight"
|
| 94 |
+
# 需要转换为diffusers格式: "transformer.lora_A.weight"
|
| 95 |
+
lora_state_dict = {}
|
| 96 |
+
|
| 97 |
+
# 查找所有LoRA相关的键
|
| 98 |
+
lora_keys = []
|
| 99 |
+
for key in state_dict.keys():
|
| 100 |
+
# 检查是否是LoRA权重(lora_A, lora_B, lora_embedding等)
|
| 101 |
+
if 'lora_A' in key or 'lora_B' in key or 'lora_embedding' in key:
|
| 102 |
+
lora_keys.append(key)
|
| 103 |
+
|
| 104 |
+
if rank0_only:
|
| 105 |
+
print(f"Found {len(lora_keys)} LoRA keys")
|
| 106 |
+
if lora_keys:
|
| 107 |
+
print(f"Sample LoRA keys: {lora_keys[:5]}")
|
| 108 |
+
|
| 109 |
+
if not lora_keys:
|
| 110 |
+
if rank0_only:
|
| 111 |
+
print("Warning: No LoRA keys found in checkpoint. Trying alternative extraction method...")
|
| 112 |
+
|
| 113 |
+
# 尝试另一种方法:检查是否有完整的transformer权重
|
| 114 |
+
# 如果是全量微调,我们需要计算LoRA权重 = 微调权重 - 基础权重
|
| 115 |
+
# 但这需要基础模型,所以这里我们假设checkpoint中已经包含了LoRA权重
|
| 116 |
+
# 或者checkpoint保存的是合并后的权重
|
| 117 |
+
|
| 118 |
+
# 检查是否有transformer的完整权重
|
| 119 |
+
transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()]
|
| 120 |
+
if transformer_keys:
|
| 121 |
+
if rank0_only:
|
| 122 |
+
print(f"Found {len(transformer_keys)} transformer keys (full fine-tuning checkpoint)")
|
| 123 |
+
print("This checkpoint appears to contain full model weights, not LoRA weights.")
|
| 124 |
+
print("You may need to use a different loading method.")
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
# 转换键名格式:从accelerator格式转换为diffusers格式
|
| 128 |
+
for key in lora_keys:
|
| 129 |
+
# 移除可能的"model."前缀
|
| 130 |
+
new_key = key
|
| 131 |
+
if new_key.startswith("model."):
|
| 132 |
+
new_key = new_key[6:] # 移除"model."前缀
|
| 133 |
+
|
| 134 |
+
# 确保键名符合diffusers格式
|
| 135 |
+
# diffusers格式通常是: "transformer.lora_A.weight" 或 "transformer.transformer_blocks.X.attn.to_q.lora_A.weight"
|
| 136 |
+
lora_state_dict[new_key] = state_dict[key]
|
| 137 |
+
|
| 138 |
+
if not lora_state_dict:
|
| 139 |
+
if rank0_only:
|
| 140 |
+
print("Error: Failed to extract LoRA weights from checkpoint")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
# 保存LoRA权重
|
| 144 |
+
if rank0_only:
|
| 145 |
+
os.makedirs(output_lora_path, exist_ok=True)
|
| 146 |
+
lora_file = os.path.join(output_lora_path, "pytorch_lora_weights.safetensors")
|
| 147 |
+
save_file(lora_state_dict, lora_file)
|
| 148 |
+
print(f"Saved LoRA weights to {lora_file} ({len(lora_state_dict)} keys)")
|
| 149 |
+
|
| 150 |
+
return True
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
if rank0_only:
|
| 154 |
+
print(f"Error extracting LoRA from checkpoint: {e}")
|
| 155 |
+
import traceback
|
| 156 |
+
traceback.print_exc()
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def load_lora_from_checkpoint_direct(pipeline, checkpoint_path, rank=64, rank0_print=True):
|
| 161 |
+
"""
|
| 162 |
+
直接从checkpoint加载LoRA权重到pipeline
|
| 163 |
+
|
| 164 |
+
这个方法尝试直接从checkpoint中加载LoRA权重,而不需要先提取
|
| 165 |
+
"""
|
| 166 |
+
model_file = os.path.join(checkpoint_path, "model.safetensors")
|
| 167 |
+
if not os.path.exists(model_file):
|
| 168 |
+
if rank0_print:
|
| 169 |
+
print(f"Model file not found: {model_file}")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# 加载checkpoint state dict
|
| 174 |
+
state_dict = load_file(model_file)
|
| 175 |
+
|
| 176 |
+
if rank0_print:
|
| 177 |
+
print(f"Loaded checkpoint with {len(state_dict)} keys")
|
| 178 |
+
# 显示前10个键名以便调试
|
| 179 |
+
sample_keys = list(state_dict.keys())[:10]
|
| 180 |
+
print(f"Sample keys: {sample_keys}")
|
| 181 |
+
|
| 182 |
+
# 查找LoRA权重
|
| 183 |
+
lora_keys = [k for k in state_dict.keys() if 'lora_A' in k or 'lora_B' in k or 'lora_embedding' in k]
|
| 184 |
+
|
| 185 |
+
if not lora_keys:
|
| 186 |
+
if rank0_print:
|
| 187 |
+
print("No LoRA keys found in checkpoint.")
|
| 188 |
+
print("This checkpoint might contain merged weights or use a different format.")
|
| 189 |
+
print("Checking checkpoint structure...")
|
| 190 |
+
|
| 191 |
+
# 检查是否是全量微调的checkpoint(包含完整transformer权重)
|
| 192 |
+
transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()]
|
| 193 |
+
if transformer_keys:
|
| 194 |
+
if rank0_print:
|
| 195 |
+
print(f"Found {len(transformer_keys)} transformer keys")
|
| 196 |
+
print("This appears to be a full fine-tuning checkpoint with merged weights.")
|
| 197 |
+
print("Attempting to use Accelerator to load the checkpoint...")
|
| 198 |
+
|
| 199 |
+
# 尝试使用Accelerator加载checkpoint
|
| 200 |
+
try:
|
| 201 |
+
# 配置LoRA适配器
|
| 202 |
+
transformer_lora_config = LoraConfig(
|
| 203 |
+
r=rank,
|
| 204 |
+
lora_alpha=rank,
|
| 205 |
+
init_lora_weights="gaussian",
|
| 206 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# 为transformer添加LoRA适配器
|
| 210 |
+
pipeline.transformer.add_adapter(transformer_lora_config)
|
| 211 |
+
|
| 212 |
+
# 使用Accelerator加载checkpoint
|
| 213 |
+
accelerator = Accelerator()
|
| 214 |
+
# 准备模型
|
| 215 |
+
transformer_prepared = accelerator.prepare(pipeline.transformer)
|
| 216 |
+
# 加载状态
|
| 217 |
+
accelerator.load_state(checkpoint_path)
|
| 218 |
+
# 提取模型
|
| 219 |
+
pipeline.transformer = accelerator.unwrap_model(transformer_prepared)
|
| 220 |
+
|
| 221 |
+
if rank0_print:
|
| 222 |
+
print("Successfully loaded checkpoint using Accelerator")
|
| 223 |
+
return True
|
| 224 |
+
except Exception as e:
|
| 225 |
+
if rank0_print:
|
| 226 |
+
print(f"Failed to load using Accelerator: {e}")
|
| 227 |
+
return False
|
| 228 |
+
else:
|
| 229 |
+
if rank0_print:
|
| 230 |
+
print("Could not identify checkpoint format. Please check the checkpoint structure.")
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
if rank0_print:
|
| 234 |
+
print(f"Found {len(lora_keys)} LoRA keys")
|
| 235 |
+
print(f"Sample LoRA keys: {lora_keys[:5]}")
|
| 236 |
+
|
| 237 |
+
# 配置LoRA适配器
|
| 238 |
+
transformer_lora_config = LoraConfig(
|
| 239 |
+
r=rank,
|
| 240 |
+
lora_alpha=rank,
|
| 241 |
+
init_lora_weights="gaussian",
|
| 242 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# 为transformer添加LoRA适配器
|
| 246 |
+
pipeline.transformer.add_adapter(transformer_lora_config)
|
| 247 |
+
|
| 248 |
+
if rank0_print:
|
| 249 |
+
print("LoRA adapter configured")
|
| 250 |
+
|
| 251 |
+
# 提取并转换LoRA权重
|
| 252 |
+
lora_state_dict = {}
|
| 253 |
+
for key in lora_keys:
|
| 254 |
+
# 移除可能的"model."或"transformer."前缀(取决于accelerator保存格式)
|
| 255 |
+
new_key = key
|
| 256 |
+
# 移除常见的accelerator前缀
|
| 257 |
+
prefixes_to_remove = ["model.", "module.", "transformer."]
|
| 258 |
+
for prefix in prefixes_to_remove:
|
| 259 |
+
if new_key.startswith(prefix):
|
| 260 |
+
new_key = new_key[len(prefix):]
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
# 确保键名符合PEFT格式
|
| 264 |
+
# PEFT格式通常是: "transformer_blocks.X.attn.to_q.lora_A.weight"
|
| 265 |
+
# 或者: "lora_A.weight" (如果已经包含完整路径)
|
| 266 |
+
lora_state_dict[new_key] = state_dict[key]
|
| 267 |
+
|
| 268 |
+
if rank0_print:
|
| 269 |
+
print(f"Extracted {len(lora_state_dict)} LoRA weights")
|
| 270 |
+
print(f"Sample extracted keys: {list(lora_state_dict.keys())[:5]}")
|
| 271 |
+
|
| 272 |
+
# 加载LoRA权重到模型
|
| 273 |
+
# 使用PEFT的load_state_dict方法
|
| 274 |
+
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
|
| 275 |
+
|
| 276 |
+
if rank0_print:
|
| 277 |
+
if missing_keys:
|
| 278 |
+
print(f"Missing keys: {len(missing_keys)}")
|
| 279 |
+
if len(missing_keys) <= 10:
|
| 280 |
+
for k in missing_keys:
|
| 281 |
+
print(f" - {k}")
|
| 282 |
+
else:
|
| 283 |
+
print(f" (showing first 10 of {len(missing_keys)} missing keys)")
|
| 284 |
+
for k in list(missing_keys)[:10]:
|
| 285 |
+
print(f" - {k}")
|
| 286 |
+
if unexpected_keys:
|
| 287 |
+
print(f"Unexpected keys: {len(unexpected_keys)}")
|
| 288 |
+
if len(unexpected_keys) <= 10:
|
| 289 |
+
for k in unexpected_keys:
|
| 290 |
+
print(f" - {k}")
|
| 291 |
+
else:
|
| 292 |
+
print(f" (showing first 10 of {len(unexpected_keys)} unexpected keys)")
|
| 293 |
+
for k in list(unexpected_keys)[:10]:
|
| 294 |
+
print(f" - {k}")
|
| 295 |
+
|
| 296 |
+
# 检查是否有peft_config
|
| 297 |
+
if hasattr(pipeline.transformer, 'peft_config'):
|
| 298 |
+
if rank0_print:
|
| 299 |
+
print(f"LoRA config found: {list(pipeline.transformer.peft_config.keys())}")
|
| 300 |
+
else:
|
| 301 |
+
if rank0_print:
|
| 302 |
+
print("Warning: No peft_config found after loading LoRA")
|
| 303 |
+
|
| 304 |
+
# 验证LoRA是否真的被加载
|
| 305 |
+
if rank0_print:
|
| 306 |
+
# 检查一个LoRA层的权重是否非零
|
| 307 |
+
has_lora_weights = False
|
| 308 |
+
for name, param in pipeline.transformer.named_parameters():
|
| 309 |
+
if 'lora' in name.lower() and param.requires_grad:
|
| 310 |
+
if param.abs().max().item() > 1e-6:
|
| 311 |
+
has_lora_weights = True
|
| 312 |
+
if rank0_print:
|
| 313 |
+
print(f"Verified LoRA weights loaded (found non-zero LoRA param: {name})")
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
if not has_lora_weights:
|
| 317 |
+
print("Warning: LoRA weights may not have been loaded correctly (all LoRA params are zero or not found)")
|
| 318 |
+
|
| 319 |
+
if rank0_print:
|
| 320 |
+
print("LoRA weights loaded successfully")
|
| 321 |
+
|
| 322 |
+
return True
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
if rank0_print:
|
| 326 |
+
print(f"Error loading LoRA from checkpoint: {e}")
|
| 327 |
+
import traceback
|
| 328 |
+
traceback.print_exc()
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def load_captions_from_jsonl(jsonl_path):
|
| 333 |
+
"""
|
| 334 |
+
从JSONL文件加载caption列表
|
| 335 |
+
"""
|
| 336 |
+
captions = []
|
| 337 |
+
try:
|
| 338 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 339 |
+
for line_num, line in enumerate(f, 1):
|
| 340 |
+
line = line.strip()
|
| 341 |
+
if not line:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
data = json.loads(line)
|
| 346 |
+
# 支持多种字段名
|
| 347 |
+
caption = None
|
| 348 |
+
for field in ['caption', 'text', 'prompt', 'description']:
|
| 349 |
+
if field in data and isinstance(data[field], str):
|
| 350 |
+
caption = data[field].strip()
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
if caption:
|
| 354 |
+
captions.append(caption)
|
| 355 |
+
else:
|
| 356 |
+
# 如果没有找到标准字段,取第一个字符串值
|
| 357 |
+
for value in data.values():
|
| 358 |
+
if isinstance(value, str) and value.strip():
|
| 359 |
+
captions.append(value.strip())
|
| 360 |
+
break
|
| 361 |
+
|
| 362 |
+
except json.JSONDecodeError as e:
|
| 363 |
+
print(f"Warning: Invalid JSON on line {line_num}: {e}")
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
except FileNotFoundError:
|
| 367 |
+
print(f"Error: JSONL file {jsonl_path} not found")
|
| 368 |
+
return []
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f"Error loading JSONL file {jsonl_path}: {e}")
|
| 371 |
+
return []
|
| 372 |
+
|
| 373 |
+
print(f"Loaded {len(captions)} captions from {jsonl_path}")
|
| 374 |
+
return captions
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def main(args):
|
| 378 |
+
"""
|
| 379 |
+
运行 SD3 LoRA 采样
|
| 380 |
+
"""
|
| 381 |
+
assert torch.cuda.is_available(), "DDP采样需要至少一个GPU"
|
| 382 |
+
torch.set_grad_enabled(False)
|
| 383 |
+
|
| 384 |
+
# 设置 DDP
|
| 385 |
+
dist.init_process_group("nccl")
|
| 386 |
+
rank = dist.get_rank()
|
| 387 |
+
world_size = dist.get_world_size()
|
| 388 |
+
device = torch.device(f"cuda:{rank}")
|
| 389 |
+
seed = args.global_seed * world_size + rank
|
| 390 |
+
torch.manual_seed(seed)
|
| 391 |
+
torch.cuda.set_device(device)
|
| 392 |
+
print(f"Starting rank={rank}, device={device}, seed={seed}, world_size={world_size}, visible_devices={torch.cuda.device_count()}.")
|
| 393 |
+
|
| 394 |
+
# 加载captions
|
| 395 |
+
captions = []
|
| 396 |
+
if args.captions_jsonl:
|
| 397 |
+
if rank == 0:
|
| 398 |
+
print(f"Loading captions from {args.captions_jsonl}")
|
| 399 |
+
captions = load_captions_from_jsonl(args.captions_jsonl)
|
| 400 |
+
if not captions:
|
| 401 |
+
if rank == 0:
|
| 402 |
+
print("Warning: No captions loaded, using default caption")
|
| 403 |
+
captions = ["a beautiful high quality image"]
|
| 404 |
+
else:
|
| 405 |
+
# 使用默认caption
|
| 406 |
+
captions = ["a beautiful high quality image"]
|
| 407 |
+
|
| 408 |
+
# 计算总的图片数量
|
| 409 |
+
total_images_needed = len(captions) * args.images_per_caption
|
| 410 |
+
# 应用最大样本数限制
|
| 411 |
+
total_images_needed = min(total_images_needed, args.max_samples)
|
| 412 |
+
if rank == 0:
|
| 413 |
+
print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions")
|
| 414 |
+
print(f"Total images requested: {len(captions) * args.images_per_caption}")
|
| 415 |
+
print(f"Max samples limit: {args.max_samples}")
|
| 416 |
+
print(f"Total images to generate: {total_images_needed}")
|
| 417 |
+
|
| 418 |
+
# 设置数据类型 - 使用混合精度以减少内存占用
|
| 419 |
+
if args.mixed_precision == "fp16":
|
| 420 |
+
dtype = torch.float16
|
| 421 |
+
elif args.mixed_precision == "bf16":
|
| 422 |
+
dtype = torch.bfloat16
|
| 423 |
+
else:
|
| 424 |
+
dtype = torch.float32
|
| 425 |
+
|
| 426 |
+
# 加载基础模型
|
| 427 |
+
if rank == 0:
|
| 428 |
+
print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}")
|
| 429 |
+
|
| 430 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 431 |
+
args.pretrained_model_name_or_path,
|
| 432 |
+
revision=args.revision,
|
| 433 |
+
variant=args.variant,
|
| 434 |
+
torch_dtype=dtype,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# 从checkpoint加载LoRA权重
|
| 438 |
+
lora_loaded = False
|
| 439 |
+
lora_source = "baseline"
|
| 440 |
+
|
| 441 |
+
if args.lora_checkpoint_path:
|
| 442 |
+
if rank == 0:
|
| 443 |
+
print(f"Loading LoRA weights from checkpoint: {args.lora_checkpoint_path}")
|
| 444 |
+
|
| 445 |
+
# 方法1: 直接从checkpoint加载
|
| 446 |
+
lora_loaded = load_lora_from_checkpoint_direct(
|
| 447 |
+
pipeline,
|
| 448 |
+
args.lora_checkpoint_path,
|
| 449 |
+
rank=args.lora_rank,
|
| 450 |
+
rank0_print=(rank == 0)
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if lora_loaded:
|
| 454 |
+
lora_source = os.path.basename(args.lora_checkpoint_path.rstrip('/'))
|
| 455 |
+
if rank == 0:
|
| 456 |
+
print("Successfully loaded LoRA weights from checkpoint")
|
| 457 |
+
else:
|
| 458 |
+
if rank == 0:
|
| 459 |
+
print("Failed to load LoRA weights directly from checkpoint")
|
| 460 |
+
print("Trying alternative method: extracting LoRA weights first...")
|
| 461 |
+
|
| 462 |
+
# 方法2: 先提取LoRA权重,再加载
|
| 463 |
+
temp_lora_path = os.path.join(args.lora_checkpoint_path, "extracted_lora")
|
| 464 |
+
if rank == 0:
|
| 465 |
+
extract_success = extract_lora_from_checkpoint(
|
| 466 |
+
args.lora_checkpoint_path,
|
| 467 |
+
temp_lora_path,
|
| 468 |
+
rank=args.lora_rank,
|
| 469 |
+
rank0_only=True
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
extract_success = False
|
| 473 |
+
|
| 474 |
+
dist.barrier() # 等待rank 0完成提取
|
| 475 |
+
|
| 476 |
+
if extract_success and os.path.exists(os.path.join(temp_lora_path, "pytorch_lora_weights.safetensors")):
|
| 477 |
+
if rank == 0:
|
| 478 |
+
print(f"Loading extracted LoRA weights from {temp_lora_path}")
|
| 479 |
+
try:
|
| 480 |
+
pipeline.load_lora_weights(temp_lora_path)
|
| 481 |
+
lora_loaded = True
|
| 482 |
+
lora_source = f"{os.path.basename(args.lora_checkpoint_path.rstrip('/'))}_extracted"
|
| 483 |
+
if rank == 0:
|
| 484 |
+
print("Successfully loaded extracted LoRA weights")
|
| 485 |
+
except Exception as e:
|
| 486 |
+
if rank == 0:
|
| 487 |
+
print(f"Failed to load extracted LoRA weights: {e}")
|
| 488 |
+
|
| 489 |
+
if not lora_loaded:
|
| 490 |
+
if rank == 0:
|
| 491 |
+
print("Warning: No LoRA weights loaded. Using baseline model.")
|
| 492 |
+
|
| 493 |
+
# 启用内存优化选项(必须在移动到设备之前)
|
| 494 |
+
# 注意:在分布式环境下,CPU offload 不支持多GPU,会导致所有进程挤在一张卡上
|
| 495 |
+
# 因此禁用 CPU offload,改用其他内存优化方法
|
| 496 |
+
if args.enable_cpu_offload and world_size > 1:
|
| 497 |
+
if rank == 0:
|
| 498 |
+
print(f"Warning: CPU offload is disabled in multi-GPU mode (world_size={world_size})")
|
| 499 |
+
print("Using device-specific placement instead")
|
| 500 |
+
args.enable_cpu_offload = False
|
| 501 |
+
|
| 502 |
+
if args.enable_cpu_offload:
|
| 503 |
+
if rank == 0:
|
| 504 |
+
print("Enabling CPU offload to save memory (single GPU mode)")
|
| 505 |
+
# CPU offload 会自动管理设备,不需要先 to(device)
|
| 506 |
+
pipeline.enable_model_cpu_offload()
|
| 507 |
+
else:
|
| 508 |
+
# 在分布式环境下,明确将pipeline移动到对应的设备
|
| 509 |
+
if rank == 0:
|
| 510 |
+
print(f"Moving pipeline to device {device} (multi-GPU mode)")
|
| 511 |
+
pipeline = pipeline.to(device)
|
| 512 |
+
if rank == 0:
|
| 513 |
+
print("Enabling memory optimization options")
|
| 514 |
+
|
| 515 |
+
# 检查并启用可用的内存优化方法
|
| 516 |
+
# 注意:所有进程都需要执行这些操作,不仅仅是 rank 0
|
| 517 |
+
if hasattr(pipeline, 'enable_attention_slicing'):
|
| 518 |
+
try:
|
| 519 |
+
pipeline.enable_attention_slicing()
|
| 520 |
+
if rank == 0:
|
| 521 |
+
print(" - Attention slicing enabled")
|
| 522 |
+
except Exception as e:
|
| 523 |
+
if rank == 0:
|
| 524 |
+
print(f" - Warning: Failed to enable attention slicing: {e}")
|
| 525 |
+
else:
|
| 526 |
+
if rank == 0:
|
| 527 |
+
print(" - Attention slicing not available for this pipeline")
|
| 528 |
+
|
| 529 |
+
# SD3 pipeline 可能不支持 enable_vae_slicing,需要检查
|
| 530 |
+
# 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
|
| 531 |
+
enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
|
| 532 |
+
if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
|
| 533 |
+
try:
|
| 534 |
+
enable_vae_slicing_method()
|
| 535 |
+
if rank == 0:
|
| 536 |
+
print(" - VAE slicing enabled")
|
| 537 |
+
except Exception as e:
|
| 538 |
+
if rank == 0:
|
| 539 |
+
print(f" - Warning: Failed to enable VAE slicing: {e}")
|
| 540 |
+
else:
|
| 541 |
+
if rank == 0:
|
| 542 |
+
print(" - VAE slicing not available for this pipeline (SD3 may not support this)")
|
| 543 |
+
|
| 544 |
+
# 验证设备分配
|
| 545 |
+
if rank == 0:
|
| 546 |
+
print(f"Pipeline device verification:")
|
| 547 |
+
print(f" - Transformer device: {next(pipeline.transformer.parameters()).device}")
|
| 548 |
+
print(f" - VAE device: {next(pipeline.vae.parameters()).device}")
|
| 549 |
+
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
|
| 550 |
+
print(f" - Text encoder device: {next(pipeline.text_encoder.parameters()).device}")
|
| 551 |
+
dist.barrier() # 等待所有进程完成设备分配
|
| 552 |
+
|
| 553 |
+
# 禁用进度条
|
| 554 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 555 |
+
|
| 556 |
+
# 创建保存目录
|
| 557 |
+
folder_name = f"checkpoint-{lora_source}-rank{args.lora_rank}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
|
| 558 |
+
sample_folder_dir = os.path.join(args.sample_dir, folder_name)
|
| 559 |
+
|
| 560 |
+
if rank == 0:
|
| 561 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 562 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 563 |
+
# 清空caption文件
|
| 564 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 565 |
+
if os.path.exists(caption_file):
|
| 566 |
+
os.remove(caption_file)
|
| 567 |
+
dist.barrier()
|
| 568 |
+
|
| 569 |
+
# 计算采样参数
|
| 570 |
+
n = args.per_proc_batch_size
|
| 571 |
+
global_batch_size = n * dist.get_world_size()
|
| 572 |
+
|
| 573 |
+
# 检查已存在的样本数量
|
| 574 |
+
existing_samples = 0
|
| 575 |
+
if os.path.exists(sample_folder_dir):
|
| 576 |
+
existing_samples = len([
|
| 577 |
+
name for name in os.listdir(sample_folder_dir)
|
| 578 |
+
if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png")
|
| 579 |
+
])
|
| 580 |
+
|
| 581 |
+
total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size)
|
| 582 |
+
if rank == 0:
|
| 583 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 584 |
+
print(f"Existing samples: {existing_samples}")
|
| 585 |
+
|
| 586 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 587 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 588 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 589 |
+
|
| 590 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 591 |
+
done_iterations = int(int(existing_samples // dist.get_world_size()) // n)
|
| 592 |
+
|
| 593 |
+
pbar = range(done_iterations, iterations)
|
| 594 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 595 |
+
|
| 596 |
+
# 生成caption和image的映射列表
|
| 597 |
+
caption_image_pairs = []
|
| 598 |
+
for i, caption in enumerate(captions):
|
| 599 |
+
for j in range(args.images_per_caption):
|
| 600 |
+
caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
|
| 601 |
+
|
| 602 |
+
total_generated = existing_samples
|
| 603 |
+
|
| 604 |
+
# 采样循环
|
| 605 |
+
for i in pbar:
|
| 606 |
+
# 获取这个batch对应的caption
|
| 607 |
+
batch_prompts = []
|
| 608 |
+
batch_caption_info = []
|
| 609 |
+
|
| 610 |
+
for j in range(n):
|
| 611 |
+
global_index = i * global_batch_size + j * dist.get_world_size() + rank
|
| 612 |
+
if global_index < len(caption_image_pairs):
|
| 613 |
+
caption, caption_idx, image_idx = caption_image_pairs[global_index]
|
| 614 |
+
batch_prompts.append(caption)
|
| 615 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 616 |
+
else:
|
| 617 |
+
# 如果超出范围,使用最后一个caption
|
| 618 |
+
if caption_image_pairs:
|
| 619 |
+
caption, caption_idx, image_idx = caption_image_pairs[-1]
|
| 620 |
+
batch_prompts.append(caption)
|
| 621 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 622 |
+
else:
|
| 623 |
+
batch_prompts.append("a beautiful high quality image")
|
| 624 |
+
batch_caption_info.append(("a beautiful high quality image", 0, 0))
|
| 625 |
+
|
| 626 |
+
# 生成图像 - 为每个图像使用不同的随机种子
|
| 627 |
+
# 确保使用正确的设备进行autocast
|
| 628 |
+
device_str = str(device) # 使用明确的设备字符串,如 "cuda:0", "cuda:1" 等
|
| 629 |
+
with torch.autocast(device_str, dtype=dtype):
|
| 630 |
+
# 为每个prompt生成独立的图像(使用不同的generator)
|
| 631 |
+
images = []
|
| 632 |
+
for k, prompt in enumerate(batch_prompts):
|
| 633 |
+
# 为每个图像创建独立的随机种子
|
| 634 |
+
image_seed = seed + i * 10000 + k * 1000 + rank
|
| 635 |
+
generator = torch.Generator(device=device).manual_seed(image_seed)
|
| 636 |
+
|
| 637 |
+
# 调试信息(仅在第一个batch的第一个图像时打印)
|
| 638 |
+
if i == done_iterations and k == 0 and rank < 2: # 只打印前两个rank
|
| 639 |
+
print(f"[Rank {rank}] Generating image on device {device}, generator device: {generator.device}")
|
| 640 |
+
|
| 641 |
+
image = pipeline(
|
| 642 |
+
prompt=prompt,
|
| 643 |
+
negative_prompt=args.negative_prompt if args.negative_prompt else None,
|
| 644 |
+
height=args.height,
|
| 645 |
+
width=args.width,
|
| 646 |
+
num_inference_steps=args.num_inference_steps,
|
| 647 |
+
guidance_scale=args.guidance_scale,
|
| 648 |
+
generator=generator,
|
| 649 |
+
num_images_per_prompt=1,
|
| 650 |
+
).images[0]
|
| 651 |
+
images.append(image)
|
| 652 |
+
|
| 653 |
+
# 清理 GPU 缓存以释放内存
|
| 654 |
+
if k == len(batch_prompts) - 1: # 每个 batch 的最后一张图片后清理
|
| 655 |
+
torch.cuda.empty_cache()
|
| 656 |
+
|
| 657 |
+
# 保存图像
|
| 658 |
+
for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
|
| 659 |
+
global_index = i * global_batch_size + j * dist.get_world_size() + rank
|
| 660 |
+
if global_index < len(caption_image_pairs):
|
| 661 |
+
# 保存图片,文件名包含caption索引和图片索引
|
| 662 |
+
filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
|
| 663 |
+
image_path = os.path.join(sample_folder_dir, filename)
|
| 664 |
+
image.save(image_path)
|
| 665 |
+
|
| 666 |
+
# 保存caption信息到文本文件(只在rank 0上操作)
|
| 667 |
+
if rank == 0:
|
| 668 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 669 |
+
with open(caption_file, "a", encoding="utf-8") as f:
|
| 670 |
+
f.write(f"{filename}\t{caption}\n")
|
| 671 |
+
|
| 672 |
+
total_generated += global_batch_size
|
| 673 |
+
|
| 674 |
+
# 每个迭代后清理 GPU 缓存
|
| 675 |
+
torch.cuda.empty_cache()
|
| 676 |
+
|
| 677 |
+
dist.barrier()
|
| 678 |
+
|
| 679 |
+
# 确保所有进程都完成采样
|
| 680 |
+
dist.barrier()
|
| 681 |
+
|
| 682 |
+
# 创建npz文件
|
| 683 |
+
if rank == 0:
|
| 684 |
+
# 重新计算实际生成的图片数量
|
| 685 |
+
actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
|
| 686 |
+
print(f"Actually generated {actual_num_samples} images")
|
| 687 |
+
# 使用实际的图片数量或用户指定的数量,取较小值
|
| 688 |
+
npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
|
| 689 |
+
create_npz_from_sample_folder(sample_folder_dir, npz_samples)
|
| 690 |
+
print("Done.")
|
| 691 |
+
|
| 692 |
+
dist.barrier()
|
| 693 |
+
dist.destroy_process_group()
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
if __name__ == "__main__":
|
| 697 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本 - 从checkpoint加载")
|
| 698 |
+
|
| 699 |
+
# 模型和路径参数
|
| 700 |
+
parser.add_argument(
|
| 701 |
+
"--pretrained_model_name_or_path",
|
| 702 |
+
type=str,
|
| 703 |
+
default="stabilityai/stable-diffusion-3-medium-diffusers",
|
| 704 |
+
help="预训练模型路径或HuggingFace模型ID"
|
| 705 |
+
)
|
| 706 |
+
parser.add_argument(
|
| 707 |
+
"--lora_checkpoint_path",
|
| 708 |
+
type=str,
|
| 709 |
+
required=True,
|
| 710 |
+
help="LoRA checkpoint目录路径(包含model.safetensors的目录)"
|
| 711 |
+
)
|
| 712 |
+
parser.add_argument(
|
| 713 |
+
"--lora_rank",
|
| 714 |
+
type=int,
|
| 715 |
+
default=64,
|
| 716 |
+
help="LoRA rank(必须与训练时一致)"
|
| 717 |
+
)
|
| 718 |
+
parser.add_argument(
|
| 719 |
+
"--revision",
|
| 720 |
+
type=str,
|
| 721 |
+
default=None,
|
| 722 |
+
help="模型修订版本"
|
| 723 |
+
)
|
| 724 |
+
parser.add_argument(
|
| 725 |
+
"--variant",
|
| 726 |
+
type=str,
|
| 727 |
+
default=None,
|
| 728 |
+
help="模型变体,如fp16"
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
# 采样参数
|
| 732 |
+
parser.add_argument(
|
| 733 |
+
"--num_inference_steps",
|
| 734 |
+
type=int,
|
| 735 |
+
default=28,
|
| 736 |
+
help="推理步数"
|
| 737 |
+
)
|
| 738 |
+
parser.add_argument(
|
| 739 |
+
"--guidance_scale",
|
| 740 |
+
type=float,
|
| 741 |
+
default=7.0,
|
| 742 |
+
help="引导尺度"
|
| 743 |
+
)
|
| 744 |
+
parser.add_argument(
|
| 745 |
+
"--height",
|
| 746 |
+
type=int,
|
| 747 |
+
default=1024,
|
| 748 |
+
help="生成图像高度"
|
| 749 |
+
)
|
| 750 |
+
parser.add_argument(
|
| 751 |
+
"--width",
|
| 752 |
+
type=int,
|
| 753 |
+
default=1024,
|
| 754 |
+
help="生成图像宽度"
|
| 755 |
+
)
|
| 756 |
+
parser.add_argument(
|
| 757 |
+
"--negative_prompt",
|
| 758 |
+
type=str,
|
| 759 |
+
default="",
|
| 760 |
+
help="负面提示词"
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
# 批处理和数据集参数
|
| 764 |
+
parser.add_argument(
|
| 765 |
+
"--per_proc_batch_size",
|
| 766 |
+
type=int,
|
| 767 |
+
default=1,
|
| 768 |
+
help="每个进程的批处理大小"
|
| 769 |
+
)
|
| 770 |
+
parser.add_argument(
|
| 771 |
+
"--sample_dir",
|
| 772 |
+
type=str,
|
| 773 |
+
default="sd3_lora_samples",
|
| 774 |
+
help="样本保存目录"
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# Caption相关参数
|
| 778 |
+
parser.add_argument(
|
| 779 |
+
"--captions_jsonl",
|
| 780 |
+
type=str,
|
| 781 |
+
required=True,
|
| 782 |
+
help="包含caption列表的JSONL文件路径"
|
| 783 |
+
)
|
| 784 |
+
parser.add_argument(
|
| 785 |
+
"--images_per_caption",
|
| 786 |
+
type=int,
|
| 787 |
+
default=1,
|
| 788 |
+
help="每个caption生成的图像数量"
|
| 789 |
+
)
|
| 790 |
+
parser.add_argument(
|
| 791 |
+
"--max_samples",
|
| 792 |
+
type=int,
|
| 793 |
+
default=30000,
|
| 794 |
+
help="最大样本生成数量"
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# 其他参数
|
| 798 |
+
parser.add_argument(
|
| 799 |
+
"--global_seed",
|
| 800 |
+
type=int,
|
| 801 |
+
default=42,
|
| 802 |
+
help="全局随机种子"
|
| 803 |
+
)
|
| 804 |
+
parser.add_argument(
|
| 805 |
+
"--mixed_precision",
|
| 806 |
+
type=str,
|
| 807 |
+
default="fp16",
|
| 808 |
+
choices=["no", "fp16", "bf16"],
|
| 809 |
+
help="混合精度类型"
|
| 810 |
+
)
|
| 811 |
+
parser.add_argument(
|
| 812 |
+
"--enable_cpu_offload",
|
| 813 |
+
action="store_true",
|
| 814 |
+
help="启用CPU offload以节省显存"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
args = parser.parse_args()
|
| 818 |
+
main(args)
|
sample_sd3_lora_ddp.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
SD3 LoRA分布式采样脚本
|
| 5 |
+
使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import os
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import math
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
import json
|
| 18 |
+
import random
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
from diffusers import (
|
| 22 |
+
StableDiffusion3Pipeline,
|
| 23 |
+
AutoencoderKL,
|
| 24 |
+
FlowMatchEulerDiscreteScheduler,
|
| 25 |
+
SD3Transformer2DModel,
|
| 26 |
+
)
|
| 27 |
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
| 28 |
+
from accelerate import Accelerator
|
| 29 |
+
from peft import LoraConfig
|
| 30 |
+
from peft.utils import get_peft_model_state_dict
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_npz_from_sample_folder(sample_dir, num_samples):
|
| 34 |
+
"""
|
| 35 |
+
从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
|
| 36 |
+
"""
|
| 37 |
+
samples = []
|
| 38 |
+
actual_files = []
|
| 39 |
+
|
| 40 |
+
# 收集所有PNG文件
|
| 41 |
+
for filename in sorted(os.listdir(sample_dir)):
|
| 42 |
+
if filename.endswith('.png'):
|
| 43 |
+
actual_files.append(filename)
|
| 44 |
+
|
| 45 |
+
# 按照数量限制处理
|
| 46 |
+
for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
|
| 47 |
+
if i < len(actual_files):
|
| 48 |
+
sample_path = os.path.join(sample_dir, actual_files[i])
|
| 49 |
+
sample_pil = Image.open(sample_path)
|
| 50 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 51 |
+
samples.append(sample_np)
|
| 52 |
+
else:
|
| 53 |
+
# 如果不够,创建空白图像
|
| 54 |
+
sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 55 |
+
samples.append(sample_np)
|
| 56 |
+
|
| 57 |
+
if samples:
|
| 58 |
+
samples = np.stack(samples)
|
| 59 |
+
npz_path = f"{sample_dir}.npz"
|
| 60 |
+
np.savez(npz_path, arr_0=samples)
|
| 61 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 62 |
+
return npz_path
|
| 63 |
+
else:
|
| 64 |
+
print("No samples found to create npz file.")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def find_latest_checkpoint(output_dir):
|
| 69 |
+
"""
|
| 70 |
+
查找最新的检查点目录
|
| 71 |
+
"""
|
| 72 |
+
checkpoint_dirs = []
|
| 73 |
+
if os.path.exists(output_dir):
|
| 74 |
+
for item in os.listdir(output_dir):
|
| 75 |
+
if item.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, item)):
|
| 76 |
+
try:
|
| 77 |
+
step = int(item.split("-")[1])
|
| 78 |
+
checkpoint_dirs.append((step, item))
|
| 79 |
+
except (ValueError, IndexError):
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if checkpoint_dirs:
|
| 83 |
+
# 按步数排序,返回最新的
|
| 84 |
+
checkpoint_dirs.sort(key=lambda x: x[0])
|
| 85 |
+
latest_step, latest_dir = checkpoint_dirs[-1]
|
| 86 |
+
latest_path = os.path.join(output_dir, latest_dir)
|
| 87 |
+
return latest_path, latest_step
|
| 88 |
+
return None, None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def check_lora_weights_exist(lora_path):
|
| 92 |
+
"""
|
| 93 |
+
检查LoRA权重文件是否存在
|
| 94 |
+
"""
|
| 95 |
+
if not lora_path:
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
# 检查是否是目录
|
| 99 |
+
if os.path.isdir(lora_path):
|
| 100 |
+
# 检查目录中是否有pytorch_lora_weights.safetensors文件
|
| 101 |
+
weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
|
| 102 |
+
if os.path.exists(weight_file):
|
| 103 |
+
return True
|
| 104 |
+
# 检查是否有其他.safetensors文件
|
| 105 |
+
for file in os.listdir(lora_path):
|
| 106 |
+
if file.endswith(".safetensors") and "lora" in file.lower():
|
| 107 |
+
return True
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# 检查是否是文件
|
| 111 |
+
elif os.path.isfile(lora_path):
|
| 112 |
+
return lora_path.endswith(".safetensors")
|
| 113 |
+
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def check_full_finetune_checkpoint(checkpoint_path):
|
| 118 |
+
"""
|
| 119 |
+
检查是否是全量微调的checkpoint(包含model.safetensors)
|
| 120 |
+
"""
|
| 121 |
+
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
# 检查是否有model.safetensors文件(全量微调的标志)
|
| 125 |
+
model_file = os.path.join(checkpoint_path, "model.safetensors")
|
| 126 |
+
return os.path.exists(model_file)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0):
|
| 130 |
+
"""
|
| 131 |
+
从检查点加载LoRA权重
|
| 132 |
+
"""
|
| 133 |
+
if rank == 0:
|
| 134 |
+
print(f"Loading LoRA weights from checkpoint: {checkpoint_path}")
|
| 135 |
+
|
| 136 |
+
# 直接从检查点目录加载state dict
|
| 137 |
+
try:
|
| 138 |
+
# 使用accelerator来加载检查点
|
| 139 |
+
accelerator = Accelerator()
|
| 140 |
+
|
| 141 |
+
# 先配置LoRA
|
| 142 |
+
transformer_lora_config = LoraConfig(
|
| 143 |
+
r=64, # 假设使用rank=64,可以根据需要调整
|
| 144 |
+
lora_alpha=64,
|
| 145 |
+
init_lora_weights="gaussian",
|
| 146 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# 为transformer添加LoRA
|
| 150 |
+
pipeline.transformer.add_adapter(transformer_lora_config)
|
| 151 |
+
|
| 152 |
+
# 加载检查点状态
|
| 153 |
+
accelerator.load_state(checkpoint_path)
|
| 154 |
+
|
| 155 |
+
if rank == 0:
|
| 156 |
+
print(f"Successfully loaded LoRA weights from checkpoint {checkpoint_path}")
|
| 157 |
+
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
if rank == 0:
|
| 162 |
+
print(f"Error loading LoRA from checkpoint {checkpoint_path}: {e}")
|
| 163 |
+
print("Falling back to baseline model without LoRA")
|
| 164 |
+
return False
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_captions_from_jsonl(jsonl_path):
|
| 168 |
+
"""
|
| 169 |
+
从JSONL文件加载caption列表
|
| 170 |
+
"""
|
| 171 |
+
captions = []
|
| 172 |
+
try:
|
| 173 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 174 |
+
for line_num, line in enumerate(f, 1):
|
| 175 |
+
line = line.strip()
|
| 176 |
+
if not line:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
data = json.loads(line)
|
| 181 |
+
# 支持多种字段名
|
| 182 |
+
caption = None
|
| 183 |
+
for field in ['caption', 'text', 'prompt', 'description']:
|
| 184 |
+
if field in data and isinstance(data[field], str):
|
| 185 |
+
caption = data[field].strip()
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
if caption:
|
| 189 |
+
captions.append(caption)
|
| 190 |
+
else:
|
| 191 |
+
# 如果没有找到标准字段,取第一个字符串值
|
| 192 |
+
for value in data.values():
|
| 193 |
+
if isinstance(value, str) and value.strip():
|
| 194 |
+
captions.append(value.strip())
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
except json.JSONDecodeError as e:
|
| 198 |
+
print(f"Warning: Invalid JSON on line {line_num}: {e}")
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
except FileNotFoundError:
|
| 202 |
+
print(f"Error: JSONL file {jsonl_path} not found")
|
| 203 |
+
return []
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"Error loading JSONL file {jsonl_path}: {e}")
|
| 206 |
+
return []
|
| 207 |
+
|
| 208 |
+
print(f"Loaded {len(captions)} captions from {jsonl_path}")
|
| 209 |
+
return captions
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def main(args):
|
| 213 |
+
"""
|
| 214 |
+
运行 SD3 LoRA 采样
|
| 215 |
+
"""
|
| 216 |
+
assert torch.cuda.is_available(), "DDP采样需要至少一个GPU"
|
| 217 |
+
torch.set_grad_enabled(False)
|
| 218 |
+
|
| 219 |
+
# 设置 DDP
|
| 220 |
+
dist.init_process_group("nccl")
|
| 221 |
+
rank = dist.get_rank()
|
| 222 |
+
device = rank % torch.cuda.device_count()
|
| 223 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 224 |
+
torch.manual_seed(seed)
|
| 225 |
+
torch.cuda.set_device(device)
|
| 226 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 227 |
+
|
| 228 |
+
# 加载captions
|
| 229 |
+
captions = []
|
| 230 |
+
if args.captions_jsonl:
|
| 231 |
+
if rank == 0:
|
| 232 |
+
print(f"Loading captions from {args.captions_jsonl}")
|
| 233 |
+
captions = load_captions_from_jsonl(args.captions_jsonl)
|
| 234 |
+
if not captions:
|
| 235 |
+
if rank == 0:
|
| 236 |
+
print("Warning: No captions loaded, using default caption")
|
| 237 |
+
captions = ["a beautiful high quality image"]
|
| 238 |
+
else:
|
| 239 |
+
# 使用默认caption
|
| 240 |
+
captions = ["a beautiful high quality image"]
|
| 241 |
+
|
| 242 |
+
# 计算总的图片数量
|
| 243 |
+
total_images_needed = len(captions) * args.images_per_caption
|
| 244 |
+
# 应用最大样本数限制
|
| 245 |
+
total_images_needed = min(total_images_needed, args.max_samples)
|
| 246 |
+
if rank == 0:
|
| 247 |
+
print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions")
|
| 248 |
+
print(f"Total images requested: {len(captions) * args.images_per_caption}")
|
| 249 |
+
print(f"Max samples limit: {args.max_samples}")
|
| 250 |
+
print(f"Total images to generate: {total_images_needed}")
|
| 251 |
+
|
| 252 |
+
# 设置数据类型 - 使用混合精度以减少内存占用
|
| 253 |
+
if args.mixed_precision == "fp16":
|
| 254 |
+
dtype = torch.float16
|
| 255 |
+
elif args.mixed_precision == "bf16":
|
| 256 |
+
dtype = torch.bfloat16
|
| 257 |
+
else:
|
| 258 |
+
dtype = torch.float32
|
| 259 |
+
|
| 260 |
+
# 检查是否是全量微调的checkpoint
|
| 261 |
+
is_full_finetune = False
|
| 262 |
+
if args.lora_path and check_full_finetune_checkpoint(args.lora_path):
|
| 263 |
+
# 全量微调:直接从checkpoint加载
|
| 264 |
+
if rank == 0:
|
| 265 |
+
print(f"Detected full fine-tuning checkpoint, loading from: {args.lora_path}")
|
| 266 |
+
try:
|
| 267 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 268 |
+
args.lora_path,
|
| 269 |
+
revision=args.revision,
|
| 270 |
+
variant=args.variant,
|
| 271 |
+
torch_dtype=dtype,
|
| 272 |
+
)
|
| 273 |
+
is_full_finetune = True
|
| 274 |
+
lora_source = os.path.basename(args.lora_path.rstrip('/'))
|
| 275 |
+
if rank == 0:
|
| 276 |
+
print("Successfully loaded full fine-tuned model from checkpoint")
|
| 277 |
+
except Exception as e:
|
| 278 |
+
if rank == 0:
|
| 279 |
+
print(f"Failed to load full fine-tuned model: {e}")
|
| 280 |
+
print("Falling back to baseline model + LoRA loading")
|
| 281 |
+
is_full_finetune = False
|
| 282 |
+
|
| 283 |
+
# 如果不是全量微调,加载基础模型
|
| 284 |
+
if not is_full_finetune:
|
| 285 |
+
if rank == 0:
|
| 286 |
+
print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}")
|
| 287 |
+
|
| 288 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 289 |
+
args.pretrained_model_name_or_path,
|
| 290 |
+
revision=args.revision,
|
| 291 |
+
variant=args.variant,
|
| 292 |
+
torch_dtype=dtype,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# 检查和加载 LoRA 权重(仅当不是全量微调时)
|
| 296 |
+
lora_loaded = False
|
| 297 |
+
lora_source = "baseline" if not is_full_finetune else lora_source
|
| 298 |
+
|
| 299 |
+
if not is_full_finetune and args.lora_path:
|
| 300 |
+
# 检查指定的LoRA路径是否存在权重文件
|
| 301 |
+
if check_lora_weights_exist(args.lora_path):
|
| 302 |
+
if rank == 0:
|
| 303 |
+
print(f"Loading LoRA weights from specified path: {args.lora_path}")
|
| 304 |
+
try:
|
| 305 |
+
pipeline.load_lora_weights(args.lora_path)
|
| 306 |
+
lora_loaded = True
|
| 307 |
+
lora_source = os.path.basename(args.lora_path.rstrip('/'))
|
| 308 |
+
if rank == 0:
|
| 309 |
+
print("Successfully loaded LoRA weights from specified path")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
if rank == 0:
|
| 312 |
+
print(f"Failed to load LoRA from specified path: {e}")
|
| 313 |
+
else:
|
| 314 |
+
if rank == 0:
|
| 315 |
+
print(f"No LoRA weights found at specified path: {args.lora_path}")
|
| 316 |
+
|
| 317 |
+
# 如果没有成功加载LoRA权重,尝试从当前目录或检查点加载(仅当不是全量微调时)
|
| 318 |
+
if not is_full_finetune and not lora_loaded:
|
| 319 |
+
# 首先检查当前工作目录是否有权重文件
|
| 320 |
+
current_dir = os.getcwd()
|
| 321 |
+
if check_lora_weights_exist(current_dir):
|
| 322 |
+
if rank == 0:
|
| 323 |
+
print(f"Found LoRA weights in current directory: {current_dir}")
|
| 324 |
+
try:
|
| 325 |
+
pipeline.load_lora_weights(current_dir)
|
| 326 |
+
lora_loaded = True
|
| 327 |
+
lora_source = "current_dir"
|
| 328 |
+
if rank == 0:
|
| 329 |
+
print("Successfully loaded LoRA weights from current directory")
|
| 330 |
+
except Exception as e:
|
| 331 |
+
if rank == 0:
|
| 332 |
+
print(f"Failed to load LoRA from current directory: {e}")
|
| 333 |
+
|
| 334 |
+
# 如果当前目录也没有,检查是否有检查点目录
|
| 335 |
+
if not lora_loaded:
|
| 336 |
+
# 检查常见的输出目录
|
| 337 |
+
possible_output_dirs = [
|
| 338 |
+
"sd3-lora-finetuned",
|
| 339 |
+
"sd3-lora-finetuned-last",
|
| 340 |
+
"output",
|
| 341 |
+
"checkpoints"
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
checkpoint_found = False
|
| 345 |
+
for output_dir in possible_output_dirs:
|
| 346 |
+
if os.path.exists(output_dir):
|
| 347 |
+
# 首先检查输出目录是否直接包含权重文件
|
| 348 |
+
if check_lora_weights_exist(output_dir):
|
| 349 |
+
if rank == 0:
|
| 350 |
+
print(f"Found LoRA weights in output directory: {output_dir}")
|
| 351 |
+
try:
|
| 352 |
+
pipeline.load_lora_weights(output_dir)
|
| 353 |
+
lora_loaded = True
|
| 354 |
+
lora_source = output_dir
|
| 355 |
+
if rank == 0:
|
| 356 |
+
print(f"Successfully loaded LoRA weights from {output_dir}")
|
| 357 |
+
break
|
| 358 |
+
except Exception as e:
|
| 359 |
+
if rank == 0:
|
| 360 |
+
print(f"Failed to load LoRA from {output_dir}: {e}")
|
| 361 |
+
|
| 362 |
+
# 如果输出目录没有直接的权重文件,查找最新的检查点
|
| 363 |
+
if not lora_loaded:
|
| 364 |
+
latest_checkpoint, latest_step = find_latest_checkpoint(output_dir)
|
| 365 |
+
if latest_checkpoint:
|
| 366 |
+
if rank == 0:
|
| 367 |
+
print(f"Found latest checkpoint: {latest_checkpoint} (step {latest_step})")
|
| 368 |
+
|
| 369 |
+
# 尝试从检查点加载LoRA权重
|
| 370 |
+
if load_lora_from_checkpoint(pipeline, latest_checkpoint, rank):
|
| 371 |
+
lora_loaded = True
|
| 372 |
+
lora_source = f"checkpoint-{latest_step}"
|
| 373 |
+
checkpoint_found = True
|
| 374 |
+
break
|
| 375 |
+
|
| 376 |
+
if not checkpoint_found and not lora_loaded:
|
| 377 |
+
if rank == 0:
|
| 378 |
+
print("No LoRA weights or checkpoints found. Using baseline model.")
|
| 379 |
+
|
| 380 |
+
# 启用内存优化选项(必须在移动到设备之前)
|
| 381 |
+
if args.enable_cpu_offload:
|
| 382 |
+
if rank == 0:
|
| 383 |
+
print("Enabling CPU offload to save memory")
|
| 384 |
+
# CPU offload 会自动管理设备,不需要先 to(device)
|
| 385 |
+
pipeline.enable_model_cpu_offload()
|
| 386 |
+
else:
|
| 387 |
+
# 如果不使用 CPU offload,先移动到设备,然后启用其他优化
|
| 388 |
+
pipeline = pipeline.to(device)
|
| 389 |
+
if rank == 0:
|
| 390 |
+
print("Enabling memory optimization options")
|
| 391 |
+
|
| 392 |
+
# 检查并启用可用的内存优化方法
|
| 393 |
+
# 注意:所有进程都需要执行这些操作,不仅仅是 rank 0
|
| 394 |
+
if hasattr(pipeline, 'enable_attention_slicing'):
|
| 395 |
+
try:
|
| 396 |
+
pipeline.enable_attention_slicing()
|
| 397 |
+
if rank == 0:
|
| 398 |
+
print(" - Attention slicing enabled")
|
| 399 |
+
except Exception as e:
|
| 400 |
+
if rank == 0:
|
| 401 |
+
print(f" - Warning: Failed to enable attention slicing: {e}")
|
| 402 |
+
else:
|
| 403 |
+
if rank == 0:
|
| 404 |
+
print(" - Attention slicing not available for this pipeline")
|
| 405 |
+
|
| 406 |
+
# SD3 pipeline 可能不支持 enable_vae_slicing,需要检查
|
| 407 |
+
# 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
|
| 408 |
+
enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
|
| 409 |
+
if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
|
| 410 |
+
try:
|
| 411 |
+
enable_vae_slicing_method()
|
| 412 |
+
if rank == 0:
|
| 413 |
+
print(" - VAE slicing enabled")
|
| 414 |
+
except Exception as e:
|
| 415 |
+
if rank == 0:
|
| 416 |
+
print(f" - Warning: Failed to enable VAE slicing: {e}")
|
| 417 |
+
else:
|
| 418 |
+
if rank == 0:
|
| 419 |
+
print(" - VAE slicing not available for this pipeline (SD3 may not support this)")
|
| 420 |
+
|
| 421 |
+
# 禁用进度条
|
| 422 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 423 |
+
|
| 424 |
+
# 创建保存目录
|
| 425 |
+
folder_name = f"batch32-rank64-last-sd3-{lora_source}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
|
| 426 |
+
sample_folder_dir = os.path.join(args.sample_dir, folder_name)
|
| 427 |
+
|
| 428 |
+
if rank == 0:
|
| 429 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 430 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 431 |
+
# 清空caption文件
|
| 432 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 433 |
+
if os.path.exists(caption_file):
|
| 434 |
+
os.remove(caption_file)
|
| 435 |
+
dist.barrier()
|
| 436 |
+
|
| 437 |
+
# 计算采样参数
|
| 438 |
+
n = args.per_proc_batch_size
|
| 439 |
+
global_batch_size = n * dist.get_world_size()
|
| 440 |
+
|
| 441 |
+
# 检查已存在的样本数量
|
| 442 |
+
existing_samples = 0
|
| 443 |
+
if os.path.exists(sample_folder_dir):
|
| 444 |
+
existing_samples = len([
|
| 445 |
+
name for name in os.listdir(sample_folder_dir)
|
| 446 |
+
if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png")
|
| 447 |
+
])
|
| 448 |
+
|
| 449 |
+
total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size)
|
| 450 |
+
if rank == 0:
|
| 451 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 452 |
+
print(f"Existing samples: {existing_samples}")
|
| 453 |
+
|
| 454 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 455 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 456 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 457 |
+
|
| 458 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 459 |
+
done_iterations = int(int(existing_samples // dist.get_world_size()) // n)
|
| 460 |
+
|
| 461 |
+
pbar = range(done_iterations, iterations)
|
| 462 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 463 |
+
|
| 464 |
+
# 生成caption和image的映射列表
|
| 465 |
+
caption_image_pairs = []
|
| 466 |
+
for i, caption in enumerate(captions):
|
| 467 |
+
for j in range(args.images_per_caption):
|
| 468 |
+
caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
|
| 469 |
+
|
| 470 |
+
total_generated = existing_samples
|
| 471 |
+
|
| 472 |
+
# 采样循环
|
| 473 |
+
for i in pbar:
|
| 474 |
+
# 获取这个batch对应的caption
|
| 475 |
+
batch_prompts = []
|
| 476 |
+
batch_caption_info = []
|
| 477 |
+
|
| 478 |
+
for j in range(n):
|
| 479 |
+
global_index = i * global_batch_size + j * dist.get_world_size() + rank
|
| 480 |
+
if global_index < len(caption_image_pairs):
|
| 481 |
+
caption, caption_idx, image_idx = caption_image_pairs[global_index]
|
| 482 |
+
batch_prompts.append(caption)
|
| 483 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 484 |
+
else:
|
| 485 |
+
# 如果超出范围,使用最后一个caption
|
| 486 |
+
if caption_image_pairs:
|
| 487 |
+
caption, caption_idx, image_idx = caption_image_pairs[-1]
|
| 488 |
+
batch_prompts.append(caption)
|
| 489 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 490 |
+
else:
|
| 491 |
+
batch_prompts.append("a beautiful high quality image")
|
| 492 |
+
batch_caption_info.append(("a beautiful high quality image", 0, 0))
|
| 493 |
+
|
| 494 |
+
# 生成图像 - 为每个图像使用不同的随机种子
|
| 495 |
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 496 |
+
with torch.autocast(device_str, dtype=dtype):
|
| 497 |
+
# 为每个prompt生成独立的图像(使用不同的generator)
|
| 498 |
+
images = []
|
| 499 |
+
for k, prompt in enumerate(batch_prompts):
|
| 500 |
+
# 为每个图像创建独立的随机种子
|
| 501 |
+
image_seed = seed + i * 10000 + k * 1000 + rank
|
| 502 |
+
generator = torch.Generator(device=device).manual_seed(image_seed)
|
| 503 |
+
|
| 504 |
+
image = pipeline(
|
| 505 |
+
prompt=prompt,
|
| 506 |
+
negative_prompt=args.negative_prompt if args.negative_prompt else None,
|
| 507 |
+
height=args.height,
|
| 508 |
+
width=args.width,
|
| 509 |
+
num_inference_steps=args.num_inference_steps,
|
| 510 |
+
guidance_scale=args.guidance_scale,
|
| 511 |
+
generator=generator,
|
| 512 |
+
num_images_per_prompt=1,
|
| 513 |
+
).images[0]
|
| 514 |
+
images.append(image)
|
| 515 |
+
|
| 516 |
+
# 清理 GPU 缓存以释放内存
|
| 517 |
+
if k == len(batch_prompts) - 1: # 每个 batch 的最后一张图片后清理
|
| 518 |
+
torch.cuda.empty_cache()
|
| 519 |
+
|
| 520 |
+
# 保存图像
|
| 521 |
+
for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
|
| 522 |
+
global_index = i * global_batch_size + j * dist.get_world_size() + rank
|
| 523 |
+
if global_index < len(caption_image_pairs):
|
| 524 |
+
# 保存图片,文件名包含caption索引和图片索引
|
| 525 |
+
filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
|
| 526 |
+
image_path = os.path.join(sample_folder_dir, filename)
|
| 527 |
+
image.save(image_path)
|
| 528 |
+
|
| 529 |
+
# 保存caption信息到文本文件(只在rank 0上操作)
|
| 530 |
+
if rank == 0:
|
| 531 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 532 |
+
with open(caption_file, "a", encoding="utf-8") as f:
|
| 533 |
+
f.write(f"{filename}\t{caption}\n")
|
| 534 |
+
|
| 535 |
+
total_generated += global_batch_size
|
| 536 |
+
|
| 537 |
+
# 每个迭代后清理 GPU 缓存
|
| 538 |
+
torch.cuda.empty_cache()
|
| 539 |
+
|
| 540 |
+
dist.barrier()
|
| 541 |
+
|
| 542 |
+
# 确保所有进程都完成采样
|
| 543 |
+
dist.barrier()
|
| 544 |
+
|
| 545 |
+
# 创建npz文件
|
| 546 |
+
if rank == 0:
|
| 547 |
+
# 重新计算实际生成的图片数量
|
| 548 |
+
actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
|
| 549 |
+
print(f"Actually generated {actual_num_samples} images")
|
| 550 |
+
# 使用实际的图片数量或用户指定的数量,取较小值
|
| 551 |
+
npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
|
| 552 |
+
create_npz_from_sample_folder(sample_folder_dir, npz_samples)
|
| 553 |
+
print("Done.")
|
| 554 |
+
|
| 555 |
+
dist.barrier()
|
| 556 |
+
dist.destroy_process_group()
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本")
|
| 561 |
+
|
| 562 |
+
# 模型和路径参数
|
| 563 |
+
parser.add_argument(
|
| 564 |
+
"--pretrained_model_name_or_path",
|
| 565 |
+
type=str,
|
| 566 |
+
default="stabilityai/stable-diffusion-3-medium-diffusers",
|
| 567 |
+
help="预训练模型路径或HuggingFace模型ID"
|
| 568 |
+
)
|
| 569 |
+
parser.add_argument(
|
| 570 |
+
"--lora_path",
|
| 571 |
+
type=str,
|
| 572 |
+
default=None,
|
| 573 |
+
help="LoRA权重文件路径"
|
| 574 |
+
)
|
| 575 |
+
parser.add_argument(
|
| 576 |
+
"--revision",
|
| 577 |
+
type=str,
|
| 578 |
+
default=None,
|
| 579 |
+
help="模型修订版本"
|
| 580 |
+
)
|
| 581 |
+
parser.add_argument(
|
| 582 |
+
"--variant",
|
| 583 |
+
type=str,
|
| 584 |
+
default=None,
|
| 585 |
+
help="模型变体,如fp16"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# 采样参数
|
| 589 |
+
parser.add_argument(
|
| 590 |
+
"--num_inference_steps",
|
| 591 |
+
type=int,
|
| 592 |
+
default=28,
|
| 593 |
+
help="推理步数"
|
| 594 |
+
)
|
| 595 |
+
parser.add_argument(
|
| 596 |
+
"--guidance_scale",
|
| 597 |
+
type=float,
|
| 598 |
+
default=7.0,
|
| 599 |
+
help="引导尺度"
|
| 600 |
+
)
|
| 601 |
+
parser.add_argument(
|
| 602 |
+
"--height",
|
| 603 |
+
type=int,
|
| 604 |
+
default=1024,
|
| 605 |
+
help="生成图像高度"
|
| 606 |
+
)
|
| 607 |
+
parser.add_argument(
|
| 608 |
+
"--width",
|
| 609 |
+
type=int,
|
| 610 |
+
default=1024,
|
| 611 |
+
help="生成图像宽度"
|
| 612 |
+
)
|
| 613 |
+
parser.add_argument(
|
| 614 |
+
"--negative_prompt",
|
| 615 |
+
type=str,
|
| 616 |
+
default="",
|
| 617 |
+
help="负面提示词"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# 批处理和数据集参数
|
| 621 |
+
parser.add_argument(
|
| 622 |
+
"--per_proc_batch_size",
|
| 623 |
+
type=int,
|
| 624 |
+
default=1,
|
| 625 |
+
help="每个进程的批处理大小"
|
| 626 |
+
)
|
| 627 |
+
parser.add_argument(
|
| 628 |
+
"--sample_dir",
|
| 629 |
+
type=str,
|
| 630 |
+
default="sd3_lora_samples",
|
| 631 |
+
help="样本保存目录"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Caption相关参数
|
| 635 |
+
parser.add_argument(
|
| 636 |
+
"--captions_jsonl",
|
| 637 |
+
type=str,
|
| 638 |
+
required=True,
|
| 639 |
+
help="包含caption列表的JSONL文件路径"
|
| 640 |
+
)
|
| 641 |
+
parser.add_argument(
|
| 642 |
+
"--images_per_caption",
|
| 643 |
+
type=int,
|
| 644 |
+
default=1,
|
| 645 |
+
help="每个caption生成的图像数量"
|
| 646 |
+
)
|
| 647 |
+
parser.add_argument(
|
| 648 |
+
"--max_samples",
|
| 649 |
+
type=int,
|
| 650 |
+
default=30000,
|
| 651 |
+
help="最大样本生成数量"
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# 其他参数
|
| 655 |
+
parser.add_argument(
|
| 656 |
+
"--global_seed",
|
| 657 |
+
type=int,
|
| 658 |
+
default=42,
|
| 659 |
+
help="全局随机种子"
|
| 660 |
+
)
|
| 661 |
+
parser.add_argument(
|
| 662 |
+
"--mixed_precision",
|
| 663 |
+
type=str,
|
| 664 |
+
default="fp16",
|
| 665 |
+
choices=["no", "fp16", "bf16"],
|
| 666 |
+
help="混合精度类型"
|
| 667 |
+
)
|
| 668 |
+
parser.add_argument(
|
| 669 |
+
"--enable_cpu_offload",
|
| 670 |
+
action="store_true",
|
| 671 |
+
help="启用CPU offload以节省显存"
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
args = parser.parse_args()
|
| 675 |
+
main(args)
|
sample_sd3_lora_rn_pair_ddp.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
DDP对照采样:同一文本+同一初始噪声,分别生成 LoRA 与 RN 两类图像,并输出 pair 拼接图与 metadata。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import importlib.util
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_import_training_classes(project_root: str):
|
| 24 |
+
sys.path.insert(0, project_root)
|
| 25 |
+
import train_rectified_noise as trn
|
| 26 |
+
|
| 27 |
+
return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_local_pipeline_class(local_pipeline_path: str):
|
| 31 |
+
"""
|
| 32 |
+
从本地文件加载 StableDiffusion3Pipeline。
|
| 33 |
+
通过将模块名挂在 diffusers.pipelines.stable_diffusion_3 下,兼容文件内的相对导入。
|
| 34 |
+
"""
|
| 35 |
+
module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3"
|
| 36 |
+
spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path)
|
| 37 |
+
if spec is None or spec.loader is None:
|
| 38 |
+
raise ImportError(f"Failed to build import spec from: {local_pipeline_path}")
|
| 39 |
+
module = importlib.util.module_from_spec(spec)
|
| 40 |
+
spec.loader.exec_module(module)
|
| 41 |
+
if not hasattr(module, "StableDiffusion3Pipeline"):
|
| 42 |
+
raise ImportError("Local pipeline file has no StableDiffusion3Pipeline symbol.")
|
| 43 |
+
return module.StableDiffusion3Pipeline
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_captions_from_jsonl(jsonl_path):
|
| 47 |
+
captions = []
|
| 48 |
+
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 49 |
+
for line in f:
|
| 50 |
+
line = line.strip()
|
| 51 |
+
if not line:
|
| 52 |
+
continue
|
| 53 |
+
try:
|
| 54 |
+
data = json.loads(line)
|
| 55 |
+
cap = None
|
| 56 |
+
for field in ["caption", "text", "prompt", "description"]:
|
| 57 |
+
if field in data and isinstance(data[field], str):
|
| 58 |
+
cap = data[field].strip()
|
| 59 |
+
break
|
| 60 |
+
if cap:
|
| 61 |
+
captions.append(cap)
|
| 62 |
+
except Exception:
|
| 63 |
+
continue
|
| 64 |
+
return captions if captions else ["a beautiful high quality image"]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_sit_weights(rectified_module, weights_path: str):
|
| 68 |
+
if os.path.isdir(weights_path):
|
| 69 |
+
search_dirs = [weights_path, os.path.join(weights_path, "sit_weights")]
|
| 70 |
+
for d in search_dirs:
|
| 71 |
+
if not os.path.exists(d):
|
| 72 |
+
continue
|
| 73 |
+
st = os.path.join(d, "pytorch_sit_weights.safetensors")
|
| 74 |
+
if os.path.exists(st):
|
| 75 |
+
from safetensors.torch import load_file
|
| 76 |
+
|
| 77 |
+
state = load_file(st)
|
| 78 |
+
rectified_module.load_state_dict(state, strict=False)
|
| 79 |
+
return True
|
| 80 |
+
for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
|
| 81 |
+
cand = os.path.join(d, name)
|
| 82 |
+
if os.path.exists(cand):
|
| 83 |
+
state = torch.load(cand, map_location="cpu")
|
| 84 |
+
rectified_module.load_state_dict(state, strict=False)
|
| 85 |
+
return True
|
| 86 |
+
return False
|
| 87 |
+
else:
|
| 88 |
+
if weights_path.endswith(".safetensors"):
|
| 89 |
+
from safetensors.torch import load_file
|
| 90 |
+
|
| 91 |
+
state = load_file(weights_path)
|
| 92 |
+
else:
|
| 93 |
+
state = torch.load(weights_path, map_location="cpu")
|
| 94 |
+
rectified_module.load_state_dict(state, strict=False)
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_jsonl_line(path, obj):
|
| 99 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 100 |
+
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_jsonl(path):
|
| 104 |
+
if not os.path.exists(path):
|
| 105 |
+
return []
|
| 106 |
+
rows = []
|
| 107 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
line = line.strip()
|
| 110 |
+
if not line:
|
| 111 |
+
continue
|
| 112 |
+
rows.append(json.loads(line))
|
| 113 |
+
return rows
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def merge_rank_metadata(out_path, rank_paths):
|
| 117 |
+
rows = []
|
| 118 |
+
for rp in rank_paths:
|
| 119 |
+
rows.extend(load_jsonl(rp))
|
| 120 |
+
rows.sort(key=lambda x: x.get("file_name", ""))
|
| 121 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 122 |
+
for r in rows:
|
| 123 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device):
|
| 127 |
+
RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
|
| 128 |
+
tfm = base_pipeline.transformer
|
| 129 |
+
|
| 130 |
+
if hasattr(tfm.config, "joint_attention_dim") and tfm.config.joint_attention_dim is not None:
|
| 131 |
+
sit_hidden_size = tfm.config.joint_attention_dim
|
| 132 |
+
elif hasattr(tfm.config, "inner_dim") and tfm.config.inner_dim is not None:
|
| 133 |
+
sit_hidden_size = tfm.config.inner_dim
|
| 134 |
+
else:
|
| 135 |
+
sit_hidden_size = 4096
|
| 136 |
+
|
| 137 |
+
transformer_hidden_size = getattr(tfm.config, "hidden_size", 1536)
|
| 138 |
+
num_attention_heads = getattr(tfm.config, "num_attention_heads", 32)
|
| 139 |
+
input_dim = getattr(tfm.config, "in_channels", 16)
|
| 140 |
+
|
| 141 |
+
rectified_module = RectifiedNoiseModule(
|
| 142 |
+
hidden_size=sit_hidden_size,
|
| 143 |
+
num_sit_layers=num_sit_layers,
|
| 144 |
+
num_attention_heads=num_attention_heads,
|
| 145 |
+
input_dim=input_dim,
|
| 146 |
+
transformer_hidden_size=transformer_hidden_size,
|
| 147 |
+
)
|
| 148 |
+
ok = load_sit_weights(rectified_module, rectified_weights)
|
| 149 |
+
if not ok:
|
| 150 |
+
raise RuntimeError(f"Failed to load rectified weights from: {rectified_weights}")
|
| 151 |
+
|
| 152 |
+
model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device)
|
| 153 |
+
model.eval()
|
| 154 |
+
return model
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def create_npz_from_dir(sample_dir, max_samples):
|
| 158 |
+
import numpy as np
|
| 159 |
+
|
| 160 |
+
files = sorted([x for x in os.listdir(sample_dir) if x.endswith(".png") and x[:-4].isdigit()])
|
| 161 |
+
files = files[:max_samples]
|
| 162 |
+
if not files:
|
| 163 |
+
return None
|
| 164 |
+
arr = []
|
| 165 |
+
for fn in tqdm(files, desc=f"npz:{os.path.basename(sample_dir)}"):
|
| 166 |
+
arr.append(np.asarray(Image.open(os.path.join(sample_dir, fn))).astype(np.uint8))
|
| 167 |
+
arr = np.stack(arr)
|
| 168 |
+
out = f"{sample_dir}.npz"
|
| 169 |
+
np.savez(out, arr_0=arr)
|
| 170 |
+
return out
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def set_pipeline_modules_eval(pipe):
|
| 174 |
+
"""
|
| 175 |
+
Diffusers pipeline 本身没有 .eval(),需要对内部 nn.Module 分别设为 eval。
|
| 176 |
+
"""
|
| 177 |
+
for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "model"]:
|
| 178 |
+
module = getattr(pipe, name, None)
|
| 179 |
+
if module is not None and hasattr(module, "eval"):
|
| 180 |
+
module.eval()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def main(args):
|
| 184 |
+
assert torch.cuda.is_available(), "Need GPU"
|
| 185 |
+
dist.init_process_group("nccl")
|
| 186 |
+
rank = dist.get_rank()
|
| 187 |
+
world = dist.get_world_size()
|
| 188 |
+
device = rank % torch.cuda.device_count()
|
| 189 |
+
torch.cuda.set_device(device)
|
| 190 |
+
seed = args.global_seed * world + rank
|
| 191 |
+
torch.manual_seed(seed)
|
| 192 |
+
|
| 193 |
+
dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
|
| 194 |
+
|
| 195 |
+
root = Path(args.sample_dir)
|
| 196 |
+
lora_dir = root / "lora"
|
| 197 |
+
rn_dir = root / "rn"
|
| 198 |
+
pair_dir = root / "pair"
|
| 199 |
+
metadata_path = root / "metadata.jsonl"
|
| 200 |
+
lora_meta = lora_dir / "metadata.jsonl"
|
| 201 |
+
rn_meta = rn_dir / "metadata.jsonl"
|
| 202 |
+
pair_meta = pair_dir / "metadata.jsonl"
|
| 203 |
+
if rank == 0:
|
| 204 |
+
lora_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
+
rn_dir.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
pair_dir.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
dist.barrier()
|
| 208 |
+
|
| 209 |
+
if args.stage == "lora":
|
| 210 |
+
pipe_lora = DiffusersStableDiffusion3Pipeline.from_pretrained(
|
| 211 |
+
args.pretrained_model_name_or_path,
|
| 212 |
+
revision=args.revision,
|
| 213 |
+
variant=args.variant,
|
| 214 |
+
torch_dtype=dtype,
|
| 215 |
+
).to(device)
|
| 216 |
+
if args.lora_path:
|
| 217 |
+
pipe_lora.load_lora_weights(args.lora_path)
|
| 218 |
+
pipe_lora.set_progress_bar_config(disable=True)
|
| 219 |
+
set_pipeline_modules_eval(pipe_lora)
|
| 220 |
+
|
| 221 |
+
captions = load_captions_from_jsonl(args.captions_jsonl)
|
| 222 |
+
total_needed = min(len(captions) * args.images_per_caption, args.max_samples)
|
| 223 |
+
n = args.per_proc_batch_size
|
| 224 |
+
global_batch = n * world
|
| 225 |
+
total_samples = int(math.ceil(total_needed / global_batch) * global_batch)
|
| 226 |
+
iters = total_samples // global_batch
|
| 227 |
+
pbar = tqdm(range(iters)) if rank == 0 else range(iters)
|
| 228 |
+
|
| 229 |
+
rank_meta_path = root / f"metadata.rank{rank}.jsonl"
|
| 230 |
+
if rank_meta_path.exists():
|
| 231 |
+
rank_meta_path.unlink()
|
| 232 |
+
rank_lora_meta_path = lora_dir / f"metadata.rank{rank}.jsonl"
|
| 233 |
+
if rank_lora_meta_path.exists():
|
| 234 |
+
rank_lora_meta_path.unlink()
|
| 235 |
+
|
| 236 |
+
for it in pbar:
|
| 237 |
+
for k in range(n):
|
| 238 |
+
global_idx = it * global_batch + k * world + rank
|
| 239 |
+
if global_idx >= total_needed:
|
| 240 |
+
continue
|
| 241 |
+
cap_idx = global_idx // args.images_per_caption
|
| 242 |
+
prompt = captions[cap_idx]
|
| 243 |
+
image_seed = seed + it * 10000 + k * 1000
|
| 244 |
+
|
| 245 |
+
g = torch.Generator(device=device).manual_seed(image_seed)
|
| 246 |
+
latent_h = args.height // pipe_lora.vae_scale_factor
|
| 247 |
+
latent_w = args.width // pipe_lora.vae_scale_factor
|
| 248 |
+
latents = torch.randn(
|
| 249 |
+
(1, pipe_lora.transformer.config.in_channels, latent_h, latent_w),
|
| 250 |
+
device=device,
|
| 251 |
+
dtype=dtype,
|
| 252 |
+
generator=g,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
with torch.autocast("cuda", dtype=dtype):
|
| 256 |
+
img_lora = pipe_lora(
|
| 257 |
+
prompt=prompt,
|
| 258 |
+
height=args.height,
|
| 259 |
+
width=args.width,
|
| 260 |
+
num_inference_steps=args.num_inference_steps,
|
| 261 |
+
guidance_scale=args.guidance_scale,
|
| 262 |
+
latents=latents,
|
| 263 |
+
num_images_per_prompt=1,
|
| 264 |
+
).images[0]
|
| 265 |
+
|
| 266 |
+
fn = f"{global_idx:07d}.png"
|
| 267 |
+
img_lora.save(lora_dir / fn)
|
| 268 |
+
save_jsonl_line(str(rank_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed), "lora_file": f"lora/{fn}"})
|
| 269 |
+
save_jsonl_line(str(rank_lora_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed)})
|
| 270 |
+
dist.barrier()
|
| 271 |
+
|
| 272 |
+
dist.barrier()
|
| 273 |
+
if rank == 0:
|
| 274 |
+
merge_rank_metadata(str(metadata_path), [str(root / f"metadata.rank{r}.jsonl") for r in range(world)])
|
| 275 |
+
merge_rank_metadata(str(lora_meta), [str(lora_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
|
| 276 |
+
records = load_jsonl(str(metadata_path))
|
| 277 |
+
create_npz_from_dir(str(lora_dir), len(records))
|
| 278 |
+
|
| 279 |
+
elif args.stage == "rn":
|
| 280 |
+
records = load_jsonl(str(metadata_path))
|
| 281 |
+
if not records:
|
| 282 |
+
raise RuntimeError(f"metadata not found or empty: {metadata_path}. Run --stage lora first.")
|
| 283 |
+
total_needed = min(len(records), args.max_samples)
|
| 284 |
+
|
| 285 |
+
LocalStableDiffusion3Pipeline = load_local_pipeline_class(args.local_pipeline_path)
|
| 286 |
+
pipe_rn = LocalStableDiffusion3Pipeline.from_pretrained(
|
| 287 |
+
args.pretrained_model_name_or_path,
|
| 288 |
+
revision=args.revision,
|
| 289 |
+
variant=args.variant,
|
| 290 |
+
torch_dtype=dtype,
|
| 291 |
+
).to(device)
|
| 292 |
+
if args.lora_path:
|
| 293 |
+
pipe_rn.load_lora_weights(args.lora_path)
|
| 294 |
+
pipe_rn.model = build_rn_model(pipe_rn, args.rectified_weights, args.num_sit_layers, device)
|
| 295 |
+
pipe_rn.set_progress_bar_config(disable=True)
|
| 296 |
+
set_pipeline_modules_eval(pipe_rn)
|
| 297 |
+
|
| 298 |
+
rank_rn_meta_path = rn_dir / f"metadata.rank{rank}.jsonl"
|
| 299 |
+
if rank_rn_meta_path.exists():
|
| 300 |
+
rank_rn_meta_path.unlink()
|
| 301 |
+
|
| 302 |
+
assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
|
| 303 |
+
pbar = tqdm(assigned) if rank == 0 else assigned
|
| 304 |
+
for rec in pbar:
|
| 305 |
+
fn = rec["file_name"]
|
| 306 |
+
prompt = rec["caption"]
|
| 307 |
+
image_seed = int(rec["seed"])
|
| 308 |
+
g = torch.Generator(device=device).manual_seed(image_seed)
|
| 309 |
+
latent_h = args.height // pipe_rn.vae_scale_factor
|
| 310 |
+
latent_w = args.width // pipe_rn.vae_scale_factor
|
| 311 |
+
latents = torch.randn(
|
| 312 |
+
(1, pipe_rn.transformer.config.in_channels, latent_h, latent_w),
|
| 313 |
+
device=device,
|
| 314 |
+
dtype=dtype,
|
| 315 |
+
generator=g,
|
| 316 |
+
)
|
| 317 |
+
with torch.autocast("cuda", dtype=dtype):
|
| 318 |
+
img_rn = pipe_rn(
|
| 319 |
+
prompt=prompt,
|
| 320 |
+
height=args.height,
|
| 321 |
+
width=args.width,
|
| 322 |
+
num_inference_steps=args.num_inference_steps,
|
| 323 |
+
guidance_scale=args.guidance_scale,
|
| 324 |
+
latents=latents,
|
| 325 |
+
num_images_per_prompt=1,
|
| 326 |
+
).images[0]
|
| 327 |
+
img_rn.save(rn_dir / fn)
|
| 328 |
+
save_jsonl_line(str(rank_rn_meta_path), {"file_name": fn, "caption": prompt, "seed": image_seed})
|
| 329 |
+
dist.barrier()
|
| 330 |
+
if rank == 0:
|
| 331 |
+
merge_rank_metadata(str(rn_meta), [str(rn_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
|
| 332 |
+
create_npz_from_dir(str(rn_dir), total_needed)
|
| 333 |
+
|
| 334 |
+
elif args.stage == "pair":
|
| 335 |
+
records = load_jsonl(str(metadata_path))
|
| 336 |
+
if not records:
|
| 337 |
+
raise RuntimeError(f"metadata not found: {metadata_path}")
|
| 338 |
+
total_needed = min(len(records), args.max_samples)
|
| 339 |
+
rank_pair_meta_path = pair_dir / f"metadata.rank{rank}.jsonl"
|
| 340 |
+
if rank_pair_meta_path.exists():
|
| 341 |
+
rank_pair_meta_path.unlink()
|
| 342 |
+
assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
|
| 343 |
+
for rec in assigned:
|
| 344 |
+
fn = rec["file_name"]
|
| 345 |
+
lora_img_path = lora_dir / fn
|
| 346 |
+
rn_img_path = rn_dir / fn
|
| 347 |
+
if not lora_img_path.exists() or not rn_img_path.exists():
|
| 348 |
+
continue
|
| 349 |
+
img_lora = Image.open(lora_img_path).convert("RGB")
|
| 350 |
+
img_rn = Image.open(rn_img_path).convert("RGB")
|
| 351 |
+
pair = Image.new("RGB", (img_lora.width + img_rn.width, max(img_lora.height, img_rn.height)))
|
| 352 |
+
pair.paste(img_lora, (0, 0))
|
| 353 |
+
pair.paste(img_rn, (img_lora.width, 0))
|
| 354 |
+
pair.save(pair_dir / fn)
|
| 355 |
+
save_jsonl_line(
|
| 356 |
+
str(rank_pair_meta_path),
|
| 357 |
+
{"file_name": fn, "caption": rec["caption"], "seed": int(rec["seed"]), "pair_file": f"pair/{fn}"},
|
| 358 |
+
)
|
| 359 |
+
dist.barrier()
|
| 360 |
+
if rank == 0:
|
| 361 |
+
merge_rank_metadata(str(pair_meta), [str(pair_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
|
| 362 |
+
# 更新根 metadata,补齐 rn/pair 路径
|
| 363 |
+
rn_set = {r["file_name"] for r in load_jsonl(str(rn_meta))}
|
| 364 |
+
pair_set = {r["file_name"] for r in load_jsonl(str(pair_meta))}
|
| 365 |
+
merged = []
|
| 366 |
+
for r in records[:total_needed]:
|
| 367 |
+
fn = r["file_name"]
|
| 368 |
+
out = dict(r)
|
| 369 |
+
if fn in rn_set:
|
| 370 |
+
out["rn_file"] = f"rn/{fn}"
|
| 371 |
+
if fn in pair_set:
|
| 372 |
+
out["pair_file"] = f"pair/{fn}"
|
| 373 |
+
merged.append(out)
|
| 374 |
+
with open(metadata_path, "w", encoding="utf-8") as f:
|
| 375 |
+
for r in merged:
|
| 376 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(f"Unknown stage: {args.stage}")
|
| 380 |
+
|
| 381 |
+
dist.barrier()
|
| 382 |
+
if rank == 0:
|
| 383 |
+
print(f"Stage {args.stage} done. Output root: {root}")
|
| 384 |
+
dist.barrier()
|
| 385 |
+
dist.destroy_process_group()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
if __name__ == "__main__":
|
| 389 |
+
parser = argparse.ArgumentParser(description="DDP compare sampling: LoRA vs RN with same latent/prompt.")
|
| 390 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--local_pipeline_path",
|
| 393 |
+
type=str,
|
| 394 |
+
default=str(Path(__file__).parent / "pipeline_stable_diffusion_3.py"),
|
| 395 |
+
help="RN 分支使用的本地 pipeline 文件路径",
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument("--revision", type=str, default=None)
|
| 398 |
+
parser.add_argument("--variant", type=str, default=None)
|
| 399 |
+
parser.add_argument("--lora_path", type=str, default=None)
|
| 400 |
+
parser.add_argument("--rectified_weights", type=str, required=True)
|
| 401 |
+
parser.add_argument("--num_sit_layers", type=int, default=1)
|
| 402 |
+
parser.add_argument("--captions_jsonl", type=str, required=True)
|
| 403 |
+
parser.add_argument("--sample_dir", type=str, default="./sd3_lora_rn_compare")
|
| 404 |
+
parser.add_argument("--num_inference_steps", type=int, default=40)
|
| 405 |
+
parser.add_argument("--guidance_scale", type=float, default=7.0)
|
| 406 |
+
parser.add_argument("--height", type=int, default=512)
|
| 407 |
+
parser.add_argument("--width", type=int, default=512)
|
| 408 |
+
parser.add_argument("--per_proc_batch_size", type=int, default=4)
|
| 409 |
+
parser.add_argument("--images_per_caption", type=int, default=1)
|
| 410 |
+
parser.add_argument("--max_samples", type=int, default=10000)
|
| 411 |
+
parser.add_argument("--global_seed", type=int, default=42)
|
| 412 |
+
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 413 |
+
parser.add_argument("--stage", type=str, default="lora", choices=["lora", "rn", "pair"])
|
| 414 |
+
|
| 415 |
+
args = parser.parse_args()
|
| 416 |
+
main(args)
|
| 417 |
+
|
sample_sd3_rectified_ddp.py
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
分布式采样脚本:支持指定 LoRA 权重与 Rectified Noise(SIT) 权重
|
| 5 |
+
|
| 6 |
+
依据 train_rectified_noise.py 的模型结构,加载并组装 SD3WithRectifiedNoise 进行采样。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
from accelerate import Accelerator
|
| 23 |
+
from diffusers import StableDiffusion3Pipeline
|
| 24 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
| 25 |
+
from peft.utils import set_peft_model_state_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def dynamic_import_training_classes(project_root: str):
|
| 29 |
+
"""从 train_rectified_noise.py 动态导入 RectifiedNoiseModule 和 SD3WithRectifiedNoise"""
|
| 30 |
+
sys.path.insert(0, project_root)
|
| 31 |
+
try:
|
| 32 |
+
import train_rectified_noise as trn
|
| 33 |
+
return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
|
| 34 |
+
except Exception as e:
|
| 35 |
+
raise ImportError(f"无法从 train_rectified_noise.py 导入类: {e}")
|
| 36 |
+
|
| 37 |
+
def create_npz_from_sample_folder(sample_dir, num_samples):
|
| 38 |
+
"""
|
| 39 |
+
从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
|
| 40 |
+
"""
|
| 41 |
+
samples = []
|
| 42 |
+
actual_files = []
|
| 43 |
+
|
| 44 |
+
# 收集所有PNG文件
|
| 45 |
+
for filename in sorted(os.listdir(sample_dir)):
|
| 46 |
+
if filename.endswith('.png'):
|
| 47 |
+
actual_files.append(filename)
|
| 48 |
+
|
| 49 |
+
# 按照数量限制处理
|
| 50 |
+
for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
|
| 51 |
+
if i < len(actual_files):
|
| 52 |
+
sample_path = os.path.join(sample_dir, actual_files[i])
|
| 53 |
+
sample_pil = Image.open(sample_path)
|
| 54 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 55 |
+
samples.append(sample_np)
|
| 56 |
+
else:
|
| 57 |
+
# 如果不够,创建空白图像
|
| 58 |
+
sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 59 |
+
samples.append(sample_np)
|
| 60 |
+
|
| 61 |
+
if samples:
|
| 62 |
+
samples = np.stack(samples)
|
| 63 |
+
npz_path = f"{sample_dir}.npz"
|
| 64 |
+
np.savez(npz_path, arr_0=samples)
|
| 65 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 66 |
+
return npz_path
|
| 67 |
+
else:
|
| 68 |
+
print("No samples found to create npz file.")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_existing_sample_count(sample_dir):
|
| 73 |
+
"""获取已存在的样本数量和最大索引"""
|
| 74 |
+
if not os.path.exists(sample_dir):
|
| 75 |
+
return 0, -1
|
| 76 |
+
|
| 77 |
+
existing_files = []
|
| 78 |
+
for filename in os.listdir(sample_dir):
|
| 79 |
+
if filename.endswith('.png') and filename[:-4].isdigit():
|
| 80 |
+
try:
|
| 81 |
+
idx = int(filename[:-4])
|
| 82 |
+
existing_files.append(idx)
|
| 83 |
+
except ValueError:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if not existing_files:
|
| 87 |
+
return 0, -1
|
| 88 |
+
|
| 89 |
+
existing_files.sort()
|
| 90 |
+
max_index = existing_files[-1]
|
| 91 |
+
count = len(existing_files)
|
| 92 |
+
|
| 93 |
+
# 检查是否有缺失的文件(从0到max_index应该连续)
|
| 94 |
+
expected_count = max_index + 1
|
| 95 |
+
if count < expected_count:
|
| 96 |
+
print(f"Warning: Found {count} files but expected {expected_count} (missing some indices)")
|
| 97 |
+
|
| 98 |
+
return count, max_index
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_sit_weights(rectified_module, weights_path: str, rank=0):
|
| 103 |
+
"""加载 Rectified Noise(SIT) 权重,支持 .safetensors / .bin / .pt
|
| 104 |
+
支持以下目录结构:
|
| 105 |
+
- weights_path/pytorch_sit_weights.safetensors (直接在主目录)
|
| 106 |
+
- weights_path/sit_weights/pytorch_sit_weights.safetensors (在sit_weights子目录)
|
| 107 |
+
"""
|
| 108 |
+
if os.path.isdir(weights_path):
|
| 109 |
+
# 首先尝试在主目录查找
|
| 110 |
+
search_paths = [
|
| 111 |
+
weights_path, # 主目录
|
| 112 |
+
os.path.join(weights_path, "sit_weights"), # sit_weights子目录
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
for search_dir in search_paths:
|
| 116 |
+
if not os.path.exists(search_dir):
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# 优先寻找 safetensors
|
| 120 |
+
st_path = os.path.join(search_dir, "pytorch_sit_weights.safetensors")
|
| 121 |
+
if os.path.exists(st_path):
|
| 122 |
+
try:
|
| 123 |
+
from safetensors.torch import load_file
|
| 124 |
+
if rank == 0:
|
| 125 |
+
print(f"Loading rectified weights from: {st_path}")
|
| 126 |
+
state = load_file(st_path)
|
| 127 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 128 |
+
if rank == 0:
|
| 129 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 130 |
+
if missing_keys:
|
| 131 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 132 |
+
if unexpected_keys:
|
| 133 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 134 |
+
return True
|
| 135 |
+
except Exception as e:
|
| 136 |
+
if rank == 0:
|
| 137 |
+
print(f" Failed to load from {st_path}: {e}")
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
# 其次寻找 bin/pt
|
| 141 |
+
for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
|
| 142 |
+
cand = os.path.join(search_dir, name)
|
| 143 |
+
if os.path.exists(cand):
|
| 144 |
+
try:
|
| 145 |
+
if rank == 0:
|
| 146 |
+
print(f"Loading rectified weights from: {cand}")
|
| 147 |
+
state = torch.load(cand, map_location="cpu")
|
| 148 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 149 |
+
if rank == 0:
|
| 150 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 151 |
+
if missing_keys:
|
| 152 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 153 |
+
if unexpected_keys:
|
| 154 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 155 |
+
return True
|
| 156 |
+
except Exception as e:
|
| 157 |
+
if rank == 0:
|
| 158 |
+
print(f" Failed to load from {cand}: {e}")
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
# 兜底:目录下任意 pt/bin
|
| 162 |
+
try:
|
| 163 |
+
for fn in os.listdir(search_dir):
|
| 164 |
+
if fn.endswith((".pt", ".bin")):
|
| 165 |
+
cand = os.path.join(search_dir, fn)
|
| 166 |
+
try:
|
| 167 |
+
if rank == 0:
|
| 168 |
+
print(f"Loading rectified weights from: {cand}")
|
| 169 |
+
state = torch.load(cand, map_location="cpu")
|
| 170 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 171 |
+
if rank == 0:
|
| 172 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 173 |
+
return True
|
| 174 |
+
except Exception as e:
|
| 175 |
+
if rank == 0:
|
| 176 |
+
print(f" Failed to load from {cand}: {e}")
|
| 177 |
+
continue
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
if rank == 0:
|
| 182 |
+
print(f" ❌ No rectified weights found in {weights_path} or {os.path.join(weights_path, 'sit_weights')}")
|
| 183 |
+
return False
|
| 184 |
+
else:
|
| 185 |
+
# 直接文件
|
| 186 |
+
try:
|
| 187 |
+
if rank == 0:
|
| 188 |
+
print(f"Loading rectified weights from file: {weights_path}")
|
| 189 |
+
if weights_path.endswith(".safetensors"):
|
| 190 |
+
from safetensors.torch import load_file
|
| 191 |
+
state = load_file(weights_path)
|
| 192 |
+
else:
|
| 193 |
+
state = torch.load(weights_path, map_location="cpu")
|
| 194 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 195 |
+
if rank == 0:
|
| 196 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 197 |
+
if missing_keys:
|
| 198 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 199 |
+
if unexpected_keys:
|
| 200 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 201 |
+
return True
|
| 202 |
+
except Exception as e:
|
| 203 |
+
if rank == 0:
|
| 204 |
+
print(f" ❌ Failed to load rectified weights from {weights_path}: {e}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def check_lora_weights_exist(lora_path):
|
| 209 |
+
"""检查LoRA权重文件是否存在"""
|
| 210 |
+
if not lora_path:
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
if os.path.isdir(lora_path):
|
| 214 |
+
# 检查目录中是否有pytorch_lora_weights.safetensors文件
|
| 215 |
+
weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
|
| 216 |
+
if os.path.exists(weight_file):
|
| 217 |
+
return True
|
| 218 |
+
# 检查是否有其他.safetensors文件
|
| 219 |
+
for file in os.listdir(lora_path):
|
| 220 |
+
if file.endswith(".safetensors") and "lora" in file.lower():
|
| 221 |
+
return True
|
| 222 |
+
return False
|
| 223 |
+
elif os.path.isfile(lora_path):
|
| 224 |
+
return lora_path.endswith(".safetensors")
|
| 225 |
+
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0, lora_rank=64):
|
| 230 |
+
"""
|
| 231 |
+
从accelerator checkpoint目录加载LoRA权重或完整模型权重
|
| 232 |
+
如果checkpoint包含完整的模型权重(合并后的),直接加载
|
| 233 |
+
如果只包含LoRA权重,则按LoRA方式加载
|
| 234 |
+
"""
|
| 235 |
+
if rank == 0:
|
| 236 |
+
print(f"Loading weights from accelerator checkpoint: {checkpoint_path}")
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
from safetensors.torch import load_file
|
| 240 |
+
model_file = os.path.join(checkpoint_path, "model.safetensors")
|
| 241 |
+
if not os.path.exists(model_file):
|
| 242 |
+
if rank == 0:
|
| 243 |
+
print(f"Model file not found: {model_file}")
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
# 加载state dict
|
| 247 |
+
state_dict = load_file(model_file)
|
| 248 |
+
all_keys = list(state_dict.keys())
|
| 249 |
+
|
| 250 |
+
# 检测checkpoint类型:
|
| 251 |
+
# 1. 是否包含base_layer(PEFT格式,需要合并)
|
| 252 |
+
# 2. 是否包含完整的模型权重(合并���的,直接可用)
|
| 253 |
+
# 3. 是否只包含LoRA权重(需要添加适配器)
|
| 254 |
+
lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()]
|
| 255 |
+
base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()]
|
| 256 |
+
non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()]
|
| 257 |
+
|
| 258 |
+
if rank == 0:
|
| 259 |
+
print(f"Checkpoint analysis:")
|
| 260 |
+
print(f" Total keys: {len(all_keys)}")
|
| 261 |
+
print(f" LoRA keys: {len(lora_keys)}")
|
| 262 |
+
print(f" Base layer keys: {len(base_layer_keys)}")
|
| 263 |
+
print(f" Direct transformer weight keys (merged): {len(non_lora_transformer_keys)}")
|
| 264 |
+
|
| 265 |
+
# 如果包含base_layer,说明是PEFT格式,需要合并base_layer + lora
|
| 266 |
+
if len(base_layer_keys) > 0:
|
| 267 |
+
if rank == 0:
|
| 268 |
+
print(f"✓ Detected PEFT format (base_layer + LoRA), merging weights...")
|
| 269 |
+
|
| 270 |
+
# 合并base_layer和lora权重
|
| 271 |
+
merged_state_dict = {}
|
| 272 |
+
|
| 273 |
+
# 首先收集所有需要合并的模块
|
| 274 |
+
modules_to_merge = {}
|
| 275 |
+
# 记录所有非LoRA的transformer权重键名(用于调试)
|
| 276 |
+
non_lora_keys_found = []
|
| 277 |
+
|
| 278 |
+
for key in all_keys:
|
| 279 |
+
# 移除前缀
|
| 280 |
+
new_key = key
|
| 281 |
+
has_transformer_prefix = False
|
| 282 |
+
|
| 283 |
+
if key.startswith('base_model.model.transformer.'):
|
| 284 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 285 |
+
has_transformer_prefix = True
|
| 286 |
+
elif key.startswith('model.transformer.'):
|
| 287 |
+
new_key = key[len('model.transformer.'):]
|
| 288 |
+
has_transformer_prefix = True
|
| 289 |
+
elif key.startswith('transformer.'):
|
| 290 |
+
new_key = key[len('transformer.'):]
|
| 291 |
+
has_transformer_prefix = True
|
| 292 |
+
elif 'transformer' in key.lower():
|
| 293 |
+
# 可能没有前缀,但包含transformer(如直接是transformer_blocks.0...)
|
| 294 |
+
has_transformer_prefix = True
|
| 295 |
+
|
| 296 |
+
if not has_transformer_prefix:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
# 检查是否是base_layer或lora权重
|
| 300 |
+
if '.base_layer.weight' in new_key:
|
| 301 |
+
# 提取模块名(去掉.base_layer.weight部分)
|
| 302 |
+
module_key = new_key.replace('.base_layer.weight', '.weight')
|
| 303 |
+
if module_key not in modules_to_merge:
|
| 304 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 305 |
+
modules_to_merge[module_key]['base_weight'] = (key, state_dict[key])
|
| 306 |
+
elif '.base_layer.bias' in new_key:
|
| 307 |
+
module_key = new_key.replace('.base_layer.bias', '.bias')
|
| 308 |
+
if module_key not in modules_to_merge:
|
| 309 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 310 |
+
modules_to_merge[module_key]['base_bias'] = (key, state_dict[key])
|
| 311 |
+
elif '.lora_A.default.weight' in new_key:
|
| 312 |
+
module_key = new_key.replace('.lora_A.default.weight', '.weight')
|
| 313 |
+
if module_key not in modules_to_merge:
|
| 314 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 315 |
+
modules_to_merge[module_key]['lora_A'] = (key, state_dict[key])
|
| 316 |
+
elif '.lora_B.default.weight' in new_key:
|
| 317 |
+
module_key = new_key.replace('.lora_B.default.weight', '.weight')
|
| 318 |
+
if module_key not in modules_to_merge:
|
| 319 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 320 |
+
modules_to_merge[module_key]['lora_B'] = (key, state_dict[key])
|
| 321 |
+
elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower():
|
| 322 |
+
# 其他非LoRA权重(如pos_embed、time_text_embed、context_embedder等),直接使用
|
| 323 |
+
# 这些权重不在LoRA适配范围内,应该直接从checkpoint加载
|
| 324 |
+
merged_state_dict[new_key] = state_dict[key]
|
| 325 |
+
non_lora_keys_found.append(new_key)
|
| 326 |
+
|
| 327 |
+
if rank == 0:
|
| 328 |
+
print(f" Found {len(non_lora_keys_found)} non-LoRA transformer keys in checkpoint")
|
| 329 |
+
if non_lora_keys_found:
|
| 330 |
+
print(f" Sample non-LoRA keys: {non_lora_keys_found[:10]}")
|
| 331 |
+
|
| 332 |
+
# 合并权重:weight = base_weight + lora_B @ lora_A * (alpha / rank)
|
| 333 |
+
if rank == 0:
|
| 334 |
+
print(f" Merging {len(modules_to_merge)} modules...")
|
| 335 |
+
|
| 336 |
+
import torch
|
| 337 |
+
for module_key, weights in modules_to_merge.items():
|
| 338 |
+
# 处理权重(.weight)
|
| 339 |
+
if weights['base_weight'] is not None:
|
| 340 |
+
base_key, base_weight = weights['base_weight']
|
| 341 |
+
base_weight = base_weight.clone()
|
| 342 |
+
|
| 343 |
+
if weights['lora_A'] is not None and weights['lora_B'] is not None:
|
| 344 |
+
lora_A_key, lora_A = weights['lora_A']
|
| 345 |
+
lora_B_key, lora_B = weights['lora_B']
|
| 346 |
+
|
| 347 |
+
# 检测rank和alpha
|
| 348 |
+
# lora_A: [rank, in_features], lora_B: [out_features, rank]
|
| 349 |
+
rank_value = lora_A.shape[0]
|
| 350 |
+
alpha = rank_value # 通常alpha = rank
|
| 351 |
+
|
| 352 |
+
# 合并:weight = base + (lora_B @ lora_A) * (alpha / rank)
|
| 353 |
+
# lora_B @ lora_A 得到 [out_features, in_features]
|
| 354 |
+
lora_delta = torch.matmul(lora_B, lora_A)
|
| 355 |
+
|
| 356 |
+
if lora_delta.shape == base_weight.shape:
|
| 357 |
+
merged_weight = base_weight + lora_delta * (alpha / rank_value)
|
| 358 |
+
merged_state_dict[module_key] = merged_weight
|
| 359 |
+
if rank == 0 and len(modules_to_merge) <= 20:
|
| 360 |
+
print(f" ✓ Merged {module_key}: {base_weight.shape}")
|
| 361 |
+
else:
|
| 362 |
+
if rank == 0:
|
| 363 |
+
print(f" ⚠️ Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only")
|
| 364 |
+
merged_state_dict[module_key] = base_weight
|
| 365 |
+
else:
|
| 366 |
+
# 只有base权重,没有LoRA
|
| 367 |
+
merged_state_dict[module_key] = base_weight
|
| 368 |
+
|
| 369 |
+
# 处理bias(.bias)- bias通常不需要合并,直接使用base_bias
|
| 370 |
+
if '.bias' in module_key and weights['base_bias'] is not None:
|
| 371 |
+
bias_key, base_bias = weights['base_bias']
|
| 372 |
+
merged_state_dict[module_key] = base_bias.clone()
|
| 373 |
+
|
| 374 |
+
if rank == 0:
|
| 375 |
+
print(f" Merged {len(merged_state_dict)} weights")
|
| 376 |
+
print(f" Sample merged keys: {list(merged_state_dict.keys())[:5]}")
|
| 377 |
+
|
| 378 |
+
# 加载合并后的权重
|
| 379 |
+
try:
|
| 380 |
+
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False)
|
| 381 |
+
|
| 382 |
+
if rank == 0:
|
| 383 |
+
print(f" Loaded merged weights:")
|
| 384 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 385 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 386 |
+
if missing_keys:
|
| 387 |
+
print(f" Missing keys: {missing_keys}")
|
| 388 |
+
# 检查缺失的keys是否关键
|
| 389 |
+
critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out']
|
| 390 |
+
has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys)
|
| 391 |
+
if has_critical:
|
| 392 |
+
print(f" ⚠️ WARNING: Missing critical keys! These should be loaded from pretrained model.")
|
| 393 |
+
print(f" The missing keys will use values from the pretrained model (not fine-tuned).")
|
| 394 |
+
|
| 395 |
+
# 如果缺失的keys太多或包含关键组件,给出警告
|
| 396 |
+
if len(missing_keys) > 0:
|
| 397 |
+
# 这些缺失的keys会使用pretrained model的默认值
|
| 398 |
+
# 这是正常的,因为LoRA只适配了部分层,其他层保持原样
|
| 399 |
+
if rank == 0:
|
| 400 |
+
print(f" Note: Missing keys will use pretrained model weights (not fine-tuned)")
|
| 401 |
+
|
| 402 |
+
if rank == 0:
|
| 403 |
+
print(f" ✓ Successfully loaded merged model weights")
|
| 404 |
+
return True
|
| 405 |
+
|
| 406 |
+
except Exception as e:
|
| 407 |
+
if rank == 0:
|
| 408 |
+
print(f" ❌ Error loading merged weights: {e}")
|
| 409 |
+
import traceback
|
| 410 |
+
traceback.print_exc()
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
# 如果包含非LoRA的transformer权重(且没有base_layer),说明是合并后的完整模型
|
| 414 |
+
elif len(non_lora_transformer_keys) > 0:
|
| 415 |
+
if rank == 0:
|
| 416 |
+
print(f"✓ Detected merged model weights (contains full transformer weights)")
|
| 417 |
+
print(f" Loading full model weights directly...")
|
| 418 |
+
|
| 419 |
+
# 提取transformer相关的权重(包括LoRA和基础权重)
|
| 420 |
+
transformer_state_dict = {}
|
| 421 |
+
for key, value in state_dict.items():
|
| 422 |
+
# 移除可能的accelerator包装前缀
|
| 423 |
+
new_key = key
|
| 424 |
+
if key.startswith('base_model.model.transformer.'):
|
| 425 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 426 |
+
elif key.startswith('model.transformer.'):
|
| 427 |
+
new_key = key[len('model.transformer.'):]
|
| 428 |
+
elif key.startswith('transformer.'):
|
| 429 |
+
new_key = key[len('transformer.'):]
|
| 430 |
+
|
| 431 |
+
# 只保留transformer相关的权重(包括所有transformer子模块)
|
| 432 |
+
# 检查是否是transformer的权重(不包含text_encoder等)
|
| 433 |
+
if (new_key.startswith('transformer_blocks') or
|
| 434 |
+
new_key.startswith('pos_embed') or
|
| 435 |
+
new_key.startswith('time_text_embed') or
|
| 436 |
+
'lora' in new_key.lower()): # 也包含LoRA权重(如果存在)
|
| 437 |
+
transformer_state_dict[new_key] = value
|
| 438 |
+
|
| 439 |
+
if rank == 0:
|
| 440 |
+
print(f" Extracted {len(transformer_state_dict)} transformer weight keys")
|
| 441 |
+
print(f" Sample keys: {list(transformer_state_dict.keys())[:5]}")
|
| 442 |
+
|
| 443 |
+
# 直接加载到transformer(不使用LoRA适配器)
|
| 444 |
+
try:
|
| 445 |
+
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False)
|
| 446 |
+
|
| 447 |
+
if rank == 0:
|
| 448 |
+
print(f" Loaded full model weights:")
|
| 449 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 450 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 451 |
+
if missing_keys:
|
| 452 |
+
print(f" Sample missing keys: {missing_keys[:5]}")
|
| 453 |
+
if unexpected_keys:
|
| 454 |
+
print(f" Sample unexpected keys: {unexpected_keys[:5]}")
|
| 455 |
+
|
| 456 |
+
# 如果missing keys太多,可能有问题
|
| 457 |
+
if len(missing_keys) > len(transformer_state_dict) * 0.5:
|
| 458 |
+
if rank == 0:
|
| 459 |
+
print(f" ⚠️ WARNING: Too many missing keys, weights may not be fully loaded")
|
| 460 |
+
return False
|
| 461 |
+
|
| 462 |
+
if rank == 0:
|
| 463 |
+
print(f" ✓ Successfully loaded merged model weights")
|
| 464 |
+
return True
|
| 465 |
+
|
| 466 |
+
except Exception as e:
|
| 467 |
+
if rank == 0:
|
| 468 |
+
print(f" ❌ Error loading full model weights: {e}")
|
| 469 |
+
import traceback
|
| 470 |
+
traceback.print_exc()
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
# 如果只包含LoRA权重,按原来的方式加载
|
| 474 |
+
if rank == 0:
|
| 475 |
+
print(f"Detected LoRA-only weights, loading as LoRA adapter...")
|
| 476 |
+
|
| 477 |
+
# 首先尝试从checkpoint中检测实际的rank
|
| 478 |
+
detected_rank = None
|
| 479 |
+
for key, value in state_dict.items():
|
| 480 |
+
if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
|
| 481 |
+
# lora_A的形状是 [rank, hidden_size]
|
| 482 |
+
detected_rank = value.shape[0]
|
| 483 |
+
if rank == 0:
|
| 484 |
+
print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
# 如果检测到rank,使用检测到的rank;否则使用传入的rank
|
| 488 |
+
actual_rank = detected_rank if detected_rank is not None else lora_rank
|
| 489 |
+
if detected_rank is not None and detected_rank != lora_rank:
|
| 490 |
+
if rank == 0:
|
| 491 |
+
print(f"⚠️ Warning: Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank")
|
| 492 |
+
|
| 493 |
+
# 检查适配器是否已存在,如果存在则先卸载
|
| 494 |
+
# SD3Transformer2DModel没有delete_adapter方法,需要使用unload_lora_weights
|
| 495 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 496 |
+
if "default" in pipeline.transformer.peft_config:
|
| 497 |
+
if rank == 0:
|
| 498 |
+
print("Removing existing 'default' adapter before adding new one...")
|
| 499 |
+
try:
|
| 500 |
+
# 使用pipeline的unload_lora_weights方法
|
| 501 |
+
pipeline.unload_lora_weights()
|
| 502 |
+
if rank == 0:
|
| 503 |
+
print("Successfully unloaded existing LoRA adapter")
|
| 504 |
+
except Exception as e:
|
| 505 |
+
if rank == 0:
|
| 506 |
+
print(f"❌ ERROR: Could not unload existing adapter: {e}")
|
| 507 |
+
print("Cannot proceed without cleaning up adapter")
|
| 508 |
+
return False
|
| 509 |
+
|
| 510 |
+
# 先配置LoRA适配器(必须在加载之前配置)
|
| 511 |
+
# 使用检测到的或传入的rank
|
| 512 |
+
transformer_lora_config = LoraConfig(
|
| 513 |
+
r=actual_rank,
|
| 514 |
+
lora_alpha=actual_rank,
|
| 515 |
+
init_lora_weights="gaussian",
|
| 516 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# 为transformer添加LoRA适配器
|
| 520 |
+
pipeline.transformer.add_adapter(transformer_lora_config)
|
| 521 |
+
|
| 522 |
+
if rank == 0:
|
| 523 |
+
print(f"LoRA adapter configured with rank={actual_rank}")
|
| 524 |
+
|
| 525 |
+
# 继续处理LoRA权重加载(state_dict已经在上面加载了)
|
| 526 |
+
|
| 527 |
+
# 提取LoRA权重 - accelerator保存的格式
|
| 528 |
+
# 从accelerator checkpoint的model.safetensors中,键名格式可能是:
|
| 529 |
+
# - transformer_blocks.X.attn.to_q.lora_A.default.weight (PEFT格式,直接可用)
|
| 530 |
+
# - 或者包含其他前缀
|
| 531 |
+
lora_state_dict = {}
|
| 532 |
+
for key, value in state_dict.items():
|
| 533 |
+
if 'lora' in key.lower() and 'transformer' in key.lower():
|
| 534 |
+
# 检查键名格式
|
| 535 |
+
new_key = key
|
| 536 |
+
|
| 537 |
+
# 移除可能的accelerator包装前缀
|
| 538 |
+
# accelerator可能保存为: model.transformer.transformer_blocks...
|
| 539 |
+
# 或者: base_model.model.transformer.transformer_blocks...
|
| 540 |
+
if key.startswith('base_model.model.transformer.'):
|
| 541 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 542 |
+
elif key.startswith('model.transformer.'):
|
| 543 |
+
new_key = key[len('model.transformer.'):]
|
| 544 |
+
elif key.startswith('transformer.'):
|
| 545 |
+
# 如果已经是transformer_blocks开头,不需要移除transformer.前缀
|
| 546 |
+
# 因为transformer_blocks是transformer的子模块
|
| 547 |
+
if not key[len('transformer.'):].startswith('transformer_blocks'):
|
| 548 |
+
new_key = key[len('transformer.'):]
|
| 549 |
+
else:
|
| 550 |
+
new_key = key[len('transformer.'):]
|
| 551 |
+
|
| 552 |
+
# 只保留transformer相关的LoRA权重
|
| 553 |
+
if 'transformer_blocks' in new_key or 'transformer' in new_key:
|
| 554 |
+
lora_state_dict[new_key] = value
|
| 555 |
+
|
| 556 |
+
if not lora_state_dict:
|
| 557 |
+
if rank == 0:
|
| 558 |
+
print("No LoRA weights found in checkpoint")
|
| 559 |
+
# 打印所有键名用于调试
|
| 560 |
+
all_keys = list(state_dict.keys())
|
| 561 |
+
print(f"Total keys: {len(all_keys)}")
|
| 562 |
+
print(f"First 20 keys: {all_keys[:20]}")
|
| 563 |
+
# 查找包含lora的键
|
| 564 |
+
lora_related = [k for k in all_keys if 'lora' in k.lower()]
|
| 565 |
+
if lora_related:
|
| 566 |
+
print(f"Keys containing 'lora': {lora_related[:10]}")
|
| 567 |
+
return False
|
| 568 |
+
|
| 569 |
+
if rank == 0:
|
| 570 |
+
print(f"Found {len(lora_state_dict)} LoRA weight keys")
|
| 571 |
+
sample_keys = list(lora_state_dict.keys())[:5]
|
| 572 |
+
print(f"Sample LoRA keys: {sample_keys}")
|
| 573 |
+
|
| 574 |
+
# 加载LoRA权重到transformer
|
| 575 |
+
# 注意:从checkpoint提取的键名格式已经是PEFT格式(如:transformer_blocks.0.attn.to_q.lora_A.default.weight)
|
| 576 |
+
# 不需要使用convert_unet_state_dict_to_peft转换,直接使用即可
|
| 577 |
+
try:
|
| 578 |
+
# 检查键名格式
|
| 579 |
+
sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
|
| 580 |
+
|
| 581 |
+
if rank == 0:
|
| 582 |
+
print(f"Original key format: {sample_key}")
|
| 583 |
+
|
| 584 |
+
# 关键问题:set_peft_model_state_dict期望的键名格式
|
| 585 |
+
# 从back/train_dreambooth_lora.py看,需要移除.default后缀
|
| 586 |
+
# 格式应该是:transformer_blocks.X.attn.to_q.lora_A.weight(没有.default)
|
| 587 |
+
# 但accelerator保存的格式是:transformer_blocks.X.attn.to_q.lora_A.default.weight(有.default)
|
| 588 |
+
|
| 589 |
+
# 检查键名格式
|
| 590 |
+
sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
|
| 591 |
+
has_default_suffix = '.default.weight' in sample_key or '.default.bias' in sample_key
|
| 592 |
+
|
| 593 |
+
if rank == 0:
|
| 594 |
+
print(f"Sample key: {sample_key}")
|
| 595 |
+
print(f"Has .default suffix: {has_default_suffix}")
|
| 596 |
+
|
| 597 |
+
# 如果键名包含.default.weight或.default.bias,需要移除.default部分
|
| 598 |
+
# 因为set_peft_model_state_dict期望的格式是:lora_A.weight,而不是lora_A.default.weight
|
| 599 |
+
converted_dict = {}
|
| 600 |
+
for key, value in lora_state_dict.items():
|
| 601 |
+
# 移除.default后缀(如果存在)
|
| 602 |
+
# transformer_blocks.0.attn.to_q.lora_A.default.weight -> transformer_blocks.0.attn.to_q.lora_A.weight
|
| 603 |
+
new_key = key
|
| 604 |
+
if '.default.weight' in new_key:
|
| 605 |
+
new_key = new_key.replace('.default.weight', '.weight')
|
| 606 |
+
elif '.default.bias' in new_key:
|
| 607 |
+
new_key = new_key.replace('.default.bias', '.bias')
|
| 608 |
+
elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')):
|
| 609 |
+
# 处理其他可能的.default位置
|
| 610 |
+
new_key = new_key.replace('.default', '')
|
| 611 |
+
|
| 612 |
+
converted_dict[new_key] = value
|
| 613 |
+
|
| 614 |
+
if rank == 0:
|
| 615 |
+
print(f"Converted {len(converted_dict)} keys (removed .default suffix if present)")
|
| 616 |
+
print(f"Sample converted keys: {list(converted_dict.keys())[:5]}")
|
| 617 |
+
|
| 618 |
+
# 调用set_peft_model_state_dict并检查返回值
|
| 619 |
+
incompatible_keys = set_peft_model_state_dict(
|
| 620 |
+
pipeline.transformer,
|
| 621 |
+
converted_dict,
|
| 622 |
+
adapter_name="default"
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# 检查加载结果
|
| 626 |
+
if incompatible_keys is not None:
|
| 627 |
+
missing_keys = getattr(incompatible_keys, "missing_keys", [])
|
| 628 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", [])
|
| 629 |
+
|
| 630 |
+
if rank == 0:
|
| 631 |
+
print(f"LoRA loading result:")
|
| 632 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 633 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 634 |
+
|
| 635 |
+
if len(missing_keys) > 100:
|
| 636 |
+
print(f" ⚠️ WARNING: Too many missing keys ({len(missing_keys)}), LoRA may not be fully loaded!")
|
| 637 |
+
print(f" Sample missing keys: {missing_keys[:10]}")
|
| 638 |
+
elif missing_keys:
|
| 639 |
+
print(f" Sample missing keys: {missing_keys[:10]}")
|
| 640 |
+
|
| 641 |
+
if unexpected_keys:
|
| 642 |
+
print(f" Unexpected keys: {unexpected_keys[:10]}")
|
| 643 |
+
|
| 644 |
+
# 如果missing keys太多,说明加载失败
|
| 645 |
+
if len(missing_keys) > len(converted_dict) * 0.5: # 超过50%的键缺失
|
| 646 |
+
if rank == 0:
|
| 647 |
+
print("❌ ERROR: Too many missing keys, LoRA weights not loaded correctly!")
|
| 648 |
+
return False
|
| 649 |
+
else:
|
| 650 |
+
if rank == 0:
|
| 651 |
+
print("✓ LoRA weights loaded (no incompatible keys reported)")
|
| 652 |
+
|
| 653 |
+
except RuntimeError as e:
|
| 654 |
+
# 检查是否是size mismatch错误
|
| 655 |
+
error_str = str(e)
|
| 656 |
+
if "size mismatch" in error_str:
|
| 657 |
+
if rank == 0:
|
| 658 |
+
print(f"❌ Size mismatch error: The checkpoint rank doesn't match the adapter rank")
|
| 659 |
+
print(f" This usually means the checkpoint was trained with a different rank")
|
| 660 |
+
# 尝试从错误信息中提取期望的rank
|
| 661 |
+
import re
|
| 662 |
+
# 错误信息格式: "copying a param with shape torch.Size([32, 1536]) from checkpoint"
|
| 663 |
+
match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str)
|
| 664 |
+
if match:
|
| 665 |
+
checkpoint_rank = int(match.group(1))
|
| 666 |
+
if rank == 0:
|
| 667 |
+
print(f" Detected checkpoint rank: {checkpoint_rank}")
|
| 668 |
+
print(f" Adapter was configured with rank: {actual_rank}")
|
| 669 |
+
if checkpoint_rank != actual_rank:
|
| 670 |
+
print(f" ⚠️ Mismatch! Need to recreate adapter with rank={checkpoint_rank}")
|
| 671 |
+
else:
|
| 672 |
+
if rank == 0:
|
| 673 |
+
print(f"❌ Error setting LoRA state dict: {e}")
|
| 674 |
+
import traceback
|
| 675 |
+
traceback.print_exc()
|
| 676 |
+
# 清理适配器以便下次尝试
|
| 677 |
+
try:
|
| 678 |
+
pipeline.unload_lora_weights()
|
| 679 |
+
except:
|
| 680 |
+
pass
|
| 681 |
+
return False
|
| 682 |
+
except Exception as e:
|
| 683 |
+
if rank == 0:
|
| 684 |
+
print(f"❌ Error setting LoRA state dict: {e}")
|
| 685 |
+
import traceback
|
| 686 |
+
traceback.print_exc()
|
| 687 |
+
# 清理适配器以便下次尝试
|
| 688 |
+
try:
|
| 689 |
+
pipeline.unload_lora_weights()
|
| 690 |
+
except:
|
| 691 |
+
pass
|
| 692 |
+
return False
|
| 693 |
+
|
| 694 |
+
# 启用LoRA适配器
|
| 695 |
+
pipeline.transformer.set_adapter("default")
|
| 696 |
+
|
| 697 |
+
# 验证LoRA是否已加载和应用
|
| 698 |
+
if hasattr(pipeline.transformer, 'peft_config'):
|
| 699 |
+
adapters = list(pipeline.transformer.peft_config.keys())
|
| 700 |
+
if rank == 0:
|
| 701 |
+
print(f"LoRA adapters configured: {adapters}")
|
| 702 |
+
# 检查适配器是否启用
|
| 703 |
+
if hasattr(pipeline.transformer, 'active_adapters'):
|
| 704 |
+
# active_adapters 是一个方法,需要调用
|
| 705 |
+
try:
|
| 706 |
+
if callable(pipeline.transformer.active_adapters):
|
| 707 |
+
active = pipeline.transformer.active_adapters()
|
| 708 |
+
else:
|
| 709 |
+
active = pipeline.transformer.active_adapters
|
| 710 |
+
if rank == 0:
|
| 711 |
+
print(f"Active adapters: {active}")
|
| 712 |
+
except:
|
| 713 |
+
if rank == 0:
|
| 714 |
+
print("Could not get active adapters, but LoRA is configured")
|
| 715 |
+
|
| 716 |
+
# 验证LoRA权���是否真的被应用
|
| 717 |
+
# 检查LoRA层的权重是否非零
|
| 718 |
+
lora_layers_found = 0
|
| 719 |
+
nonzero_lora_layers = 0
|
| 720 |
+
total_lora_weight_sum = 0.0
|
| 721 |
+
|
| 722 |
+
for name, module in pipeline.transformer.named_modules():
|
| 723 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 724 |
+
lora_layers_found += 1
|
| 725 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 726 |
+
weight_sum = module.weight.abs().sum().item()
|
| 727 |
+
total_lora_weight_sum += weight_sum
|
| 728 |
+
if weight_sum > 1e-6: # 非零阈值
|
| 729 |
+
nonzero_lora_layers += 1
|
| 730 |
+
if rank == 0 and nonzero_lora_layers <= 3: # 只打印前3个
|
| 731 |
+
print(f"✓ Found non-zero LoRA weight in: {name}, sum={weight_sum:.6f}")
|
| 732 |
+
|
| 733 |
+
if rank == 0:
|
| 734 |
+
print(f"LoRA verification:")
|
| 735 |
+
print(f" Total LoRA layers found: {lora_layers_found}")
|
| 736 |
+
print(f" Non-zero LoRA layers: {nonzero_lora_layers}")
|
| 737 |
+
print(f" Total LoRA weight sum: {total_lora_weight_sum:.6f}")
|
| 738 |
+
|
| 739 |
+
if lora_layers_found == 0:
|
| 740 |
+
print("❌ ERROR: No LoRA layers found in transformer!")
|
| 741 |
+
return False
|
| 742 |
+
elif nonzero_lora_layers == 0:
|
| 743 |
+
print("❌ ERROR: All LoRA weights are zero, LoRA not loaded correctly!")
|
| 744 |
+
return False
|
| 745 |
+
elif nonzero_lora_layers < lora_layers_found * 0.5:
|
| 746 |
+
print(f"⚠️ WARNING: Only {nonzero_lora_layers}/{lora_layers_found} LoRA layers have non-zero weights!")
|
| 747 |
+
print("⚠️ LoRA may not be fully applied!")
|
| 748 |
+
else:
|
| 749 |
+
print(f"✓ LoRA weights verified: {nonzero_lora_layers}/{lora_layers_found} layers have non-zero weights")
|
| 750 |
+
|
| 751 |
+
if nonzero_lora_layers == 0:
|
| 752 |
+
return False
|
| 753 |
+
|
| 754 |
+
if rank == 0:
|
| 755 |
+
print("✓ Successfully loaded and verified LoRA weights from checkpoint")
|
| 756 |
+
|
| 757 |
+
return True
|
| 758 |
+
|
| 759 |
+
except Exception as e:
|
| 760 |
+
if rank == 0:
|
| 761 |
+
print(f"Error loading LoRA from checkpoint: {e}")
|
| 762 |
+
import traceback
|
| 763 |
+
traceback.print_exc()
|
| 764 |
+
return False
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def load_captions_from_jsonl(jsonl_path):
|
| 768 |
+
captions = []
|
| 769 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 770 |
+
for line in f:
|
| 771 |
+
line = line.strip()
|
| 772 |
+
if not line:
|
| 773 |
+
continue
|
| 774 |
+
try:
|
| 775 |
+
data = json.loads(line)
|
| 776 |
+
cap = None
|
| 777 |
+
for field in ['caption', 'text', 'prompt', 'description']:
|
| 778 |
+
if field in data and isinstance(data[field], str):
|
| 779 |
+
cap = data[field].strip()
|
| 780 |
+
break
|
| 781 |
+
if cap:
|
| 782 |
+
captions.append(cap)
|
| 783 |
+
except Exception:
|
| 784 |
+
continue
|
| 785 |
+
return captions if captions else ["a beautiful high quality image"]
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def main(args):
|
| 789 |
+
assert torch.cuda.is_available(), "需要GPU运行"
|
| 790 |
+
dist.init_process_group("nccl")
|
| 791 |
+
rank = dist.get_rank()
|
| 792 |
+
world_size = dist.get_world_size()
|
| 793 |
+
device = rank % torch.cuda.device_count()
|
| 794 |
+
torch.cuda.set_device(device)
|
| 795 |
+
seed = args.global_seed * world_size + rank
|
| 796 |
+
torch.manual_seed(seed)
|
| 797 |
+
|
| 798 |
+
print(f"[rank{rank}] DDP initialized, device={device}, seed={seed}, world_size={world_size}")
|
| 799 |
+
|
| 800 |
+
# 调试:打印接收到的参数
|
| 801 |
+
if rank == 0:
|
| 802 |
+
print("=" * 80)
|
| 803 |
+
print("参数检查:")
|
| 804 |
+
print(f" lora_path: {args.lora_path}")
|
| 805 |
+
print(f" rectified_weights: {args.rectified_weights}")
|
| 806 |
+
print(f" lora_path is None: {args.lora_path is None}")
|
| 807 |
+
print(f" lora_path is empty: {args.lora_path == '' if args.lora_path else 'N/A'}")
|
| 808 |
+
print(f" rectified_weights is None: {args.rectified_weights is None}")
|
| 809 |
+
print(f" rectified_weights is empty: {args.rectified_weights == '' if args.rectified_weights else 'N/A'}")
|
| 810 |
+
print("=" * 80)
|
| 811 |
+
|
| 812 |
+
# 导入训练脚本中的类
|
| 813 |
+
RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
|
| 814 |
+
|
| 815 |
+
# 加载 pipeline
|
| 816 |
+
dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
|
| 817 |
+
if rank == 0:
|
| 818 |
+
print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path} (dtype={dtype})")
|
| 819 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 820 |
+
args.pretrained_model_name_or_path,
|
| 821 |
+
revision=args.revision,
|
| 822 |
+
variant=args.variant,
|
| 823 |
+
torch_dtype=dtype,
|
| 824 |
+
).to(device)
|
| 825 |
+
|
| 826 |
+
print(f"[rank{rank}] Pipeline loaded and moved to device {device}")
|
| 827 |
+
|
| 828 |
+
# 加载 LoRA(可选)
|
| 829 |
+
lora_loaded = False
|
| 830 |
+
if args.lora_path:
|
| 831 |
+
if rank == 0:
|
| 832 |
+
print(f"Attempting to load LoRA weights from: {args.lora_path}")
|
| 833 |
+
print(f"LoRA path exists: {os.path.exists(args.lora_path) if args.lora_path else False}")
|
| 834 |
+
|
| 835 |
+
# 首��检查是否是标准的LoRA权重文件/目录
|
| 836 |
+
if check_lora_weights_exist(args.lora_path):
|
| 837 |
+
if rank == 0:
|
| 838 |
+
print("Found standard LoRA weights, loading...")
|
| 839 |
+
try:
|
| 840 |
+
# 检查加载前的transformer参数(用于验证)
|
| 841 |
+
if rank == 0:
|
| 842 |
+
sample_param_before = next(iter(pipeline.transformer.parameters())).clone()
|
| 843 |
+
print(f"Sample transformer param before LoRA (first 5 values): {sample_param_before.flatten()[:5]}")
|
| 844 |
+
|
| 845 |
+
pipeline.load_lora_weights(args.lora_path)
|
| 846 |
+
lora_loaded = True
|
| 847 |
+
|
| 848 |
+
# 验证LoRA是否真的被加载
|
| 849 |
+
if rank == 0:
|
| 850 |
+
sample_param_after = next(iter(pipeline.transformer.parameters())).clone()
|
| 851 |
+
param_diff = (sample_param_after - sample_param_before).abs().max().item()
|
| 852 |
+
print(f"Sample transformer param after LoRA (first 5 values): {sample_param_after.flatten()[:5]}")
|
| 853 |
+
print(f"Max parameter change after LoRA loading: {param_diff}")
|
| 854 |
+
if param_diff < 1e-6:
|
| 855 |
+
print("⚠️ WARNING: LoRA weights may not have been applied (parameter change is very small)")
|
| 856 |
+
else:
|
| 857 |
+
print("✓ LoRA weights appear to have been applied")
|
| 858 |
+
|
| 859 |
+
# 检查是否有peft_config
|
| 860 |
+
if hasattr(pipeline.transformer, 'peft_config'):
|
| 861 |
+
print(f"✓ PEFT config found: {list(pipeline.transformer.peft_config.keys())}")
|
| 862 |
+
else:
|
| 863 |
+
print("⚠️ WARNING: No peft_config found after loading LoRA")
|
| 864 |
+
|
| 865 |
+
if rank == 0:
|
| 866 |
+
print("LoRA loaded successfully from standard format.")
|
| 867 |
+
except Exception as e:
|
| 868 |
+
if rank == 0:
|
| 869 |
+
print(f"Failed to load LoRA from standard format: {e}")
|
| 870 |
+
import traceback
|
| 871 |
+
traceback.print_exc()
|
| 872 |
+
|
| 873 |
+
# 如果不是标准格式,尝试从accelerator checkpoint加载
|
| 874 |
+
if not lora_loaded and os.path.isdir(args.lora_path):
|
| 875 |
+
if rank == 0:
|
| 876 |
+
print("Standard LoRA weights not found, trying accelerator checkpoint format...")
|
| 877 |
+
|
| 878 |
+
# 首先尝试从checkpoint的model.safetensors中检测实际的rank
|
| 879 |
+
# 通过检查LoRA权重的形状来推断rank
|
| 880 |
+
detected_rank = None
|
| 881 |
+
try:
|
| 882 |
+
from safetensors.torch import load_file
|
| 883 |
+
model_file = os.path.join(args.lora_path, "model.safetensors")
|
| 884 |
+
if os.path.exists(model_file):
|
| 885 |
+
state_dict = load_file(model_file)
|
| 886 |
+
# 查找一个LoRA权重来确定rank
|
| 887 |
+
for key, value in state_dict.items():
|
| 888 |
+
if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
|
| 889 |
+
# lora_A的形状是 [rank, hidden_size]
|
| 890 |
+
detected_rank = value.shape[0]
|
| 891 |
+
if rank == 0:
|
| 892 |
+
print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
|
| 893 |
+
break
|
| 894 |
+
except Exception as e:
|
| 895 |
+
if rank == 0:
|
| 896 |
+
print(f"Could not detect rank from checkpoint: {e}")
|
| 897 |
+
|
| 898 |
+
# 构建rank尝试列表
|
| 899 |
+
# 如果检测到rank,优先使用检测到的rank,只尝试一次
|
| 900 |
+
# 如果未检测到,尝试常见的rank值
|
| 901 |
+
if detected_rank is not None:
|
| 902 |
+
rank_list = [detected_rank]
|
| 903 |
+
if rank == 0:
|
| 904 |
+
print(f"Using detected rank: {detected_rank}")
|
| 905 |
+
else:
|
| 906 |
+
# 如果检测失败,尝试常见的rank值(按用户指定的rank优先)
|
| 907 |
+
rank_list = []
|
| 908 |
+
# 如果用户指定了rank(从args.lora_rank),优先尝试
|
| 909 |
+
if hasattr(args, 'lora_rank') and args.lora_rank:
|
| 910 |
+
rank_list.append(args.lora_rank)
|
| 911 |
+
# 添加其他常见的rank值
|
| 912 |
+
for r in [32, 64, 16, 128]:
|
| 913 |
+
if r not in rank_list:
|
| 914 |
+
rank_list.append(r)
|
| 915 |
+
if rank == 0:
|
| 916 |
+
print(f"Rank detection failed, will try ranks in order: {rank_list}")
|
| 917 |
+
|
| 918 |
+
# 尝试不同的rank值
|
| 919 |
+
for lora_rank in rank_list:
|
| 920 |
+
# 在尝试新的rank之前,先清理已存在的适配器
|
| 921 |
+
# 重要:每次尝试前都要清理,否则适配器会保留之前的rank配置
|
| 922 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 923 |
+
if "default" in pipeline.transformer.peft_config:
|
| 924 |
+
try:
|
| 925 |
+
# 使用pipeline的unload_lora_weights方法
|
| 926 |
+
pipeline.unload_lora_weights()
|
| 927 |
+
if rank == 0:
|
| 928 |
+
print(f"Cleaned up existing adapter before trying rank={lora_rank}")
|
| 929 |
+
except Exception as e:
|
| 930 |
+
if rank == 0:
|
| 931 |
+
print(f"Warning: Could not unload adapter: {e}")
|
| 932 |
+
# 如果卸载失败,需要重新创建pipeline
|
| 933 |
+
if rank == 0:
|
| 934 |
+
print("⚠️ WARNING: Cannot unload adapter, will recreate pipeline...")
|
| 935 |
+
# 重新加载pipeline(最后手段)
|
| 936 |
+
try:
|
| 937 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 938 |
+
args.pretrained_model_name_or_path,
|
| 939 |
+
revision=args.revision,
|
| 940 |
+
variant=args.variant,
|
| 941 |
+
torch_dtype=dtype,
|
| 942 |
+
).to(device)
|
| 943 |
+
if rank == 0:
|
| 944 |
+
print("Pipeline recreated to clear adapter state")
|
| 945 |
+
except Exception as e2:
|
| 946 |
+
if rank == 0:
|
| 947 |
+
print(f"Failed to recreate pipeline: {e2}")
|
| 948 |
+
|
| 949 |
+
if rank == 0:
|
| 950 |
+
print(f"Trying to load with LoRA rank={lora_rank}...")
|
| 951 |
+
lora_loaded = load_lora_from_checkpoint(pipeline, args.lora_path, rank=rank, lora_rank=lora_rank)
|
| 952 |
+
if lora_loaded:
|
| 953 |
+
if rank == 0:
|
| 954 |
+
print(f"✓ Successfully loaded LoRA with rank={lora_rank}")
|
| 955 |
+
break
|
| 956 |
+
elif rank == 0:
|
| 957 |
+
print(f"✗ Failed to load with rank={lora_rank}, trying next rank...")
|
| 958 |
+
|
| 959 |
+
# 如果checkpoint目录加载失败,尝试从输出目录的根目录加载标准LoRA权重
|
| 960 |
+
if not lora_loaded and os.path.isdir(args.lora_path):
|
| 961 |
+
# 检查输出目录的根目录(checkpoint的父目录)
|
| 962 |
+
output_dir = os.path.dirname(args.lora_path.rstrip('/'))
|
| 963 |
+
if output_dir and os.path.exists(output_dir):
|
| 964 |
+
if rank == 0:
|
| 965 |
+
print(f"Trying to load standard LoRA weights from output directory: {output_dir}")
|
| 966 |
+
if check_lora_weights_exist(output_dir):
|
| 967 |
+
try:
|
| 968 |
+
pipeline.load_lora_weights(output_dir)
|
| 969 |
+
lora_loaded = True
|
| 970 |
+
if rank == 0:
|
| 971 |
+
print("LoRA loaded successfully from output directory.")
|
| 972 |
+
except Exception as e:
|
| 973 |
+
if rank == 0:
|
| 974 |
+
print(f"Failed to load LoRA from output directory: {e}")
|
| 975 |
+
|
| 976 |
+
if not lora_loaded:
|
| 977 |
+
if rank == 0:
|
| 978 |
+
print(f"⚠️ WARNING: Failed to load LoRA weights from {args.lora_path}, using baseline model")
|
| 979 |
+
else:
|
| 980 |
+
# 最终验证LoRA是否真的被启用
|
| 981 |
+
if rank == 0:
|
| 982 |
+
print("=" * 80)
|
| 983 |
+
print("LoRA 加载验证:")
|
| 984 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 985 |
+
print(f" ✓ PEFT config exists: {list(pipeline.transformer.peft_config.keys())}")
|
| 986 |
+
# 检查LoRA层的权重
|
| 987 |
+
lora_layers_found = 0
|
| 988 |
+
for name, module in pipeline.transformer.named_modules():
|
| 989 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 990 |
+
lora_layers_found += 1
|
| 991 |
+
if lora_layers_found <= 3: # 只打印前3个
|
| 992 |
+
if hasattr(module, 'weight'):
|
| 993 |
+
weight_sum = module.weight.abs().sum().item() if module.weight is not None else 0
|
| 994 |
+
print(f" ✓ Found LoRA layer: {name}, weight_sum={weight_sum:.6f}")
|
| 995 |
+
print(f" ✓ Total LoRA layers found: {lora_layers_found}")
|
| 996 |
+
if lora_layers_found == 0:
|
| 997 |
+
print(" ⚠️ WARNING: No LoRA layers found in transformer!")
|
| 998 |
+
else:
|
| 999 |
+
print(" ⚠️ WARNING: No PEFT config found - LoRA may not be active!")
|
| 1000 |
+
print("=" * 80)
|
| 1001 |
+
|
| 1002 |
+
# 构建 RectifiedNoiseModule 并加载权重(仅在提供了 rectified_weights 时)
|
| 1003 |
+
# 安全地检查 rectified_weights 是否有效
|
| 1004 |
+
use_rectified = False
|
| 1005 |
+
rectified_weights_path = None
|
| 1006 |
+
if args.rectified_weights:
|
| 1007 |
+
rectified_weights_str = str(args.rectified_weights).strip()
|
| 1008 |
+
if rectified_weights_str:
|
| 1009 |
+
use_rectified = True
|
| 1010 |
+
rectified_weights_path = rectified_weights_str
|
| 1011 |
+
|
| 1012 |
+
if rank == 0:
|
| 1013 |
+
print(f"use_rectified: {use_rectified}, rectified_weights_path: {rectified_weights_path}")
|
| 1014 |
+
|
| 1015 |
+
if use_rectified:
|
| 1016 |
+
if rank == 0:
|
| 1017 |
+
print(f"Using Rectified Noise module with weights from: {rectified_weights_path}")
|
| 1018 |
+
print(f"[rank{rank}] RectifiedNoiseModule configuration: num_sit_layers={args.num_sit_layers}")
|
| 1019 |
+
|
| 1020 |
+
# 从 transformer 配置推断必要尺寸
|
| 1021 |
+
tfm = pipeline.transformer
|
| 1022 |
+
if hasattr(tfm.config, 'joint_attention_dim') and tfm.config.joint_attention_dim is not None:
|
| 1023 |
+
sit_hidden_size = tfm.config.joint_attention_dim
|
| 1024 |
+
elif hasattr(tfm.config, 'inner_dim') and tfm.config.inner_dim is not None:
|
| 1025 |
+
sit_hidden_size = tfm.config.inner_dim
|
| 1026 |
+
elif hasattr(tfm.config, 'hidden_size') and tfm.config.hidden_size is not None:
|
| 1027 |
+
sit_hidden_size = tfm.config.hidden_size
|
| 1028 |
+
else:
|
| 1029 |
+
sit_hidden_size = 4096
|
| 1030 |
+
|
| 1031 |
+
transformer_hidden_size = getattr(tfm.config, 'hidden_size', 1536)
|
| 1032 |
+
num_attention_heads = getattr(tfm.config, 'num_attention_heads', 32)
|
| 1033 |
+
input_dim = getattr(tfm.config, 'in_channels', 16)
|
| 1034 |
+
|
| 1035 |
+
rectified_module = RectifiedNoiseModule(
|
| 1036 |
+
hidden_size=sit_hidden_size,
|
| 1037 |
+
num_sit_layers=args.num_sit_layers,
|
| 1038 |
+
num_attention_heads=num_attention_heads,
|
| 1039 |
+
input_dim=input_dim,
|
| 1040 |
+
transformer_hidden_size=transformer_hidden_size,
|
| 1041 |
+
)
|
| 1042 |
+
# 加载 SIT 权重
|
| 1043 |
+
ok = load_sit_weights(rectified_module, rectified_weights_path, rank=rank)
|
| 1044 |
+
if rank == 0:
|
| 1045 |
+
if not ok:
|
| 1046 |
+
print("⚠️ Warning: Failed to load rectified weights, will use baseline model without rectified noise")
|
| 1047 |
+
else:
|
| 1048 |
+
print("✓ Successfully loaded rectified noise weights")
|
| 1049 |
+
|
| 1050 |
+
# 组装 SD3WithRectifiedNoise
|
| 1051 |
+
# 关键:SD3WithRectifiedNoise 会保留 transformer 的引用
|
| 1052 |
+
# 但是,SD3WithRectifiedNoise 在 __init__ 中会冻结 transformer 参数
|
| 1053 |
+
# 这不应该影响 LoRA,因为 LoRA 是作为适配器添加的,不是原始参数
|
| 1054 |
+
# 我们需要确保在创建 SD3WithRectifiedNoise 之前,LoRA 适配器已经正确加载和启用
|
| 1055 |
+
if lora_loaded and rank == 0:
|
| 1056 |
+
print("Creating SD3WithRectifiedNoise with LoRA-enabled transformer...")
|
| 1057 |
+
elif rank == 0:
|
| 1058 |
+
print("Creating SD3WithRectifiedNoise...")
|
| 1059 |
+
|
| 1060 |
+
model = SD3WithRectifiedNoise(pipeline.transformer, rectified_module).to(device)
|
| 1061 |
+
|
| 1062 |
+
# 重要:SD3WithRectifiedNoise 的 __init__ 会冻结 transformer 参数
|
| 1063 |
+
# 但 LoRA 适配器应该仍然有效,因为它们是独立的模块
|
| 1064 |
+
# 我们需要确保 LoRA 适配器在包装后仍然可以访问
|
| 1065 |
+
|
| 1066 |
+
# 确保 LoRA 适配器在模型替换后仍然启用
|
| 1067 |
+
if lora_loaded:
|
| 1068 |
+
# 通过model.transformer访问,因为SD3WithRectifiedNoise包装了transformer
|
| 1069 |
+
if hasattr(model.transformer, 'peft_config'):
|
| 1070 |
+
try:
|
| 1071 |
+
# 确保适配器处于启用状态
|
| 1072 |
+
model.transformer.set_adapter("default_0")
|
| 1073 |
+
|
| 1074 |
+
# 验证LoRA权重在包装后是否仍然存在
|
| 1075 |
+
lora_layers_after_wrap = 0
|
| 1076 |
+
nonzero_after_wrap = 0
|
| 1077 |
+
for name, module in model.transformer.named_modules():
|
| 1078 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 1079 |
+
lora_layers_after_wrap += 1
|
| 1080 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 1081 |
+
if module.weight.abs().sum().item() > 1e-6:
|
| 1082 |
+
nonzero_after_wrap += 1
|
| 1083 |
+
|
| 1084 |
+
if rank == 0:
|
| 1085 |
+
print(f"LoRA after SD3WithRectifiedNoise wrapping:")
|
| 1086 |
+
print(f" LoRA layers: {lora_layers_after_wrap}, Non-zero: {nonzero_after_wrap}")
|
| 1087 |
+
if nonzero_after_wrap == 0:
|
| 1088 |
+
print(" ❌ ERROR: All LoRA weights are zero after wrapping!")
|
| 1089 |
+
elif nonzero_after_wrap < lora_layers_after_wrap * 0.5:
|
| 1090 |
+
print(f" ⚠️ WARNING: Only {nonzero_after_wrap}/{lora_layers_after_wrap} LoRA layers have weights!")
|
| 1091 |
+
else:
|
| 1092 |
+
print(f" ✓ LoRA weights preserved after wrapping")
|
| 1093 |
+
|
| 1094 |
+
# 验证适配器是否真的启用
|
| 1095 |
+
if hasattr(model.transformer, 'active_adapters'):
|
| 1096 |
+
try:
|
| 1097 |
+
if callable(model.transformer.active_adapters):
|
| 1098 |
+
active = model.transformer.active_adapters()
|
| 1099 |
+
else:
|
| 1100 |
+
active = model.transformer.active_adapters
|
| 1101 |
+
if rank == 0:
|
| 1102 |
+
print(f" Active adapters: {active}")
|
| 1103 |
+
except:
|
| 1104 |
+
if rank == 0:
|
| 1105 |
+
print(" LoRA adapter re-enabled after model wrapping")
|
| 1106 |
+
else:
|
| 1107 |
+
if rank == 0:
|
| 1108 |
+
print(" LoRA adapter re-enabled after model wrapping")
|
| 1109 |
+
except Exception as e:
|
| 1110 |
+
if rank == 0:
|
| 1111 |
+
print(f"❌ ERROR: Could not re-enable LoRA adapter: {e}")
|
| 1112 |
+
import traceback
|
| 1113 |
+
traceback.print_exc()
|
| 1114 |
+
else:
|
| 1115 |
+
# LoRA权重已经合并到transformer的基础权重中(合并加载方式)
|
| 1116 |
+
# 这种情况下没有peft_config是正常的,因为LoRA已经合并了
|
| 1117 |
+
if rank == 0:
|
| 1118 |
+
print("LoRA loaded via merged weights (no PEFT adapter needed)")
|
| 1119 |
+
print(" ✓ LoRA weights are already merged into transformer base weights")
|
| 1120 |
+
print(" Note: This is expected when loading from merged checkpoint format")
|
| 1121 |
+
|
| 1122 |
+
# 注册到 pipeline(pipeline_stable_diffusion_3.py 已支持 external model)
|
| 1123 |
+
pipeline.model = model
|
| 1124 |
+
|
| 1125 |
+
# 确保模型处于评估模式(LoRA在eval模式下也应该工作)
|
| 1126 |
+
model.eval()
|
| 1127 |
+
model.transformer.eval() # 确保transformer也处于eval模式
|
| 1128 |
+
else:
|
| 1129 |
+
if rank == 0:
|
| 1130 |
+
print("Not using Rectified Noise module, using baseline SD3 pipeline")
|
| 1131 |
+
# 不使用 SD3WithRectifiedNoise,保持原始 pipeline
|
| 1132 |
+
# pipeline.model 保持为原始的 transformer
|
| 1133 |
+
|
| 1134 |
+
# 关键:确保LoRA适配器在推理时被使用
|
| 1135 |
+
# PEFT模型在eval模式下,LoRA适配器应该自动启用,但我们需要确保
|
| 1136 |
+
if lora_loaded:
|
| 1137 |
+
# 获取正确的 transformer 引用
|
| 1138 |
+
transformer_ref = model.transformer if use_rectified else pipeline.transformer
|
| 1139 |
+
|
| 1140 |
+
# 确保transformer的LoRA适配器处于启用状态
|
| 1141 |
+
if hasattr(transformer_ref, 'set_adapter'):
|
| 1142 |
+
try:
|
| 1143 |
+
transformer_ref.set_adapter("default")
|
| 1144 |
+
except:
|
| 1145 |
+
pass
|
| 1146 |
+
|
| 1147 |
+
# 验证LoRA是否真的会被使用
|
| 1148 |
+
if rank == 0:
|
| 1149 |
+
# 检查一个LoRA层的权重
|
| 1150 |
+
lora_found = False
|
| 1151 |
+
for name, module in transformer_ref.named_modules():
|
| 1152 |
+
if 'lora_A' in name and 'default' in name and hasattr(module, 'weight'):
|
| 1153 |
+
if module.weight is not None:
|
| 1154 |
+
weight_sum = module.weight.abs().sum().item()
|
| 1155 |
+
if weight_sum > 0:
|
| 1156 |
+
print(f"✓ Verified LoRA weight in {name}: sum={weight_sum:.6f}")
|
| 1157 |
+
lora_found = True
|
| 1158 |
+
break
|
| 1159 |
+
|
| 1160 |
+
if not lora_found:
|
| 1161 |
+
print("⚠ Warning: Could not verify LoRA weights in model")
|
| 1162 |
+
else:
|
| 1163 |
+
# 额外检查:验证LoRA层是否真的会被调用
|
| 1164 |
+
# 检查一个LoRA Linear层
|
| 1165 |
+
for name, module in transformer_ref.named_modules():
|
| 1166 |
+
if hasattr(module, '__class__') and 'lora' in module.__class__.__name__.lower():
|
| 1167 |
+
if hasattr(module, 'lora_enabled'):
|
| 1168 |
+
enabled = module.lora_enabled
|
| 1169 |
+
if rank == 0:
|
| 1170 |
+
print(f"✓ Found LoRA layer {name}, enabled: {enabled}")
|
| 1171 |
+
break
|
| 1172 |
+
|
| 1173 |
+
print("Model set to eval mode, LoRA should be active during inference")
|
| 1174 |
+
|
| 1175 |
+
# 启用内存优化选项
|
| 1176 |
+
if args.enable_attention_slicing:
|
| 1177 |
+
if rank == 0:
|
| 1178 |
+
print("Enabling attention slicing to save memory")
|
| 1179 |
+
pipeline.enable_attention_slicing()
|
| 1180 |
+
|
| 1181 |
+
if args.enable_vae_slicing:
|
| 1182 |
+
if rank == 0:
|
| 1183 |
+
print("Enabling VAE slicing to save memory")
|
| 1184 |
+
pipeline.enable_vae_slicing()
|
| 1185 |
+
|
| 1186 |
+
if args.enable_cpu_offload:
|
| 1187 |
+
if rank == 0:
|
| 1188 |
+
print("Enabling CPU offload to save memory")
|
| 1189 |
+
pipeline.enable_model_cpu_offload()
|
| 1190 |
+
|
| 1191 |
+
# 禁用进度条以减少输出
|
| 1192 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 1193 |
+
|
| 1194 |
+
# 读入 captions
|
| 1195 |
+
captions = load_captions_from_jsonl(args.captions_jsonl)
|
| 1196 |
+
total_images_needed = min(len(captions) * args.images_per_caption, args.max_samples)
|
| 1197 |
+
|
| 1198 |
+
# 输出目录
|
| 1199 |
+
if rank == 0:
|
| 1200 |
+
os.makedirs(args.sample_dir, exist_ok=True)
|
| 1201 |
+
dist.barrier()
|
| 1202 |
+
|
| 1203 |
+
# 检查已存在的样本
|
| 1204 |
+
existing_count, max_existing_index = get_existing_sample_count(args.sample_dir)
|
| 1205 |
+
if rank == 0:
|
| 1206 |
+
print(f"Found {existing_count} existing samples, max index: {max_existing_index}")
|
| 1207 |
+
|
| 1208 |
+
# 调整需要生成的样本数量
|
| 1209 |
+
remaining_images_needed = max(0, total_images_needed - existing_count)
|
| 1210 |
+
if remaining_images_needed == 0:
|
| 1211 |
+
if rank == 0:
|
| 1212 |
+
print("All required samples already exist. Skipping generation.")
|
| 1213 |
+
print(f"Creating npz from existing samples...")
|
| 1214 |
+
create_npz_from_sample_folder(args.sample_dir, total_images_needed)
|
| 1215 |
+
return
|
| 1216 |
+
|
| 1217 |
+
if rank == 0:
|
| 1218 |
+
print(f"Need to generate {remaining_images_needed} more samples (total needed: {total_images_needed})")
|
| 1219 |
+
|
| 1220 |
+
n = args.per_proc_batch_size
|
| 1221 |
+
global_batch = n * world_size
|
| 1222 |
+
total_samples = int(math.ceil(remaining_images_needed / global_batch) * global_batch)
|
| 1223 |
+
assert total_samples % world_size == 0
|
| 1224 |
+
samples_per_gpu = total_samples // world_size
|
| 1225 |
+
assert samples_per_gpu % n == 0
|
| 1226 |
+
iterations = samples_per_gpu // n
|
| 1227 |
+
|
| 1228 |
+
if rank == 0:
|
| 1229 |
+
print(f"Sampling remaining={remaining_images_needed}, total_samples={total_samples}, per_gpu={samples_per_gpu}, iterations={iterations}")
|
| 1230 |
+
|
| 1231 |
+
pbar = tqdm(range(iterations)) if rank == 0 else range(iterations)
|
| 1232 |
+
saved = 0
|
| 1233 |
+
|
| 1234 |
+
autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1235 |
+
for it in pbar:
|
| 1236 |
+
if rank == 0 and it % 10 == 0:
|
| 1237 |
+
print(f"[rank{rank}] Sampling iteration {it}/{iterations}")
|
| 1238 |
+
batch_prompts = []
|
| 1239 |
+
base_index = it * global_batch + rank
|
| 1240 |
+
for j in range(n):
|
| 1241 |
+
idx = it * global_batch + j * world_size + rank
|
| 1242 |
+
if idx < remaining_images_needed:
|
| 1243 |
+
cap_idx = idx // args.images_per_caption
|
| 1244 |
+
batch_prompts.append(captions[cap_idx])
|
| 1245 |
+
else:
|
| 1246 |
+
batch_prompts.append("a beautiful high quality image")
|
| 1247 |
+
|
| 1248 |
+
with torch.autocast(autocast_device, dtype=dtype):
|
| 1249 |
+
images = []
|
| 1250 |
+
for k, prompt in enumerate(batch_prompts):
|
| 1251 |
+
image_seed = seed + it * 10000 + k * 1000 + rank
|
| 1252 |
+
generator = torch.Generator(device=device).manual_seed(image_seed)
|
| 1253 |
+
img = pipeline(
|
| 1254 |
+
prompt=prompt,
|
| 1255 |
+
height=args.height,
|
| 1256 |
+
width=args.width,
|
| 1257 |
+
num_inference_steps=args.num_inference_steps,
|
| 1258 |
+
guidance_scale=args.guidance_scale,
|
| 1259 |
+
generator=generator,
|
| 1260 |
+
num_images_per_prompt=1,
|
| 1261 |
+
).images[0]
|
| 1262 |
+
images.append(img)
|
| 1263 |
+
|
| 1264 |
+
# 保存
|
| 1265 |
+
out_dir = Path(args.sample_dir)
|
| 1266 |
+
if rank == 0 and it == 0:
|
| 1267 |
+
print(f"Saving pngs to: {out_dir}")
|
| 1268 |
+
for j, img in enumerate(images):
|
| 1269 |
+
global_index = it * global_batch + j * world_size + rank + existing_count # 加上已存在的数量
|
| 1270 |
+
if global_index < total_images_needed:
|
| 1271 |
+
filename = f"{global_index:07d}.png"
|
| 1272 |
+
img.save(out_dir / filename)
|
| 1273 |
+
saved += 1
|
| 1274 |
+
dist.barrier()
|
| 1275 |
+
|
| 1276 |
+
if rank == 0:
|
| 1277 |
+
print(f"Done. Saved {saved * world_size} images in total.")
|
| 1278 |
+
actual_num_samples = len([name for name in os.listdir(args.sample_dir) if name.endswith(".png")])
|
| 1279 |
+
print(f"Actually generated {actual_num_samples} images")
|
| 1280 |
+
npz_samples = min(actual_num_samples, total_images_needed)
|
| 1281 |
+
print(f"[rank{rank}] Creating npz from sample folder: {args.sample_dir}, npz_samples={npz_samples}")
|
| 1282 |
+
create_npz_from_sample_folder(args.sample_dir, npz_samples)
|
| 1283 |
+
print("Done creating npz.")
|
| 1284 |
+
print("Done.")
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
if __name__ == "__main__":
|
| 1288 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA + RectifiedNoise 分布式采样脚本")
|
| 1289 |
+
# 模型
|
| 1290 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 1291 |
+
parser.add_argument("--revision", type=str, default=None)
|
| 1292 |
+
parser.add_argument("--variant", type=str, default=None)
|
| 1293 |
+
# LoRA 与 Rectified
|
| 1294 |
+
parser.add_argument("--lora_path", type=str, default=None, help="LoRA 权重路径(文件或目录)")
|
| 1295 |
+
parser.add_argument("--rectified_weights", type=str, default=None, help="Rectified(SIT) 权重路径(文件或目录)")
|
| 1296 |
+
parser.add_argument("--num_sit_layers", type=int, default=1, help="与训练一致的 SIT 层数")
|
| 1297 |
+
# 采样
|
| 1298 |
+
parser.add_argument("--num_inference_steps", type=int, default=28)
|
| 1299 |
+
parser.add_argument("--guidance_scale", type=float, default=7.0)
|
| 1300 |
+
parser.add_argument("--height", type=int, default=1024)
|
| 1301 |
+
parser.add_argument("--width", type=int, default=1024)
|
| 1302 |
+
parser.add_argument("--per_proc_batch_size", type=int, default=1)
|
| 1303 |
+
parser.add_argument("--images_per_caption", type=int, default=1)
|
| 1304 |
+
parser.add_argument("--max_samples", type=int, default=10000)
|
| 1305 |
+
parser.add_argument("--captions_jsonl", type=str, required=True)
|
| 1306 |
+
parser.add_argument("--sample_dir", type=str, default="sd3_rectified_samples")
|
| 1307 |
+
parser.add_argument("--global_seed", type=int, default=42)
|
| 1308 |
+
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 1309 |
+
# 内存优化选项
|
| 1310 |
+
parser.add_argument("--enable_attention_slicing", action="store_true", help="启用 attention slicing 以节省显存")
|
| 1311 |
+
parser.add_argument("--enable_vae_slicing", action="store_true", help="启用 VAE slicing 以节省显存")
|
| 1312 |
+
parser.add_argument("--enable_cpu_offload", action="store_true", help="启用 CPU offload 以节省显存")
|
| 1313 |
+
|
| 1314 |
+
args = parser.parse_args()
|
| 1315 |
+
main(args)
|
| 1316 |
+
|
sample_sd3_rectified_ddp_old.py
ADDED
|
@@ -0,0 +1,1317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
分布式采样脚本:支持指定 LoRA 权重与 Rectified Noise(SIT) 权重
|
| 5 |
+
|
| 6 |
+
依据 train_rectified_noise.py 的模型结构,加载并组装 SD3WithRectifiedNoise 进行采样。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
from accelerate import Accelerator
|
| 23 |
+
from diffusers import StableDiffusion3Pipeline
|
| 24 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
| 25 |
+
from peft.utils import set_peft_model_state_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def dynamic_import_training_classes(project_root: str):
|
| 29 |
+
"""从 train_rectified_noise.py 动态导入 RectifiedNoiseModule 和 SD3WithRectifiedNoise"""
|
| 30 |
+
sys.path.insert(0, project_root)
|
| 31 |
+
try:
|
| 32 |
+
import train_rectified_noise as trn
|
| 33 |
+
return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
|
| 34 |
+
except Exception as e:
|
| 35 |
+
raise ImportError(f"无法从 train_rectified_noise.py 导入类: {e}")
|
| 36 |
+
|
| 37 |
+
def create_npz_from_sample_folder(sample_dir, num_samples):
|
| 38 |
+
"""
|
| 39 |
+
从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式
|
| 40 |
+
"""
|
| 41 |
+
samples = []
|
| 42 |
+
actual_files = []
|
| 43 |
+
|
| 44 |
+
# 收集所有PNG文件
|
| 45 |
+
for filename in sorted(os.listdir(sample_dir)):
|
| 46 |
+
if filename.endswith('.png'):
|
| 47 |
+
actual_files.append(filename)
|
| 48 |
+
|
| 49 |
+
# 按照数量限制处理
|
| 50 |
+
for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"):
|
| 51 |
+
if i < len(actual_files):
|
| 52 |
+
sample_path = os.path.join(sample_dir, actual_files[i])
|
| 53 |
+
sample_pil = Image.open(sample_path)
|
| 54 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 55 |
+
samples.append(sample_np)
|
| 56 |
+
else:
|
| 57 |
+
# 如果不够,创建空白图像
|
| 58 |
+
sample_np = np.zeros((512, 512, 3), dtype=np.uint8)
|
| 59 |
+
samples.append(sample_np)
|
| 60 |
+
|
| 61 |
+
if samples:
|
| 62 |
+
samples = np.stack(samples)
|
| 63 |
+
npz_path = f"{sample_dir}.npz"
|
| 64 |
+
np.savez(npz_path, arr_0=samples)
|
| 65 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 66 |
+
return npz_path
|
| 67 |
+
else:
|
| 68 |
+
print("No samples found to create npz file.")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_sit_weights(rectified_module, weights_path: str, rank=0):
|
| 73 |
+
"""加载 Rectified Noise(SIT) 权重,支持 .safetensors / .bin / .pt
|
| 74 |
+
支持以下目录结构:
|
| 75 |
+
- weights_path/pytorch_sit_weights.safetensors (直接在主目录)
|
| 76 |
+
- weights_path/sit_weights/pytorch_sit_weights.safetensors (在sit_weights子目录)
|
| 77 |
+
"""
|
| 78 |
+
if os.path.isdir(weights_path):
|
| 79 |
+
# 首先尝试在主目录查找
|
| 80 |
+
search_paths = [
|
| 81 |
+
weights_path, # 主目录
|
| 82 |
+
os.path.join(weights_path, "sit_weights"), # sit_weights子目录
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
for search_dir in search_paths:
|
| 86 |
+
if not os.path.exists(search_dir):
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
# 优先寻找 safetensors
|
| 90 |
+
st_path = os.path.join(search_dir, "pytorch_sit_weights.safetensors")
|
| 91 |
+
if os.path.exists(st_path):
|
| 92 |
+
try:
|
| 93 |
+
from safetensors.torch import load_file
|
| 94 |
+
if rank == 0:
|
| 95 |
+
print(f"Loading rectified weights from: {st_path}")
|
| 96 |
+
state = load_file(st_path)
|
| 97 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 98 |
+
if rank == 0:
|
| 99 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 100 |
+
if missing_keys:
|
| 101 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 102 |
+
if unexpected_keys:
|
| 103 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 104 |
+
return True
|
| 105 |
+
except Exception as e:
|
| 106 |
+
if rank == 0:
|
| 107 |
+
print(f" Failed to load from {st_path}: {e}")
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# 其次寻找 bin/pt
|
| 111 |
+
for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
|
| 112 |
+
cand = os.path.join(search_dir, name)
|
| 113 |
+
if os.path.exists(cand):
|
| 114 |
+
try:
|
| 115 |
+
if rank == 0:
|
| 116 |
+
print(f"Loading rectified weights from: {cand}")
|
| 117 |
+
state = torch.load(cand, map_location="cpu")
|
| 118 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 119 |
+
if rank == 0:
|
| 120 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 121 |
+
if missing_keys:
|
| 122 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 123 |
+
if unexpected_keys:
|
| 124 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 125 |
+
return True
|
| 126 |
+
except Exception as e:
|
| 127 |
+
if rank == 0:
|
| 128 |
+
print(f" Failed to load from {cand}: {e}")
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
# 兜底:目录下任意 pt/bin
|
| 132 |
+
try:
|
| 133 |
+
for fn in os.listdir(search_dir):
|
| 134 |
+
if fn.endswith((".pt", ".bin")):
|
| 135 |
+
cand = os.path.join(search_dir, fn)
|
| 136 |
+
try:
|
| 137 |
+
if rank == 0:
|
| 138 |
+
print(f"Loading rectified weights from: {cand}")
|
| 139 |
+
state = torch.load(cand, map_location="cpu")
|
| 140 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 141 |
+
if rank == 0:
|
| 142 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 143 |
+
return True
|
| 144 |
+
except Exception as e:
|
| 145 |
+
if rank == 0:
|
| 146 |
+
print(f" Failed to load from {cand}: {e}")
|
| 147 |
+
continue
|
| 148 |
+
except Exception:
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
if rank == 0:
|
| 152 |
+
print(f" ❌ No rectified weights found in {weights_path} or {os.path.join(weights_path, 'sit_weights')}")
|
| 153 |
+
return False
|
| 154 |
+
else:
|
| 155 |
+
# 直接文件
|
| 156 |
+
try:
|
| 157 |
+
if rank == 0:
|
| 158 |
+
print(f"Loading rectified weights from file: {weights_path}")
|
| 159 |
+
if weights_path.endswith(".safetensors"):
|
| 160 |
+
from safetensors.torch import load_file
|
| 161 |
+
state = load_file(weights_path)
|
| 162 |
+
else:
|
| 163 |
+
state = torch.load(weights_path, map_location="cpu")
|
| 164 |
+
missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False)
|
| 165 |
+
if rank == 0:
|
| 166 |
+
print(f" Loaded rectified weights: {len(state)} keys")
|
| 167 |
+
if missing_keys:
|
| 168 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 169 |
+
if unexpected_keys:
|
| 170 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 171 |
+
return True
|
| 172 |
+
except Exception as e:
|
| 173 |
+
if rank == 0:
|
| 174 |
+
print(f" ❌ Failed to load rectified weights from {weights_path}: {e}")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def check_lora_weights_exist(lora_path):
|
| 179 |
+
"""检查LoRA权重文件是否存在"""
|
| 180 |
+
if not lora_path:
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
if os.path.isdir(lora_path):
|
| 184 |
+
# 检查目录中是否有pytorch_lora_weights.safetensors文件
|
| 185 |
+
weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors")
|
| 186 |
+
if os.path.exists(weight_file):
|
| 187 |
+
return True
|
| 188 |
+
# 检查是否有其他.safetensors文件
|
| 189 |
+
for file in os.listdir(lora_path):
|
| 190 |
+
if file.endswith(".safetensors") and "lora" in file.lower():
|
| 191 |
+
return True
|
| 192 |
+
return False
|
| 193 |
+
elif os.path.isfile(lora_path):
|
| 194 |
+
return lora_path.endswith(".safetensors")
|
| 195 |
+
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0, lora_rank=64):
|
| 200 |
+
"""
|
| 201 |
+
从accelerator checkpoint目录加载LoRA权重或完整模型权重
|
| 202 |
+
如果checkpoint包含完整的模型权重(合并后的),直接加载
|
| 203 |
+
如果只包含LoRA权重,则按LoRA方式加载
|
| 204 |
+
"""
|
| 205 |
+
if rank == 0:
|
| 206 |
+
print(f"Loading weights from accelerator checkpoint: {checkpoint_path}")
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
from safetensors.torch import load_file
|
| 210 |
+
model_file = os.path.join(checkpoint_path, "model.safetensors")
|
| 211 |
+
if not os.path.exists(model_file):
|
| 212 |
+
if rank == 0:
|
| 213 |
+
print(f"Model file not found: {model_file}")
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
# 加载state dict
|
| 217 |
+
state_dict = load_file(model_file)
|
| 218 |
+
all_keys = list(state_dict.keys())
|
| 219 |
+
|
| 220 |
+
# 检测checkpoint类型:
|
| 221 |
+
# 1. 是否包含base_layer(PEFT格式,需要合并)
|
| 222 |
+
# 2. 是否包含完整的模型权重(合并后的,直接可用)
|
| 223 |
+
# 3. 是否只包含LoRA权重(需要添加适配器)
|
| 224 |
+
lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()]
|
| 225 |
+
base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()]
|
| 226 |
+
non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()]
|
| 227 |
+
|
| 228 |
+
if rank == 0:
|
| 229 |
+
print(f"Checkpoint analysis:")
|
| 230 |
+
print(f" Total keys: {len(all_keys)}")
|
| 231 |
+
print(f" LoRA keys: {len(lora_keys)}")
|
| 232 |
+
print(f" Base layer keys: {len(base_layer_keys)}")
|
| 233 |
+
print(f" Direct transformer weight keys (merged): {len(non_lora_transformer_keys)}")
|
| 234 |
+
|
| 235 |
+
# 如果包含base_layer,说明是PEFT格式,需要合并base_layer + lora
|
| 236 |
+
if len(base_layer_keys) > 0:
|
| 237 |
+
if rank == 0:
|
| 238 |
+
print(f"✓ Detected PEFT format (base_layer + LoRA), merging weights...")
|
| 239 |
+
|
| 240 |
+
# 合并base_layer和lora权重
|
| 241 |
+
merged_state_dict = {}
|
| 242 |
+
|
| 243 |
+
# 首先收集所有需要合并的模块
|
| 244 |
+
modules_to_merge = {}
|
| 245 |
+
# 记录所有非LoRA的transformer权重键名(用于调试)
|
| 246 |
+
non_lora_keys_found = []
|
| 247 |
+
|
| 248 |
+
for key in all_keys:
|
| 249 |
+
# 移除前缀
|
| 250 |
+
new_key = key
|
| 251 |
+
has_transformer_prefix = False
|
| 252 |
+
|
| 253 |
+
if key.startswith('base_model.model.transformer.'):
|
| 254 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 255 |
+
has_transformer_prefix = True
|
| 256 |
+
elif key.startswith('model.transformer.'):
|
| 257 |
+
new_key = key[len('model.transformer.'):]
|
| 258 |
+
has_transformer_prefix = True
|
| 259 |
+
elif key.startswith('transformer.'):
|
| 260 |
+
new_key = key[len('transformer.'):]
|
| 261 |
+
has_transformer_prefix = True
|
| 262 |
+
elif 'transformer' in key.lower():
|
| 263 |
+
# 可能没有前缀,但包含transformer(如直接是transformer_blocks.0...)
|
| 264 |
+
has_transformer_prefix = True
|
| 265 |
+
|
| 266 |
+
if not has_transformer_prefix:
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
# 检查是否是base_layer或lora权重
|
| 270 |
+
if '.base_layer.weight' in new_key:
|
| 271 |
+
# 提取模块名(去掉.base_layer.weight部分)
|
| 272 |
+
module_key = new_key.replace('.base_layer.weight', '.weight')
|
| 273 |
+
if module_key not in modules_to_merge:
|
| 274 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 275 |
+
modules_to_merge[module_key]['base_weight'] = (key, state_dict[key])
|
| 276 |
+
elif '.base_layer.bias' in new_key:
|
| 277 |
+
module_key = new_key.replace('.base_layer.bias', '.bias')
|
| 278 |
+
if module_key not in modules_to_merge:
|
| 279 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 280 |
+
modules_to_merge[module_key]['base_bias'] = (key, state_dict[key])
|
| 281 |
+
elif '.lora_A.default.weight' in new_key:
|
| 282 |
+
module_key = new_key.replace('.lora_A.default.weight', '.weight')
|
| 283 |
+
if module_key not in modules_to_merge:
|
| 284 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 285 |
+
modules_to_merge[module_key]['lora_A'] = (key, state_dict[key])
|
| 286 |
+
elif '.lora_B.default.weight' in new_key:
|
| 287 |
+
module_key = new_key.replace('.lora_B.default.weight', '.weight')
|
| 288 |
+
if module_key not in modules_to_merge:
|
| 289 |
+
modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None}
|
| 290 |
+
modules_to_merge[module_key]['lora_B'] = (key, state_dict[key])
|
| 291 |
+
elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower():
|
| 292 |
+
# 其他非LoRA权重(如pos_embed、time_text_embed、context_embedder等),直接使用
|
| 293 |
+
# 这些权重不在LoRA适配范围内,应该直接从checkpoint加载
|
| 294 |
+
merged_state_dict[new_key] = state_dict[key]
|
| 295 |
+
non_lora_keys_found.append(new_key)
|
| 296 |
+
|
| 297 |
+
if rank == 0:
|
| 298 |
+
print(f" Found {len(non_lora_keys_found)} non-LoRA transformer keys in checkpoint")
|
| 299 |
+
if non_lora_keys_found:
|
| 300 |
+
print(f" Sample non-LoRA keys: {non_lora_keys_found[:10]}")
|
| 301 |
+
|
| 302 |
+
# 合并权重:weight = base_weight + lora_B @ lora_A * (alpha / rank)
|
| 303 |
+
if rank == 0:
|
| 304 |
+
print(f" Merging {len(modules_to_merge)} modules...")
|
| 305 |
+
|
| 306 |
+
import torch
|
| 307 |
+
for module_key, weights in modules_to_merge.items():
|
| 308 |
+
# 处理权重(.weight)
|
| 309 |
+
if weights['base_weight'] is not None:
|
| 310 |
+
base_key, base_weight = weights['base_weight']
|
| 311 |
+
base_weight = base_weight.clone()
|
| 312 |
+
|
| 313 |
+
if weights['lora_A'] is not None and weights['lora_B'] is not None:
|
| 314 |
+
lora_A_key, lora_A = weights['lora_A']
|
| 315 |
+
lora_B_key, lora_B = weights['lora_B']
|
| 316 |
+
|
| 317 |
+
# 检测rank和alpha
|
| 318 |
+
# lora_A: [rank, in_features], lora_B: [out_features, rank]
|
| 319 |
+
rank_value = lora_A.shape[0]
|
| 320 |
+
alpha = rank_value # 通常alpha = rank
|
| 321 |
+
|
| 322 |
+
# 合并:weight = base + (lora_B @ lora_A) * (alpha / rank)
|
| 323 |
+
# lora_B @ lora_A 得到 [out_features, in_features]
|
| 324 |
+
lora_delta = torch.matmul(lora_B, lora_A)
|
| 325 |
+
|
| 326 |
+
if lora_delta.shape == base_weight.shape:
|
| 327 |
+
merged_weight = base_weight + lora_delta * (alpha / rank_value)
|
| 328 |
+
merged_state_dict[module_key] = merged_weight
|
| 329 |
+
if rank == 0 and len(modules_to_merge) <= 20:
|
| 330 |
+
print(f" ✓ Merged {module_key}: {base_weight.shape}")
|
| 331 |
+
else:
|
| 332 |
+
if rank == 0:
|
| 333 |
+
print(f" ⚠️ Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only")
|
| 334 |
+
merged_state_dict[module_key] = base_weight
|
| 335 |
+
else:
|
| 336 |
+
# 只有base权重,没有LoRA
|
| 337 |
+
merged_state_dict[module_key] = base_weight
|
| 338 |
+
|
| 339 |
+
# 处理bias(.bias)- bias通常不需要合并,直接使用base_bias
|
| 340 |
+
if '.bias' in module_key and weights['base_bias'] is not None:
|
| 341 |
+
bias_key, base_bias = weights['base_bias']
|
| 342 |
+
merged_state_dict[module_key] = base_bias.clone()
|
| 343 |
+
|
| 344 |
+
if rank == 0:
|
| 345 |
+
print(f" Merged {len(merged_state_dict)} weights")
|
| 346 |
+
print(f" Sample merged keys: {list(merged_state_dict.keys())[:5]}")
|
| 347 |
+
|
| 348 |
+
# 加载合并后的权重
|
| 349 |
+
try:
|
| 350 |
+
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False)
|
| 351 |
+
|
| 352 |
+
if rank == 0:
|
| 353 |
+
print(f" Loaded merged weights:")
|
| 354 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 355 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 356 |
+
if missing_keys:
|
| 357 |
+
print(f" Missing keys: {missing_keys}")
|
| 358 |
+
# 检查缺失的keys是否关键
|
| 359 |
+
critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out']
|
| 360 |
+
has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys)
|
| 361 |
+
if has_critical:
|
| 362 |
+
print(f" ⚠️ WARNING: Missing critical keys! These should be loaded from pretrained model.")
|
| 363 |
+
print(f" The missing keys will use values from the pretrained model (not fine-tuned).")
|
| 364 |
+
|
| 365 |
+
# 如果缺失的keys太多或包含关键组件,给出警告
|
| 366 |
+
if len(missing_keys) > 0:
|
| 367 |
+
# 这些缺失的keys会使用pretrained model的默认值
|
| 368 |
+
# 这是正常的,因为LoRA只适配了部分层,其他层保持原样
|
| 369 |
+
if rank == 0:
|
| 370 |
+
print(f" Note: Missing keys will use pretrained model weights (not fine-tuned)")
|
| 371 |
+
|
| 372 |
+
if rank == 0:
|
| 373 |
+
print(f" ✓ Successfully loaded merged model weights")
|
| 374 |
+
return True
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
if rank == 0:
|
| 378 |
+
print(f" ❌ Error loading merged weights: {e}")
|
| 379 |
+
import traceback
|
| 380 |
+
traceback.print_exc()
|
| 381 |
+
return False
|
| 382 |
+
|
| 383 |
+
# 如果包含非LoRA的transformer权重(且没有base_layer),说明是合并后的完整模型
|
| 384 |
+
elif len(non_lora_transformer_keys) > 0:
|
| 385 |
+
if rank == 0:
|
| 386 |
+
print(f"✓ Detected merged model weights (contains full transformer weights)")
|
| 387 |
+
print(f" Loading full model weights directly...")
|
| 388 |
+
|
| 389 |
+
# 提取transformer相关的权重(包括LoRA和基础权重)
|
| 390 |
+
transformer_state_dict = {}
|
| 391 |
+
for key, value in state_dict.items():
|
| 392 |
+
# 移除可能的accelerator包装前缀
|
| 393 |
+
new_key = key
|
| 394 |
+
if key.startswith('base_model.model.transformer.'):
|
| 395 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 396 |
+
elif key.startswith('model.transformer.'):
|
| 397 |
+
new_key = key[len('model.transformer.'):]
|
| 398 |
+
elif key.startswith('transformer.'):
|
| 399 |
+
new_key = key[len('transformer.'):]
|
| 400 |
+
|
| 401 |
+
# 只保留transformer相关的权重(包括所有transformer子模块)
|
| 402 |
+
# 检查是否是transformer的权重(不包含text_encoder等)
|
| 403 |
+
if (new_key.startswith('transformer_blocks') or
|
| 404 |
+
new_key.startswith('pos_embed') or
|
| 405 |
+
new_key.startswith('time_text_embed') or
|
| 406 |
+
'lora' in new_key.lower()): # 也包含LoRA权重(如果存在)
|
| 407 |
+
transformer_state_dict[new_key] = value
|
| 408 |
+
|
| 409 |
+
if rank == 0:
|
| 410 |
+
print(f" Extracted {len(transformer_state_dict)} transformer weight keys")
|
| 411 |
+
print(f" Sample keys: {list(transformer_state_dict.keys())[:5]}")
|
| 412 |
+
|
| 413 |
+
# 直接加载到transformer(不使用LoRA适配器)
|
| 414 |
+
try:
|
| 415 |
+
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False)
|
| 416 |
+
|
| 417 |
+
if rank == 0:
|
| 418 |
+
print(f" Loaded full model weights:")
|
| 419 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 420 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 421 |
+
if missing_keys:
|
| 422 |
+
print(f" Sample missing keys: {missing_keys[:5]}")
|
| 423 |
+
if unexpected_keys:
|
| 424 |
+
print(f" Sample unexpected keys: {unexpected_keys[:5]}")
|
| 425 |
+
|
| 426 |
+
# 如果missing keys太多,可能有问题
|
| 427 |
+
if len(missing_keys) > len(transformer_state_dict) * 0.5:
|
| 428 |
+
if rank == 0:
|
| 429 |
+
print(f" ⚠️ WARNING: Too many missing keys, weights may not be fully loaded")
|
| 430 |
+
return False
|
| 431 |
+
|
| 432 |
+
if rank == 0:
|
| 433 |
+
print(f" ✓ Successfully loaded merged model weights")
|
| 434 |
+
return True
|
| 435 |
+
|
| 436 |
+
except Exception as e:
|
| 437 |
+
if rank == 0:
|
| 438 |
+
print(f" ❌ Error loading full model weights: {e}")
|
| 439 |
+
import traceback
|
| 440 |
+
traceback.print_exc()
|
| 441 |
+
return False
|
| 442 |
+
|
| 443 |
+
# 如果只包含LoRA权重,按原来的方式加载
|
| 444 |
+
if rank == 0:
|
| 445 |
+
print(f"Detected LoRA-only weights, loading as LoRA adapter...")
|
| 446 |
+
|
| 447 |
+
# 首先尝试从checkpoint中检测实际的rank
|
| 448 |
+
detected_rank = None
|
| 449 |
+
for key, value in state_dict.items():
|
| 450 |
+
if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
|
| 451 |
+
# lora_A的形状是 [rank, hidden_size]
|
| 452 |
+
detected_rank = value.shape[0]
|
| 453 |
+
if rank == 0:
|
| 454 |
+
print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
|
| 455 |
+
break
|
| 456 |
+
|
| 457 |
+
# 如果检测到rank,使用检测到的rank;否则使用传入的rank
|
| 458 |
+
actual_rank = detected_rank if detected_rank is not None else lora_rank
|
| 459 |
+
if detected_rank is not None and detected_rank != lora_rank:
|
| 460 |
+
if rank == 0:
|
| 461 |
+
print(f"⚠️ Warning: Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank")
|
| 462 |
+
|
| 463 |
+
# 检查适配器是否已存在,如果存在则先卸载
|
| 464 |
+
# SD3Transformer2DModel没有delete_adapter方法,需要使用unload_lora_weights
|
| 465 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 466 |
+
if "default" in pipeline.transformer.peft_config:
|
| 467 |
+
if rank == 0:
|
| 468 |
+
print("Removing existing 'default' adapter before adding new one...")
|
| 469 |
+
try:
|
| 470 |
+
# 使用pipeline的unload_lora_weights方法
|
| 471 |
+
pipeline.unload_lora_weights()
|
| 472 |
+
if rank == 0:
|
| 473 |
+
print("Successfully unloaded existing LoRA adapter")
|
| 474 |
+
except Exception as e:
|
| 475 |
+
if rank == 0:
|
| 476 |
+
print(f"❌ ERROR: Could not unload existing adapter: {e}")
|
| 477 |
+
print("Cannot proceed without cleaning up adapter")
|
| 478 |
+
return False
|
| 479 |
+
|
| 480 |
+
# 先配置LoRA适配器(必须在加载之前配置)
|
| 481 |
+
# 使用检测到的或传入的rank
|
| 482 |
+
transformer_lora_config = LoraConfig(
|
| 483 |
+
r=actual_rank,
|
| 484 |
+
lora_alpha=actual_rank,
|
| 485 |
+
init_lora_weights="gaussian",
|
| 486 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# 为transformer添加LoRA适配器
|
| 490 |
+
pipeline.transformer.add_adapter(transformer_lora_config)
|
| 491 |
+
|
| 492 |
+
if rank == 0:
|
| 493 |
+
print(f"LoRA adapter configured with rank={actual_rank}")
|
| 494 |
+
|
| 495 |
+
# 继续处理LoRA权重加载(state_dict已经在上面加载了)
|
| 496 |
+
|
| 497 |
+
# 提取LoRA权重 - accelerator保存的格式
|
| 498 |
+
# 从accelerator checkpoint的model.safetensors中,键名格式可能是:
|
| 499 |
+
# - transformer_blocks.X.attn.to_q.lora_A.default.weight (PEFT格式,直接可用)
|
| 500 |
+
# - 或者包含其他前缀
|
| 501 |
+
lora_state_dict = {}
|
| 502 |
+
for key, value in state_dict.items():
|
| 503 |
+
if 'lora' in key.lower() and 'transformer' in key.lower():
|
| 504 |
+
# 检查键名格式
|
| 505 |
+
new_key = key
|
| 506 |
+
|
| 507 |
+
# 移除可能的accelerator包装前缀
|
| 508 |
+
# accelerator可能保存为: model.transformer.transformer_blocks...
|
| 509 |
+
# 或者: base_model.model.transformer.transformer_blocks...
|
| 510 |
+
if key.startswith('base_model.model.transformer.'):
|
| 511 |
+
new_key = key[len('base_model.model.transformer.'):]
|
| 512 |
+
elif key.startswith('model.transformer.'):
|
| 513 |
+
new_key = key[len('model.transformer.'):]
|
| 514 |
+
elif key.startswith('transformer.'):
|
| 515 |
+
# 如果已经是transformer_blocks开头,不需要移除transformer.前缀
|
| 516 |
+
# 因为transformer_blocks是transformer的子模块
|
| 517 |
+
if not key[len('transformer.'):].startswith('transformer_blocks'):
|
| 518 |
+
new_key = key[len('transformer.'):]
|
| 519 |
+
else:
|
| 520 |
+
new_key = key[len('transformer.'):]
|
| 521 |
+
|
| 522 |
+
# 只保留transformer相关的LoRA权重
|
| 523 |
+
if 'transformer_blocks' in new_key or 'transformer' in new_key:
|
| 524 |
+
lora_state_dict[new_key] = value
|
| 525 |
+
|
| 526 |
+
if not lora_state_dict:
|
| 527 |
+
if rank == 0:
|
| 528 |
+
print("No LoRA weights found in checkpoint")
|
| 529 |
+
# 打印所有键名用于调试
|
| 530 |
+
all_keys = list(state_dict.keys())
|
| 531 |
+
print(f"Total keys: {len(all_keys)}")
|
| 532 |
+
print(f"First 20 keys: {all_keys[:20]}")
|
| 533 |
+
# 查找包含lora的键
|
| 534 |
+
lora_related = [k for k in all_keys if 'lora' in k.lower()]
|
| 535 |
+
if lora_related:
|
| 536 |
+
print(f"Keys containing 'lora': {lora_related[:10]}")
|
| 537 |
+
return False
|
| 538 |
+
|
| 539 |
+
if rank == 0:
|
| 540 |
+
print(f"Found {len(lora_state_dict)} LoRA weight keys")
|
| 541 |
+
sample_keys = list(lora_state_dict.keys())[:5]
|
| 542 |
+
print(f"Sample LoRA keys: {sample_keys}")
|
| 543 |
+
|
| 544 |
+
# 加载LoRA权重到transformer
|
| 545 |
+
# 注意:从checkpoint提取的键名格式已经是PEFT格式(如:transformer_blocks.0.attn.to_q.lora_A.default.weight)
|
| 546 |
+
# 不需要使用convert_unet_state_dict_to_peft转换,直接使用即可
|
| 547 |
+
try:
|
| 548 |
+
# 检查键名格式
|
| 549 |
+
sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
|
| 550 |
+
|
| 551 |
+
if rank == 0:
|
| 552 |
+
print(f"Original key format: {sample_key}")
|
| 553 |
+
|
| 554 |
+
# 关键问题:set_peft_model_state_dict期望的键名格式
|
| 555 |
+
# 从back/train_dreambooth_lora.py看,需要移除.default后缀
|
| 556 |
+
# 格式应该是:transformer_blocks.X.attn.to_q.lora_A.weight(没有.default)
|
| 557 |
+
# 但accelerator保存的格式是:transformer_blocks.X.attn.to_q.lora_A.default.weight(有.default)
|
| 558 |
+
|
| 559 |
+
# 检查键名格式
|
| 560 |
+
sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else ""
|
| 561 |
+
has_default_suffix = '.default.weight' in sample_key or '.default.bias' in sample_key
|
| 562 |
+
|
| 563 |
+
if rank == 0:
|
| 564 |
+
print(f"Sample key: {sample_key}")
|
| 565 |
+
print(f"Has .default suffix: {has_default_suffix}")
|
| 566 |
+
|
| 567 |
+
# 如果键名包含.default.weight或.default.bias,需要移除.default部分
|
| 568 |
+
# 因为set_peft_model_state_dict期望的格式是:lora_A.weight,而不是lora_A.default.weight
|
| 569 |
+
converted_dict = {}
|
| 570 |
+
for key, value in lora_state_dict.items():
|
| 571 |
+
# 移除.default后缀(如果存在)
|
| 572 |
+
# transformer_blocks.0.attn.to_q.lora_A.default.weight -> transformer_blocks.0.attn.to_q.lora_A.weight
|
| 573 |
+
new_key = key
|
| 574 |
+
if '.default.weight' in new_key:
|
| 575 |
+
new_key = new_key.replace('.default.weight', '.weight')
|
| 576 |
+
elif '.default.bias' in new_key:
|
| 577 |
+
new_key = new_key.replace('.default.bias', '.bias')
|
| 578 |
+
elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')):
|
| 579 |
+
# 处理其他可能的.default位置
|
| 580 |
+
new_key = new_key.replace('.default', '')
|
| 581 |
+
|
| 582 |
+
converted_dict[new_key] = value
|
| 583 |
+
|
| 584 |
+
if rank == 0:
|
| 585 |
+
print(f"Converted {len(converted_dict)} keys (removed .default suffix if present)")
|
| 586 |
+
print(f"Sample converted keys: {list(converted_dict.keys())[:5]}")
|
| 587 |
+
|
| 588 |
+
# 调用set_peft_model_state_dict并检查返回值
|
| 589 |
+
incompatible_keys = set_peft_model_state_dict(
|
| 590 |
+
pipeline.transformer,
|
| 591 |
+
converted_dict,
|
| 592 |
+
adapter_name="default"
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# 检查加载结果
|
| 596 |
+
if incompatible_keys is not None:
|
| 597 |
+
missing_keys = getattr(incompatible_keys, "missing_keys", [])
|
| 598 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", [])
|
| 599 |
+
|
| 600 |
+
if rank == 0:
|
| 601 |
+
print(f"LoRA loading result:")
|
| 602 |
+
print(f" Missing keys: {len(missing_keys)}")
|
| 603 |
+
print(f" Unexpected keys: {len(unexpected_keys)}")
|
| 604 |
+
|
| 605 |
+
if len(missing_keys) > 100:
|
| 606 |
+
print(f" ⚠️ WARNING: Too many missing keys ({len(missing_keys)}), LoRA may not be fully loaded!")
|
| 607 |
+
print(f" Sample missing keys: {missing_keys[:10]}")
|
| 608 |
+
elif missing_keys:
|
| 609 |
+
print(f" Sample missing keys: {missing_keys[:10]}")
|
| 610 |
+
|
| 611 |
+
if unexpected_keys:
|
| 612 |
+
print(f" Unexpected keys: {unexpected_keys[:10]}")
|
| 613 |
+
|
| 614 |
+
# 如果missing keys太多,说明加载失败
|
| 615 |
+
if len(missing_keys) > len(converted_dict) * 0.5: # 超过50%的键缺失
|
| 616 |
+
if rank == 0:
|
| 617 |
+
print("❌ ERROR: Too many missing keys, LoRA weights not loaded correctly!")
|
| 618 |
+
return False
|
| 619 |
+
else:
|
| 620 |
+
if rank == 0:
|
| 621 |
+
print("✓ LoRA weights loaded (no incompatible keys reported)")
|
| 622 |
+
|
| 623 |
+
except RuntimeError as e:
|
| 624 |
+
# 检查是否是size mismatch错误
|
| 625 |
+
error_str = str(e)
|
| 626 |
+
if "size mismatch" in error_str:
|
| 627 |
+
if rank == 0:
|
| 628 |
+
print(f"❌ Size mismatch error: The checkpoint rank doesn't match the adapter rank")
|
| 629 |
+
print(f" This usually means the checkpoint was trained with a different rank")
|
| 630 |
+
# 尝试从错误信息中提取期望的rank
|
| 631 |
+
import re
|
| 632 |
+
# 错误信息格式: "copying a param with shape torch.Size([32, 1536]) from checkpoint"
|
| 633 |
+
match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str)
|
| 634 |
+
if match:
|
| 635 |
+
checkpoint_rank = int(match.group(1))
|
| 636 |
+
if rank == 0:
|
| 637 |
+
print(f" Detected checkpoint rank: {checkpoint_rank}")
|
| 638 |
+
print(f" Adapter was configured with rank: {actual_rank}")
|
| 639 |
+
if checkpoint_rank != actual_rank:
|
| 640 |
+
print(f" ⚠️ Mismatch! Need to recreate adapter with rank={checkpoint_rank}")
|
| 641 |
+
else:
|
| 642 |
+
if rank == 0:
|
| 643 |
+
print(f"❌ Error setting LoRA state dict: {e}")
|
| 644 |
+
import traceback
|
| 645 |
+
traceback.print_exc()
|
| 646 |
+
# 清理适配器以便下次尝试
|
| 647 |
+
try:
|
| 648 |
+
pipeline.unload_lora_weights()
|
| 649 |
+
except:
|
| 650 |
+
pass
|
| 651 |
+
return False
|
| 652 |
+
except Exception as e:
|
| 653 |
+
if rank == 0:
|
| 654 |
+
print(f"❌ Error setting LoRA state dict: {e}")
|
| 655 |
+
import traceback
|
| 656 |
+
traceback.print_exc()
|
| 657 |
+
# 清理适配器以便下次尝试
|
| 658 |
+
try:
|
| 659 |
+
pipeline.unload_lora_weights()
|
| 660 |
+
except:
|
| 661 |
+
pass
|
| 662 |
+
return False
|
| 663 |
+
|
| 664 |
+
# 启用LoRA适配器
|
| 665 |
+
pipeline.transformer.set_adapter("default")
|
| 666 |
+
|
| 667 |
+
# 验证LoRA是否已加载和应用
|
| 668 |
+
if hasattr(pipeline.transformer, 'peft_config'):
|
| 669 |
+
adapters = list(pipeline.transformer.peft_config.keys())
|
| 670 |
+
if rank == 0:
|
| 671 |
+
print(f"LoRA adapters configured: {adapters}")
|
| 672 |
+
# 检查适配器是否启用
|
| 673 |
+
if hasattr(pipeline.transformer, 'active_adapters'):
|
| 674 |
+
# active_adapters 是一个方法,需要调用
|
| 675 |
+
try:
|
| 676 |
+
if callable(pipeline.transformer.active_adapters):
|
| 677 |
+
active = pipeline.transformer.active_adapters()
|
| 678 |
+
else:
|
| 679 |
+
active = pipeline.transformer.active_adapters
|
| 680 |
+
if rank == 0:
|
| 681 |
+
print(f"Active adapters: {active}")
|
| 682 |
+
except:
|
| 683 |
+
if rank == 0:
|
| 684 |
+
print("Could not get active adapters, but LoRA is configured")
|
| 685 |
+
|
| 686 |
+
# 验证LoRA权重是否真的被应用
|
| 687 |
+
# 检查LoRA层的权重是否非零
|
| 688 |
+
lora_layers_found = 0
|
| 689 |
+
nonzero_lora_layers = 0
|
| 690 |
+
total_lora_weight_sum = 0.0
|
| 691 |
+
|
| 692 |
+
for name, module in pipeline.transformer.named_modules():
|
| 693 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 694 |
+
lora_layers_found += 1
|
| 695 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 696 |
+
weight_sum = module.weight.abs().sum().item()
|
| 697 |
+
total_lora_weight_sum += weight_sum
|
| 698 |
+
if weight_sum > 1e-6: # 非零阈值
|
| 699 |
+
nonzero_lora_layers += 1
|
| 700 |
+
if rank == 0 and nonzero_lora_layers <= 3: # 只打印前3个
|
| 701 |
+
print(f"✓ Found non-zero LoRA weight in: {name}, sum={weight_sum:.6f}")
|
| 702 |
+
|
| 703 |
+
if rank == 0:
|
| 704 |
+
print(f"LoRA verification:")
|
| 705 |
+
print(f" Total LoRA layers found: {lora_layers_found}")
|
| 706 |
+
print(f" Non-zero LoRA layers: {nonzero_lora_layers}")
|
| 707 |
+
print(f" Total LoRA weight sum: {total_lora_weight_sum:.6f}")
|
| 708 |
+
|
| 709 |
+
if lora_layers_found == 0:
|
| 710 |
+
print("❌ ERROR: No LoRA layers found in transformer!")
|
| 711 |
+
return False
|
| 712 |
+
elif nonzero_lora_layers == 0:
|
| 713 |
+
print("❌ ERROR: All LoRA weights are zero, LoRA not loaded correctly!")
|
| 714 |
+
return False
|
| 715 |
+
elif nonzero_lora_layers < lora_layers_found * 0.5:
|
| 716 |
+
print(f"⚠️ WARNING: Only {nonzero_lora_layers}/{lora_layers_found} LoRA layers have non-zero weights!")
|
| 717 |
+
print("⚠️ LoRA may not be fully applied!")
|
| 718 |
+
else:
|
| 719 |
+
print(f"✓ LoRA weights verified: {nonzero_lora_layers}/{lora_layers_found} layers have non-zero weights")
|
| 720 |
+
|
| 721 |
+
if nonzero_lora_layers == 0:
|
| 722 |
+
return False
|
| 723 |
+
|
| 724 |
+
if rank == 0:
|
| 725 |
+
print("✓ Successfully loaded and verified LoRA weights from checkpoint")
|
| 726 |
+
|
| 727 |
+
return True
|
| 728 |
+
|
| 729 |
+
except Exception as e:
|
| 730 |
+
if rank == 0:
|
| 731 |
+
print(f"Error loading LoRA from checkpoint: {e}")
|
| 732 |
+
import traceback
|
| 733 |
+
traceback.print_exc()
|
| 734 |
+
return False
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def load_captions_from_jsonl(jsonl_path):
|
| 738 |
+
captions = []
|
| 739 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 740 |
+
for line in f:
|
| 741 |
+
line = line.strip()
|
| 742 |
+
if not line:
|
| 743 |
+
continue
|
| 744 |
+
try:
|
| 745 |
+
data = json.loads(line)
|
| 746 |
+
cap = None
|
| 747 |
+
for field in ['caption', 'text', 'prompt', 'description']:
|
| 748 |
+
if field in data and isinstance(data[field], str):
|
| 749 |
+
cap = data[field].strip()
|
| 750 |
+
break
|
| 751 |
+
if cap:
|
| 752 |
+
captions.append(cap)
|
| 753 |
+
except Exception:
|
| 754 |
+
continue
|
| 755 |
+
return captions if captions else ["a beautiful high quality image"]
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def main(args):
|
| 759 |
+
assert torch.cuda.is_available(), "需要GPU运行"
|
| 760 |
+
dist.init_process_group("nccl")
|
| 761 |
+
rank = dist.get_rank()
|
| 762 |
+
world_size = dist.get_world_size()
|
| 763 |
+
device = rank % torch.cuda.device_count()
|
| 764 |
+
torch.cuda.set_device(device)
|
| 765 |
+
seed = args.global_seed * world_size + rank
|
| 766 |
+
torch.manual_seed(seed)
|
| 767 |
+
|
| 768 |
+
# 调试:打印接收到的参数
|
| 769 |
+
if rank == 0:
|
| 770 |
+
print("=" * 80)
|
| 771 |
+
print("参数检查:")
|
| 772 |
+
print(f" lora_path: {args.lora_path}")
|
| 773 |
+
print(f" rectified_weights: {args.rectified_weights}")
|
| 774 |
+
print(f" lora_path is None: {args.lora_path is None}")
|
| 775 |
+
print(f" lora_path is empty: {args.lora_path == '' if args.lora_path else 'N/A'}")
|
| 776 |
+
print(f" rectified_weights is None: {args.rectified_weights is None}")
|
| 777 |
+
print(f" rectified_weights is empty: {args.rectified_weights == '' if args.rectified_weights else 'N/A'}")
|
| 778 |
+
print("=" * 80)
|
| 779 |
+
|
| 780 |
+
lora_source = "baseline"
|
| 781 |
+
|
| 782 |
+
# 导入训练脚本中的类
|
| 783 |
+
RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
|
| 784 |
+
|
| 785 |
+
# 加载 pipeline
|
| 786 |
+
dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
|
| 787 |
+
if rank == 0:
|
| 788 |
+
print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path} (dtype={dtype})")
|
| 789 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 790 |
+
args.pretrained_model_name_or_path,
|
| 791 |
+
revision=args.revision,
|
| 792 |
+
variant=args.variant,
|
| 793 |
+
torch_dtype=dtype,
|
| 794 |
+
).to(device)
|
| 795 |
+
|
| 796 |
+
# 加载 LoRA(可选)
|
| 797 |
+
lora_loaded = False
|
| 798 |
+
if args.lora_path:
|
| 799 |
+
if rank == 0:
|
| 800 |
+
print(f"Attempting to load LoRA weights from: {args.lora_path}")
|
| 801 |
+
print(f"LoRA path exists: {os.path.exists(args.lora_path) if args.lora_path else False}")
|
| 802 |
+
|
| 803 |
+
# 首先检查是否是标准的LoRA权重文件/目录
|
| 804 |
+
if check_lora_weights_exist(args.lora_path):
|
| 805 |
+
if rank == 0:
|
| 806 |
+
print("Found standard LoRA weights, loading...")
|
| 807 |
+
try:
|
| 808 |
+
# 检查加载前的transformer参数(用于验证)
|
| 809 |
+
if rank == 0:
|
| 810 |
+
sample_param_before = next(iter(pipeline.transformer.parameters())).clone()
|
| 811 |
+
print(f"Sample transformer param before LoRA (first 5 values): {sample_param_before.flatten()[:5]}")
|
| 812 |
+
|
| 813 |
+
pipeline.load_lora_weights(args.lora_path)
|
| 814 |
+
lora_loaded = True
|
| 815 |
+
lora_source = os.path.basename(args.lora_path.rstrip('/'))
|
| 816 |
+
|
| 817 |
+
# 验证LoRA是否真的被加载
|
| 818 |
+
if rank == 0:
|
| 819 |
+
sample_param_after = next(iter(pipeline.transformer.parameters())).clone()
|
| 820 |
+
param_diff = (sample_param_after - sample_param_before).abs().max().item()
|
| 821 |
+
print(f"Sample transformer param after LoRA (first 5 values): {sample_param_after.flatten()[:5]}")
|
| 822 |
+
print(f"Max parameter change after LoRA loading: {param_diff}")
|
| 823 |
+
if param_diff < 1e-6:
|
| 824 |
+
print("⚠️ WARNING: LoRA weights may not have been applied (parameter change is very small)")
|
| 825 |
+
else:
|
| 826 |
+
print("✓ LoRA weights appear to have been applied")
|
| 827 |
+
|
| 828 |
+
# 检查是否有peft_config
|
| 829 |
+
if hasattr(pipeline.transformer, 'peft_config'):
|
| 830 |
+
print(f"✓ PEFT config found: {list(pipeline.transformer.peft_config.keys())}")
|
| 831 |
+
else:
|
| 832 |
+
print("⚠️ WARNING: No peft_config found after loading LoRA")
|
| 833 |
+
|
| 834 |
+
if rank == 0:
|
| 835 |
+
print("LoRA loaded successfully from standard format.")
|
| 836 |
+
except Exception as e:
|
| 837 |
+
if rank == 0:
|
| 838 |
+
print(f"Failed to load LoRA from standard format: {e}")
|
| 839 |
+
import traceback
|
| 840 |
+
traceback.print_exc()
|
| 841 |
+
|
| 842 |
+
# 如果不是标准格式,尝试从accelerator checkpoint加载
|
| 843 |
+
if not lora_loaded and os.path.isdir(args.lora_path):
|
| 844 |
+
if rank == 0:
|
| 845 |
+
print("Standard LoRA weights not found, trying accelerator checkpoint format...")
|
| 846 |
+
|
| 847 |
+
# 首先尝试从checkpoint的model.safetensors中检测实际的rank
|
| 848 |
+
# 通过检查LoRA权重的形状来推断rank
|
| 849 |
+
detected_rank = None
|
| 850 |
+
try:
|
| 851 |
+
from safetensors.torch import load_file
|
| 852 |
+
model_file = os.path.join(args.lora_path, "model.safetensors")
|
| 853 |
+
if os.path.exists(model_file):
|
| 854 |
+
state_dict = load_file(model_file)
|
| 855 |
+
# 查找一个LoRA权重来确定rank
|
| 856 |
+
for key, value in state_dict.items():
|
| 857 |
+
if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2:
|
| 858 |
+
# lora_A的形状是 [rank, hidden_size]
|
| 859 |
+
detected_rank = value.shape[0]
|
| 860 |
+
if rank == 0:
|
| 861 |
+
print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})")
|
| 862 |
+
break
|
| 863 |
+
except Exception as e:
|
| 864 |
+
if rank == 0:
|
| 865 |
+
print(f"Could not detect rank from checkpoint: {e}")
|
| 866 |
+
|
| 867 |
+
# 构建rank尝试列表
|
| 868 |
+
# 如果检测到rank,优先使用检测到的rank,只尝试一次
|
| 869 |
+
# 如果未检测到,尝试常见的rank值
|
| 870 |
+
if detected_rank is not None:
|
| 871 |
+
rank_list = [detected_rank]
|
| 872 |
+
if rank == 0:
|
| 873 |
+
print(f"Using detected rank: {detected_rank}")
|
| 874 |
+
else:
|
| 875 |
+
# 如果检测失败,尝试常见的rank值(按用户指定的rank优先)
|
| 876 |
+
rank_list = []
|
| 877 |
+
# 如果用户指定了rank(从args.lora_rank),优先尝试
|
| 878 |
+
if hasattr(args, 'lora_rank') and args.lora_rank:
|
| 879 |
+
rank_list.append(args.lora_rank)
|
| 880 |
+
# 添加其他常见的rank值
|
| 881 |
+
for r in [32, 64, 16, 128]:
|
| 882 |
+
if r not in rank_list:
|
| 883 |
+
rank_list.append(r)
|
| 884 |
+
if rank == 0:
|
| 885 |
+
print(f"Rank detection failed, will try ranks in order: {rank_list}")
|
| 886 |
+
|
| 887 |
+
# 尝试不同的rank值
|
| 888 |
+
for lora_rank in rank_list:
|
| 889 |
+
# 在尝试新的rank之前,先清理已存在的适配器
|
| 890 |
+
# 重要:每次尝试前都要清理,否则适配器会保留之前的rank配置
|
| 891 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 892 |
+
if "default" in pipeline.transformer.peft_config:
|
| 893 |
+
try:
|
| 894 |
+
# 使用pipeline的unload_lora_weights方法
|
| 895 |
+
pipeline.unload_lora_weights()
|
| 896 |
+
if rank == 0:
|
| 897 |
+
print(f"Cleaned up existing adapter before trying rank={lora_rank}")
|
| 898 |
+
except Exception as e:
|
| 899 |
+
if rank == 0:
|
| 900 |
+
print(f"Warning: Could not unload adapter: {e}")
|
| 901 |
+
# 如果卸载失败,需要重新创建pipeline
|
| 902 |
+
if rank == 0:
|
| 903 |
+
print("⚠️ WARNING: Cannot unload adapter, will recreate pipeline...")
|
| 904 |
+
# 重新加载pipeline(最后手段)
|
| 905 |
+
try:
|
| 906 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 907 |
+
args.pretrained_model_name_or_path,
|
| 908 |
+
revision=args.revision,
|
| 909 |
+
variant=args.variant,
|
| 910 |
+
torch_dtype=dtype,
|
| 911 |
+
).to(device)
|
| 912 |
+
if rank == 0:
|
| 913 |
+
print("Pipeline recreated to clear adapter state")
|
| 914 |
+
except Exception as e2:
|
| 915 |
+
if rank == 0:
|
| 916 |
+
print(f"Failed to recreate pipeline: {e2}")
|
| 917 |
+
|
| 918 |
+
if rank == 0:
|
| 919 |
+
print(f"Trying to load with LoRA rank={lora_rank}...")
|
| 920 |
+
lora_loaded = load_lora_from_checkpoint(pipeline, args.lora_path, rank=rank, lora_rank=lora_rank)
|
| 921 |
+
if lora_loaded:
|
| 922 |
+
if rank == 0:
|
| 923 |
+
print(f"✓ Successfully loaded LoRA with rank={lora_rank}")
|
| 924 |
+
lora_source = "checkpoint"
|
| 925 |
+
break
|
| 926 |
+
elif rank == 0:
|
| 927 |
+
print(f"✗ Failed to load with rank={lora_rank}, trying next rank...")
|
| 928 |
+
|
| 929 |
+
# 如果checkpoint目录加载失败,尝试从输出目录的根目录加载标准LoRA权重
|
| 930 |
+
if not lora_loaded and os.path.isdir(args.lora_path):
|
| 931 |
+
# 检查输出目录的根目录(checkpoint的父目录)
|
| 932 |
+
output_dir = os.path.dirname(args.lora_path.rstrip('/'))
|
| 933 |
+
if output_dir and os.path.exists(output_dir):
|
| 934 |
+
if rank == 0:
|
| 935 |
+
print(f"Trying to load standard LoRA weights from output directory: {output_dir}")
|
| 936 |
+
if check_lora_weights_exist(output_dir):
|
| 937 |
+
try:
|
| 938 |
+
pipeline.load_lora_weights(output_dir)
|
| 939 |
+
lora_loaded = True
|
| 940 |
+
if rank == 0:
|
| 941 |
+
print("LoRA loaded successfully from output directory.")
|
| 942 |
+
except Exception as e:
|
| 943 |
+
if rank == 0:
|
| 944 |
+
print(f"Failed to load LoRA from output directory: {e}")
|
| 945 |
+
|
| 946 |
+
if not lora_loaded:
|
| 947 |
+
if rank == 0:
|
| 948 |
+
print(f"⚠️ WARNING: Failed to load LoRA weights from {args.lora_path}, using baseline model")
|
| 949 |
+
else:
|
| 950 |
+
# 最终验证LoRA是否真的被启用
|
| 951 |
+
if rank == 0:
|
| 952 |
+
print("=" * 80)
|
| 953 |
+
print("LoRA 加载验证:")
|
| 954 |
+
if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config:
|
| 955 |
+
print(f" ✓ PEFT config exists: {list(pipeline.transformer.peft_config.keys())}")
|
| 956 |
+
# 检查LoRA层的权重
|
| 957 |
+
lora_layers_found = 0
|
| 958 |
+
for name, module in pipeline.transformer.named_modules():
|
| 959 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 960 |
+
lora_layers_found += 1
|
| 961 |
+
if lora_layers_found <= 3: # 只打印前3个
|
| 962 |
+
if hasattr(module, 'weight'):
|
| 963 |
+
weight_sum = module.weight.abs().sum().item() if module.weight is not None else 0
|
| 964 |
+
print(f" ✓ Found LoRA layer: {name}, weight_sum={weight_sum:.6f}")
|
| 965 |
+
print(f" ✓ Total LoRA layers found: {lora_layers_found}")
|
| 966 |
+
if lora_layers_found == 0:
|
| 967 |
+
print(" ⚠️ WARNING: No LoRA layers found in transformer!")
|
| 968 |
+
else:
|
| 969 |
+
print(" ⚠️ WARNING: No PEFT config found - LoRA may not be active!")
|
| 970 |
+
print("=" * 80)
|
| 971 |
+
|
| 972 |
+
# 构建 RectifiedNoiseModule 并加载权重(仅在提供了 rectified_weights 时)
|
| 973 |
+
# 安全地检查 rectified_weights 是否有效
|
| 974 |
+
use_rectified = False
|
| 975 |
+
rectified_weights_path = None
|
| 976 |
+
if args.rectified_weights:
|
| 977 |
+
rectified_weights_str = str(args.rectified_weights).strip()
|
| 978 |
+
if rectified_weights_str:
|
| 979 |
+
use_rectified = True
|
| 980 |
+
rectified_weights_path = rectified_weights_str
|
| 981 |
+
|
| 982 |
+
if rank == 0:
|
| 983 |
+
print(f"use_rectified: {use_rectified}, rectified_weights_path: {rectified_weights_path}")
|
| 984 |
+
|
| 985 |
+
if use_rectified:
|
| 986 |
+
if rank == 0:
|
| 987 |
+
print(f"Using Rectified Noise module with weights from: {rectified_weights_path}")
|
| 988 |
+
|
| 989 |
+
# 从 transformer 配置推断必要尺寸
|
| 990 |
+
tfm = pipeline.transformer
|
| 991 |
+
if hasattr(tfm.config, 'joint_attention_dim') and tfm.config.joint_attention_dim is not None:
|
| 992 |
+
sit_hidden_size = tfm.config.joint_attention_dim
|
| 993 |
+
elif hasattr(tfm.config, 'inner_dim') and tfm.config.inner_dim is not None:
|
| 994 |
+
sit_hidden_size = tfm.config.inner_dim
|
| 995 |
+
elif hasattr(tfm.config, 'hidden_size') and tfm.config.hidden_size is not None:
|
| 996 |
+
sit_hidden_size = tfm.config.hidden_size
|
| 997 |
+
else:
|
| 998 |
+
sit_hidden_size = 4096
|
| 999 |
+
|
| 1000 |
+
transformer_hidden_size = getattr(tfm.config, 'hidden_size', 1536)
|
| 1001 |
+
num_attention_heads = getattr(tfm.config, 'num_attention_heads', 32)
|
| 1002 |
+
input_dim = getattr(tfm.config, 'in_channels', 16)
|
| 1003 |
+
|
| 1004 |
+
rectified_module = RectifiedNoiseModule(
|
| 1005 |
+
hidden_size=sit_hidden_size,
|
| 1006 |
+
num_sit_layers=args.num_sit_layers,
|
| 1007 |
+
num_attention_heads=num_attention_heads,
|
| 1008 |
+
input_dim=input_dim,
|
| 1009 |
+
transformer_hidden_size=transformer_hidden_size,
|
| 1010 |
+
)
|
| 1011 |
+
# 加载 SIT 权重
|
| 1012 |
+
ok = load_sit_weights(rectified_module, rectified_weights_path, rank=rank)
|
| 1013 |
+
if rank == 0:
|
| 1014 |
+
if not ok:
|
| 1015 |
+
print("⚠️ Warning: Failed to load rectified weights, will use baseline model without rectified noise")
|
| 1016 |
+
else:
|
| 1017 |
+
print("✓ Successfully loaded rectified noise weights")
|
| 1018 |
+
|
| 1019 |
+
# 组装 SD3WithRectifiedNoise
|
| 1020 |
+
# 关键:SD3WithRectifiedNoise 会保留 transformer 的引用
|
| 1021 |
+
# 但是,SD3WithRectifiedNoise 在 __init__ 中会冻结 transformer 参数
|
| 1022 |
+
# 这不应该影响 LoRA,因为 LoRA 是作为适配器添加的,不是原始参数
|
| 1023 |
+
# 我们需要确保在创建 SD3WithRectifiedNoise 之前,LoRA 适配器已经正确加载和启用
|
| 1024 |
+
if lora_loaded and rank == 0:
|
| 1025 |
+
print("Creating SD3WithRectifiedNoise with LoRA-enabled transformer...")
|
| 1026 |
+
elif rank == 0:
|
| 1027 |
+
print("Creating SD3WithRectifiedNoise...")
|
| 1028 |
+
|
| 1029 |
+
model = SD3WithRectifiedNoise(pipeline.transformer, rectified_module).to(device)
|
| 1030 |
+
|
| 1031 |
+
# 重要:SD3WithRectifiedNoise 的 __init__ 会冻结 transformer 参数
|
| 1032 |
+
# 但 LoRA 适配器应该仍然有效,因为它们是独立的模块
|
| 1033 |
+
# 我们需要确保 LoRA 适配器在包装后仍然可以访问
|
| 1034 |
+
|
| 1035 |
+
# 确保 LoRA 适配器在模型替换后仍然启用
|
| 1036 |
+
if lora_loaded:
|
| 1037 |
+
# 通过model.transformer访问,因为SD3WithRectifiedNoise包装了transformer
|
| 1038 |
+
if hasattr(model.transformer, 'peft_config'):
|
| 1039 |
+
try:
|
| 1040 |
+
# 确保适配器处于启用状态
|
| 1041 |
+
model.transformer.set_adapter("default_0")
|
| 1042 |
+
|
| 1043 |
+
# 验证LoRA权重在包装后是否仍然存在
|
| 1044 |
+
lora_layers_after_wrap = 0
|
| 1045 |
+
nonzero_after_wrap = 0
|
| 1046 |
+
for name, module in model.transformer.named_modules():
|
| 1047 |
+
if 'lora_A' in name or 'lora_B' in name:
|
| 1048 |
+
lora_layers_after_wrap += 1
|
| 1049 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 1050 |
+
if module.weight.abs().sum().item() > 1e-6:
|
| 1051 |
+
nonzero_after_wrap += 1
|
| 1052 |
+
|
| 1053 |
+
if rank == 0:
|
| 1054 |
+
print(f"LoRA after SD3WithRectifiedNoise wrapping:")
|
| 1055 |
+
print(f" LoRA layers: {lora_layers_after_wrap}, Non-zero: {nonzero_after_wrap}")
|
| 1056 |
+
if nonzero_after_wrap == 0:
|
| 1057 |
+
print(" ❌ ERROR: All LoRA weights are zero after wrapping!")
|
| 1058 |
+
elif nonzero_after_wrap < lora_layers_after_wrap * 0.5:
|
| 1059 |
+
print(f" ⚠️ WARNING: Only {nonzero_after_wrap}/{lora_layers_after_wrap} LoRA layers have weights!")
|
| 1060 |
+
else:
|
| 1061 |
+
print(f" ✓ LoRA weights preserved after wrapping")
|
| 1062 |
+
|
| 1063 |
+
# 验证适配器是否真的启用
|
| 1064 |
+
if hasattr(model.transformer, 'active_adapters'):
|
| 1065 |
+
try:
|
| 1066 |
+
if callable(model.transformer.active_adapters):
|
| 1067 |
+
active = model.transformer.active_adapters()
|
| 1068 |
+
else:
|
| 1069 |
+
active = model.transformer.active_adapters
|
| 1070 |
+
if rank == 0:
|
| 1071 |
+
print(f" Active adapters: {active}")
|
| 1072 |
+
except:
|
| 1073 |
+
if rank == 0:
|
| 1074 |
+
print(" LoRA adapter re-enabled after model wrapping")
|
| 1075 |
+
else:
|
| 1076 |
+
if rank == 0:
|
| 1077 |
+
print(" LoRA adapter re-enabled after model wrapping")
|
| 1078 |
+
except Exception as e:
|
| 1079 |
+
if rank == 0:
|
| 1080 |
+
print(f"❌ ERROR: Could not re-enable LoRA adapter: {e}")
|
| 1081 |
+
import traceback
|
| 1082 |
+
traceback.print_exc()
|
| 1083 |
+
else:
|
| 1084 |
+
# LoRA权重已经合并到transformer的基础权重中(合并加载方式)
|
| 1085 |
+
# 这种情况下没有peft_config是正常的,因为LoRA已经合并了
|
| 1086 |
+
if rank == 0:
|
| 1087 |
+
print("LoRA loaded via merged weights (no PEFT adapter needed)")
|
| 1088 |
+
print(" ✓ LoRA weights are already merged into transformer base weights")
|
| 1089 |
+
print(" Note: This is expected when loading from merged checkpoint format")
|
| 1090 |
+
|
| 1091 |
+
# 注册到 pipeline(pipeline_stable_diffusion_3.py 已支持 external model)
|
| 1092 |
+
pipeline.model = model
|
| 1093 |
+
|
| 1094 |
+
# 确保模型处于评估模式(LoRA在eval模式下也应该工作)
|
| 1095 |
+
model.eval()
|
| 1096 |
+
model.transformer.eval() # 确保transformer也处于eval模式
|
| 1097 |
+
else:
|
| 1098 |
+
if rank == 0:
|
| 1099 |
+
print("Not using Rectified Noise module, using baseline SD3 pipeline")
|
| 1100 |
+
# 不使用 SD3WithRectifiedNoise,保持原始 pipeline
|
| 1101 |
+
# pipeline.model 保持为原始的 transformer
|
| 1102 |
+
|
| 1103 |
+
# 关键:确保LoRA适配器在推理时被使用
|
| 1104 |
+
# PEFT模型在eval模式下,LoRA适配器应该自动启用,但我们需要确保
|
| 1105 |
+
if lora_loaded:
|
| 1106 |
+
# 获取正确的 transformer 引用
|
| 1107 |
+
transformer_ref = model.transformer if use_rectified else pipeline.transformer
|
| 1108 |
+
|
| 1109 |
+
# 确保transformer的LoRA适配器处于启用状态
|
| 1110 |
+
if hasattr(transformer_ref, 'set_adapter'):
|
| 1111 |
+
try:
|
| 1112 |
+
transformer_ref.set_adapter("default")
|
| 1113 |
+
except:
|
| 1114 |
+
pass
|
| 1115 |
+
|
| 1116 |
+
# 验证LoRA是否真的会被使用
|
| 1117 |
+
if rank == 0:
|
| 1118 |
+
# 检查一个LoRA层的权重
|
| 1119 |
+
lora_found = False
|
| 1120 |
+
for name, module in transformer_ref.named_modules():
|
| 1121 |
+
if 'lora_A' in name and 'default' in name and hasattr(module, 'weight'):
|
| 1122 |
+
if module.weight is not None:
|
| 1123 |
+
weight_sum = module.weight.abs().sum().item()
|
| 1124 |
+
if weight_sum > 0:
|
| 1125 |
+
print(f"✓ Verified LoRA weight in {name}: sum={weight_sum:.6f}")
|
| 1126 |
+
lora_found = True
|
| 1127 |
+
break
|
| 1128 |
+
|
| 1129 |
+
if not lora_found:
|
| 1130 |
+
print("⚠ Warning: Could not verify LoRA weights in model")
|
| 1131 |
+
else:
|
| 1132 |
+
# 额外检查:验证LoRA层是否真的会被调用
|
| 1133 |
+
# 检查一个LoRA Linear层
|
| 1134 |
+
for name, module in transformer_ref.named_modules():
|
| 1135 |
+
if hasattr(module, '__class__') and 'lora' in module.__class__.__name__.lower():
|
| 1136 |
+
if hasattr(module, 'lora_enabled'):
|
| 1137 |
+
enabled = module.lora_enabled
|
| 1138 |
+
if rank == 0:
|
| 1139 |
+
print(f"✓ Found LoRA layer {name}, enabled: {enabled}")
|
| 1140 |
+
break
|
| 1141 |
+
|
| 1142 |
+
print("Model set to eval mode, LoRA should be active during inference")
|
| 1143 |
+
|
| 1144 |
+
# 启用内存优化选项
|
| 1145 |
+
if args.enable_attention_slicing:
|
| 1146 |
+
enable_attention_slicing_method = getattr(pipeline, 'enable_attention_slicing', None)
|
| 1147 |
+
if enable_attention_slicing_method is not None and callable(enable_attention_slicing_method):
|
| 1148 |
+
try:
|
| 1149 |
+
if rank == 0:
|
| 1150 |
+
print("Enabling attention slicing to save memory")
|
| 1151 |
+
enable_attention_slicing_method()
|
| 1152 |
+
except Exception as e:
|
| 1153 |
+
if rank == 0:
|
| 1154 |
+
print(f"Warning: Failed to enable attention slicing: {e}")
|
| 1155 |
+
else:
|
| 1156 |
+
if rank == 0:
|
| 1157 |
+
print("Warning: Attention slicing not available for this pipeline")
|
| 1158 |
+
|
| 1159 |
+
if args.enable_vae_slicing:
|
| 1160 |
+
# 使用 getattr 来安全地检查方法是否存在,避免触发 __getattr__ 异常
|
| 1161 |
+
enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None)
|
| 1162 |
+
if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method):
|
| 1163 |
+
try:
|
| 1164 |
+
if rank == 0:
|
| 1165 |
+
print("Enabling VAE slicing to save memory")
|
| 1166 |
+
enable_vae_slicing_method()
|
| 1167 |
+
except Exception as e:
|
| 1168 |
+
if rank == 0:
|
| 1169 |
+
print(f"Warning: Failed to enable VAE slicing: {e}")
|
| 1170 |
+
else:
|
| 1171 |
+
if rank == 0:
|
| 1172 |
+
print("Warning: VAE slicing not available for this pipeline (SD3 may not support this)")
|
| 1173 |
+
|
| 1174 |
+
if args.enable_cpu_offload:
|
| 1175 |
+
if rank == 0:
|
| 1176 |
+
print("Enabling CPU offload to save memory")
|
| 1177 |
+
pipeline.enable_model_cpu_offload()
|
| 1178 |
+
|
| 1179 |
+
# 禁用进度条以减少输出
|
| 1180 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 1181 |
+
|
| 1182 |
+
# 读入 captions
|
| 1183 |
+
captions = load_captions_from_jsonl(args.captions_jsonl)
|
| 1184 |
+
total_images_needed = min(len(captions) * args.images_per_caption, args.max_samples)
|
| 1185 |
+
|
| 1186 |
+
# 生成caption和image的映射列表
|
| 1187 |
+
caption_image_pairs = []
|
| 1188 |
+
for i, caption in enumerate(captions):
|
| 1189 |
+
for j in range(args.images_per_caption):
|
| 1190 |
+
caption_image_pairs.append((caption, i, j)) # (caption, caption_idx, image_idx)
|
| 1191 |
+
|
| 1192 |
+
# 输出目录
|
| 1193 |
+
folder_name = f"sd3-rectified-{lora_source}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}"
|
| 1194 |
+
sample_folder_dir = os.path.join(args.sample_dir, folder_name)
|
| 1195 |
+
|
| 1196 |
+
if rank == 0:
|
| 1197 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 1198 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 1199 |
+
# 清空caption文件
|
| 1200 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 1201 |
+
if os.path.exists(caption_file):
|
| 1202 |
+
os.remove(caption_file)
|
| 1203 |
+
dist.barrier()
|
| 1204 |
+
|
| 1205 |
+
n = args.per_proc_batch_size
|
| 1206 |
+
global_batch = n * world_size
|
| 1207 |
+
total_samples = int(math.ceil(total_images_needed / global_batch) * global_batch)
|
| 1208 |
+
assert total_samples % world_size == 0
|
| 1209 |
+
samples_per_gpu = total_samples // world_size
|
| 1210 |
+
assert samples_per_gpu % n == 0
|
| 1211 |
+
iterations = samples_per_gpu // n
|
| 1212 |
+
|
| 1213 |
+
if rank == 0:
|
| 1214 |
+
print(f"Sampling total={total_samples}, per_gpu={samples_per_gpu}, iterations={iterations}")
|
| 1215 |
+
|
| 1216 |
+
pbar = tqdm(range(iterations)) if rank == 0 else range(iterations)
|
| 1217 |
+
saved = 0
|
| 1218 |
+
|
| 1219 |
+
autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1220 |
+
for it in pbar:
|
| 1221 |
+
# 获取这个batch对应的caption
|
| 1222 |
+
batch_prompts = []
|
| 1223 |
+
batch_caption_info = []
|
| 1224 |
+
|
| 1225 |
+
for j in range(n):
|
| 1226 |
+
global_index = it * global_batch + j * world_size + rank
|
| 1227 |
+
if global_index < len(caption_image_pairs):
|
| 1228 |
+
caption, caption_idx, image_idx = caption_image_pairs[global_index]
|
| 1229 |
+
batch_prompts.append(caption)
|
| 1230 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 1231 |
+
else:
|
| 1232 |
+
# 如果超出范围,使用最后一个caption
|
| 1233 |
+
if caption_image_pairs:
|
| 1234 |
+
caption, caption_idx, image_idx = caption_image_pairs[-1]
|
| 1235 |
+
batch_prompts.append(caption)
|
| 1236 |
+
batch_caption_info.append((caption, caption_idx, image_idx))
|
| 1237 |
+
else:
|
| 1238 |
+
batch_prompts.append("a beautiful high quality image")
|
| 1239 |
+
batch_caption_info.append(("a beautiful high quality image", 0, 0))
|
| 1240 |
+
|
| 1241 |
+
with torch.autocast(autocast_device, dtype=dtype):
|
| 1242 |
+
images = []
|
| 1243 |
+
for k, prompt in enumerate(batch_prompts):
|
| 1244 |
+
image_seed = seed + it * 10000 + k * 1000 + rank
|
| 1245 |
+
generator = torch.Generator(device=device).manual_seed(image_seed)
|
| 1246 |
+
img = pipeline(
|
| 1247 |
+
prompt=prompt,
|
| 1248 |
+
height=args.height,
|
| 1249 |
+
width=args.width,
|
| 1250 |
+
num_inference_steps=args.num_inference_steps,
|
| 1251 |
+
guidance_scale=args.guidance_scale,
|
| 1252 |
+
generator=generator,
|
| 1253 |
+
num_images_per_prompt=1,
|
| 1254 |
+
).images[0]
|
| 1255 |
+
images.append(img)
|
| 1256 |
+
|
| 1257 |
+
# 保存
|
| 1258 |
+
for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)):
|
| 1259 |
+
global_index = it * global_batch + j * world_size + rank
|
| 1260 |
+
if global_index < len(caption_image_pairs):
|
| 1261 |
+
# 保存图片,文件名包含caption索引和图片索引
|
| 1262 |
+
filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png"
|
| 1263 |
+
image_path = os.path.join(sample_folder_dir, filename)
|
| 1264 |
+
image.save(image_path)
|
| 1265 |
+
|
| 1266 |
+
# 保存caption信息到文本文件(只在rank 0上操作)
|
| 1267 |
+
if rank == 0:
|
| 1268 |
+
caption_file = os.path.join(sample_folder_dir, "captions.txt")
|
| 1269 |
+
with open(caption_file, "a", encoding="utf-8") as f:
|
| 1270 |
+
f.write(f"{filename}\t{caption}\n")
|
| 1271 |
+
|
| 1272 |
+
total_generated = saved * world_size # 近似值
|
| 1273 |
+
|
| 1274 |
+
dist.barrier()
|
| 1275 |
+
|
| 1276 |
+
if rank == 0:
|
| 1277 |
+
print(f"Done. Saved {saved * world_size} images in total.")
|
| 1278 |
+
# 重新计算实际生成的图片数量
|
| 1279 |
+
actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")])
|
| 1280 |
+
print(f"Actually generated {actual_num_samples} images")
|
| 1281 |
+
# 使用实际的图片数量或用户指定的数量,取较小值
|
| 1282 |
+
npz_samples = min(actual_num_samples, total_images_needed, args.max_samples)
|
| 1283 |
+
create_npz_from_sample_folder(sample_folder_dir, npz_samples)
|
| 1284 |
+
print("Done.")
|
| 1285 |
+
|
| 1286 |
+
dist.barrier()
|
| 1287 |
+
dist.destroy_process_group()
|
| 1288 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA + RectifiedNoise 分布式采样脚本")
|
| 1289 |
+
# 模型
|
| 1290 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 1291 |
+
parser.add_argument("--revision", type=str, default=None)
|
| 1292 |
+
parser.add_argument("--variant", type=str, default=None)
|
| 1293 |
+
# LoRA 与 Rectified
|
| 1294 |
+
parser.add_argument("--lora_path", type=str, default=None, help="LoRA 权重路径(文件或目录)")
|
| 1295 |
+
parser.add_argument("--rectified_weights", type=str, default=None, help="Rectified(SIT) 权重路径(文件或目录)")
|
| 1296 |
+
parser.add_argument("--num_sit_layers", type=int, default=1, help="与训练一致的 SIT 层数")
|
| 1297 |
+
# 采样
|
| 1298 |
+
parser.add_argument("--num_inference_steps", type=int, default=28)
|
| 1299 |
+
parser.add_argument("--guidance_scale", type=float, default=7.0)
|
| 1300 |
+
parser.add_argument("--height", type=int, default=1024)
|
| 1301 |
+
parser.add_argument("--width", type=int, default=1024)
|
| 1302 |
+
parser.add_argument("--per_proc_batch_size", type=int, default=1)
|
| 1303 |
+
parser.add_argument("--images_per_caption", type=int, default=1)
|
| 1304 |
+
parser.add_argument("--max_samples", type=int, default=10000)
|
| 1305 |
+
parser.add_argument("--captions_jsonl", type=str, required=True)
|
| 1306 |
+
parser.add_argument("--sample_dir", type=str, default="sd3_rectified_samples")
|
| 1307 |
+
parser.add_argument("--global_seed", type=int, default=42)
|
| 1308 |
+
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 1309 |
+
# 内存优化选项
|
| 1310 |
+
parser.add_argument("--enable_attention_slicing", action="store_true", help="启用 attention slicing 以节省显存")
|
| 1311 |
+
parser.add_argument("--enable_vae_slicing", action="store_true", help="启用 VAE slicing 以节省显存")
|
| 1312 |
+
parser.add_argument("--enable_cpu_offload", action="store_true", help="启用 CPU offload 以节省显存")
|
| 1313 |
+
|
| 1314 |
+
args = parser.parse_args()
|
| 1315 |
+
main(args)
|
| 1316 |
+
|
| 1317 |
+
|
sd3_rectified_samples_batch2_2200005011.01.01.0cfg_cond_true.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Inception Score: 37.646392822265625
|
| 2 |
+
FID: 21.19386100577333
|
| 3 |
+
sFID: 71.79977998851734
|
| 4 |
+
Precision: 0.690407122136641
|
| 5 |
+
Recall: 0.358997247638176
|
train_lora_sd3.py
ADDED
|
@@ -0,0 +1,1597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""SD3 LoRA fine-tuning script for text2image generation."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import copy
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
import shutil
|
| 26 |
+
from contextlib import nullcontext
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
import datasets
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
import torch.utils.checkpoint
|
| 34 |
+
import transformers
|
| 35 |
+
from accelerate import Accelerator
|
| 36 |
+
from accelerate.logging import get_logger
|
| 37 |
+
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
|
| 38 |
+
from datasets import load_dataset
|
| 39 |
+
from huggingface_hub import create_repo, upload_folder
|
| 40 |
+
from packaging import version
|
| 41 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 42 |
+
from peft.utils import get_peft_model_state_dict
|
| 43 |
+
from PIL import Image
|
| 44 |
+
from torchvision import transforms
|
| 45 |
+
from torchvision.transforms.functional import crop
|
| 46 |
+
from tqdm.auto import tqdm
|
| 47 |
+
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
|
| 48 |
+
|
| 49 |
+
import diffusers
|
| 50 |
+
from diffusers import (
|
| 51 |
+
AutoencoderKL,
|
| 52 |
+
FlowMatchEulerDiscreteScheduler,
|
| 53 |
+
SD3Transformer2DModel,
|
| 54 |
+
StableDiffusion3Pipeline,
|
| 55 |
+
)
|
| 56 |
+
from diffusers.optimization import get_scheduler
|
| 57 |
+
from diffusers.training_utils import (
|
| 58 |
+
_set_state_dict_into_text_encoder,
|
| 59 |
+
cast_training_params,
|
| 60 |
+
compute_density_for_timestep_sampling,
|
| 61 |
+
compute_loss_weighting_for_sd3,
|
| 62 |
+
free_memory,
|
| 63 |
+
)
|
| 64 |
+
from diffusers.utils import (
|
| 65 |
+
check_min_version,
|
| 66 |
+
convert_unet_state_dict_to_peft,
|
| 67 |
+
is_wandb_available,
|
| 68 |
+
)
|
| 69 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 70 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 71 |
+
|
| 72 |
+
if is_wandb_available():
|
| 73 |
+
import wandb
|
| 74 |
+
|
| 75 |
+
# Check minimum diffusers version
|
| 76 |
+
check_min_version("0.30.0")
|
| 77 |
+
|
| 78 |
+
logger = get_logger(__name__)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def save_model_card(
|
| 82 |
+
repo_id: str,
|
| 83 |
+
images: list = None,
|
| 84 |
+
base_model: str = None,
|
| 85 |
+
dataset_name: str = None,
|
| 86 |
+
train_text_encoder: bool = False,
|
| 87 |
+
repo_folder: str = None,
|
| 88 |
+
vae_path: str = None,
|
| 89 |
+
):
|
| 90 |
+
"""Save model card for SD3 LoRA model."""
|
| 91 |
+
img_str = ""
|
| 92 |
+
if images is not None:
|
| 93 |
+
for i, image in enumerate(images):
|
| 94 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 95 |
+
img_str += f"\n"
|
| 96 |
+
|
| 97 |
+
model_description = f"""
|
| 98 |
+
# SD3 LoRA text2image fine-tuning - {repo_id}
|
| 99 |
+
|
| 100 |
+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
| 101 |
+
{img_str}
|
| 102 |
+
|
| 103 |
+
LoRA for the text encoder was enabled: {train_text_encoder}.
|
| 104 |
+
|
| 105 |
+
Special VAE used for training: {vae_path}.
|
| 106 |
+
"""
|
| 107 |
+
model_card = load_or_create_model_card(
|
| 108 |
+
repo_id_or_path=repo_id,
|
| 109 |
+
from_training=True,
|
| 110 |
+
license="other",
|
| 111 |
+
base_model=base_model,
|
| 112 |
+
model_description=model_description,
|
| 113 |
+
inference=True,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
tags = [
|
| 117 |
+
"stable-diffusion-3",
|
| 118 |
+
"stable-diffusion-3-diffusers",
|
| 119 |
+
"text-to-image",
|
| 120 |
+
"diffusers",
|
| 121 |
+
"diffusers-training",
|
| 122 |
+
"lora",
|
| 123 |
+
"sd3",
|
| 124 |
+
]
|
| 125 |
+
model_card = populate_model_card(model_card, tags=tags)
|
| 126 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def log_validation(
|
| 130 |
+
pipeline,
|
| 131 |
+
args,
|
| 132 |
+
accelerator,
|
| 133 |
+
epoch,
|
| 134 |
+
is_final_validation=False,
|
| 135 |
+
global_step=None,
|
| 136 |
+
):
|
| 137 |
+
"""Run validation and log images."""
|
| 138 |
+
logger.info(
|
| 139 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 140 |
+
f" {args.validation_prompt}."
|
| 141 |
+
)
|
| 142 |
+
pipeline = pipeline.to(accelerator.device)
|
| 143 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 144 |
+
|
| 145 |
+
# run inference
|
| 146 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
| 147 |
+
pipeline_args = {"prompt": args.validation_prompt}
|
| 148 |
+
|
| 149 |
+
if torch.backends.mps.is_available():
|
| 150 |
+
autocast_ctx = nullcontext()
|
| 151 |
+
else:
|
| 152 |
+
autocast_ctx = torch.autocast(accelerator.device.type)
|
| 153 |
+
|
| 154 |
+
with autocast_ctx:
|
| 155 |
+
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
| 156 |
+
|
| 157 |
+
# Save images to output directory
|
| 158 |
+
if accelerator.is_main_process:
|
| 159 |
+
validation_dir = os.path.join(args.output_dir, "validation_images")
|
| 160 |
+
os.makedirs(validation_dir, exist_ok=True)
|
| 161 |
+
for i, image in enumerate(images):
|
| 162 |
+
# Create filename with step and epoch information
|
| 163 |
+
if global_step is not None:
|
| 164 |
+
filename = f"validation_step_{global_step}_epoch_{epoch}_img_{i}.png"
|
| 165 |
+
else:
|
| 166 |
+
filename = f"validation_epoch_{epoch}_img_{i}.png"
|
| 167 |
+
|
| 168 |
+
image_path = os.path.join(validation_dir, filename)
|
| 169 |
+
image.save(image_path)
|
| 170 |
+
logger.info(f"Saved validation image: {image_path}")
|
| 171 |
+
|
| 172 |
+
for tracker in accelerator.trackers if hasattr(accelerator, 'trackers') and accelerator.trackers else []:
|
| 173 |
+
phase_name = "test" if is_final_validation else "validation"
|
| 174 |
+
try:
|
| 175 |
+
if tracker.name == "tensorboard":
|
| 176 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 177 |
+
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
|
| 178 |
+
if tracker.name == "wandb":
|
| 179 |
+
tracker.log(
|
| 180 |
+
{
|
| 181 |
+
phase_name: [
|
| 182 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
| 183 |
+
]
|
| 184 |
+
}
|
| 185 |
+
)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.warning(f"Failed to log to {tracker.name}: {e}")
|
| 188 |
+
|
| 189 |
+
del pipeline
|
| 190 |
+
free_memory()
|
| 191 |
+
return images
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def import_model_class_from_model_name_or_path(
|
| 195 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 196 |
+
):
|
| 197 |
+
"""Import the correct text encoder class."""
|
| 198 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 199 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 200 |
+
)
|
| 201 |
+
model_class = text_encoder_config.architectures[0]
|
| 202 |
+
|
| 203 |
+
if model_class == "CLIPTextModelWithProjection":
|
| 204 |
+
from transformers import CLIPTextModelWithProjection
|
| 205 |
+
return CLIPTextModelWithProjection
|
| 206 |
+
elif model_class == "T5EncoderModel":
|
| 207 |
+
from transformers import T5EncoderModel
|
| 208 |
+
return T5EncoderModel
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def parse_args(input_args=None):
|
| 214 |
+
"""Parse command line arguments."""
|
| 215 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA training script.")
|
| 216 |
+
|
| 217 |
+
# Model arguments
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--pretrained_model_name_or_path",
|
| 220 |
+
type=str,
|
| 221 |
+
default=None,
|
| 222 |
+
required=True,
|
| 223 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--revision",
|
| 227 |
+
type=str,
|
| 228 |
+
default=None,
|
| 229 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--variant",
|
| 233 |
+
type=str,
|
| 234 |
+
default=None,
|
| 235 |
+
help="Variant of the model files, e.g. fp16",
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Dataset arguments
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--dataset_name",
|
| 241 |
+
type=str,
|
| 242 |
+
default=None,
|
| 243 |
+
help="The name of the Dataset to train on.",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--dataset_config_name",
|
| 247 |
+
type=str,
|
| 248 |
+
default=None,
|
| 249 |
+
help="The config of the Dataset.",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--train_data_dir",
|
| 253 |
+
type=str,
|
| 254 |
+
default=None,
|
| 255 |
+
help="A folder containing the training data.",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--image_column",
|
| 259 |
+
type=str,
|
| 260 |
+
default="image",
|
| 261 |
+
help="The column of the dataset containing an image."
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--caption_column",
|
| 265 |
+
type=str,
|
| 266 |
+
default="caption",
|
| 267 |
+
help="The column of the dataset containing a caption.",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Training arguments
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--max_sequence_length",
|
| 273 |
+
type=int,
|
| 274 |
+
default=77,
|
| 275 |
+
help="Maximum sequence length to use with the T5 text encoder",
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--validation_prompt",
|
| 279 |
+
type=str,
|
| 280 |
+
default=None,
|
| 281 |
+
help="A prompt used during validation.",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--num_validation_images",
|
| 285 |
+
type=int,
|
| 286 |
+
default=4,
|
| 287 |
+
help="Number of images for validation.",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--validation_epochs",
|
| 291 |
+
type=int,
|
| 292 |
+
default=1,
|
| 293 |
+
help="Run validation every X epochs.",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--max_train_samples",
|
| 297 |
+
type=int,
|
| 298 |
+
default=None,
|
| 299 |
+
help="Truncate the number of training examples.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--output_dir",
|
| 303 |
+
type=str,
|
| 304 |
+
default="sd3-lora-finetuned",
|
| 305 |
+
help="Output directory for model predictions and checkpoints.",
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
"--cache_dir",
|
| 309 |
+
type=str,
|
| 310 |
+
default=None,
|
| 311 |
+
help="Directory to store downloaded models and datasets.",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--seed",
|
| 315 |
+
type=int,
|
| 316 |
+
default=None,
|
| 317 |
+
help="A seed for reproducible training."
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--resolution",
|
| 321 |
+
type=int,
|
| 322 |
+
default=1024,
|
| 323 |
+
help="Image resolution for training.",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--center_crop",
|
| 327 |
+
default=False,
|
| 328 |
+
action="store_true",
|
| 329 |
+
help="Whether to center crop input images.",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--random_flip",
|
| 333 |
+
action="store_true",
|
| 334 |
+
help="Whether to randomly flip images horizontally.",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--train_text_encoder",
|
| 338 |
+
action="store_true",
|
| 339 |
+
help="Whether to train the text encoder.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--train_batch_size",
|
| 343 |
+
type=int,
|
| 344 |
+
default=16,
|
| 345 |
+
help="Batch size for training dataloader."
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--num_train_epochs",
|
| 349 |
+
type=int,
|
| 350 |
+
default=100
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--max_train_steps",
|
| 354 |
+
type=int,
|
| 355 |
+
default=None,
|
| 356 |
+
help="Total number of training steps.",
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--checkpointing_steps",
|
| 360 |
+
type=int,
|
| 361 |
+
default=500,
|
| 362 |
+
help="Save checkpoint every X updates.",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--checkpoints_total_limit",
|
| 366 |
+
type=int,
|
| 367 |
+
default=None,
|
| 368 |
+
help="Max number of checkpoints to store.",
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--resume_from_checkpoint",
|
| 372 |
+
type=str,
|
| 373 |
+
default=None,
|
| 374 |
+
help="Path to resume training from checkpoint.",
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--gradient_accumulation_steps",
|
| 378 |
+
type=int,
|
| 379 |
+
default=1,
|
| 380 |
+
help="Number of update steps to accumulate.",
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
"--gradient_checkpointing",
|
| 384 |
+
action="store_true",
|
| 385 |
+
help="Use gradient checkpointing to save memory.",
|
| 386 |
+
)
|
| 387 |
+
parser.add_argument(
|
| 388 |
+
"--learning_rate",
|
| 389 |
+
type=float,
|
| 390 |
+
default=1e-4,
|
| 391 |
+
help="Initial learning rate.",
|
| 392 |
+
)
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--scale_lr",
|
| 395 |
+
action="store_true",
|
| 396 |
+
default=False,
|
| 397 |
+
help="Scale learning rate by number of GPUs, etc.",
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--lr_scheduler",
|
| 401 |
+
type=str,
|
| 402 |
+
default="constant",
|
| 403 |
+
help="Learning rate scheduler type.",
|
| 404 |
+
)
|
| 405 |
+
parser.add_argument(
|
| 406 |
+
"--lr_warmup_steps",
|
| 407 |
+
type=int,
|
| 408 |
+
default=500,
|
| 409 |
+
help="Number of warmup steps."
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# SD3 specific arguments
|
| 413 |
+
parser.add_argument(
|
| 414 |
+
"--weighting_scheme",
|
| 415 |
+
type=str,
|
| 416 |
+
default="logit_normal",
|
| 417 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
| 418 |
+
help="Weighting scheme for flow matching loss.",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--logit_mean",
|
| 422 |
+
type=float,
|
| 423 |
+
default=0.0,
|
| 424 |
+
help="Mean for logit_normal weighting."
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
"--logit_std",
|
| 428 |
+
type=float,
|
| 429 |
+
default=1.0,
|
| 430 |
+
help="Std for logit_normal weighting."
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
"--mode_scale",
|
| 434 |
+
type=float,
|
| 435 |
+
default=1.29,
|
| 436 |
+
help="Scale for mode weighting scheme.",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--precondition_outputs",
|
| 440 |
+
type=int,
|
| 441 |
+
default=1,
|
| 442 |
+
help="Whether to precondition model outputs.",
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Optimization arguments
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--allow_tf32",
|
| 448 |
+
action="store_true",
|
| 449 |
+
help="Allow TF32 on Ampere GPUs.",
|
| 450 |
+
)
|
| 451 |
+
parser.add_argument(
|
| 452 |
+
"--dataloader_num_workers",
|
| 453 |
+
type=int,
|
| 454 |
+
default=0,
|
| 455 |
+
help="Number of data loading workers.",
|
| 456 |
+
)
|
| 457 |
+
parser.add_argument(
|
| 458 |
+
"--use_8bit_adam",
|
| 459 |
+
action="store_true",
|
| 460 |
+
help="Use 8-bit Adam optimizer."
|
| 461 |
+
)
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--adam_beta1",
|
| 464 |
+
type=float,
|
| 465 |
+
default=0.9,
|
| 466 |
+
help="Beta1 for Adam optimizer."
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--adam_beta2",
|
| 470 |
+
type=float,
|
| 471 |
+
default=0.999,
|
| 472 |
+
help="Beta2 for Adam optimizer."
|
| 473 |
+
)
|
| 474 |
+
parser.add_argument(
|
| 475 |
+
"--adam_weight_decay",
|
| 476 |
+
type=float,
|
| 477 |
+
default=1e-2,
|
| 478 |
+
help="Weight decay for Adam."
|
| 479 |
+
)
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--adam_epsilon",
|
| 482 |
+
type=float,
|
| 483 |
+
default=1e-08,
|
| 484 |
+
help="Epsilon for Adam optimizer."
|
| 485 |
+
)
|
| 486 |
+
parser.add_argument(
|
| 487 |
+
"--max_grad_norm",
|
| 488 |
+
default=1.0,
|
| 489 |
+
type=float,
|
| 490 |
+
help="Max gradient norm."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Hub and logging arguments
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--push_to_hub",
|
| 496 |
+
action="store_true",
|
| 497 |
+
help="Push model to the Hub."
|
| 498 |
+
)
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--hub_token",
|
| 501 |
+
type=str,
|
| 502 |
+
default=None,
|
| 503 |
+
help="Token for Model Hub."
|
| 504 |
+
)
|
| 505 |
+
parser.add_argument(
|
| 506 |
+
"--hub_model_id",
|
| 507 |
+
type=str,
|
| 508 |
+
default=None,
|
| 509 |
+
help="Repository name for the Hub.",
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--logging_dir",
|
| 513 |
+
type=str,
|
| 514 |
+
default="logs",
|
| 515 |
+
help="TensorBoard log directory.",
|
| 516 |
+
)
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
"--report_to",
|
| 519 |
+
type=str,
|
| 520 |
+
default="tensorboard",
|
| 521 |
+
help="Logging integration to use.",
|
| 522 |
+
)
|
| 523 |
+
parser.add_argument(
|
| 524 |
+
"--mixed_precision",
|
| 525 |
+
type=str,
|
| 526 |
+
default=None,
|
| 527 |
+
choices=["no", "fp16", "bf16"],
|
| 528 |
+
help="Mixed precision type.",
|
| 529 |
+
)
|
| 530 |
+
parser.add_argument(
|
| 531 |
+
"--local_rank",
|
| 532 |
+
type=int,
|
| 533 |
+
default=-1,
|
| 534 |
+
help="Local rank for distributed training."
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# LoRA arguments
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
"--rank",
|
| 540 |
+
type=int,
|
| 541 |
+
default=64,
|
| 542 |
+
help="LoRA rank dimension.",
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if input_args is not None:
|
| 546 |
+
args = parser.parse_args(input_args)
|
| 547 |
+
else:
|
| 548 |
+
args = parser.parse_args()
|
| 549 |
+
|
| 550 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 551 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 552 |
+
args.local_rank = env_local_rank
|
| 553 |
+
|
| 554 |
+
# Sanity checks
|
| 555 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 556 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 557 |
+
|
| 558 |
+
return args
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
DATASET_NAME_MAPPING = {
|
| 562 |
+
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def tokenize_prompt(tokenizer, prompt):
|
| 567 |
+
"""Tokenize prompt using the given tokenizer."""
|
| 568 |
+
text_inputs = tokenizer(
|
| 569 |
+
prompt,
|
| 570 |
+
padding="max_length",
|
| 571 |
+
max_length=77,
|
| 572 |
+
truncation=True,
|
| 573 |
+
return_tensors="pt",
|
| 574 |
+
)
|
| 575 |
+
return text_inputs.input_ids
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _encode_prompt_with_t5(
|
| 579 |
+
text_encoder,
|
| 580 |
+
tokenizer,
|
| 581 |
+
max_sequence_length,
|
| 582 |
+
prompt=None,
|
| 583 |
+
num_images_per_prompt=1,
|
| 584 |
+
device=None,
|
| 585 |
+
text_input_ids=None,
|
| 586 |
+
):
|
| 587 |
+
"""Encode prompt using T5 text encoder."""
|
| 588 |
+
if prompt is not None:
|
| 589 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 590 |
+
batch_size = len(prompt)
|
| 591 |
+
else:
|
| 592 |
+
# When prompt is None, we must have text_input_ids
|
| 593 |
+
if text_input_ids is None:
|
| 594 |
+
raise ValueError("Either prompt or text_input_ids must be provided")
|
| 595 |
+
batch_size = text_input_ids.shape[0]
|
| 596 |
+
|
| 597 |
+
if tokenizer is not None and prompt is not None:
|
| 598 |
+
text_inputs = tokenizer(
|
| 599 |
+
prompt,
|
| 600 |
+
padding="max_length",
|
| 601 |
+
max_length=max_sequence_length,
|
| 602 |
+
truncation=True,
|
| 603 |
+
add_special_tokens=True,
|
| 604 |
+
return_tensors="pt",
|
| 605 |
+
)
|
| 606 |
+
text_input_ids = text_inputs.input_ids
|
| 607 |
+
else:
|
| 608 |
+
if text_input_ids is None:
|
| 609 |
+
raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
|
| 610 |
+
|
| 611 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 612 |
+
dtype = text_encoder.dtype
|
| 613 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 614 |
+
|
| 615 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 616 |
+
# duplicate text embeddings for each generation per prompt
|
| 617 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 618 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 619 |
+
|
| 620 |
+
return prompt_embeds
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def _encode_prompt_with_clip(
|
| 624 |
+
text_encoder,
|
| 625 |
+
tokenizer,
|
| 626 |
+
prompt: str,
|
| 627 |
+
device=None,
|
| 628 |
+
text_input_ids=None,
|
| 629 |
+
num_images_per_prompt: int = 1,
|
| 630 |
+
):
|
| 631 |
+
"""Encode prompt using CLIP text encoder."""
|
| 632 |
+
if prompt is not None:
|
| 633 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 634 |
+
batch_size = len(prompt)
|
| 635 |
+
else:
|
| 636 |
+
# When prompt is None, we must have text_input_ids
|
| 637 |
+
if text_input_ids is None:
|
| 638 |
+
raise ValueError("Either prompt or text_input_ids must be provided")
|
| 639 |
+
batch_size = text_input_ids.shape[0]
|
| 640 |
+
|
| 641 |
+
if tokenizer is not None and prompt is not None:
|
| 642 |
+
text_inputs = tokenizer(
|
| 643 |
+
prompt,
|
| 644 |
+
padding="max_length",
|
| 645 |
+
max_length=77,
|
| 646 |
+
truncation=True,
|
| 647 |
+
return_tensors="pt",
|
| 648 |
+
)
|
| 649 |
+
text_input_ids = text_inputs.input_ids
|
| 650 |
+
else:
|
| 651 |
+
if text_input_ids is None:
|
| 652 |
+
raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
|
| 653 |
+
|
| 654 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 655 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 656 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 657 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
| 658 |
+
|
| 659 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 660 |
+
# duplicate text embeddings for each generation per prompt
|
| 661 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 662 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 663 |
+
|
| 664 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def encode_prompt(
|
| 668 |
+
text_encoders,
|
| 669 |
+
tokenizers,
|
| 670 |
+
prompt: str,
|
| 671 |
+
max_sequence_length,
|
| 672 |
+
device=None,
|
| 673 |
+
num_images_per_prompt: int = 1,
|
| 674 |
+
text_input_ids_list=None,
|
| 675 |
+
):
|
| 676 |
+
"""Encode prompt using all three text encoders (SD3 architecture)."""
|
| 677 |
+
if prompt is not None:
|
| 678 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 679 |
+
|
| 680 |
+
# Process CLIP encoders (first two)
|
| 681 |
+
clip_tokenizers = tokenizers[:2]
|
| 682 |
+
clip_text_encoders = text_encoders[:2]
|
| 683 |
+
|
| 684 |
+
clip_prompt_embeds_list = []
|
| 685 |
+
clip_pooled_prompt_embeds_list = []
|
| 686 |
+
|
| 687 |
+
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
|
| 688 |
+
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 689 |
+
text_encoder=text_encoder,
|
| 690 |
+
tokenizer=tokenizer,
|
| 691 |
+
prompt=prompt,
|
| 692 |
+
device=device if device is not None else text_encoder.device,
|
| 693 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 694 |
+
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
| 695 |
+
)
|
| 696 |
+
clip_prompt_embeds_list.append(prompt_embeds)
|
| 697 |
+
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
| 698 |
+
|
| 699 |
+
# Concatenate CLIP embeddings
|
| 700 |
+
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
|
| 701 |
+
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
|
| 702 |
+
|
| 703 |
+
# Process T5 encoder (third encoder)
|
| 704 |
+
t5_prompt_embed = _encode_prompt_with_t5(
|
| 705 |
+
text_encoders[-1],
|
| 706 |
+
tokenizers[-1],
|
| 707 |
+
max_sequence_length,
|
| 708 |
+
prompt=prompt,
|
| 709 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 710 |
+
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
|
| 711 |
+
device=device if device is not None else text_encoders[-1].device,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Pad CLIP embeddings to match T5 embedding dimension
|
| 715 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 716 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 717 |
+
)
|
| 718 |
+
# Concatenate all embeddings
|
| 719 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 720 |
+
|
| 721 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def load_dataset_from_jsonl(metadata_path, data_dir, accelerator=None):
|
| 725 |
+
"""
|
| 726 |
+
从 metadata.jsonl 文件加载数据集,避免扫描所有文件。
|
| 727 |
+
这对于大型数据集在分布式训练中非常重要。
|
| 728 |
+
|
| 729 |
+
注意:只让主进程读取 jsonl 文件,然后创建数据集。
|
| 730 |
+
其他进程会等待主进程完成后再继续。
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
metadata_path: metadata.jsonl 文件路径
|
| 734 |
+
data_dir: 数据集根目录
|
| 735 |
+
accelerator: Accelerator 对象,用于多进程同步
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
datasets.DatasetDict
|
| 739 |
+
"""
|
| 740 |
+
if accelerator is None or accelerator.is_main_process:
|
| 741 |
+
print(f"[INFO] Loading dataset from metadata.jsonl: {metadata_path}", flush=True)
|
| 742 |
+
|
| 743 |
+
# 读取 metadata.jsonl(只让主进程读取,避免多进程竞争)
|
| 744 |
+
data_list = []
|
| 745 |
+
if os.path.exists(metadata_path):
|
| 746 |
+
with open(metadata_path, 'r', encoding='utf-8') as f:
|
| 747 |
+
for line_num, line in enumerate(f):
|
| 748 |
+
try:
|
| 749 |
+
item = json.loads(line.strip())
|
| 750 |
+
file_name = item.get('file_name', '')
|
| 751 |
+
caption = item.get('caption', '')
|
| 752 |
+
|
| 753 |
+
# 构建完整路径
|
| 754 |
+
image_path = os.path.join(data_dir, file_name)
|
| 755 |
+
|
| 756 |
+
# 注意:这里不检查文件是否存在,因为:
|
| 757 |
+
# 1. 检查会非常慢(需要访问文件系统)
|
| 758 |
+
# 2. 在 DataLoader 中加载时会自然处理不存在的文件
|
| 759 |
+
# 3. 可以大大加快数据集加载速度
|
| 760 |
+
data_list.append({
|
| 761 |
+
'image': image_path,
|
| 762 |
+
'text': caption
|
| 763 |
+
})
|
| 764 |
+
|
| 765 |
+
# 每处理 100000 条记录打印一次进度(减少打印频率)
|
| 766 |
+
if (line_num + 1) % 100000 == 0 and (accelerator is None or accelerator.is_main_process):
|
| 767 |
+
print(f"[INFO] Processed {line_num + 1} entries from metadata.jsonl", flush=True)
|
| 768 |
+
|
| 769 |
+
except json.JSONDecodeError as e:
|
| 770 |
+
if accelerator is None or accelerator.is_main_process:
|
| 771 |
+
print(f"[WARNING] Skipping invalid JSON at line {line_num + 1}: {e}", flush=True)
|
| 772 |
+
continue
|
| 773 |
+
|
| 774 |
+
if accelerator is None or accelerator.is_main_process:
|
| 775 |
+
print(f"[INFO] Loaded {len(data_list)} image-caption pairs from metadata.jsonl", flush=True)
|
| 776 |
+
else:
|
| 777 |
+
raise FileNotFoundError(f"metadata.jsonl not found at: {metadata_path}")
|
| 778 |
+
|
| 779 |
+
# 创建数据集
|
| 780 |
+
# 注意:'image' 列存储的是路径字符串,不是 PIL Image 对象
|
| 781 |
+
# 图片会在预处理函数中延迟加载
|
| 782 |
+
dataset = datasets.Dataset.from_list(data_list)
|
| 783 |
+
|
| 784 |
+
return datasets.DatasetDict({'train': dataset})
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def main(args):
|
| 788 |
+
"""Main training function."""
|
| 789 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 790 |
+
raise ValueError(
|
| 791 |
+
"You cannot use both --report_to=wandb and --hub_token due to security risk."
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 795 |
+
|
| 796 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 797 |
+
raise ValueError(
|
| 798 |
+
"Mixed precision training with bfloat16 is not supported on MPS."
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
# GPU多卡训练检查
|
| 802 |
+
if torch.cuda.is_available():
|
| 803 |
+
num_gpus = torch.cuda.device_count()
|
| 804 |
+
print(f"Found {num_gpus} GPUs available")
|
| 805 |
+
if num_gpus > 1:
|
| 806 |
+
print(f"Multi-GPU training enabled with {num_gpus} GPUs")
|
| 807 |
+
else:
|
| 808 |
+
print("No CUDA GPUs found, training on CPU")
|
| 809 |
+
|
| 810 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 811 |
+
# 优化多GPU训练的DDP参数
|
| 812 |
+
kwargs = DistributedDataParallelKwargs(
|
| 813 |
+
find_unused_parameters=True,
|
| 814 |
+
gradient_as_bucket_view=True, # 提高多GPU训练效率
|
| 815 |
+
static_graph=False, # 动态图支持
|
| 816 |
+
)
|
| 817 |
+
accelerator = Accelerator(
|
| 818 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 819 |
+
mixed_precision=args.mixed_precision,
|
| 820 |
+
log_with=args.report_to,
|
| 821 |
+
project_config=accelerator_project_config,
|
| 822 |
+
kwargs_handlers=[kwargs],
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Logging setup
|
| 826 |
+
logging.basicConfig(
|
| 827 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 828 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 829 |
+
level=logging.INFO,
|
| 830 |
+
)
|
| 831 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 832 |
+
if accelerator.is_main_process:
|
| 833 |
+
print("[INFO] Accelerator initialized", flush=True)
|
| 834 |
+
|
| 835 |
+
# 记录多GPU训练信息
|
| 836 |
+
if accelerator.is_main_process:
|
| 837 |
+
logger.info(f"Number of processes: {accelerator.num_processes}")
|
| 838 |
+
logger.info(f"Distributed type: {accelerator.distributed_type}")
|
| 839 |
+
logger.info(f"Mixed precision: {accelerator.mixed_precision}")
|
| 840 |
+
if torch.cuda.is_available():
|
| 841 |
+
for i in range(torch.cuda.device_count()):
|
| 842 |
+
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 843 |
+
logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
|
| 844 |
+
|
| 845 |
+
if accelerator.is_local_main_process:
|
| 846 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 847 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 848 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 849 |
+
else:
|
| 850 |
+
datasets.utils.logging.set_verbosity_error()
|
| 851 |
+
transformers.utils.logging.set_verbosity_error()
|
| 852 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 853 |
+
|
| 854 |
+
# Set training seed
|
| 855 |
+
if args.seed is not None:
|
| 856 |
+
set_seed(args.seed)
|
| 857 |
+
if accelerator.is_main_process:
|
| 858 |
+
print(f"[INFO] Seed set to {args.seed}", flush=True)
|
| 859 |
+
|
| 860 |
+
# Create output directory
|
| 861 |
+
if accelerator.is_main_process:
|
| 862 |
+
if args.output_dir is not None:
|
| 863 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 864 |
+
|
| 865 |
+
if args.push_to_hub:
|
| 866 |
+
repo_id = create_repo(
|
| 867 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
| 868 |
+
exist_ok=True,
|
| 869 |
+
token=args.hub_token
|
| 870 |
+
).repo_id
|
| 871 |
+
|
| 872 |
+
if accelerator.is_main_process:
|
| 873 |
+
print("[INFO] Loading tokenizers...", flush=True)
|
| 874 |
+
|
| 875 |
+
# Load tokenizers (three for SD3)
|
| 876 |
+
tokenizer_one = CLIPTokenizer.from_pretrained(
|
| 877 |
+
args.pretrained_model_name_or_path,
|
| 878 |
+
subfolder="tokenizer",
|
| 879 |
+
revision=args.revision,
|
| 880 |
+
)
|
| 881 |
+
tokenizer_two = CLIPTokenizer.from_pretrained(
|
| 882 |
+
args.pretrained_model_name_or_path,
|
| 883 |
+
subfolder="tokenizer_2",
|
| 884 |
+
revision=args.revision,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
if accelerator.is_main_process:
|
| 888 |
+
print("[INFO] Tokenizers loaded. Loading text encoders, VAE, and transformer...", flush=True)
|
| 889 |
+
tokenizer_three = T5TokenizerFast.from_pretrained(
|
| 890 |
+
args.pretrained_model_name_or_path,
|
| 891 |
+
subfolder="tokenizer_3",
|
| 892 |
+
revision=args.revision,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
# Import text encoder classes
|
| 896 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 897 |
+
args.pretrained_model_name_or_path, args.revision
|
| 898 |
+
)
|
| 899 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 900 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
| 901 |
+
)
|
| 902 |
+
text_encoder_cls_three = import_model_class_from_model_name_or_path(
|
| 903 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
# Load models
|
| 907 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 908 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
| 909 |
+
)
|
| 910 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 911 |
+
|
| 912 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 913 |
+
args.pretrained_model_name_or_path,
|
| 914 |
+
subfolder="text_encoder",
|
| 915 |
+
revision=args.revision,
|
| 916 |
+
variant=args.variant
|
| 917 |
+
)
|
| 918 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 919 |
+
args.pretrained_model_name_or_path,
|
| 920 |
+
subfolder="text_encoder_2",
|
| 921 |
+
revision=args.revision,
|
| 922 |
+
variant=args.variant
|
| 923 |
+
)
|
| 924 |
+
text_encoder_three = text_encoder_cls_three.from_pretrained(
|
| 925 |
+
args.pretrained_model_name_or_path,
|
| 926 |
+
subfolder="text_encoder_3",
|
| 927 |
+
revision=args.revision,
|
| 928 |
+
variant=args.variant
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
vae = AutoencoderKL.from_pretrained(
|
| 932 |
+
args.pretrained_model_name_or_path,
|
| 933 |
+
subfolder="vae",
|
| 934 |
+
revision=args.revision,
|
| 935 |
+
variant=args.variant,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
| 939 |
+
args.pretrained_model_name_or_path,
|
| 940 |
+
subfolder="transformer",
|
| 941 |
+
revision=args.revision,
|
| 942 |
+
variant=args.variant
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
if accelerator.is_main_process:
|
| 946 |
+
print("[INFO] Text encoders, VAE, and transformer loaded", flush=True)
|
| 947 |
+
|
| 948 |
+
# Freeze non-trainable weights
|
| 949 |
+
transformer.requires_grad_(False)
|
| 950 |
+
vae.requires_grad_(False)
|
| 951 |
+
text_encoder_one.requires_grad_(False)
|
| 952 |
+
text_encoder_two.requires_grad_(False)
|
| 953 |
+
text_encoder_three.requires_grad_(False)
|
| 954 |
+
|
| 955 |
+
# Set precision
|
| 956 |
+
weight_dtype = torch.float32
|
| 957 |
+
if accelerator.mixed_precision == "fp16":
|
| 958 |
+
weight_dtype = torch.float16
|
| 959 |
+
elif accelerator.mixed_precision == "bf16":
|
| 960 |
+
weight_dtype = torch.bfloat16
|
| 961 |
+
|
| 962 |
+
# Move models to device
|
| 963 |
+
vae.to(accelerator.device, dtype=torch.float32) # VAE stays in fp32
|
| 964 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 965 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 966 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 967 |
+
text_encoder_three.to(accelerator.device, dtype=weight_dtype)
|
| 968 |
+
|
| 969 |
+
# Enable gradient checkpointing
|
| 970 |
+
if args.gradient_checkpointing:
|
| 971 |
+
transformer.enable_gradient_checkpointing()
|
| 972 |
+
if args.train_text_encoder:
|
| 973 |
+
text_encoder_one.gradient_checkpointing_enable()
|
| 974 |
+
text_encoder_two.gradient_checkpointing_enable()
|
| 975 |
+
|
| 976 |
+
# Configure LoRA for transformer
|
| 977 |
+
transformer_lora_config = LoraConfig(
|
| 978 |
+
r=args.rank,
|
| 979 |
+
lora_alpha=args.rank,
|
| 980 |
+
init_lora_weights="gaussian",
|
| 981 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 982 |
+
)
|
| 983 |
+
transformer.add_adapter(transformer_lora_config)
|
| 984 |
+
|
| 985 |
+
# Configure LoRA for text encoders if enabled
|
| 986 |
+
if args.train_text_encoder:
|
| 987 |
+
text_lora_config = LoraConfig(
|
| 988 |
+
r=args.rank,
|
| 989 |
+
lora_alpha=args.rank,
|
| 990 |
+
init_lora_weights="gaussian",
|
| 991 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
| 992 |
+
)
|
| 993 |
+
text_encoder_one.add_adapter(text_lora_config)
|
| 994 |
+
text_encoder_two.add_adapter(text_lora_config)
|
| 995 |
+
# Note: T5 encoder typically doesn't use LoRA
|
| 996 |
+
|
| 997 |
+
def unwrap_model(model):
|
| 998 |
+
model = accelerator.unwrap_model(model)
|
| 999 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 1000 |
+
return model
|
| 1001 |
+
|
| 1002 |
+
# Enable TF32 for faster training
|
| 1003 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
| 1004 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 1005 |
+
|
| 1006 |
+
# Scale learning rate
|
| 1007 |
+
if args.scale_lr:
|
| 1008 |
+
args.learning_rate = (
|
| 1009 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
# Cast trainable parameters to float32
|
| 1013 |
+
if args.mixed_precision == "fp16":
|
| 1014 |
+
models = [transformer]
|
| 1015 |
+
if args.train_text_encoder:
|
| 1016 |
+
models.extend([text_encoder_one, text_encoder_two])
|
| 1017 |
+
cast_training_params(models, dtype=torch.float32)
|
| 1018 |
+
|
| 1019 |
+
# Setup optimizer
|
| 1020 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
| 1021 |
+
if args.train_text_encoder:
|
| 1022 |
+
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
| 1023 |
+
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
| 1024 |
+
params_to_optimize = (
|
| 1025 |
+
transformer_lora_parameters
|
| 1026 |
+
+ text_lora_parameters_one
|
| 1027 |
+
+ text_lora_parameters_two
|
| 1028 |
+
)
|
| 1029 |
+
else:
|
| 1030 |
+
params_to_optimize = transformer_lora_parameters
|
| 1031 |
+
|
| 1032 |
+
# Create optimizer
|
| 1033 |
+
if args.use_8bit_adam:
|
| 1034 |
+
try:
|
| 1035 |
+
import bitsandbytes as bnb
|
| 1036 |
+
except ImportError:
|
| 1037 |
+
raise ImportError("To use 8-bit Adam, install bitsandbytes: pip install bitsandbytes")
|
| 1038 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 1039 |
+
else:
|
| 1040 |
+
optimizer_class = torch.optim.AdamW
|
| 1041 |
+
|
| 1042 |
+
optimizer = optimizer_class(
|
| 1043 |
+
params_to_optimize,
|
| 1044 |
+
lr=args.learning_rate,
|
| 1045 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 1046 |
+
weight_decay=args.adam_weight_decay,
|
| 1047 |
+
eps=args.adam_epsilon,
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
if accelerator.is_main_process:
|
| 1051 |
+
print("[INFO] Optimizer created. Loading dataset...", flush=True)
|
| 1052 |
+
|
| 1053 |
+
# Load dataset - 使用 main_process_first 避免多进程竞争
|
| 1054 |
+
# 优先使用 metadata.jsonl 文件,避免扫描所有文件
|
| 1055 |
+
with accelerator.main_process_first():
|
| 1056 |
+
metadata_path = None
|
| 1057 |
+
if args.train_data_dir is not None:
|
| 1058 |
+
# 检查是否存在 metadata.jsonl
|
| 1059 |
+
potential_metadata = os.path.join(args.train_data_dir, "metadata.jsonl")
|
| 1060 |
+
if os.path.exists(potential_metadata):
|
| 1061 |
+
metadata_path = potential_metadata
|
| 1062 |
+
|
| 1063 |
+
if metadata_path is not None:
|
| 1064 |
+
# 使用 metadata.jsonl 加载数据集(更高效,避免扫描所有文件)
|
| 1065 |
+
if accelerator.is_main_process:
|
| 1066 |
+
print(f"[INFO] Found metadata.jsonl, using efficient loading method", flush=True)
|
| 1067 |
+
dataset = load_dataset_from_jsonl(metadata_path, args.train_data_dir, accelerator)
|
| 1068 |
+
elif args.dataset_name is not None:
|
| 1069 |
+
dataset = load_dataset(
|
| 1070 |
+
args.dataset_name,
|
| 1071 |
+
args.dataset_config_name,
|
| 1072 |
+
cache_dir=args.cache_dir,
|
| 1073 |
+
data_dir=args.train_data_dir
|
| 1074 |
+
)
|
| 1075 |
+
else:
|
| 1076 |
+
# 回退到 imagefolder(可能会很慢)
|
| 1077 |
+
if accelerator.is_main_process:
|
| 1078 |
+
print("[WARNING] No metadata.jsonl found, using imagefolder (may be slow for large datasets)", flush=True)
|
| 1079 |
+
data_files = {}
|
| 1080 |
+
if args.train_data_dir is not None:
|
| 1081 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 1082 |
+
dataset = load_dataset(
|
| 1083 |
+
"imagefolder",
|
| 1084 |
+
data_files=data_files,
|
| 1085 |
+
cache_dir=args.cache_dir,
|
| 1086 |
+
)
|
| 1087 |
+
if accelerator.is_main_process:
|
| 1088 |
+
print("[INFO] Dataset loaded successfully.", flush=True)
|
| 1089 |
+
|
| 1090 |
+
# 确保所有进程等待数据集加载完成
|
| 1091 |
+
accelerator.wait_for_everyone()
|
| 1092 |
+
|
| 1093 |
+
if accelerator.is_main_process:
|
| 1094 |
+
print("[INFO] All processes synchronized. Building transforms and DataLoader...", flush=True)
|
| 1095 |
+
|
| 1096 |
+
# Preprocessing
|
| 1097 |
+
column_names = dataset["train"].column_names
|
| 1098 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
| 1099 |
+
|
| 1100 |
+
if accelerator.is_main_process:
|
| 1101 |
+
print(f"[INFO] Dataset columns: {column_names}", flush=True)
|
| 1102 |
+
|
| 1103 |
+
# 智能选择 image 列:优先使用指定的列,如果不存在则自动回退
|
| 1104 |
+
if args.image_column is not None and args.image_column in column_names:
|
| 1105 |
+
# 如果指定了列名且存在,使用指定的列
|
| 1106 |
+
image_column = args.image_column
|
| 1107 |
+
else:
|
| 1108 |
+
# 自动选择可用的 image 列
|
| 1109 |
+
if 'image' in column_names:
|
| 1110 |
+
image_column = 'image'
|
| 1111 |
+
else:
|
| 1112 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
| 1113 |
+
|
| 1114 |
+
# 如果用户指定了列名但不存在,给出警告
|
| 1115 |
+
if args.image_column is not None and args.image_column != image_column:
|
| 1116 |
+
if accelerator.is_main_process:
|
| 1117 |
+
print(f"[WARNING] Specified image_column '{args.image_column}' not found. Using '{image_column}' instead.", flush=True)
|
| 1118 |
+
|
| 1119 |
+
if accelerator.is_main_process:
|
| 1120 |
+
print(f"[INFO] Using image column: {image_column}", flush=True)
|
| 1121 |
+
|
| 1122 |
+
# 智能选择 caption 列:优先使用指定的列,如果不存在则自动回退
|
| 1123 |
+
if args.caption_column is not None and args.caption_column in column_names:
|
| 1124 |
+
# 如果指定了列名且存在,使用指定的列
|
| 1125 |
+
caption_column = args.caption_column
|
| 1126 |
+
else:
|
| 1127 |
+
# 自动选择可用的 caption 列
|
| 1128 |
+
if 'text' in column_names:
|
| 1129 |
+
caption_column = 'text'
|
| 1130 |
+
elif 'caption' in column_names:
|
| 1131 |
+
caption_column = 'caption'
|
| 1132 |
+
else:
|
| 1133 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else (column_names[1] if len(column_names) > 1 else column_names[0])
|
| 1134 |
+
|
| 1135 |
+
# 如果用户指定了列名但不存在,给出警告
|
| 1136 |
+
if args.caption_column is not None and args.caption_column != caption_column:
|
| 1137 |
+
if accelerator.is_main_process:
|
| 1138 |
+
print(f"[WARNING] Specified caption_column '{args.caption_column}' not found. Using '{caption_column}' instead.", flush=True)
|
| 1139 |
+
|
| 1140 |
+
if accelerator.is_main_process:
|
| 1141 |
+
print(f"[INFO] Using caption column: {caption_column}", flush=True)
|
| 1142 |
+
|
| 1143 |
+
def tokenize_captions(examples, is_train=True):
|
| 1144 |
+
captions = []
|
| 1145 |
+
for caption in examples[caption_column]:
|
| 1146 |
+
if isinstance(caption, str):
|
| 1147 |
+
captions.append(caption)
|
| 1148 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 1149 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 1150 |
+
else:
|
| 1151 |
+
raise ValueError(f"Caption column should contain strings or lists of strings.")
|
| 1152 |
+
|
| 1153 |
+
tokens_one = tokenize_prompt(tokenizer_one, captions)
|
| 1154 |
+
tokens_two = tokenize_prompt(tokenizer_two, captions)
|
| 1155 |
+
tokens_three = tokenize_prompt(tokenizer_three, captions)
|
| 1156 |
+
return tokens_one, tokens_two, tokens_three
|
| 1157 |
+
|
| 1158 |
+
# Image transforms
|
| 1159 |
+
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 1160 |
+
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
| 1161 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 1162 |
+
train_transforms = transforms.Compose([
|
| 1163 |
+
transforms.ToTensor(),
|
| 1164 |
+
transforms.Normalize([0.5], [0.5]),
|
| 1165 |
+
])
|
| 1166 |
+
|
| 1167 |
+
def preprocess_train(examples):
|
| 1168 |
+
# 处理图片:如果 image_column 中是路径字符串,则加载图片;如果是 PIL Image,则直接使用
|
| 1169 |
+
images = []
|
| 1170 |
+
for img in examples[image_column]:
|
| 1171 |
+
if isinstance(img, str):
|
| 1172 |
+
# 如果是路径字符串,加载图片
|
| 1173 |
+
try:
|
| 1174 |
+
img = Image.open(img).convert("RGB")
|
| 1175 |
+
except Exception as e:
|
| 1176 |
+
# 如果加载失败,创建一个占位符
|
| 1177 |
+
if accelerator.is_main_process:
|
| 1178 |
+
print(f"[WARNING] Failed to load image {img}: {e}", flush=True)
|
| 1179 |
+
img = Image.new('RGB', (args.resolution, args.resolution), color='black')
|
| 1180 |
+
elif hasattr(img, 'convert'):
|
| 1181 |
+
# 如果是 PIL Image,直接使用
|
| 1182 |
+
img = img.convert("RGB")
|
| 1183 |
+
else:
|
| 1184 |
+
raise ValueError(f"Unexpected image type: {type(img)}")
|
| 1185 |
+
images.append(img)
|
| 1186 |
+
original_sizes = []
|
| 1187 |
+
all_images = []
|
| 1188 |
+
crop_top_lefts = []
|
| 1189 |
+
|
| 1190 |
+
for image in images:
|
| 1191 |
+
original_sizes.append((image.height, image.width))
|
| 1192 |
+
image = train_resize(image)
|
| 1193 |
+
if args.random_flip and random.random() < 0.5:
|
| 1194 |
+
image = train_flip(image)
|
| 1195 |
+
if args.center_crop:
|
| 1196 |
+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
| 1197 |
+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
| 1198 |
+
image = train_crop(image)
|
| 1199 |
+
else:
|
| 1200 |
+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
| 1201 |
+
image = crop(image, y1, x1, h, w)
|
| 1202 |
+
crop_top_left = (y1, x1)
|
| 1203 |
+
crop_top_lefts.append(crop_top_left)
|
| 1204 |
+
image = train_transforms(image)
|
| 1205 |
+
all_images.append(image)
|
| 1206 |
+
|
| 1207 |
+
examples["original_sizes"] = original_sizes
|
| 1208 |
+
examples["crop_top_lefts"] = crop_top_lefts
|
| 1209 |
+
examples["pixel_values"] = all_images
|
| 1210 |
+
|
| 1211 |
+
tokens_one, tokens_two, tokens_three = tokenize_captions(examples)
|
| 1212 |
+
examples["input_ids_one"] = tokens_one
|
| 1213 |
+
examples["input_ids_two"] = tokens_two
|
| 1214 |
+
examples["input_ids_three"] = tokens_three
|
| 1215 |
+
return examples
|
| 1216 |
+
|
| 1217 |
+
with accelerator.main_process_first():
|
| 1218 |
+
if args.max_train_samples is not None:
|
| 1219 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 1220 |
+
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
|
| 1221 |
+
|
| 1222 |
+
def collate_fn(examples):
|
| 1223 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1224 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 1225 |
+
original_sizes = [example["original_sizes"] for example in examples]
|
| 1226 |
+
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
| 1227 |
+
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
| 1228 |
+
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
| 1229 |
+
input_ids_three = torch.stack([example["input_ids_three"] for example in examples])
|
| 1230 |
+
|
| 1231 |
+
return {
|
| 1232 |
+
"pixel_values": pixel_values,
|
| 1233 |
+
"input_ids_one": input_ids_one,
|
| 1234 |
+
"input_ids_two": input_ids_two,
|
| 1235 |
+
"input_ids_three": input_ids_three,
|
| 1236 |
+
"original_sizes": original_sizes,
|
| 1237 |
+
"crop_top_lefts": crop_top_lefts,
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
# 针对多GPU训练优化dataloader设置
|
| 1241 |
+
if args.dataloader_num_workers == 0 and accelerator.num_processes > 1:
|
| 1242 |
+
# 多GPU训练时自动设置数据加载器worker数量
|
| 1243 |
+
args.dataloader_num_workers = min(4, os.cpu_count() // accelerator.num_processes)
|
| 1244 |
+
logger.info(f"Auto-setting dataloader_num_workers to {args.dataloader_num_workers} for multi-GPU training")
|
| 1245 |
+
|
| 1246 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 1247 |
+
train_dataset,
|
| 1248 |
+
shuffle=True,
|
| 1249 |
+
collate_fn=collate_fn,
|
| 1250 |
+
batch_size=args.train_batch_size,
|
| 1251 |
+
num_workers=args.dataloader_num_workers,
|
| 1252 |
+
pin_memory=True, # 提高GPU数据传输效率
|
| 1253 |
+
persistent_workers=args.dataloader_num_workers > 0, # 保持worker进程活跃
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
if accelerator.is_main_process:
|
| 1257 |
+
print("[INFO] DataLoader ready. Computing training steps and scheduler...", flush=True)
|
| 1258 |
+
|
| 1259 |
+
# Scheduler and math around training steps
|
| 1260 |
+
overrode_max_train_steps = False
|
| 1261 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1262 |
+
if args.max_train_steps is None:
|
| 1263 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1264 |
+
overrode_max_train_steps = True
|
| 1265 |
+
|
| 1266 |
+
lr_scheduler = get_scheduler(
|
| 1267 |
+
args.lr_scheduler,
|
| 1268 |
+
optimizer=optimizer,
|
| 1269 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 1270 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
# Prepare everything with accelerator
|
| 1274 |
+
if args.train_text_encoder:
|
| 1275 |
+
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1276 |
+
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
|
| 1277 |
+
)
|
| 1278 |
+
else:
|
| 1279 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1280 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
# Recalculate training steps
|
| 1284 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1285 |
+
if overrode_max_train_steps:
|
| 1286 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1287 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 1288 |
+
|
| 1289 |
+
# Initialize trackers
|
| 1290 |
+
if accelerator.is_main_process:
|
| 1291 |
+
try:
|
| 1292 |
+
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
| 1293 |
+
except Exception as e:
|
| 1294 |
+
logger.warning(f"Failed to initialize trackers: {e}")
|
| 1295 |
+
logger.warning("Continuing without tracking. You can monitor training through console logs.")
|
| 1296 |
+
# Set report_to to None to avoid further tracking attempts
|
| 1297 |
+
args.report_to = None
|
| 1298 |
+
|
| 1299 |
+
# Train!
|
| 1300 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 1301 |
+
logger.info("***** Running training *****")
|
| 1302 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 1303 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 1304 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 1305 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 1306 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 1307 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 1308 |
+
logger.info(f" Number of GPU processes = {accelerator.num_processes}")
|
| 1309 |
+
if accelerator.num_processes > 1:
|
| 1310 |
+
logger.info(f" Effective batch size per GPU = {args.train_batch_size * args.gradient_accumulation_steps}")
|
| 1311 |
+
logger.info(f" Total effective batch size across all GPUs = {total_batch_size}")
|
| 1312 |
+
|
| 1313 |
+
global_step = 0
|
| 1314 |
+
first_epoch = 0
|
| 1315 |
+
if accelerator.is_main_process:
|
| 1316 |
+
print(
|
| 1317 |
+
f"[INFO] Training setup complete. num_examples={len(train_dataset)}, "
|
| 1318 |
+
f"max_train_steps={args.max_train_steps}, num_epochs={args.num_train_epochs}",
|
| 1319 |
+
flush=True,
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
# Resume from checkpoint if specified
|
| 1323 |
+
if args.resume_from_checkpoint:
|
| 1324 |
+
if args.resume_from_checkpoint != "latest":
|
| 1325 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 1326 |
+
else:
|
| 1327 |
+
dirs = os.listdir(args.output_dir)
|
| 1328 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 1329 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 1330 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 1331 |
+
|
| 1332 |
+
if path is None:
|
| 1333 |
+
accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
|
| 1334 |
+
args.resume_from_checkpoint = None
|
| 1335 |
+
initial_global_step = 0
|
| 1336 |
+
else:
|
| 1337 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 1338 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 1339 |
+
global_step = int(path.split("-")[1])
|
| 1340 |
+
initial_global_step = global_step
|
| 1341 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 1342 |
+
else:
|
| 1343 |
+
initial_global_step = 0
|
| 1344 |
+
|
| 1345 |
+
progress_bar = tqdm(
|
| 1346 |
+
range(0, args.max_train_steps),
|
| 1347 |
+
initial=initial_global_step,
|
| 1348 |
+
desc="Steps",
|
| 1349 |
+
disable=not accelerator.is_local_main_process,
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 1353 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 1354 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 1355 |
+
timesteps = timesteps.to(accelerator.device)
|
| 1356 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 1357 |
+
sigma = sigmas[step_indices].flatten()
|
| 1358 |
+
while len(sigma.shape) < n_dim:
|
| 1359 |
+
sigma = sigma.unsqueeze(-1)
|
| 1360 |
+
return sigma
|
| 1361 |
+
|
| 1362 |
+
# Training loop
|
| 1363 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1364 |
+
transformer.train()
|
| 1365 |
+
if args.train_text_encoder:
|
| 1366 |
+
text_encoder_one.train()
|
| 1367 |
+
text_encoder_two.train()
|
| 1368 |
+
|
| 1369 |
+
if accelerator.is_main_process:
|
| 1370 |
+
print(
|
| 1371 |
+
f"[INFO] Starting epoch {epoch + 1}/{args.num_train_epochs}, current global_step={global_step}",
|
| 1372 |
+
flush=True,
|
| 1373 |
+
)
|
| 1374 |
+
|
| 1375 |
+
train_loss = 0.0
|
| 1376 |
+
for step, batch in enumerate(train_dataloader):
|
| 1377 |
+
with accelerator.accumulate(transformer):
|
| 1378 |
+
# Convert images to latent space
|
| 1379 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 1380 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 1381 |
+
|
| 1382 |
+
# Apply VAE scaling
|
| 1383 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 1384 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 1385 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 1386 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 1387 |
+
|
| 1388 |
+
# Encode prompts
|
| 1389 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
| 1390 |
+
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
| 1391 |
+
tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three],
|
| 1392 |
+
prompt=None,
|
| 1393 |
+
max_sequence_length=args.max_sequence_length,
|
| 1394 |
+
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"], batch["input_ids_three"]],
|
| 1395 |
+
)
|
| 1396 |
+
|
| 1397 |
+
# Sample noise and timesteps
|
| 1398 |
+
noise = torch.randn_like(model_input)
|
| 1399 |
+
bsz = model_input.shape[0]
|
| 1400 |
+
|
| 1401 |
+
# Flow Matching timestep sampling
|
| 1402 |
+
u = compute_density_for_timestep_sampling(
|
| 1403 |
+
weighting_scheme=args.weighting_scheme,
|
| 1404 |
+
batch_size=bsz,
|
| 1405 |
+
logit_mean=args.logit_mean,
|
| 1406 |
+
logit_std=args.logit_std,
|
| 1407 |
+
mode_scale=args.mode_scale,
|
| 1408 |
+
)
|
| 1409 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 1410 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 1411 |
+
|
| 1412 |
+
# Flow Matching interpolation
|
| 1413 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 1414 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 1415 |
+
|
| 1416 |
+
# Predict using SD3 Transformer
|
| 1417 |
+
model_pred = transformer(
|
| 1418 |
+
hidden_states=noisy_model_input,
|
| 1419 |
+
timestep=timesteps,
|
| 1420 |
+
encoder_hidden_states=prompt_embeds,
|
| 1421 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1422 |
+
return_dict=False,
|
| 1423 |
+
)[0]
|
| 1424 |
+
|
| 1425 |
+
# Compute target for Flow Matching
|
| 1426 |
+
if args.precondition_outputs:
|
| 1427 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
| 1428 |
+
target = model_input
|
| 1429 |
+
else:
|
| 1430 |
+
target = noise - model_input
|
| 1431 |
+
|
| 1432 |
+
# Compute loss with weighting
|
| 1433 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 1434 |
+
loss = torch.mean(
|
| 1435 |
+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 1436 |
+
1,
|
| 1437 |
+
)
|
| 1438 |
+
loss = loss.mean()
|
| 1439 |
+
|
| 1440 |
+
# Gather loss across processes
|
| 1441 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 1442 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 1443 |
+
|
| 1444 |
+
# Backpropagate
|
| 1445 |
+
accelerator.backward(loss)
|
| 1446 |
+
if accelerator.sync_gradients:
|
| 1447 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 1448 |
+
optimizer.step()
|
| 1449 |
+
lr_scheduler.step()
|
| 1450 |
+
optimizer.zero_grad()
|
| 1451 |
+
|
| 1452 |
+
# Checks if the accelerator has performed an optimization step
|
| 1453 |
+
if accelerator.sync_gradients:
|
| 1454 |
+
progress_bar.update(1)
|
| 1455 |
+
global_step += 1
|
| 1456 |
+
if hasattr(accelerator, 'trackers') and accelerator.trackers:
|
| 1457 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 1458 |
+
train_loss = 0.0
|
| 1459 |
+
|
| 1460 |
+
if accelerator.is_main_process and global_step % 1000 == 0:
|
| 1461 |
+
print(
|
| 1462 |
+
f"[INFO] Optimization step completed at global_step={global_step}, "
|
| 1463 |
+
f"recent step_loss={loss.detach().item():.4f}",
|
| 1464 |
+
flush=True,
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
# Save checkpoint
|
| 1468 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
| 1469 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1470 |
+
if args.checkpoints_total_limit is not None:
|
| 1471 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1472 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1473 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1474 |
+
|
| 1475 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1476 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1477 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1478 |
+
logger.info(f"Removing {len(removing_checkpoints)} checkpoints")
|
| 1479 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1480 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1481 |
+
shutil.rmtree(removing_checkpoint)
|
| 1482 |
+
|
| 1483 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1484 |
+
accelerator.save_state(save_path)
|
| 1485 |
+
logger.info(f"Saved state to {save_path}")
|
| 1486 |
+
|
| 1487 |
+
# 同时保存标准的LoRA权重格式,方便采样时直接加载
|
| 1488 |
+
try:
|
| 1489 |
+
# 获取当前模型的LoRA权重
|
| 1490 |
+
unwrapped_transformer = unwrap_model(transformer)
|
| 1491 |
+
transformer_lora_layers = get_peft_model_state_dict(unwrapped_transformer)
|
| 1492 |
+
|
| 1493 |
+
text_encoder_lora_layers = None
|
| 1494 |
+
text_encoder_2_lora_layers = None
|
| 1495 |
+
if args.train_text_encoder:
|
| 1496 |
+
unwrapped_text_encoder_one = unwrap_model(text_encoder_one)
|
| 1497 |
+
unwrapped_text_encoder_two = unwrap_model(text_encoder_two)
|
| 1498 |
+
text_encoder_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_one)
|
| 1499 |
+
text_encoder_2_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_two)
|
| 1500 |
+
|
| 1501 |
+
# 保存为标准LoRA格式到checkpoint目录
|
| 1502 |
+
StableDiffusion3Pipeline.save_lora_weights(
|
| 1503 |
+
save_directory=save_path,
|
| 1504 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 1505 |
+
text_encoder_lora_layers=text_encoder_lora_layers,
|
| 1506 |
+
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
| 1507 |
+
)
|
| 1508 |
+
logger.info(f"Saved LoRA weights in standard format to {save_path}")
|
| 1509 |
+
except Exception as e:
|
| 1510 |
+
logger.warning(f"Failed to save LoRA weights in standard format: {e}")
|
| 1511 |
+
logger.warning("Checkpoint saved with accelerator format only. You can extract LoRA weights later.")
|
| 1512 |
+
|
| 1513 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1514 |
+
progress_bar.set_postfix(**logs)
|
| 1515 |
+
|
| 1516 |
+
if global_step >= args.max_train_steps:
|
| 1517 |
+
break
|
| 1518 |
+
|
| 1519 |
+
# Validation
|
| 1520 |
+
if accelerator.is_main_process:
|
| 1521 |
+
if args.validation_prompt is not None :#and epoch % args.validation_epochs == 0:
|
| 1522 |
+
print(f"[INFO] Running validation for epoch {epoch + 1}, global_step={global_step}", flush=True)
|
| 1523 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 1524 |
+
args.pretrained_model_name_or_path,
|
| 1525 |
+
vae=vae,
|
| 1526 |
+
text_encoder=unwrap_model(text_encoder_one),
|
| 1527 |
+
text_encoder_2=unwrap_model(text_encoder_two),
|
| 1528 |
+
text_encoder_3=unwrap_model(text_encoder_three),
|
| 1529 |
+
transformer=unwrap_model(transformer),
|
| 1530 |
+
revision=args.revision,
|
| 1531 |
+
variant=args.variant,
|
| 1532 |
+
torch_dtype=weight_dtype,
|
| 1533 |
+
)
|
| 1534 |
+
images = log_validation(pipeline, args, accelerator, epoch, global_step=global_step)
|
| 1535 |
+
del pipeline
|
| 1536 |
+
torch.cuda.empty_cache()
|
| 1537 |
+
|
| 1538 |
+
# Save final LoRA weights
|
| 1539 |
+
accelerator.wait_for_everyone()
|
| 1540 |
+
if accelerator.is_main_process:
|
| 1541 |
+
transformer = unwrap_model(transformer)
|
| 1542 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
| 1543 |
+
|
| 1544 |
+
if args.train_text_encoder:
|
| 1545 |
+
text_encoder_one = unwrap_model(text_encoder_one)
|
| 1546 |
+
text_encoder_two = unwrap_model(text_encoder_two)
|
| 1547 |
+
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
| 1548 |
+
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
| 1549 |
+
else:
|
| 1550 |
+
text_encoder_lora_layers = None
|
| 1551 |
+
text_encoder_2_lora_layers = None
|
| 1552 |
+
|
| 1553 |
+
StableDiffusion3Pipeline.save_lora_weights(
|
| 1554 |
+
save_directory=args.output_dir,
|
| 1555 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 1556 |
+
text_encoder_lora_layers=text_encoder_lora_layers,
|
| 1557 |
+
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
| 1558 |
+
)
|
| 1559 |
+
|
| 1560 |
+
# Final inference
|
| 1561 |
+
if args.mixed_precision == "fp16":
|
| 1562 |
+
vae.to(weight_dtype)
|
| 1563 |
+
|
| 1564 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 1565 |
+
args.pretrained_model_name_or_path,
|
| 1566 |
+
vae=vae,
|
| 1567 |
+
revision=args.revision,
|
| 1568 |
+
variant=args.variant,
|
| 1569 |
+
torch_dtype=weight_dtype,
|
| 1570 |
+
)
|
| 1571 |
+
pipeline.load_lora_weights(args.output_dir)
|
| 1572 |
+
|
| 1573 |
+
if args.validation_prompt and args.num_validation_images > 0:
|
| 1574 |
+
images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True, global_step=global_step)
|
| 1575 |
+
|
| 1576 |
+
if args.push_to_hub:
|
| 1577 |
+
save_model_card(
|
| 1578 |
+
repo_id,
|
| 1579 |
+
images=images,
|
| 1580 |
+
base_model=args.pretrained_model_name_or_path,
|
| 1581 |
+
dataset_name=args.dataset_name,
|
| 1582 |
+
train_text_encoder=args.train_text_encoder,
|
| 1583 |
+
repo_folder=args.output_dir,
|
| 1584 |
+
)
|
| 1585 |
+
upload_folder(
|
| 1586 |
+
repo_id=repo_id,
|
| 1587 |
+
folder_path=args.output_dir,
|
| 1588 |
+
commit_message="End of training",
|
| 1589 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
accelerator.end_training()
|
| 1593 |
+
|
| 1594 |
+
|
| 1595 |
+
if __name__ == "__main__":
|
| 1596 |
+
args = parse_args()
|
| 1597 |
+
main(args)
|
train_lora_sd3_new.py
ADDED
|
@@ -0,0 +1,1422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""SD3 LoRA fine-tuning script for text2image generation."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import copy
|
| 20 |
+
import logging
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import random
|
| 24 |
+
import shutil
|
| 25 |
+
from contextlib import nullcontext
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import datasets
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import torch.utils.checkpoint
|
| 33 |
+
import transformers
|
| 34 |
+
from accelerate import Accelerator
|
| 35 |
+
from accelerate.logging import get_logger
|
| 36 |
+
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
|
| 37 |
+
from datasets import load_dataset
|
| 38 |
+
from huggingface_hub import create_repo, upload_folder
|
| 39 |
+
from packaging import version
|
| 40 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 41 |
+
from peft.utils import get_peft_model_state_dict
|
| 42 |
+
from torchvision import transforms
|
| 43 |
+
from torchvision.transforms.functional import crop
|
| 44 |
+
from tqdm.auto import tqdm
|
| 45 |
+
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
|
| 46 |
+
|
| 47 |
+
import diffusers
|
| 48 |
+
from diffusers import (
|
| 49 |
+
AutoencoderKL,
|
| 50 |
+
FlowMatchEulerDiscreteScheduler,
|
| 51 |
+
SD3Transformer2DModel,
|
| 52 |
+
StableDiffusion3Pipeline,
|
| 53 |
+
)
|
| 54 |
+
from diffusers.optimization import get_scheduler
|
| 55 |
+
from diffusers.training_utils import (
|
| 56 |
+
_set_state_dict_into_text_encoder,
|
| 57 |
+
cast_training_params,
|
| 58 |
+
compute_density_for_timestep_sampling,
|
| 59 |
+
compute_loss_weighting_for_sd3,
|
| 60 |
+
free_memory,
|
| 61 |
+
)
|
| 62 |
+
from diffusers.utils import (
|
| 63 |
+
check_min_version,
|
| 64 |
+
convert_unet_state_dict_to_peft,
|
| 65 |
+
is_wandb_available,
|
| 66 |
+
)
|
| 67 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 68 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 69 |
+
|
| 70 |
+
if is_wandb_available():
|
| 71 |
+
import wandb
|
| 72 |
+
|
| 73 |
+
# Check minimum diffusers version
|
| 74 |
+
check_min_version("0.30.0")
|
| 75 |
+
|
| 76 |
+
logger = get_logger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_model_card(
|
| 80 |
+
repo_id: str,
|
| 81 |
+
images: list = None,
|
| 82 |
+
base_model: str = None,
|
| 83 |
+
dataset_name: str = None,
|
| 84 |
+
train_text_encoder: bool = False,
|
| 85 |
+
repo_folder: str = None,
|
| 86 |
+
vae_path: str = None,
|
| 87 |
+
):
|
| 88 |
+
"""Save model card for SD3 LoRA model."""
|
| 89 |
+
img_str = ""
|
| 90 |
+
if images is not None:
|
| 91 |
+
for i, image in enumerate(images):
|
| 92 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 93 |
+
img_str += f"\n"
|
| 94 |
+
|
| 95 |
+
model_description = f"""
|
| 96 |
+
# SD3 LoRA text2image fine-tuning - {repo_id}
|
| 97 |
+
|
| 98 |
+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
| 99 |
+
{img_str}
|
| 100 |
+
|
| 101 |
+
LoRA for the text encoder was enabled: {train_text_encoder}.
|
| 102 |
+
|
| 103 |
+
Special VAE used for training: {vae_path}.
|
| 104 |
+
"""
|
| 105 |
+
model_card = load_or_create_model_card(
|
| 106 |
+
repo_id_or_path=repo_id,
|
| 107 |
+
from_training=True,
|
| 108 |
+
license="other",
|
| 109 |
+
base_model=base_model,
|
| 110 |
+
model_description=model_description,
|
| 111 |
+
inference=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
tags = [
|
| 115 |
+
"stable-diffusion-3",
|
| 116 |
+
"stable-diffusion-3-diffusers",
|
| 117 |
+
"text-to-image",
|
| 118 |
+
"diffusers",
|
| 119 |
+
"diffusers-training",
|
| 120 |
+
"lora",
|
| 121 |
+
"sd3",
|
| 122 |
+
]
|
| 123 |
+
model_card = populate_model_card(model_card, tags=tags)
|
| 124 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def log_validation(
|
| 128 |
+
pipeline,
|
| 129 |
+
args,
|
| 130 |
+
accelerator,
|
| 131 |
+
epoch,
|
| 132 |
+
is_final_validation=False,
|
| 133 |
+
global_step=None,
|
| 134 |
+
):
|
| 135 |
+
"""Run validation and log images."""
|
| 136 |
+
logger.info(
|
| 137 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 138 |
+
f" {args.validation_prompt}."
|
| 139 |
+
)
|
| 140 |
+
pipeline = pipeline.to(accelerator.device)
|
| 141 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 142 |
+
|
| 143 |
+
# run inference
|
| 144 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
| 145 |
+
pipeline_args = {"prompt": args.validation_prompt}
|
| 146 |
+
|
| 147 |
+
if torch.backends.mps.is_available():
|
| 148 |
+
autocast_ctx = nullcontext()
|
| 149 |
+
else:
|
| 150 |
+
autocast_ctx = torch.autocast(accelerator.device.type)
|
| 151 |
+
|
| 152 |
+
with autocast_ctx:
|
| 153 |
+
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
| 154 |
+
|
| 155 |
+
# Save images to output directory
|
| 156 |
+
if accelerator.is_main_process:
|
| 157 |
+
validation_dir = os.path.join(args.output_dir, "validation_images")
|
| 158 |
+
os.makedirs(validation_dir, exist_ok=True)
|
| 159 |
+
for i, image in enumerate(images):
|
| 160 |
+
# Create filename with step and epoch information
|
| 161 |
+
if global_step is not None:
|
| 162 |
+
filename = f"validation_step_{global_step}_epoch_{epoch}_img_{i}.png"
|
| 163 |
+
else:
|
| 164 |
+
filename = f"validation_epoch_{epoch}_img_{i}.png"
|
| 165 |
+
|
| 166 |
+
image_path = os.path.join(validation_dir, filename)
|
| 167 |
+
image.save(image_path)
|
| 168 |
+
logger.info(f"Saved validation image: {image_path}")
|
| 169 |
+
|
| 170 |
+
for tracker in accelerator.trackers if hasattr(accelerator, 'trackers') and accelerator.trackers else []:
|
| 171 |
+
phase_name = "test" if is_final_validation else "validation"
|
| 172 |
+
try:
|
| 173 |
+
if tracker.name == "tensorboard":
|
| 174 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 175 |
+
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
|
| 176 |
+
if tracker.name == "wandb":
|
| 177 |
+
tracker.log(
|
| 178 |
+
{
|
| 179 |
+
phase_name: [
|
| 180 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
| 181 |
+
]
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.warning(f"Failed to log to {tracker.name}: {e}")
|
| 186 |
+
|
| 187 |
+
del pipeline
|
| 188 |
+
free_memory()
|
| 189 |
+
return images
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def import_model_class_from_model_name_or_path(
|
| 193 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 194 |
+
):
|
| 195 |
+
"""Import the correct text encoder class."""
|
| 196 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 197 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 198 |
+
)
|
| 199 |
+
model_class = text_encoder_config.architectures[0]
|
| 200 |
+
|
| 201 |
+
if model_class == "CLIPTextModelWithProjection":
|
| 202 |
+
from transformers import CLIPTextModelWithProjection
|
| 203 |
+
return CLIPTextModelWithProjection
|
| 204 |
+
elif model_class == "T5EncoderModel":
|
| 205 |
+
from transformers import T5EncoderModel
|
| 206 |
+
return T5EncoderModel
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def parse_args(input_args=None):
|
| 212 |
+
"""Parse command line arguments."""
|
| 213 |
+
parser = argparse.ArgumentParser(description="SD3 LoRA training script.")
|
| 214 |
+
|
| 215 |
+
# Model arguments
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--pretrained_model_name_or_path",
|
| 218 |
+
type=str,
|
| 219 |
+
default=None,
|
| 220 |
+
required=True,
|
| 221 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--revision",
|
| 225 |
+
type=str,
|
| 226 |
+
default=None,
|
| 227 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--variant",
|
| 231 |
+
type=str,
|
| 232 |
+
default=None,
|
| 233 |
+
help="Variant of the model files, e.g. fp16",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Dataset arguments
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--dataset_name",
|
| 239 |
+
type=str,
|
| 240 |
+
default=None,
|
| 241 |
+
help="The name of the Dataset to train on.",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--dataset_config_name",
|
| 245 |
+
type=str,
|
| 246 |
+
default=None,
|
| 247 |
+
help="The config of the Dataset.",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--train_data_dir",
|
| 251 |
+
type=str,
|
| 252 |
+
default=None,
|
| 253 |
+
help="A folder containing the training data.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--image_column",
|
| 257 |
+
type=str,
|
| 258 |
+
default="image",
|
| 259 |
+
help="The column of the dataset containing an image."
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--caption_column",
|
| 263 |
+
type=str,
|
| 264 |
+
default="caption",
|
| 265 |
+
help="The column of the dataset containing a caption.",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Training arguments
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--max_sequence_length",
|
| 271 |
+
type=int,
|
| 272 |
+
default=77,
|
| 273 |
+
help="Maximum sequence length to use with the T5 text encoder",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--validation_prompt",
|
| 277 |
+
type=str,
|
| 278 |
+
default=None,
|
| 279 |
+
help="A prompt used during validation.",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--num_validation_images",
|
| 283 |
+
type=int,
|
| 284 |
+
default=4,
|
| 285 |
+
help="Number of images for validation.",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--validation_epochs",
|
| 289 |
+
type=int,
|
| 290 |
+
default=1,
|
| 291 |
+
help="Run validation every X epochs.",
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--max_train_samples",
|
| 295 |
+
type=int,
|
| 296 |
+
default=None,
|
| 297 |
+
help="Truncate the number of training examples.",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--output_dir",
|
| 301 |
+
type=str,
|
| 302 |
+
default="sd3-lora-finetuned",
|
| 303 |
+
help="Output directory for model predictions and checkpoints.",
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--cache_dir",
|
| 307 |
+
type=str,
|
| 308 |
+
default=None,
|
| 309 |
+
help="Directory to store downloaded models and datasets.",
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--seed",
|
| 313 |
+
type=int,
|
| 314 |
+
default=None,
|
| 315 |
+
help="A seed for reproducible training."
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--resolution",
|
| 319 |
+
type=int,
|
| 320 |
+
default=1024,
|
| 321 |
+
help="Image resolution for training.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--center_crop",
|
| 325 |
+
default=False,
|
| 326 |
+
action="store_true",
|
| 327 |
+
help="Whether to center crop input images.",
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--random_flip",
|
| 331 |
+
action="store_true",
|
| 332 |
+
help="Whether to randomly flip images horizontally.",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--train_text_encoder",
|
| 336 |
+
action="store_true",
|
| 337 |
+
help="Whether to train the text encoder.",
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--train_batch_size",
|
| 341 |
+
type=int,
|
| 342 |
+
default=16,
|
| 343 |
+
help="Batch size for training dataloader."
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--num_train_epochs",
|
| 347 |
+
type=int,
|
| 348 |
+
default=100
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--max_train_steps",
|
| 352 |
+
type=int,
|
| 353 |
+
default=None,
|
| 354 |
+
help="Total number of training steps.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--checkpointing_steps",
|
| 358 |
+
type=int,
|
| 359 |
+
default=500,
|
| 360 |
+
help="Save checkpoint every X updates.",
|
| 361 |
+
)
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--checkpoints_total_limit",
|
| 364 |
+
type=int,
|
| 365 |
+
default=None,
|
| 366 |
+
help="Max number of checkpoints to store.",
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--resume_from_checkpoint",
|
| 370 |
+
type=str,
|
| 371 |
+
default=None,
|
| 372 |
+
help="Path to resume training from checkpoint.",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--gradient_accumulation_steps",
|
| 376 |
+
type=int,
|
| 377 |
+
default=1,
|
| 378 |
+
help="Number of update steps to accumulate.",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--gradient_checkpointing",
|
| 382 |
+
action="store_true",
|
| 383 |
+
help="Use gradient checkpointing to save memory.",
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument(
|
| 386 |
+
"--learning_rate",
|
| 387 |
+
type=float,
|
| 388 |
+
default=1e-4,
|
| 389 |
+
help="Initial learning rate.",
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--scale_lr",
|
| 393 |
+
action="store_true",
|
| 394 |
+
default=False,
|
| 395 |
+
help="Scale learning rate by number of GPUs, etc.",
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"--lr_scheduler",
|
| 399 |
+
type=str,
|
| 400 |
+
default="constant",
|
| 401 |
+
help="Learning rate scheduler type.",
|
| 402 |
+
)
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--lr_warmup_steps",
|
| 405 |
+
type=int,
|
| 406 |
+
default=500,
|
| 407 |
+
help="Number of warmup steps."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# SD3 specific arguments
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--weighting_scheme",
|
| 413 |
+
type=str,
|
| 414 |
+
default="logit_normal",
|
| 415 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
| 416 |
+
help="Weighting scheme for flow matching loss.",
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--logit_mean",
|
| 420 |
+
type=float,
|
| 421 |
+
default=0.0,
|
| 422 |
+
help="Mean for logit_normal weighting."
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--logit_std",
|
| 426 |
+
type=float,
|
| 427 |
+
default=1.0,
|
| 428 |
+
help="Std for logit_normal weighting."
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--mode_scale",
|
| 432 |
+
type=float,
|
| 433 |
+
default=1.29,
|
| 434 |
+
help="Scale for mode weighting scheme.",
|
| 435 |
+
)
|
| 436 |
+
parser.add_argument(
|
| 437 |
+
"--precondition_outputs",
|
| 438 |
+
type=int,
|
| 439 |
+
default=1,
|
| 440 |
+
help="Whether to precondition model outputs.",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Optimization arguments
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--allow_tf32",
|
| 446 |
+
action="store_true",
|
| 447 |
+
help="Allow TF32 on Ampere GPUs.",
|
| 448 |
+
)
|
| 449 |
+
parser.add_argument(
|
| 450 |
+
"--dataloader_num_workers",
|
| 451 |
+
type=int,
|
| 452 |
+
default=0,
|
| 453 |
+
help="Number of data loading workers.",
|
| 454 |
+
)
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--use_8bit_adam",
|
| 457 |
+
action="store_true",
|
| 458 |
+
help="Use 8-bit Adam optimizer."
|
| 459 |
+
)
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--adam_beta1",
|
| 462 |
+
type=float,
|
| 463 |
+
default=0.9,
|
| 464 |
+
help="Beta1 for Adam optimizer."
|
| 465 |
+
)
|
| 466 |
+
parser.add_argument(
|
| 467 |
+
"--adam_beta2",
|
| 468 |
+
type=float,
|
| 469 |
+
default=0.999,
|
| 470 |
+
help="Beta2 for Adam optimizer."
|
| 471 |
+
)
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--adam_weight_decay",
|
| 474 |
+
type=float,
|
| 475 |
+
default=1e-2,
|
| 476 |
+
help="Weight decay for Adam."
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--adam_epsilon",
|
| 480 |
+
type=float,
|
| 481 |
+
default=1e-08,
|
| 482 |
+
help="Epsilon for Adam optimizer."
|
| 483 |
+
)
|
| 484 |
+
parser.add_argument(
|
| 485 |
+
"--max_grad_norm",
|
| 486 |
+
default=1.0,
|
| 487 |
+
type=float,
|
| 488 |
+
help="Max gradient norm."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Hub and logging arguments
|
| 492 |
+
parser.add_argument(
|
| 493 |
+
"--push_to_hub",
|
| 494 |
+
action="store_true",
|
| 495 |
+
help="Push model to the Hub."
|
| 496 |
+
)
|
| 497 |
+
parser.add_argument(
|
| 498 |
+
"--hub_token",
|
| 499 |
+
type=str,
|
| 500 |
+
default=None,
|
| 501 |
+
help="Token for Model Hub."
|
| 502 |
+
)
|
| 503 |
+
parser.add_argument(
|
| 504 |
+
"--hub_model_id",
|
| 505 |
+
type=str,
|
| 506 |
+
default=None,
|
| 507 |
+
help="Repository name for the Hub.",
|
| 508 |
+
)
|
| 509 |
+
parser.add_argument(
|
| 510 |
+
"--logging_dir",
|
| 511 |
+
type=str,
|
| 512 |
+
default="logs",
|
| 513 |
+
help="TensorBoard log directory.",
|
| 514 |
+
)
|
| 515 |
+
parser.add_argument(
|
| 516 |
+
"--report_to",
|
| 517 |
+
type=str,
|
| 518 |
+
default="tensorboard",
|
| 519 |
+
help="Logging integration to use.",
|
| 520 |
+
)
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--mixed_precision",
|
| 523 |
+
type=str,
|
| 524 |
+
default=None,
|
| 525 |
+
choices=["no", "fp16", "bf16"],
|
| 526 |
+
help="Mixed precision type.",
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--local_rank",
|
| 530 |
+
type=int,
|
| 531 |
+
default=-1,
|
| 532 |
+
help="Local rank for distributed training."
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# LoRA arguments
|
| 536 |
+
parser.add_argument(
|
| 537 |
+
"--rank",
|
| 538 |
+
type=int,
|
| 539 |
+
default=64,
|
| 540 |
+
help="LoRA rank dimension.",
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
if input_args is not None:
|
| 544 |
+
args = parser.parse_args(input_args)
|
| 545 |
+
else:
|
| 546 |
+
args = parser.parse_args()
|
| 547 |
+
|
| 548 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 549 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 550 |
+
args.local_rank = env_local_rank
|
| 551 |
+
|
| 552 |
+
# Sanity checks
|
| 553 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 554 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 555 |
+
|
| 556 |
+
return args
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
DATASET_NAME_MAPPING = {
|
| 560 |
+
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def tokenize_prompt(tokenizer, prompt):
|
| 565 |
+
"""Tokenize prompt using the given tokenizer."""
|
| 566 |
+
text_inputs = tokenizer(
|
| 567 |
+
prompt,
|
| 568 |
+
padding="max_length",
|
| 569 |
+
max_length=77,
|
| 570 |
+
truncation=True,
|
| 571 |
+
return_tensors="pt",
|
| 572 |
+
)
|
| 573 |
+
return text_inputs.input_ids
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _encode_prompt_with_t5(
|
| 577 |
+
text_encoder,
|
| 578 |
+
tokenizer,
|
| 579 |
+
max_sequence_length,
|
| 580 |
+
prompt=None,
|
| 581 |
+
num_images_per_prompt=1,
|
| 582 |
+
device=None,
|
| 583 |
+
text_input_ids=None,
|
| 584 |
+
):
|
| 585 |
+
"""Encode prompt using T5 text encoder."""
|
| 586 |
+
if prompt is not None:
|
| 587 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 588 |
+
batch_size = len(prompt)
|
| 589 |
+
else:
|
| 590 |
+
# When prompt is None, we must have text_input_ids
|
| 591 |
+
if text_input_ids is None:
|
| 592 |
+
raise ValueError("Either prompt or text_input_ids must be provided")
|
| 593 |
+
batch_size = text_input_ids.shape[0]
|
| 594 |
+
|
| 595 |
+
if tokenizer is not None and prompt is not None:
|
| 596 |
+
text_inputs = tokenizer(
|
| 597 |
+
prompt,
|
| 598 |
+
padding="max_length",
|
| 599 |
+
max_length=max_sequence_length,
|
| 600 |
+
truncation=True,
|
| 601 |
+
add_special_tokens=True,
|
| 602 |
+
return_tensors="pt",
|
| 603 |
+
)
|
| 604 |
+
text_input_ids = text_inputs.input_ids
|
| 605 |
+
else:
|
| 606 |
+
if text_input_ids is None:
|
| 607 |
+
raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
|
| 608 |
+
|
| 609 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 610 |
+
dtype = text_encoder.dtype
|
| 611 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 612 |
+
|
| 613 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 614 |
+
# duplicate text embeddings for each generation per prompt
|
| 615 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 616 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 617 |
+
|
| 618 |
+
return prompt_embeds
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def _encode_prompt_with_clip(
|
| 622 |
+
text_encoder,
|
| 623 |
+
tokenizer,
|
| 624 |
+
prompt: str,
|
| 625 |
+
device=None,
|
| 626 |
+
text_input_ids=None,
|
| 627 |
+
num_images_per_prompt: int = 1,
|
| 628 |
+
):
|
| 629 |
+
"""Encode prompt using CLIP text encoder."""
|
| 630 |
+
if prompt is not None:
|
| 631 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 632 |
+
batch_size = len(prompt)
|
| 633 |
+
else:
|
| 634 |
+
# When prompt is None, we must have text_input_ids
|
| 635 |
+
if text_input_ids is None:
|
| 636 |
+
raise ValueError("Either prompt or text_input_ids must be provided")
|
| 637 |
+
batch_size = text_input_ids.shape[0]
|
| 638 |
+
|
| 639 |
+
if tokenizer is not None and prompt is not None:
|
| 640 |
+
text_inputs = tokenizer(
|
| 641 |
+
prompt,
|
| 642 |
+
padding="max_length",
|
| 643 |
+
max_length=77,
|
| 644 |
+
truncation=True,
|
| 645 |
+
return_tensors="pt",
|
| 646 |
+
)
|
| 647 |
+
text_input_ids = text_inputs.input_ids
|
| 648 |
+
else:
|
| 649 |
+
if text_input_ids is None:
|
| 650 |
+
raise ValueError("text_input_ids must be provided when tokenizer is not specified or prompt is None")
|
| 651 |
+
|
| 652 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 653 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 654 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 655 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
| 656 |
+
|
| 657 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 658 |
+
# duplicate text embeddings for each generation per prompt
|
| 659 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 660 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 661 |
+
|
| 662 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def encode_prompt(
|
| 666 |
+
text_encoders,
|
| 667 |
+
tokenizers,
|
| 668 |
+
prompt: str,
|
| 669 |
+
max_sequence_length,
|
| 670 |
+
device=None,
|
| 671 |
+
num_images_per_prompt: int = 1,
|
| 672 |
+
text_input_ids_list=None,
|
| 673 |
+
):
|
| 674 |
+
"""Encode prompt using all three text encoders (SD3 architecture)."""
|
| 675 |
+
if prompt is not None:
|
| 676 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 677 |
+
|
| 678 |
+
# Process CLIP encoders (first two)
|
| 679 |
+
clip_tokenizers = tokenizers[:2]
|
| 680 |
+
clip_text_encoders = text_encoders[:2]
|
| 681 |
+
|
| 682 |
+
clip_prompt_embeds_list = []
|
| 683 |
+
clip_pooled_prompt_embeds_list = []
|
| 684 |
+
|
| 685 |
+
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
|
| 686 |
+
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 687 |
+
text_encoder=text_encoder,
|
| 688 |
+
tokenizer=tokenizer,
|
| 689 |
+
prompt=prompt,
|
| 690 |
+
device=device if device is not None else text_encoder.device,
|
| 691 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 692 |
+
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
| 693 |
+
)
|
| 694 |
+
clip_prompt_embeds_list.append(prompt_embeds)
|
| 695 |
+
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
| 696 |
+
|
| 697 |
+
# Concatenate CLIP embeddings
|
| 698 |
+
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
|
| 699 |
+
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
|
| 700 |
+
|
| 701 |
+
# Process T5 encoder (third encoder)
|
| 702 |
+
t5_prompt_embed = _encode_prompt_with_t5(
|
| 703 |
+
text_encoders[-1],
|
| 704 |
+
tokenizers[-1],
|
| 705 |
+
max_sequence_length,
|
| 706 |
+
prompt=prompt,
|
| 707 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 708 |
+
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
|
| 709 |
+
device=device if device is not None else text_encoders[-1].device,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# Pad CLIP embeddings to match T5 embedding dimension
|
| 713 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 714 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 715 |
+
)
|
| 716 |
+
# Concatenate all embeddings
|
| 717 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 718 |
+
|
| 719 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def main(args):
|
| 723 |
+
"""Main training function."""
|
| 724 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 725 |
+
raise ValueError(
|
| 726 |
+
"You cannot use both --report_to=wandb and --hub_token due to security risk."
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 730 |
+
|
| 731 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 732 |
+
raise ValueError(
|
| 733 |
+
"Mixed precision training with bfloat16 is not supported on MPS."
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# GPU多卡训练检查
|
| 737 |
+
if torch.cuda.is_available():
|
| 738 |
+
num_gpus = torch.cuda.device_count()
|
| 739 |
+
print(f"Found {num_gpus} GPUs available")
|
| 740 |
+
if num_gpus > 1:
|
| 741 |
+
print(f"Multi-GPU training enabled with {num_gpus} GPUs")
|
| 742 |
+
else:
|
| 743 |
+
print("No CUDA GPUs found, training on CPU")
|
| 744 |
+
|
| 745 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 746 |
+
# 优化多GPU训练的DDP参数
|
| 747 |
+
kwargs = DistributedDataParallelKwargs(
|
| 748 |
+
find_unused_parameters=True,
|
| 749 |
+
gradient_as_bucket_view=True, # 提高多GPU训练效率
|
| 750 |
+
static_graph=False, # 动态图支持
|
| 751 |
+
)
|
| 752 |
+
accelerator = Accelerator(
|
| 753 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 754 |
+
mixed_precision=args.mixed_precision,
|
| 755 |
+
log_with=args.report_to,
|
| 756 |
+
project_config=accelerator_project_config,
|
| 757 |
+
kwargs_handlers=[kwargs],
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
# Logging setup
|
| 761 |
+
logging.basicConfig(
|
| 762 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 763 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 764 |
+
level=logging.INFO,
|
| 765 |
+
)
|
| 766 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 767 |
+
|
| 768 |
+
# 记录多GPU训练信息
|
| 769 |
+
if accelerator.is_main_process:
|
| 770 |
+
logger.info(f"Number of processes: {accelerator.num_processes}")
|
| 771 |
+
logger.info(f"Distributed type: {accelerator.distributed_type}")
|
| 772 |
+
logger.info(f"Mixed precision: {accelerator.mixed_precision}")
|
| 773 |
+
if torch.cuda.is_available():
|
| 774 |
+
for i in range(torch.cuda.device_count()):
|
| 775 |
+
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 776 |
+
logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
|
| 777 |
+
|
| 778 |
+
if accelerator.is_local_main_process:
|
| 779 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 780 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 781 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 782 |
+
else:
|
| 783 |
+
datasets.utils.logging.set_verbosity_error()
|
| 784 |
+
transformers.utils.logging.set_verbosity_error()
|
| 785 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 786 |
+
|
| 787 |
+
# Set training seed
|
| 788 |
+
if args.seed is not None:
|
| 789 |
+
set_seed(args.seed)
|
| 790 |
+
|
| 791 |
+
# Create output directory
|
| 792 |
+
if accelerator.is_main_process:
|
| 793 |
+
if args.output_dir is not None:
|
| 794 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 795 |
+
|
| 796 |
+
if args.push_to_hub:
|
| 797 |
+
repo_id = create_repo(
|
| 798 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
| 799 |
+
exist_ok=True,
|
| 800 |
+
token=args.hub_token
|
| 801 |
+
).repo_id
|
| 802 |
+
|
| 803 |
+
# Load tokenizers (three for SD3)
|
| 804 |
+
tokenizer_one = CLIPTokenizer.from_pretrained(
|
| 805 |
+
args.pretrained_model_name_or_path,
|
| 806 |
+
subfolder="tokenizer",
|
| 807 |
+
revision=args.revision,
|
| 808 |
+
)
|
| 809 |
+
tokenizer_two = CLIPTokenizer.from_pretrained(
|
| 810 |
+
args.pretrained_model_name_or_path,
|
| 811 |
+
subfolder="tokenizer_2",
|
| 812 |
+
revision=args.revision,
|
| 813 |
+
)
|
| 814 |
+
tokenizer_three = T5TokenizerFast.from_pretrained(
|
| 815 |
+
args.pretrained_model_name_or_path,
|
| 816 |
+
subfolder="tokenizer_3",
|
| 817 |
+
revision=args.revision,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# Import text encoder classes
|
| 821 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 822 |
+
args.pretrained_model_name_or_path, args.revision
|
| 823 |
+
)
|
| 824 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 825 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
| 826 |
+
)
|
| 827 |
+
text_encoder_cls_three = import_model_class_from_model_name_or_path(
|
| 828 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Load models
|
| 832 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 833 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
| 834 |
+
)
|
| 835 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 836 |
+
|
| 837 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 838 |
+
args.pretrained_model_name_or_path,
|
| 839 |
+
subfolder="text_encoder",
|
| 840 |
+
revision=args.revision,
|
| 841 |
+
variant=args.variant
|
| 842 |
+
)
|
| 843 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 844 |
+
args.pretrained_model_name_or_path,
|
| 845 |
+
subfolder="text_encoder_2",
|
| 846 |
+
revision=args.revision,
|
| 847 |
+
variant=args.variant
|
| 848 |
+
)
|
| 849 |
+
text_encoder_three = text_encoder_cls_three.from_pretrained(
|
| 850 |
+
args.pretrained_model_name_or_path,
|
| 851 |
+
subfolder="text_encoder_3",
|
| 852 |
+
revision=args.revision,
|
| 853 |
+
variant=args.variant
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
vae = AutoencoderKL.from_pretrained(
|
| 857 |
+
args.pretrained_model_name_or_path,
|
| 858 |
+
subfolder="vae",
|
| 859 |
+
revision=args.revision,
|
| 860 |
+
variant=args.variant,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
| 864 |
+
args.pretrained_model_name_or_path,
|
| 865 |
+
subfolder="transformer",
|
| 866 |
+
revision=args.revision,
|
| 867 |
+
variant=args.variant
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Freeze non-trainable weights
|
| 871 |
+
transformer.requires_grad_(False)
|
| 872 |
+
vae.requires_grad_(False)
|
| 873 |
+
text_encoder_one.requires_grad_(False)
|
| 874 |
+
text_encoder_two.requires_grad_(False)
|
| 875 |
+
text_encoder_three.requires_grad_(False)
|
| 876 |
+
|
| 877 |
+
# Set precision
|
| 878 |
+
weight_dtype = torch.float32
|
| 879 |
+
if accelerator.mixed_precision == "fp16":
|
| 880 |
+
weight_dtype = torch.float16
|
| 881 |
+
elif accelerator.mixed_precision == "bf16":
|
| 882 |
+
weight_dtype = torch.bfloat16
|
| 883 |
+
|
| 884 |
+
# Move models to device
|
| 885 |
+
vae.to(accelerator.device, dtype=torch.float32) # VAE stays in fp32
|
| 886 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 887 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 888 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 889 |
+
text_encoder_three.to(accelerator.device, dtype=weight_dtype)
|
| 890 |
+
|
| 891 |
+
# Enable gradient checkpointing
|
| 892 |
+
if args.gradient_checkpointing:
|
| 893 |
+
transformer.enable_gradient_checkpointing()
|
| 894 |
+
if args.train_text_encoder:
|
| 895 |
+
text_encoder_one.gradient_checkpointing_enable()
|
| 896 |
+
text_encoder_two.gradient_checkpointing_enable()
|
| 897 |
+
|
| 898 |
+
# Configure LoRA for transformer
|
| 899 |
+
transformer_lora_config = LoraConfig(
|
| 900 |
+
r=args.rank,
|
| 901 |
+
lora_alpha=args.rank,
|
| 902 |
+
init_lora_weights="gaussian",
|
| 903 |
+
target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"],
|
| 904 |
+
)
|
| 905 |
+
transformer.add_adapter(transformer_lora_config)
|
| 906 |
+
|
| 907 |
+
# Configure LoRA for text encoders if enabled
|
| 908 |
+
if args.train_text_encoder:
|
| 909 |
+
text_lora_config = LoraConfig(
|
| 910 |
+
r=args.rank,
|
| 911 |
+
lora_alpha=args.rank,
|
| 912 |
+
init_lora_weights="gaussian",
|
| 913 |
+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
| 914 |
+
)
|
| 915 |
+
text_encoder_one.add_adapter(text_lora_config)
|
| 916 |
+
text_encoder_two.add_adapter(text_lora_config)
|
| 917 |
+
# Note: T5 encoder typically doesn't use LoRA
|
| 918 |
+
|
| 919 |
+
def unwrap_model(model):
|
| 920 |
+
model = accelerator.unwrap_model(model)
|
| 921 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 922 |
+
return model
|
| 923 |
+
|
| 924 |
+
# Enable TF32 for faster training
|
| 925 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
| 926 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 927 |
+
|
| 928 |
+
# Scale learning rate
|
| 929 |
+
if args.scale_lr:
|
| 930 |
+
args.learning_rate = (
|
| 931 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
# Cast trainable parameters to float32
|
| 935 |
+
if args.mixed_precision == "fp16":
|
| 936 |
+
models = [transformer]
|
| 937 |
+
if args.train_text_encoder:
|
| 938 |
+
models.extend([text_encoder_one, text_encoder_two])
|
| 939 |
+
cast_training_params(models, dtype=torch.float32)
|
| 940 |
+
|
| 941 |
+
# Setup optimizer
|
| 942 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
| 943 |
+
if args.train_text_encoder:
|
| 944 |
+
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
| 945 |
+
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
| 946 |
+
params_to_optimize = (
|
| 947 |
+
transformer_lora_parameters
|
| 948 |
+
+ text_lora_parameters_one
|
| 949 |
+
+ text_lora_parameters_two
|
| 950 |
+
)
|
| 951 |
+
else:
|
| 952 |
+
params_to_optimize = transformer_lora_parameters
|
| 953 |
+
|
| 954 |
+
# Create optimizer
|
| 955 |
+
if args.use_8bit_adam:
|
| 956 |
+
try:
|
| 957 |
+
import bitsandbytes as bnb
|
| 958 |
+
except ImportError:
|
| 959 |
+
raise ImportError("To use 8-bit Adam, install bitsandbytes: pip install bitsandbytes")
|
| 960 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 961 |
+
else:
|
| 962 |
+
optimizer_class = torch.optim.AdamW
|
| 963 |
+
|
| 964 |
+
optimizer = optimizer_class(
|
| 965 |
+
params_to_optimize,
|
| 966 |
+
lr=args.learning_rate,
|
| 967 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 968 |
+
weight_decay=args.adam_weight_decay,
|
| 969 |
+
eps=args.adam_epsilon,
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
# Load dataset
|
| 973 |
+
if args.dataset_name is not None:
|
| 974 |
+
dataset = load_dataset(
|
| 975 |
+
args.dataset_name,
|
| 976 |
+
args.dataset_config_name,
|
| 977 |
+
cache_dir=args.cache_dir,
|
| 978 |
+
data_dir=args.train_data_dir
|
| 979 |
+
)
|
| 980 |
+
else:
|
| 981 |
+
data_files = {}
|
| 982 |
+
if args.train_data_dir is not None:
|
| 983 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 984 |
+
dataset = load_dataset(
|
| 985 |
+
"imagefolder",
|
| 986 |
+
data_files=data_files,
|
| 987 |
+
cache_dir=args.cache_dir,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
# Preprocessing
|
| 991 |
+
column_names = dataset["train"].column_names
|
| 992 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
| 993 |
+
|
| 994 |
+
if args.image_column is None:
|
| 995 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
| 996 |
+
else:
|
| 997 |
+
image_column = args.image_column
|
| 998 |
+
if image_column not in column_names:
|
| 999 |
+
raise ValueError(f"--image_column '{args.image_column}' not found in: {column_names}")
|
| 1000 |
+
|
| 1001 |
+
if args.caption_column is None:
|
| 1002 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
| 1003 |
+
else:
|
| 1004 |
+
caption_column = args.caption_column
|
| 1005 |
+
if caption_column not in column_names:
|
| 1006 |
+
raise ValueError(f"--caption_column '{args.caption_column}' not found in: {column_names}")
|
| 1007 |
+
|
| 1008 |
+
def tokenize_captions(examples, is_train=True):
|
| 1009 |
+
captions = []
|
| 1010 |
+
for caption in examples[caption_column]:
|
| 1011 |
+
if isinstance(caption, str):
|
| 1012 |
+
captions.append(caption)
|
| 1013 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 1014 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 1015 |
+
else:
|
| 1016 |
+
raise ValueError(f"Caption column should contain strings or lists of strings.")
|
| 1017 |
+
|
| 1018 |
+
tokens_one = tokenize_prompt(tokenizer_one, captions)
|
| 1019 |
+
tokens_two = tokenize_prompt(tokenizer_two, captions)
|
| 1020 |
+
tokens_three = tokenize_prompt(tokenizer_three, captions)
|
| 1021 |
+
return tokens_one, tokens_two, tokens_three
|
| 1022 |
+
|
| 1023 |
+
# Image transforms
|
| 1024 |
+
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 1025 |
+
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
| 1026 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 1027 |
+
train_transforms = transforms.Compose([
|
| 1028 |
+
transforms.ToTensor(),
|
| 1029 |
+
transforms.Normalize([0.5], [0.5]),
|
| 1030 |
+
])
|
| 1031 |
+
|
| 1032 |
+
def preprocess_train(examples):
|
| 1033 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
| 1034 |
+
original_sizes = []
|
| 1035 |
+
all_images = []
|
| 1036 |
+
crop_top_lefts = []
|
| 1037 |
+
|
| 1038 |
+
for image in images:
|
| 1039 |
+
original_sizes.append((image.height, image.width))
|
| 1040 |
+
image = train_resize(image)
|
| 1041 |
+
if args.random_flip and random.random() < 0.5:
|
| 1042 |
+
image = train_flip(image)
|
| 1043 |
+
if args.center_crop:
|
| 1044 |
+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
| 1045 |
+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
| 1046 |
+
image = train_crop(image)
|
| 1047 |
+
else:
|
| 1048 |
+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
| 1049 |
+
image = crop(image, y1, x1, h, w)
|
| 1050 |
+
crop_top_left = (y1, x1)
|
| 1051 |
+
crop_top_lefts.append(crop_top_left)
|
| 1052 |
+
image = train_transforms(image)
|
| 1053 |
+
all_images.append(image)
|
| 1054 |
+
|
| 1055 |
+
examples["original_sizes"] = original_sizes
|
| 1056 |
+
examples["crop_top_lefts"] = crop_top_lefts
|
| 1057 |
+
examples["pixel_values"] = all_images
|
| 1058 |
+
|
| 1059 |
+
tokens_one, tokens_two, tokens_three = tokenize_captions(examples)
|
| 1060 |
+
examples["input_ids_one"] = tokens_one
|
| 1061 |
+
examples["input_ids_two"] = tokens_two
|
| 1062 |
+
examples["input_ids_three"] = tokens_three
|
| 1063 |
+
return examples
|
| 1064 |
+
|
| 1065 |
+
with accelerator.main_process_first():
|
| 1066 |
+
if args.max_train_samples is not None:
|
| 1067 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 1068 |
+
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
|
| 1069 |
+
|
| 1070 |
+
def collate_fn(examples):
|
| 1071 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 1072 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 1073 |
+
original_sizes = [example["original_sizes"] for example in examples]
|
| 1074 |
+
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
| 1075 |
+
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
| 1076 |
+
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
| 1077 |
+
input_ids_three = torch.stack([example["input_ids_three"] for example in examples])
|
| 1078 |
+
|
| 1079 |
+
return {
|
| 1080 |
+
"pixel_values": pixel_values,
|
| 1081 |
+
"input_ids_one": input_ids_one,
|
| 1082 |
+
"input_ids_two": input_ids_two,
|
| 1083 |
+
"input_ids_three": input_ids_three,
|
| 1084 |
+
"original_sizes": original_sizes,
|
| 1085 |
+
"crop_top_lefts": crop_top_lefts,
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
# 针对多GPU训练优化dataloader设置
|
| 1089 |
+
if args.dataloader_num_workers == 0 and accelerator.num_processes > 1:
|
| 1090 |
+
# 多GPU训练时自动设置数据加载器worker数量
|
| 1091 |
+
args.dataloader_num_workers = min(4, os.cpu_count() // accelerator.num_processes)
|
| 1092 |
+
logger.info(f"Auto-setting dataloader_num_workers to {args.dataloader_num_workers} for multi-GPU training")
|
| 1093 |
+
|
| 1094 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 1095 |
+
train_dataset,
|
| 1096 |
+
shuffle=True,
|
| 1097 |
+
collate_fn=collate_fn,
|
| 1098 |
+
batch_size=args.train_batch_size,
|
| 1099 |
+
num_workers=args.dataloader_num_workers,
|
| 1100 |
+
pin_memory=True, # 提高GPU数据传输效率
|
| 1101 |
+
persistent_workers=args.dataloader_num_workers > 0, # 保持worker进程活跃
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Scheduler and math around training steps
|
| 1105 |
+
overrode_max_train_steps = False
|
| 1106 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1107 |
+
if args.max_train_steps is None:
|
| 1108 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1109 |
+
overrode_max_train_steps = True
|
| 1110 |
+
|
| 1111 |
+
lr_scheduler = get_scheduler(
|
| 1112 |
+
args.lr_scheduler,
|
| 1113 |
+
optimizer=optimizer,
|
| 1114 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 1115 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
# Prepare everything with accelerator
|
| 1119 |
+
if args.train_text_encoder:
|
| 1120 |
+
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1121 |
+
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
|
| 1122 |
+
)
|
| 1123 |
+
else:
|
| 1124 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1125 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
# Recalculate training steps
|
| 1129 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1130 |
+
if overrode_max_train_steps:
|
| 1131 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1132 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 1133 |
+
|
| 1134 |
+
# Initialize trackers
|
| 1135 |
+
if accelerator.is_main_process:
|
| 1136 |
+
try:
|
| 1137 |
+
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
| 1138 |
+
except Exception as e:
|
| 1139 |
+
logger.warning(f"Failed to initialize trackers: {e}")
|
| 1140 |
+
logger.warning("Continuing without tracking. You can monitor training through console logs.")
|
| 1141 |
+
# Set report_to to None to avoid further tracking attempts
|
| 1142 |
+
args.report_to = None
|
| 1143 |
+
|
| 1144 |
+
# Train!
|
| 1145 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 1146 |
+
logger.info("***** Running training *****")
|
| 1147 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 1148 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 1149 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 1150 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 1151 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 1152 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 1153 |
+
logger.info(f" Number of GPU processes = {accelerator.num_processes}")
|
| 1154 |
+
if accelerator.num_processes > 1:
|
| 1155 |
+
logger.info(f" Effective batch size per GPU = {args.train_batch_size * args.gradient_accumulation_steps}")
|
| 1156 |
+
logger.info(f" Total effective batch size across all GPUs = {total_batch_size}")
|
| 1157 |
+
|
| 1158 |
+
global_step = 0
|
| 1159 |
+
first_epoch = 0
|
| 1160 |
+
|
| 1161 |
+
# Resume from checkpoint if specified
|
| 1162 |
+
if args.resume_from_checkpoint:
|
| 1163 |
+
if args.resume_from_checkpoint != "latest":
|
| 1164 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 1165 |
+
else:
|
| 1166 |
+
dirs = os.listdir(args.output_dir)
|
| 1167 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 1168 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 1169 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 1170 |
+
|
| 1171 |
+
if path is None:
|
| 1172 |
+
accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
|
| 1173 |
+
args.resume_from_checkpoint = None
|
| 1174 |
+
initial_global_step = 0
|
| 1175 |
+
else:
|
| 1176 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 1177 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 1178 |
+
global_step = int(path.split("-")[1])
|
| 1179 |
+
initial_global_step = global_step
|
| 1180 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 1181 |
+
else:
|
| 1182 |
+
initial_global_step = 0
|
| 1183 |
+
|
| 1184 |
+
progress_bar = tqdm(
|
| 1185 |
+
range(0, args.max_train_steps),
|
| 1186 |
+
initial=initial_global_step,
|
| 1187 |
+
desc="Steps",
|
| 1188 |
+
disable=not accelerator.is_local_main_process,
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 1192 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 1193 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 1194 |
+
timesteps = timesteps.to(accelerator.device)
|
| 1195 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 1196 |
+
sigma = sigmas[step_indices].flatten()
|
| 1197 |
+
while len(sigma.shape) < n_dim:
|
| 1198 |
+
sigma = sigma.unsqueeze(-1)
|
| 1199 |
+
return sigma
|
| 1200 |
+
|
| 1201 |
+
# Training loop
|
| 1202 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1203 |
+
transformer.train()
|
| 1204 |
+
if args.train_text_encoder:
|
| 1205 |
+
text_encoder_one.train()
|
| 1206 |
+
text_encoder_two.train()
|
| 1207 |
+
|
| 1208 |
+
train_loss = 0.0
|
| 1209 |
+
for step, batch in enumerate(train_dataloader):
|
| 1210 |
+
with accelerator.accumulate(transformer):
|
| 1211 |
+
# Convert images to latent space
|
| 1212 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 1213 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 1214 |
+
|
| 1215 |
+
# Apply VAE scaling
|
| 1216 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 1217 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 1218 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 1219 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 1220 |
+
|
| 1221 |
+
# Encode prompts
|
| 1222 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
| 1223 |
+
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
| 1224 |
+
tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three],
|
| 1225 |
+
prompt=None,
|
| 1226 |
+
max_sequence_length=args.max_sequence_length,
|
| 1227 |
+
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"], batch["input_ids_three"]],
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
# Sample noise and timesteps
|
| 1231 |
+
noise = torch.randn_like(model_input)
|
| 1232 |
+
bsz = model_input.shape[0]
|
| 1233 |
+
|
| 1234 |
+
# Flow Matching timestep sampling
|
| 1235 |
+
u = compute_density_for_timestep_sampling(
|
| 1236 |
+
weighting_scheme=args.weighting_scheme,
|
| 1237 |
+
batch_size=bsz,
|
| 1238 |
+
logit_mean=args.logit_mean,
|
| 1239 |
+
logit_std=args.logit_std,
|
| 1240 |
+
mode_scale=args.mode_scale,
|
| 1241 |
+
)
|
| 1242 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 1243 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 1244 |
+
|
| 1245 |
+
# Flow Matching interpolation
|
| 1246 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 1247 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 1248 |
+
|
| 1249 |
+
# Predict using SD3 Transformer
|
| 1250 |
+
model_pred = transformer(
|
| 1251 |
+
hidden_states=noisy_model_input,
|
| 1252 |
+
timestep=timesteps,
|
| 1253 |
+
encoder_hidden_states=prompt_embeds,
|
| 1254 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1255 |
+
return_dict=False,
|
| 1256 |
+
)[0]
|
| 1257 |
+
|
| 1258 |
+
# Compute target for Flow Matching
|
| 1259 |
+
if args.precondition_outputs:
|
| 1260 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
| 1261 |
+
target = model_input
|
| 1262 |
+
else:
|
| 1263 |
+
target = noise - model_input
|
| 1264 |
+
|
| 1265 |
+
# Compute loss with weighting
|
| 1266 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 1267 |
+
loss = torch.mean(
|
| 1268 |
+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 1269 |
+
1,
|
| 1270 |
+
)
|
| 1271 |
+
loss = loss.mean()
|
| 1272 |
+
|
| 1273 |
+
# Gather loss across processes
|
| 1274 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 1275 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 1276 |
+
|
| 1277 |
+
# Backpropagate
|
| 1278 |
+
accelerator.backward(loss)
|
| 1279 |
+
if accelerator.sync_gradients:
|
| 1280 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 1281 |
+
optimizer.step()
|
| 1282 |
+
lr_scheduler.step()
|
| 1283 |
+
optimizer.zero_grad()
|
| 1284 |
+
|
| 1285 |
+
# Checks if the accelerator has performed an optimization step
|
| 1286 |
+
if accelerator.sync_gradients:
|
| 1287 |
+
progress_bar.update(1)
|
| 1288 |
+
global_step += 1
|
| 1289 |
+
if hasattr(accelerator, 'trackers') and accelerator.trackers:
|
| 1290 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 1291 |
+
train_loss = 0.0
|
| 1292 |
+
|
| 1293 |
+
# Save checkpoint
|
| 1294 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
| 1295 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1296 |
+
if args.checkpoints_total_limit is not None:
|
| 1297 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1298 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1299 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1300 |
+
|
| 1301 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1302 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1303 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1304 |
+
logger.info(f"Removing {len(removing_checkpoints)} checkpoints")
|
| 1305 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1306 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1307 |
+
shutil.rmtree(removing_checkpoint)
|
| 1308 |
+
|
| 1309 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1310 |
+
accelerator.save_state(save_path)
|
| 1311 |
+
logger.info(f"Saved state to {save_path}")
|
| 1312 |
+
|
| 1313 |
+
# 同时保存标准的LoRA权重格式,方便采样时直接加载
|
| 1314 |
+
try:
|
| 1315 |
+
# 获取当前模型的LoRA权重
|
| 1316 |
+
unwrapped_transformer = unwrap_model(transformer)
|
| 1317 |
+
transformer_lora_layers = get_peft_model_state_dict(unwrapped_transformer)
|
| 1318 |
+
|
| 1319 |
+
text_encoder_lora_layers = None
|
| 1320 |
+
text_encoder_2_lora_layers = None
|
| 1321 |
+
if args.train_text_encoder:
|
| 1322 |
+
unwrapped_text_encoder_one = unwrap_model(text_encoder_one)
|
| 1323 |
+
unwrapped_text_encoder_two = unwrap_model(text_encoder_two)
|
| 1324 |
+
text_encoder_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_one)
|
| 1325 |
+
text_encoder_2_lora_layers = get_peft_model_state_dict(unwrapped_text_encoder_two)
|
| 1326 |
+
|
| 1327 |
+
# 保存为标准LoRA格式到checkpoint目录
|
| 1328 |
+
StableDiffusion3Pipeline.save_lora_weights(
|
| 1329 |
+
save_directory=save_path,
|
| 1330 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 1331 |
+
text_encoder_lora_layers=text_encoder_lora_layers,
|
| 1332 |
+
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
| 1333 |
+
)
|
| 1334 |
+
logger.info(f"Saved LoRA weights in standard format to {save_path}")
|
| 1335 |
+
except Exception as e:
|
| 1336 |
+
logger.warning(f"Failed to save LoRA weights in standard format: {e}")
|
| 1337 |
+
logger.warning("Checkpoint saved with accelerator format only. You can extract LoRA weights later.")
|
| 1338 |
+
|
| 1339 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1340 |
+
progress_bar.set_postfix(**logs)
|
| 1341 |
+
|
| 1342 |
+
if global_step >= args.max_train_steps:
|
| 1343 |
+
break
|
| 1344 |
+
|
| 1345 |
+
# Validation
|
| 1346 |
+
if accelerator.is_main_process:
|
| 1347 |
+
if args.validation_prompt is not None :#and epoch % args.validation_epochs == 0:
|
| 1348 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 1349 |
+
args.pretrained_model_name_or_path,
|
| 1350 |
+
vae=vae,
|
| 1351 |
+
text_encoder=unwrap_model(text_encoder_one),
|
| 1352 |
+
text_encoder_2=unwrap_model(text_encoder_two),
|
| 1353 |
+
text_encoder_3=unwrap_model(text_encoder_three),
|
| 1354 |
+
transformer=unwrap_model(transformer),
|
| 1355 |
+
revision=args.revision,
|
| 1356 |
+
variant=args.variant,
|
| 1357 |
+
torch_dtype=weight_dtype,
|
| 1358 |
+
)
|
| 1359 |
+
images = log_validation(pipeline, args, accelerator, epoch, global_step=global_step)
|
| 1360 |
+
del pipeline
|
| 1361 |
+
torch.cuda.empty_cache()
|
| 1362 |
+
|
| 1363 |
+
# Save final LoRA weights
|
| 1364 |
+
accelerator.wait_for_everyone()
|
| 1365 |
+
if accelerator.is_main_process:
|
| 1366 |
+
transformer = unwrap_model(transformer)
|
| 1367 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
| 1368 |
+
|
| 1369 |
+
if args.train_text_encoder:
|
| 1370 |
+
text_encoder_one = unwrap_model(text_encoder_one)
|
| 1371 |
+
text_encoder_two = unwrap_model(text_encoder_two)
|
| 1372 |
+
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
| 1373 |
+
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
| 1374 |
+
else:
|
| 1375 |
+
text_encoder_lora_layers = None
|
| 1376 |
+
text_encoder_2_lora_layers = None
|
| 1377 |
+
|
| 1378 |
+
StableDiffusion3Pipeline.save_lora_weights(
|
| 1379 |
+
save_directory=args.output_dir,
|
| 1380 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 1381 |
+
text_encoder_lora_layers=text_encoder_lora_layers,
|
| 1382 |
+
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
| 1383 |
+
)
|
| 1384 |
+
|
| 1385 |
+
# Final inference
|
| 1386 |
+
if args.mixed_precision == "fp16":
|
| 1387 |
+
vae.to(weight_dtype)
|
| 1388 |
+
|
| 1389 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 1390 |
+
args.pretrained_model_name_or_path,
|
| 1391 |
+
vae=vae,
|
| 1392 |
+
revision=args.revision,
|
| 1393 |
+
variant=args.variant,
|
| 1394 |
+
torch_dtype=weight_dtype,
|
| 1395 |
+
)
|
| 1396 |
+
pipeline.load_lora_weights(args.output_dir)
|
| 1397 |
+
|
| 1398 |
+
if args.validation_prompt and args.num_validation_images > 0:
|
| 1399 |
+
images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True, global_step=global_step)
|
| 1400 |
+
|
| 1401 |
+
if args.push_to_hub:
|
| 1402 |
+
save_model_card(
|
| 1403 |
+
repo_id,
|
| 1404 |
+
images=images,
|
| 1405 |
+
base_model=args.pretrained_model_name_or_path,
|
| 1406 |
+
dataset_name=args.dataset_name,
|
| 1407 |
+
train_text_encoder=args.train_text_encoder,
|
| 1408 |
+
repo_folder=args.output_dir,
|
| 1409 |
+
)
|
| 1410 |
+
upload_folder(
|
| 1411 |
+
repo_id=repo_id,
|
| 1412 |
+
folder_path=args.output_dir,
|
| 1413 |
+
commit_message="End of training",
|
| 1414 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
accelerator.end_training()
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
if __name__ == "__main__":
|
| 1421 |
+
args = parse_args()
|
| 1422 |
+
main(args)
|
train_rectified_noise.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_rectified_noise.sh
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# SD3 Rectified Noise Training Script
|
| 4 |
+
# 这个脚本展示了如何使用 train_rectified_noise.py 进行训练
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
# 激活正确的conda环境
|
| 9 |
+
source /root/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate SiT
|
| 11 |
+
|
| 12 |
+
# 基础配置
|
| 13 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 设置使用4个GPU(0,1,2,3)
|
| 14 |
+
#export OMP_NUM_THREADS=1
|
| 15 |
+
|
| 16 |
+
# 内存优化设置
|
| 17 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 18 |
+
export TOKENIZERS_PARALLELISM=false
|
| 19 |
+
|
| 20 |
+
# 模型和数据路径
|
| 21 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 22 |
+
LORA_MODEL_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000" # LoRA微调后的SD3模型路径
|
| 23 |
+
TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 训练数据目录
|
| 24 |
+
OUTPUT_DIR="./rectified-noise-batch-2" # 输出目录
|
| 25 |
+
|
| 26 |
+
# 训练参数
|
| 27 |
+
NUM_SIT_LAYERS=1 # SIT块的层数
|
| 28 |
+
SIT_LEARNING_RATE=1e-5 # SIT块的学习率
|
| 29 |
+
KL_LOSS_WEIGHT=0.5 # KL散度损失权重
|
| 30 |
+
RESOLUTION=512 # 图像分辨率
|
| 31 |
+
BATCH_SIZE=2 # 批次大小
|
| 32 |
+
GRADIENT_ACCUMULATION=2 # 梯度累积步数
|
| 33 |
+
MAX_TRAIN_STEPS=500000 # 最大训练步数
|
| 34 |
+
|
| 35 |
+
# 验证参数
|
| 36 |
+
VALIDATION_PROMPT="A bicycle replica with a clock as the front wheel."
|
| 37 |
+
NUM_VALIDATION_IMAGES=1
|
| 38 |
+
|
| 39 |
+
echo "开始 SD3 Rectified Noise 训练..."
|
| 40 |
+
echo "LoRA模型路径: $LORA_MODEL_PATH"
|
| 41 |
+
echo "SIT层数: $NUM_SIT_LAYERS"
|
| 42 |
+
echo "输出目录: $OUTPUT_DIR"
|
| 43 |
+
|
| 44 |
+
# 检查LoRA模型路径是否存在
|
| 45 |
+
if [ ! -d "$LORA_MODEL_PATH" ]; then
|
| 46 |
+
echo "错误: LoRA模型路径不存在: $LORA_MODEL_PATH"
|
| 47 |
+
echo "请先使用 train_lora_sd3.py 训练LoRA模型"
|
| 48 |
+
exit 1
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
# 使用accelerate启动训练
|
| 52 |
+
# 注意:移除了命令行中的mixed_precision参数,因为已经在accelerate_config.yaml中设置
|
| 53 |
+
accelerate launch --config_file accelerate_config.yaml train_rectified_noise.py \
|
| 54 |
+
--pretrained_model_name_or_path="$PRETRAINED_MODEL" \
|
| 55 |
+
--lora_model_path="$LORA_MODEL_PATH" \
|
| 56 |
+
--train_data_dir="$TRAIN_DATA_DIR" \
|
| 57 |
+
--num_sit_layers=$NUM_SIT_LAYERS \
|
| 58 |
+
--sit_learning_rate=$SIT_LEARNING_RATE \
|
| 59 |
+
--kl_loss_weight=$KL_LOSS_WEIGHT \
|
| 60 |
+
--resolution=$RESOLUTION \
|
| 61 |
+
--train_batch_size=$BATCH_SIZE \
|
| 62 |
+
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION \
|
| 63 |
+
--gradient_checkpointing \
|
| 64 |
+
--learning_rate=1e-5 \
|
| 65 |
+
--time_weight_alpha=5.0 \
|
| 66 |
+
--lr_scheduler="constant" \
|
| 67 |
+
--lr_warmup_steps=0 \
|
| 68 |
+
--max_train_steps=$MAX_TRAIN_STEPS \
|
| 69 |
+
--output_dir="$OUTPUT_DIR" \
|
| 70 |
+
--validation_prompt="$VALIDATION_PROMPT" \
|
| 71 |
+
--num_validation_images=$NUM_VALIDATION_IMAGES \
|
| 72 |
+
--validation_steps=20000 \
|
| 73 |
+
--seed=42 \
|
| 74 |
+
--dataloader_num_workers=8 \
|
| 75 |
+
--save_sit_weights_only \
|
| 76 |
+
--checkpointing_steps=20000 \
|
| 77 |
+
--checkpoints_total_limit=10 \
|
| 78 |
+
--report_to="tensorboard" \
|
| 79 |
+
--logging_dir="./logs"
|
| 80 |
+
|
| 81 |
+
echo "训练完成!"
|
| 82 |
+
echo "SIT权重保存在: $OUTPUT_DIR/sit_weights/"
|
| 83 |
+
echo "验证图像保存在: $OUTPUT_DIR/validation_images/"
|
| 84 |
+
|
| 85 |
+
# 可选:快速测试训练命令
|
| 86 |
+
# cat << 'EOF'
|
| 87 |
+
|
| 88 |
+
# # 快速测试命令(少量步数):
|
| 89 |
+
# accelerate launch train_rectified_noise.py \
|
| 90 |
+
# --pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers" \
|
| 91 |
+
# --lora_model_path="./sd3-lora-finetuned" \
|
| 92 |
+
# --train_data_dir="./dataset" \
|
| 93 |
+
# --num_sit_layers=2 \
|
| 94 |
+
# --resolution=256 \
|
| 95 |
+
# --train_batch_size=1 \
|
| 96 |
+
# --gradient_accumulation_steps=4 \
|
| 97 |
+
# --max_train_steps=100 \
|
| 98 |
+
# --output_dir="./test-rectified-noise" \
|
| 99 |
+
# --mixed_precision="fp16" \
|
| 100 |
+
# --save_sit_weights_only
|
| 101 |
+
|
| 102 |
+
# EOF
|
| 103 |
+
|
| 104 |
+
# nohup bash train_rectified_noise.sh > train_rectified_noise.log 2>&1 &
|
train_rectified_noise2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_sd3_lora.log
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
检测到 4 个GPU
|
| 3 |
+
每个GPU批次大小: 4
|
| 4 |
+
总有效批次大小: 16
|
| 5 |
+
===== SD3 LoRA 多GPU训练开始 =====
|
| 6 |
+
模型: /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671
|
| 7 |
+
输出目录: sd3-lora-finetuned-batch-8
|
| 8 |
+
分辨率: 512
|
| 9 |
+
每个GPU批次大小: 4
|
| 10 |
+
梯度累积步数: 1
|
| 11 |
+
总有效批次大小: 16
|
| 12 |
+
学习率: 1e-5
|
| 13 |
+
最大训练步数: 500000
|
| 14 |
+
LoRA Rank: 32
|
| 15 |
+
使用GPU: 0,1,2,3
|
| 16 |
+
断点重训: latest
|
| 17 |
+
===========================================
|
| 18 |
+
使用 accelerate 启动多GPU训练...
|
| 19 |
+
/root/miniconda3/envs/SiT/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
|
| 20 |
+
warnings.warn(
|
| 21 |
+
Terminated
|
| 22 |
+
===========================================
|
| 23 |
+
训练完成!
|
| 24 |
+
模型保存在: sd3-lora-finetuned-batch-8
|
| 25 |
+
日志保存在: sd3-lora-finetuned-batch-8/logs
|
| 26 |
+
验证图片保存在: sd3-lora-finetuned-batch-8/validation_images
|
| 27 |
+
===========================================
|
train_sd3_lora.sh
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# SD3 LoRA Fine-tuning Training Script
|
| 4 |
+
# 使用 Stable Diffusion 3 进行 LoRA 微调训练 - 多GPU优化版本
|
| 5 |
+
|
| 6 |
+
# 设置环境变量
|
| 7 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU数量调整
|
| 8 |
+
# export PYTHONPATH=$PYTHONPATH:/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco
|
| 9 |
+
|
| 10 |
+
# 检查GPU数量
|
| 11 |
+
num_gpus=$(nvidia-smi --list-gpus | wc -l)
|
| 12 |
+
echo "检测到 $num_gpus 个GPU"
|
| 13 |
+
|
| 14 |
+
# 训练参数配置
|
| 15 |
+
MODEL_NAME="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 16 |
+
#DATASET_NAME="lambdalabs/naruto-blip-captions" # 或者使用本地数据集路径
|
| 17 |
+
TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 本地数据集路径0
|
| 18 |
+
OUTPUT_DIR="sd3-lora-finetuned-batch-8"
|
| 19 |
+
RESOLUTION=512
|
| 20 |
+
# 多GPU训练时调整批次大小 - 每个GPU的批次大小
|
| 21 |
+
TRAIN_BATCH_SIZE=4 # 每个GPU的批次大小,总批次大小 = TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS
|
| 22 |
+
GRADIENT_ACCUMULATION_STEPS=1 # 梯度累积步数
|
| 23 |
+
LEARNING_RATE=1e-5
|
| 24 |
+
MAX_TRAIN_STEPS=500000
|
| 25 |
+
NUM_TRAIN_EPOCHS=50
|
| 26 |
+
VALIDATION_PROMPT="A photo of a beautiful landscape with mountains and lake"
|
| 27 |
+
NUM_VALIDATION_IMAGES=2
|
| 28 |
+
VALIDATION_EPOCHS=1 # 每10个epoch验证一次
|
| 29 |
+
LORA_RANK=32
|
| 30 |
+
SEED=42
|
| 31 |
+
RESUME_FROM_CHECKPOINT="latest" # 设置为 "latest" 以自动从最新checkpoint恢复,或指定checkpoint路径,或设为 "" 以不恢复
|
| 32 |
+
|
| 33 |
+
# 计算有效批次大小
|
| 34 |
+
effective_batch_size=$((TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS))
|
| 35 |
+
echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
|
| 36 |
+
echo "总有效批次大小: $effective_batch_size"
|
| 37 |
+
|
| 38 |
+
# 创建输出目录
|
| 39 |
+
mkdir -p $OUTPUT_DIR
|
| 40 |
+
|
| 41 |
+
echo "===== SD3 LoRA 多GPU训练开始 ====="
|
| 42 |
+
echo "模型: $MODEL_NAME"
|
| 43 |
+
echo "输出目录: $OUTPUT_DIR"
|
| 44 |
+
echo "分辨率: $RESOLUTION"
|
| 45 |
+
echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
|
| 46 |
+
echo "梯度累积步数: $GRADIENT_ACCUMULATION_STEPS"
|
| 47 |
+
echo "总有效批次大小: $effective_batch_size"
|
| 48 |
+
echo "学习率: $LEARNING_RATE"
|
| 49 |
+
echo "最大训练步数: $MAX_TRAIN_STEPS"
|
| 50 |
+
echo "LoRA Rank: $LORA_RANK"
|
| 51 |
+
echo "使用GPU: $CUDA_VISIBLE_DEVICES"
|
| 52 |
+
echo "断点重训: $RESUME_FROM_CHECKPOINT"
|
| 53 |
+
echo "==========================================="
|
| 54 |
+
|
| 55 |
+
# 检查accelerate配置
|
| 56 |
+
if [ ! -f "accelerate_config.yaml" ]; then
|
| 57 |
+
echo "错误: 未找到 accelerate_config.yaml 文件"
|
| 58 |
+
echo "请运行: accelerate config 来配置多GPU训练"
|
| 59 |
+
exit 1
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
echo "使用 accelerate 启动多GPU训练..."
|
| 63 |
+
|
| 64 |
+
# 使用 accelerate 启动训练
|
| 65 |
+
accelerate launch --config_file=accelerate_config.yaml train_lora_sd3.py \
|
| 66 |
+
--pretrained_model_name_or_path="$MODEL_NAME" \
|
| 67 |
+
--train_data_dir="$TRAIN_DATA_DIR" \
|
| 68 |
+
--output_dir="$OUTPUT_DIR" \
|
| 69 |
+
--mixed_precision="no" \
|
| 70 |
+
--resolution=$RESOLUTION \
|
| 71 |
+
--train_batch_size=$TRAIN_BATCH_SIZE \
|
| 72 |
+
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
|
| 73 |
+
--learning_rate=$LEARNING_RATE \
|
| 74 |
+
--scale_lr \
|
| 75 |
+
--lr_scheduler="cosine" \
|
| 76 |
+
--lr_warmup_steps=100 \
|
| 77 |
+
--max_train_steps=$MAX_TRAIN_STEPS \
|
| 78 |
+
--num_train_epochs=$NUM_TRAIN_EPOCHS \
|
| 79 |
+
--validation_prompt="$VALIDATION_PROMPT" \
|
| 80 |
+
--num_validation_images=$NUM_VALIDATION_IMAGES \
|
| 81 |
+
--validation_epochs=$VALIDATION_EPOCHS \
|
| 82 |
+
--checkpointing_steps=20000 \
|
| 83 |
+
--checkpoints_total_limit=10 \
|
| 84 |
+
--seed=$SEED \
|
| 85 |
+
--rank=$LORA_RANK \
|
| 86 |
+
--gradient_checkpointing \
|
| 87 |
+
--use_8bit_adam \
|
| 88 |
+
--dataloader_num_workers=0 \
|
| 89 |
+
--report_to="tensorboard" \
|
| 90 |
+
--logging_dir="logs" \
|
| 91 |
+
--adam_beta1=0.9 \
|
| 92 |
+
--adam_beta2=0.999 \
|
| 93 |
+
--adam_weight_decay=1e-2 \
|
| 94 |
+
--adam_epsilon=1e-8 \
|
| 95 |
+
--max_grad_norm=1.0 \
|
| 96 |
+
--allow_tf32 \
|
| 97 |
+
--weighting_scheme="logit_normal" \
|
| 98 |
+
--logit_mean=0.0 \
|
| 99 |
+
--logit_std=1.0 \
|
| 100 |
+
--precondition_outputs=1
|
| 101 |
+
|
| 102 |
+
echo "==========================================="
|
| 103 |
+
echo "训练完成!"
|
| 104 |
+
echo "模型保存在: $OUTPUT_DIR"
|
| 105 |
+
echo "日志保存在: $OUTPUT_DIR/logs"
|
| 106 |
+
echo "验证图片保存在: $OUTPUT_DIR/validation_images"
|
| 107 |
+
echo "==========================================="
|
| 108 |
+
|
| 109 |
+
# nohup bash train_sd3_lora.sh > train_sd3_lora.log 2>&1 &
|
train_sd3_lora2.log
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
检测到 4 个GPU
|
| 3 |
+
每个GPU批次大小: 8
|
| 4 |
+
总有效批次大小: 32
|
| 5 |
+
===== SD3 LoRA 多GPU训练开始 =====
|
| 6 |
+
模型: /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671
|
| 7 |
+
输出目录: sd3-lora-finetuned-batch-32
|
| 8 |
+
分辨率: 512
|
| 9 |
+
每个GPU批次大小: 8
|
| 10 |
+
梯度累积步数: 1
|
| 11 |
+
总有效批次大小: 32
|
| 12 |
+
学习率: 1e-5
|
| 13 |
+
最大训练步数: 500000
|
| 14 |
+
LoRA Rank: 32
|
| 15 |
+
使用GPU: 0,1,2,3
|
| 16 |
+
===========================================
|
| 17 |
+
使用 accelerate 启动多GPU训练...
|
| 18 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 19 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 20 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 21 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 22 |
+
Found 4 GPUs available
|
| 23 |
+
Multi-GPU training enabled with 4 GPUs
|
| 24 |
+
Found 4 GPUs available
|
| 25 |
+
Multi-GPU training enabled with 4 GPUs
|
| 26 |
+
Found 4 GPUs available
|
| 27 |
+
Multi-GPU training enabled with 4 GPUs
|
| 28 |
+
Found 4 GPUs available
|
| 29 |
+
Multi-GPU training enabled with 4 GPUs
|
| 30 |
+
03/09/2026 10:48:06 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
|
| 31 |
+
Num processes: 4
|
| 32 |
+
Process index: 2
|
| 33 |
+
Local process index: 2
|
| 34 |
+
Device: cuda:2
|
| 35 |
+
|
| 36 |
+
Mixed precision type: no
|
| 37 |
+
|
| 38 |
+
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
|
| 39 |
+
03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
|
| 40 |
+
Num processes: 4
|
| 41 |
+
Process index: 0
|
| 42 |
+
Local process index: 0
|
| 43 |
+
Device: cuda:0
|
| 44 |
+
|
| 45 |
+
Mixed precision type: no
|
| 46 |
+
|
| 47 |
+
[INFO] Accelerator initialized
|
| 48 |
+
03/09/2026 10:48:07 - INFO - __main__ - Number of processes: 4
|
| 49 |
+
03/09/2026 10:48:07 - INFO - __main__ - Distributed type: MULTI_GPU
|
| 50 |
+
03/09/2026 10:48:07 - INFO - __main__ - Mixed precision: no
|
| 51 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 0: NVIDIA A100-PCIE-40GB
|
| 52 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 0 memory: 39.4 GB
|
| 53 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 1: NVIDIA A100-PCIE-40GB
|
| 54 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 1 memory: 39.4 GB
|
| 55 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 2: NVIDIA A100-PCIE-40GB
|
| 56 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 2 memory: 39.4 GB
|
| 57 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 3: NVIDIA A100-PCIE-40GB
|
| 58 |
+
03/09/2026 10:48:07 - INFO - __main__ - GPU 3 memory: 39.4 GB
|
| 59 |
+
[INFO] Seed set to 42
|
| 60 |
+
[INFO] Loading tokenizers...
|
| 61 |
+
03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
|
| 62 |
+
Num processes: 4
|
| 63 |
+
Process index: 1
|
| 64 |
+
Local process index: 1
|
| 65 |
+
Device: cuda:1
|
| 66 |
+
|
| 67 |
+
Mixed precision type: no
|
| 68 |
+
|
| 69 |
+
03/09/2026 10:48:07 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
|
| 70 |
+
Num processes: 4
|
| 71 |
+
Process index: 3
|
| 72 |
+
Local process index: 3
|
| 73 |
+
Device: cuda:3
|
| 74 |
+
|
| 75 |
+
Mixed precision type: no
|
| 76 |
+
|
| 77 |
+
[INFO] Tokenizers loaded. Loading text encoders, VAE, and transformer...
|
| 78 |
+
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
| 79 |
+
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
|
| 80 |
+
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
|
| 81 |
+
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
|
| 82 |
+
{'max_image_seq_len', 'shift_terminal', 'use_beta_sigmas', 'time_shift_type', 'use_dynamic_shifting', 'stochastic_sampling', 'base_shift', 'invert_sigmas', 'use_exponential_sigmas', 'use_karras_sigmas', 'max_shift', 'base_image_seq_len'} was not found in config. Values will be initialized to default values.
|
| 83 |
+
|
| 84 |
+
{'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
If your task is similar to the task the model of the checkpoint was trained on, you can already use AutoencoderKL for predictions without further training.
|
| 89 |
+
|
| 90 |
+
{'qk_norm', 'dual_attention_layers'} was not found in config. Values will be initialized to default values.
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
All model checkpoint weights were used when initializing SD3Transformer2DModel.
|
| 94 |
+
|
| 95 |
+
All the weights of SD3Transformer2DModel were initialized from the model checkpoint at /gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671.
|
| 96 |
+
If your task is similar to the task the model of the checkpoint was trained on, you can already use SD3Transformer2DModel for predictions without further training.
|
| 97 |
+
[INFO] Text encoders, VAE, and transformer loaded
|
| 98 |
+
[rank2]:[W309 10:48:36.300339912 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 99 |
+
[INFO] Optimizer created. Loading dataset...
|
| 100 |
+
[INFO] Found metadata.jsonl, using efficient loading method
|
| 101 |
+
[INFO] Loading dataset from metadata.jsonl: /gemini/space/hsd/project/dataset/cc3m-wds/train/metadata.jsonl
|
| 102 |
+
[rank1]:[W309 10:48:38.132724538 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 103 |
+
[INFO] Processed 100000 entries from metadata.jsonl
|
| 104 |
+
[rank3]:[W309 10:48:39.804821476 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 105 |
+
[INFO] Processed 200000 entries from metadata.jsonl
|
| 106 |
+
[INFO] Processed 300000 entries from metadata.jsonl
|
| 107 |
+
[INFO] Processed 400000 entries from metadata.jsonl
|
| 108 |
+
[INFO] Processed 500000 entries from metadata.jsonl
|
| 109 |
+
[INFO] Processed 600000 entries from metadata.jsonl
|
| 110 |
+
[INFO] Processed 700000 entries from metadata.jsonl
|
| 111 |
+
[INFO] Processed 800000 entries from metadata.jsonl
|
| 112 |
+
[INFO] Processed 900000 entries from metadata.jsonl
|
| 113 |
+
[INFO] Processed 1000000 entries from metadata.jsonl
|
| 114 |
+
[INFO] Processed 1100000 entries from metadata.jsonl
|
| 115 |
+
[INFO] Processed 1200000 entries from metadata.jsonl
|
| 116 |
+
[INFO] Processed 1300000 entries from metadata.jsonl
|
| 117 |
+
[INFO] Processed 1400000 entries from metadata.jsonl
|
| 118 |
+
[INFO] Processed 1500000 entries from metadata.jsonl
|
| 119 |
+
[INFO] Processed 1600000 entries from metadata.jsonl
|
| 120 |
+
[INFO] Processed 1700000 entries from metadata.jsonl
|
| 121 |
+
[INFO] Processed 1800000 entries from metadata.jsonl
|
| 122 |
+
[INFO] Processed 1900000 entries from metadata.jsonl
|
| 123 |
+
[INFO] Processed 2000000 entries from metadata.jsonl
|
| 124 |
+
[INFO] Processed 2100000 entries from metadata.jsonl
|
| 125 |
+
[INFO] Processed 2200000 entries from metadata.jsonl
|
| 126 |
+
[INFO] Processed 2300000 entries from metadata.jsonl
|
| 127 |
+
[INFO] Processed 2400000 entries from metadata.jsonl
|
| 128 |
+
[INFO] Processed 2500000 entries from metadata.jsonl
|
| 129 |
+
[INFO] Processed 2600000 entries from metadata.jsonl
|
| 130 |
+
[INFO] Processed 2700000 entries from metadata.jsonl
|
| 131 |
+
[INFO] Processed 2800000 entries from metadata.jsonl
|
| 132 |
+
[INFO] Processed 2900000 entries from metadata.jsonl
|
| 133 |
+
[INFO] Loaded 2905954 image-caption pairs from metadata.jsonl
|
| 134 |
+
[INFO] Dataset loaded successfully.
|
| 135 |
+
[rank0]:[W309 10:49:04.100729496 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 136 |
+
[INFO] All processes synchronized. Building transforms and DataLoader...
|
| 137 |
+
[INFO] Dataset columns: ['image', 'text']
|
| 138 |
+
[INFO] Using image column: image
|
| 139 |
+
[WARNING] Specified caption_column 'caption' not found. Using 'text' instead.
|
| 140 |
+
[INFO] Using caption column: text
|
| 141 |
+
03/09/2026 10:49:34 - INFO - __main__ - Auto-setting dataloader_num_workers to 4 for multi-GPU training
|
| 142 |
+
[INFO] DataLoader ready. Computing training steps and scheduler...
|
| 143 |
+
03/09/2026 10:49:36 - WARNING - __main__ - Failed to initialize trackers: Descriptors cannot be created directly.
|
| 144 |
+
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
|
| 145 |
+
If you cannot immediately regenerate your protos, some other possible workarounds are:
|
| 146 |
+
1. Downgrade the protobuf package to 3.20.x or lower.
|
| 147 |
+
2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).
|
| 148 |
+
|
| 149 |
+
More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
|
| 150 |
+
03/09/2026 10:49:36 - WARNING - __main__ - Continuing without tracking. You can monitor training through console logs.
|
| 151 |
+
03/09/2026 10:49:36 - INFO - __main__ - ***** Running training *****
|
| 152 |
+
03/09/2026 10:49:36 - INFO - __main__ - Num examples = 2905954
|
| 153 |
+
03/09/2026 10:49:36 - INFO - __main__ - Num Epochs = 6
|
| 154 |
+
03/09/2026 10:49:36 - INFO - __main__ - Instantaneous batch size per device = 8
|
| 155 |
+
03/09/2026 10:49:36 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 32
|
| 156 |
+
03/09/2026 10:49:36 - INFO - __main__ - Gradient Accumulation steps = 1
|
| 157 |
+
03/09/2026 10:49:36 - INFO - __main__ - Total optimization steps = 500000
|
| 158 |
+
03/09/2026 10:49:36 - INFO - __main__ - Number of GPU processes = 4
|
| 159 |
+
03/09/2026 10:49:36 - INFO - __main__ - Effective batch size per GPU = 8
|
| 160 |
+
03/09/2026 10:49:36 - INFO - __main__ - Total effective batch size across all GPUs = 32
|
| 161 |
+
[INFO] Training setup complete. num_examples=2905954, max_train_steps=500000, num_epochs=6
|
| 162 |
+
|
| 163 |
+
[rank3]:[W309 10:49:40.986061745 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
|
| 164 |
+
[rank2]:[W309 10:49:40.986962214 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
|
| 165 |
+
[rank1]:[W309 10:49:40.988288933 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
|
| 166 |
+
[rank0]:[W309 10:49:40.989585595 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
|
| 167 |
+
|
| 168 |
+
[rank1]: File "/gemini/space/gzy_new/models/Sida/train_lora_sd3.py", line 1597, in <module>
|
| 169 |
+
[rank1]: main(args)
|
| 170 |
+
[rank1]: File "/gemini/space/gzy_new/models/Sida/train_lora_sd3.py", line 1410, in main
|
| 171 |
+
[rank1]: timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 172 |
+
[rank1]: File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler
|
| 173 |
+
[rank1]: _error_if_any_worker_fails()
|
| 174 |
+
[rank1]: RuntimeError: DataLoader worker (pid 8746) is killed by signal: Terminated.
|
| 175 |
+
W0309 10:54:03.295000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7639 closing signal SIGTERM
|
| 176 |
+
W0309 10:54:03.296000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7640 closing signal SIGTERM
|
| 177 |
+
W0309 10:54:03.296000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 7641 closing signal SIGTERM
|
| 178 |
+
E0309 10:54:03.762000 7448 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -15) local_rank: 0 (pid: 7638) of binary: /root/miniconda3/envs/SiT/bin/python3.10
|
| 179 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 180 |
+
Traceback (most recent call last):
|
| 181 |
+
File "/root/miniconda3/envs/SiT/bin/accelerate", line 6, in <module>
|
| 182 |
+
sys.exit(main())
|
| 183 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
|
| 184 |
+
args.func(args)
|
| 185 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1189, in launch_command
|
| 186 |
+
multi_gpu_launcher(args)
|
| 187 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/accelerate/commands/launch.py", line 815, in multi_gpu_launcher
|
| 188 |
+
distrib_run.run(args)
|
| 189 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
|
| 190 |
+
elastic_launch(
|
| 191 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 192 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 193 |
+
File "/root/miniconda3/envs/SiT/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
|
| 194 |
+
raise ChildFailedError(
|
| 195 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 196 |
+
==========================================================
|
| 197 |
+
train_lora_sd3.py FAILED
|
| 198 |
+
----------------------------------------------------------
|
| 199 |
+
Failures:
|
| 200 |
+
<NO_OTHER_FAILURES>
|
| 201 |
+
----------------------------------------------------------
|
| 202 |
+
Root Cause (first observed failure):
|
| 203 |
+
[0]:
|
| 204 |
+
time : 2026-03-09_10:54:03
|
| 205 |
+
host : 1406241bacf2123cb0cdd1395a94d0f5-taskrole1-0
|
| 206 |
+
rank : 0 (local_rank: 0)
|
| 207 |
+
exitcode : -15 (pid: 7638)
|
| 208 |
+
error_file: <N/A>
|
| 209 |
+
traceback : Signal 15 (SIGTERM) received by PID 7638
|
| 210 |
+
==========================================================
|
| 211 |
+
===========================================
|
| 212 |
+
训练完成!
|
| 213 |
+
模型保存在: sd3-lora-finetuned-batch-32
|
| 214 |
+
日志保存在: sd3-lora-finetuned-batch-32/logs
|
| 215 |
+
验证图片保存在: sd3-lora-finetuned-batch-32/validation_images
|
| 216 |
+
===========================================
|
train_sd3_lora2.sh
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# SD3 LoRA Fine-tuning Training Script
|
| 4 |
+
# 使用 Stable Diffusion 3 进行 LoRA 微调训练 - 多GPU优化版本
|
| 5 |
+
|
| 6 |
+
# 设置环境变量
|
| 7 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU数量调整
|
| 8 |
+
# export PYTHONPATH=$PYTHONPATH:/gemini/space/gzy_new/Rectified_Noise/Finetune/finetune-coco
|
| 9 |
+
|
| 10 |
+
# 检查GPU数量
|
| 11 |
+
num_gpus=$(nvidia-smi --list-gpus | wc -l)
|
| 12 |
+
echo "检测到 $num_gpus 个GPU"
|
| 13 |
+
|
| 14 |
+
# 训练参数配置
|
| 15 |
+
MODEL_NAME="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 16 |
+
#DATASET_NAME="lambdalabs/naruto-blip-captions" # 或者使用本地数据集路径
|
| 17 |
+
TRAIN_DATA_DIR="/gemini/space/hsd/project/dataset/cc3m-wds/train" # 本地数据集路径0
|
| 18 |
+
OUTPUT_DIR="sd3-lora-finetuned-batch-32"
|
| 19 |
+
RESOLUTION=512
|
| 20 |
+
# 多GPU训练时调整批次大小 - 每个GPU的批次大小
|
| 21 |
+
TRAIN_BATCH_SIZE=8 # 每个GPU的批次大小,总批次大小 = TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS
|
| 22 |
+
GRADIENT_ACCUMULATION_STEPS=1 # 梯度累积步数
|
| 23 |
+
LEARNING_RATE=1e-5
|
| 24 |
+
MAX_TRAIN_STEPS=500000
|
| 25 |
+
NUM_TRAIN_EPOCHS=50
|
| 26 |
+
VALIDATION_PROMPT="A photo of a beautiful landscape with mountains and lake"
|
| 27 |
+
NUM_VALIDATION_IMAGES=2
|
| 28 |
+
VALIDATION_EPOCHS=1 # 每10个epoch验证一次
|
| 29 |
+
LORA_RANK=32
|
| 30 |
+
SEED=42
|
| 31 |
+
|
| 32 |
+
# 计算有效批次大小
|
| 33 |
+
effective_batch_size=$((TRAIN_BATCH_SIZE * num_gpus * GRADIENT_ACCUMULATION_STEPS))
|
| 34 |
+
echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
|
| 35 |
+
echo "总有效批次大小: $effective_batch_size"
|
| 36 |
+
|
| 37 |
+
# 创建输出目录
|
| 38 |
+
mkdir -p $OUTPUT_DIR
|
| 39 |
+
|
| 40 |
+
echo "===== SD3 LoRA 多GPU训练开始 ====="
|
| 41 |
+
echo "模型: $MODEL_NAME"
|
| 42 |
+
echo "输出目录: $OUTPUT_DIR"
|
| 43 |
+
echo "分辨率: $RESOLUTION"
|
| 44 |
+
echo "每个GPU批次大小: $TRAIN_BATCH_SIZE"
|
| 45 |
+
echo "梯度累积步数: $GRADIENT_ACCUMULATION_STEPS"
|
| 46 |
+
echo "总有效批次大小: $effective_batch_size"
|
| 47 |
+
echo "学习率: $LEARNING_RATE"
|
| 48 |
+
echo "最大训练步数: $MAX_TRAIN_STEPS"
|
| 49 |
+
echo "LoRA Rank: $LORA_RANK"
|
| 50 |
+
echo "使用GPU: $CUDA_VISIBLE_DEVICES"
|
| 51 |
+
echo "==========================================="
|
| 52 |
+
|
| 53 |
+
# 检查accelerate配置
|
| 54 |
+
if [ ! -f "accelerate_config.yaml" ]; then
|
| 55 |
+
echo "错误: 未找到 accelerate_config.yaml 文件"
|
| 56 |
+
echo "请运行: accelerate config 来配置多GPU训练"
|
| 57 |
+
exit 1
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
echo "使用 accelerate 启动多GPU训练..."
|
| 61 |
+
|
| 62 |
+
# 使用 accelerate 启动训练
|
| 63 |
+
accelerate launch --config_file=accelerate_config.yaml train_lora_sd3.py \
|
| 64 |
+
--pretrained_model_name_or_path="$MODEL_NAME" \
|
| 65 |
+
--train_data_dir="$TRAIN_DATA_DIR" \
|
| 66 |
+
--output_dir="$OUTPUT_DIR" \
|
| 67 |
+
--mixed_precision="no" \
|
| 68 |
+
--resolution=$RESOLUTION \
|
| 69 |
+
--train_batch_size=$TRAIN_BATCH_SIZE \
|
| 70 |
+
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
|
| 71 |
+
--learning_rate=$LEARNING_RATE \
|
| 72 |
+
--scale_lr \
|
| 73 |
+
--lr_scheduler="cosine" \
|
| 74 |
+
--lr_warmup_steps=100 \
|
| 75 |
+
--max_train_steps=$MAX_TRAIN_STEPS \
|
| 76 |
+
--num_train_epochs=$NUM_TRAIN_EPOCHS \
|
| 77 |
+
--validation_prompt="$VALIDATION_PROMPT" \
|
| 78 |
+
--num_validation_images=$NUM_VALIDATION_IMAGES \
|
| 79 |
+
--validation_epochs=$VALIDATION_EPOCHS \
|
| 80 |
+
--checkpointing_steps=50000 \
|
| 81 |
+
--checkpoints_total_limit=10 \
|
| 82 |
+
--seed=$SEED \
|
| 83 |
+
--rank=$LORA_RANK \
|
| 84 |
+
--gradient_checkpointing \
|
| 85 |
+
--use_8bit_adam \
|
| 86 |
+
--dataloader_num_workers=0 \
|
| 87 |
+
--report_to="tensorboard" \
|
| 88 |
+
--logging_dir="logs" \
|
| 89 |
+
--adam_beta1=0.9 \
|
| 90 |
+
--adam_beta2=0.999 \
|
| 91 |
+
--adam_weight_decay=1e-2 \
|
| 92 |
+
--adam_epsilon=1e-8 \
|
| 93 |
+
--max_grad_norm=1.0 \
|
| 94 |
+
--allow_tf32 \
|
| 95 |
+
--weighting_scheme="logit_normal" \
|
| 96 |
+
--logit_mean=0.0 \
|
| 97 |
+
--logit_std=1.0 \
|
| 98 |
+
--precondition_outputs=1
|
| 99 |
+
|
| 100 |
+
echo "==========================================="
|
| 101 |
+
echo "训练完成!"
|
| 102 |
+
echo "模型保存在: $OUTPUT_DIR"
|
| 103 |
+
echo "日志保存在: $OUTPUT_DIR/logs"
|
| 104 |
+
echo "验证图片保存在: $OUTPUT_DIR/validation_images"
|
| 105 |
+
echo "==========================================="
|
| 106 |
+
|
| 107 |
+
# nohup bash train_sd3_lora2.sh > train_sd3_lora2.log 2>&1 &
|
visual.sh
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# 分两次主模型运行,最后拼图,避免同时加载双模型导致 OOM
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-/root/miniconda3/envs/SiT/bin/python}"
|
| 8 |
+
|
| 9 |
+
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
|
| 10 |
+
LOCAL_PIPELINE_PATH="/gemini/space/gzy_new/models/Sida/pipeline_stable_diffusion_3.py"
|
| 11 |
+
LORA_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
|
| 12 |
+
RECTIFIED_WEIGHTS="/gemini/space/gzy_new/models/Sida/rectified-noise-batch-2/checkpoint-220000/sit_weights"
|
| 13 |
+
|
| 14 |
+
# 可按需修改为你想看的文本
|
| 15 |
+
PROMPT="young man beside a dog wearing sunglasses"
|
| 16 |
+
|
| 17 |
+
OUTPUT_DIR="/gemini/space/gzy_new/models/Sida/sd3_lora_rn_pair_samples"
|
| 18 |
+
OUTPUT_FILE="${OUTPUT_DIR}/lora_rn_4x8_step180.png"
|
| 19 |
+
LORA_NPZ="${OUTPUT_DIR}/lora_trace_4x8.npz"
|
| 20 |
+
RN_NPZ="${OUTPUT_DIR}/rn_trace_4x8.npz"
|
| 21 |
+
|
| 22 |
+
STEPS=180
|
| 23 |
+
GUIDANCE_SCALE=7.0
|
| 24 |
+
HEIGHT=512
|
| 25 |
+
WIDTH=512
|
| 26 |
+
SEED=42
|
| 27 |
+
MIXED_PRECISION="fp16" # no / fp16 / bf16
|
| 28 |
+
NUM_SIT_LAYERS=1 # 需与训练一致
|
| 29 |
+
|
| 30 |
+
mkdir -p "$OUTPUT_DIR"
|
| 31 |
+
|
| 32 |
+
if [ ! -x "$PYTHON_BIN" ]; then
|
| 33 |
+
echo "ERROR: PYTHON_BIN not executable: $PYTHON_BIN"
|
| 34 |
+
echo "Hint: export PYTHON_BIN=/path/to/your/python"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
if [ ! -e "$PRETRAINED_MODEL" ]; then
|
| 39 |
+
echo "ERROR: PRETRAINED_MODEL not found: $PRETRAINED_MODEL"
|
| 40 |
+
exit 1
|
| 41 |
+
fi
|
| 42 |
+
if [ ! -f "$LOCAL_PIPELINE_PATH" ]; then
|
| 43 |
+
echo "ERROR: LOCAL_PIPELINE_PATH not found: $LOCAL_PIPELINE_PATH"
|
| 44 |
+
exit 1
|
| 45 |
+
fi
|
| 46 |
+
if [ ! -e "$LORA_PATH" ]; then
|
| 47 |
+
echo "ERROR: LORA_PATH not found: $LORA_PATH"
|
| 48 |
+
exit 1
|
| 49 |
+
fi
|
| 50 |
+
if [ ! -e "$RECTIFIED_WEIGHTS" ]; then
|
| 51 |
+
echo "ERROR: RECTIFIED_WEIGHTS not found: $RECTIFIED_WEIGHTS"
|
| 52 |
+
exit 1
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
COMMON_ARGS=(
|
| 56 |
+
--pretrained_model_name_or_path "$PRETRAINED_MODEL"
|
| 57 |
+
--local_pipeline_path "$LOCAL_PIPELINE_PATH"
|
| 58 |
+
--lora_path "$LORA_PATH"
|
| 59 |
+
--rectified_weights "$RECTIFIED_WEIGHTS"
|
| 60 |
+
--num_sit_layers "$NUM_SIT_LAYERS"
|
| 61 |
+
--prompt "$PROMPT"
|
| 62 |
+
--output "$OUTPUT_FILE"
|
| 63 |
+
--lora_npz "$LORA_NPZ"
|
| 64 |
+
--rn_npz "$RN_NPZ"
|
| 65 |
+
--steps "$STEPS"
|
| 66 |
+
--guidance_scale "$GUIDANCE_SCALE"
|
| 67 |
+
--height "$HEIGHT"
|
| 68 |
+
--width "$WIDTH"
|
| 69 |
+
--seed "$SEED"
|
| 70 |
+
--mixed_precision "$MIXED_PRECISION"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
"$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage lora
|
| 74 |
+
"$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage rn
|
| 75 |
+
"$PYTHON_BIN" /gemini/space/gzy_new/models/Sida/visualize_lora_rn_4x8.py "${COMMON_ARGS[@]}" --stage pair
|
| 76 |
+
|
| 77 |
+
echo "Done. Saved to: $OUTPUT_FILE"
|
| 78 |
+
# nohup bash run_sd3_rectified_sampling.sh > run_sd3_rectified_sampling.log 2>&1 &
|
visualize_lora_rn_4x8.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
分阶段生成 4x8 对比图(默认 180 步):
|
| 5 |
+
- stage=lora: 仅生成 LoRA 轨迹中间结果并保存中间 npz
|
| 6 |
+
- stage=rn: 仅生成 RN 轨迹中间结果并保存中间 npz
|
| 7 |
+
- stage=pair: 读取两阶段 npz,计算第3/4行并拼接总图
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import importlib.util
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from diffusers import AutoencoderKL
|
| 19 |
+
from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def dynamic_import_training_classes(project_root: str):
|
| 23 |
+
import sys
|
| 24 |
+
sys.path.insert(0, project_root)
|
| 25 |
+
import train_rectified_noise as trn
|
| 26 |
+
return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_local_pipeline_class(local_pipeline_path: str):
|
| 30 |
+
module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3"
|
| 31 |
+
spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path)
|
| 32 |
+
if spec is None or spec.loader is None:
|
| 33 |
+
raise ImportError(f"Failed to import local pipeline: {local_pipeline_path}")
|
| 34 |
+
module = importlib.util.module_from_spec(spec)
|
| 35 |
+
spec.loader.exec_module(module)
|
| 36 |
+
return module.StableDiffusion3Pipeline
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_sit_weights(rectified_module, weights_path: str):
|
| 40 |
+
p = Path(weights_path)
|
| 41 |
+
if p.is_dir():
|
| 42 |
+
search_dirs = [p, p / "sit_weights"]
|
| 43 |
+
for d in search_dirs:
|
| 44 |
+
if not d.exists():
|
| 45 |
+
continue
|
| 46 |
+
st = d / "pytorch_sit_weights.safetensors"
|
| 47 |
+
if st.exists():
|
| 48 |
+
from safetensors.torch import load_file
|
| 49 |
+
rectified_module.load_state_dict(load_file(str(st)), strict=False)
|
| 50 |
+
return True
|
| 51 |
+
for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
|
| 52 |
+
ck = d / name
|
| 53 |
+
if ck.exists():
|
| 54 |
+
rectified_module.load_state_dict(torch.load(str(ck), map_location="cpu"), strict=False)
|
| 55 |
+
return True
|
| 56 |
+
return False
|
| 57 |
+
if str(p).endswith(".safetensors"):
|
| 58 |
+
from safetensors.torch import load_file
|
| 59 |
+
state = load_file(str(p))
|
| 60 |
+
else:
|
| 61 |
+
state = torch.load(str(p), map_location="cpu")
|
| 62 |
+
rectified_module.load_state_dict(state, strict=False)
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device):
|
| 67 |
+
RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
|
| 68 |
+
tfm = base_pipeline.transformer
|
| 69 |
+
sit_hidden_size = getattr(tfm.config, "joint_attention_dim", None) or getattr(tfm.config, "inner_dim", 4096)
|
| 70 |
+
rectified_module = RectifiedNoiseModule(
|
| 71 |
+
hidden_size=sit_hidden_size,
|
| 72 |
+
num_sit_layers=num_sit_layers,
|
| 73 |
+
num_attention_heads=getattr(tfm.config, "num_attention_heads", 32),
|
| 74 |
+
input_dim=getattr(tfm.config, "in_channels", 16),
|
| 75 |
+
transformer_hidden_size=getattr(tfm.config, "hidden_size", 1536),
|
| 76 |
+
)
|
| 77 |
+
if not load_sit_weights(rectified_module, rectified_weights):
|
| 78 |
+
raise RuntimeError(f"Failed to load rectified weights: {rectified_weights}")
|
| 79 |
+
model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device)
|
| 80 |
+
model.eval()
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def set_pipeline_modules_eval(pipe):
|
| 85 |
+
for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "model"]:
|
| 86 |
+
m = getattr(pipe, name, None)
|
| 87 |
+
if m is not None and hasattr(m, "eval"):
|
| 88 |
+
m.eval()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def align_rn_branch_dtype(pipe, dtype):
|
| 92 |
+
model = getattr(pipe, "model", None)
|
| 93 |
+
if model is None:
|
| 94 |
+
return
|
| 95 |
+
rn_branch = getattr(model, "rectified_noise_module", None)
|
| 96 |
+
if rn_branch is not None:
|
| 97 |
+
rn_branch.to(device=pipe._execution_device, dtype=dtype)
|
| 98 |
+
if hasattr(model, "to"):
|
| 99 |
+
model.to(device=pipe._execution_device)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def decode_latent_to_uint8(vae, latents, normalize=False):
|
| 104 |
+
shift = getattr(vae.config, "shift_factor", 0.0) or 0.0
|
| 105 |
+
scaled = (latents / vae.config.scaling_factor) + shift
|
| 106 |
+
image = vae.decode(scaled, return_dict=False)[0]
|
| 107 |
+
x = image[0].float()
|
| 108 |
+
if normalize:
|
| 109 |
+
x = (x - x.min()) / (x.max() - x.min() + 1e-6)
|
| 110 |
+
else:
|
| 111 |
+
x = (x / 2.0) + 0.5
|
| 112 |
+
x = x.clamp(0, 1)
|
| 113 |
+
return (x.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def encode_prompt_for_pipe(pipe, prompt, guidance_scale):
|
| 117 |
+
do_cfg = guidance_scale > 1.0
|
| 118 |
+
pe, npe, pp, npp = pipe.encode_prompt(
|
| 119 |
+
prompt=prompt,
|
| 120 |
+
prompt_2=prompt,
|
| 121 |
+
prompt_3=prompt,
|
| 122 |
+
do_classifier_free_guidance=do_cfg,
|
| 123 |
+
num_images_per_prompt=1,
|
| 124 |
+
device=pipe._execution_device,
|
| 125 |
+
)
|
| 126 |
+
if do_cfg:
|
| 127 |
+
pe = torch.cat([npe, pe], dim=0)
|
| 128 |
+
pp = torch.cat([npp, pp], dim=0)
|
| 129 |
+
return pe, pp
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def run_single_trajectory(pipe, prompt, steps, guidance_scale, height, width, seed, use_rn_model=False, autocast_dtype=None):
|
| 134 |
+
device = pipe._execution_device
|
| 135 |
+
dtype = next(pipe.transformer.parameters()).dtype
|
| 136 |
+
sample_idx = np.linspace(0, steps - 1, 8, dtype=int).tolist()
|
| 137 |
+
|
| 138 |
+
latent_h = height // pipe.vae_scale_factor
|
| 139 |
+
latent_w = width // pipe.vae_scale_factor
|
| 140 |
+
g = torch.Generator(device=device).manual_seed(seed)
|
| 141 |
+
init_latents = torch.randn(
|
| 142 |
+
(1, pipe.transformer.config.in_channels, latent_h, latent_w),
|
| 143 |
+
device=device,
|
| 144 |
+
dtype=dtype,
|
| 145 |
+
generator=g,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if use_rn_model:
|
| 149 |
+
effective_model = getattr(pipe, "model", None)
|
| 150 |
+
if effective_model is not None:
|
| 151 |
+
pipe.model = effective_model
|
| 152 |
+
|
| 153 |
+
step_latents = {}
|
| 154 |
+
|
| 155 |
+
def _capture_callback(_pipe, i, _t, callback_kwargs):
|
| 156 |
+
if i in sample_idx:
|
| 157 |
+
step_latents[i] = callback_kwargs["latents"].detach().clone()
|
| 158 |
+
return callback_kwargs
|
| 159 |
+
|
| 160 |
+
# 跟参考脚本一致,主路径走 pipeline(...),通过 callback 抓中间 latent
|
| 161 |
+
if autocast_dtype is None:
|
| 162 |
+
_ = pipe(
|
| 163 |
+
prompt=prompt,
|
| 164 |
+
height=height,
|
| 165 |
+
width=width,
|
| 166 |
+
num_inference_steps=steps,
|
| 167 |
+
guidance_scale=guidance_scale,
|
| 168 |
+
latents=init_latents,
|
| 169 |
+
num_images_per_prompt=1,
|
| 170 |
+
callback_on_step_end=_capture_callback,
|
| 171 |
+
callback_on_step_end_tensor_inputs=["latents"],
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
with torch.autocast("cuda", dtype=autocast_dtype):
|
| 175 |
+
_ = pipe(
|
| 176 |
+
prompt=prompt,
|
| 177 |
+
height=height,
|
| 178 |
+
width=width,
|
| 179 |
+
num_inference_steps=steps,
|
| 180 |
+
guidance_scale=guidance_scale,
|
| 181 |
+
latents=init_latents,
|
| 182 |
+
num_images_per_prompt=1,
|
| 183 |
+
callback_on_step_end=_capture_callback,
|
| 184 |
+
callback_on_step_end_tensor_inputs=["latents"],
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
images = []
|
| 188 |
+
latents = []
|
| 189 |
+
noises = []
|
| 190 |
+
prev_lat = init_latents
|
| 191 |
+
pe = pp = None
|
| 192 |
+
do_cfg = guidance_scale > 1.0
|
| 193 |
+
timesteps = None
|
| 194 |
+
if use_rn_model:
|
| 195 |
+
pe, pp = encode_prompt_for_pipe(pipe, prompt, guidance_scale)
|
| 196 |
+
pipe.scheduler.set_timesteps(steps, device=device)
|
| 197 |
+
timesteps = pipe.scheduler.timesteps
|
| 198 |
+
|
| 199 |
+
for i in sample_idx:
|
| 200 |
+
cur_lat = step_latents[i]
|
| 201 |
+
images.append(decode_latent_to_uint8(pipe.vae, cur_lat, normalize=False))
|
| 202 |
+
latents.append(cur_lat.squeeze(0).float().cpu().numpy())
|
| 203 |
+
if use_rn_model:
|
| 204 |
+
# 第3行严格使用 RN 噪声分支(速度场)输出
|
| 205 |
+
model_in = torch.cat([cur_lat] * 2) if do_cfg else cur_lat
|
| 206 |
+
ts = timesteps[i].expand(model_in.shape[0])
|
| 207 |
+
if autocast_dtype is not None and device == "cuda":
|
| 208 |
+
with torch.autocast("cuda", dtype=autocast_dtype):
|
| 209 |
+
rn_out = pipe.model(
|
| 210 |
+
hidden_states=model_in,
|
| 211 |
+
timestep=ts,
|
| 212 |
+
encoder_hidden_states=pe,
|
| 213 |
+
pooled_projections=pp,
|
| 214 |
+
return_dict=False,
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
rn_out = pipe.model(
|
| 218 |
+
hidden_states=model_in,
|
| 219 |
+
timestep=ts,
|
| 220 |
+
encoder_hidden_states=pe,
|
| 221 |
+
pooled_projections=pp,
|
| 222 |
+
return_dict=False,
|
| 223 |
+
)
|
| 224 |
+
# SD3WithRectifiedNoise: (final_output, mean_out, var_out)
|
| 225 |
+
rn_branch = rn_out[1] if isinstance(rn_out, tuple) and len(rn_out) > 1 else rn_out[0]
|
| 226 |
+
if do_cfg:
|
| 227 |
+
ru, rt = rn_branch.chunk(2)
|
| 228 |
+
rn_branch = ru + guidance_scale * (rt - ru)
|
| 229 |
+
noises.append(rn_branch.squeeze(0).float().cpu().numpy())
|
| 230 |
+
else:
|
| 231 |
+
# lora 阶段保留占位,pair 阶段不会用到
|
| 232 |
+
delta = (cur_lat - prev_lat)
|
| 233 |
+
noises.append(delta.squeeze(0).float().cpu().numpy())
|
| 234 |
+
prev_lat = cur_lat
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"images": np.stack(images, axis=0),
|
| 238 |
+
"latents": np.stack(latents, axis=0),
|
| 239 |
+
"noises": np.stack(noises, axis=0),
|
| 240 |
+
"sample_idx": np.array(sample_idx, dtype=np.int32),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def save_grid_4x8(rows, sample_idx, out_path, cell_w=512, cell_h=512):
|
| 245 |
+
cols = 8
|
| 246 |
+
grid = Image.new("RGB", (cols * cell_w, 4 * cell_h), color=(245, 245, 245))
|
| 247 |
+
for r in range(4):
|
| 248 |
+
for c in range(cols):
|
| 249 |
+
img = Image.fromarray(rows[r][c]).resize((cell_w, cell_h), Image.BILINEAR)
|
| 250 |
+
x = c * cell_w
|
| 251 |
+
y = r * cell_h
|
| 252 |
+
grid.paste(img, (x, y))
|
| 253 |
+
grid.save(out_path)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def stage_lora(args, dtype, device):
|
| 257 |
+
pipe = DiffusersStableDiffusion3Pipeline.from_pretrained(
|
| 258 |
+
args.pretrained_model_name_or_path,
|
| 259 |
+
revision=args.revision,
|
| 260 |
+
variant=args.variant,
|
| 261 |
+
torch_dtype=dtype,
|
| 262 |
+
).to(device)
|
| 263 |
+
pipe.load_lora_weights(args.lora_path)
|
| 264 |
+
set_pipeline_modules_eval(pipe)
|
| 265 |
+
pipe.set_progress_bar_config(disable=True)
|
| 266 |
+
data = run_single_trajectory(
|
| 267 |
+
pipe=pipe,
|
| 268 |
+
prompt=args.prompt,
|
| 269 |
+
steps=args.steps,
|
| 270 |
+
guidance_scale=args.guidance_scale,
|
| 271 |
+
height=args.height,
|
| 272 |
+
width=args.width,
|
| 273 |
+
seed=args.seed,
|
| 274 |
+
use_rn_model=False,
|
| 275 |
+
autocast_dtype=dtype if args.mixed_precision != "no" and device == "cuda" else None,
|
| 276 |
+
)
|
| 277 |
+
np.savez_compressed(args.lora_npz, **data)
|
| 278 |
+
print(f"[lora] saved: {args.lora_npz}")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def stage_rn(args, dtype, device):
|
| 282 |
+
LocalPipe = load_local_pipeline_class(args.local_pipeline_path)
|
| 283 |
+
pipe = LocalPipe.from_pretrained(
|
| 284 |
+
args.pretrained_model_name_or_path,
|
| 285 |
+
revision=args.revision,
|
| 286 |
+
variant=args.variant,
|
| 287 |
+
torch_dtype=dtype,
|
| 288 |
+
).to(device)
|
| 289 |
+
pipe.load_lora_weights(args.lora_path)
|
| 290 |
+
pipe.model = build_rn_model(pipe, args.rectified_weights, args.num_sit_layers, device)
|
| 291 |
+
# 避免 RN 速度场额外前向时出现 Half/Float 冲突
|
| 292 |
+
align_rn_branch_dtype(pipe, dtype)
|
| 293 |
+
set_pipeline_modules_eval(pipe)
|
| 294 |
+
pipe.set_progress_bar_config(disable=True)
|
| 295 |
+
data = run_single_trajectory(
|
| 296 |
+
pipe=pipe,
|
| 297 |
+
prompt=args.prompt,
|
| 298 |
+
steps=args.steps,
|
| 299 |
+
guidance_scale=args.guidance_scale,
|
| 300 |
+
height=args.height,
|
| 301 |
+
width=args.width,
|
| 302 |
+
seed=args.seed,
|
| 303 |
+
use_rn_model=True,
|
| 304 |
+
autocast_dtype=dtype if args.mixed_precision != "no" and device == "cuda" else None,
|
| 305 |
+
)
|
| 306 |
+
np.savez_compressed(args.rn_npz, **data)
|
| 307 |
+
print(f"[rn] saved: {args.rn_npz}")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def stage_pair(args, dtype, device):
|
| 311 |
+
lora = np.load(args.lora_npz)
|
| 312 |
+
rn = np.load(args.rn_npz)
|
| 313 |
+
lora_images = lora["images"]
|
| 314 |
+
rn_images = rn["images"]
|
| 315 |
+
sample_idx = lora["sample_idx"]
|
| 316 |
+
rn_noises = rn["noises"] # RN 速度场,shape: [8, C, H, W]
|
| 317 |
+
|
| 318 |
+
def _velocity_to_sparse_points(vel_chw, out_h, out_w, q=99.6, point_color=(245, 245, 245)):
|
| 319 |
+
# 通道聚合成强度图,再转成黑底稀疏点图
|
| 320 |
+
mag = np.sqrt(np.sum(np.square(vel_chw.astype(np.float32)), axis=0)) # [H, W]
|
| 321 |
+
thr = np.percentile(mag, q)
|
| 322 |
+
mask = mag >= thr
|
| 323 |
+
# 最近邻放大到输出分辨率
|
| 324 |
+
sy = max(1, out_h // mask.shape[0])
|
| 325 |
+
sx = max(1, out_w // mask.shape[1])
|
| 326 |
+
up = np.repeat(np.repeat(mask, sy, axis=0), sx, axis=1)
|
| 327 |
+
up = up[:out_h, :out_w]
|
| 328 |
+
canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
|
| 329 |
+
canvas[up] = np.array(point_color, dtype=np.uint8)
|
| 330 |
+
return canvas, mag
|
| 331 |
+
|
| 332 |
+
# 第3行:黑底 + 稀疏速度点(接近你最初图风格)
|
| 333 |
+
step_noise_imgs = []
|
| 334 |
+
# 第4行:速度场累积后再做稀疏点图
|
| 335 |
+
sum_noise_imgs = []
|
| 336 |
+
running_mag = None
|
| 337 |
+
for i in range(8):
|
| 338 |
+
step_vis, mag = _velocity_to_sparse_points(rn_noises[i], args.height, args.width, q=99.6)
|
| 339 |
+
if running_mag is None:
|
| 340 |
+
running_mag = mag
|
| 341 |
+
else:
|
| 342 |
+
running_mag = running_mag + mag
|
| 343 |
+
thr_sum = np.percentile(running_mag, 99.2)
|
| 344 |
+
mask_sum = running_mag >= thr_sum
|
| 345 |
+
sy = max(1, args.height // mask_sum.shape[0])
|
| 346 |
+
sx = max(1, args.width // mask_sum.shape[1])
|
| 347 |
+
up_sum = np.repeat(np.repeat(mask_sum, sy, axis=0), sx, axis=1)[:args.height, :args.width]
|
| 348 |
+
sum_vis = np.zeros((args.height, args.width, 3), dtype=np.uint8)
|
| 349 |
+
sum_vis[up_sum] = np.array([245, 245, 245], dtype=np.uint8)
|
| 350 |
+
|
| 351 |
+
step_noise_imgs.append(step_vis)
|
| 352 |
+
sum_noise_imgs.append(sum_vis)
|
| 353 |
+
|
| 354 |
+
rows = [
|
| 355 |
+
[lora_images[i] for i in range(8)],
|
| 356 |
+
[rn_images[i] for i in range(8)],
|
| 357 |
+
step_noise_imgs,
|
| 358 |
+
sum_noise_imgs,
|
| 359 |
+
]
|
| 360 |
+
save_grid_4x8(rows, sample_idx, args.output, cell_w=args.width, cell_h=args.height)
|
| 361 |
+
print(f"[pair] saved: {args.output}")
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def main():
|
| 365 |
+
parser = argparse.ArgumentParser()
|
| 366 |
+
parser.add_argument("--stage", type=str, default="all", choices=["all", "lora", "rn", "pair"])
|
| 367 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 368 |
+
parser.add_argument("--revision", type=str, default=None)
|
| 369 |
+
parser.add_argument("--variant", type=str, default=None)
|
| 370 |
+
parser.add_argument("--local_pipeline_path", type=str, required=True)
|
| 371 |
+
parser.add_argument("--lora_path", type=str, required=True)
|
| 372 |
+
parser.add_argument("--rectified_weights", type=str, required=True)
|
| 373 |
+
parser.add_argument("--num_sit_layers", type=int, default=1)
|
| 374 |
+
parser.add_argument("--prompt", type=str, required=True)
|
| 375 |
+
parser.add_argument("--output", type=str, default="lora_rn_4x8.png")
|
| 376 |
+
parser.add_argument("--lora_npz", type=str, default="lora_trace_4x8.npz")
|
| 377 |
+
parser.add_argument("--rn_npz", type=str, default="rn_trace_4x8.npz")
|
| 378 |
+
parser.add_argument("--steps", type=int, default=180)
|
| 379 |
+
parser.add_argument("--guidance_scale", type=float, default=7.0)
|
| 380 |
+
parser.add_argument("--height", type=int, default=512)
|
| 381 |
+
parser.add_argument("--width", type=int, default=512)
|
| 382 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 383 |
+
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 384 |
+
args = parser.parse_args()
|
| 385 |
+
|
| 386 |
+
dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
|
| 387 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 388 |
+
|
| 389 |
+
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
Path(args.lora_npz).parent.mkdir(parents=True, exist_ok=True)
|
| 391 |
+
Path(args.rn_npz).parent.mkdir(parents=True, exist_ok=True)
|
| 392 |
+
|
| 393 |
+
if args.stage in ("all", "lora"):
|
| 394 |
+
stage_lora(args, dtype, device)
|
| 395 |
+
if device == "cuda":
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
if args.stage in ("all", "rn"):
|
| 398 |
+
stage_rn(args, dtype, device)
|
| 399 |
+
if device == "cuda":
|
| 400 |
+
torch.cuda.empty_cache()
|
| 401 |
+
if args.stage in ("all", "pair"):
|
| 402 |
+
stage_pair(args, dtype, device)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
main()
|
visualize_sitf2_noise_evolution.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import imageio
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from models import SiTF1, SiTF2, SiT, CombinedModel
|
| 7 |
+
from download import find_model
|
| 8 |
+
from diffusers.models import AutoencoderKL
|
| 9 |
+
|
| 10 |
+
def tensor_to_img(tensor):
|
| 11 |
+
arr = tensor.detach().cpu().numpy()
|
| 12 |
+
if arr.ndim == 3:
|
| 13 |
+
arr = arr[0]
|
| 14 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255
|
| 15 |
+
return arr.astype(np.uint8)
|
| 16 |
+
|
| 17 |
+
def main(
|
| 18 |
+
sit_ckpt, sitf2_ckpt,
|
| 19 |
+
image_size=256,
|
| 20 |
+
patch_size=2,
|
| 21 |
+
hidden_size=1152,
|
| 22 |
+
out_channels=8,
|
| 23 |
+
steps=50,
|
| 24 |
+
gif_path='sitf2_noise_evolution.gif',
|
| 25 |
+
device='cuda'
|
| 26 |
+
):
|
| 27 |
+
dist.init_process_group("nccl")
|
| 28 |
+
rank = dist.get_rank()
|
| 29 |
+
device = rank % torch.cuda.device_count()
|
| 30 |
+
latent_size = image_size // 8
|
| 31 |
+
sitf1 = SiTF1(
|
| 32 |
+
input_size=latent_size,
|
| 33 |
+
patch_size=2,
|
| 34 |
+
in_channels=4,
|
| 35 |
+
hidden_size=768,
|
| 36 |
+
depth=12,
|
| 37 |
+
num_heads=12,
|
| 38 |
+
mlp_ratio=4.0,
|
| 39 |
+
class_dropout_prob=0.1,
|
| 40 |
+
num_classes=3,
|
| 41 |
+
learn_sigma=False
|
| 42 |
+
).to(device)
|
| 43 |
+
sitf1_state = find_model(sit_ckpt)
|
| 44 |
+
try:
|
| 45 |
+
sitf1.load_state_dict(sitf1_state["model"], strict=False)
|
| 46 |
+
except Exception:
|
| 47 |
+
sitf1.load_state_dict(sitf1_state, strict=False)
|
| 48 |
+
sitf1.eval()
|
| 49 |
+
|
| 50 |
+
sitf2 = SiTF2(
|
| 51 |
+
hidden_size=768,
|
| 52 |
+
out_channels=8,
|
| 53 |
+
patch_size=2,
|
| 54 |
+
num_heads=12,
|
| 55 |
+
mlp_ratio=4.0,
|
| 56 |
+
depth=2,
|
| 57 |
+
learn_sigma=False,
|
| 58 |
+
learn_mu=True
|
| 59 |
+
).to(device)
|
| 60 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 61 |
+
sitf2 = DDP(sitf2, device_ids=[device])
|
| 62 |
+
sitf2_state = find_model(args.sitf2_ckpt)
|
| 63 |
+
try:
|
| 64 |
+
sitf2.load_state_dict(sitf2_state["ema"])
|
| 65 |
+
except Exception:
|
| 66 |
+
sitf2.load_state_dict(sitf2_state)
|
| 67 |
+
sitf2.eval()
|
| 68 |
+
|
| 69 |
+
batch = 1
|
| 70 |
+
x = torch.randn(batch, 4, latent_size, latent_size, device=device)
|
| 71 |
+
x0= x
|
| 72 |
+
y = torch.zeros(batch, dtype=torch.long, device=device)
|
| 73 |
+
t = torch.ones(batch, device=device)
|
| 74 |
+
|
| 75 |
+
imgs = []
|
| 76 |
+
imgs1 = []
|
| 77 |
+
imgs2 = []
|
| 78 |
+
img_original=[]
|
| 79 |
+
sit = SiT(
|
| 80 |
+
input_size=latent_size,
|
| 81 |
+
patch_size=2,
|
| 82 |
+
in_channels=4,
|
| 83 |
+
hidden_size=768,
|
| 84 |
+
depth=12,
|
| 85 |
+
num_heads=12,
|
| 86 |
+
mlp_ratio=4.0,
|
| 87 |
+
class_dropout_prob=0.1,
|
| 88 |
+
num_classes=3,
|
| 89 |
+
learn_sigma=False
|
| 90 |
+
).to(device)
|
| 91 |
+
try:
|
| 92 |
+
sit.load_state_dict(sitf1_state["model"])
|
| 93 |
+
except Exception:
|
| 94 |
+
sit.load_state_dict(sitf1_state)
|
| 95 |
+
sit.eval()
|
| 96 |
+
combined_model = CombinedModel(sitf1, sitf2).to(device)
|
| 97 |
+
combined_model.eval()
|
| 98 |
+
|
| 99 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)
|
| 100 |
+
|
| 101 |
+
for step in tqdm(range(steps)):
|
| 102 |
+
t = torch.full((batch,), step / (steps - 1), device=device)
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
patch_tokens = sitf1(x, t, y)
|
| 105 |
+
t_emb = sitf1.t_embedder(t)
|
| 106 |
+
y_emb = sitf1.y_embedder(y, False)
|
| 107 |
+
c = t_emb + y_emb
|
| 108 |
+
std = sitf2.module.forward_noise(patch_tokens, c)
|
| 109 |
+
x1=x
|
| 110 |
+
drift = sit(x1, t, y)
|
| 111 |
+
delta_t = 1.0 / steps
|
| 112 |
+
x1 = x1 + drift * delta_t
|
| 113 |
+
x_dec = vae.decode(x1 / 0.18215).sample
|
| 114 |
+
img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 115 |
+
img1 = img[0]
|
| 116 |
+
drift = sit(x, t, y)
|
| 117 |
+
delta_t = 1.0 / steps
|
| 118 |
+
noise = torch.randn_like(x)
|
| 119 |
+
x = x + drift * delta_t
|
| 120 |
+
x_dec = vae.decode(x / 0.18215).sample
|
| 121 |
+
img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 122 |
+
img2 = img[0]
|
| 123 |
+
imgs.append(img2-img1)
|
| 124 |
+
imgs1.append(img1)
|
| 125 |
+
imgs2.append(img2)
|
| 126 |
+
x=x0
|
| 127 |
+
for step in tqdm(range(steps)):
|
| 128 |
+
t = torch.full((batch,), step / (steps - 1), device=device)
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
patch_tokens = sitf1(x, t, y)
|
| 131 |
+
t_emb = sitf1.t_embedder(t)
|
| 132 |
+
y_emb = sitf1.y_embedder(y, False)
|
| 133 |
+
c = t_emb + y_emb
|
| 134 |
+
std = sitf2.module.forward_noise(patch_tokens, c)
|
| 135 |
+
drift = sit(x, t, y)
|
| 136 |
+
delta_t = 1.0 / steps
|
| 137 |
+
x = x + drift * delta_t
|
| 138 |
+
x_dec = vae.decode(x / 0.18215).sample
|
| 139 |
+
img = torch.clamp(127.5 * x_dec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 140 |
+
img1 = img[0]
|
| 141 |
+
img_original.append(img1)
|
| 142 |
+
|
| 143 |
+
imageio.mimsave(gif_path, imgs, duration=0.1)
|
| 144 |
+
print(f"Saved gif to {gif_path}")
|
| 145 |
+
imageio.mimsave('noise.gif', imgs1, duration=0.1)
|
| 146 |
+
print(f"Saved gif to {gif_path}")
|
| 147 |
+
imageio.mimsave('std.gif', imgs2, duration=0.1)
|
| 148 |
+
imageio.mimsave('nothing.gif', img_original, duration=0.1)
|
| 149 |
+
print(f"Saved gif to {gif_path}")
|
| 150 |
+
if __name__ == '__main__':
|
| 151 |
+
import argparse
|
| 152 |
+
parser = argparse.ArgumentParser()
|
| 153 |
+
parser.add_argument('--ckpt', type=str, default='/gemini/space/gzy/w_w_last/Celeba/w_w_sit_1/0200000.pt')
|
| 154 |
+
parser.add_argument('--sitf2-ckpt', type=str,default='/gemini/space/gzy/w_w_last/Celeba/w_w_sit_1/results/depth-mu-2-014-SiT-XL-2-Linear-velocity-None/checkpoints/0010000.pt')
|
| 155 |
+
parser.add_argument('--steps', type=int, default=100)
|
| 156 |
+
parser.add_argument('--gif-path', type=str, default='sitf2_noise_evolution.gif')
|
| 157 |
+
parser.add_argument('--gif-path2', type=str, default='noise.gif')
|
| 158 |
+
parser.add_argument('--gif-path1', type=str, default='std.gif')
|
| 159 |
+
parser.add_argument('--image-size', type=int, default=256)
|
| 160 |
+
parser.add_argument('--device', type=str, default='cuda')
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
main(
|
| 163 |
+
sit_ckpt=args.ckpt,
|
| 164 |
+
sitf2_ckpt=args.sitf2_ckpt,
|
| 165 |
+
image_size=args.image_size,
|
| 166 |
+
steps=args.steps,
|
| 167 |
+
gif_path=args.gif_path,
|
| 168 |
+
device=args.device
|
| 169 |
+
)
|