| import glob |
| from os import path |
| from paths import get_file_name, FastStableDiffusionPaths |
| from pathlib import Path |
|
|
|
|
| class _lora_info: |
| """ |
| A basic class to keep track of the currently loaded LoRAs and their weights. |
| |
| The diffusers function _get_active_adapters()_ returns a list of adapter |
| names but not their weights so we need a way to keep track of the current |
| LoRA weights to set whenever a new LoRA is loaded. |
| """ |
|
|
| def __init__( |
| self, |
| path: str, |
| weight: float, |
| ): |
| self.path = path |
| self.adapter_name = get_file_name(path) |
| self.weight = weight |
|
|
| def __del__(self): |
| self.path = None |
| self.adapter_name = None |
|
|
|
|
| _loaded_loras = [] |
| _current_pipeline = None |
|
|
|
|
| def load_lora_weight( |
| pipeline, |
| lcm_diffusion_setting, |
| ): |
| """ |
| Loads a LoRA from the LoRA path setting. |
| |
| This function loads a LoRA from the LoRA path stored in the settings so |
| it's possible to load multiple LoRAs by calling this function more than |
| once with a different LoRA path setting; note that if you plan to load |
| multiple LoRAs and dynamically change their weights, you might want to |
| set the LoRA fuse option to _False_. |
| """ |
| if not lcm_diffusion_setting.lora.path: |
| raise Exception("Empty lora model path") |
|
|
| if not path.exists(lcm_diffusion_setting.lora.path): |
| raise Exception("Lora model path is invalid") |
|
|
| |
| |
| global _loaded_loras |
| global _current_pipeline |
| if pipeline != _current_pipeline: |
| reset_active_lora_weights() |
| _current_pipeline = pipeline |
|
|
| current_lora = _lora_info( |
| lcm_diffusion_setting.lora.path, |
| lcm_diffusion_setting.lora.weight, |
| ) |
| _loaded_loras.append(current_lora) |
|
|
| if lcm_diffusion_setting.lora.enabled: |
| print(f"LoRA adapter name : {current_lora.adapter_name}") |
| pipeline.load_lora_weights( |
| FastStableDiffusionPaths.get_lora_models_path(), |
| weight_name=Path(lcm_diffusion_setting.lora.path).name, |
| local_files_only=True, |
| adapter_name=current_lora.adapter_name, |
| ) |
| update_lora_weights( |
| pipeline, |
| lcm_diffusion_setting, |
| ) |
|
|
| if lcm_diffusion_setting.lora.fuse: |
| pipeline.fuse_lora() |
|
|
|
|
| def get_lora_models(root_dir: str): |
| lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) |
| lora_models_map = {} |
| for file_path in lora_models: |
| lora_name = get_file_name(file_path) |
| if lora_name is not None: |
| lora_models_map[lora_name] = file_path |
| return lora_models_map |
|
|
|
|
| def get_active_lora_weights(): |
| """ |
| Returns a list of _(adapter_name, weight)_ tuples for the currently loaded LoRAs. |
| """ |
| active_loras = [] |
| for lora_info in _loaded_loras: |
| active_loras.append( |
| ( |
| lora_info.adapter_name, |
| lora_info.weight, |
| ) |
| ) |
| return active_loras |
|
|
|
|
| def reset_active_lora_weights(): |
| """ |
| Clears the global list of active LoRA weights. |
| |
| This method clears the list of active LoRA weights but it doesn't actually |
| remove the active LoRA weights from the current generation pipeline. |
| This method is only meant to be called when rebuilding the generation pipeline |
| as it will also clear the _current_pipeline_ variable; setting the |
| _current_pipeline_ variable to _None_ is safe here since the active LoRA weights |
| list is being reset, but it also helps to remove the pipeline reference that |
| might prevent the garbage collector from releasing the current pipeline memory. |
| """ |
| global _loaded_loras |
| for lora in _loaded_loras: |
| del lora |
| del _loaded_loras |
| _loaded_loras = [] |
|
|
| global _current_pipeline |
| _current_pipeline = None |
|
|
|
|
| def update_lora_weights( |
| pipeline, |
| lcm_diffusion_setting, |
| lora_weights=None, |
| ): |
| """ |
| Updates the LoRA weights for the currently active LoRAs. |
| |
| Args: |
| pipeline: The currently active pipeline. |
| lcm_diffusion_setting: The global settings, needed to verify if the |
| pipeline is running in LCM-LoRA mode. |
| lora_weights: An optional list of updated _(adapter_name, weight)_ tuples. |
| """ |
| global _loaded_loras |
| global _current_pipeline |
| if pipeline != _current_pipeline: |
| print("Wrong pipeline when trying to update LoRA weights") |
| return |
| if lora_weights: |
| for idx, lora in enumerate(lora_weights): |
| if _loaded_loras[idx].adapter_name != lora[0]: |
| print("Wrong adapter name in LoRA enumeration!") |
| continue |
| _loaded_loras[idx].weight = lora[1] |
|
|
| adapter_names = [] |
| adapter_weights = [] |
| if lcm_diffusion_setting.use_lcm_lora: |
| adapter_names.append("lcm") |
| adapter_weights.append(1.0) |
| for lora in _loaded_loras: |
| adapter_names.append(lora.adapter_name) |
| adapter_weights.append(lora.weight) |
| pipeline.set_adapters( |
| adapter_names, |
| adapter_weights=adapter_weights, |
| ) |
| adapter_weights = zip(adapter_names, adapter_weights) |
| print(f"Adapters: {list(adapter_weights)}") |
|
|