cond_gen / MuCodec /mp3_to_code.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import argparse
import os
from pathlib import Path
import tempfile
import traceback
import numpy as np
import torch
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(
description="Batch encode MP3 files to MuCodec codes (recursive)."
)
parser.add_argument("input_dir", type=Path, help="Input folder (recursive scan)")
parser.add_argument("output_dir", type=Path, help="Output folder for saved codes")
parser.add_argument(
"--ckpt",
type=Path,
default=Path(__file__).resolve().parent / "ckpt" / "mucodec.pt",
help="Path to MuCodec checkpoint",
)
parser.add_argument(
"--layer-num",
type=int,
default=7,
help="MuCodec layer num (default follows generate.py)",
)
parser.add_argument(
"--device",
default="cuda:0",
help="Torch device, e.g. cuda:0",
)
parser.add_argument(
"--ext",
nargs="+",
default=[".mp3"],
help="Audio extensions to include, e.g. .mp3 .wav .flac",
)
parser.add_argument(
"--format",
choices=["npz", "pt", "npy", "both", "all"],
default="npz",
help="Output format for code files",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Recompute files even if output already exists (disable resume)",
)
parser.add_argument(
"--strict",
action="store_true",
help="Stop immediately on first failed file",
)
return parser.parse_args()
def list_audio_files(root: Path, exts):
ext_set = {e.lower() if e.startswith(".") else f".{e.lower()}" for e in exts}
files = [
p
for p in root.rglob("*")
if p.is_file() and p.suffix.lower() in ext_set
]
files.sort()
return files
def expected_output_paths(output_stem: Path, fmt: str):
if fmt == "npz":
return [output_stem.with_suffix(".npz")]
if fmt == "pt":
return [output_stem.with_suffix(".pt")]
if fmt == "npy":
return [output_stem.with_suffix(".npy")]
if fmt == "both":
return [output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy")]
if fmt == "all":
return [
output_stem.with_suffix(".npz"),
output_stem.with_suffix(".pt"),
output_stem.with_suffix(".npy"),
]
raise ValueError(f"Unsupported format: {fmt}")
def save_npz_atomic(codes_np: np.ndarray, output_path: Path):
output_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
mode="wb",
suffix=".npz",
dir=output_path.parent,
delete=False,
) as tmp_file:
tmp_path = Path(tmp_file.name)
np.savez_compressed(tmp_file, codes=codes_np)
os.replace(tmp_path, output_path)
except Exception:
if tmp_path is not None and tmp_path.exists():
tmp_path.unlink()
raise
def save_codes(codes: torch.Tensor, output_stem: Path, fmt: str):
codes_cpu = codes.detach().cpu()
codes_np = codes_cpu.numpy()
if fmt in ("npz", "all"):
save_npz_atomic(codes_np, output_stem.with_suffix(".npz"))
if fmt in ("pt", "both", "all"):
torch.save(codes_cpu, output_stem.with_suffix(".pt"))
if fmt in ("npy", "both", "all"):
np.save(output_stem.with_suffix(".npy"), codes_np)
def main():
args = parse_args()
from generate import MuCodec
if not args.input_dir.exists() or not args.input_dir.is_dir():
raise ValueError(f"input_dir does not exist or is not a directory: {args.input_dir}")
if not args.ckpt.exists():
raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}")
if args.device.startswith("cuda") and not torch.cuda.is_available():
raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False")
audio_files = list_audio_files(args.input_dir, args.ext)
if not audio_files:
print("No audio files found.")
return
args.output_dir.mkdir(parents=True, exist_ok=True)
mucodec = MuCodec(
model_path=str(args.ckpt),
layer_num=args.layer_num,
load_main_model=True,
device=args.device,
)
resume_enabled = not args.overwrite
ok = 0
skipped = 0
failed = []
for src in tqdm(audio_files, desc="Encoding", unit="file"):
rel = src.relative_to(args.input_dir)
output_stem = (args.output_dir / rel).with_suffix("")
output_paths = expected_output_paths(output_stem, args.format)
if resume_enabled and all(p.exists() for p in output_paths):
skipped += 1
continue
output_stem.parent.mkdir(parents=True, exist_ok=True)
try:
codes = mucodec.file2code(str(src))
save_codes(codes, output_stem, args.format)
ok += 1
except Exception as e:
failed.append((src, str(e)))
print(f"[FAILED] {src}: {e}")
if args.strict:
print("--strict enabled, stopping on first failure.")
traceback.print_exc()
break
print(
"Done. "
f"success={ok}, skipped={skipped}, failed={len(failed)}, total={len(audio_files)}"
)
if failed:
print("Failed files:")
for path, err in failed:
print(f"- {path}: {err}")
if __name__ == "__main__":
main()