moondream3-endpoint / handler.py
Forrest Wargo
unity v2
eb3bdf4
import base64
import io
import json
import os
from typing import Any, Dict, List, Optional
from PIL import Image
import torch
from transformers import AutoModelForCausalLM
def _b64_to_pil(data_url: str) -> Image.Image:
if not isinstance(data_url, str) or not data_url.startswith("data:"):
raise ValueError("Expected a data URL starting with 'data:'")
header, b64data = data_url.split(",", 1)
raw = base64.b64decode(b64data)
img = Image.open(io.BytesIO(raw))
img.load()
return img
class EndpointHandler:
"""HF Inference Endpoint handler for Moondream3 Preview (point only).
Input contract (OpenAI-style, simplified):
{
"system": "<system prompt>",
"user": "<user prompt>",
"image": "data:<mime>;base64,<...>",
"prioritize_accuracy": true | false // optional (default true)
}
Output (point only):
{ points: [{x, y}], raw: <debug payload> }
Coordinates are normalized [0,1].
"""
def __init__(self, path: str = "") -> None:
model_id = os.environ.get("MODEL_ID", "moondream/moondream3-preview")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Load local repo (or remote if MODEL_ID points to hub id)
# Pass token when accessing gated repos and ensure consistent device placement
hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
if torch.cuda.is_available():
device_map = {"": "cuda"}
dtype = torch.bfloat16
else:
device_map = {"": "cpu"}
dtype = torch.float32
load_kwargs = {
"trust_remote_code": True,
"torch_dtype": dtype,
"device_map": device_map,
}
if hub_token:
load_kwargs["token"] = hub_token
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
**load_kwargs,
)
# Optional compilation for speed if exposed by remote code
try:
compile_fn = getattr(self.model, "compile", None)
if callable(compile_fn):
compile_fn()
except Exception:
pass
# Prefer robustness over speed on HF endpoints: disable FlexAttention decoding
# to avoid BlockMask attribute mismatches in some torch builds.
try:
if hasattr(self.model, "model") and hasattr(self.model.model, "use_flex_decoding"):
self.model.model.use_flex_decoding = False
except Exception:
pass
def __call__(self, data: Dict[str, Any]) -> Any:
# Accept HF toolkit shapes: { inputs: {...} } or JSON string
if isinstance(data, dict) and "inputs" in data:
inputs_val = data.get("inputs")
if isinstance(inputs_val, dict):
data = inputs_val
elif isinstance(inputs_val, (str, bytes, bytearray)):
try:
if isinstance(inputs_val, (bytes, bytearray)):
inputs_val = inputs_val.decode("utf-8")
parsed = json.loads(inputs_val)
if isinstance(parsed, dict):
data = parsed
except Exception:
pass
# New input contract: expect 'system', 'user', 'image' (point task only)
prioritize_accuracy = bool(data.get("prioritize_accuracy", True))
system_prompt: Optional[str] = data.get("system")
text_piece: Optional[str] = data.get("user")
image_data_url: Optional[str] = data.get("image")
if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
return {"error": "image must be a data URL (data:...)"}
if not text_piece:
return {"error": "user text must be provided"}
# Decode for dimensions and pass PIL to model
try:
pil = _b64_to_pil(image_data_url)
except Exception as e:
return {"error": f"Failed to decode image data URL: {e}"}
width = getattr(pil, "width", None)
height = getattr(pil, "height", None)
if width and height:
try:
print(f"[moondream-endpoint] Received image size: {width}x{height}")
except Exception:
pass
# Point-only inference
try:
if prioritize_accuracy:
flipped = pil.transpose(Image.FLIP_LEFT_RIGHT)
res_orig = self.model.point(pil, text_piece)
res_flip = self.model.point(flipped, text_piece)
points = self._tta_points(res_orig.get("points", []), res_flip.get("points", []))
out: Dict[str, Any] = {"points": points}
else:
result = self.model.point(pil, text_piece)
out = {"points": result.get("points", [])}
except Exception as e:
return {"error": f"Model inference failed: {e}"}
# Print prompt, dimensions, and raw output
# Log prompts and timings
def _se(s: Optional[str], n: int = 120):
if not s:
return ("", "")
return (s[:n], s[-n:] if len(s) > n else s)
sys_start, sys_end = _se(system_prompt)
usr_start, usr_end = _se(text_piece)
try:
print(f"[moondream-endpoint] System prompt (start): {sys_start}")
print(f"[moondream-endpoint] System prompt (end): {sys_end}")
print(f"[moondream-endpoint] User prompt (full): {text_piece}")
except Exception:
pass
if width and height:
try:
print(f"[moondream-endpoint] Received image size: {width}x{height}")
except Exception:
pass
try:
print(f"[moondream-endpoint] Raw output: {json.dumps(out)}")
except Exception:
pass
# Ensure points array exists and normalized [0,1]
if not isinstance(out.get("points"), list) or not out["points"]:
return {"error": "No points returned"}
def _to_01(p):
x = float(p.get("x", 0.0))
y = float(p.get("y", 0.0))
if x > 1.0 or y > 1.0:
return {"x": x / 1000.0, "y": y / 1000.0}
return {"x": x, "y": y}
points_01 = [_to_01(p) for p in out["points"]]
return {"points": points_01, "raw": out}
@staticmethod
def _flip_point(p: Dict[str, Any]) -> Dict[str, float]:
x = float(p.get("x", 0.0))
y = float(p.get("y", 0.0))
x = 1.0 - x
return {"x": max(0.0, min(1.0, x)), "y": max(0.0, min(1.0, y))}
@classmethod
def _deduplicate_and_average_points(cls, points: List[Dict[str, Any]], tol: float = 0.03) -> List[Dict[str, float]]:
clusters: List[Dict[str, float]] = []
counts: List[int] = []
for p in points:
px = float(p.get("x", 0.0))
py = float(p.get("y", 0.0))
matched = False
for i, c in enumerate(clusters):
dx = px - c["x"]
dy = py - c["y"]
if dx * dx + dy * dy <= tol * tol:
n = counts[i]
c["x"] = (c["x"] * n + px) / (n + 1)
c["y"] = (c["y"] * n + py) / (n + 1)
counts[i] = n + 1
matched = True
break
if not matched:
clusters.append({"x": px, "y": py})
counts.append(1)
return clusters
@classmethod
def _tta_points(cls, points_a: List[Dict[str, Any]], points_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]:
# Convert flipped prediction back to original frame
unflipped_b = [cls._flip_point(p) for p in points_b_flipped]
merged = list(points_a) + unflipped_b
return cls._deduplicate_and_average_points(merged)
# Box-related utilities removed (endpoint is point-only)