| from typing import Any, Dict |
|
|
| import torch |
| from transformers import AutoModel, AutoProcessor |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.processor = AutoProcessor.from_pretrained("suno/bark") |
| self.model = AutoModel.from_pretrained( |
| "suno/bark", |
| ).to("cuda") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| """ |
| Args: |
| data (:dict:): |
| The payload with the text prompt and generation parameters. |
| """ |
| |
| text = data.pop("inputs", data) |
| voice_preset = data.get("voice_preset", None) |
| if voice_preset: |
| inputs = self.processor( |
| text=[text], |
| return_tensors="pt", |
| voice_preset=voice_preset, |
| ).to("cuda") |
| else: |
| inputs = self.processor( |
| text=[text], |
| return_tensors="pt", |
| ).to("cuda") |
|
|
| with torch.autocast("cuda"): |
| outputs = self.model.generate(**inputs) |
|
|
| |
| prediction = outputs.cpu().numpy().tolist() |
|
|
| return {"generated_audio": prediction} |
|
|