from __future__ import annotations import argparse import random import sys from pathlib import Path import numpy as np import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from bandtok import BandTokPipeline # noqa: E402 def main() -> None: parser = argparse.ArgumentParser(description="Local BandTok inference smoke test.") parser.add_argument( "--repo_dir", default=str(REPO_ROOT), help="Local model repo directory containing config.yaml and bandtoklm.safetensors.", ) parser.add_argument("--prompt", default="A calm piano piece", help="Text prompt.") parser.add_argument("--output", default="local_test_output.wav", help="Output wav path.") parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.") parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.") parser.add_argument("--seed", type=int, default=0, help="Random seed. Use -1 to disable seeding.") parser.add_argument("--cfg-scale", type=float, default=2.0) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-k", type=int, default=50) parser.add_argument("--top-p", type=float, default=0.6) parser.add_argument( "--max-gen-len", type=int, default=None, help="Optional token limit for quick debugging. Overrides duration-derived token length.", ) args = parser.parse_args() repo_dir = Path(args.repo_dir).expanduser().resolve() missing = [name for name in ("config.yaml", "bandtoklm.safetensors") if not (repo_dir / name).is_file()] if missing: raise FileNotFoundError(f"Missing required file(s) in {repo_dir}: {', '.join(missing)}") if args.seed >= 0: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) kwargs = { "cfg_scale": args.cfg_scale, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, } if args.max_gen_len is not None: kwargs["max_gen_len"] = args.max_gen_len print(f"Loading BandTok from: {repo_dir}") print(f"Device: {args.device}") pipe = BandTokPipeline.from_pretrained(str(repo_dir), device=args.device) print(f"Generating {args.duration:.2f}s for prompt: {args.prompt!r}") audio = pipe.generate(args.prompt, duration=args.duration, **kwargs) pipe.save(audio, args.output) print(f"Saved: {Path(args.output).resolve()}") if __name__ == "__main__": main()