Create handler.py
Browse files- handler.py +45 -0
handler.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stable_audio_tools import get_pretrained_model
|
| 2 |
+
from stable_audio_tools.inference.generation import generate_diffusion_cond
|
| 3 |
+
import torch
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
|
| 8 |
+
# Load once (IMPORTANT)
|
| 9 |
+
model, cfg = get_pretrained_model("bharatverse11/BeatGeneration")
|
| 10 |
+
model.eval().to("cuda")
|
| 11 |
+
|
| 12 |
+
SAMPLE_RATE = cfg.get("sample_rate", 44100)
|
| 13 |
+
|
| 14 |
+
def handler(data):
|
| 15 |
+
inputs = data["inputs"]
|
| 16 |
+
|
| 17 |
+
prompt = inputs.get("prompt", "")
|
| 18 |
+
duration = inputs.get("duration", 10)
|
| 19 |
+
steps = inputs.get("steps", 50)
|
| 20 |
+
cfg_scale = inputs.get("cfg_scale", 7)
|
| 21 |
+
|
| 22 |
+
conditioning = [{
|
| 23 |
+
"prompt": prompt,
|
| 24 |
+
"seconds_start": 0,
|
| 25 |
+
"seconds_total": duration,
|
| 26 |
+
}]
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
output = generate_diffusion_cond(
|
| 30 |
+
model,
|
| 31 |
+
steps=steps,
|
| 32 |
+
cfg_scale=cfg_scale,
|
| 33 |
+
conditioning=conditioning,
|
| 34 |
+
sample_size=int(duration * SAMPLE_RATE),
|
| 35 |
+
device="cuda",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
audio = output.cpu().numpy()[0].T
|
| 39 |
+
|
| 40 |
+
buffer = io.BytesIO()
|
| 41 |
+
sf.write(buffer, audio, SAMPLE_RATE, format="WAV")
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"audio": base64.b64encode(buffer.getvalue()).decode()
|
| 45 |
+
}
|