import base64 import io import json import os from typing import Any, Dict, List, Optional from PIL import Image import torch from transformers import AutoModelForCausalLM def _b64_to_pil(data_url: str) -> Image.Image: if not isinstance(data_url, str) or not data_url.startswith("data:"): raise ValueError("Expected a data URL starting with 'data:'") header, b64data = data_url.split(",", 1) raw = base64.b64decode(b64data) img = Image.open(io.BytesIO(raw)) img.load() return img class EndpointHandler: """HF Inference Endpoint handler for Moondream3 Preview (point only). Input contract (OpenAI-style, simplified): { "system": "", "user": "", "image": "data:;base64,<...>", "prioritize_accuracy": true | false // optional (default true) } Output (point only): { points: [{x, y}], raw: } Coordinates are normalized [0,1]. """ def __init__(self, path: str = "") -> None: model_id = os.environ.get("MODEL_ID", "moondream/moondream3-preview") os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # Load local repo (or remote if MODEL_ID points to hub id) # Pass token when accessing gated repos and ensure consistent device placement hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN") if torch.cuda.is_available(): device_map = {"": "cuda"} dtype = torch.bfloat16 else: device_map = {"": "cpu"} dtype = torch.float32 load_kwargs = { "trust_remote_code": True, "torch_dtype": dtype, "device_map": device_map, } if hub_token: load_kwargs["token"] = hub_token self.model = AutoModelForCausalLM.from_pretrained( model_id, **load_kwargs, ) # Optional compilation for speed if exposed by remote code try: compile_fn = getattr(self.model, "compile", None) if callable(compile_fn): compile_fn() except Exception: pass # Prefer robustness over speed on HF endpoints: disable FlexAttention decoding # to avoid BlockMask attribute mismatches in some torch builds. try: if hasattr(self.model, "model") and hasattr(self.model.model, "use_flex_decoding"): self.model.model.use_flex_decoding = False except Exception: pass def __call__(self, data: Dict[str, Any]) -> Any: # Accept HF toolkit shapes: { inputs: {...} } or JSON string if isinstance(data, dict) and "inputs" in data: inputs_val = data.get("inputs") if isinstance(inputs_val, dict): data = inputs_val elif isinstance(inputs_val, (str, bytes, bytearray)): try: if isinstance(inputs_val, (bytes, bytearray)): inputs_val = inputs_val.decode("utf-8") parsed = json.loads(inputs_val) if isinstance(parsed, dict): data = parsed except Exception: pass # New input contract: expect 'system', 'user', 'image' (point task only) prioritize_accuracy = bool(data.get("prioritize_accuracy", True)) system_prompt: Optional[str] = data.get("system") text_piece: Optional[str] = data.get("user") image_data_url: Optional[str] = data.get("image") if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"): return {"error": "image must be a data URL (data:...)"} if not text_piece: return {"error": "user text must be provided"} # Decode for dimensions and pass PIL to model try: pil = _b64_to_pil(image_data_url) except Exception as e: return {"error": f"Failed to decode image data URL: {e}"} width = getattr(pil, "width", None) height = getattr(pil, "height", None) if width and height: try: print(f"[moondream-endpoint] Received image size: {width}x{height}") except Exception: pass # Point-only inference try: if prioritize_accuracy: flipped = pil.transpose(Image.FLIP_LEFT_RIGHT) res_orig = self.model.point(pil, text_piece) res_flip = self.model.point(flipped, text_piece) points = self._tta_points(res_orig.get("points", []), res_flip.get("points", [])) out: Dict[str, Any] = {"points": points} else: result = self.model.point(pil, text_piece) out = {"points": result.get("points", [])} except Exception as e: return {"error": f"Model inference failed: {e}"} # Print prompt, dimensions, and raw output # Log prompts and timings def _se(s: Optional[str], n: int = 120): if not s: return ("", "") return (s[:n], s[-n:] if len(s) > n else s) sys_start, sys_end = _se(system_prompt) usr_start, usr_end = _se(text_piece) try: print(f"[moondream-endpoint] System prompt (start): {sys_start}") print(f"[moondream-endpoint] System prompt (end): {sys_end}") print(f"[moondream-endpoint] User prompt (full): {text_piece}") except Exception: pass if width and height: try: print(f"[moondream-endpoint] Received image size: {width}x{height}") except Exception: pass try: print(f"[moondream-endpoint] Raw output: {json.dumps(out)}") except Exception: pass # Ensure points array exists and normalized [0,1] if not isinstance(out.get("points"), list) or not out["points"]: return {"error": "No points returned"} def _to_01(p): x = float(p.get("x", 0.0)) y = float(p.get("y", 0.0)) if x > 1.0 or y > 1.0: return {"x": x / 1000.0, "y": y / 1000.0} return {"x": x, "y": y} points_01 = [_to_01(p) for p in out["points"]] return {"points": points_01, "raw": out} @staticmethod def _flip_point(p: Dict[str, Any]) -> Dict[str, float]: x = float(p.get("x", 0.0)) y = float(p.get("y", 0.0)) x = 1.0 - x return {"x": max(0.0, min(1.0, x)), "y": max(0.0, min(1.0, y))} @classmethod def _deduplicate_and_average_points(cls, points: List[Dict[str, Any]], tol: float = 0.03) -> List[Dict[str, float]]: clusters: List[Dict[str, float]] = [] counts: List[int] = [] for p in points: px = float(p.get("x", 0.0)) py = float(p.get("y", 0.0)) matched = False for i, c in enumerate(clusters): dx = px - c["x"] dy = py - c["y"] if dx * dx + dy * dy <= tol * tol: n = counts[i] c["x"] = (c["x"] * n + px) / (n + 1) c["y"] = (c["y"] * n + py) / (n + 1) counts[i] = n + 1 matched = True break if not matched: clusters.append({"x": px, "y": py}) counts.append(1) return clusters @classmethod def _tta_points(cls, points_a: List[Dict[str, Any]], points_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]: # Convert flipped prediction back to original frame unflipped_b = [cls._flip_point(p) for p in points_b_flipped] merged = list(points_a) + unflipped_b return cls._deduplicate_and_average_points(merged) # Box-related utilities removed (endpoint is point-only)