| import torch |
| import nltk |
| import io |
| import base64 |
| from torchvision import transforms |
| from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample |
| class PreTrainedPipeline(): |
| def __init__(self, path=""): |
| """ |
| Initialize model |
| """ |
| nltk.download('wordnet') |
| self.model = BigGAN.from_pretrained(path) |
| self.truncation = 0.1 |
| def __call__(self, inputs: str): |
| """ |
| Args: |
| inputs (:obj:`str`): |
| a string containing some text |
| Return: |
| A :obj:`PIL.Image` with the raw image representation as PIL. |
| """ |
| class_vector = one_hot_from_names([inputs], batch_size=1) |
| if type(class_vector) == type(None): |
| raise ValueError("Input is not in ImageNet") |
| noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) |
| noise_vector = torch.from_numpy(noise_vector) |
| class_vector = torch.from_numpy(class_vector) |
| with torch.no_grad(): |
| output = self.model(noise_vector, class_vector, self.truncation) |
| |
| img = output[0] |
| img = (img + 1) / 2.0 |
| img = transforms.ToPILImage()(img) |
| return img |