| import glob |
| from os import path |
| from paths import get_file_name, FastStableDiffusionPaths |
| from pathlib import Path |
|
|
|
|
| |
| |
| |
| |
| |
| class _lora_info: |
| def __init__( |
| self, |
| path: str, |
| weight: float, |
| ): |
| self.path = path |
| self.adapter_name = get_file_name(path) |
| self.weight = weight |
|
|
| def __del__(self): |
| self.path = None |
| self.adapter_name = None |
|
|
|
|
| _loaded_loras = [] |
| _current_pipeline = None |
|
|
|
|
| |
| |
| |
| |
| |
| def load_lora_weight( |
| pipeline, |
| lcm_diffusion_setting, |
| ): |
| if not lcm_diffusion_setting.lora.path: |
| raise Exception("Empty lora model path") |
|
|
| if not path.exists(lcm_diffusion_setting.lora.path): |
| raise Exception("Lora model path is invalid") |
|
|
| |
| |
| global _loaded_loras |
| global _current_pipeline |
| if pipeline != _current_pipeline: |
| for lora in _loaded_loras: |
| del lora |
| del _loaded_loras |
| _loaded_loras = [] |
| _current_pipeline = pipeline |
|
|
| current_lora = _lora_info( |
| lcm_diffusion_setting.lora.path, |
| lcm_diffusion_setting.lora.weight, |
| ) |
| _loaded_loras.append(current_lora) |
|
|
| if lcm_diffusion_setting.lora.enabled: |
| print(f"LoRA adapter name : {current_lora.adapter_name}") |
| pipeline.load_lora_weights( |
| FastStableDiffusionPaths.get_lora_models_path(), |
| weight_name=Path(lcm_diffusion_setting.lora.path).name, |
| local_files_only=True, |
| adapter_name=current_lora.adapter_name, |
| ) |
| update_lora_weights( |
| pipeline, |
| lcm_diffusion_setting, |
| ) |
|
|
| if lcm_diffusion_setting.lora.fuse: |
| pipeline.fuse_lora() |
|
|
|
|
| def get_lora_models(root_dir: str): |
| lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) |
| lora_models_map = {} |
| for file_path in lora_models: |
| lora_name = get_file_name(file_path) |
| if lora_name is not None: |
| lora_models_map[lora_name] = file_path |
| return lora_models_map |
|
|
|
|
| |
| |
| def get_active_lora_weights(): |
| active_loras = [] |
| for lora_info in _loaded_loras: |
| active_loras.append( |
| ( |
| lora_info.adapter_name, |
| lora_info.weight, |
| ) |
| ) |
| return active_loras |
|
|
|
|
| |
| |
| def update_lora_weights( |
| pipeline, |
| lcm_diffusion_setting, |
| lora_weights=None, |
| ): |
| global _loaded_loras |
| global _current_pipeline |
| if pipeline != _current_pipeline: |
| print("Wrong pipeline when trying to update LoRA weights") |
| return |
| if lora_weights: |
| for idx, lora in enumerate(lora_weights): |
| if _loaded_loras[idx].adapter_name != lora[0]: |
| print("Wrong adapter name in LoRA enumeration!") |
| continue |
| _loaded_loras[idx].weight = lora[1] |
|
|
| adapter_names = [] |
| adapter_weights = [] |
| if lcm_diffusion_setting.use_lcm_lora: |
| adapter_names.append("lcm") |
| adapter_weights.append(1.0) |
| for lora in _loaded_loras: |
| adapter_names.append(lora.adapter_name) |
| adapter_weights.append(lora.weight) |
| pipeline.set_adapters( |
| adapter_names, |
| adapter_weights=adapter_weights, |
| ) |
| adapter_weights = zip(adapter_names, adapter_weights) |
| print(f"Adapters: {list(adapter_weights)}") |
|
|