#!/usr/bin/env python3 """ Convert LongCat-Image transformer weights from HuggingFace Diffusers format to ComfyUI format. Usage: python conversion.py input.safetensors output.safetensors The input file is the Diffusers-format transformer, typically: meituan-longcat/LongCat-Image/transformer/diffusion_pytorch_model.safetensors The output file will contain ComfyUI-format keys with fused QKV tensors, ready for zero-copy loading via UNETLoader. """ import argparse import torch import logging from safetensors.torch import load_file, save_file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def convert_longcat_image(state_dict): out_sd = {} double_q, double_k, double_v = {}, {}, {} double_tq, double_tk, double_tv = {}, {}, {} single_q, single_k, single_v, single_mlp = {}, {}, {}, {} for k, v in state_dict.items(): if k.startswith("transformer_blocks."): idx = k.split(".")[1] rest = ".".join(k.split(".")[2:]) prefix = "double_blocks.{}.".format(idx) if rest.startswith("norm1.linear."): out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v elif rest.startswith("norm1_context.linear."): out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_q."): double_q[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_k."): double_k[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_v."): double_v[idx + "." + rest.split(".")[-1]] = v elif rest == "attn.norm_q.weight": out_sd[prefix + "img_attn.norm.query_norm.weight"] = v elif rest == "attn.norm_k.weight": out_sd[prefix + "img_attn.norm.key_norm.weight"] = v elif rest.startswith("attn.to_out.0."): out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v elif rest.startswith("attn.add_q_proj."): double_tq[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.add_k_proj."): double_tk[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.add_v_proj."): double_tv[idx + "." + rest.split(".")[-1]] = v elif rest == "attn.norm_added_q.weight": out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v elif rest == "attn.norm_added_k.weight": out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v elif rest.startswith("attn.to_add_out."): out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v elif rest.startswith("ff.net.0.proj."): out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v elif rest.startswith("ff.net.2."): out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v elif rest.startswith("ff_context.net.0.proj."): out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v elif rest.startswith("ff_context.net.2."): out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v else: out_sd["double_blocks.{}.{}".format(idx, rest)] = v elif k.startswith("single_transformer_blocks."): idx = k.split(".")[1] rest = ".".join(k.split(".")[2:]) prefix = "single_blocks.{}.".format(idx) if rest.startswith("norm.linear."): out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_q."): single_q[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_k."): single_k[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("attn.to_v."): single_v[idx + "." + rest.split(".")[-1]] = v elif rest == "attn.norm_q.weight": out_sd[prefix + "norm.query_norm.weight"] = v elif rest == "attn.norm_k.weight": out_sd[prefix + "norm.key_norm.weight"] = v elif rest.startswith("proj_mlp."): single_mlp[idx + "." + rest.split(".")[-1]] = v elif rest.startswith("proj_out."): out_sd[prefix + "linear2." + rest.split(".")[-1]] = v else: out_sd["single_blocks.{}.{}".format(idx, rest)] = v elif k == "x_embedder.weight" or k == "x_embedder.bias": out_sd["img_in." + k.split(".")[-1]] = v elif k == "context_embedder.weight" or k == "context_embedder.bias": out_sd["txt_in." + k.split(".")[-1]] = v elif k.startswith("time_embed.timestep_embedder.linear_1."): out_sd["time_in.in_layer." + k.split(".")[-1]] = v elif k.startswith("time_embed.timestep_embedder.linear_2."): out_sd["time_in.out_layer." + k.split(".")[-1]] = v elif k.startswith("norm_out.linear."): # HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI # LastLayer expects [shift | scale], so swap the two halves. half = v.shape[0] // 2 v = torch.cat([v[half:], v[:half]], dim=0) out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v elif k == "proj_out.weight" or k == "proj_out.bias": out_sd["final_layer.linear." + k.split(".")[-1]] = v else: out_sd[k] = v for suffix in ["weight", "bias"]: for idx in sorted(set(x.split(".")[0] for x in double_q)): qk = idx + "." + suffix if qk in double_q and qk in double_k and qk in double_v: out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0) if qk in double_tq and qk in double_tk and qk in double_tv: out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0) for idx in sorted(set(x.split(".")[0] for x in single_q)): qk = idx + "." + suffix if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp: out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0) return out_sd def main(): parser = argparse.ArgumentParser( description="Convert LongCat-Image weights from Diffusers to ComfyUI format" ) parser.add_argument("input", help="Path to Diffusers-format safetensors file") parser.add_argument("output", help="Path to write ComfyUI-format safetensors file") args = parser.parse_args() logger.info(f"Loading {args.input}...") sd = load_file(args.input) logger.info(f"Converting {len(sd)} keys...") converted = convert_longcat_image(sd) logger.info(f"Saving {len(converted)} keys to {args.output}...") save_file(converted, args.output) logger.info("Done.") if __name__ == "__main__": main()