import argparse import os from pathlib import Path import tempfile import traceback import numpy as np import torch from tqdm import tqdm def parse_args(): parser = argparse.ArgumentParser( description="Batch encode MP3 files to MuCodec codes (recursive)." ) parser.add_argument("input_dir", type=Path, help="Input folder (recursive scan)") parser.add_argument("output_dir", type=Path, help="Output folder for saved codes") parser.add_argument( "--ckpt", type=Path, default=Path(__file__).resolve().parent / "ckpt" / "mucodec.pt", help="Path to MuCodec checkpoint", ) parser.add_argument( "--layer-num", type=int, default=7, help="MuCodec layer num (default follows generate.py)", ) parser.add_argument( "--device", default="cuda:0", help="Torch device, e.g. cuda:0", ) parser.add_argument( "--ext", nargs="+", default=[".mp3"], help="Audio extensions to include, e.g. .mp3 .wav .flac", ) parser.add_argument( "--format", choices=["npz", "pt", "npy", "both", "all"], default="npz", help="Output format for code files", ) parser.add_argument( "--overwrite", action="store_true", help="Recompute files even if output already exists (disable resume)", ) parser.add_argument( "--strict", action="store_true", help="Stop immediately on first failed file", ) return parser.parse_args() def list_audio_files(root: Path, exts): ext_set = {e.lower() if e.startswith(".") else f".{e.lower()}" for e in exts} files = [ p for p in root.rglob("*") if p.is_file() and p.suffix.lower() in ext_set ] files.sort() return files def expected_output_paths(output_stem: Path, fmt: str): if fmt == "npz": return [output_stem.with_suffix(".npz")] if fmt == "pt": return [output_stem.with_suffix(".pt")] if fmt == "npy": return [output_stem.with_suffix(".npy")] if fmt == "both": return [output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy")] if fmt == "all": return [ output_stem.with_suffix(".npz"), output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy"), ] raise ValueError(f"Unsupported format: {fmt}") def save_npz_atomic(codes_np: np.ndarray, output_path: Path): output_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = None try: with tempfile.NamedTemporaryFile( mode="wb", suffix=".npz", dir=output_path.parent, delete=False, ) as tmp_file: tmp_path = Path(tmp_file.name) np.savez_compressed(tmp_file, codes=codes_np) os.replace(tmp_path, output_path) except Exception: if tmp_path is not None and tmp_path.exists(): tmp_path.unlink() raise def save_codes(codes: torch.Tensor, output_stem: Path, fmt: str): codes_cpu = codes.detach().cpu() codes_np = codes_cpu.numpy() if fmt in ("npz", "all"): save_npz_atomic(codes_np, output_stem.with_suffix(".npz")) if fmt in ("pt", "both", "all"): torch.save(codes_cpu, output_stem.with_suffix(".pt")) if fmt in ("npy", "both", "all"): np.save(output_stem.with_suffix(".npy"), codes_np) def main(): args = parse_args() from generate import MuCodec if not args.input_dir.exists() or not args.input_dir.is_dir(): raise ValueError(f"input_dir does not exist or is not a directory: {args.input_dir}") if not args.ckpt.exists(): raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}") if args.device.startswith("cuda") and not torch.cuda.is_available(): raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False") audio_files = list_audio_files(args.input_dir, args.ext) if not audio_files: print("No audio files found.") return args.output_dir.mkdir(parents=True, exist_ok=True) mucodec = MuCodec( model_path=str(args.ckpt), layer_num=args.layer_num, load_main_model=True, device=args.device, ) resume_enabled = not args.overwrite ok = 0 skipped = 0 failed = [] for src in tqdm(audio_files, desc="Encoding", unit="file"): rel = src.relative_to(args.input_dir) output_stem = (args.output_dir / rel).with_suffix("") output_paths = expected_output_paths(output_stem, args.format) if resume_enabled and all(p.exists() for p in output_paths): skipped += 1 continue output_stem.parent.mkdir(parents=True, exist_ok=True) try: codes = mucodec.file2code(str(src)) save_codes(codes, output_stem, args.format) ok += 1 except Exception as e: failed.append((src, str(e))) print(f"[FAILED] {src}: {e}") if args.strict: print("--strict enabled, stopping on first failure.") traceback.print_exc() break print( "Done. " f"success={ok}, skipped={skipped}, failed={len(failed)}, total={len(audio_files)}" ) if failed: print("Failed files:") for path, err in failed: print(f"- {path}: {err}") if __name__ == "__main__": main()