perorina commited on
Commit
36f7261
·
1 Parent(s): 847e2b9

Create scripts/runtime_block_merge.py

Browse files
Files changed (1) hide show
  1. scripts/runtime_block_merge.py +734 -0
scripts/runtime_block_merge.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import json
4
+ from datetime import datetime
5
+
6
+ import modules.scripts as scripts
7
+ import gradio as gr
8
+
9
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel
10
+ from modules import sd_models, shared, devices
11
+ from scripts.mbw_util.preset_weights import PresetWeights
12
+ import torch
13
+ from natsort import natsorted
14
+
15
+ from pathlib import Path
16
+ import safetensors.torch
17
+
18
+ presetWeights = PresetWeights()
19
+
20
+ shared.UNetBManager = None
21
+
22
+ known_block_prefixes = [
23
+ 'input_blocks.0.',
24
+ 'input_blocks.1.',
25
+ 'input_blocks.2.',
26
+ 'input_blocks.3.',
27
+ 'input_blocks.4.',
28
+ 'input_blocks.5.',
29
+ 'input_blocks.6.',
30
+ 'input_blocks.7.',
31
+ 'input_blocks.8.',
32
+ 'input_blocks.9.',
33
+ 'input_blocks.10.',
34
+ 'input_blocks.11.',
35
+ 'middle_block.',
36
+ 'out.',
37
+ 'output_blocks.0.',
38
+ 'output_blocks.1.',
39
+ 'output_blocks.2.',
40
+ 'output_blocks.3.',
41
+ 'output_blocks.4.',
42
+ 'output_blocks.5.',
43
+ 'output_blocks.6.',
44
+ 'output_blocks.7.',
45
+ 'output_blocks.8.',
46
+ 'output_blocks.9.',
47
+ 'output_blocks.10.',
48
+ 'output_blocks.11.',
49
+ 'time_embed.'
50
+ ]
51
+
52
+ class UNetStateManager(object):
53
+ def __init__(self, org_unet: UNetModel = None):
54
+ super().__init__()
55
+ self.modelB_state_dict_by_blocks = []
56
+ self.torch_unet = org_unet
57
+ # self.modelA_state_dict = copy.deepcopy(org_unet.state_dict())
58
+ self.modelA_state_dict = None
59
+ self.dtype = devices.dtype
60
+ self.modelA_state_dict_by_blocks = []
61
+ # self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks)
62
+ self.modelB_state_dict = None
63
+ # self.unet_block_module_list = []
64
+ self.unet_block_module_list = [*self.torch_unet.input_blocks, self.torch_unet.middle_block, self.torch_unet.out,
65
+ *self.torch_unet.output_blocks, self.torch_unet.time_embed]
66
+ self.applied_weights = [0] * 27
67
+ # self.gui_weights = [0.5] * 27
68
+ self.enabled = False
69
+ self.modelA_path = shared.sd_model.sd_model_checkpoint
70
+ self.modelB_path = ''
71
+ self.force_cpu = False
72
+ self.modelA_dtype = None
73
+ self.modelB_dtype = None
74
+ self.device = devices.get_cuda_device_string() if (torch.cuda.is_available() and not shared.cmd_opts.lowvram) else "cpu"
75
+
76
+ # def set_gui_weights(self, current_weights):
77
+ # self.gui_weights = current_weights
78
+
79
+ def reload_modelA(self):
80
+ if not self.enabled:
81
+ return
82
+
83
+ if self.modelA_path == shared.sd_model.sd_model_checkpoint and self.modelA_state_dict is not None:
84
+ return
85
+ self.modelA_path = shared.sd_model.sd_model_checkpoint
86
+
87
+ del self.modelA_state_dict_by_blocks
88
+ self.modelA_state_dict_by_blocks = []
89
+ # orig_modelA_state_dict_keys = list(self.modelA_state_dict.keys())
90
+ # for key in orig_modelA_state_dict_keys:
91
+ # del self.modelA_state_dict[key]
92
+ del self.modelA_state_dict
93
+ torch.cuda.empty_cache()
94
+ if self.force_cpu:
95
+ self.modelA_state_dict = self.filter_unet_state_dict(
96
+ sd_models.read_state_dict(self.modelA_path, map_location="cpu"))
97
+ self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks)
98
+ self.modelA_dtype = itertools.islice(self.modelA_state_dict.items(), 1).__next__()[1].dtype
99
+ else:
100
+ self.modelA_state_dict = copy.deepcopy(self.torch_unet.state_dict())
101
+ self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks)
102
+ # if self.enabled:
103
+ # self.model_state_apply(self.gui_weights)
104
+ self.model_state_apply(self.applied_weights)
105
+ print('model A reloaded')
106
+
107
+ def load_modelB(self, modelB_path, force_cpu_checkbox, current_weights):
108
+ self.force_cpu = force_cpu_checkbox
109
+ self.device = devices.get_cuda_device_string() if (torch.cuda.is_available() and not shared.cmd_opts.lowvram) else "cpu"
110
+ if self.force_cpu:
111
+ self.device = "cpu"
112
+ model_info = sd_models.get_closet_checkpoint_match(modelB_path)
113
+ checkpoint_file = model_info.filename
114
+ self.modelB_path = checkpoint_file
115
+
116
+
117
+ if self.modelA_path == checkpoint_file:
118
+ if not self.modelB_state_dict:
119
+ self.enabled = False
120
+ # self.gui_weights = current_weights
121
+ return False
122
+
123
+ # move initialization of model A to here
124
+ if not self.modelA_state_dict:
125
+ if self.force_cpu:
126
+ self.modelA_path = shared.sd_model.sd_model_checkpoint
127
+ self.modelA_state_dict = self.filter_unet_state_dict(
128
+ sd_models.read_state_dict(self.modelA_path, map_location="cpu"))
129
+ self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks)
130
+
131
+ else:
132
+ self.modelA_state_dict = copy.deepcopy(self.torch_unet.state_dict())
133
+ self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks)
134
+ # self.modelA_dtype = self.torch_unet.dtype
135
+ self.modelA_dtype = itertools.islice(self.modelA_state_dict.items(), 1).__next__()[1].dtype
136
+ sd_model_hash = model_info.hash
137
+ cache_enabled = shared.opts.sd_checkpoint_cache > 0
138
+
139
+ # if cache_enabled and model_info in sd_models.checkpoints_loaded:
140
+ # # use checkpoint cache
141
+ # print(f"Loading weights [{sd_model_hash}] from cache")
142
+ # self.modelB_state_dict = sd_models.checkpoints_loaded[model_info]
143
+
144
+ if self.modelB_state_dict:
145
+ # orig_modelB_state_dict_keys = list(self.modelB_state_dict.keys())
146
+ # for key in orig_modelB_state_dict_keys:
147
+ # del self.modelB_state_dict[key]
148
+ del self.modelB_state_dict_by_blocks
149
+ del self.modelB_state_dict
150
+ torch.cuda.empty_cache()
151
+ self.modelB_state_dict_by_blocks = []
152
+ self.modelB_state_dict = self.filter_unet_state_dict(
153
+ sd_models.read_state_dict(checkpoint_file, map_location=self.device))
154
+ self.modelB_dtype = itertools.islice(self.modelB_state_dict.items(), 1).__next__()[1].dtype
155
+ if len(self.modelA_state_dict) != len(self.modelB_state_dict):
156
+ print('modelA and modelB state dict have different length, aborting')
157
+ return False
158
+ self.map_blocks(self.modelB_state_dict, self.modelB_state_dict_by_blocks)
159
+ # verify self.modelA_state_dict and self.modelB_state_dict have same structure
160
+ self.model_state_apply(current_weights)
161
+
162
+ print('model B loaded')
163
+ self.enabled = True
164
+ return True
165
+
166
+ def model_state_apply(self, current_weights):
167
+ # self.gui_weights = current_weights
168
+ # ensuring maximum precision
169
+ operation_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16
170
+ for i in range(27):
171
+ cur_block_state_dict = {}
172
+ for cur_layer_key in self.modelA_state_dict_by_blocks[i]:
173
+ if operation_dtype == torch.float32:
174
+ # try:
175
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
176
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
177
+ current_weights[i]).to(self.dtype)
178
+ # except RuntimeError:
179
+ # # self.modelB_state_dict_by_blocks[i][cur_layer_key] = self.modelB_state_dict_by_blocks[i][cur_layer_key].to('cpu')
180
+ # self.modelA_state_dict_by_blocks[i][cur_layer_key] = self.modelA_state_dict_by_blocks[i][
181
+ # cur_layer_key].to('cpu')
182
+ # curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
183
+ # self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
184
+ # current_weights[i]).to(self.dtype)
185
+ else:
186
+ if self.force_cpu:
187
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
188
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
189
+ current_weights[i]).to(self.dtype)
190
+ else:
191
+ # try:
192
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key],
193
+ self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i])
194
+ # except RuntimeError:
195
+ # # self.modelB_state_dict_by_blocks[i][cur_layer_key] = self.modelB_state_dict_by_blocks[i][cur_layer_key].to('cpu')
196
+ # self.modelA_state_dict_by_blocks[i][cur_layer_key] = self.modelA_state_dict_by_blocks[i][cur_layer_key].to('cpu')
197
+ # curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key],
198
+ # self.modelB_state_dict_by_blocks[i][cur_layer_key],
199
+ # current_weights[i])
200
+ if str(shared.device) != self.device:
201
+ curlayer_tensor = curlayer_tensor.to(shared.device)
202
+ cur_block_state_dict[cur_layer_key] = curlayer_tensor
203
+ self.unet_block_module_list[i].load_state_dict(cur_block_state_dict)
204
+ self.applied_weights = current_weights
205
+
206
+ def model_state_construct(self, current_weights):
207
+ precision_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16
208
+ result_state_dict = {}
209
+ for i in range(27):
210
+ cur_block_state_dict = {}
211
+ for cur_layer_key in self.modelA_state_dict_by_blocks[i]:
212
+ if precision_dtype == torch.float32:
213
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
214
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
215
+ current_weights[i])
216
+ else:
217
+ if self.force_cpu:
218
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
219
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
220
+ current_weights[i]).to(torch.float16)
221
+ else:
222
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key],
223
+ self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i])
224
+
225
+ result_state_dict[known_block_prefixes[i] + cur_layer_key] = curlayer_tensor
226
+ return result_state_dict
227
+
228
+
229
+
230
+ def model_state_apply_modified_blocks(self, current_weights, current_model_B):
231
+ if not self.enabled:
232
+ return
233
+ modelB_info = sd_models.get_closet_checkpoint_match(current_model_B)
234
+ checkpoint_file_B = modelB_info.filename
235
+ if checkpoint_file_B != self.modelB_path:
236
+ print('model B changed, shouldn\'t happen')
237
+ self.load_modelB(current_model_B, current_weights)
238
+ return
239
+ if self.applied_weights == current_weights:
240
+ return
241
+ operation_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16
242
+ for i in range(27):
243
+ if current_weights[i] != self.applied_weights[i]:
244
+ cur_block_state_dict = {}
245
+ for cur_layer_key in self.modelA_state_dict_by_blocks[i]:
246
+ if operation_dtype == torch.float32:
247
+ curlayer_tensor = torch.lerp(
248
+ self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
249
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
250
+ current_weights[i]).to(self.dtype)
251
+ else:
252
+ if self.force_cpu:
253
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
254
+ self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32),
255
+ current_weights[i]).to(torch.float16)
256
+ else:
257
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key],
258
+ self.modelB_state_dict_by_blocks[i][cur_layer_key],
259
+ current_weights[i])
260
+ if str(shared.device) != self.device:
261
+ curlayer_tensor = curlayer_tensor.to(shared.device)
262
+ cur_block_state_dict[cur_layer_key] = curlayer_tensor
263
+ self.unet_block_module_list[i].load_state_dict(cur_block_state_dict)
264
+ self.applied_weights = current_weights
265
+
266
+
267
+
268
+
269
+ # diff current_weights and self.applied_weights, apply only the difference
270
+ def model_state_apply_block(self, current_weights):
271
+ # self.gui_weights = current_weights
272
+ if not self.enabled:
273
+ return self.applied_weights
274
+ for i in range(27):
275
+ if current_weights[i] != self.applied_weights[i]:
276
+ cur_block_state_dict = {}
277
+ for cur_layer_key in self.modelA_state_dict_by_blocks[i]:
278
+ curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key],
279
+ self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i])
280
+ cur_block_state_dict[cur_layer_key] = curlayer_tensor
281
+ self.unet_block_module_list[i].load_state_dict(cur_block_state_dict)
282
+ self.applied_weights = current_weights
283
+ return self.applied_weights
284
+
285
+ # filter input_dict to include only keys starting with 'model.diffusion_model'
286
+ def filter_unet_state_dict(self, input_dict):
287
+ filtered_dict = {}
288
+ for key, value in input_dict.items():
289
+
290
+ if key.startswith('model.diffusion_model'):
291
+ filtered_dict[key[22:]] = value
292
+ filtered_dict_keys = natsorted(filtered_dict.keys())
293
+ filtered_dict = {k: filtered_dict[k] for k in filtered_dict_keys}
294
+
295
+ return filtered_dict
296
+
297
+ def map_blocks(self, model_state_dict_input, model_state_dict_by_blocks):
298
+ if model_state_dict_by_blocks:
299
+ print('mapping to non empty list')
300
+ return
301
+ model_state_dict_sorted_keys = natsorted(model_state_dict_input.keys())
302
+ # sort model_state_dict by model_state_dict_sorted_keys
303
+ model_state_dict = {k: model_state_dict_input[k] for k in model_state_dict_sorted_keys}
304
+
305
+
306
+ current_block_index = 0
307
+ processing_block_dict = {}
308
+ for key in model_state_dict:
309
+ # print(key)
310
+ if not key.startswith(known_block_prefixes[current_block_index]):
311
+ if not key.startswith(known_block_prefixes[current_block_index + 1]):
312
+ print(
313
+ f"unknown key {key} in statedict after block {known_block_prefixes[current_block_index]}, possible UNet structure deviation"
314
+ )
315
+ continue
316
+ else:
317
+ model_state_dict_by_blocks.append(processing_block_dict)
318
+ processing_block_dict = {}
319
+ current_block_index += 1
320
+ block_local_key = key[len(known_block_prefixes[current_block_index]):]
321
+ processing_block_dict[block_local_key] = model_state_dict[key]
322
+
323
+ model_state_dict_by_blocks.append(processing_block_dict)
324
+ print('mapping complete')
325
+ return
326
+
327
+ def restore_original_unet(self):
328
+ self.torch_unet.load_state_dict(self.modelA_state_dict)
329
+ return
330
+
331
+ def unload_all(self):
332
+ self.modelA_path = ''
333
+ self.modelB_path = ''
334
+ self.applied_weights = [0.0] * 27
335
+ del self.modelA_state_dict
336
+ self.modelA_state_dict = None
337
+ del self.modelA_state_dict_by_blocks
338
+ self.modelA_state_dict_by_blocks = []
339
+ del self.modelB_state_dict
340
+ self.modelB_state_dict = None
341
+ del self.modelB_state_dict_by_blocks
342
+ self.modelB_state_dict_by_blocks = []
343
+ # self.unet_block_module_list = []
344
+ self.enabled = False
345
+
346
+
347
+ class Script(scripts.Script):
348
+ def __init__(self) -> None:
349
+ super().__init__()
350
+ if shared.UNetBManager is None:
351
+ try:
352
+ shared.UNetBManager = UNetStateManager(shared.sd_model.model.diffusion_model)
353
+ except AttributeError:
354
+ shared.UNetBManager = None
355
+ from modules.call_queue import wrap_queued_call
356
+
357
+ def reload_modelA_checkpoint():
358
+ if shared.opts.sd_model_checkpoint == shared.sd_model.sd_checkpoint_info.title:
359
+ return
360
+ sd_models.reload_model_weights()
361
+ shared.UNetBManager.reload_modelA()
362
+
363
+ shared.opts.onchange("sd_model_checkpoint",
364
+ wrap_queued_call(reload_modelA_checkpoint), call=False)
365
+
366
+ def title(self):
367
+ return "Runtime block merging for UNet"
368
+
369
+ def show(self, is_img2img):
370
+ return scripts.AlwaysVisible
371
+
372
+ def ui(self, is_img2img):
373
+ process_script_params = []
374
+ with gr.Accordion('Runtime Block Merge', open=False):
375
+ hidden_title = gr.Textbox(label='Runtime Block Merge Title', value='Runtime Block Merge',
376
+ visible=False, interactive=False)
377
+ with gr.Row():
378
+ enabled = gr.Checkbox(label='Enable', value=False, interactive=False)
379
+ unload_button = gr.Button(value='Unload and Disable', elem_id="rbm_unload", visible=False)
380
+ experimental_range_checkbox = gr.Checkbox(label='Enable Experimental Range', value=False)
381
+ force_cpu_checkbox = gr.Checkbox(label='Force CPU (Max Precision)', value=True, interactive=True)
382
+ with gr.Column():
383
+ with gr.Row():
384
+ with gr.Column():
385
+ dd_preset_weight = gr.Dropdown(label="Preset Weights",
386
+ choices=presetWeights.get_preset_name_list())
387
+ config_paste_button = gr.Button(value='Generate Merge Block Weighted Config\u2199\ufe0f',
388
+ elem_id="rbm_config_paste",
389
+ title="Paste Current Block Configs Into Weight Command. Useful for copying to \"Merge Block Weighted\" extension")
390
+ weight_command_textbox = gr.Textbox(label="Weight Command",
391
+ placeholder="Input weight command, then press enter. \nExample: base:0.5, in00:1, out09:0.8, time_embed:0, out:0")
392
+ # weight_config_textbox_readonly = gr.Textbox(label="Weight Config For Merge Block Weighted", interactive=False)
393
+
394
+ # btn_apply_block_weight_from_txt = gr.Button(value="Apply block weight from text")
395
+ # with gr.Row():
396
+ # sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0)
397
+ # chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
398
+ # with gr.Row():
399
+ # with gr.Column(scale=3):
400
+ # with gr.Row():
401
+ # chk_save_as_half = gr.Checkbox(label="Save as half", value=False)
402
+ # chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False)
403
+ # with gr.Column(scale=4):
404
+ # radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids",
405
+ # choices=["None", "Skip", "Force Reset"], value="None",
406
+ # type="index")
407
+ with gr.Row():
408
+ # model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
409
+ model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
410
+ refresh_button = gr.Button(variant='tool', value='\U0001f504', elem_id='rbm_modelb_refresh')
411
+
412
+ # txt_model_O = gr.Text(label="Output Model Name")
413
+ with gr.Row():
414
+ sl_TIME_EMBED = gr.Slider(label="TIME_EMBED", minimum=0, maximum=1, step=0.01, value=0)
415
+ sl_OUT = gr.Slider(label="OUT", minimum=0, maximum=1, step=0.01, value=0)
416
+ with gr.Row():
417
+ with gr.Column(min_width=100):
418
+ sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5)
419
+ sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5)
420
+ sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5)
421
+ sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5)
422
+ sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5)
423
+ sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5)
424
+ sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5)
425
+ sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5)
426
+ sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5)
427
+ sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5)
428
+ sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5)
429
+ sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5)
430
+ with gr.Column(min_width=100):
431
+ gr.Slider(visible=False)
432
+ gr.Slider(visible=False)
433
+ gr.Slider(visible=False)
434
+ gr.Slider(visible=False)
435
+ gr.Slider(visible=False)
436
+ gr.Slider(visible=False)
437
+ gr.Slider(visible=False)
438
+ gr.Slider(visible=False)
439
+ gr.Slider(visible=False)
440
+ gr.Slider(visible=False)
441
+ gr.Slider(visible=False)
442
+ sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5,
443
+ elem_id="mbw_sl_M00")
444
+ with gr.Column(min_width=100):
445
+ sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5)
446
+ sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5)
447
+ sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5)
448
+ sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5)
449
+ sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5)
450
+ sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5)
451
+ sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5)
452
+ sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5)
453
+ sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5)
454
+ sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5)
455
+ sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5)
456
+ sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5)
457
+
458
+ sl_INPUT = [
459
+ sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
460
+ sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11]
461
+ sl_MID = [sl_M_00]
462
+ sl_OUTPUT = [
463
+ sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
464
+ sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11]
465
+ sl_ALL_nat = [*sl_INPUT, *sl_MID, sl_OUT, *sl_OUTPUT, sl_TIME_EMBED]
466
+ sl_ALL = [*sl_INPUT, *sl_MID, *sl_OUTPUT, sl_TIME_EMBED, sl_OUT]
467
+
468
+
469
+
470
+
471
+
472
+ def handle_modelB_load(modelB, force_cpu_checkbox, *slALL):
473
+ if modelB is None:
474
+ return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False)
475
+ load_flag = shared.UNetBManager.load_modelB(modelB, force_cpu_checkbox, slALL)
476
+ if load_flag:
477
+ return modelB, True, gr.update(interactive=False), gr.update(visible=True), gr.update(visible=True)
478
+ else:
479
+ return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False)
480
+
481
+ def handle_unload():
482
+ shared.UNetBManager.restore_original_unet()
483
+ shared.UNetBManager.unload_all()
484
+ return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False)
485
+
486
+ def handle_weight_change(*slALL):
487
+ # convert float list to string+
488
+ slALL_str = [str(sl) for sl in slALL]
489
+ old_config_str = ','.join(slALL_str[:25])
490
+ return old_config_str
491
+
492
+ # for slider in sl_ALL:
493
+ # # slider.change(fn=handle_weight_change, inputs=sl_ALL, outputs=sl_ALL)
494
+ # slider.change(fn=handle_weight_change, inputs=sl_ALL, outputs=[weight_config_textbox_readonly])
495
+
496
+
497
+ def on_weight_command_submit(command_str, *current_weights):
498
+ weight_list = parse_weight_str_to_list(command_str, list(current_weights))
499
+ if not weight_list:
500
+ return [gr.update() for _ in range(27)]
501
+ if len(weight_list) == 25:
502
+ # noinspection PyTypeChecker
503
+ weight_list.extend([gr.update(), gr.update()])
504
+ return weight_list
505
+
506
+ weight_command_textbox.submit(
507
+ fn=on_weight_command_submit,
508
+ inputs=[weight_command_textbox, *sl_ALL],
509
+ outputs=sl_ALL
510
+ )
511
+
512
+ def parse_weight_str_to_list(weightstr, current_weights):
513
+ weightstr = weightstr[:500]
514
+ if ':' in weightstr:
515
+ # parse as json
516
+ weightstr = weightstr.replace(' ', '')
517
+ cmd_segments = weightstr.split(',')
518
+ constructed_json_segments = [f'"{key.upper()}":{value}' for key, value in
519
+ [x.split(':') for x in cmd_segments]]
520
+ constructed_json = '{' + ','.join(constructed_json_segments) + '}'
521
+ try:
522
+ parsed_json = json.loads(constructed_json)
523
+
524
+ except Exception as e:
525
+ print(e)
526
+ return None
527
+ weight_name_map = {
528
+ 'IN00': 0,
529
+ 'IN01': 1,
530
+ 'IN02': 2,
531
+ 'IN03': 3,
532
+ 'IN04': 4,
533
+ 'IN05': 5,
534
+ 'IN06': 6,
535
+ 'IN07': 7,
536
+ 'IN08': 8,
537
+ 'IN09': 9,
538
+ 'IN10': 10,
539
+ 'IN11': 11,
540
+ 'M00': 12,
541
+ 'OUT00': 13,
542
+ 'OUT01': 14,
543
+ 'OUT02': 15,
544
+ 'OUT03': 16,
545
+ 'OUT04': 17,
546
+ 'OUT05': 18,
547
+ 'OUT06': 19,
548
+ 'OUT07': 20,
549
+ 'OUT08': 21,
550
+ 'OUT09': 22,
551
+ 'OUT10': 23,
552
+ 'OUT11': 24,
553
+ 'TIME_EMBED': 25,
554
+ 'OUT': 26
555
+ }
556
+ extra_commands = ['BASE']
557
+ # type check
558
+ for key, value in parsed_json.items():
559
+ if key not in weight_name_map and key not in extra_commands:
560
+ print(f'invalid key: {key}')
561
+ return None
562
+ if not (isinstance(value, (float, int))) or value < -1 or value > 2:
563
+ print(f'{key} value {value} out of range')
564
+ return None
565
+
566
+ weight_list = current_weights
567
+ if 'BASE' in parsed_json:
568
+ weight_list = [float(parsed_json['BASE'])] * 27
569
+ del parsed_json['BASE']
570
+ for key, value in parsed_json.items():
571
+ weight_list[weight_name_map[key]] = value
572
+ return weight_list
573
+ else:
574
+ # parse as list
575
+ _list = [x.strip() for x in weightstr.split(",")]
576
+ if len(_list) != 25 and len(_list) != 27:
577
+ return None
578
+ validated_float_weight_list = []
579
+ for x in _list:
580
+ try:
581
+ validated_float_weight_list.append(float(x))
582
+ except ValueError:
583
+ return None
584
+ return validated_float_weight_list
585
+
586
+ def on_change_dd_preset_weight(preset_weight_name, *current_weights):
587
+ _weights = presetWeights.find_weight_by_name(preset_weight_name)
588
+ weight_list = parse_weight_str_to_list(_weights, list(current_weights))
589
+ if not weight_list:
590
+ return [gr.update() for _ in range(27)]
591
+ if len(weight_list) == 25:
592
+ # noinspection PyTypeChecker
593
+ weight_list.extend([gr.update(), gr.update()])
594
+ return weight_list
595
+
596
+ dd_preset_weight.change(
597
+ fn=on_change_dd_preset_weight,
598
+ inputs=[dd_preset_weight, *sl_ALL],
599
+ outputs=sl_ALL
600
+ )
601
+
602
+ def update_slider_range(experimental_range_flag):
603
+ if experimental_range_flag:
604
+ return [gr.update(minimum=-1, maximum=2) for _ in sl_ALL]
605
+ else:
606
+ return [gr.update(minimum=0, maximum=1) for _ in sl_ALL]
607
+
608
+ experimental_range_checkbox.change(fn=update_slider_range, inputs=[experimental_range_checkbox],
609
+ outputs=sl_ALL)
610
+
611
+ def on_config_paste(*current_weights):
612
+ slALL_str = [str(sl) for sl in current_weights]
613
+ old_config_str = ','.join(slALL_str[:25])
614
+ return old_config_str
615
+
616
+ config_paste_button.click(fn=on_config_paste, inputs=[*sl_ALL], outputs=[weight_command_textbox])
617
+
618
+ def refresh_modelB_dropdown():
619
+ return gr.update(choices=sd_models.checkpoint_tiles())
620
+
621
+ refresh_button.click(
622
+ fn=refresh_modelB_dropdown,
623
+ inputs=None,
624
+ outputs=[model_B]
625
+ )
626
+
627
+ # process_script_params.append(hidden_title)
628
+ process_script_params.extend(sl_ALL_nat)
629
+ process_script_params.append(model_B)
630
+ process_script_params.append(enabled)
631
+
632
+ with gr.Row():
633
+ output_mode_radio = gr.Radio(label="Output Mode",choices=["Max Precision", "Runtime Snapshot"],
634
+ value="Max Precision", type="value", interactive=True)
635
+ position_id_fix_radio = gr.Radio(label="Skip/Reset CLIP position_ids",
636
+ choices=["Keep Original", "Fix"], value="Keep Original", type="value", interactive=True)
637
+
638
+ output_format_radio = gr.Radio(label="Output Format",
639
+ choices=[".ckpt", ".safetensors"], value=".ckpt", type="value",
640
+ interactive=True)
641
+ with gr.Row():
642
+ output_recipe_checkbox = gr.Checkbox(label="Output Recipe", value=True, interactive=True)
643
+
644
+
645
+ # with gr.Row():
646
+ # save_snapshot_checkbox = gr.Checkbox(label="Save Snapshot", value=False)
647
+ with gr.Row():
648
+ save_checkpoint_name_textbox = gr.Textbox(label="New Checkpoint Name")
649
+ save_checkpoint_button = gr.Button(value="Save Runtime Checkpoint", elem_id="mbw_save_checkpoint_button", variant='primary', interactive=True, visible=False, )
650
+
651
+ def on_save_checkpoint(output_mode_radio, position_id_fix_radio, output_format_radio, save_checkpoint_name, output_recipe_checkbox, *weights,
652
+ ):
653
+ current_weights_nat = weights[:27]
654
+
655
+ weights_output_recipe = weights[27:]
656
+ if not save_checkpoint_name:
657
+ # current timestamp
658
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
659
+ save_checkpoint_name = f"mbw_{timestamp_str}"
660
+ save_checkpoint_namewext = save_checkpoint_name + output_format_radio
661
+ loaded_sd_model_path = Path(shared.sd_model.sd_model_checkpoint)
662
+ model_ext = loaded_sd_model_path.suffix
663
+ if model_ext == '.ckpt':
664
+
665
+ model_A_raw_state_dict = torch.load(shared.sd_model.sd_model_checkpoint, map_location='cpu')
666
+ if 'state_dict' in model_A_raw_state_dict:
667
+ model_A_raw_state_dict = model_A_raw_state_dict['state_dict']
668
+ elif model_ext == '.safetensors':
669
+ model_A_raw_state_dict = safetensors.torch.load_file(shared.sd_model.sd_model_checkpoint, device="cpu")
670
+ save_checkpoint_path = Path(shared.sd_model.sd_model_checkpoint).parent / save_checkpoint_namewext
671
+
672
+ if output_mode_radio == 'Runtime Snapshot':
673
+ snapshot_state_dict = shared.sd_model.model.diffusion_model.state_dict()
674
+
675
+ elif output_mode_radio == 'Max Precision':
676
+ snapshot_state_dict = shared.UNetBManager.model_state_construct(current_weights_nat)
677
+
678
+ snapshot_state_dict_prefixed = {'model.diffusion_model.' + key: value for key, value in
679
+ snapshot_state_dict.items()}
680
+ if not set(snapshot_state_dict_prefixed.keys()).issubset(set(model_A_raw_state_dict.keys())):
681
+ print(
682
+ 'warning: snapshot state_dict keys are not subset of model A state_dict keys, possible structural deviation')
683
+
684
+ combined_state_dict = {**model_A_raw_state_dict, **snapshot_state_dict_prefixed}
685
+ if position_id_fix_radio == 'Fix':
686
+ combined_state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.tensor([list(range(77))], dtype=torch.int64)
687
+
688
+ if output_format_radio == '.ckpt':
689
+ state_dict_save = {'state_dict': combined_state_dict}
690
+ torch.save(state_dict_save, save_checkpoint_path)
691
+ elif output_format_radio == '.safetensors':
692
+ safetensors.torch.save_file(combined_state_dict, save_checkpoint_path)
693
+
694
+ if output_recipe_checkbox:
695
+ recipe_path = Path(shared.sd_model.sd_model_checkpoint).parent / f"{save_checkpoint_name}.recipe.txt"
696
+ with open(recipe_path, 'w') as f:
697
+ f.write(f"modelA={shared.sd_model.sd_model_checkpoint}\n")
698
+ f.write(f"modelB={shared.UNetBManager.modelB_path}\n")
699
+ f.write(f"position_id_fix={position_id_fix_radio}\n")
700
+ f.write(f"output_mode={output_mode_radio}\n")
701
+ f.write(f"{','.join([str(w) for w in weights_output_recipe])}\n")
702
+
703
+ return gr.update(value=save_checkpoint_name)
704
+
705
+
706
+ def on_change_force_cpu(force_cpu_flag):
707
+ if not force_cpu_flag:
708
+ return gr.update(choices=["Runtime Snapshot"], value="Runtime Snapshot")
709
+ else:
710
+ return gr.update(choices=["Max Precision", "Runtime Snapshot"], value="Max Precision")
711
+
712
+
713
+ save_checkpoint_button.click(
714
+ fn=on_save_checkpoint,
715
+ inputs=[output_mode_radio, position_id_fix_radio, output_format_radio, save_checkpoint_name_textbox, output_recipe_checkbox, *sl_ALL_nat, *sl_ALL],
716
+ outputs=[save_checkpoint_name_textbox],
717
+ show_progress=True
718
+ )
719
+ force_cpu_checkbox.change(fn=on_change_force_cpu, inputs=[force_cpu_checkbox], outputs=[output_mode_radio])
720
+ model_B.change(fn=handle_modelB_load, inputs=[model_B, force_cpu_checkbox, *sl_ALL_nat],
721
+ outputs=[model_B, enabled, force_cpu_checkbox, save_checkpoint_button, unload_button])
722
+ unload_button.click(fn=handle_unload, inputs=[], outputs=[model_B, enabled, force_cpu_checkbox, save_checkpoint_button, unload_button])
723
+
724
+ return process_script_params
725
+
726
+ def process(self, p, *args):
727
+ gui_weights = args[:27]
728
+ modelB = args[27]
729
+ enabled = args[28]
730
+ if not enabled:
731
+ return
732
+ if not shared.UNetBManager:
733
+ shared.UNetBManager = UNetStateManager(shared.sd_model.model.diffusion_model)
734
+ shared.UNetBManager.model_state_apply_modified_blocks(gui_weights, modelB)