bandtok-model / examples /infer.py
xlbhzz's picture
Upload folder using huggingface_hub
ddc5f7d verified
from __future__ import annotations
import argparse
from bandtok import BandTokPipeline
def main() -> None:
parser = argparse.ArgumentParser(description="Run BandTok text-to-music inference.")
parser.add_argument("--repo_id", required=True, help="Hugging Face repo id or local repo directory.")
parser.add_argument("--prompt", required=True, help="Text prompt.")
parser.add_argument("--output", default="output.wav", help="Output wav path.")
parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.")
parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.")
parser.add_argument("--cfg-scale", type=float, default=None)
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top-k", type=int, default=None)
parser.add_argument("--top-p", type=float, default=None)
args = parser.parse_args()
kwargs = {}
if args.cfg_scale is not None:
kwargs["cfg_scale"] = args.cfg_scale
if args.temperature is not None:
kwargs["temperature"] = args.temperature
if args.top_k is not None:
kwargs["top_k"] = args.top_k
if args.top_p is not None:
kwargs["top_p"] = args.top_p
pipe = BandTokPipeline.from_pretrained(args.repo_id, device=args.device)
audio = pipe.generate(args.prompt, duration=args.duration, **kwargs)
pipe.save(audio, args.output)
print(f"Saved audio to {args.output}")
if __name__ == "__main__":
main()