from __future__ import annotations import argparse import sys from pathlib import Path import torch from tqdm.auto import tqdm REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from bandtok import BandTokTokenizer # noqa: E402 from bandtok.audio_utils import save_audio # noqa: E402 from bandtok.config import DEFAULT_TOKENIZER_WEIGHTS_NAME # noqa: E402 DEFAULT_EXTENSIONS = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac") def parse_extensions(value: str) -> tuple[str, ...]: extensions = [] for item in value.split(","): item = item.strip().lower() if not item: continue extensions.append(item if item.startswith(".") else f".{item}") if not extensions: raise ValueError("At least one audio extension is required") return tuple(extensions) def find_audio(input_path: Path, extensions: tuple[str, ...]) -> list[Path]: if input_path.is_file(): return [input_path] if not input_path.is_dir(): raise FileNotFoundError(f"Input path not found: {input_path}") return sorted(path for path in input_path.rglob("*") if path.is_file() and path.suffix.lower() in extensions) def resolve_output_path(audio_path: Path, input_root: Path, output_root: Path, single_file: bool) -> Path: if single_file and output_root.suffix: return output_root try: rel = audio_path.relative_to(input_root) except ValueError: rel = Path(audio_path.name) return output_root / rel.with_suffix(".wav") def resolve_tokens_path(audio_path: Path, input_root: Path, tokens_root: Path, single_file: bool) -> Path: if single_file and tokens_root.suffix: return tokens_root try: rel = audio_path.relative_to(input_root) except ValueError: rel = Path(audio_path.name) return tokens_root / rel.with_suffix(".pt") def main() -> None: parser = argparse.ArgumentParser(description="Run BandTok tokenizer reconstruction inference.") parser.add_argument( "--repo_id", default=str(REPO_ROOT), help="Hugging Face repo id or local repo directory containing config.yaml and bandtok.safetensors.", ) parser.add_argument("--input", required=True, help="Input audio file or directory.") parser.add_argument("--output", default="tokenizer_reconstructions", help="Output wav file or directory.") parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.") parser.add_argument("--weights-name", default=DEFAULT_TOKENIZER_WEIGHTS_NAME, help="Tokenizer safetensors filename.") parser.add_argument("--save-tokens", default=None, help="Optional token .pt file or directory.") parser.add_argument( "--extensions", default=",".join(DEFAULT_EXTENSIONS), help="Comma-separated extensions to scan when --input is a directory.", ) args = parser.parse_args() repo_id = str(Path(args.repo_id).expanduser().resolve()) if Path(args.repo_id).expanduser().exists() else args.repo_id input_path = Path(args.input).expanduser().resolve() output_root = Path(args.output).expanduser() tokens_root = Path(args.save_tokens).expanduser() if args.save_tokens else None extensions = parse_extensions(args.extensions) audio_paths = find_audio(input_path, extensions) if not audio_paths: raise FileNotFoundError(f"No audio files found under: {input_path}") print(f"Loading BandTok tokenizer from: {repo_id}") print(f"Tokenizer weights: {args.weights_name}") print(f"Device: {args.device}") tokenizer = BandTokTokenizer.from_pretrained(repo_id, device=args.device, weights_name=args.weights_name) single_file = input_path.is_file() input_root = input_path.parent if single_file else input_path for audio_path in tqdm(audio_paths, desc="Reconstructing"): out_path = resolve_output_path(audio_path, input_root, output_root, single_file) tokens_path = resolve_tokens_path(audio_path, input_root, tokens_root, single_file) if tokens_root else None tokens = tokenizer.encode(str(audio_path)) if tokens_path is not None: tokens_path.parent.mkdir(parents=True, exist_ok=True) torch.save(tokens, tokens_path) audio = tokenizer.decode(tokens) save_audio(audio, out_path, tokenizer.sample_rate) print(f"Saved {out_path} tokens_shape={tuple(tokens.shape)}") if __name__ == "__main__": main()