| import logging |
| from typing import Any, Dict |
|
|
| import numpy as np |
| import torch |
| from audiocraft.models import MusicGen |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| else: |
| self.device = "cpu" |
| |
| |
| self.channels = 1 |
| self.model = MusicGen.get_pretrained( |
| "facebook/musicgen-large", device=self.device |
| ) |
| self.sample_rate = self.model.sample_rate |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]: |
| """ |
| This call function is called by the endpoint. It takes in a payload and returns an audio signal. |
| The main advantage of this function is that it supports generation of audio in chunks, |
| so the limitation of 30s audio generation is removed for the model. |
| The payload should be a dictionary with the following keys: |
| prompt: The prompt to generate audio for. |
| generation_params: A dictionary of generation parameters. The following keys are supported: |
| duration: The duration of audio to generate in seconds. Default: 30 |
| temperature: The temperature to use for generation. Default: 0.8 |
| top_p: The top p value to use for generation. Default: 0.0 |
| top_k: The top k value to use for generation. Default: 250 |
| cfg_coef: The amount of classifier free guidance to use. Default: 0.0 |
| These values are passed to the model's set_generation_params function. Other |
| values can be passed as well if they are supported by the model. |
| audio_window: The amount of audio to use as prompt for the next chunk. Default: 20 |
| chunk_size: The size of each chunk in seconds. Default: 30 |
| |
| Args: |
| data (Dict[str, Any]): The payload to generate audio for. |
| |
| Raises: |
| ValueError: If chunk_size is less than audio_window |
| or if the duration is not a multiple of chunk_size - audio_window |
| |
| Returns: |
| Dict[str, str]: A dictionary with the generated audio. |
| """ |
| prompt = data["inputs"] |
|
|
| generation_params = data.get("generation_params", {}) |
|
|
| duration = generation_params.get("duration", 30) |
|
|
| if duration <= 30: |
| logger.info(f"Generating audio with duration {duration} in one go.") |
| self.model.set_generation_params(**generation_params) |
| final_audio = self.model.generate([prompt], progress=True) |
| else: |
| logger.info(f"Generating audio with duration {duration} in chunks.") |
|
|
| audio_window = data.get("audio_window", 20) |
| chunk_size = data.get("chunk_size", 30) |
| continuation = chunk_size - audio_window |
| final_duration = duration |
|
|
| if chunk_size < audio_window: |
| raise ValueError( |
| f"Chunk size {chunk_size} must be greater than audio window {audio_window}" |
| ) |
|
|
| if (final_duration - chunk_size) % continuation != 0: |
| raise ValueError( |
| f"Duration ({duration} secs) - chunksize ({chunk_size} secs)" |
| f" must be a multiple of continuation ({continuation} secs)" |
| ) |
|
|
| generation_params["duration"] = chunk_size |
| self.model.set_generation_params(**generation_params) |
|
|
| logger.info( |
| f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs " |
| f"and continuation of {continuation} secs." |
| ) |
|
|
| |
| logger.info(f"Initializing final audio with {chunk_size} secs of audio.") |
| final_audio = torch.zeros( |
| ( |
| self.channels, |
| self.sample_rate * final_duration, |
| ), |
| dtype=torch.float, |
| ).to(self.device) |
|
|
| final_audio[ |
| :, |
| : chunk_size * self.sample_rate, |
| ] = self.model.generate([prompt], progress=True) |
|
|
| n_hops = (final_duration - chunk_size) // continuation |
| for i_hop in range(n_hops): |
| logger.info(f"Generating audio for hop {i_hop}") |
|
|
| prompt_stop = chunk_size + i_hop * continuation |
| prompt_start = prompt_stop - audio_window |
|
|
| audio_prompt = final_audio[ |
| :, prompt_start * self.sample_rate : prompt_stop * self.sample_rate |
| ].reshape(1, self.channels, -1) |
|
|
| output = self.model.generate_continuation( |
| audio_prompt, |
| self.sample_rate, |
| [prompt], |
| progress=True, |
| ) |
|
|
| final_audio[ |
| :, |
| prompt_stop |
| * self.sample_rate : (prompt_stop + continuation) |
| * self.sample_rate, |
| ] = output[..., audio_window * self.sample_rate :] |
| logger.info( |
| f"finished generating audio till {(prompt_stop + continuation)} secs." |
| ) |
|
|
| return {"generated_audio": final_audio.cpu().numpy().transpose()} |
|
|