sida / sample_sd3_lora_rn_pair_ddp.py
xiangzai's picture
Add files using upload-large-folder tool
7803bdf verified
#!/usr/bin/env python
# coding=utf-8
"""
DDP对照采样:同一文本+同一初始噪声,分别生成 LoRA 与 RN 两类图像,并输出 pair 拼接图与 metadata。
"""
import argparse
import importlib.util
import json
import math
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
from PIL import Image
from tqdm import tqdm
from diffusers import StableDiffusion3Pipeline as DiffusersStableDiffusion3Pipeline
def dynamic_import_training_classes(project_root: str):
sys.path.insert(0, project_root)
import train_rectified_noise as trn
return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise
def load_local_pipeline_class(local_pipeline_path: str):
"""
从本地文件加载 StableDiffusion3Pipeline。
通过将模块名挂在 diffusers.pipelines.stable_diffusion_3 下,兼容文件内的相对导入。
"""
module_name = "diffusers.pipelines.stable_diffusion_3.local_pipeline_stable_diffusion_3"
spec = importlib.util.spec_from_file_location(module_name, local_pipeline_path)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to build import spec from: {local_pipeline_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if not hasattr(module, "StableDiffusion3Pipeline"):
raise ImportError("Local pipeline file has no StableDiffusion3Pipeline symbol.")
return module.StableDiffusion3Pipeline
def load_captions_from_jsonl(jsonl_path):
captions = []
with open(jsonl_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
cap = None
for field in ["caption", "text", "prompt", "description"]:
if field in data and isinstance(data[field], str):
cap = data[field].strip()
break
if cap:
captions.append(cap)
except Exception:
continue
return captions if captions else ["a beautiful high quality image"]
def load_sit_weights(rectified_module, weights_path: str):
if os.path.isdir(weights_path):
search_dirs = [weights_path, os.path.join(weights_path, "sit_weights")]
for d in search_dirs:
if not os.path.exists(d):
continue
st = os.path.join(d, "pytorch_sit_weights.safetensors")
if os.path.exists(st):
from safetensors.torch import load_file
state = load_file(st)
rectified_module.load_state_dict(state, strict=False)
return True
for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]:
cand = os.path.join(d, name)
if os.path.exists(cand):
state = torch.load(cand, map_location="cpu")
rectified_module.load_state_dict(state, strict=False)
return True
return False
else:
if weights_path.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(weights_path)
else:
state = torch.load(weights_path, map_location="cpu")
rectified_module.load_state_dict(state, strict=False)
return True
def save_jsonl_line(path, obj):
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
def load_jsonl(path):
if not os.path.exists(path):
return []
rows = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def merge_rank_metadata(out_path, rank_paths):
rows = []
for rp in rank_paths:
rows.extend(load_jsonl(rp))
rows.sort(key=lambda x: x.get("file_name", ""))
with open(out_path, "w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
def build_rn_model(base_pipeline, rectified_weights, num_sit_layers, device):
RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent))
tfm = base_pipeline.transformer
if hasattr(tfm.config, "joint_attention_dim") and tfm.config.joint_attention_dim is not None:
sit_hidden_size = tfm.config.joint_attention_dim
elif hasattr(tfm.config, "inner_dim") and tfm.config.inner_dim is not None:
sit_hidden_size = tfm.config.inner_dim
else:
sit_hidden_size = 4096
transformer_hidden_size = getattr(tfm.config, "hidden_size", 1536)
num_attention_heads = getattr(tfm.config, "num_attention_heads", 32)
input_dim = getattr(tfm.config, "in_channels", 16)
rectified_module = RectifiedNoiseModule(
hidden_size=sit_hidden_size,
num_sit_layers=num_sit_layers,
num_attention_heads=num_attention_heads,
input_dim=input_dim,
transformer_hidden_size=transformer_hidden_size,
)
ok = load_sit_weights(rectified_module, rectified_weights)
if not ok:
raise RuntimeError(f"Failed to load rectified weights from: {rectified_weights}")
model = SD3WithRectifiedNoise(base_pipeline.transformer, rectified_module).to(device)
model.eval()
return model
def create_npz_from_dir(sample_dir, max_samples):
import numpy as np
files = sorted([x for x in os.listdir(sample_dir) if x.endswith(".png") and x[:-4].isdigit()])
files = files[:max_samples]
if not files:
return None
arr = []
for fn in tqdm(files, desc=f"npz:{os.path.basename(sample_dir)}"):
arr.append(np.asarray(Image.open(os.path.join(sample_dir, fn))).astype(np.uint8))
arr = np.stack(arr)
out = f"{sample_dir}.npz"
np.savez(out, arr_0=arr)
return out
def set_pipeline_modules_eval(pipe):
"""
Diffusers pipeline 本身没有 .eval(),需要对内部 nn.Module 分别设为 eval。
"""
for name in ["transformer", "vae", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "model"]:
module = getattr(pipe, name, None)
if module is not None and hasattr(module, "eval"):
module.eval()
def main(args):
assert torch.cuda.is_available(), "Need GPU"
dist.init_process_group("nccl")
rank = dist.get_rank()
world = dist.get_world_size()
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
seed = args.global_seed * world + rank
torch.manual_seed(seed)
dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32)
root = Path(args.sample_dir)
lora_dir = root / "lora"
rn_dir = root / "rn"
pair_dir = root / "pair"
metadata_path = root / "metadata.jsonl"
lora_meta = lora_dir / "metadata.jsonl"
rn_meta = rn_dir / "metadata.jsonl"
pair_meta = pair_dir / "metadata.jsonl"
if rank == 0:
lora_dir.mkdir(parents=True, exist_ok=True)
rn_dir.mkdir(parents=True, exist_ok=True)
pair_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()
if args.stage == "lora":
pipe_lora = DiffusersStableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=dtype,
).to(device)
if args.lora_path:
pipe_lora.load_lora_weights(args.lora_path)
pipe_lora.set_progress_bar_config(disable=True)
set_pipeline_modules_eval(pipe_lora)
captions = load_captions_from_jsonl(args.captions_jsonl)
total_needed = min(len(captions) * args.images_per_caption, args.max_samples)
n = args.per_proc_batch_size
global_batch = n * world
total_samples = int(math.ceil(total_needed / global_batch) * global_batch)
iters = total_samples // global_batch
pbar = tqdm(range(iters)) if rank == 0 else range(iters)
rank_meta_path = root / f"metadata.rank{rank}.jsonl"
if rank_meta_path.exists():
rank_meta_path.unlink()
rank_lora_meta_path = lora_dir / f"metadata.rank{rank}.jsonl"
if rank_lora_meta_path.exists():
rank_lora_meta_path.unlink()
for it in pbar:
for k in range(n):
global_idx = it * global_batch + k * world + rank
if global_idx >= total_needed:
continue
cap_idx = global_idx // args.images_per_caption
prompt = captions[cap_idx]
image_seed = seed + it * 10000 + k * 1000
g = torch.Generator(device=device).manual_seed(image_seed)
latent_h = args.height // pipe_lora.vae_scale_factor
latent_w = args.width // pipe_lora.vae_scale_factor
latents = torch.randn(
(1, pipe_lora.transformer.config.in_channels, latent_h, latent_w),
device=device,
dtype=dtype,
generator=g,
)
with torch.autocast("cuda", dtype=dtype):
img_lora = pipe_lora(
prompt=prompt,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
latents=latents,
num_images_per_prompt=1,
).images[0]
fn = f"{global_idx:07d}.png"
img_lora.save(lora_dir / fn)
save_jsonl_line(str(rank_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed), "lora_file": f"lora/{fn}"})
save_jsonl_line(str(rank_lora_meta_path), {"file_name": fn, "caption": prompt, "seed": int(image_seed)})
dist.barrier()
dist.barrier()
if rank == 0:
merge_rank_metadata(str(metadata_path), [str(root / f"metadata.rank{r}.jsonl") for r in range(world)])
merge_rank_metadata(str(lora_meta), [str(lora_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
records = load_jsonl(str(metadata_path))
create_npz_from_dir(str(lora_dir), len(records))
elif args.stage == "rn":
records = load_jsonl(str(metadata_path))
if not records:
raise RuntimeError(f"metadata not found or empty: {metadata_path}. Run --stage lora first.")
total_needed = min(len(records), args.max_samples)
LocalStableDiffusion3Pipeline = load_local_pipeline_class(args.local_pipeline_path)
pipe_rn = LocalStableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=dtype,
).to(device)
if args.lora_path:
pipe_rn.load_lora_weights(args.lora_path)
pipe_rn.model = build_rn_model(pipe_rn, args.rectified_weights, args.num_sit_layers, device)
pipe_rn.set_progress_bar_config(disable=True)
set_pipeline_modules_eval(pipe_rn)
rank_rn_meta_path = rn_dir / f"metadata.rank{rank}.jsonl"
if rank_rn_meta_path.exists():
rank_rn_meta_path.unlink()
assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
pbar = tqdm(assigned) if rank == 0 else assigned
for rec in pbar:
fn = rec["file_name"]
prompt = rec["caption"]
image_seed = int(rec["seed"])
g = torch.Generator(device=device).manual_seed(image_seed)
latent_h = args.height // pipe_rn.vae_scale_factor
latent_w = args.width // pipe_rn.vae_scale_factor
latents = torch.randn(
(1, pipe_rn.transformer.config.in_channels, latent_h, latent_w),
device=device,
dtype=dtype,
generator=g,
)
with torch.autocast("cuda", dtype=dtype):
img_rn = pipe_rn(
prompt=prompt,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
latents=latents,
num_images_per_prompt=1,
).images[0]
img_rn.save(rn_dir / fn)
save_jsonl_line(str(rank_rn_meta_path), {"file_name": fn, "caption": prompt, "seed": image_seed})
dist.barrier()
if rank == 0:
merge_rank_metadata(str(rn_meta), [str(rn_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
create_npz_from_dir(str(rn_dir), total_needed)
elif args.stage == "pair":
records = load_jsonl(str(metadata_path))
if not records:
raise RuntimeError(f"metadata not found: {metadata_path}")
total_needed = min(len(records), args.max_samples)
rank_pair_meta_path = pair_dir / f"metadata.rank{rank}.jsonl"
if rank_pair_meta_path.exists():
rank_pair_meta_path.unlink()
assigned = [r for i, r in enumerate(records[:total_needed]) if i % world == rank]
for rec in assigned:
fn = rec["file_name"]
lora_img_path = lora_dir / fn
rn_img_path = rn_dir / fn
if not lora_img_path.exists() or not rn_img_path.exists():
continue
img_lora = Image.open(lora_img_path).convert("RGB")
img_rn = Image.open(rn_img_path).convert("RGB")
pair = Image.new("RGB", (img_lora.width + img_rn.width, max(img_lora.height, img_rn.height)))
pair.paste(img_lora, (0, 0))
pair.paste(img_rn, (img_lora.width, 0))
pair.save(pair_dir / fn)
save_jsonl_line(
str(rank_pair_meta_path),
{"file_name": fn, "caption": rec["caption"], "seed": int(rec["seed"]), "pair_file": f"pair/{fn}"},
)
dist.barrier()
if rank == 0:
merge_rank_metadata(str(pair_meta), [str(pair_dir / f"metadata.rank{r}.jsonl") for r in range(world)])
# 更新根 metadata,补齐 rn/pair 路径
rn_set = {r["file_name"] for r in load_jsonl(str(rn_meta))}
pair_set = {r["file_name"] for r in load_jsonl(str(pair_meta))}
merged = []
for r in records[:total_needed]:
fn = r["file_name"]
out = dict(r)
if fn in rn_set:
out["rn_file"] = f"rn/{fn}"
if fn in pair_set:
out["pair_file"] = f"pair/{fn}"
merged.append(out)
with open(metadata_path, "w", encoding="utf-8") as f:
for r in merged:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
else:
raise ValueError(f"Unknown stage: {args.stage}")
dist.barrier()
if rank == 0:
print(f"Stage {args.stage} done. Output root: {root}")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DDP compare sampling: LoRA vs RN with same latent/prompt.")
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
parser.add_argument(
"--local_pipeline_path",
type=str,
default=str(Path(__file__).parent / "pipeline_stable_diffusion_3.py"),
help="RN 分支使用的本地 pipeline 文件路径",
)
parser.add_argument("--revision", type=str, default=None)
parser.add_argument("--variant", type=str, default=None)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--rectified_weights", type=str, required=True)
parser.add_argument("--num_sit_layers", type=int, default=1)
parser.add_argument("--captions_jsonl", type=str, required=True)
parser.add_argument("--sample_dir", type=str, default="./sd3_lora_rn_compare")
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--guidance_scale", type=float, default=7.0)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--per_proc_batch_size", type=int, default=4)
parser.add_argument("--images_per_caption", type=int, default=1)
parser.add_argument("--max_samples", type=int, default=10000)
parser.add_argument("--global_seed", type=int, default=42)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
parser.add_argument("--stage", type=str, default="lora", choices=["lora", "rn", "pair"])
args = parser.parse_args()
main(args)