| from __future__ import annotations |
|
|
| import argparse |
| import random |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from bandtok import BandTokPipeline |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Local BandTok inference smoke test.") |
| parser.add_argument( |
| "--repo_dir", |
| default=str(REPO_ROOT), |
| help="Local model repo directory containing config.yaml and bandtoklm.safetensors.", |
| ) |
| parser.add_argument("--prompt", default="A calm piano piece", help="Text prompt.") |
| parser.add_argument("--output", default="local_test_output.wav", help="Output wav path.") |
| parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.") |
| parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.") |
| parser.add_argument("--seed", type=int, default=0, help="Random seed. Use -1 to disable seeding.") |
| parser.add_argument("--cfg-scale", type=float, default=2.0) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--top-p", type=float, default=0.6) |
| parser.add_argument( |
| "--max-gen-len", |
| type=int, |
| default=None, |
| help="Optional token limit for quick debugging. Overrides duration-derived token length.", |
| ) |
| args = parser.parse_args() |
|
|
| repo_dir = Path(args.repo_dir).expanduser().resolve() |
| missing = [name for name in ("config.yaml", "bandtoklm.safetensors") if not (repo_dir / name).is_file()] |
| if missing: |
| raise FileNotFoundError(f"Missing required file(s) in {repo_dir}: {', '.join(missing)}") |
|
|
| if args.seed >= 0: |
| random.seed(args.seed) |
| np.random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| kwargs = { |
| "cfg_scale": args.cfg_scale, |
| "temperature": args.temperature, |
| "top_k": args.top_k, |
| "top_p": args.top_p, |
| } |
| if args.max_gen_len is not None: |
| kwargs["max_gen_len"] = args.max_gen_len |
|
|
| print(f"Loading BandTok from: {repo_dir}") |
| print(f"Device: {args.device}") |
| pipe = BandTokPipeline.from_pretrained(str(repo_dir), device=args.device) |
|
|
| print(f"Generating {args.duration:.2f}s for prompt: {args.prompt!r}") |
| audio = pipe.generate(args.prompt, duration=args.duration, **kwargs) |
| pipe.save(audio, args.output) |
| print(f"Saved: {Path(args.output).resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|