| from typing import Dict, List, Any |
| import torch |
| from transformers import AutoProcessor, Pix2StructVisionModel |
| from PIL import Image |
| import pdb |
| import requests |
|
|
| MODEL = "google/pix2struct-screen2words-large" |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| |
| self.processor = AutoProcessor.from_pretrained(MODEL) |
| self.processor.image_processor.is_vqa = False |
| self.model = Pix2StructVisionModel.from_pretrained(MODEL).cuda() |
|
|
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
| url = data.pop("inputs", data) |
| device = "cuda" |
| image = Image.open(requests.get(url, stream=True).raw) |
| inputs = self.processor(images=image, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| last_hidden_state = outputs['last_hidden_state'] |
| embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist() |
| return {"embedding": embedding} |
|
|
| """ |
| handler = EndpointHandler() |
| output = handler({"inputs": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"}) |
| print(output) |
| """ |
|
|