Spaces:
Running on Zero
Running on Zero
| from .general import GeneralLoRALoader | |
| import torch, math | |
| class FluxLoRALoader(GeneralLoRALoader): | |
| def __init__(self, device="cpu", torch_dtype=torch.float32): | |
| super().__init__(device=device, torch_dtype=torch_dtype) | |
| self.diffusers_rename_dict = { | |
| "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", | |
| "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", | |
| "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", | |
| "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", | |
| "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", | |
| "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", | |
| "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", | |
| "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", | |
| "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", | |
| } | |
| self.civitai_rename_dict = { | |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", | |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", | |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", | |
| "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", | |
| "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", | |
| "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", | |
| "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", | |
| } | |
| def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): | |
| super().fuse_lora_to_base_model(model, state_dict_lora, alpha) | |
| def convert_state_dict(self, state_dict): | |
| def guess_block_id(name,model_resource): | |
| if model_resource == 'civitai': | |
| names = name.split("_") | |
| for i in names: | |
| if i.isdigit(): | |
| return i, name.replace(f"_{i}_", "_blockid_") | |
| if model_resource == 'diffusers': | |
| names = name.split(".") | |
| for i in names: | |
| if i.isdigit(): | |
| return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") | |
| return None, None | |
| def guess_resource(state_dict): | |
| for k in state_dict: | |
| if "lora_unet_" in k: | |
| return 'civitai' | |
| elif k.startswith("transformer."): | |
| return 'diffusers' | |
| else: | |
| None | |
| model_resource = guess_resource(state_dict) | |
| if model_resource is None: | |
| return state_dict | |
| rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict | |
| def guess_alpha(state_dict): | |
| for name, param in state_dict.items(): | |
| if ".alpha" in name: | |
| for suffix in [".lora_down.weight", ".lora_A.weight"]: | |
| name_ = name.replace(".alpha", suffix) | |
| if name_ in state_dict: | |
| lora_alpha = param.item() / state_dict[name_].shape[0] | |
| lora_alpha = math.sqrt(lora_alpha) | |
| return lora_alpha | |
| return 1 | |
| alpha = guess_alpha(state_dict) | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| block_id, source_name = guess_block_id(name,model_resource) | |
| if alpha != 1: | |
| param *= alpha | |
| if source_name in rename_dict: | |
| target_name = rename_dict[source_name] | |
| target_name = target_name.replace(".blockid.", f".{block_id}.") | |
| state_dict_[target_name] = param | |
| else: | |
| state_dict_[name] = param | |
| if model_resource == 'diffusers': | |
| for name in list(state_dict_.keys()): | |
| if "single_blocks." in name and ".a_to_q." in name: | |
| mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) | |
| if mlp is None: | |
| dim = 4 | |
| if 'lora_A' in name: | |
| dim = 1 | |
| mlp = torch.zeros(dim * state_dict_[name].shape[0], | |
| *state_dict_[name].shape[1:], | |
| dtype=state_dict_[name].dtype) | |
| else: | |
| state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) | |
| mlp = mlp.to(device=state_dict_[name].device) | |
| if 'lora_A' in name: | |
| param = torch.concat([ | |
| state_dict_.pop(name), | |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), | |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), | |
| mlp, | |
| ], dim=0) | |
| elif 'lora_B' in name: | |
| d, r = state_dict_[name].shape | |
| param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) | |
| param[:d, :r] = state_dict_.pop(name) | |
| param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) | |
| param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) | |
| param[3*d:, 3*r:] = mlp | |
| else: | |
| param = torch.concat([ | |
| state_dict_.pop(name), | |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), | |
| state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), | |
| mlp, | |
| ], dim=0) | |
| name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") | |
| state_dict_[name_] = param | |
| for name in list(state_dict_.keys()): | |
| for component in ["a", "b"]: | |
| if f".{component}_to_q." in name: | |
| name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") | |
| concat_dim = 0 | |
| if 'lora_A' in name: | |
| param = torch.concat([ | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], | |
| ], dim=0) | |
| elif 'lora_B' in name: | |
| origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] | |
| d, r = origin.shape | |
| # print(d, r) | |
| param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) | |
| param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] | |
| param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] | |
| param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] | |
| else: | |
| param = torch.concat([ | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], | |
| state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], | |
| ], dim=0) | |
| state_dict_[name_] = param | |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) | |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) | |
| state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) | |
| return state_dict_ | |
| class FluxLoRAConverter: | |
| def __init__(self): | |
| pass | |
| def align_to_opensource_format(state_dict, alpha=None): | |
| prefix_rename_dict = { | |
| "single_blocks": "lora_unet_single_blocks", | |
| "blocks": "lora_unet_double_blocks", | |
| } | |
| middle_rename_dict = { | |
| "norm.linear": "modulation_lin", | |
| "to_qkv_mlp": "linear1", | |
| "proj_out": "linear2", | |
| "norm1_a.linear": "img_mod_lin", | |
| "norm1_b.linear": "txt_mod_lin", | |
| "attn.a_to_qkv": "img_attn_qkv", | |
| "attn.b_to_qkv": "txt_attn_qkv", | |
| "attn.a_to_out": "img_attn_proj", | |
| "attn.b_to_out": "txt_attn_proj", | |
| "ff_a.0": "img_mlp_0", | |
| "ff_a.2": "img_mlp_2", | |
| "ff_b.0": "txt_mlp_0", | |
| "ff_b.2": "txt_mlp_2", | |
| } | |
| suffix_rename_dict = { | |
| "lora_B.weight": "lora_up.weight", | |
| "lora_A.weight": "lora_down.weight", | |
| } | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| names = name.split(".") | |
| if names[-2] != "lora_A" and names[-2] != "lora_B": | |
| names.pop(-2) | |
| prefix = names[0] | |
| middle = ".".join(names[2:-2]) | |
| suffix = ".".join(names[-2:]) | |
| block_id = names[1] | |
| if middle not in middle_rename_dict: | |
| continue | |
| rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] | |
| state_dict_[rename] = param | |
| if rename.endswith("lora_up.weight"): | |
| lora_alpha = alpha if alpha is not None else param.shape[-1] | |
| state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0] | |
| return state_dict_ | |
| def align_to_diffsynth_format(state_dict): | |
| rename_dict = { | |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", | |
| "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", | |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", | |
| "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", | |
| "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", | |
| "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", | |
| "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", | |
| "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", | |
| } | |
| def guess_block_id(name): | |
| names = name.split("_") | |
| for i in names: | |
| if i.isdigit(): | |
| return i, name.replace(f"_{i}_", "_blockid_") | |
| return None, None | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| block_id, source_name = guess_block_id(name) | |
| if source_name in rename_dict: | |
| target_name = rename_dict[source_name] | |
| target_name = target_name.replace(".blockid.", f".{block_id}.") | |
| state_dict_[target_name] = param | |
| else: | |
| state_dict_[name] = param | |
| return state_dict_ | |