| import comfy_extras.nodes_model_merging |
|
|
| class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
| @classmethod |
| def INPUT_TYPES(s): |
| arg_dict = { "model1": ("MODEL",), |
| "model2": ("MODEL",)} |
|
|
| argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
| arg_dict["time_embed."] = argument |
| arg_dict["label_emb."] = argument |
|
|
| for i in range(9): |
| arg_dict["input_blocks.{}".format(i)] = argument |
|
|
| for i in range(3): |
| arg_dict["middle_block.{}".format(i)] = argument |
|
|
| for i in range(9): |
| arg_dict["output_blocks.{}".format(i)] = argument |
|
|
| arg_dict["out."] = argument |
|
|
| return {"required": arg_dict} |
|
|
|
|
| class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
| @classmethod |
| def INPUT_TYPES(s): |
| arg_dict = { "model1": ("MODEL",), |
| "model2": ("MODEL",)} |
|
|
| argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
| arg_dict["time_embed."] = argument |
| arg_dict["label_emb."] = argument |
|
|
| transformers = {4: 2, 5:2, 7:10, 8:10} |
|
|
| for i in range(9): |
| arg_dict["input_blocks.{}.0.".format(i)] = argument |
| if i in transformers: |
| arg_dict["input_blocks.{}.1.".format(i)] = argument |
| for j in range(transformers[i]): |
| arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument |
|
|
| for i in range(3): |
| arg_dict["middle_block.{}.".format(i)] = argument |
| if i == 1: |
| for j in range(10): |
| arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument |
|
|
| transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} |
| for i in range(9): |
| arg_dict["output_blocks.{}.0.".format(i)] = argument |
| t = 8 - i |
| if t in transformers: |
| arg_dict["output_blocks.{}.1.".format(i)] = argument |
| for j in range(transformers[t]): |
| arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument |
|
|
| arg_dict["out."] = argument |
|
|
| return {"required": arg_dict} |
|
|
| class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
| @classmethod |
| def INPUT_TYPES(s): |
| arg_dict = { "model1": ("MODEL",), |
| "model2": ("MODEL",)} |
|
|
| argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
| arg_dict["time_embed."] = argument |
| arg_dict["label_emb."] = argument |
|
|
| transformers = {4: 2, 5:2, 7:10, 8:10} |
| transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"] |
|
|
| for i in range(9): |
| arg_dict["input_blocks.{}.0.".format(i)] = argument |
| if i in transformers: |
| arg_dict["input_blocks.{}.1.".format(i)] = argument |
| for j in range(transformers[i]): |
| for x in transformers_args: |
| arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
| for i in range(3): |
| arg_dict["middle_block.{}.".format(i)] = argument |
| if i == 1: |
| for j in range(10): |
| for x in transformers_args: |
| arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
| transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} |
| for i in range(9): |
| arg_dict["output_blocks.{}.0.".format(i)] = argument |
| t = 8 - i |
| if t in transformers: |
| arg_dict["output_blocks.{}.1.".format(i)] = argument |
| for j in range(transformers[t]): |
| for x in transformers_args: |
| arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
| arg_dict["out."] = argument |
|
|
| return {"required": arg_dict} |
|
|
| NODE_CLASS_MAPPINGS = { |
| "ModelMergeSDXL": ModelMergeSDXL, |
| "ModelMergeSDXLTransformers": ModelMergeSDXLTransformers, |
| "ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers, |
| } |
|
|