kawaiithug commited on
Commit
6af38a3
·
verified ·
1 Parent(s): accdaae

Upload 2 files

Browse files
Kohya_ss/sdxl_merge_lora.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Replace kohya_ss/sd-scripts/networks/sdxl_merge_lora.py
3
+ '''
4
+
5
+ import math
6
+ import argparse
7
+ import os
8
+ import time
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ from tqdm import tqdm
12
+ from library import sai_model_spec, sdxl_model_util, train_util
13
+ import library.model_util as model_util
14
+ import lora
15
+ from library.utils import setup_logging
16
+ setup_logging()
17
+ import logging
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def load_state_dict(file_name, dtype):
22
+ if os.path.splitext(file_name)[1] == ".safetensors":
23
+ sd = load_file(file_name)
24
+ metadata = train_util.load_metadata_from_safetensors(file_name)
25
+ else:
26
+ sd = torch.load(file_name, map_location="cuda")
27
+ metadata = {}
28
+
29
+ for key in list(sd.keys()):
30
+ if type(sd[key]) == torch.Tensor:
31
+ sd[key] = sd[key].to(dtype)
32
+
33
+ return sd, metadata
34
+
35
+
36
+ def save_to_file(file_name, model, state_dict, dtype, metadata):
37
+ if dtype is not None:
38
+ for key in list(state_dict.keys()):
39
+ if type(state_dict[key]) == torch.Tensor:
40
+ state_dict[key] = state_dict[key].to(dtype)
41
+
42
+ if os.path.splitext(file_name)[1] == ".safetensors":
43
+ save_file(model, file_name, metadata=metadata)
44
+ else:
45
+ torch.save(model, file_name)
46
+
47
+
48
+ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
49
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
+ text_encoder1.to(device)
51
+ text_encoder2.to(device)
52
+ unet.to(device)
53
+
54
+ # create module map
55
+ name_to_module = {}
56
+ for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
57
+ if i <= 1:
58
+ if i == 0:
59
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
60
+ else:
61
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
62
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
63
+ else:
64
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
65
+ target_replace_modules = (
66
+ lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
67
+ )
68
+
69
+ for name, module in root_module.named_modules():
70
+ if module.__class__.__name__ in target_replace_modules:
71
+ for child_name, child_module in module.named_modules():
72
+ if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
73
+ lora_name = prefix + "." + name + "." + child_name
74
+ lora_name = lora_name.replace(".", "_")
75
+ name_to_module[lora_name] = child_module
76
+
77
+ for model, ratio in zip(models, ratios):
78
+ logger.info(f"loading: {model}")
79
+ lora_sd, _ = load_state_dict(model, merge_dtype)
80
+
81
+ # Move lora weights to CUDA
82
+ for key in lora_sd.keys():
83
+ if isinstance(lora_sd[key], torch.Tensor):
84
+ lora_sd[key] = lora_sd[key].to(device)
85
+
86
+ logger.info(f"merging...")
87
+ for key in tqdm(lora_sd.keys()):
88
+ if "lora_down" in key:
89
+ up_key = key.replace("lora_down", "lora_up")
90
+ alpha_key = key[: key.index("lora_down")] + "alpha"
91
+
92
+ # find original module for this lora
93
+ module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
94
+ if module_name not in name_to_module:
95
+ logger.info(f"no module found for LoRA weight: {key}")
96
+ continue
97
+ module = name_to_module[module_name]
98
+ # logger.info(f"apply {key} to {module}")
99
+
100
+ down_weight = lora_sd[key]
101
+ up_weight = lora_sd[up_key]
102
+
103
+ dim = down_weight.size()[0]
104
+ alpha = lora_sd.get(alpha_key, dim)
105
+ scale = alpha / dim
106
+
107
+ # W <- W + U * D
108
+ weight = module.weight
109
+ # logger.info(module_name, down_weight.size(), up_weight.size())
110
+ if len(weight.size()) == 2:
111
+ # linear
112
+ weight = weight + ratio * (up_weight @ down_weight) * scale
113
+ elif down_weight.size()[2:4] == (1, 1):
114
+ # conv2d 1x1
115
+ weight = (
116
+ weight
117
+ + ratio
118
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
119
+ * scale
120
+ )
121
+ else:
122
+ # conv2d 3x3
123
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
124
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
125
+ weight = weight + ratio * conved * scale
126
+
127
+ module.weight = torch.nn.Parameter(weight)
128
+
129
+
130
+ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
131
+ base_alphas = {} # alpha for merged model
132
+ base_dims = {}
133
+
134
+ merged_sd = {}
135
+ v2 = None
136
+ base_model = None
137
+ for model, ratio in zip(models, ratios):
138
+ logger.info(f"loading: {model}")
139
+ lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
140
+
141
+ if lora_metadata is not None:
142
+ if v2 is None:
143
+ v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
144
+ if base_model is None:
145
+ base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
146
+
147
+ # get alpha and dim
148
+ alphas = {} # alpha for current model
149
+ dims = {} # dims for current model
150
+ for key in lora_sd.keys():
151
+ if "alpha" in key:
152
+ lora_module_name = key[: key.rfind(".alpha")]
153
+ alpha = float(lora_sd[key].detach().numpy())
154
+ alphas[lora_module_name] = alpha
155
+ if lora_module_name not in base_alphas:
156
+ base_alphas[lora_module_name] = alpha
157
+ elif "lora_down" in key:
158
+ lora_module_name = key[: key.rfind(".lora_down")]
159
+ dim = lora_sd[key].size()[0]
160
+ dims[lora_module_name] = dim
161
+ if lora_module_name not in base_dims:
162
+ base_dims[lora_module_name] = dim
163
+
164
+ for lora_module_name in dims.keys():
165
+ if lora_module_name not in alphas:
166
+ alpha = dims[lora_module_name]
167
+ alphas[lora_module_name] = alpha
168
+ if lora_module_name not in base_alphas:
169
+ base_alphas[lora_module_name] = alpha
170
+
171
+ logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
172
+
173
+ # merge
174
+ logger.info(f"merging...")
175
+ for key in tqdm(lora_sd.keys()):
176
+ if "alpha" in key:
177
+ continue
178
+
179
+ if "lora_up" in key and concat:
180
+ concat_dim = 1
181
+ elif "lora_down" in key and concat:
182
+ concat_dim = 0
183
+ else:
184
+ concat_dim = None
185
+
186
+ lora_module_name = key[: key.rfind(".lora_")]
187
+
188
+ base_alpha = base_alphas[lora_module_name]
189
+ alpha = alphas[lora_module_name]
190
+
191
+ scale = math.sqrt(alpha / base_alpha) * ratio
192
+ scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
193
+
194
+ if key in merged_sd:
195
+ assert (
196
+ merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
197
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
198
+ if concat_dim is not None:
199
+ merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
200
+ else:
201
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
202
+ else:
203
+ merged_sd[key] = lora_sd[key] * scale
204
+
205
+ # set alpha to sd
206
+ for lora_module_name, alpha in base_alphas.items():
207
+ key = lora_module_name + ".alpha"
208
+ merged_sd[key] = torch.tensor(alpha)
209
+ if shuffle:
210
+ key_down = lora_module_name + ".lora_down.weight"
211
+ key_up = lora_module_name + ".lora_up.weight"
212
+ dim = merged_sd[key_down].shape[0]
213
+ perm = torch.randperm(dim)
214
+ merged_sd[key_down] = merged_sd[key_down][perm]
215
+ merged_sd[key_up] = merged_sd[key_up][:,perm]
216
+
217
+ logger.info("merged model")
218
+ logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
219
+
220
+ # check all dims are same
221
+ dims_list = list(set(base_dims.values()))
222
+ alphas_list = list(set(base_alphas.values()))
223
+ all_same_dims = True
224
+ all_same_alphas = True
225
+ for dims in dims_list:
226
+ if dims != dims_list[0]:
227
+ all_same_dims = False
228
+ break
229
+ for alphas in alphas_list:
230
+ if alphas != alphas_list[0]:
231
+ all_same_alphas = False
232
+ break
233
+
234
+ # build minimum metadata
235
+ dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
236
+ alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
237
+ metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
238
+
239
+ return merged_sd, metadata
240
+
241
+
242
+ def merge(args):
243
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
244
+
245
+ def str_to_dtype(p):
246
+ if p == "float":
247
+ return torch.float
248
+ if p == "fp16":
249
+ return torch.float16
250
+ if p == "bf16":
251
+ return torch.bfloat16
252
+ return None
253
+
254
+ merge_dtype = str_to_dtype(args.precision)
255
+ save_dtype = str_to_dtype(args.save_precision)
256
+ if save_dtype is None:
257
+ save_dtype = merge_dtype
258
+
259
+ if args.sd_model is not None:
260
+ logger.info(f"loading SD model: {args.sd_model}")
261
+
262
+ (
263
+ text_model1,
264
+ text_model2,
265
+ vae,
266
+ unet,
267
+ logit_scale,
268
+ ckpt_info,
269
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cuda")
270
+
271
+ merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
272
+
273
+ if args.no_metadata:
274
+ sai_metadata = None
275
+ else:
276
+ merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
277
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
278
+ sai_metadata = sai_model_spec.build_metadata(
279
+ None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
280
+ )
281
+
282
+ logger.info(f"saving SD model to: {args.save_to}")
283
+ sdxl_model_util.save_stable_diffusion_checkpoint(
284
+ args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
285
+ )
286
+ else:
287
+ state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
288
+
289
+ logger.info(f"calculating hashes and creating metadata...")
290
+
291
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
292
+ metadata["sshs_model_hash"] = model_hash
293
+ metadata["sshs_legacy_hash"] = legacy_hash
294
+
295
+ if not args.no_metadata:
296
+ merged_from = sai_model_spec.build_merged_from(args.models)
297
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
298
+ sai_metadata = sai_model_spec.build_metadata(
299
+ state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
300
+ )
301
+ metadata.update(sai_metadata)
302
+
303
+ logger.info(f"saving model to: {args.save_to}")
304
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
305
+
306
+
307
+ def setup_parser() -> argparse.ArgumentParser:
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument(
310
+ "--save_precision",
311
+ type=str,
312
+ default=None,
313
+ choices=[None, "float", "fp16", "bf16"],
314
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
315
+ )
316
+ parser.add_argument(
317
+ "--precision",
318
+ type=str,
319
+ default="float",
320
+ choices=["float", "fp16", "bf16"],
321
+ help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
322
+ )
323
+ parser.add_argument(
324
+ "--sd_model",
325
+ type=str,
326
+ default=None,
327
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
328
+ )
329
+ parser.add_argument(
330
+ "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
331
+ )
332
+ parser.add_argument(
333
+ "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
334
+ )
335
+ parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
336
+ parser.add_argument(
337
+ "--no_metadata",
338
+ action="store_true",
339
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
340
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
341
+ )
342
+ parser.add_argument(
343
+ "--concat",
344
+ action="store_true",
345
+ help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
346
+ + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
347
+ )
348
+ parser.add_argument(
349
+ "--shuffle",
350
+ action="store_true",
351
+ help="shuffle lora weight./ "
352
+ + "LoRAの重みをシャッフルする",
353
+ )
354
+
355
+ return parser
356
+
357
+
358
+ if __name__ == "__main__":
359
+ parser = setup_parser()
360
+
361
+ args = parser.parse_args()
362
+ merge(args)
LoRA Block Weight Presets/presets.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These presets were made to alleviate overfitted/overbaked LoRAs. They prioritize strong base characteristics and significant middle-layer style transfer while minimizing excessive style bleed-through.
2
+
3
+ ## SINE = More dramatic drop-off in OUT blocks with strong M00 peak and lower final weight. Best for when you want subject matter and minimal style influnce.
4
+
5
+ ## PLUM = More balanced OUT block transition and slightly softer M00 peak. Consistent OUT blocks. Better for maintaining subjects to the highest degree of faithfulness while allowing for a tiny touch of style to come in.
6
+
7
+ ## PEAR = Similar results to PLUM but made only for Lycoris.
8
+
9
+ SINE:1,0.8,0.6,0.4,0.5,1,0.8,0.5,0.3,0.2,0.1,0.3
10
+
11
+ PLUM:1,0.8,0.6,0.4,0.5,0.95,0.6,0.5,0.3,0.4,0.3,0.5
12
+
13
+ PEAR:1,0.3,0.3,0.3,0.6,0.6,0.6,0.9,0.9,0.9,1,0.3,0.3,0.3,0.6,0.6,0.6,1,1,1
14
+