Image Classification
Transformers
Safetensors
dinov2
font_classifier_v4 / handler.py
dchen0's picture
Add merged model + processor
e8208a0 verified
raw
history blame
2.1 kB
# to be bundled with the model on upload to HF Inference Endpoints
import base64
import io
from typing import Any, Dict
import torch
from PIL import Image
from transformers import AutoImageProcessor, Dinov2ForImageClassification
from train_model import get_inference_transform
class EndpointHandler:
"""
HF Inference Endpoints entry‑point.
Loads model/processor once, then uses your *imported* preprocessing
on every request.
"""
def __init__(self, path: str = "", image_size: int = 224):
# Weights + processor --------------------------------------------------------
self.processor = AutoImageProcessor.from_pretrained(path or ".")
self.model = (
Dinov2ForImageClassification.from_pretrained(path or ".")
.eval()
)
# Re‑use the exact transform from your code ---------------------------------
self.transform = get_inference_transform(self.processor, image_size)
self.id2label = self.model.config.id2label
# -------------------------------------------------------------------------------
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Expects {"inputs": "<base64‑encoded image>"}.
Returns the top prediction + per‑class probabilities.
"""
if "inputs" not in data:
raise ValueError("Request JSON must contain an 'inputs' field.")
# Decode base64 → PIL
img_bytes = base64.b64decode(data["inputs"])
image = Image.open(io.BytesIO(img_bytes))
# Preprocess with your own transform
pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
with torch.no_grad():
logits = self.model(pixel_values).logits
probs = logits.softmax(dim=-1)[0]
top_idx = int(probs.argmax())
top_label = self.id2label[top_idx]
return {
"predicted_label": top_label,
"scores": {
self.id2label[i]: float(p)
for i, p in enumerate(probs)
}
}