|
|
| from transformers import Pipeline |
| import requests |
| from PIL import Image |
| import torchvision.transforms as transforms |
| import torch |
|
|
| class MnistPipe(Pipeline): |
| def __init__(self,**kwargs): |
|
|
| |
|
|
| Pipeline.__init__(self,**kwargs) |
|
|
| self.transform = transforms.Compose( |
| [transforms.ToTensor(), |
| transforms.Resize((28,28), antialias=True) |
| ]) |
|
|
| def _sanitize_parameters(self, **kwargs): |
| |
| preprocess_kwargs = {} |
| postprocess_kwargs = {} |
| if "download" in kwargs: |
| preprocess_kwargs["download"] = kwargs["download"] |
| if "clean_output" in kwargs : |
| postprocess_kwargs["clean_output"] = kwargs["clean_output"] |
| return preprocess_kwargs, {}, postprocess_kwargs |
|
|
| def preprocess(self, inputs, download=False): |
| if download == True : |
| |
| self.download_img(inputs) |
| inputs = "image.png" |
|
|
| |
| img = Image.open(inputs) |
| gray = img.convert('L') |
| tensor = self.transform(gray) |
| tensor = tensor.unsqueeze(0) |
| return tensor |
|
|
| def _forward(self, tensor): |
| with torch.no_grad(): |
| |
| |
| out = self.model(tensor) |
| return out |
|
|
| def postprocess(self, out, clean_output=True): |
| if clean_output ==True : |
| label = torch.argmax(out,axis=-1) |
| label = label.tolist()[0] |
| return label |
| else : |
| return out |
|
|
| def download_img(self,url): |
| |
| response = requests.get(url, stream=True) |
|
|
| with open("image.png", "wb") as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| print("image saved as image.png") |
|
|