| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from tqdm.auto import tqdm |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from bandtok import BandTokTokenizer |
| from bandtok.audio_utils import save_audio |
| from bandtok.config import DEFAULT_TOKENIZER_WEIGHTS_NAME |
|
|
|
|
| DEFAULT_EXTENSIONS = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac") |
|
|
|
|
| def parse_extensions(value: str) -> tuple[str, ...]: |
| extensions = [] |
| for item in value.split(","): |
| item = item.strip().lower() |
| if not item: |
| continue |
| extensions.append(item if item.startswith(".") else f".{item}") |
| if not extensions: |
| raise ValueError("At least one audio extension is required") |
| return tuple(extensions) |
|
|
|
|
| def find_audio(input_path: Path, extensions: tuple[str, ...]) -> list[Path]: |
| if input_path.is_file(): |
| return [input_path] |
| if not input_path.is_dir(): |
| raise FileNotFoundError(f"Input path not found: {input_path}") |
| return sorted(path for path in input_path.rglob("*") if path.is_file() and path.suffix.lower() in extensions) |
|
|
|
|
| def resolve_output_path(audio_path: Path, input_root: Path, output_root: Path, single_file: bool) -> Path: |
| if single_file and output_root.suffix: |
| return output_root |
| try: |
| rel = audio_path.relative_to(input_root) |
| except ValueError: |
| rel = Path(audio_path.name) |
| return output_root / rel.with_suffix(".wav") |
|
|
|
|
| def resolve_tokens_path(audio_path: Path, input_root: Path, tokens_root: Path, single_file: bool) -> Path: |
| if single_file and tokens_root.suffix: |
| return tokens_root |
| try: |
| rel = audio_path.relative_to(input_root) |
| except ValueError: |
| rel = Path(audio_path.name) |
| return tokens_root / rel.with_suffix(".pt") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Run BandTok tokenizer reconstruction inference.") |
| parser.add_argument( |
| "--repo_id", |
| default=str(REPO_ROOT), |
| help="Hugging Face repo id or local repo directory containing config.yaml and bandtok.safetensors.", |
| ) |
| parser.add_argument("--input", required=True, help="Input audio file or directory.") |
| parser.add_argument("--output", default="tokenizer_reconstructions", help="Output wav file or directory.") |
| parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.") |
| parser.add_argument("--weights-name", default=DEFAULT_TOKENIZER_WEIGHTS_NAME, help="Tokenizer safetensors filename.") |
| parser.add_argument("--save-tokens", default=None, help="Optional token .pt file or directory.") |
| parser.add_argument( |
| "--extensions", |
| default=",".join(DEFAULT_EXTENSIONS), |
| help="Comma-separated extensions to scan when --input is a directory.", |
| ) |
| args = parser.parse_args() |
|
|
| repo_id = str(Path(args.repo_id).expanduser().resolve()) if Path(args.repo_id).expanduser().exists() else args.repo_id |
| input_path = Path(args.input).expanduser().resolve() |
| output_root = Path(args.output).expanduser() |
| tokens_root = Path(args.save_tokens).expanduser() if args.save_tokens else None |
| extensions = parse_extensions(args.extensions) |
| audio_paths = find_audio(input_path, extensions) |
| if not audio_paths: |
| raise FileNotFoundError(f"No audio files found under: {input_path}") |
|
|
| print(f"Loading BandTok tokenizer from: {repo_id}") |
| print(f"Tokenizer weights: {args.weights_name}") |
| print(f"Device: {args.device}") |
| tokenizer = BandTokTokenizer.from_pretrained(repo_id, device=args.device, weights_name=args.weights_name) |
|
|
| single_file = input_path.is_file() |
| input_root = input_path.parent if single_file else input_path |
| for audio_path in tqdm(audio_paths, desc="Reconstructing"): |
| out_path = resolve_output_path(audio_path, input_root, output_root, single_file) |
| tokens_path = resolve_tokens_path(audio_path, input_root, tokens_root, single_file) if tokens_root else None |
|
|
| tokens = tokenizer.encode(str(audio_path)) |
| if tokens_path is not None: |
| tokens_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(tokens, tokens_path) |
|
|
| audio = tokenizer.decode(tokens) |
| save_audio(audio, out_path, tokenizer.sample_rate) |
| print(f"Saved {out_path} tokens_shape={tuple(tokens.shape)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|