| import base64 |
| import io |
| import json |
| import os |
| from typing import Any, Dict, List, Optional |
|
|
| from PIL import Image |
|
|
| import torch |
| from transformers import AutoModelForCausalLM |
|
|
|
|
| def _b64_to_pil(data_url: str) -> Image.Image: |
| if not isinstance(data_url, str) or not data_url.startswith("data:"): |
| raise ValueError("Expected a data URL starting with 'data:'") |
| header, b64data = data_url.split(",", 1) |
| raw = base64.b64decode(b64data) |
| img = Image.open(io.BytesIO(raw)) |
| img.load() |
| return img |
|
|
|
|
| class EndpointHandler: |
| """HF Inference Endpoint handler for Moondream3 Preview (point only). |
| |
| Input contract (OpenAI-style, simplified): |
| { |
| "system": "<system prompt>", |
| "user": "<user prompt>", |
| "image": "data:<mime>;base64,<...>", |
| "prioritize_accuracy": true | false // optional (default true) |
| } |
| |
| Output (point only): |
| { points: [{x, y}], raw: <debug payload> } |
| Coordinates are normalized [0,1]. |
| """ |
|
|
| def __init__(self, path: str = "") -> None: |
| model_id = os.environ.get("MODEL_ID", "moondream/moondream3-preview") |
|
|
| os.environ.setdefault("OMP_NUM_THREADS", "1") |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| |
| |
| hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN") |
| if torch.cuda.is_available(): |
| device_map = {"": "cuda"} |
| dtype = torch.bfloat16 |
| else: |
| device_map = {"": "cpu"} |
| dtype = torch.float32 |
| load_kwargs = { |
| "trust_remote_code": True, |
| "torch_dtype": dtype, |
| "device_map": device_map, |
| } |
| if hub_token: |
| load_kwargs["token"] = hub_token |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| **load_kwargs, |
| ) |
|
|
| |
| try: |
| compile_fn = getattr(self.model, "compile", None) |
| if callable(compile_fn): |
| compile_fn() |
| except Exception: |
| pass |
|
|
| |
| |
| try: |
| if hasattr(self.model, "model") and hasattr(self.model.model, "use_flex_decoding"): |
| self.model.model.use_flex_decoding = False |
| except Exception: |
| pass |
|
|
| def __call__(self, data: Dict[str, Any]) -> Any: |
| |
| if isinstance(data, dict) and "inputs" in data: |
| inputs_val = data.get("inputs") |
| if isinstance(inputs_val, dict): |
| data = inputs_val |
| elif isinstance(inputs_val, (str, bytes, bytearray)): |
| try: |
| if isinstance(inputs_val, (bytes, bytearray)): |
| inputs_val = inputs_val.decode("utf-8") |
| parsed = json.loads(inputs_val) |
| if isinstance(parsed, dict): |
| data = parsed |
| except Exception: |
| pass |
|
|
| |
| prioritize_accuracy = bool(data.get("prioritize_accuracy", True)) |
|
|
| system_prompt: Optional[str] = data.get("system") |
| text_piece: Optional[str] = data.get("user") |
| image_data_url: Optional[str] = data.get("image") |
| if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"): |
| return {"error": "image must be a data URL (data:...)"} |
| if not text_piece: |
| return {"error": "user text must be provided"} |
|
|
| |
| try: |
| pil = _b64_to_pil(image_data_url) |
| except Exception as e: |
| return {"error": f"Failed to decode image data URL: {e}"} |
|
|
| width = getattr(pil, "width", None) |
| height = getattr(pil, "height", None) |
| if width and height: |
| try: |
| print(f"[moondream-endpoint] Received image size: {width}x{height}") |
| except Exception: |
| pass |
|
|
| |
| try: |
| if prioritize_accuracy: |
| flipped = pil.transpose(Image.FLIP_LEFT_RIGHT) |
| res_orig = self.model.point(pil, text_piece) |
| res_flip = self.model.point(flipped, text_piece) |
| points = self._tta_points(res_orig.get("points", []), res_flip.get("points", [])) |
| out: Dict[str, Any] = {"points": points} |
| else: |
| result = self.model.point(pil, text_piece) |
| out = {"points": result.get("points", [])} |
| except Exception as e: |
| return {"error": f"Model inference failed: {e}"} |
|
|
| |
| |
| def _se(s: Optional[str], n: int = 120): |
| if not s: |
| return ("", "") |
| return (s[:n], s[-n:] if len(s) > n else s) |
| sys_start, sys_end = _se(system_prompt) |
| usr_start, usr_end = _se(text_piece) |
| try: |
| print(f"[moondream-endpoint] System prompt (start): {sys_start}") |
| print(f"[moondream-endpoint] System prompt (end): {sys_end}") |
| print(f"[moondream-endpoint] User prompt (full): {text_piece}") |
| except Exception: |
| pass |
| if width and height: |
| try: |
| print(f"[moondream-endpoint] Received image size: {width}x{height}") |
| except Exception: |
| pass |
| try: |
| print(f"[moondream-endpoint] Raw output: {json.dumps(out)}") |
| except Exception: |
| pass |
|
|
| |
| if not isinstance(out.get("points"), list) or not out["points"]: |
| return {"error": "No points returned"} |
| def _to_01(p): |
| x = float(p.get("x", 0.0)) |
| y = float(p.get("y", 0.0)) |
| if x > 1.0 or y > 1.0: |
| return {"x": x / 1000.0, "y": y / 1000.0} |
| return {"x": x, "y": y} |
| points_01 = [_to_01(p) for p in out["points"]] |
| return {"points": points_01, "raw": out} |
|
|
| @staticmethod |
| def _flip_point(p: Dict[str, Any]) -> Dict[str, float]: |
| x = float(p.get("x", 0.0)) |
| y = float(p.get("y", 0.0)) |
| x = 1.0 - x |
| return {"x": max(0.0, min(1.0, x)), "y": max(0.0, min(1.0, y))} |
|
|
| @classmethod |
| def _deduplicate_and_average_points(cls, points: List[Dict[str, Any]], tol: float = 0.03) -> List[Dict[str, float]]: |
| clusters: List[Dict[str, float]] = [] |
| counts: List[int] = [] |
| for p in points: |
| px = float(p.get("x", 0.0)) |
| py = float(p.get("y", 0.0)) |
| matched = False |
| for i, c in enumerate(clusters): |
| dx = px - c["x"] |
| dy = py - c["y"] |
| if dx * dx + dy * dy <= tol * tol: |
| n = counts[i] |
| c["x"] = (c["x"] * n + px) / (n + 1) |
| c["y"] = (c["y"] * n + py) / (n + 1) |
| counts[i] = n + 1 |
| matched = True |
| break |
| if not matched: |
| clusters.append({"x": px, "y": py}) |
| counts.append(1) |
| return clusters |
|
|
| @classmethod |
| def _tta_points(cls, points_a: List[Dict[str, Any]], points_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]: |
| |
| unflipped_b = [cls._flip_point(p) for p in points_b_flipped] |
| merged = list(points_a) + unflipped_b |
| return cls._deduplicate_and_average_points(merged) |
|
|
| |
|
|
|
|
|
|