| import torch |
| import yaml |
| import argparse |
| from models import AVCDiT_models |
|
|
|
|
| def add_exact_keys(mapping, keys): |
| for k in keys: |
| mapping[k] = k |
|
|
|
|
| def add_mlp_block_keys(mapping, mlp_name, num_blocks): |
| for i in range(num_blocks): |
| for fc in ["fc1", "fc2"]: |
| for param in ["weight", "bias"]: |
| k = f"blocks.{i}.{mlp_name}.{fc}.{param}" |
| mapping[k] = k |
|
|
|
|
| def load_from_two_checkpoints(model, ckpt1_path, ckpt2_path, map1=None, map2=None, device='cuda'): |
| ckpt1 = torch.load(ckpt1_path, map_location=device, weights_only=False) |
| ckpt2 = torch.load(ckpt2_path, map_location=device, weights_only=False) |
|
|
| state1 = {k.replace('_orig_mod.', ''): v for k, v in ckpt1["ema"].items()} |
| state2 = {k.replace('_orig_mod.', ''): v for k, v in ckpt2["ema"].items()} |
|
|
| model_state = model.state_dict() |
|
|
| new_state = {} |
| source_info = {} |
|
|
| if map1: |
| for k_model, k_ckpt in map1.items(): |
| if ( |
| k_ckpt in state1 |
| and k_model in model_state |
| and state1[k_ckpt].shape == model_state[k_model].shape |
| ): |
| new_state[k_model] = state1[k_ckpt] |
| source_info[k_model] = "ckpt1" |
|
|
| if map2: |
| for k_model, k_ckpt in map2.items(): |
| if ( |
| k_ckpt in state2 |
| and k_model in model_state |
| and state2[k_ckpt].shape == model_state[k_model].shape |
| ): |
| new_state[k_model] = state2[k_ckpt] |
| source_info[k_model] = "ckpt2" |
|
|
| for k_model, tensor in model_state.items(): |
| if k_model not in new_state: |
| if k_model in state1 and state1[k_model].shape == tensor.shape: |
| new_state[k_model] = state1[k_model] |
| source_info[k_model] = "fallback_ckpt1" |
|
|
| model.load_state_dict(new_state, strict=False) |
| print(f"Loaded {len(new_state)} / {len(model_state)} parameters") |
|
|
| return new_state |
|
|
|
|
| def main(args): |
| with open(args.config, "r") as f: |
| config = yaml.safe_load(f) |
|
|
| model_name = config.get("model", "AVCDiT-B/2") |
| print(f"Using model: {model_name}") |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| model = AVCDiT_models[model_name]( |
| context_size=4, |
| input_size=28, |
| in_channels=4, |
| mode="av" |
| ).to(device) |
|
|
| depth = len(model.blocks) |
|
|
| map1 = {} |
| add_exact_keys(map1, [ |
| "pos_embed_v", |
| "x_embedder_v.proj.weight", |
| "x_embedder_v.proj.bias", |
| "final_layer.linear.weight", |
| "final_layer.linear.bias", |
| "final_layer.adaLN_modulation.1.weight", |
| "final_layer.adaLN_modulation.1.bias", |
| ]) |
| add_mlp_block_keys(map1, "mlp_v", depth) |
|
|
| map2 = {} |
| add_exact_keys(map2, [ |
| "pos_embed_a_cond", |
| "pos_embed_a_pred", |
| "x_embedder_a.weight", |
| "x_embedder_a.bias", |
| "final_layer_a.linear.weight", |
| "final_layer_a.linear.bias", |
| "final_layer_a.adaLN_modulation.1.weight", |
| "final_layer_a.adaLN_modulation.1.bias", |
| ]) |
| add_mlp_block_keys(map2, "mlp_a", depth) |
|
|
| merged_state_dict = load_from_two_checkpoints( |
| model, |
| ckpt1_path=args.v_expert, |
| ckpt2_path=args.a_expert, |
| map1=map1, |
| map2=map2, |
| device=device |
| ) |
|
|
| torch.save({"ema": merged_state_dict}, args.output) |
| print(f"Merged model saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--v_expert", type=str, required=True) |
| parser.add_argument("--a_expert", type=str, required=True) |
| parser.add_argument("--output", type=str, default="experts_merged.pth") |
| args = parser.parse_args() |
|
|
| main(args) |
|
|