cccode / merge_experts.py
WayneW's picture
Upload folder using huggingface_hub
705a8fd verified
raw
history blame
3.89 kB
import torch
import yaml
import argparse
from models import AVCDiT_models
def add_exact_keys(mapping, keys):
for k in keys:
mapping[k] = k
def add_mlp_block_keys(mapping, mlp_name, num_blocks):
for i in range(num_blocks):
for fc in ["fc1", "fc2"]:
for param in ["weight", "bias"]:
k = f"blocks.{i}.{mlp_name}.{fc}.{param}"
mapping[k] = k
def load_from_two_checkpoints(model, ckpt1_path, ckpt2_path, map1=None, map2=None, device='cuda'):
ckpt1 = torch.load(ckpt1_path, map_location=device, weights_only=False)
ckpt2 = torch.load(ckpt2_path, map_location=device, weights_only=False)
state1 = {k.replace('_orig_mod.', ''): v for k, v in ckpt1["ema"].items()}
state2 = {k.replace('_orig_mod.', ''): v for k, v in ckpt2["ema"].items()}
model_state = model.state_dict()
new_state = {}
source_info = {} # key: model param name, value: ckpt source name
if map1:
for k_model, k_ckpt in map1.items():
if (
k_ckpt in state1
and k_model in model_state
and state1[k_ckpt].shape == model_state[k_model].shape
):
new_state[k_model] = state1[k_ckpt]
source_info[k_model] = "ckpt1"
if map2:
for k_model, k_ckpt in map2.items():
if (
k_ckpt in state2
and k_model in model_state
and state2[k_ckpt].shape == model_state[k_model].shape
):
new_state[k_model] = state2[k_ckpt]
source_info[k_model] = "ckpt2"
for k_model, tensor in model_state.items():
if k_model not in new_state:
if k_model in state1 and state1[k_model].shape == tensor.shape:
new_state[k_model] = state1[k_model]
source_info[k_model] = "fallback_ckpt1"
model.load_state_dict(new_state, strict=False)
print(f"Loaded {len(new_state)} / {len(model_state)} parameters")
return new_state
def main(args):
with open(args.config, "r") as f:
config = yaml.safe_load(f)
model_name = config.get("model", "AVCDiT-B/2")
print(f"Using model: {model_name}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AVCDiT_models[model_name](
context_size=4,
input_size=28,
in_channels=4,
mode="av"
).to(device)
depth = len(model.blocks)
map1 = {}
add_exact_keys(map1, [
"pos_embed_v",
"x_embedder_v.proj.weight",
"x_embedder_v.proj.bias",
"final_layer.linear.weight",
"final_layer.linear.bias",
"final_layer.adaLN_modulation.1.weight",
"final_layer.adaLN_modulation.1.bias",
])
add_mlp_block_keys(map1, "mlp_v", depth)
map2 = {}
add_exact_keys(map2, [
"pos_embed_a_cond",
"pos_embed_a_pred",
"x_embedder_a.weight",
"x_embedder_a.bias",
"final_layer_a.linear.weight",
"final_layer_a.linear.bias",
"final_layer_a.adaLN_modulation.1.weight",
"final_layer_a.adaLN_modulation.1.bias",
])
add_mlp_block_keys(map2, "mlp_a", depth)
merged_state_dict = load_from_two_checkpoints(
model,
ckpt1_path=args.v_expert,
ckpt2_path=args.a_expert,
map1=map1,
map2=map2,
device=device
)
torch.save({"ema": merged_state_dict}, args.output)
print(f"Merged model saved to {args.output}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--v_expert", type=str, required=True)
parser.add_argument("--a_expert", type=str, required=True)
parser.add_argument("--output", type=str, default="experts_merged.pth")
args = parser.parse_args()
main(args)