| """ |
| This is just an example of what people would submit for inference. |
| """ |
| import os |
| from typing import Dict, List |
|
|
| import torch |
| from s3prl.downstream.runner import Runner |
|
|
| class PreTrainedModel(Runner): |
| def __init__(self, path=""): |
| """ |
| Initialize downstream model. |
| """ |
| ckp_file = os.path.join(path, "hubert_sd.ckpt") |
| ckp = torch.load(ckp_file, map_location="cpu") |
| ckp["Args"].init_ckpt = ckp_file |
| ckp["Args"].mode = "inference" |
| ckp["Args"].device = "cpu" |
| Runner.__init__(self, ckp["Args"], ckp["Config"]) |
|
|
| def __call__(self, inputs) -> List[int]: |
| """ |
| Args: inputs (:obj:`np.array`): The raw waveform of audio received. By |
| default at 16KHz. |
| Return: A list with logits. |
| """ |
| for entry in self.all_entries: |
| entry.model.eval() |
|
|
| inputs = [torch.FloatTensor(inputs)] |
|
|
| with torch.no_grad(): |
| features = self.upstream.model(inputs) |
| features = self.featurizer.model(inputs, features) |
| preds = self.downstream.model.inference(features, []) |
| return preds[0] |
|
|
| """ |
| import io |
| import soundfile as sf |
| from urllib.request import urlopen |
| model = PreTrainedModel() |
| url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav" |
| data, samplerate = sf.read(io.BytesIO(urlopen(url).read())) |
| print(model(data)) |
| """ |