| from typing import Any, Dict, List, AnyStr |
| import numpy as np |
| from transformers import CLIPProcessor, CLIPModel |
| from PIL import Image |
| from io import BytesIO |
| import base64 |
|
|
| class EndpointHandler(): |
| def __init__(self, path="") -> None: |
| "Preload all the elements we need at inference." |
| self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| self.path = path |
| |
| def __call__(self, data: Dict[str, AnyStr]) -> List[Dict[str, AnyStr]]: |
| "Run the inference." |
| inputs = data.get('inputs') |
| text = inputs.get('text') |
| imageData = inputs.get('image') |
| image = Image.open(BytesIO(base64.b64decode(imageData))) |
| inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True) |
| outputs = self.model(**inputs) |
| image_embeds = outputs.image_embeds.detach().numpy().flatten().tolist() |
| text_embeds = outputs.text_embeds.detach().numpy().flatten().tolist() |
| logits_per_image = outputs.logits_per_image.detach().numpy().flatten().tolist() |
| return {'image_embeddings': image_embeds, 'text_embeddings': text_embeds, 'logits_per_image': logits_per_image} |
| |