AuralSAM2 / avs.code /v1s.code /tools /remap_aural_ckpt_keys.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
raw
history blame
3.44 kB
#!/usr/bin/env python3
"""
Remap legacy checkpoint keys: rename audio_prompter.* to the current AuralFuser layout (aural_fuser.*),
and drop duplicate weights under training_layers / finetuning_layers.
Usage:
python tools/remap_aural_ckpt_keys.py /path/to/model.pth [--in-place] [--no-backup]
By default writes <stem>_remapped.pth; --in-place overwrites the input (after a .bak backup unless --no-backup).
"""
from __future__ import annotations
import argparse
import shutil
from pathlib import Path
import torch
# Matches AuralFuser ModuleList names (old train_* indices start at 1; new indices are 0-based).
_REPLACEMENTS: list[tuple[str, str]] = [
("train_f_patch_embed1", "patch_embeds.0"),
("train_f_patch_embed2", "patch_embeds.1"),
("train_f_patch_embed3", "patch_embeds.2"),
("train_f_a_block1", "fusion_modules.0"),
("train_f_a_block2", "fusion_modules.1"),
("train_f_a_block3", "fusion_modules.2"),
("train_f_block1", "f_blocks.0"),
("train_f_block2", "f_blocks.1"),
("train_f_block3", "f_blocks.2"),
("train_a_block1", "a_blocks.0"),
("train_a_block2", "a_blocks.1"),
("train_a_block3", "a_blocks.2"),
("train_smooth1", "smooth_convs.0"),
("train_smooth2", "smooth_convs.1"),
]
def remap_state_dict(sd: dict) -> dict:
out: dict = {}
dropped = 0
for k, v in sd.items():
if k.startswith("audio_prompter."):
if ".training_layers." in k or ".finetuning_layers." in k:
dropped += 1
continue
nk = k.replace("audio_prompter.", "aural_fuser.", 1)
for old, new in _REPLACEMENTS:
nk = nk.replace(old, new)
out[nk] = v
else:
out[k] = v
if dropped:
print(f"Dropped duplicate keys: {dropped} (training_layers / finetuning_layers)")
return out
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("ckpt", type=Path, help="Input .pth (full-model state_dict)")
ap.add_argument(
"-o", "--output", type=Path, default=None,
help="Output path; default <stem>_remapped.pth",
)
ap.add_argument("--in-place", action="store_true", help="Overwrite input file")
ap.add_argument("--no-backup", action="store_true", help="Skip .bak when using --in-place")
args = ap.parse_args()
ckpt_path: Path = args.ckpt.resolve()
if not ckpt_path.is_file():
raise SystemExit(f"File not found: {ckpt_path}")
print(f"Loading: {ckpt_path}")
sd = torch.load(ckpt_path, map_location="cpu")
if not isinstance(sd, dict):
raise SystemExit("Expected top-level checkpoint to be a state_dict dict")
n_old_ap = sum(1 for k in sd if k.startswith("audio_prompter."))
if n_old_ap == 0:
print("Warning: no audio_prompter.* keys found; checkpoint may already be remapped.")
new_sd = remap_state_dict(sd)
n_af = sum(1 for k in new_sd if k.startswith("aural_fuser."))
print(f"aural_fuser key count: {n_af}")
if args.in_place:
out = ckpt_path
if not args.no_backup:
bak = ckpt_path.with_suffix(ckpt_path.suffix + ".bak")
print(f"Backup -> {bak}")
shutil.copy2(ckpt_path, bak)
else:
out = args.output or ckpt_path.with_name(ckpt_path.stem + "_remapped.pth")
torch.save(new_sd, out)
print(f"Saved: {out} ({len(new_sd)} tensor keys)")
if __name__ == "__main__":
main()