#!/usr/bin/env python # coding=utf-8 """ DDP对照采样:同一文本+同一初始噪声,分别生成 LoRA 与 RN 两类图像,并输出 pair 拼接图与 metadata。 """ import argparse import importlib.util import json import math import os import sys from pathlib import Path import torch import torch.distributed as dist from PIL import Image from tqdm import tqdm from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline def dynamic_import_training_classes(project_root: str): sys.path.insert(0, project_root) import train_rectified_noise as trn return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise def load_local_pipeline_class(local_pipeline_path: str): """ 从本地文件加载 StableDiffusion3Pipeline。 通过将模块名挂在 diffusers.pipelines.stable_diffusion_3 下,兼容文件内的相对导入。 """ module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3" spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path) if spec is None or spec.loader is None: raise ImportError(f"Failed to build import spec from: {local_pipeline_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) if not hasattr(module, "StableDiffusion3Pipeline"): raise ImportError("Local pipeline file has no StableDiffusion3Pipeline symbol.") return module.StableDiffusion3Pipeline def load_captions_from_jsonl(jsonl_path): captions = [] with open(jsonl_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: data = json.loads(line) cap = None for field in ["caption", "text", "prompt", "description"]: if field in data and isinstance(data[field], str): cap = data[field].strip() break if cap: captions.append(cap) except Exception: continue return captions if captions else ["a beautiful high quality image"] def load_sit_weights(rectified_module, weights_path: str): if os.path.isdir(weights_path): search_dirs = [weights_path, os.path.join(weights_path, "sit_weights")] for d in search_dirs: if not os.path.exists(d): continue st = os.path.join(d, "pytorch_sit_weights.safetensors") if os.path.exists(st): from safetensors.torch import load_file state = load_file(st) rectified_module.load_state_dict(state, strict=False) return True for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]: cand = os.path.join(d, name) if os.path.exists(cand): state = torch.load(cand, map_location="cpu") rectified_module.load_state_dict(state, strict=False) return True return False else: if weights_path.endswith(".safetensors"): from safetensors.torch import load_file state = load_file(weights_path) else: state = torch.load(weights_path, map_location="cpu") rectified_module.load_state_dict(state, strict=False) return True def save_jsonl_line(path, obj): with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(obj, ensure_ascii=False) + "\n") def load_jsonl(path): if not os.path.exists(path): return [] rows = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) return rows def merge_rank_metadata(out_path, rank_paths): rows = [] for rp in rank_paths: rows.extend(load_jsonl(rp)) rows.sort(key=lambda x: x.get("file_name", "")) with open(out_path, "w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r, ensure_ascii=False) + "\n") def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device): RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent)) tfm = base_pipeline.transformer if hasattr(tfm.config, "joint_attention_dim") and tfm.config.joint_attention_dim is not None: sit_hidden_size = tfm.config.joint_attention_dim elif hasattr(tfm.config, "inner_dim") and tfm.config.inner_dim is not None: sit_hidden_size = tfm.config.inner_dim else: sit_hidden_size = 4096 transformer_hidden_size = getattr(tfm.config, "hidden_size", 1536) num_attention_heads = getattr(tfm.config, "num_attention_heads", 32) input_dim = getattr(tfm.config, "in_channels", 16) rectified_module = RectifiedNoiseModule( hidden_size=sit_hidden_size, num_sit_layers=num_sit_layers, num_attention_heads=num_attention_heads, input_dim=input_dim, transformer_hidden_size=transformer_hidden_size, ) ok = load_sit_weights(rectified_module, rectified_weights) if not ok: raise RuntimeError(f"Failed to load rectified weights from: {rectified_weights}") model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device) model.eval() return model def create_npz_from_dir(sample_dir, max_samples): import numpy as np files = sorted([x for x in os.listdir(sample_dir) if x.endswith(".png") and x[:-4].isdigit()]) files = files[:max_samples] if not files: return None arr = [] for fn in tqdm(files, desc=f"npz:{os.path.basename(sample_dir)}"): arr.append(np.asarray(Image.open(os.path.join(sample_dir, fn))).astype(np.uint8)) arr = np.stack(arr) out = f"{sample_dir}.npz" np.savez(out, arr_0=arr) return out def set_pipeline_modules_eval(pipe): """ Diffusers pipeline 本身没有 .eval(),需要对内部 nn.Module 分别设为 eval。 """ for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "model"]: module = getattr(pipe, name, None) if module is not None and hasattr(module, "eval"): module.eval() def main(args): assert torch.cuda.is_available(), "Need GPU" dist.init_process_group("nccl") rank = dist.get_rank() world = dist.get_world_size() device = rank % torch.cuda.device_count() torch.cuda.set_device(device) seed = args.global_seed * world + rank torch.manual_seed(seed) dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32) root = Path(args.sample_dir) lora_dir = root / "lora" rn_dir = root / "rn" pair_dir = root / "pair" metadata_path = root / "metadata.jsonl" lora_meta = lora_dir / "metadata.jsonl" rn_meta = rn_dir / "metadata.jsonl" pair_meta = pair_dir / "metadata.jsonl" if rank == 0: lora_dir.mkdir(parents=True, exist_ok=True) rn_dir.mkdir(parents=True, exist_ok=True) pair_dir.mkdir(parents=True, exist_ok=True) dist.barrier() if args.stage == "lora": pipe_lora = DiffusersStableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=dtype, ).to(device) if args.lora_path: pipe_lora.load_lora_weights(args.lora_path) pipe_lora.set_progress_bar_config(disable=True) set_pipeline_modules_eval(pipe_lora) captions = load_captions_from_jsonl(args.captions_jsonl) total_needed = min(len(captions) * args.images_per_caption, args.max_samples) n = args.per_proc_batch_size global_batch = n * world total_samples = int(math.ceil(total_needed / global_batch) * global_batch) iters = total_samples // global_batch pbar = tqdm(range(iters)) if rank == 0 else range(iters) rank_meta_path = root / f"metadata.rank{rank}.jsonl" if rank_meta_path.exists(): rank_meta_path.unlink() rank_lora_meta_path = lora_dir / f"metadata.rank{rank}.jsonl" if rank_lora_meta_path.exists(): rank_lora_meta_path.unlink() for it in pbar: for k in range(n): global_idx = it * global_batch + k * world + rank if global_idx >= total_needed: continue cap_idx = global_idx // args.images_per_caption prompt = captions[cap_idx] image_seed = seed + it * 10000 + k * 1000 g = torch.Generator(device=device).manual_seed(image_seed) latent_h = args.height // pipe_lora.vae_scale_factor latent_w = args.width // pipe_lora.vae_scale_factor latents = torch.randn( (1, pipe_lora.transformer.config.in_channels, latent_h, latent_w), device=device, dtype=dtype, generator=g, ) with torch.autocast("cuda", dtype=dtype): img_lora = pipe_lora( prompt=prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, latents=latents, num_images_per_prompt=1, ).images[0] fn = f"{global_idx:07d}.png" img_lora.save(lora_dir / fn) save_jsonl_line(str(rank_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed), "lora_file": f"lora/{fn}"}) save_jsonl_line(str(rank_lora_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed)}) dist.barrier() dist.barrier() if rank == 0: merge_rank_metadata(str(metadata_path), [str(root / f"metadata.rank{r}.jsonl") for r in range(world)]) merge_rank_metadata(str(lora_meta), [str(lora_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) records = load_jsonl(str(metadata_path)) create_npz_from_dir(str(lora_dir), len(records)) elif args.stage == "rn": records = load_jsonl(str(metadata_path)) if not records: raise RuntimeError(f"metadata not found or empty: {metadata_path}. Run --stage lora first.") total_needed = min(len(records), args.max_samples) LocalStableDiffusion3Pipeline = load_local_pipeline_class(args.local_pipeline_path) pipe_rn = LocalStableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=dtype, ).to(device) if args.lora_path: pipe_rn.load_lora_weights(args.lora_path) pipe_rn.model = build_rn_model(pipe_rn, args.rectified_weights, args.num_sit_layers, device) pipe_rn.set_progress_bar_config(disable=True) set_pipeline_modules_eval(pipe_rn) rank_rn_meta_path = rn_dir / f"metadata.rank{rank}.jsonl" if rank_rn_meta_path.exists(): rank_rn_meta_path.unlink() assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank] pbar = tqdm(assigned) if rank == 0 else assigned for rec in pbar: fn = rec["file_name"] prompt = rec["caption"] image_seed = int(rec["seed"]) g = torch.Generator(device=device).manual_seed(image_seed) latent_h = args.height // pipe_rn.vae_scale_factor latent_w = args.width // pipe_rn.vae_scale_factor latents = torch.randn( (1, pipe_rn.transformer.config.in_channels, latent_h, latent_w), device=device, dtype=dtype, generator=g, ) with torch.autocast("cuda", dtype=dtype): img_rn = pipe_rn( prompt=prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, latents=latents, num_images_per_prompt=1, ).images[0] img_rn.save(rn_dir / fn) save_jsonl_line(str(rank_rn_meta_path), {"file_name": fn, "caption": prompt, "seed": image_seed}) dist.barrier() if rank == 0: merge_rank_metadata(str(rn_meta), [str(rn_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) create_npz_from_dir(str(rn_dir), total_needed) elif args.stage == "pair": records = load_jsonl(str(metadata_path)) if not records: raise RuntimeError(f"metadata not found: {metadata_path}") total_needed = min(len(records), args.max_samples) rank_pair_meta_path = pair_dir / f"metadata.rank{rank}.jsonl" if rank_pair_meta_path.exists(): rank_pair_meta_path.unlink() assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank] for rec in assigned: fn = rec["file_name"] lora_img_path = lora_dir / fn rn_img_path = rn_dir / fn if not lora_img_path.exists() or not rn_img_path.exists(): continue img_lora = Image.open(lora_img_path).convert("RGB") img_rn = Image.open(rn_img_path).convert("RGB") pair = Image.new("RGB", (img_lora.width + img_rn.width, max(img_lora.height, img_rn.height))) pair.paste(img_lora, (0, 0)) pair.paste(img_rn, (img_lora.width, 0)) pair.save(pair_dir / fn) save_jsonl_line( str(rank_pair_meta_path), {"file_name": fn, "caption": rec["caption"], "seed": int(rec["seed"]), "pair_file": f"pair/{fn}"}, ) dist.barrier() if rank == 0: merge_rank_metadata(str(pair_meta), [str(pair_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) # 更新根 metadata,补齐 rn/pair 路径 rn_set = {r["file_name"] for r in load_jsonl(str(rn_meta))} pair_set = {r["file_name"] for r in load_jsonl(str(pair_meta))} merged = [] for r in records[:total_needed]: fn = r["file_name"] out = dict(r) if fn in rn_set: out["rn_file"] = f"rn/{fn}" if fn in pair_set: out["pair_file"] = f"pair/{fn}" merged.append(out) with open(metadata_path, "w", encoding="utf-8") as f: for r in merged: f.write(json.dumps(r, ensure_ascii=False) + "\n") else: raise ValueError(f"Unknown stage: {args.stage}") dist.barrier() if rank == 0: print(f"Stage {args.stage} done. Output root: {root}") dist.barrier() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser(description="DDP compare sampling: LoRA vs RN with same latent/prompt.") parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) parser.add_argument( "--local_pipeline_path", type=str, default=str(Path(__file__).parent / "pipeline_stable_diffusion_3.py"), help="RN 分支使用的本地 pipeline 文件路径", ) parser.add_argument("--revision", type=str, default=None) parser.add_argument("--variant", type=str, default=None) parser.add_argument("--lora_path", type=str, default=None) parser.add_argument("--rectified_weights", type=str, required=True) parser.add_argument("--num_sit_layers", type=int, default=1) parser.add_argument("--captions_jsonl", type=str, required=True) parser.add_argument("--sample_dir", type=str, default="./sd3_lora_rn_compare") parser.add_argument("--num_inference_steps", type=int, default=40) parser.add_argument("--guidance_scale", type=float, default=7.0) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) parser.add_argument("--per_proc_batch_size", type=int, default=4) parser.add_argument("--images_per_caption", type=int, default=1) parser.add_argument("--max_samples", type=int, default=10000) parser.add_argument("--global_seed", type=int, default=42) parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) parser.add_argument("--stage", type=str, default="lora", choices=["lora", "rn", "pair"]) args = parser.parse_args() main(args)