File size: 5,344 Bytes
4189926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import glob
from os import path
from paths import get_file_name, FastStableDiffusionPaths
from pathlib import Path


class _lora_info:
    """
    A basic class to keep track of the currently loaded LoRAs and their weights.

    The diffusers function _get_active_adapters()_ returns a list of adapter
    names but not their weights so we need a way to keep track of the current
    LoRA weights to set whenever a new LoRA is loaded.
    """

    def __init__(
        self,
        path: str,
        weight: float,
    ):
        self.path = path
        self.adapter_name = get_file_name(path)
        self.weight = weight

    def __del__(self):
        self.path = None
        self.adapter_name = None


_loaded_loras = []
_current_pipeline = None


def load_lora_weight(
    pipeline,
    lcm_diffusion_setting,
):
    """
    Loads a LoRA from the LoRA path setting.

    This function loads a LoRA from the LoRA path stored in the settings so
    it's possible to load multiple LoRAs by calling this function more than
    once with a different LoRA path setting; note that if you plan to load
    multiple LoRAs and dynamically change their weights, you might want to
    set the LoRA fuse option to _False_.
    """
    if not lcm_diffusion_setting.lora.path:
        raise Exception("Empty lora model path")

    if not path.exists(lcm_diffusion_setting.lora.path):
        raise Exception("Lora model path is invalid")

    # If the pipeline has been rebuilt since the last call, remove all
    # references to previously loaded LoRAs and store the new pipeline
    global _loaded_loras
    global _current_pipeline
    if pipeline != _current_pipeline:
        reset_active_lora_weights()
        _current_pipeline = pipeline

    current_lora = _lora_info(
        lcm_diffusion_setting.lora.path,
        lcm_diffusion_setting.lora.weight,
    )
    _loaded_loras.append(current_lora)

    if lcm_diffusion_setting.lora.enabled:
        print(f"LoRA adapter name : {current_lora.adapter_name}")
        pipeline.load_lora_weights(
            FastStableDiffusionPaths.get_lora_models_path(),
            weight_name=Path(lcm_diffusion_setting.lora.path).name,
            local_files_only=True,
            adapter_name=current_lora.adapter_name,
        )
        update_lora_weights(
            pipeline,
            lcm_diffusion_setting,
        )

        if lcm_diffusion_setting.lora.fuse:
            pipeline.fuse_lora()


def get_lora_models(root_dir: str):
    lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
    lora_models_map = {}
    for file_path in lora_models:
        lora_name = get_file_name(file_path)
        if lora_name is not None:
            lora_models_map[lora_name] = file_path
    return lora_models_map


def get_active_lora_weights():
    """
    Returns a list of _(adapter_name, weight)_ tuples for the currently loaded LoRAs.
    """
    active_loras = []
    for lora_info in _loaded_loras:
        active_loras.append(
            (
                lora_info.adapter_name,
                lora_info.weight,
            )
        )
    return active_loras


def reset_active_lora_weights():
    """
    Clears the global list of active LoRA weights.

    This method clears the list of active LoRA weights but it doesn't actually
    remove the active LoRA weights from the current generation pipeline.
    This method is only meant to be called when rebuilding the generation pipeline
    as it will also clear the _current_pipeline_ variable; setting the
    _current_pipeline_ variable to _None_ is safe here since the active LoRA weights
    list is being reset, but it also helps to remove the pipeline reference that
    might prevent the garbage collector from releasing the current pipeline memory.
    """
    global _loaded_loras
    for lora in _loaded_loras:
        del lora
    del _loaded_loras
    _loaded_loras = []

    global _current_pipeline
    _current_pipeline = None


def update_lora_weights(
    pipeline,
    lcm_diffusion_setting,
    lora_weights=None,
):
    """
    Updates the LoRA weights for the currently active LoRAs.

    Args:
        pipeline: The currently active pipeline.
        lcm_diffusion_setting: The global settings, needed to verify if the
            pipeline is running in LCM-LoRA mode.
        lora_weights: An optional list of updated _(adapter_name, weight)_ tuples.
    """
    global _loaded_loras
    global _current_pipeline
    if pipeline != _current_pipeline:
        print("Wrong pipeline when trying to update LoRA weights")
        return
    if lora_weights:
        for idx, lora in enumerate(lora_weights):
            if _loaded_loras[idx].adapter_name != lora[0]:
                print("Wrong adapter name in LoRA enumeration!")
                continue
            _loaded_loras[idx].weight = lora[1]

    adapter_names = []
    adapter_weights = []
    if lcm_diffusion_setting.use_lcm_lora:
        adapter_names.append("lcm")
        adapter_weights.append(1.0)
    for lora in _loaded_loras:
        adapter_names.append(lora.adapter_name)
        adapter_weights.append(lora.weight)
    pipeline.set_adapters(
        adapter_names,
        adapter_weights=adapter_weights,
    )
    adapter_weights = zip(adapter_names, adapter_weights)
    print(f"Adapters: {list(adapter_weights)}")