| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline |
| from transformers import CLIPVisionModelWithProjection |
| import torch |
| from copy import deepcopy |
|
|
| ENABLE_CPU_CACHE = False |
| DEFAULT_BASE_MODEL = "benjamin-paine/stable-diffusion-v1-5" |
|
|
| cached_models = {} |
| def cache_model(func): |
| def wrapper(*args, **kwargs): |
| if ENABLE_CPU_CACHE: |
| model_name = func.__name__ + str(args) + str(kwargs) |
| if model_name not in cached_models: |
| cached_models[model_name] = func(*args, **kwargs) |
| return cached_models[model_name] |
| else: |
| return func(*args, **kwargs) |
| return wrapper |
|
|
| def copied_cache_model(func): |
| def wrapper(*args, **kwargs): |
| if ENABLE_CPU_CACHE: |
| model_name = func.__name__ + str(args) + str(kwargs) |
| if model_name not in cached_models: |
| cached_models[model_name] = func(*args, **kwargs) |
| return deepcopy(cached_models[model_name]) |
| else: |
| return func(*args, **kwargs) |
| return wrapper |
|
|
| def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): |
| if ckpt_or_pretrained.endswith(".safetensors"): |
| pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) |
| else: |
| pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) |
| return pipe |
|
|
| @copied_cache_model |
| def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): |
| model_kwargs = dict( |
| torch_dtype=torch_dtype, |
| requires_safety_checker=False, |
| safety_checker=None, |
| ) |
| pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( |
| base_model, |
| StableDiffusionPipeline, |
| **model_kwargs |
| ) |
| pipe.to("cpu") |
| return pipe.components |
|
|
| @cache_model |
| def load_controlnet(controlnet_path, torch_dtype=torch.float16): |
| controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) |
| return controlnet |
|
|
| @cache_model |
| def load_image_encoder(): |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| "h94/IP-Adapter", |
| subfolder="models/image_encoder", |
| torch_dtype=torch.float16, |
| ) |
| return image_encoder |
|
|
| def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): |
| model_kwargs = dict( |
| torch_dtype=torch_dtype, |
| |
| requires_safety_checker=False, |
| safety_checker=None, |
| ) |
| components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) |
| model_kwargs.update(components) |
| model_kwargs.update(kwargs) |
| |
| if controlnet is not None: |
| if isinstance(controlnet, list): |
| controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] |
| else: |
| controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) |
| model_kwargs.update(controlnet=controlnet) |
| |
| if pipeline_class is None: |
| if controlnet is not None: |
| pipeline_class = StableDiffusionControlNetPipeline |
| else: |
| pipeline_class = StableDiffusionPipeline |
| |
| pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( |
| base_model, |
| pipeline_class, |
| **model_kwargs |
| ) |
|
|
| if ip_adapter: |
| image_encoder = load_image_encoder() |
| pipe.image_encoder = image_encoder |
| if plus_model: |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") |
| else: |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") |
| pipe.set_ip_adapter_scale(1.0) |
| else: |
| pipe.unload_ip_adapter() |
| |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
|
|
| if model_cpu_offload_seq is None: |
| if isinstance(pipe, StableDiffusionControlNetPipeline): |
| pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" |
| elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): |
| pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" |
| else: |
| pipe.model_cpu_offload_seq = model_cpu_offload_seq |
| |
| if enable_sequential_cpu_offload: |
| pipe.enable_sequential_cpu_offload() |
| else: |
| pass |
| pipe.enable_model_cpu_offload() |
| if vae_slicing: |
| pipe.enable_vae_slicing() |
| |
| import gc |
| gc.collect() |
| return pipe |
|
|
|
|