| from __future__ import annotations | |
| import argparse | |
| from bandtok import BandTokPipeline | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run BandTok text-to-music inference.") | |
| parser.add_argument("--repo_id", required=True, help="Hugging Face repo id or local repo directory.") | |
| parser.add_argument("--prompt", required=True, help="Text prompt.") | |
| parser.add_argument("--output", default="output.wav", help="Output wav path.") | |
| parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.") | |
| parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.") | |
| parser.add_argument("--cfg-scale", type=float, default=None) | |
| parser.add_argument("--temperature", type=float, default=None) | |
| parser.add_argument("--top-k", type=int, default=None) | |
| parser.add_argument("--top-p", type=float, default=None) | |
| args = parser.parse_args() | |
| kwargs = {} | |
| if args.cfg_scale is not None: | |
| kwargs["cfg_scale"] = args.cfg_scale | |
| if args.temperature is not None: | |
| kwargs["temperature"] = args.temperature | |
| if args.top_k is not None: | |
| kwargs["top_k"] = args.top_k | |
| if args.top_p is not None: | |
| kwargs["top_p"] = args.top_p | |
| pipe = BandTokPipeline.from_pretrained(args.repo_id, device=args.device) | |
| audio = pipe.generate(args.prompt, duration=args.duration, **kwargs) | |
| pipe.save(audio, args.output) | |
| print(f"Saved audio to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |