File size: 4,521 Bytes
ddc5f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import argparse
import sys
from pathlib import Path

import torch
from tqdm.auto import tqdm


REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from bandtok import BandTokTokenizer  # noqa: E402
from bandtok.audio_utils import save_audio  # noqa: E402
from bandtok.config import DEFAULT_TOKENIZER_WEIGHTS_NAME  # noqa: E402


DEFAULT_EXTENSIONS = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac")


def parse_extensions(value: str) -> tuple[str, ...]:
    extensions = []
    for item in value.split(","):
        item = item.strip().lower()
        if not item:
            continue
        extensions.append(item if item.startswith(".") else f".{item}")
    if not extensions:
        raise ValueError("At least one audio extension is required")
    return tuple(extensions)


def find_audio(input_path: Path, extensions: tuple[str, ...]) -> list[Path]:
    if input_path.is_file():
        return [input_path]
    if not input_path.is_dir():
        raise FileNotFoundError(f"Input path not found: {input_path}")
    return sorted(path for path in input_path.rglob("*") if path.is_file() and path.suffix.lower() in extensions)


def resolve_output_path(audio_path: Path, input_root: Path, output_root: Path, single_file: bool) -> Path:
    if single_file and output_root.suffix:
        return output_root
    try:
        rel = audio_path.relative_to(input_root)
    except ValueError:
        rel = Path(audio_path.name)
    return output_root / rel.with_suffix(".wav")


def resolve_tokens_path(audio_path: Path, input_root: Path, tokens_root: Path, single_file: bool) -> Path:
    if single_file and tokens_root.suffix:
        return tokens_root
    try:
        rel = audio_path.relative_to(input_root)
    except ValueError:
        rel = Path(audio_path.name)
    return tokens_root / rel.with_suffix(".pt")


def main() -> None:
    parser = argparse.ArgumentParser(description="Run BandTok tokenizer reconstruction inference.")
    parser.add_argument(
        "--repo_id",
        default=str(REPO_ROOT),
        help="Hugging Face repo id or local repo directory containing config.yaml and bandtok.safetensors.",
    )
    parser.add_argument("--input", required=True, help="Input audio file or directory.")
    parser.add_argument("--output", default="tokenizer_reconstructions", help="Output wav file or directory.")
    parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.")
    parser.add_argument("--weights-name", default=DEFAULT_TOKENIZER_WEIGHTS_NAME, help="Tokenizer safetensors filename.")
    parser.add_argument("--save-tokens", default=None, help="Optional token .pt file or directory.")
    parser.add_argument(
        "--extensions",
        default=",".join(DEFAULT_EXTENSIONS),
        help="Comma-separated extensions to scan when --input is a directory.",
    )
    args = parser.parse_args()

    repo_id = str(Path(args.repo_id).expanduser().resolve()) if Path(args.repo_id).expanduser().exists() else args.repo_id
    input_path = Path(args.input).expanduser().resolve()
    output_root = Path(args.output).expanduser()
    tokens_root = Path(args.save_tokens).expanduser() if args.save_tokens else None
    extensions = parse_extensions(args.extensions)
    audio_paths = find_audio(input_path, extensions)
    if not audio_paths:
        raise FileNotFoundError(f"No audio files found under: {input_path}")

    print(f"Loading BandTok tokenizer from: {repo_id}")
    print(f"Tokenizer weights: {args.weights_name}")
    print(f"Device: {args.device}")
    tokenizer = BandTokTokenizer.from_pretrained(repo_id, device=args.device, weights_name=args.weights_name)

    single_file = input_path.is_file()
    input_root = input_path.parent if single_file else input_path
    for audio_path in tqdm(audio_paths, desc="Reconstructing"):
        out_path = resolve_output_path(audio_path, input_root, output_root, single_file)
        tokens_path = resolve_tokens_path(audio_path, input_root, tokens_root, single_file) if tokens_root else None

        tokens = tokenizer.encode(str(audio_path))
        if tokens_path is not None:
            tokens_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save(tokens, tokens_path)

        audio = tokenizer.decode(tokens)
        save_audio(audio, out_path, tokenizer.sample_rate)
        print(f"Saved {out_path} tokens_shape={tuple(tokens.shape)}")


if __name__ == "__main__":
    main()