| #!/usr/bin/env bash |
| |
| set -euo pipefail |
|
|
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" |
| cd "$SCRIPT_DIR" |
|
|
| NUM_GPUS=1 |
| NUM_SAMPLES=100 |
| PER_PROC_BATCH_SIZE=10 |
|
|
| CFG_SCALE=1 |
| CLS_CFG_SCALE=1 |
| GH=0.85 |
|
|
| CKPT_PATH="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/REG-XLarge-256/REG-XLarge-256/2400000.pt" |
| |
|
|
| |
| OUTPUT_BASE="${OUTPUT_BASE:-./reg_xl_2400000_ode_multistep}" |
| NOISE_BANK_PATH="${NOISE_BANK_PATH:-${OUTPUT_BASE}/noise_bank.pt}" |
|
|
| export NCCL_P2P_DISABLE=1 |
| export OUTPUT_BASE |
|
|
| GLOBAL_BATCH_SIZE=$((PER_PROC_BATCH_SIZE * NUM_GPUS)) |
| TOTAL_SAMPLES=$((((NUM_SAMPLES + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE) * GLOBAL_BATCH_SIZE)) |
| export NOISE_BANK_PATH TOTAL_SAMPLES |
|
|
| REBUILD_NOISE_BANK=0 |
| if [[ ! -f "${NOISE_BANK_PATH}" ]]; then |
| REBUILD_NOISE_BANK=1 |
| else |
| EXISTING_SAMPLES=$(python - <<'PY' |
| import os |
| import torch |
| p = os.environ["NOISE_BANK_PATH"] |
| try: |
| d = torch.load(p, map_location="cpu", weights_only=True) |
| except TypeError: |
| d = torch.load(p, map_location="cpu") |
| print(int(d["z"].shape[0])) |
| PY |
| ) |
| if [[ "${EXISTING_SAMPLES}" != "${TOTAL_SAMPLES}" ]]; then |
| echo "noise_bank size mismatch: existing=${EXISTING_SAMPLES}, required=${TOTAL_SAMPLES}. Rebuilding..." |
| REBUILD_NOISE_BANK=1 |
| fi |
| fi |
|
|
| if [[ "${REBUILD_NOISE_BANK}" -eq 1 ]]; then |
| mkdir -p "$(dirname "${NOISE_BANK_PATH}")" |
| echo "Creating fixed noise bank: ${NOISE_BANK_PATH} (total_samples=${TOTAL_SAMPLES})" |
| python - <<'PY' |
| import os |
| import torch |
|
|
| noise_path = os.environ["NOISE_BANK_PATH"] |
| total_samples = int(os.environ["TOTAL_SAMPLES"]) |
| num_classes = 1000 |
| in_channels = 4 |
| latent_size = 32 |
| cls_dim = 768 |
|
|
| z = torch.randn(total_samples, in_channels, latent_size, latent_size) |
| y = torch.randint(0, num_classes, (total_samples,), dtype=torch.long) |
| cls_z = torch.randn(total_samples, cls_dim) |
|
|
| torch.save({"z": z, "y": y, "cls_z": cls_z}, noise_path) |
| print(f"Saved fixed noise bank to {noise_path}") |
| PY |
| fi |
|
|
| for NUM_STEP in 5 10 20 50 100; do |
| RANDOM_PORT=$((RANDOM % 100 + 1200)) |
| OUTPUT_DIR="${OUTPUT_BASE}/ode_steps_${NUM_STEP}/checkpoints" |
| echo "=== ODE num_steps=${NUM_STEP} -> ${OUTPUT_DIR} ===" |
|
|
| python -m torch.distributed.launch \ |
| --master_port="${RANDOM_PORT}" \ |
| --nproc_per_node="${NUM_GPUS}" \ |
| generate.py \ |
| --model "SiT-XL/2" \ |
| --num-fid-samples "${NUM_SAMPLES}" \ |
| --ckpt "${CKPT_PATH}" \ |
| --path-type=linear \ |
| --encoder-depth=8 \ |
| --projector-embed-dims=768 \ |
| --per-proc-batch-size="${PER_PROC_BATCH_SIZE}" \ |
| --mode=ode \ |
| --num-steps="${NUM_STEP}" \ |
| --fixed-noise-file "${NOISE_BANK_PATH}" \ |
| --cfg-scale="${CFG_SCALE}" \ |
| --cls-cfg-scale="${CLS_CFG_SCALE}" \ |
| --guidance-high="${GH}" \ |
| --sample-dir "${OUTPUT_DIR}" \ |
| --cls=768 |
| done |
|
|
| echo "All ODE runs finished. Building side-by-side pairs..." |
|
|
| python - <<'PY' |
| import os |
| from PIL import Image |
|
|
| output_base = os.environ.get("OUTPUT_BASE", "./reg_xl_2400000_ode_multistep") |
| steps = [5, 10, 50] |
|
|
| def find_single_subdir(step): |
| root = os.path.join(output_base, f"ode_steps_{step}", "checkpoints") |
| if not os.path.isdir(root): |
| raise RuntimeError(f"Missing directory: {root}") |
| subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] |
| if len(subdirs) != 1: |
| raise RuntimeError(f"Expected exactly 1 run folder under {root}, got {len(subdirs)}: {subdirs}") |
| return os.path.join(root, subdirs[0]) |
|
|
| run_dirs = {s: find_single_subdir(s) for s in steps} |
| pair_dir = os.path.join(output_base, "pair") |
| os.makedirs(pair_dir, exist_ok=True) |
|
|
| files = sorted([f for f in os.listdir(run_dirs[steps[0]]) if f.endswith(".png")]) |
| for name in files: |
| imgs = [Image.open(os.path.join(run_dirs[s], name)).convert("RGB") for s in steps] |
| w = sum(im.width for im in imgs) |
| h = max(im.height for im in imgs) |
| canvas = Image.new("RGB", (w, h)) |
| x = 0 |
| for im in imgs: |
| canvas.paste(im, (x, 0)) |
| x += im.width |
| canvas.save(os.path.join(pair_dir, name)) |
|
|
| print(f"Saved paired comparisons to: {pair_dir}") |
| PY |
|
|
| echo "Done. Outputs under: ${OUTPUT_BASE}/ode_steps_{5,10,50}/checkpoints ; pairs under: ${OUTPUT_BASE}/pair" |
|
|