bandtok-model / examples /local_test_infer.py
xlbhzz's picture
Upload folder using huggingface_hub
ddc5f7d verified
from __future__ import annotations
import argparse
import random
import sys
from pathlib import Path
import numpy as np
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from bandtok import BandTokPipeline # noqa: E402
def main() -> None:
parser = argparse.ArgumentParser(description="Local BandTok inference smoke test.")
parser.add_argument(
"--repo_dir",
default=str(REPO_ROOT),
help="Local model repo directory containing config.yaml and bandtoklm.safetensors.",
)
parser.add_argument("--prompt", default="A calm piano piece", help="Text prompt.")
parser.add_argument("--output", default="local_test_output.wav", help="Output wav path.")
parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.")
parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.")
parser.add_argument("--seed", type=int, default=0, help="Random seed. Use -1 to disable seeding.")
parser.add_argument("--cfg-scale", type=float, default=2.0)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--top-p", type=float, default=0.6)
parser.add_argument(
"--max-gen-len",
type=int,
default=None,
help="Optional token limit for quick debugging. Overrides duration-derived token length.",
)
args = parser.parse_args()
repo_dir = Path(args.repo_dir).expanduser().resolve()
missing = [name for name in ("config.yaml", "bandtoklm.safetensors") if not (repo_dir / name).is_file()]
if missing:
raise FileNotFoundError(f"Missing required file(s) in {repo_dir}: {', '.join(missing)}")
if args.seed >= 0:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
kwargs = {
"cfg_scale": args.cfg_scale,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
}
if args.max_gen_len is not None:
kwargs["max_gen_len"] = args.max_gen_len
print(f"Loading BandTok from: {repo_dir}")
print(f"Device: {args.device}")
pipe = BandTokPipeline.from_pretrained(str(repo_dir), device=args.device)
print(f"Generating {args.duration:.2f}s for prompt: {args.prompt!r}")
audio = pipe.generate(args.prompt, duration=args.duration, **kwargs)
pipe.save(audio, args.output)
print(f"Saved: {Path(args.output).resolve()}")
if __name__ == "__main__":
main()