Image Classification
Transformers
Safetensors
dinov2
font_classifier_v4 / handler.py
dchen0's picture
Add merged model + processor
2e97025 verified
raw
history blame
2.8 kB
# to be bundled with the model on upload to HF Inference Endpoints
import base64
import io
from typing import Any, Dict
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import AutoImageProcessor, Dinov2ForImageClassification
def get_inference_transform(processor: AutoImageProcessor, size: int):
"""Get the raw validation transform for direct inference on PIL images."""
normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std)
to_rgb = T.Lambda(lambda img: img.convert('RGB'))
def pad_to_square(img):
w, h = img.size
max_size = max(w, h)
pad_w = (max_size - w) // 2
pad_h = (max_size - h) // 2
padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
return T.Pad(padding, fill=0)(img)
aug = T.Compose([
to_rgb,
pad_to_square,
T.Resize(size),
T.ToTensor(),
normalize
])
return aug
class EndpointHandler:
"""
HF Inference Endpoints entry‑point.
Loads model/processor once, then uses your *imported* preprocessing
on every request.
"""
def __init__(self, path: str = "", image_size: int = 224):
# Weights + processor --------------------------------------------------------
self.processor = AutoImageProcessor.from_pretrained(path or ".")
self.model = (
Dinov2ForImageClassification.from_pretrained(path or ".")
.eval()
)
# Re‑use the exact transform from your code ---------------------------------
self.transform = get_inference_transform(self.processor, image_size)
self.id2label = self.model.config.id2label
# -------------------------------------------------------------------------------
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Expects {"inputs": "<base64‑encoded image>"}.
Returns the top prediction + per‑class probabilities.
"""
if "inputs" not in data:
raise ValueError("Request JSON must contain an 'inputs' field.")
# Decode base64 → PIL
img_bytes = base64.b64decode(data["inputs"])
image = Image.open(io.BytesIO(img_bytes))
# Preprocess with your own transform
pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
with torch.no_grad():
logits = self.model(pixel_values).logits
probs = logits.softmax(dim=-1)[0]
top_idx = int(probs.argmax())
top_label = self.id2label[top_idx]
return {
"predicted_label": top_label,
"scores": {
self.id2label[i]: float(p)
for i, p in enumerate(probs)
}
}