| import json |
| import numpy as np |
| import math |
| import csv |
| import random |
| import argparse |
| import torch |
| import os |
| import torch.distributed as dist |
|
|
| from PIL import Image |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| from accelerate.utils import set_seed |
|
|
| from diffusion_pipeline.sd35_pipeline import StableDiffusion3Pipeline, FlowMatchEulerInverseScheduler |
| from diffusion_pipeline.sdxl_pipeline import StableDiffusionXLPipeline |
| from diffusers import BitsAndBytesConfig, SD3Transformer2DModel |
| from diffusers import FlowMatchEulerDiscreteScheduler, DDIMInverseScheduler, DDIMScheduler |
|
|
| device = torch.device('cuda') |
|
|
| def get_args(): |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default='sd35', choices=['sdxl', 'sd35'], type=str) |
| parser.add_argument("--inference-step", default=30, type=int) |
| parser.add_argument("--size", default=1024, type=int) |
| parser.add_argument("--seed", default=33, type=int) |
| parser.add_argument("--cfg", default=3.5, type=float) |
|
|
| |
| parser.add_argument("--inv-cfg", default=0.5, type=float) |
|
|
| |
| parser.add_argument("--w2s-guidance", default=1.5, type=float) |
| parser.add_argument("--end_timesteps", default=28, type=int) |
|
|
|
|
| parser.add_argument("--prompt", default='Mickey Mouse painting by Frank Frazetta.', type=str) |
|
|
| parser.add_argument("--method", default='standard', choices=['standard', 'core', 'zigzag', 'z-core'], type=str) |
|
|
| args = parser.parse_args() |
| return args |
|
|
|
|
| if __name__ == '__main__': |
| torch.cuda.empty_cache() |
| dtype = torch.float16 |
| args = get_args() |
| print("args.seed: ", args.seed) |
| set_seed(args.seed) |
|
|
| |
| if args.model == 'sd35': |
| nf4_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
| model_nf4 = SD3Transformer2DModel.from_pretrained( |
| "stabilityai/stable-diffusion-3.5-large", |
| subfolder="transformer", |
| quantization_config=nf4_config, |
| torch_dtype=torch.bfloat16 |
| ) |
|
|
| pipe = StableDiffusion3Pipeline.from_pretrained( |
| "stabilityai/stable-diffusion-3.5-large", |
| transformer=model_nf4, |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) |
| inverse_scheduler = FlowMatchEulerInverseScheduler.from_pretrained("stabilityai/stable-diffusion-3.5-large", |
| subfolder='scheduler') |
| pipe.inv_scheduler = inverse_scheduler |
|
|
| elif args.model == "sdxl": |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| torch_dtype=torch.float16, |
| variant="fp16", |
| use_safetensors=True |
| ).to("cuda") |
|
|
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
| inverse_scheduler = DDIMInverseScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", |
| subfolder='scheduler') |
| pipe.inv_scheduler = inverse_scheduler |
|
|
| pipe.to(device) |
| pipe.enable_model_cpu_offload() |
|
|
| |
| if args.method == 'core' or args.method == 'z-core': |
| from diffusion_pipeline.refine_model import PromptSD35Net, PromptSDXLNet |
| from diffusion_pipeline.lora import replace_linear_with_lora, lora_true |
|
|
| if args.model == 'sd35': |
| refine_model = PromptSD35Net() |
| replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28) |
| lora_true(refine_model, lora_idx=0) |
|
|
| checkpoint = torch.load('./weights/sd35_noise_model.pth', map_location='cpu') |
| refine_model.load_state_dict(checkpoint) |
| elif args.model == 'sdxl': |
| refine_model = PromptSDXLNet() |
| replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50) |
| lora_true(refine_model, lora_idx=0) |
|
|
| checkpoint = torch.load('./weights/sdxl_noise_model.pth', map_location='cpu') |
| refine_model.load_state_dict(checkpoint) |
|
|
| print("Load Lora Success") |
| refine_model = refine_model.to(device) |
| refine_model = refine_model.to(torch.bfloat16) |
|
|
| |
| |
| size = args.size |
| if args.model == 'sdxl': |
| shape = (1, 4, size // 8, size // 8) |
| else: |
| shape = (1, 16, size // 8, size // 8) |
|
|
| num_steps = args.inference_step |
| end_timesteps = args.end_timesteps |
| guidance_scale = args.cfg |
| w2s_guidance = args.w2s_guidance |
| inv_cfg = args.inv_cfg |
| prompt = args.prompt |
|
|
| print("pass this prompt: ", prompt) |
| |
| start_latents = torch.randn(shape, dtype=dtype).to(device) |
| |
| if args.model == 'sdxl': |
| if args.method == 'core': |
| output = pipe.core( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_steps, |
| latents=start_latents, |
| return_dict=False, |
| refine_model=refine_model, |
| lora_true=lora_true, |
| end_timesteps=end_timesteps, |
| w2s_guidance=w2s_guidance)[0][0] |
| |
| elif args.method == 'zigzag': |
| output = pipe.zigzag( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| latents=start_latents, |
| return_dict=False, |
| num_inference_steps=num_steps, |
| inv_cfg=inv_cfg)[0][0] |
|
|
| elif args.method == 'z-core': |
| output = pipe.z_core( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_steps, |
| latents=start_latents, |
| return_dict=False, |
| refine_model=refine_model, |
| lora_true=lora_true, |
| end_timesteps=end_timesteps, |
| w2s_guidance=w2s_guidance, |
| inv_cfg=inv_cfg)[0][0] |
| |
| elif args.method == 'standard': |
| output = pipe( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| latents=start_latents, |
| return_dict=False, |
| num_inference_steps=num_steps)[0][0] |
| else: |
| raise ValueError("Invalid method") |
| |
| output.save(f'{args.model}_{args.method}.png') |
|
|
|
|
| else: |
| if args.method == 'core': |
| output = pipe.core( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_steps, |
| latents=start_latents, |
| max_sequence_length=512, |
| return_dict=False, |
| refine_model=refine_model, |
| lora_true=lora_true, |
| end_timesteps=end_timesteps, |
| w2s_guidance=w2s_guidance)[0][0] |
| |
| elif args.method == 'zigzag': |
| output = pipe.zigzag( |
| prompt=prompt, |
| max_sequence_length=512, |
| guidance_scale=guidance_scale, |
| latents=start_latents, |
| return_dict=False, |
| num_inference_steps=num_steps, |
| inv_cfg=inv_cfg)[0][0] |
|
|
| elif args.method == 'z-core': |
| output = pipe.z_core( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_steps, |
| latents=start_latents, |
| return_dict=False, |
| max_sequence_length=512, |
| refine_model=refine_model, |
| lora_true=lora_true, |
| end_timesteps=end_timesteps, |
| w2s_guidance=w2s_guidance)[0][0] |
| |
| elif args.method == 'standard': |
| output = pipe( |
| prompt=prompt, |
| guidance_scale=guidance_scale, |
| latents=start_latents, |
| return_dict=False, |
| max_sequence_length=512, |
| num_inference_steps=num_steps)[0][0] |
| else: |
| raise ValueError("Invalid method") |
| |
| output.save(f'{args.model}_{args.method}.png') |
|
|
| |