File size: 1,516 Bytes
ddc5f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import annotations

import argparse

from bandtok import BandTokPipeline


def main() -> None:
    parser = argparse.ArgumentParser(description="Run BandTok text-to-music inference.")
    parser.add_argument("--repo_id", required=True, help="Hugging Face repo id or local repo directory.")
    parser.add_argument("--prompt", required=True, help="Text prompt.")
    parser.add_argument("--output", default="output.wav", help="Output wav path.")
    parser.add_argument("--duration", type=float, default=10.0, help="Generated duration in seconds.")
    parser.add_argument("--device", default="cuda", help="cuda, cuda:0, or cpu.")
    parser.add_argument("--cfg-scale", type=float, default=None)
    parser.add_argument("--temperature", type=float, default=None)
    parser.add_argument("--top-k", type=int, default=None)
    parser.add_argument("--top-p", type=float, default=None)
    args = parser.parse_args()

    kwargs = {}
    if args.cfg_scale is not None:
        kwargs["cfg_scale"] = args.cfg_scale
    if args.temperature is not None:
        kwargs["temperature"] = args.temperature
    if args.top_k is not None:
        kwargs["top_k"] = args.top_k
    if args.top_p is not None:
        kwargs["top_p"] = args.top_p

    pipe = BandTokPipeline.from_pretrained(args.repo_id, device=args.device)
    audio = pipe.generate(args.prompt, duration=args.duration, **kwargs)
    pipe.save(audio, args.output)
    print(f"Saved audio to {args.output}")


if __name__ == "__main__":
    main()