File size: 7,287 Bytes
4189926 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | from os import path
from PIL import Image
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import (
QApplication,
QHBoxLayout,
QVBoxLayout,
QGridLayout,
QPushButton,
QSlider,
QLabel,
QFrame,
QComboBox,
QWidget,
QSizePolicy,
QMessageBox,
)
from backend.lora import (
get_lora_models,
get_active_lora_weights,
update_lora_weights,
load_lora_weight,
)
from frontend.gui.common_widgets import LabeledSlider
from app_settings import AppSettings
from paths import FastStableDiffusionPaths
if __name__ != "__main__":
from state import get_settings, get_context
from models.interface_types import InterfaceType
# app_settings = get_settings()
_MAX_LORA_WEIGHTS = 5
_current_lora_count = 0
_active_lora_widgets = []
# This is a simple widget for displaying the loaded LoRAs name and weight
class _LoraWidget(QWidget):
def __init__(self):
super().__init__()
self.name_label = QLabel()
self.strength_slider = LabeledSlider(True)
hlayout = QHBoxLayout()
hlayout.addWidget(self.name_label)
hlayout.addWidget(self.strength_slider)
self.setLayout(hlayout)
def setValues(self, name: str, weight: float):
self.name_label.setText(name)
self.strength_slider.setValue(weight)
def getValues(self):
return (self.name_label.text(), self.strength_slider.getValue())
class LoraModelsWidget(QWidget):
def __init__(self, config: AppSettings, parent):
super().__init__()
self.parent = parent
self.config = config
lora_models_map = {}
if config != None:
lora_models_map = get_lora_models(
config.settings.lcm_diffusion_setting.lora.models_dir
)
self.models_combobox = QComboBox()
self.models_combobox.addItems(lora_models_map.keys())
self.models_combobox.setToolTip(
"<p style='white-space:pre'>Place LoRA models in the <b>lora_models</b> folder</p>"
)
self.weight_slider = LabeledSlider(True)
self.load_button = QPushButton("Load selected LoRA")
self.load_button.setEnabled(False)
self.load_button.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
self.load_button.setStyleSheet("padding: 10px")
self.load_button.clicked.connect(self.on_load_lora)
if len(lora_models_map) > 0:
self.load_button.setEnabled(True)
self.loaded_label = QLabel("Loaded LoRA models:")
self.update_button = QPushButton("Update LoRA weights")
self.update_button.setEnabled(False)
self.update_button.clicked.connect(self.on_update_weights)
self.separator = QLabel()
self.separator.setFrameShape(QFrame.HLine)
glayout = QGridLayout()
glayout.setVerticalSpacing(0)
glayout.addWidget(QLabel("LoRA model:"), 0, 0)
glayout.addWidget(
QLabel(
"Initial LoRA weight:",
),
0,
1,
)
glayout.addWidget(self.models_combobox, 1, 0)
glayout.addWidget(self.weight_slider, 1, 1)
glayout.addWidget(self.load_button, 0, 2, 2, 1)
hlayout = QHBoxLayout()
hlayout.addWidget(self.loaded_label)
hlayout.addWidget(self.update_button)
vlayout = QVBoxLayout()
vlayout.addLayout(glayout, 10)
vlayout.addWidget(self.separator, 1)
vlayout.addLayout(hlayout, 10)
vlayout.addStretch(80)
self.setLayout(vlayout)
def on_load_lora(self):
# Code for testing the GUI; ignore when running FastSD CPU
if __name__ == "__main__":
self.layout().insertWidget(3, _LoraWidget())
return
# End of code for testing the GUI
global _current_lora_count
global _active_lora_widgets
if (
self.config == None
or self.config.settings == None
or _current_lora_count >= _MAX_LORA_WEIGHTS
):
return
if self.config.settings.lcm_diffusion_setting.use_openvino:
QMessageBox().information(
self.parent,
"Error",
"LoRA suppport is currently not implemented for OpenVINO.",
)
return
lora_models_map = get_lora_models(
self.config.settings.lcm_diffusion_setting.lora.models_dir
)
# Load a new LoRA
settings = self.config.settings.lcm_diffusion_setting
settings.lora.fuse = False
settings.lora.enabled = False
current_lora = self.models_combobox.currentText()
current_weight = self.weight_slider.getValue()
print(f"Selected Lora Model :{current_lora}")
print(f"Lora weight :{current_weight}")
settings.lora.path = lora_models_map[current_lora]
settings.lora.weight = current_weight
if not path.exists(settings.lora.path):
QMessageBox.information(self.parent, "Error", "Invalid LoRA model path!")
return
if not self.parent.context.lcm_text_to_image.pipeline:
QMessageBox.information(
self.parent,
"Error",
"Pipeline not initialized. Please generate an image first.",
)
return
settings.lora.enabled = True
load_lora_weight(
self.parent.context.lcm_text_to_image.pipeline,
settings,
)
lora_widget = _LoraWidget()
lora_widget.setValues(current_lora, current_weight)
self.layout().insertWidget(3, lora_widget)
self.update_button.setEnabled(True)
_active_lora_widgets.append(lora_widget)
_current_lora_count += 1
def on_update_weights(self):
update_weights = []
active_weights = get_active_lora_weights()
if not len(active_weights):
return
global _active_lora_widgets
for idx, lora in enumerate(active_weights):
update_weights.append(
(
lora[0],
_active_lora_widgets[idx].getValues()[1],
)
)
if len(update_weights) > 0:
update_lora_weights(
self.parent.context.lcm_text_to_image.pipeline,
self.config.settings.lcm_diffusion_setting,
update_weights,
)
def reset_active_lora_widgets(self):
# This code assumes that the only time when the active LoRA weights count
# is different from the current LoRA GUI widgets count is after a pipeline
# rebuild, when the active LoRA widgets count will be zero, so all LoRA GUI
# widgets are simply removed with no further action
global _current_lora_count
global _active_lora_widgets
if len(get_active_lora_weights()) != _current_lora_count:
for lora_widget in _active_lora_widgets:
self.layout().removeWidget(lora_widget)
_current_lora_count = 0
_active_lora_widgets = []
# Test the widget
if __name__ == "__main__":
import sys
app = QApplication(sys.argv)
widget = LoraModelsWidget(None, None)
widget.show()
app.exec()
|