Image Classification
Transformers
Safetensors
dinov2
File size: 2,099 Bytes
e8208a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# to be bundled with the model on upload to HF Inference Endpoints

import base64
import io
from typing import Any, Dict

import torch
from PIL import Image
from transformers import AutoImageProcessor, Dinov2ForImageClassification

from train_model import get_inference_transform


class EndpointHandler:
    """
    HF Inference Endpoints entry‑point.
    Loads model/processor once, then uses your *imported* preprocessing
    on every request.
    """

    def __init__(self, path: str = "", image_size: int = 224):
        # Weights + processor --------------------------------------------------------
        self.processor = AutoImageProcessor.from_pretrained(path or ".")
        self.model     = (
            Dinov2ForImageClassification.from_pretrained(path or ".")
            .eval()
        )

        # Re‑use the exact transform from your code ---------------------------------
        self.transform = get_inference_transform(self.processor, image_size)

        self.id2label = self.model.config.id2label

    # -------------------------------------------------------------------------------
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Expects {"inputs": "<base64‑encoded image>"}.
        Returns the top prediction + per‑class probabilities.
        """
        if "inputs" not in data:
            raise ValueError("Request JSON must contain an 'inputs' field.")

        # Decode base64 → PIL
        img_bytes = base64.b64decode(data["inputs"])
        image = Image.open(io.BytesIO(img_bytes))

        # Preprocess with your own transform
        pixel_values = self.transform(image).unsqueeze(0)   # [1, C, H, W]

        with torch.no_grad():
            logits = self.model(pixel_values).logits
            probs  = logits.softmax(dim=-1)[0]

        top_idx = int(probs.argmax())
        top_label = self.id2label[top_idx]

        return {
            "predicted_label": top_label,
            "scores": {
                self.id2label[i]: float(p)
                for i, p in enumerate(probs)
            }
        }