| import pathlib
|
| from os import path
|
|
|
| import torch
|
| from diffusers import (
|
| AutoPipelineForText2Image,
|
| LCMScheduler,
|
| StableDiffusionPipeline,
|
| )
|
|
|
|
|
| def load_lcm_weights(
|
| pipeline,
|
| use_local_model,
|
| lcm_lora_id,
|
| ):
|
| kwargs = {
|
| "local_files_only": use_local_model,
|
| "weight_name": "pytorch_lora_weights.safetensors",
|
| }
|
| pipeline.load_lora_weights(
|
| lcm_lora_id,
|
| **kwargs,
|
| adapter_name="lcm",
|
| )
|
|
|
|
|
| def get_lcm_lora_pipeline(
|
| base_model_id: str,
|
| lcm_lora_id: str,
|
| use_local_model: bool,
|
| torch_data_type: torch.dtype,
|
| pipeline_args={},
|
| ):
|
| if pathlib.Path(base_model_id).suffix == ".safetensors":
|
|
|
|
|
|
|
|
|
|
|
|
|
| if not path.exists(base_model_id):
|
| raise FileNotFoundError(
|
| f"Model file not found,Please check your model path: {base_model_id}"
|
| )
|
| print("Using single file Safetensors model (Supported models - SD 1.5 models)")
|
|
|
| dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
| base_model_id,
|
| torch_dtype=torch_data_type,
|
| safety_checker=None,
|
| local_files_only=use_local_model,
|
| use_safetensors=True,
|
| )
|
| pipeline = AutoPipelineForText2Image.from_pipe(
|
| dummy_pipeline,
|
| **pipeline_args,
|
| )
|
| del dummy_pipeline
|
| else:
|
| pipeline = AutoPipelineForText2Image.from_pretrained(
|
| base_model_id,
|
| torch_dtype=torch_data_type,
|
| local_files_only=use_local_model,
|
| **pipeline_args,
|
| )
|
|
|
| load_lcm_weights(
|
| pipeline,
|
| use_local_model,
|
| lcm_lora_id,
|
| )
|
|
|
|
|
|
|
| if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
|
| print("LCM LoRA model detected so using recommended LCMScheduler")
|
| pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
|
|
|
|
| return pipeline
|
|
|