File size: 4,521 Bytes
ddc5f7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
from tqdm.auto import tqdm
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from bandtok import BandTokTokenizer # noqa: E402
from bandtok.audio_utils import save_audio # noqa: E402
from bandtok.config import DEFAULT_TOKENIZER_WEIGHTS_NAME # noqa: E402
DEFAULT_EXTENSIONS = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac")
def parse_extensions(value: str) -> tuple[str, ...]:
extensions = []
for item in value.split(","):
item = item.strip().lower()
if not item:
continue
extensions.append(item if item.startswith(".") else f".{item}")
if not extensions:
raise ValueError("At least one audio extension is required")
return tuple(extensions)
def find_audio(input_path: Path, extensions: tuple[str, ...]) -> list[Path]:
if input_path.is_file():
return [input_path]
if not input_path.is_dir():
raise FileNotFoundError(f"Input path not found: {input_path}")
return sorted(path for path in input_path.rglob("*") if path.is_file() and path.suffix.lower() in extensions)
def resolve_output_path(audio_path: Path, input_root: Path, output_root: Path, single_file: bool) -> Path:
if single_file and output_root.suffix:
return output_root
try:
rel = audio_path.relative_to(input_root)
except ValueError:
rel = Path(audio_path.name)
return output_root / rel.with_suffix(".wav")
def resolve_tokens_path(audio_path: Path, input_root: Path, tokens_root: Path, single_file: bool) -> Path:
if single_file and tokens_root.suffix:
return tokens_root
try:
rel = audio_path.relative_to(input_root)
except ValueError:
rel = Path(audio_path.name)
return tokens_root / rel.with_suffix(".pt")
def main() -> None:
parser = argparse.ArgumentParser(description="Run BandTok tokenizer reconstruction inference.")
parser.add_argument(
"--repo_id",
default=str(REPO_ROOT),
help="Hugging Face repo id or local repo directory containing config.yaml and bandtok.safetensors.",
)
parser.add_argument("--input", required=True, help="Input audio file or directory.")
parser.add_argument("--output", default="tokenizer_reconstructions", help="Output wav file or directory.")
parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.")
parser.add_argument("--weights-name", default=DEFAULT_TOKENIZER_WEIGHTS_NAME, help="Tokenizer safetensors filename.")
parser.add_argument("--save-tokens", default=None, help="Optional token .pt file or directory.")
parser.add_argument(
"--extensions",
default=",".join(DEFAULT_EXTENSIONS),
help="Comma-separated extensions to scan when --input is a directory.",
)
args = parser.parse_args()
repo_id = str(Path(args.repo_id).expanduser().resolve()) if Path(args.repo_id).expanduser().exists() else args.repo_id
input_path = Path(args.input).expanduser().resolve()
output_root = Path(args.output).expanduser()
tokens_root = Path(args.save_tokens).expanduser() if args.save_tokens else None
extensions = parse_extensions(args.extensions)
audio_paths = find_audio(input_path, extensions)
if not audio_paths:
raise FileNotFoundError(f"No audio files found under: {input_path}")
print(f"Loading BandTok tokenizer from: {repo_id}")
print(f"Tokenizer weights: {args.weights_name}")
print(f"Device: {args.device}")
tokenizer = BandTokTokenizer.from_pretrained(repo_id, device=args.device, weights_name=args.weights_name)
single_file = input_path.is_file()
input_root = input_path.parent if single_file else input_path
for audio_path in tqdm(audio_paths, desc="Reconstructing"):
out_path = resolve_output_path(audio_path, input_root, output_root, single_file)
tokens_path = resolve_tokens_path(audio_path, input_root, tokens_root, single_file) if tokens_root else None
tokens = tokenizer.encode(str(audio_path))
if tokens_path is not None:
tokens_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(tokens, tokens_path)
audio = tokenizer.decode(tokens)
save_audio(audio, out_path, tokenizer.sample_rate)
print(f"Saved {out_path} tokens_shape={tuple(tokens.shape)}")
if __name__ == "__main__":
main()
|