| |
| |
| """ |
| DDP对照采样:同一文本+同一初始噪声,分别生成 LoRA 与 RN 两类图像,并输出 pair 拼接图与 metadata。 |
| """ |
|
|
| import argparse |
| import importlib.util |
| import json |
| import math |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline |
|
|
|
|
| def dynamic_import_training_classes(project_root: str): |
| sys.path.insert(0, project_root) |
| import train_rectified_noise as trn |
|
|
| return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise |
|
|
|
|
| def load_local_pipeline_class(local_pipeline_path: str): |
| """ |
| 从本地文件加载 StableDiffusion3Pipeline。 |
| 通过将模块名挂在 diffusers.pipelines.stable_diffusion_3 下,兼容文件内的相对导入。 |
| """ |
| module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3" |
| spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path) |
| if spec is None or spec.loader is None: |
| raise ImportError(f"Failed to build import spec from: {local_pipeline_path}") |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| if not hasattr(module, "StableDiffusion3Pipeline"): |
| raise ImportError("Local pipeline file has no StableDiffusion3Pipeline symbol.") |
| return module.StableDiffusion3Pipeline |
|
|
|
|
| def load_captions_from_jsonl(jsonl_path): |
| captions = [] |
| with open(jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| data = json.loads(line) |
| cap = None |
| for field in ["caption", "text", "prompt", "description"]: |
| if field in data and isinstance(data[field], str): |
| cap = data[field].strip() |
| break |
| if cap: |
| captions.append(cap) |
| except Exception: |
| continue |
| return captions if captions else ["a beautiful high quality image"] |
|
|
|
|
| def load_sit_weights(rectified_module, weights_path: str): |
| if os.path.isdir(weights_path): |
| search_dirs = [weights_path, os.path.join(weights_path, "sit_weights")] |
| for d in search_dirs: |
| if not os.path.exists(d): |
| continue |
| st = os.path.join(d, "pytorch_sit_weights.safetensors") |
| if os.path.exists(st): |
| from safetensors.torch import load_file |
|
|
| state = load_file(st) |
| rectified_module.load_state_dict(state, strict=False) |
| return True |
| for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]: |
| cand = os.path.join(d, name) |
| if os.path.exists(cand): |
| state = torch.load(cand, map_location="cpu") |
| rectified_module.load_state_dict(state, strict=False) |
| return True |
| return False |
| else: |
| if weights_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
|
|
| state = load_file(weights_path) |
| else: |
| state = torch.load(weights_path, map_location="cpu") |
| rectified_module.load_state_dict(state, strict=False) |
| return True |
|
|
|
|
| def save_jsonl_line(path, obj): |
| with open(path, "a", encoding="utf-8") as f: |
| f.write(json.dumps(obj, ensure_ascii=False) + "\n") |
|
|
|
|
| def load_jsonl(path): |
| if not os.path.exists(path): |
| return [] |
| rows = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| rows.append(json.loads(line)) |
| return rows |
|
|
|
|
| def merge_rank_metadata(out_path, rank_paths): |
| rows = [] |
| for rp in rank_paths: |
| rows.extend(load_jsonl(rp)) |
| rows.sort(key=lambda x: x.get("file_name", "")) |
| with open(out_path, "w", encoding="utf-8") as f: |
| for r in rows: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
|
|
| def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device): |
| RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent)) |
| tfm = base_pipeline.transformer |
|
|
| if hasattr(tfm.config, "joint_attention_dim") and tfm.config.joint_attention_dim is not None: |
| sit_hidden_size = tfm.config.joint_attention_dim |
| elif hasattr(tfm.config, "inner_dim") and tfm.config.inner_dim is not None: |
| sit_hidden_size = tfm.config.inner_dim |
| else: |
| sit_hidden_size = 4096 |
|
|
| transformer_hidden_size = getattr(tfm.config, "hidden_size", 1536) |
| num_attention_heads = getattr(tfm.config, "num_attention_heads", 32) |
| input_dim = getattr(tfm.config, "in_channels", 16) |
|
|
| rectified_module = RectifiedNoiseModule( |
| hidden_size=sit_hidden_size, |
| num_sit_layers=num_sit_layers, |
| num_attention_heads=num_attention_heads, |
| input_dim=input_dim, |
| transformer_hidden_size=transformer_hidden_size, |
| ) |
| ok = load_sit_weights(rectified_module, rectified_weights) |
| if not ok: |
| raise RuntimeError(f"Failed to load rectified weights from: {rectified_weights}") |
|
|
| model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device) |
| model.eval() |
| return model |
|
|
|
|
| def create_npz_from_dir(sample_dir, max_samples): |
| import numpy as np |
|
|
| files = sorted([x for x in os.listdir(sample_dir) if x.endswith(".png") and x[:-4].isdigit()]) |
| files = files[:max_samples] |
| if not files: |
| return None |
| arr = [] |
| for fn in tqdm(files, desc=f"npz:{os.path.basename(sample_dir)}"): |
| arr.append(np.asarray(Image.open(os.path.join(sample_dir, fn))).astype(np.uint8)) |
| arr = np.stack(arr) |
| out = f"{sample_dir}.npz" |
| np.savez(out, arr_0=arr) |
| return out |
|
|
|
|
| def set_pipeline_modules_eval(pipe): |
| """ |
| Diffusers pipeline 本身没有 .eval(),需要对内部 nn.Module 分别设为 eval。 |
| """ |
| for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "model"]: |
| module = getattr(pipe, name, None) |
| if module is not None and hasattr(module, "eval"): |
| module.eval() |
|
|
|
|
| def main(args): |
| assert torch.cuda.is_available(), "Need GPU" |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| world = dist.get_world_size() |
| device = rank % torch.cuda.device_count() |
| torch.cuda.set_device(device) |
| seed = args.global_seed * world + rank |
| torch.manual_seed(seed) |
|
|
| dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32) |
|
|
| root = Path(args.sample_dir) |
| lora_dir = root / "lora" |
| rn_dir = root / "rn" |
| pair_dir = root / "pair" |
| metadata_path = root / "metadata.jsonl" |
| lora_meta = lora_dir / "metadata.jsonl" |
| rn_meta = rn_dir / "metadata.jsonl" |
| pair_meta = pair_dir / "metadata.jsonl" |
| if rank == 0: |
| lora_dir.mkdir(parents=True, exist_ok=True) |
| rn_dir.mkdir(parents=True, exist_ok=True) |
| pair_dir.mkdir(parents=True, exist_ok=True) |
| dist.barrier() |
|
|
| if args.stage == "lora": |
| pipe_lora = DiffusersStableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ).to(device) |
| if args.lora_path: |
| pipe_lora.load_lora_weights(args.lora_path) |
| pipe_lora.set_progress_bar_config(disable=True) |
| set_pipeline_modules_eval(pipe_lora) |
|
|
| captions = load_captions_from_jsonl(args.captions_jsonl) |
| total_needed = min(len(captions) * args.images_per_caption, args.max_samples) |
| n = args.per_proc_batch_size |
| global_batch = n * world |
| total_samples = int(math.ceil(total_needed / global_batch) * global_batch) |
| iters = total_samples // global_batch |
| pbar = tqdm(range(iters)) if rank == 0 else range(iters) |
|
|
| rank_meta_path = root / f"metadata.rank{rank}.jsonl" |
| if rank_meta_path.exists(): |
| rank_meta_path.unlink() |
| rank_lora_meta_path = lora_dir / f"metadata.rank{rank}.jsonl" |
| if rank_lora_meta_path.exists(): |
| rank_lora_meta_path.unlink() |
|
|
| for it in pbar: |
| for k in range(n): |
| global_idx = it * global_batch + k * world + rank |
| if global_idx >= total_needed: |
| continue |
| cap_idx = global_idx // args.images_per_caption |
| prompt = captions[cap_idx] |
| image_seed = seed + it * 10000 + k * 1000 |
|
|
| g = torch.Generator(device=device).manual_seed(image_seed) |
| latent_h = args.height // pipe_lora.vae_scale_factor |
| latent_w = args.width // pipe_lora.vae_scale_factor |
| latents = torch.randn( |
| (1, pipe_lora.transformer.config.in_channels, latent_h, latent_w), |
| device=device, |
| dtype=dtype, |
| generator=g, |
| ) |
|
|
| with torch.autocast("cuda", dtype=dtype): |
| img_lora = pipe_lora( |
| prompt=prompt, |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| latents=latents, |
| num_images_per_prompt=1, |
| ).images[0] |
|
|
| fn = f"{global_idx:07d}.png" |
| img_lora.save(lora_dir / fn) |
| save_jsonl_line(str(rank_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed), "lora_file": f"lora/{fn}"}) |
| save_jsonl_line(str(rank_lora_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed)}) |
| dist.barrier() |
|
|
| dist.barrier() |
| if rank == 0: |
| merge_rank_metadata(str(metadata_path), [str(root / f"metadata.rank{r}.jsonl") for r in range(world)]) |
| merge_rank_metadata(str(lora_meta), [str(lora_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) |
| records = load_jsonl(str(metadata_path)) |
| create_npz_from_dir(str(lora_dir), len(records)) |
|
|
| elif args.stage == "rn": |
| records = load_jsonl(str(metadata_path)) |
| if not records: |
| raise RuntimeError(f"metadata not found or empty: {metadata_path}. Run --stage lora first.") |
| total_needed = min(len(records), args.max_samples) |
|
|
| LocalStableDiffusion3Pipeline = load_local_pipeline_class(args.local_pipeline_path) |
| pipe_rn = LocalStableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ).to(device) |
| if args.lora_path: |
| pipe_rn.load_lora_weights(args.lora_path) |
| pipe_rn.model = build_rn_model(pipe_rn, args.rectified_weights, args.num_sit_layers, device) |
| pipe_rn.set_progress_bar_config(disable=True) |
| set_pipeline_modules_eval(pipe_rn) |
|
|
| rank_rn_meta_path = rn_dir / f"metadata.rank{rank}.jsonl" |
| if rank_rn_meta_path.exists(): |
| rank_rn_meta_path.unlink() |
|
|
| assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank] |
| pbar = tqdm(assigned) if rank == 0 else assigned |
| for rec in pbar: |
| fn = rec["file_name"] |
| prompt = rec["caption"] |
| image_seed = int(rec["seed"]) |
| g = torch.Generator(device=device).manual_seed(image_seed) |
| latent_h = args.height // pipe_rn.vae_scale_factor |
| latent_w = args.width // pipe_rn.vae_scale_factor |
| latents = torch.randn( |
| (1, pipe_rn.transformer.config.in_channels, latent_h, latent_w), |
| device=device, |
| dtype=dtype, |
| generator=g, |
| ) |
| with torch.autocast("cuda", dtype=dtype): |
| img_rn = pipe_rn( |
| prompt=prompt, |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| latents=latents, |
| num_images_per_prompt=1, |
| ).images[0] |
| img_rn.save(rn_dir / fn) |
| save_jsonl_line(str(rank_rn_meta_path), {"file_name": fn, "caption": prompt, "seed": image_seed}) |
| dist.barrier() |
| if rank == 0: |
| merge_rank_metadata(str(rn_meta), [str(rn_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) |
| create_npz_from_dir(str(rn_dir), total_needed) |
|
|
| elif args.stage == "pair": |
| records = load_jsonl(str(metadata_path)) |
| if not records: |
| raise RuntimeError(f"metadata not found: {metadata_path}") |
| total_needed = min(len(records), args.max_samples) |
| rank_pair_meta_path = pair_dir / f"metadata.rank{rank}.jsonl" |
| if rank_pair_meta_path.exists(): |
| rank_pair_meta_path.unlink() |
| assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank] |
| for rec in assigned: |
| fn = rec["file_name"] |
| lora_img_path = lora_dir / fn |
| rn_img_path = rn_dir / fn |
| if not lora_img_path.exists() or not rn_img_path.exists(): |
| continue |
| img_lora = Image.open(lora_img_path).convert("RGB") |
| img_rn = Image.open(rn_img_path).convert("RGB") |
| pair = Image.new("RGB", (img_lora.width + img_rn.width, max(img_lora.height, img_rn.height))) |
| pair.paste(img_lora, (0, 0)) |
| pair.paste(img_rn, (img_lora.width, 0)) |
| pair.save(pair_dir / fn) |
| save_jsonl_line( |
| str(rank_pair_meta_path), |
| {"file_name": fn, "caption": rec["caption"], "seed": int(rec["seed"]), "pair_file": f"pair/{fn}"}, |
| ) |
| dist.barrier() |
| if rank == 0: |
| merge_rank_metadata(str(pair_meta), [str(pair_dir / f"metadata.rank{r}.jsonl") for r in range(world)]) |
| |
| rn_set = {r["file_name"] for r in load_jsonl(str(rn_meta))} |
| pair_set = {r["file_name"] for r in load_jsonl(str(pair_meta))} |
| merged = [] |
| for r in records[:total_needed]: |
| fn = r["file_name"] |
| out = dict(r) |
| if fn in rn_set: |
| out["rn_file"] = f"rn/{fn}" |
| if fn in pair_set: |
| out["pair_file"] = f"pair/{fn}" |
| merged.append(out) |
| with open(metadata_path, "w", encoding="utf-8") as f: |
| for r in merged: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
| else: |
| raise ValueError(f"Unknown stage: {args.stage}") |
|
|
| dist.barrier() |
| if rank == 0: |
| print(f"Stage {args.stage} done. Output root: {root}") |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="DDP compare sampling: LoRA vs RN with same latent/prompt.") |
| parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) |
| parser.add_argument( |
| "--local_pipeline_path", |
| type=str, |
| default=str(Path(__file__).parent / "pipeline_stable_diffusion_3.py"), |
| help="RN 分支使用的本地 pipeline 文件路径", |
| ) |
| parser.add_argument("--revision", type=str, default=None) |
| parser.add_argument("--variant", type=str, default=None) |
| parser.add_argument("--lora_path", type=str, default=None) |
| parser.add_argument("--rectified_weights", type=str, required=True) |
| parser.add_argument("--num_sit_layers", type=int, default=1) |
| parser.add_argument("--captions_jsonl", type=str, required=True) |
| parser.add_argument("--sample_dir", type=str, default="./sd3_lora_rn_compare") |
| parser.add_argument("--num_inference_steps", type=int, default=40) |
| parser.add_argument("--guidance_scale", type=float, default=7.0) |
| parser.add_argument("--height", type=int, default=512) |
| parser.add_argument("--width", type=int, default=512) |
| parser.add_argument("--per_proc_batch_size", type=int, default=4) |
| parser.add_argument("--images_per_caption", type=int, default=1) |
| parser.add_argument("--max_samples", type=int, default=10000) |
| parser.add_argument("--global_seed", type=int, default=42) |
| parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) |
| parser.add_argument("--stage", type=str, default="lora", choices=["lora", "rn", "pair"]) |
|
|
| args = parser.parse_args() |
| main(args) |
|
|
|
|