bharatverse11 commited on
Commit
5c04df5
·
verified ·
1 Parent(s): d5555ab

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -0
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
3
+ import torch
4
+ import base64
5
+ import io
6
+ import soundfile as sf
7
+
8
+ # Load once (IMPORTANT)
9
+ model, cfg = get_pretrained_model("bharatverse11/BeatGeneration")
10
+ model.eval().to("cuda")
11
+
12
+ SAMPLE_RATE = cfg.get("sample_rate", 44100)
13
+
14
+ def handler(data):
15
+ inputs = data["inputs"]
16
+
17
+ prompt = inputs.get("prompt", "")
18
+ duration = inputs.get("duration", 10)
19
+ steps = inputs.get("steps", 50)
20
+ cfg_scale = inputs.get("cfg_scale", 7)
21
+
22
+ conditioning = [{
23
+ "prompt": prompt,
24
+ "seconds_start": 0,
25
+ "seconds_total": duration,
26
+ }]
27
+
28
+ with torch.no_grad():
29
+ output = generate_diffusion_cond(
30
+ model,
31
+ steps=steps,
32
+ cfg_scale=cfg_scale,
33
+ conditioning=conditioning,
34
+ sample_size=int(duration * SAMPLE_RATE),
35
+ device="cuda",
36
+ )
37
+
38
+ audio = output.cpu().numpy()[0].T
39
+
40
+ buffer = io.BytesIO()
41
+ sf.write(buffer, audio, SAMPLE_RATE, format="WAV")
42
+
43
+ return {
44
+ "audio": base64.b64encode(buffer.getvalue()).decode()
45
+ }