| import pathlib |
| from os import path |
|
|
| import torch |
| from diffusers import ( |
| AutoPipelineForText2Image, |
| LCMScheduler, |
| StableDiffusionPipeline, |
| StableDiffusionXLPipeline, |
| ) |
|
|
|
|
| def load_lcm_weights( |
| pipeline, |
| use_local_model, |
| lcm_lora_id, |
| ): |
| if pathlib.Path(lcm_lora_id).suffix == ".safetensors": |
| path = pathlib.Path(lcm_lora_id) |
| |
| |
| pipeline.load_lora_weights( |
| path.parent, |
| local_files_only=True, |
| weight_name=path.name, |
| adapter_name="lcm", |
| ) |
| else: |
| kwargs = { |
| "local_files_only": use_local_model, |
| "weight_name": "pytorch_lora_weights.safetensors", |
| } |
| pipeline.load_lora_weights( |
| lcm_lora_id, |
| **kwargs, |
| adapter_name="lcm", |
| ) |
|
|
|
|
| def get_lcm_lora_pipeline( |
| base_model_id: str, |
| lcm_lora_id: str, |
| use_local_model: bool, |
| torch_data_type: torch.dtype, |
| pipeline_args={}, |
| ): |
| if pathlib.Path(base_model_id).suffix == ".safetensors": |
| |
| |
| |
| |
| |
| if not path.exists(base_model_id): |
| raise FileNotFoundError( |
| f"Model file not found,Please check your model path: {base_model_id}" |
| ) |
| print("Using single file Safetensors model") |
|
|
| if "xl" in base_model_id.lower(): |
| dummy_pipeline = StableDiffusionXLPipeline.from_single_file( |
| base_model_id, |
| torch_dtype=torch_data_type, |
| safety_checker=None, |
| local_files_only=use_local_model, |
| use_safetensors=True, |
| ) |
| else: |
| dummy_pipeline = StableDiffusionPipeline.from_single_file( |
| base_model_id, |
| torch_dtype=torch_data_type, |
| safety_checker=None, |
| local_files_only=use_local_model, |
| use_safetensors=True, |
| ) |
|
|
| pipeline = AutoPipelineForText2Image.from_pipe( |
| dummy_pipeline, |
| **pipeline_args, |
| ) |
| del dummy_pipeline |
| else: |
| pipeline = AutoPipelineForText2Image.from_pretrained( |
| base_model_id, |
| torch_dtype=torch_data_type, |
| local_files_only=use_local_model, |
| **pipeline_args, |
| ) |
|
|
| load_lcm_weights( |
| pipeline, |
| use_local_model, |
| lcm_lora_id, |
| ) |
| |
| |
|
|
| lcmlora = lcm_lora_id.lower() |
| if "lcm" in lcmlora or "hypersd" in lcmlora or "dmd2" in lcmlora: |
| print("LCM LoRA model detected so using recommended LCMScheduler") |
| pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) |
|
|
| |
| return pipeline |
|
|