| """ |
| OWLv2 Custom Handler for HuggingFace Inference Endpoints |
| |
| Supports: |
| - Image-conditioned detection (find objects similar to a reference image) |
| - Text-conditioned detection (find objects matching text descriptions) |
| """ |
|
|
| from typing import Dict, Any, List, Union |
| import torch |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection |
| from PIL import Image |
| import base64 |
| import io |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """Load model on endpoint startup.""" |
| model_id = "google/owlv2-large-patch14-ensemble" |
| |
| self.processor = Owlv2Processor.from_pretrained(model_id) |
| self.model = Owlv2ForObjectDetection.from_pretrained(model_id) |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| |
| print(f"OWLv2 loaded on {self.device}") |
|
|
| def _decode_image(self, image_data: str) -> Image.Image: |
| """Decode base64 image string to PIL Image.""" |
| |
| if "," in image_data: |
| image_data = image_data.split(",")[1] |
| |
| image_bytes = base64.b64decode(image_data) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| return image |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process detection request. |
| |
| === Image-Conditioned Detection === |
| Find objects similar to a reference image. |
| |
| Request: |
| { |
| "inputs": { |
| "target_image": "base64...", |
| "query_image": "base64...", |
| "threshold": 0.5, |
| "nms_threshold": 0.3 |
| } |
| } |
| |
| === Text-Conditioned Detection === |
| Find objects matching text descriptions. |
| |
| Request: |
| { |
| "inputs": { |
| "target_image": "base64...", |
| "queries": ["a button", "an icon"], |
| "threshold": 0.1 |
| } |
| } |
| |
| === Multiple Query Images === |
| Find multiple different objects by image. |
| |
| Request: |
| { |
| "inputs": { |
| "target_image": "base64...", |
| "query_images": ["base64...", "base64..."], |
| "threshold": 0.5, |
| "nms_threshold": 0.3 |
| } |
| } |
| |
| Response: |
| { |
| "detections": [ |
| {"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"} |
| ] |
| } |
| """ |
| try: |
| |
| inputs = data.get("inputs", data) |
| |
| |
| if "target_image" not in inputs: |
| return {"error": "Missing required field: target_image"} |
| |
| target_image = self._decode_image(inputs["target_image"]) |
| threshold = float(inputs.get("threshold", 0.5)) |
| nms_threshold = float(inputs.get("nms_threshold", 0.3)) |
| |
| |
| if "query_image" in inputs: |
| |
| query_image = self._decode_image(inputs["query_image"]) |
| return self._detect_with_image( |
| target_image, [query_image], threshold, nms_threshold |
| ) |
| |
| elif "query_images" in inputs: |
| |
| query_images = [ |
| self._decode_image(img) for img in inputs["query_images"] |
| ] |
| return self._detect_with_image( |
| target_image, query_images, threshold, nms_threshold |
| ) |
| |
| elif "queries" in inputs: |
| |
| return self._detect_with_text( |
| target_image, inputs["queries"], threshold |
| ) |
| |
| else: |
| return { |
| "error": "Provide 'query_image', 'query_images', or 'queries'" |
| } |
| |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def _detect_with_image( |
| self, |
| target: Image.Image, |
| query_images: List[Image.Image], |
| threshold: float, |
| nms_threshold: float |
| ) -> Dict[str, Any]: |
| """Image-conditioned detection.""" |
| |
| inputs = self.processor( |
| images=target, |
| query_images=query_images, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model.image_guided_detection(**inputs) |
| |
| target_sizes = torch.tensor([target.size[::-1]]) |
| results = self.processor.post_process_image_guided_detection( |
| outputs=outputs, |
| threshold=threshold, |
| nms_threshold=nms_threshold, |
| target_sizes=target_sizes |
| )[0] |
| |
| detections = [] |
| for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])): |
| det = { |
| "box": [round(c, 2) for c in box.tolist()], |
| "confidence": round(score.item(), 4) |
| } |
| |
| if len(query_images) > 1 and "labels" in results: |
| det["label"] = f"query_{results['labels'][i].item()}" |
| detections.append(det) |
| |
| return {"detections": detections} |
|
|
| def _detect_with_text( |
| self, |
| target: Image.Image, |
| queries: List[str], |
| threshold: float |
| ) -> Dict[str, Any]: |
| """Text-conditioned detection.""" |
| |
| inputs = self.processor( |
| text=[queries], |
| images=target, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| target_sizes = torch.tensor([target.size[::-1]]) |
| results = self.processor.post_process_object_detection( |
| outputs, threshold=threshold, target_sizes=target_sizes |
| )[0] |
| |
| detections = [] |
| for box, score, label_idx in zip( |
| results["boxes"], results["scores"], results["labels"] |
| ): |
| detections.append({ |
| "box": [round(c, 2) for c in box.tolist()], |
| "confidence": round(score.item(), 4), |
| "label": queries[label_idx.item()] |
| }) |
| |
| return {"detections": detections} |
|
|