| from typing import Tuple |
| import subprocess |
|
|
| from torch import no_grad, package |
| import numpy as np |
| import os |
|
|
|
|
|
|
|
|
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path: str): |
| |
| subprocess.run("apt-get update -y && apt-get install espeak-ng -y", shell=True, |
| universal_newlines=True, start_new_session=True) |
|
|
| |
| model_path = os.path.join(path, "model.pt") |
| importer = package.PackageImporter(model_path) |
| synt = importer.load_pickle("tts_models", "model") |
| self.synt = synt |
|
|
| self.tts_kwargs = { |
| "speaker_name": "uk", |
| "language_name": "uk", |
| } |
|
|
| self.sampling_rate = self.synt.output_sample_rate |
|
|
| def __call__(self, inputs: str) -> Tuple[np.array, int]: |
| """ |
| Args: |
| inputs (:obj:`str`): |
| The text to generate audio from |
| Return: |
| A :obj:`np.array` and a :obj:`int`: The raw waveform as a numpy array, and the sampling rate as an int. |
| """ |
| with no_grad(): |
| waveforms = self.synt.tts(inputs, **self.tts_kwargs) |
| waveforms = np.array(waveforms, dtype=np.float32) |
| return waveforms, self.sampling_rate |