#!/usr/bin/env bash # ODE 采样(euler_sampler,与 samplers.py 中 REG ODE 一致):分别用 5 / 10 / 50 步,输出到不同子目录。 set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR" NUM_GPUS=1 NUM_SAMPLES=100 PER_PROC_BATCH_SIZE=10 CFG_SCALE=1 CLS_CFG_SCALE=1 GH=0.85 CKPT_PATH="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/REG-XLarge-256/REG-XLarge-256/2400000.pt" #CKPT_PATH="/your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8/checkpoints/0040000.pt" # 根输出目录(可用环境变量覆盖);其下会生成 ode_steps_5 / ode_steps_10 / ode_steps_50 OUTPUT_BASE="${OUTPUT_BASE:-./reg_xl_2400000_ode_multistep}" NOISE_BANK_PATH="${NOISE_BANK_PATH:-${OUTPUT_BASE}/noise_bank.pt}" export NCCL_P2P_DISABLE=1 export OUTPUT_BASE GLOBAL_BATCH_SIZE=$((PER_PROC_BATCH_SIZE * NUM_GPUS)) TOTAL_SAMPLES=$((((NUM_SAMPLES + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE) * GLOBAL_BATCH_SIZE)) export NOISE_BANK_PATH TOTAL_SAMPLES REBUILD_NOISE_BANK=0 if [[ ! -f "${NOISE_BANK_PATH}" ]]; then REBUILD_NOISE_BANK=1 else EXISTING_SAMPLES=$(python - <<'PY' import os import torch p = os.environ["NOISE_BANK_PATH"] try: d = torch.load(p, map_location="cpu", weights_only=True) except TypeError: d = torch.load(p, map_location="cpu") print(int(d["z"].shape[0])) PY ) if [[ "${EXISTING_SAMPLES}" != "${TOTAL_SAMPLES}" ]]; then echo "noise_bank size mismatch: existing=${EXISTING_SAMPLES}, required=${TOTAL_SAMPLES}. Rebuilding..." REBUILD_NOISE_BANK=1 fi fi if [[ "${REBUILD_NOISE_BANK}" -eq 1 ]]; then mkdir -p "$(dirname "${NOISE_BANK_PATH}")" echo "Creating fixed noise bank: ${NOISE_BANK_PATH} (total_samples=${TOTAL_SAMPLES})" python - <<'PY' import os import torch noise_path = os.environ["NOISE_BANK_PATH"] total_samples = int(os.environ["TOTAL_SAMPLES"]) num_classes = 1000 in_channels = 4 latent_size = 32 # 256 / 8 cls_dim = 768 z = torch.randn(total_samples, in_channels, latent_size, latent_size) y = torch.randint(0, num_classes, (total_samples,), dtype=torch.long) cls_z = torch.randn(total_samples, cls_dim) torch.save({"z": z, "y": y, "cls_z": cls_z}, noise_path) print(f"Saved fixed noise bank to {noise_path}") PY fi for NUM_STEP in 5 10 20 50 100; do RANDOM_PORT=$((RANDOM % 100 + 1200)) OUTPUT_DIR="${OUTPUT_BASE}/ode_steps_${NUM_STEP}/checkpoints" echo "=== ODE num_steps=${NUM_STEP} -> ${OUTPUT_DIR} ===" python -m torch.distributed.launch \ --master_port="${RANDOM_PORT}" \ --nproc_per_node="${NUM_GPUS}" \ generate.py \ --model "SiT-XL/2" \ --num-fid-samples "${NUM_SAMPLES}" \ --ckpt "${CKPT_PATH}" \ --path-type=linear \ --encoder-depth=8 \ --projector-embed-dims=768 \ --per-proc-batch-size="${PER_PROC_BATCH_SIZE}" \ --mode=ode \ --num-steps="${NUM_STEP}" \ --fixed-noise-file "${NOISE_BANK_PATH}" \ --cfg-scale="${CFG_SCALE}" \ --cls-cfg-scale="${CLS_CFG_SCALE}" \ --guidance-high="${GH}" \ --sample-dir "${OUTPUT_DIR}" \ --cls=768 done echo "All ODE runs finished. Building side-by-side pairs..." python - <<'PY' import os from PIL import Image output_base = os.environ.get("OUTPUT_BASE", "./reg_xl_2400000_ode_multistep") steps = [5, 10, 50] def find_single_subdir(step): root = os.path.join(output_base, f"ode_steps_{step}", "checkpoints") if not os.path.isdir(root): raise RuntimeError(f"Missing directory: {root}") subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] if len(subdirs) != 1: raise RuntimeError(f"Expected exactly 1 run folder under {root}, got {len(subdirs)}: {subdirs}") return os.path.join(root, subdirs[0]) run_dirs = {s: find_single_subdir(s) for s in steps} pair_dir = os.path.join(output_base, "pair") os.makedirs(pair_dir, exist_ok=True) files = sorted([f for f in os.listdir(run_dirs[steps[0]]) if f.endswith(".png")]) for name in files: imgs = [Image.open(os.path.join(run_dirs[s], name)).convert("RGB") for s in steps] w = sum(im.width for im in imgs) h = max(im.height for im in imgs) canvas = Image.new("RGB", (w, h)) x = 0 for im in imgs: canvas.paste(im, (x, 0)) x += im.width canvas.save(os.path.join(pair_dir, name)) print(f"Saved paired comparisons to: {pair_dir}") PY echo "Done. Outputs under: ${OUTPUT_BASE}/ode_steps_{5,10,50}/checkpoints ; pairs under: ${OUTPUT_BASE}/pair"