| import argparse |
| import os |
| from pathlib import Path |
| import tempfile |
| import traceback |
|
|
| import numpy as np |
| import torch |
| from tqdm import tqdm |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Batch encode MP3 files to MuCodec codes (recursive)." |
| ) |
| parser.add_argument("input_dir", type=Path, help="Input folder (recursive scan)") |
| parser.add_argument("output_dir", type=Path, help="Output folder for saved codes") |
| parser.add_argument( |
| "--ckpt", |
| type=Path, |
| default=Path(__file__).resolve().parent / "ckpt" / "mucodec.pt", |
| help="Path to MuCodec checkpoint", |
| ) |
| parser.add_argument( |
| "--layer-num", |
| type=int, |
| default=7, |
| help="MuCodec layer num (default follows generate.py)", |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda:0", |
| help="Torch device, e.g. cuda:0", |
| ) |
| parser.add_argument( |
| "--ext", |
| nargs="+", |
| default=[".mp3"], |
| help="Audio extensions to include, e.g. .mp3 .wav .flac", |
| ) |
| parser.add_argument( |
| "--format", |
| choices=["npz", "pt", "npy", "both", "all"], |
| default="npz", |
| help="Output format for code files", |
| ) |
| parser.add_argument( |
| "--overwrite", |
| action="store_true", |
| help="Recompute files even if output already exists (disable resume)", |
| ) |
| parser.add_argument( |
| "--strict", |
| action="store_true", |
| help="Stop immediately on first failed file", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def list_audio_files(root: Path, exts): |
| ext_set = {e.lower() if e.startswith(".") else f".{e.lower()}" for e in exts} |
| files = [ |
| p |
| for p in root.rglob("*") |
| if p.is_file() and p.suffix.lower() in ext_set |
| ] |
| files.sort() |
| return files |
|
|
|
|
| def expected_output_paths(output_stem: Path, fmt: str): |
| if fmt == "npz": |
| return [output_stem.with_suffix(".npz")] |
| if fmt == "pt": |
| return [output_stem.with_suffix(".pt")] |
| if fmt == "npy": |
| return [output_stem.with_suffix(".npy")] |
| if fmt == "both": |
| return [output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy")] |
| if fmt == "all": |
| return [ |
| output_stem.with_suffix(".npz"), |
| output_stem.with_suffix(".pt"), |
| output_stem.with_suffix(".npy"), |
| ] |
| raise ValueError(f"Unsupported format: {fmt}") |
|
|
|
|
| def save_npz_atomic(codes_np: np.ndarray, output_path: Path): |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| tmp_path = None |
| try: |
| with tempfile.NamedTemporaryFile( |
| mode="wb", |
| suffix=".npz", |
| dir=output_path.parent, |
| delete=False, |
| ) as tmp_file: |
| tmp_path = Path(tmp_file.name) |
| np.savez_compressed(tmp_file, codes=codes_np) |
| os.replace(tmp_path, output_path) |
| except Exception: |
| if tmp_path is not None and tmp_path.exists(): |
| tmp_path.unlink() |
| raise |
|
|
|
|
| def save_codes(codes: torch.Tensor, output_stem: Path, fmt: str): |
| codes_cpu = codes.detach().cpu() |
| codes_np = codes_cpu.numpy() |
| if fmt in ("npz", "all"): |
| save_npz_atomic(codes_np, output_stem.with_suffix(".npz")) |
| if fmt in ("pt", "both", "all"): |
| torch.save(codes_cpu, output_stem.with_suffix(".pt")) |
| if fmt in ("npy", "both", "all"): |
| np.save(output_stem.with_suffix(".npy"), codes_np) |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| from generate import MuCodec |
|
|
| if not args.input_dir.exists() or not args.input_dir.is_dir(): |
| raise ValueError(f"input_dir does not exist or is not a directory: {args.input_dir}") |
|
|
| if not args.ckpt.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}") |
|
|
| if args.device.startswith("cuda") and not torch.cuda.is_available(): |
| raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False") |
|
|
| audio_files = list_audio_files(args.input_dir, args.ext) |
| if not audio_files: |
| print("No audio files found.") |
| return |
|
|
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| mucodec = MuCodec( |
| model_path=str(args.ckpt), |
| layer_num=args.layer_num, |
| load_main_model=True, |
| device=args.device, |
| ) |
|
|
| resume_enabled = not args.overwrite |
| ok = 0 |
| skipped = 0 |
| failed = [] |
|
|
| for src in tqdm(audio_files, desc="Encoding", unit="file"): |
| rel = src.relative_to(args.input_dir) |
| output_stem = (args.output_dir / rel).with_suffix("") |
| output_paths = expected_output_paths(output_stem, args.format) |
|
|
| if resume_enabled and all(p.exists() for p in output_paths): |
| skipped += 1 |
| continue |
|
|
| output_stem.parent.mkdir(parents=True, exist_ok=True) |
|
|
| try: |
| codes = mucodec.file2code(str(src)) |
| save_codes(codes, output_stem, args.format) |
| ok += 1 |
| except Exception as e: |
| failed.append((src, str(e))) |
| print(f"[FAILED] {src}: {e}") |
| if args.strict: |
| print("--strict enabled, stopping on first failure.") |
| traceback.print_exc() |
| break |
|
|
| print( |
| "Done. " |
| f"success={ok}, skipped={skipped}, failed={len(failed)}, total={len(audio_files)}" |
| ) |
| if failed: |
| print("Failed files:") |
| for path, err in failed: |
| print(f"- {path}: {err}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|