#!/usr/bin/env python3 """ 从 REG/train.py 保存的检查点加载权重,在指定目录生成若干 PNG。 示例: python sample_from_checkpoint.py \\ --ckpt exps/jsflow-experiment/checkpoints/0050000.pt \\ --out-dir ./samples_gen \\ --num-images 64 \\ --batch-size 8 # 按训练 t_c 分段分配步数(t=1→t_c 与 t_c→0;--t-c 可省略若检查点含 t_c): python sample_from_checkpoint.py ... \\ --steps-before-tc 150 --steps-after-tc 100 --t-c 0.5 # 同一批初始噪声连跑两种 t_c 后段步数(输出到 out-dir 下子目录): python sample_from_checkpoint.py ... \\ --steps-before-tc 150 --steps-after-tc 5 --dual-compare-after # 分段时会在 at_tc/(或 at_tc/after_input、at_tc/after_equal_before)额外保存 t≈t_c 的解码图。 检查点需包含 train.py 写入的键:ema(或 model)、args(推荐,用于自动还原结构)。 若缺少 args,需通过命令行显式传入 --model、--resolution、--enc-type 等。 """ from __future__ import annotations import argparse import os import sys import types import numpy as np import torch from diffusers.models import AutoencoderKL from PIL import Image from tqdm import tqdm from models.sit import SiT_models from samplers import ( euler_maruyama_image_noise_before_tc_sampler, euler_maruyama_image_noise_sampler, euler_maruyama_sampler, euler_ode_sampler, ) def semantic_dim_from_enc_type(enc_type): """与 train.py 一致:按 enc_type 推断语义/class token 维度。""" if enc_type is None: return 768 s = str(enc_type).lower() if "vit-g" in s or "vitg" in s: return 1536 if "vit-l" in s or "vitl" in s: return 1024 if "vit-s" in s or "vits" in s: return 384 return 768 def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None: a = ckpt.get("args") if a is None: return None if isinstance(a, argparse.Namespace): return a if isinstance(a, dict): return argparse.Namespace(**a) if isinstance(a, types.SimpleNamespace): return argparse.Namespace(**vars(a)) return None def load_vae(device: torch.device): """与 train.py 相同策略:优先本地 diffusers 缓存中的 sd-vae-ft-mse。""" try: from preprocessing import dnnlib cache_dir = dnnlib.make_cache_dir_path("diffusers") os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ["HF_HOME"] = cache_dir try: vae = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", cache_dir=cache_dir, local_files_only=True, ).to(device) vae.eval() print(f"Loaded VAE from local cache: {cache_dir}") return vae except Exception: pass candidate_dir = None for root_dir in [ cache_dir, os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), ]: if not os.path.isdir(root_dir): continue for root, _, files in os.walk(root_dir): if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"): candidate_dir = root break if candidate_dir is not None: break if candidate_dir is not None: vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device) vae.eval() print(f"Loaded VAE from {candidate_dir}") return vae except Exception as e: print(f"VAE local cache search failed: {e}", file=sys.stderr) try: vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) vae.eval() print("Loaded VAE from Hub: stabilityai/sd-vae-ft-mse") return vae except Exception as e: raise RuntimeError( "无法加载 VAE stabilityai/sd-vae-ft-mse,请确认已下载或网络可用。" ) from e def build_model_from_train_args(ta: argparse.Namespace, device: torch.device): res = int(getattr(ta, "resolution", 256)) latent_size = res // 8 enc_type = getattr(ta, "enc_type", "dinov2-vit-b") z_dims = [semantic_dim_from_enc_type(enc_type)] block_kwargs = { "fused_attn": getattr(ta, "fused_attn", True), "qk_norm": getattr(ta, "qk_norm", False), } cfg_prob = float(getattr(ta, "cfg_prob", 0.1)) if ta.model not in SiT_models: raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}") model = SiT_models[ta.model]( input_size=latent_size, num_classes=int(getattr(ta, "num_classes", 1000)), use_cfg=(cfg_prob > 0), z_dims=z_dims, encoder_depth=int(getattr(ta, "encoder_depth", 8)), **block_kwargs, ).to(device) return model, z_dims[0] def resolve_tc_schedule(cli, ta): """ 若同时给出 --steps-before-tc 与 --steps-after-tc:在 t_c 处分段(--t-c 缺省则用检查点 args.t_c)。 否则使用均匀 --num-steps(与旧版一致)。 """ sb = cli.steps_before_tc sa = cli.steps_after_tc tc = cli.t_c if sb is None and sa is None: return None, None, None if sb is None or sa is None: print( "使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr, ) sys.exit(1) if tc is None: tc = getattr(ta, "t_c", None) if ta is not None else None if tc is None: print( "分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr, ) sys.exit(1) return float(tc), int(sb), int(sa) def parse_cli(): p = argparse.ArgumentParser(description="REG 检查点采样出图(可选 ODE/EM/EM-图像噪声)") p.add_argument("--ckpt", type=str, required=True, help="train.py 保存的 .pt 路径") p.add_argument("--out-dir", type=str, required=True, help="输出 PNG 目录(会创建)") p.add_argument("--num-images", type=int, required=True, help="生成图片总数") p.add_argument("--batch-size", type=int, default=16) p.add_argument("--seed", type=int, default=0) p.add_argument( "--weights", type=str, choices=("ema", "model"), default="ema", help="使用检查点中的 ema 或 model 权重", ) p.add_argument("--device", type=str, default="cuda", help="如 cuda 或 cuda:0") p.add_argument( "--num-steps", type=int, default=50, help="均匀时间网格时的欧拉步数(未使用 --steps-before-tc/--steps-after-tc 时生效)", ) p.add_argument( "--t-c", type=float, default=None, help="分段时刻:t∈(t_c,1] 与 t∈[0,t_c] 两段;缺省可用检查点 args.t_c(需配合两段步数)", ) p.add_argument( "--steps-before-tc", type=int, default=None, help="从 t=1 积分到 t=t_c 的步数(与 --steps-after-tc 成对使用)", ) p.add_argument( "--steps-after-tc", type=int, default=None, help="从 t=t_c 积分到 t=0(经 t_floor=0.04)的步数", ) p.add_argument("--cfg-scale", type=float, default=1.0) p.add_argument("--cls-cfg-scale", type=float, default=0.0, help="cls 分支 CFG(>0 时需 cfg-scale>1)") p.add_argument("--guidance-low", type=float, default=0.0) p.add_argument("--guidance-high", type=float, default=1.0) p.add_argument( "--path-type", type=str, default=None, choices=["linear", "cosine"], help="默认从检查点 args 读取;可覆盖", ) p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # 无 args 时的兜底 p.add_argument("--model", type=str, default=None, help="无检查点 args 时必填;与 SiT_models 键一致,如 SiT-XL/2") p.add_argument("--resolution", type=int, default=None, choices=[256, 512]) p.add_argument("--num-classes", type=int, default=None) p.add_argument("--encoder-depth", type=int, default=None) p.add_argument("--enc-type", type=str, default=None) p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None) p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None) p.add_argument("--cfg-prob", type=float, default=None) p.add_argument( "--sampler", type=str, default="em_image_noise", choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"], help="采样器:ode=euler_sampler 确定性漂移(linspace 1→0 或 t_c 分段直连 0,无 t_floor;与 EM 网格不同)," "em=标准EM(含图像+cls噪声),em_image_noise=仅图像噪声," "em_image_noise_before_tc=t<=t_c时图像去随机+cls全程去随机", ) p.add_argument( "--dual-compare-after", action="store_true", help="需配合分段步数:同批 z/y/cls 连跑两次;after_input 用 --steps-after-tc," "after_equal_before 将 after 步数设为与 --steps-before-tc 相同", ) p.add_argument( "--save-fixed-trajectory", action="store_true", help="保存固定步采样轨迹(npy);仅对非 em 采样器启用,输出在 out-dir/trajectory", ) return p.parse_args() def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae): imgs = vae.decode((latents - latents_bias) / latents_scale).sample imgs = (imgs + 1) / 2.0 imgs = torch.clamp(imgs, 0, 1) return ( (imgs * 255.0) .round() .to(torch.uint8) .permute(0, 2, 3, 1) .cpu() .numpy() ) def main(): cli = parse_cli() device = torch.device(cli.device if torch.cuda.is_available() else "cpu") if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True try: ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False) except TypeError: ckpt = torch.load(cli.ckpt, map_location="cpu") ta = load_train_args_from_ckpt(ckpt) if ta is None: if cli.model is None or cli.resolution is None or cli.enc_type is None: print( "检查点中无 args,请至少指定:--model --resolution --enc-type " "(以及按需 --num-classes --encoder-depth)", file=sys.stderr, ) sys.exit(1) ta = argparse.Namespace( model=cli.model, resolution=cli.resolution, num_classes=cli.num_classes if cli.num_classes is not None else 1000, encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8, enc_type=cli.enc_type, fused_attn=cli.fused_attn if cli.fused_attn is not None else True, qk_norm=cli.qk_norm if cli.qk_norm is not None else False, cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1, ) else: if cli.model is not None: ta.model = cli.model if cli.resolution is not None: ta.resolution = cli.resolution if cli.num_classes is not None: ta.num_classes = cli.num_classes if cli.encoder_depth is not None: ta.encoder_depth = cli.encoder_depth if cli.enc_type is not None: ta.enc_type = cli.enc_type if cli.fused_attn is not None: ta.fused_attn = cli.fused_attn if cli.qk_norm is not None: ta.qk_norm = cli.qk_norm if cli.cfg_prob is not None: ta.cfg_prob = cli.cfg_prob path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear") tc_split = resolve_tc_schedule(cli, ta) if cli.dual_compare_after and tc_split[0] is None: print("--dual-compare-after 必须配合 --steps-before-tc 与 --steps-after-tc(分段采样)", file=sys.stderr) sys.exit(1) if tc_split[0] is not None: if cli.dual_compare_after: print( f"双次对比:t_c={tc_split[0]}, before={tc_split[1]}, " f"after_input={tc_split[2]}, after_equal_before={tc_split[1]}" ) else: print( f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]} " f"(总模型前向约 {tc_split[1] + tc_split[2] + 1} 次)" ) else: print(f"时间网格:均匀 num_steps={cli.num_steps}") if cli.sampler == "ode": sampler_fn = euler_ode_sampler elif cli.sampler == "em": sampler_fn = euler_maruyama_sampler elif cli.sampler == "em_image_noise_before_tc": sampler_fn = euler_maruyama_image_noise_before_tc_sampler else: sampler_fn = euler_maruyama_image_noise_sampler model, cls_dim = build_model_from_train_args(ta, device) wkey = cli.weights if wkey not in ckpt: raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}") state = ckpt[wkey] if cli.legacy: from utils import load_legacy_checkpoints state = load_legacy_checkpoints( state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8)) ) model.load_state_dict(state, strict=True) model.eval() vae = load_vae(device) latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1) latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1) sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale)) at_tc_dir = at_tc_a = at_tc_b = None pair_dir = None traj_dir = traj_a = traj_b = None if cli.dual_compare_after: out_a = os.path.join(cli.out_dir, "after_input") out_b = os.path.join(cli.out_dir, "after_equal_before") pair_dir = os.path.join(cli.out_dir, "pair") os.makedirs(out_a, exist_ok=True) os.makedirs(out_b, exist_ok=True) os.makedirs(pair_dir, exist_ok=True) if tc_split[0] is not None: at_tc_a = os.path.join(cli.out_dir, "at_tc", "after_input") at_tc_b = os.path.join(cli.out_dir, "at_tc", "after_equal_before") os.makedirs(at_tc_a, exist_ok=True) os.makedirs(at_tc_b, exist_ok=True) if cli.save_fixed_trajectory and cli.sampler != "em": traj_a = os.path.join(cli.out_dir, "trajectory", "after_input") traj_b = os.path.join(cli.out_dir, "trajectory", "after_equal_before") os.makedirs(traj_a, exist_ok=True) os.makedirs(traj_b, exist_ok=True) else: os.makedirs(cli.out_dir, exist_ok=True) if tc_split[0] is not None: at_tc_dir = os.path.join(cli.out_dir, "at_tc") os.makedirs(at_tc_dir, exist_ok=True) if cli.save_fixed_trajectory and cli.sampler != "em": traj_dir = os.path.join(cli.out_dir, "trajectory") os.makedirs(traj_dir, exist_ok=True) latent_size = int(getattr(ta, "resolution", 256)) // 8 n_total = int(cli.num_images) b = max(1, int(cli.batch_size)) torch.manual_seed(cli.seed) if device.type == "cuda": torch.cuda.manual_seed_all(cli.seed) written = 0 pbar = tqdm(total=n_total, desc="sampling") while written < n_total: cur = min(b, n_total - written) z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device) y = torch.randint(0, int(ta.num_classes), (cur,), device=device) cls_z = torch.randn(cur, cls_dim, device=device) with torch.no_grad(): base_kw = dict( num_steps=cli.num_steps, cfg_scale=cli.cfg_scale, guidance_low=cli.guidance_low, guidance_high=cli.guidance_high, path_type=path_type, cls_latents=cls_z, args=sampler_args, ) if cli.dual_compare_after: tc_v, sb, sa_in = tc_split # 两次完整采样会各自消耗 RNG;不重置则第二条的 1→t_c 噪声与第一条不同,z_tc/at_tc 会对不齐。 # 在固定 z/y/cls_z 之后打快照,第二条运行前恢复,使 t_c 中间态一致(仅后段步数不同)。 _rng_cpu_dual = torch.random.get_rng_state() _rng_cuda_dual = ( torch.cuda.get_rng_state_all() if device.type == "cuda" else None ) batch_imgs = {} for _run_i, (subdir, sa, tc_save_dir) in enumerate( ( (out_a, sa_in, at_tc_a), (out_b, sb, at_tc_b), ) ): if _run_i > 0: torch.random.set_rng_state(_rng_cpu_dual) if _rng_cuda_dual is not None: torch.cuda.set_rng_state_all(_rng_cuda_dual) em_kw = dict(base_kw) em_kw["t_c"] = tc_v em_kw["num_steps_before_tc"] = sb em_kw["num_steps_after_tc"] = sa if cli.sampler == "em_image_noise_before_tc": if cli.save_fixed_trajectory and cli.sampler != "em": latents, z_tc, cls_tc, cls_t0, traj = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_v), return_cls_final=True, return_trajectory=True, ) else: latents, z_tc, cls_tc, cls_t0 = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_v), return_cls_final=True, ) traj = None else: if cli.save_fixed_trajectory and cli.sampler != "em": latents, z_tc, cls_tc, traj = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_v), return_trajectory=True, ) else: latents, z_tc, cls_tc = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_v), ) traj = None cls_t0 = None latents = latents.to(torch.float32) imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae) batch_imgs[subdir] = imgs for i in range(cur): Image.fromarray(imgs[i]).save( os.path.join(subdir, f"{written + i:06d}.png") ) if tc_save_dir is not None and z_tc is not None: imgs_tc = _decode_to_uint8_hwc( z_tc.to(torch.float32), latents_bias, latents_scale, vae ) for i in range(cur): Image.fromarray(imgs_tc[i]).save( os.path.join(tc_save_dir, f"{written + i:06d}.png") ) if traj is not None: traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy() save_traj_dir = traj_a if subdir == out_a else traj_b np.save(os.path.join(save_traj_dir, f"{written:06d}_traj.npy"), traj_np) imgs_a = batch_imgs.get(out_a) imgs_b = batch_imgs.get(out_b) if pair_dir is not None and imgs_a is not None and imgs_b is not None: for i in range(cur): pair_img = np.concatenate([imgs_a[i], imgs_b[i]], axis=1) Image.fromarray(pair_img).save( os.path.join(pair_dir, f"{written + i:06d}.png") ) else: em_kw = dict(base_kw) if tc_split[0] is not None: em_kw["t_c"] = tc_split[0] em_kw["num_steps_before_tc"] = tc_split[1] em_kw["num_steps_after_tc"] = tc_split[2] if cli.sampler == "em_image_noise_before_tc": if cli.save_fixed_trajectory and cli.sampler != "em": latents, z_tc, cls_tc, cls_t0, traj = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_split[0]), return_cls_final=True, return_trajectory=True, ) else: latents, z_tc, cls_tc, cls_t0 = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_split[0]), return_cls_final=True, ) traj = None else: if cli.save_fixed_trajectory and cli.sampler != "em": latents, z_tc, cls_tc, traj = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_split[0]), return_trajectory=True, ) else: latents, z_tc, cls_tc = sampler_fn( model, z, y, **em_kw, return_mid_state=True, t_mid=float(tc_split[0]), ) traj = None cls_t0 = None latents = latents.to(torch.float32) if z_tc is not None and at_tc_dir is not None: imgs_tc = _decode_to_uint8_hwc( z_tc.to(torch.float32), latents_bias, latents_scale, vae ) for i in range(cur): Image.fromarray(imgs_tc[i]).save( os.path.join(at_tc_dir, f"{written + i:06d}.png") ) if traj is not None and traj_dir is not None: traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy() np.save(os.path.join(traj_dir, f"{written:06d}_traj.npy"), traj_np) else: latents = sampler_fn(model, z, y, **em_kw).to(torch.float32) imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae) for i in range(cur): Image.fromarray(imgs[i]).save( os.path.join(cli.out_dir, f"{written + i:06d}.png") ) written += cur pbar.update(cur) pbar.close() if cli.dual_compare_after: msg = ( f"Done. Saved {written} images per run under {out_a} and {out_b} " f"(parent: {cli.out_dir})" ) if pair_dir is not None: msg += f"; paired comparisons under {pair_dir}" if tc_split[0] is not None and at_tc_a is not None: msg += f"; t≈t_c decoded under {at_tc_a} and {at_tc_b}" print(msg) else: msg = f"Done. Saved {written} images under {cli.out_dir}" if tc_split[0] is not None and at_tc_dir is not None: msg += f"; t≈t_c decoded under {at_tc_dir}" print(msg) if __name__ == "__main__": main()