| """ |
| SigLIP 2 Custom Inference Handler for Hugging Face Inference Endpoints |
| Model: google/siglip2-so400m-patch14-384 (Best balance of performance/quality) |
| |
| For ProofPath video assessment - identifies objects, tools, and actions in video frames. |
| """ |
|
|
| from typing import Dict, List, Any, Union |
| import torch |
| import numpy as np |
| import base64 |
| import io |
| from PIL import Image |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize SigLIP 2 model for image/frame classification and embedding. |
| |
| Args: |
| path: Path to the model directory (provided by HF Inference Endpoints) |
| """ |
| from transformers import AutoProcessor, AutoModel |
| |
| |
| |
| model_id = "google/siglip2-so400m-patch14-384" |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| self.processor = AutoProcessor.from_pretrained(model_id) |
| self.model = AutoModel.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| attn_implementation="sdpa" |
| ) |
| |
| if not torch.cuda.is_available(): |
| self.model = self.model.to(self.device) |
| |
| self.model.eval() |
| |
| def _decode_image(self, image_data: Any) -> Image.Image: |
| """ |
| Decode image from various input formats. |
| |
| Supports: |
| - Base64 encoded image |
| - URL to image |
| - PIL Image |
| - Raw bytes |
| """ |
| import requests |
| |
| if isinstance(image_data, Image.Image): |
| return image_data |
| elif isinstance(image_data, str): |
| if image_data.startswith(('http://', 'https://')): |
| |
| response = requests.get(image_data, stream=True) |
| return Image.open(response.raw).convert('RGB') |
| elif image_data.startswith('data:'): |
| |
| header, encoded = image_data.split(',', 1) |
| image_bytes = base64.b64decode(encoded) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| else: |
| |
| image_bytes = base64.b64decode(image_data) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| elif isinstance(image_data, bytes): |
| return Image.open(io.BytesIO(image_data)).convert('RGB') |
| else: |
| raise ValueError(f"Unsupported image input type: {type(image_data)}") |
| |
| def _process_batch( |
| self, |
| images: List[Image.Image], |
| texts: List[str] = None |
| ) -> Dict[str, torch.Tensor]: |
| """Process a batch of images and optional texts.""" |
| if texts: |
| |
| inputs = self.processor( |
| images=images, |
| text=texts, |
| padding="max_length", |
| max_length=64, |
| return_tensors="pt" |
| ) |
| else: |
| inputs = self.processor( |
| images=images, |
| return_tensors="pt" |
| ) |
| |
| return {k: v.to(self.model.device) for k, v in inputs.items()} |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process image(s) for classification or embedding extraction. |
| |
| Expected input formats: |
| |
| 1. Zero-shot classification: |
| { |
| "inputs": <image_data>, # single image or list of images |
| "parameters": { |
| "candidate_labels": ["label1", "label2", ...], |
| "hypothesis_template": "This is a photo of {}." # Optional |
| } |
| } |
| |
| 2. Image embedding only: |
| { |
| "inputs": <image_data>, |
| "parameters": { |
| "mode": "embedding" |
| } |
| } |
| |
| 3. Image-text similarity: |
| { |
| "inputs": { |
| "images": [<image1>, <image2>, ...], |
| "texts": ["text1", "text2", ...] |
| }, |
| "parameters": { |
| "mode": "similarity" |
| } |
| } |
| |
| Returns for classification: |
| { |
| "labels": ["label1", "label2"], |
| "scores": [0.85, 0.12], |
| "predictions": [{"label": "label1", "score": 0.85}, ...] |
| } |
| |
| Returns for embedding: |
| { |
| "image_embeddings": [[...], ...], |
| "embedding_shape": [batch, hidden_dim] |
| } |
| |
| Returns for similarity: |
| { |
| "similarity_matrix": [[...], ...], |
| "shape": [num_images, num_texts] |
| } |
| """ |
| inputs = data.get("inputs") |
| if inputs is None: |
| inputs = data.get("image") or data.get("images") |
| if inputs is None: |
| raise ValueError("No input provided. Use 'inputs', 'image', or 'images' key.") |
| |
| params = data.get("parameters", {}) |
| mode = params.get("mode", "classification") |
| |
| try: |
| |
| if mode == "embedding": |
| return self._extract_embeddings(inputs) |
| elif mode == "similarity": |
| return self._compute_similarity(inputs, params) |
| else: |
| |
| return self._classify(inputs, params) |
| |
| except Exception as e: |
| return {"error": str(e), "error_type": type(e).__name__} |
| |
| def _classify(self, inputs: Any, params: Dict) -> Dict[str, Any]: |
| """Zero-shot image classification.""" |
| candidate_labels = params.get("candidate_labels", []) |
| if not candidate_labels: |
| raise ValueError("candidate_labels required for classification mode") |
| |
| hypothesis_template = params.get("hypothesis_template", "This is a photo of {}.") |
| |
| |
| if isinstance(inputs, list): |
| images = [self._decode_image(img) for img in inputs] |
| else: |
| images = [self._decode_image(inputs)] |
| |
| |
| texts = [hypothesis_template.format(label) for label in candidate_labels] |
| |
| results = [] |
| for image in images: |
| |
| processed = self._process_batch([image] * len(texts), texts) |
| |
| with torch.no_grad(): |
| outputs = self.model(**processed) |
| |
| |
| logits_per_image = outputs.logits_per_image |
| probs = torch.sigmoid(logits_per_image[0]) |
| |
| |
| sorted_indices = probs.argsort(descending=True) |
| |
| predictions = [] |
| for idx in sorted_indices: |
| predictions.append({ |
| "label": candidate_labels[idx.item()], |
| "score": float(probs[idx].item()) |
| }) |
| |
| results.append({ |
| "labels": [p["label"] for p in predictions], |
| "scores": [p["score"] for p in predictions], |
| "predictions": predictions |
| }) |
| |
| |
| if len(results) == 1: |
| return results[0] |
| return {"results": results} |
| |
| def _extract_embeddings(self, inputs: Any) -> Dict[str, Any]: |
| """Extract image embeddings only.""" |
| |
| if isinstance(inputs, list): |
| images = [self._decode_image(img) for img in inputs] |
| else: |
| images = [self._decode_image(inputs)] |
| |
| processed = self.processor(images=images, return_tensors="pt") |
| processed = {k: v.to(self.model.device) for k, v in processed.items()} |
| |
| with torch.no_grad(): |
| |
| vision_outputs = self.model.get_image_features(**processed) |
| |
| embeddings = vision_outputs.cpu().numpy().tolist() |
| |
| return { |
| "image_embeddings": embeddings, |
| "embedding_shape": list(vision_outputs.shape) |
| } |
| |
| def _compute_similarity(self, inputs: Dict, params: Dict) -> Dict[str, Any]: |
| """Compute image-text similarity matrix.""" |
| images_data = inputs.get("images", []) |
| texts = inputs.get("texts", []) |
| |
| if not images_data or not texts: |
| raise ValueError("Both 'images' and 'texts' required for similarity mode") |
| |
| |
| images = [self._decode_image(img) for img in images_data] |
| |
| |
| processed = self.processor( |
| images=images, |
| text=texts, |
| padding="max_length", |
| max_length=64, |
| return_tensors="pt" |
| ) |
| processed = {k: v.to(self.model.device) for k, v in processed.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model(**processed) |
| |
| |
| similarity = outputs.logits_per_image |
| probs = torch.sigmoid(similarity) |
| |
| return { |
| "similarity_matrix": probs.cpu().numpy().tolist(), |
| "shape": list(probs.shape), |
| "logits": similarity.cpu().numpy().tolist() |
| } |
|
|