| import torch, math |
| from ..core.loader import load_state_dict |
| from typing import Union |
|
|
| class GeneralLoRALoader: |
| def __init__(self, device="cpu", torch_dtype=torch.float32): |
| self.device = device |
| self.torch_dtype = torch_dtype |
| |
| |
| def get_name_dict(self, lora_state_dict): |
| lora_name_dict = {} |
| for key in lora_state_dict: |
| if ".lora_B." not in key: |
| continue |
| keys = key.split(".") |
| if len(keys) > keys.index("lora_B") + 2: |
| keys.pop(keys.index("lora_B") + 1) |
| keys.pop(keys.index("lora_B")) |
| if keys[0] == "diffusion_model": |
| keys.pop(0) |
| keys.pop(-1) |
| target_name = ".".join(keys) |
| lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) |
| return lora_name_dict |
|
|
|
|
| def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): |
| updated_num = 0 |
| lora_name_dict = self.get_name_dict(state_dict_lora) |
| for name, module in model.named_modules(): |
| if name in lora_name_dict: |
| weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) |
| weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) |
| if len(weight_up.shape) == 4: |
| weight_up = weight_up.squeeze(3).squeeze(2) |
| weight_down = weight_down.squeeze(3).squeeze(2) |
| weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) |
| else: |
| weight_lora = alpha * torch.mm(weight_up, weight_down) |
| state_dict = module.state_dict() |
| state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora |
| module.load_state_dict(state_dict) |
| updated_num += 1 |
| print(f"{updated_num} tensors are updated by LoRA.") |
|
|
| class FluxLoRALoader(GeneralLoRALoader): |
| def __init__(self, device="cpu", torch_dtype=torch.float32): |
| super().__init__(device=device, torch_dtype=torch_dtype) |
| |
| self.diffusers_rename_dict = { |
| "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", |
| "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", |
| "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", |
| "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", |
| "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", |
| "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", |
| "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", |
| "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", |
| "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", |
| } |
|
|
| self.civitai_rename_dict = { |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", |
| "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", |
| "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", |
| "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", |
| "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", |
| } |
|
|
| def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): |
| super().load(model, state_dict_lora, alpha) |
|
|
| |
| def convert_state_dict(self,state_dict): |
|
|
| def guess_block_id(name,model_resource): |
| if model_resource == 'civitai': |
| names = name.split("_") |
| for i in names: |
| if i.isdigit(): |
| return i, name.replace(f"_{i}_", "_blockid_") |
| if model_resource == 'diffusers': |
| names = name.split(".") |
| for i in names: |
| if i.isdigit(): |
| return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") |
| return None, None |
|
|
| def guess_resource(state_dict): |
| for k in state_dict: |
| if "lora_unet_" in k: |
| return 'civitai' |
| elif k.startswith("transformer."): |
| return 'diffusers' |
| else: |
| None |
| |
| model_resource = guess_resource(state_dict) |
| if model_resource is None: |
| return state_dict |
|
|
| rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict |
| def guess_alpha(state_dict): |
| for name, param in state_dict.items(): |
| if ".alpha" in name: |
| for suffix in [".lora_down.weight", ".lora_A.weight"]: |
| name_ = name.replace(".alpha", suffix) |
| if name_ in state_dict: |
| lora_alpha = param.item() / state_dict[name_].shape[0] |
| lora_alpha = math.sqrt(lora_alpha) |
| return lora_alpha |
|
|
| return 1 |
| |
| alpha = guess_alpha(state_dict) |
| |
| state_dict_ = {} |
| for name, param in state_dict.items(): |
| block_id, source_name = guess_block_id(name,model_resource) |
| if alpha != 1: |
| param *= alpha |
| if source_name in rename_dict: |
| target_name = rename_dict[source_name] |
| target_name = target_name.replace(".blockid.", f".{block_id}.") |
| state_dict_[target_name] = param |
| else: |
| state_dict_[name] = param |
| |
| if model_resource == 'diffusers': |
| for name in list(state_dict_.keys()): |
| if "single_blocks." in name and ".a_to_q." in name: |
| mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) |
| if mlp is None: |
| dim = 4 |
| if 'lora_A' in name: |
| dim = 1 |
| mlp = torch.zeros(dim * state_dict_[name].shape[0], |
| *state_dict_[name].shape[1:], |
| dtype=state_dict_[name].dtype) |
| else: |
| state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) |
| if 'lora_A' in name: |
| param = torch.concat([ |
| state_dict_.pop(name), |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), |
| mlp, |
| ], dim=0) |
| elif 'lora_B' in name: |
| d, r = state_dict_[name].shape |
| param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) |
| param[:d, :r] = state_dict_.pop(name) |
| param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) |
| param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) |
| param[3*d:, 3*r:] = mlp |
| else: |
| param = torch.concat([ |
| state_dict_.pop(name), |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), |
| mlp, |
| ], dim=0) |
| name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") |
| state_dict_[name_] = param |
| for name in list(state_dict_.keys()): |
| for component in ["a", "b"]: |
| if f".{component}_to_q." in name: |
| name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") |
| concat_dim = 0 |
| if 'lora_A' in name: |
| param = torch.concat([ |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], |
| ], dim=0) |
| elif 'lora_B' in name: |
| origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] |
| d, r = origin.shape |
| |
| param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) |
| param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] |
| param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] |
| param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] |
| else: |
| param = torch.concat([ |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], |
| ], dim=0) |
| state_dict_[name_] = param |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) |
| return state_dict_ |
|
|
|
|
| class LoraMerger(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.weight_base = torch.nn.Parameter(torch.randn((dim,))) |
| self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) |
| self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) |
| self.weight_out = torch.nn.Parameter(torch.ones((dim,))) |
| self.bias = torch.nn.Parameter(torch.randn((dim,))) |
| self.activation = torch.nn.Sigmoid() |
| self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) |
| self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) |
| |
| def forward(self, base_output, lora_outputs): |
| norm_base_output = self.norm_base(base_output) |
| norm_lora_outputs = self.norm_lora(lora_outputs) |
| gate = self.activation( |
| norm_base_output * self.weight_base \ |
| + norm_lora_outputs * self.weight_lora \ |
| + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias |
| ) |
| output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) |
| return output |
|
|
| class FluxLoraPatcher(torch.nn.Module): |
| def __init__(self, lora_patterns=None): |
| super().__init__() |
| if lora_patterns is None: |
| lora_patterns = self.default_lora_patterns() |
| model_dict = {} |
| for lora_pattern in lora_patterns: |
| name, dim = lora_pattern["name"], lora_pattern["dim"] |
| model_dict[name.replace(".", "___")] = LoraMerger(dim) |
| self.model_dict = torch.nn.ModuleDict(model_dict) |
| |
| def default_lora_patterns(self): |
| lora_patterns = [] |
| lora_dict = { |
| "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, |
| "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, |
| } |
| for i in range(19): |
| for suffix in lora_dict: |
| lora_patterns.append({ |
| "name": f"blocks.{i}.{suffix}", |
| "dim": lora_dict[suffix] |
| }) |
| lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} |
| for i in range(38): |
| for suffix in lora_dict: |
| lora_patterns.append({ |
| "name": f"single_blocks.{i}.{suffix}", |
| "dim": lora_dict[suffix] |
| }) |
| return lora_patterns |
| |
| def forward(self, base_output, lora_outputs, name): |
| return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) |
|
|