#!/usr/bin/env python3 """ DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。 用法(4 卡示例): torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \ --ckpt exps/jsflow-experiment/checkpoints/0290000.pt \ --out-dir ./my_samples_ddp \ --num-images 50000 \ --batch-size 16 \ --t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \ --sampler em_image_noise_before_tc """ from __future__ import annotations import argparse import math import os import sys import types import numpy as np import torch import torch.distributed as dist from diffusers.models import AutoencoderKL from PIL import Image from tqdm import tqdm from models.sit import SiT_models from samplers import ( euler_maruyama_image_noise_before_tc_sampler, euler_maruyama_image_noise_sampler, euler_maruyama_sampler, euler_ode_sampler, ) def create_npz_from_sample_folder(sample_dir: str, num: int): """ 将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。 """ samples = [] for i in tqdm(range(num), desc="Building .npz file from samples"): sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png")) sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) npz_path = f"{sample_dir}.npz" np.savez(npz_path, arr_0=samples) print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path def semantic_dim_from_enc_type(enc_type): if enc_type is None: return 768 s = str(enc_type).lower() if "vit-g" in s or "vitg" in s: return 1536 if "vit-l" in s or "vitl" in s: return 1024 if "vit-s" in s or "vits" in s: return 384 return 768 def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None: a = ckpt.get("args") if a is None: return None if isinstance(a, argparse.Namespace): return a if isinstance(a, dict): return argparse.Namespace(**a) if isinstance(a, types.SimpleNamespace): return argparse.Namespace(**vars(a)) return None def load_vae(device: torch.device): try: from preprocessing import dnnlib cache_dir = dnnlib.make_cache_dir_path("diffusers") os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ["HF_HOME"] = cache_dir try: vae = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", cache_dir=cache_dir, local_files_only=True, ).to(device) vae.eval() return vae except Exception: pass candidate_dir = None for root_dir in [ cache_dir, os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), ]: if not os.path.isdir(root_dir): continue for root, _, files in os.walk(root_dir): if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"): candidate_dir = root break if candidate_dir is not None: break if candidate_dir is not None: vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device) vae.eval() return vae except Exception: pass vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) vae.eval() return vae def build_model_from_train_args(ta: argparse.Namespace, device: torch.device): res = int(getattr(ta, "resolution", 256)) latent_size = res // 8 enc_type = getattr(ta, "enc_type", "dinov2-vit-b") z_dims = [semantic_dim_from_enc_type(enc_type)] block_kwargs = { "fused_attn": getattr(ta, "fused_attn", True), "qk_norm": getattr(ta, "qk_norm", False), } cfg_prob = float(getattr(ta, "cfg_prob", 0.1)) if ta.model not in SiT_models: raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}") model = SiT_models[ta.model]( input_size=latent_size, num_classes=int(getattr(ta, "num_classes", 1000)), use_cfg=(cfg_prob > 0), z_dims=z_dims, encoder_depth=int(getattr(ta, "encoder_depth", 8)), **block_kwargs, ).to(device) return model, z_dims[0] def resolve_tc_schedule(cli, ta): sb = cli.steps_before_tc sa = cli.steps_after_tc tc = cli.t_c if sb is None and sa is None: return None, None, None if sb is None or sa is None: print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr) sys.exit(1) if tc is None: tc = getattr(ta, "t_c", None) if ta is not None else None if tc is None: print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr) sys.exit(1) return float(tc), int(sb), int(sa) def parse_cli(): p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)") p.add_argument("--ckpt", type=str, required=True) p.add_argument("--out-dir", type=str, required=True) p.add_argument("--num-images", type=int, required=True) p.add_argument("--batch-size", type=int, default=16) p.add_argument("--seed", type=int, default=0) p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema") p.add_argument("--device", type=str, default="cuda") p.add_argument("--num-steps", type=int, default=50) p.add_argument("--t-c", type=float, default=None) p.add_argument("--steps-before-tc", type=int, default=None) p.add_argument("--steps-after-tc", type=int, default=None) p.add_argument("--cfg-scale", type=float, default=1.0) p.add_argument("--cls-cfg-scale", type=float, default=0.0) p.add_argument("--guidance-low", type=float, default=0.0) p.add_argument("--guidance-high", type=float, default=1.0) p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"]) p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) p.add_argument("--model", type=str, default=None) p.add_argument("--resolution", type=int, default=None, choices=[256, 512]) p.add_argument("--num-classes", type=int, default=1000) p.add_argument("--encoder-depth", type=int, default=None) p.add_argument("--enc-type", type=str, default=None) p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None) p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None) p.add_argument("--cfg-prob", type=float, default=None) p.add_argument( "--sampler", type=str, default="em_image_noise_before_tc", choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"], ) p.add_argument( "--save-fixed-trajectory", action="store_true", help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}", ) p.add_argument( "--save-npz", action=argparse.BooleanOptionalAction, default=True, help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz", ) return p.parse_args() def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae): imgs = vae.decode((latents - latents_bias) / latents_scale).sample imgs = (imgs + 1) / 2.0 imgs = torch.clamp(imgs, 0, 1) return ( (imgs * 255.0) .round() .to(torch.uint8) .permute(0, 2, 3, 1) .cpu() .numpy() ) def init_ddp(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) dist.init_process_group(backend="nccl", init_method="env://") torch.cuda.set_device(local_rank) return True, rank, world_size, local_rank return False, 0, 1, 0 def main(): cli = parse_cli() use_ddp, rank, world_size, local_rank = init_ddp() if torch.cuda.is_available(): device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device) torch.backends.cuda.matmul.allow_tf32 = True else: device = torch.device("cpu") try: ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False) except TypeError: ckpt = torch.load(cli.ckpt, map_location="cpu") ta = load_train_args_from_ckpt(ckpt) if ta is None: if cli.model is None or cli.resolution is None or cli.enc_type is None: print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr) sys.exit(1) ta = argparse.Namespace( model=cli.model, resolution=cli.resolution, num_classes=cli.num_classes if cli.num_classes is not None else 1000, encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8, enc_type=cli.enc_type, fused_attn=cli.fused_attn if cli.fused_attn is not None else True, qk_norm=cli.qk_norm if cli.qk_norm is not None else False, cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1, ) else: if cli.model is not None: ta.model = cli.model if cli.resolution is not None: ta.resolution = cli.resolution if cli.num_classes is not None: ta.num_classes = cli.num_classes if cli.encoder_depth is not None: ta.encoder_depth = cli.encoder_depth if cli.enc_type is not None: ta.enc_type = cli.enc_type if cli.fused_attn is not None: ta.fused_attn = cli.fused_attn if cli.qk_norm is not None: ta.qk_norm = cli.qk_norm if cli.cfg_prob is not None: ta.cfg_prob = cli.cfg_prob path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear") tc_split = resolve_tc_schedule(cli, ta) if rank == 0: if tc_split[0] is not None: print( f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}" ) else: print(f"时间网格:均匀 num_steps={cli.num_steps}") if cli.sampler == "ode": sampler_fn = euler_ode_sampler elif cli.sampler == "em": sampler_fn = euler_maruyama_sampler elif cli.sampler == "em_image_noise_before_tc": sampler_fn = euler_maruyama_image_noise_before_tc_sampler else: sampler_fn = euler_maruyama_image_noise_sampler model, cls_dim = build_model_from_train_args(ta, device) wkey = cli.weights if wkey not in ckpt: raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}") state = ckpt[wkey] if cli.legacy: from utils import load_legacy_checkpoints state = load_legacy_checkpoints( state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8)) ) model.load_state_dict(state, strict=True) model.eval() vae = load_vae(device) latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1) latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1) sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale)) os.makedirs(cli.out_dir, exist_ok=True) traj_dir = None if cli.save_fixed_trajectory and cli.sampler != "em": traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}") os.makedirs(traj_dir, exist_ok=True) latent_size = int(getattr(ta, "resolution", 256)) // 8 n_total = int(cli.num_images) b = max(1, int(cli.batch_size)) global_batch_size = b * world_size total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size) samples_needed_this_gpu = int(total_samples // world_size) if samples_needed_this_gpu % b != 0: raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size") iterations = int(samples_needed_this_gpu // b) seed_rank = int(cli.seed) + int(rank) torch.manual_seed(seed_rank) if device.type == "cuda": torch.cuda.manual_seed_all(seed_rank) if rank == 0: print(f"Total number of images that will be sampled: {total_samples}") pbar = range(iterations) pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar total = 0 written_local = 0 for _ in pbar: cur = b z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device) y = torch.randint(0, int(ta.num_classes), (cur,), device=device) cls_z = torch.randn(cur, cls_dim, device=device) with torch.no_grad(): em_kw = dict( num_steps=cli.num_steps, cfg_scale=cli.cfg_scale, guidance_low=cli.guidance_low, guidance_high=cli.guidance_high, path_type=path_type, cls_latents=cls_z, args=sampler_args, ) if tc_split[0] is not None: em_kw["t_c"] = tc_split[0] em_kw["num_steps_before_tc"] = tc_split[1] em_kw["num_steps_after_tc"] = tc_split[2] if cli.save_fixed_trajectory and cli.sampler != "em": if cli.sampler == "em_image_noise_before_tc": latents, traj = sampler_fn( model, z, y, **em_kw, return_trajectory=True ) else: latents, traj = sampler_fn( model, z, y, **em_kw, return_trajectory=True ) else: latents = sampler_fn(model, z, y, **em_kw) traj = None latents = latents.to(torch.float32) imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae) for i, img in enumerate(imgs): gidx = i * world_size + rank + total if gidx < n_total: Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png")) written_local += 1 if traj is not None and traj_dir is not None: traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy() first_idx = rank + total if first_idx < n_total: np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np) total += global_batch_size if use_ddp: dist.barrier() if rank == 0 and hasattr(pbar, "close"): pbar.close() if use_ddp: dist.barrier() if rank == 0: if cli.save_npz: create_npz_from_sample_folder(cli.out_dir, n_total) print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).") if use_ddp: dist.destroy_process_group() if __name__ == "__main__": main()