jsflow / New /REG /generate_2400000_100.sh
xiangzai's picture
Add files using upload-large-folder tool
db5d40d verified
#!/usr/bin/env bash
# ODE 采样(euler_sampler,与 samplers.py 中 REG ODE 一致):分别用 5 / 10 / 50 步,输出到不同子目录。
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
NUM_GPUS=1
NUM_SAMPLES=100
PER_PROC_BATCH_SIZE=10
CFG_SCALE=1
CLS_CFG_SCALE=1
GH=0.85
CKPT_PATH="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/REG-XLarge-256/REG-XLarge-256/2400000.pt"
#CKPT_PATH="/your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8/checkpoints/0040000.pt"
# 根输出目录(可用环境变量覆盖);其下会生成 ode_steps_5 / ode_steps_10 / ode_steps_50
OUTPUT_BASE="${OUTPUT_BASE:-./reg_xl_2400000_ode_multistep}"
NOISE_BANK_PATH="${NOISE_BANK_PATH:-${OUTPUT_BASE}/noise_bank.pt}"
export NCCL_P2P_DISABLE=1
export OUTPUT_BASE
GLOBAL_BATCH_SIZE=$((PER_PROC_BATCH_SIZE * NUM_GPUS))
TOTAL_SAMPLES=$((((NUM_SAMPLES + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE) * GLOBAL_BATCH_SIZE))
export NOISE_BANK_PATH TOTAL_SAMPLES
REBUILD_NOISE_BANK=0
if [[ ! -f "${NOISE_BANK_PATH}" ]]; then
REBUILD_NOISE_BANK=1
else
EXISTING_SAMPLES=$(python - <<'PY'
import os
import torch
p = os.environ["NOISE_BANK_PATH"]
try:
d = torch.load(p, map_location="cpu", weights_only=True)
except TypeError:
d = torch.load(p, map_location="cpu")
print(int(d["z"].shape[0]))
PY
)
if [[ "${EXISTING_SAMPLES}" != "${TOTAL_SAMPLES}" ]]; then
echo "noise_bank size mismatch: existing=${EXISTING_SAMPLES}, required=${TOTAL_SAMPLES}. Rebuilding..."
REBUILD_NOISE_BANK=1
fi
fi
if [[ "${REBUILD_NOISE_BANK}" -eq 1 ]]; then
mkdir -p "$(dirname "${NOISE_BANK_PATH}")"
echo "Creating fixed noise bank: ${NOISE_BANK_PATH} (total_samples=${TOTAL_SAMPLES})"
python - <<'PY'
import os
import torch
noise_path = os.environ["NOISE_BANK_PATH"]
total_samples = int(os.environ["TOTAL_SAMPLES"])
num_classes = 1000
in_channels = 4
latent_size = 32 # 256 / 8
cls_dim = 768
z = torch.randn(total_samples, in_channels, latent_size, latent_size)
y = torch.randint(0, num_classes, (total_samples,), dtype=torch.long)
cls_z = torch.randn(total_samples, cls_dim)
torch.save({"z": z, "y": y, "cls_z": cls_z}, noise_path)
print(f"Saved fixed noise bank to {noise_path}")
PY
fi
for NUM_STEP in 5 10 20 50 100; do
RANDOM_PORT=$((RANDOM % 100 + 1200))
OUTPUT_DIR="${OUTPUT_BASE}/ode_steps_${NUM_STEP}/checkpoints"
echo "=== ODE num_steps=${NUM_STEP} -> ${OUTPUT_DIR} ==="
python -m torch.distributed.launch \
--master_port="${RANDOM_PORT}" \
--nproc_per_node="${NUM_GPUS}" \
generate.py \
--model "SiT-XL/2" \
--num-fid-samples "${NUM_SAMPLES}" \
--ckpt "${CKPT_PATH}" \
--path-type=linear \
--encoder-depth=8 \
--projector-embed-dims=768 \
--per-proc-batch-size="${PER_PROC_BATCH_SIZE}" \
--mode=ode \
--num-steps="${NUM_STEP}" \
--fixed-noise-file "${NOISE_BANK_PATH}" \
--cfg-scale="${CFG_SCALE}" \
--cls-cfg-scale="${CLS_CFG_SCALE}" \
--guidance-high="${GH}" \
--sample-dir "${OUTPUT_DIR}" \
--cls=768
done
echo "All ODE runs finished. Building side-by-side pairs..."
python - <<'PY'
import os
from PIL import Image
output_base = os.environ.get("OUTPUT_BASE", "./reg_xl_2400000_ode_multistep")
steps = [5, 10, 50]
def find_single_subdir(step):
root = os.path.join(output_base, f"ode_steps_{step}", "checkpoints")
if not os.path.isdir(root):
raise RuntimeError(f"Missing directory: {root}")
subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
if len(subdirs) != 1:
raise RuntimeError(f"Expected exactly 1 run folder under {root}, got {len(subdirs)}: {subdirs}")
return os.path.join(root, subdirs[0])
run_dirs = {s: find_single_subdir(s) for s in steps}
pair_dir = os.path.join(output_base, "pair")
os.makedirs(pair_dir, exist_ok=True)
files = sorted([f for f in os.listdir(run_dirs[steps[0]]) if f.endswith(".png")])
for name in files:
imgs = [Image.open(os.path.join(run_dirs[s], name)).convert("RGB") for s in steps]
w = sum(im.width for im in imgs)
h = max(im.height for im in imgs)
canvas = Image.new("RGB", (w, h))
x = 0
for im in imgs:
canvas.paste(im, (x, 0))
x += im.width
canvas.save(os.path.join(pair_dir, name))
print(f"Saved paired comparisons to: {pair_dir}")
PY
echo "Done. Outputs under: ${OUTPUT_BASE}/ode_steps_{5,10,50}/checkpoints ; pairs under: ${OUTPUT_BASE}/pair"