| import argparse |
| import torch |
| from omegaconf import OmegaConf |
|
|
| from ldm.models.diffusion.ddim import DDIMSampler |
| from ldm.models.diffusion.plms import PLMSSampler |
| from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light |
| from ldm.modules.extra_condition.api import ExtraCondition |
| from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict |
|
|
| DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
| 'fewer digits, cropped, worst quality, low quality' |
|
|
|
|
| def get_base_argument_parser() -> argparse.ArgumentParser: |
| """get the base argument parser for inference scripts""" |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '--outdir', |
| type=str, |
| help='dir to write results to', |
| default=None, |
| ) |
|
|
| parser.add_argument( |
| '--prompt', |
| type=str, |
| nargs='?', |
| default=None, |
| help='positive prompt', |
| ) |
|
|
| parser.add_argument( |
| '--neg_prompt', |
| type=str, |
| default=DEFAULT_NEGATIVE_PROMPT, |
| help='negative prompt', |
| ) |
|
|
| parser.add_argument( |
| '--cond_path', |
| type=str, |
| default=None, |
| help='condition image path', |
| ) |
|
|
| parser.add_argument( |
| '--cond_inp_type', |
| type=str, |
| default='image', |
| help='the type of the input condition image, take depth T2I as example, the input can be raw image, ' |
| 'which depth will be calculated, or the input can be a directly a depth map image', |
| ) |
|
|
| parser.add_argument( |
| '--sampler', |
| type=str, |
| default='ddim', |
| choices=['ddim', 'plms'], |
| help='sampling algorithm, currently, only ddim and plms are supported, more are on the way', |
| ) |
|
|
| parser.add_argument( |
| '--steps', |
| type=int, |
| default=50, |
| help='number of sampling steps', |
| ) |
|
|
| parser.add_argument( |
| '--sd_ckpt', |
| type=str, |
| default='models/sd-v1-4.ckpt', |
| help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', |
| ) |
|
|
| parser.add_argument( |
| '--vae_ckpt', |
| type=str, |
| default=None, |
| help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', |
| ) |
|
|
| parser.add_argument( |
| '--adapter_ckpt', |
| type=str, |
| default=None, |
| help='path to checkpoint of adapter', |
| ) |
|
|
| parser.add_argument( |
| '--config', |
| type=str, |
| default='configs/stable-diffusion/sd-v1-inference.yaml', |
| help='path to config which constructs SD model', |
| ) |
|
|
| parser.add_argument( |
| '--max_resolution', |
| type=float, |
| default=512 * 512, |
| help='max image height * width, only for computer with limited vram', |
| ) |
|
|
| parser.add_argument( |
| '--resize_short_edge', |
| type=int, |
| default=None, |
| help='resize short edge of the input image, if this arg is set, max_resolution will not be used', |
| ) |
|
|
| parser.add_argument( |
| '--C', |
| type=int, |
| default=4, |
| help='latent channels', |
| ) |
|
|
| parser.add_argument( |
| '--f', |
| type=int, |
| default=8, |
| help='downsampling factor', |
| ) |
|
|
| parser.add_argument( |
| '--scale', |
| type=float, |
| default=7.5, |
| help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', |
| ) |
|
|
| parser.add_argument( |
| '--cond_tau', |
| type=float, |
| default=1.0, |
| help='timestamp parameter that determines until which step the adapter is applied, ' |
| 'similar as Prompt-to-Prompt tau', |
| ) |
|
|
| parser.add_argument( |
| '--style_cond_tau', |
| type=float, |
| default=1.0, |
| help='timestamp parameter that determines until which step the adapter is applied, ' |
| 'similar as Prompt-to-Prompt tau', |
| ) |
|
|
| parser.add_argument( |
| '--cond_weight', |
| type=float, |
| default=1.0, |
| help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned ' |
| 'the generated image and condition will be, but the generated quality may be reduced', |
| ) |
|
|
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=42, |
| ) |
|
|
| parser.add_argument( |
| '--n_samples', |
| type=int, |
| default=4, |
| help='# of samples to generate', |
| ) |
|
|
| return parser |
|
|
|
|
| def get_sd_models(opt): |
| """ |
| build stable diffusion model, sampler |
| """ |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| config = OmegaConf.load(f"{opt.config}") |
| model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) |
| model = model.half() |
| sd_model = model.to(opt.device) |
|
|
| |
| if opt.sampler == 'plms': |
| sampler = PLMSSampler(model) |
| elif opt.sampler == 'ddim': |
| sampler = DDIMSampler(model) |
| else: |
| raise NotImplementedError |
|
|
| return sd_model, sampler |
|
|
|
|
| def get_t2i_adapter_models(opt): |
| config = OmegaConf.load(f"{opt.config}") |
| model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) |
| adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None) |
| if adapter_ckpt_path is None: |
| adapter_ckpt_path = getattr(opt, 'adapter_ckpt') |
| adapter_ckpt = read_state_dict(adapter_ckpt_path) |
| new_state_dict = {} |
| for k, v in adapter_ckpt.items(): |
| if not k.startswith('adapter.'): |
| new_state_dict[f'adapter.{k}'] = v |
| else: |
| new_state_dict[k] = v |
| m, u = model.load_state_dict(new_state_dict, strict=False) |
| if len(u) > 0: |
| print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:") |
| print(u) |
|
|
| model = model.to(opt.device) |
|
|
| |
| if opt.sampler == 'plms': |
| sampler = PLMSSampler(model) |
| elif opt.sampler == 'ddim': |
| sampler = DDIMSampler(model) |
| else: |
| raise NotImplementedError |
|
|
| return model, sampler |
|
|
|
|
| def get_cond_ch(cond_type: ExtraCondition): |
| if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny: |
| return 1 |
| return 3 |
|
|
|
|
| def get_adapters(opt, cond_type: ExtraCondition): |
| adapter = {} |
| cond_weight = getattr(opt, f'{cond_type.name}_weight', None) |
| if cond_weight is None: |
| cond_weight = getattr(opt, 'cond_weight') |
| adapter['cond_weight'] = cond_weight |
|
|
| if cond_type == ExtraCondition.style: |
| adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device) |
| elif cond_type == ExtraCondition.color: |
| adapter['model'] = Adapter_light( |
| cin=64 * get_cond_ch(cond_type), |
| channels=[320, 640, 1280, 1280], |
| nums_rb=4).to(opt.device) |
| else: |
| adapter['model'] = Adapter( |
| cin=64 * get_cond_ch(cond_type), |
| channels=[320, 640, 1280, 1280][:4], |
| nums_rb=2, |
| ksize=1, |
| sk=True, |
| use_conv=False).to(opt.device) |
| ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None) |
| if ckpt_path is None: |
| ckpt_path = getattr(opt, 'adapter_ckpt') |
| adapter['model'].load_state_dict(torch.load(ckpt_path)) |
|
|
| return adapter |
|
|
|
|
| def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None): |
| |
| c = model.get_learned_conditioning([opt.prompt]) |
| if opt.scale != 1.0: |
| uc = model.get_learned_conditioning([opt.neg_prompt]) |
| else: |
| uc = None |
| c, uc = fix_cond_shapes(model, c, uc) |
|
|
| if not hasattr(opt, 'H'): |
| opt.H = 512 |
| opt.W = 512 |
| shape = [opt.C, opt.H // opt.f, opt.W // opt.f] |
|
|
| samples_latents, _ = sampler.sample( |
| S=opt.steps, |
| conditioning=c, |
| batch_size=1, |
| shape=shape, |
| verbose=False, |
| unconditional_guidance_scale=opt.scale, |
| unconditional_conditioning=uc, |
| x_T=None, |
| features_adapter=adapter_features, |
| append_to_context=append_to_context, |
| cond_tau=opt.cond_tau, |
| style_cond_tau=opt.style_cond_tau, |
| ) |
|
|
| x_samples = model.decode_first_stage(samples_latents) |
| x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
| return x_samples |
|
|