File size: 8,124 Bytes
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3bdf4
7c2acc2
eb3bdf4
7c2acc2
eb3bdf4
 
 
 
7c2acc2
 
eb3bdf4
 
 
7c2acc2
 
 
5daee26
7c2acc2
 
 
 
 
7cc210d
34b89db
7cc210d
 
 
 
 
 
34b89db
 
7cc210d
 
34b89db
 
 
7c2acc2
 
34b89db
7c2acc2
 
 
 
 
 
 
 
 
 
6b2981b
 
 
 
 
 
 
 
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3bdf4
7c2acc2
 
eb3bdf4
 
 
 
 
7c2acc2
eb3bdf4
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3bdf4
7c2acc2
eb3bdf4
 
 
 
 
 
7c2acc2
eb3bdf4
 
7c2acc2
 
 
5dca6b2
eb3bdf4
 
 
 
 
 
 
5dca6b2
eb3bdf4
 
 
5dca6b2
 
 
 
 
 
 
 
 
 
 
 
eb3bdf4
 
 
 
 
 
 
 
 
 
 
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb3bdf4
7c2acc2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import base64
import io
import json
import os
from typing import Any, Dict, List, Optional

from PIL import Image

import torch
from transformers import AutoModelForCausalLM


def _b64_to_pil(data_url: str) -> Image.Image:
    if not isinstance(data_url, str) or not data_url.startswith("data:"):
        raise ValueError("Expected a data URL starting with 'data:'")
    header, b64data = data_url.split(",", 1)
    raw = base64.b64decode(b64data)
    img = Image.open(io.BytesIO(raw))
    img.load()
    return img


class EndpointHandler:
    """HF Inference Endpoint handler for Moondream3 Preview (point only).

    Input contract (OpenAI-style, simplified):
    {
      "system": "<system prompt>",
      "user": "<user prompt>",
      "image": "data:<mime>;base64,<...>",
      "prioritize_accuracy": true | false // optional (default true)
    }

    Output (point only):
      { points: [{x, y}], raw: <debug payload> }
    Coordinates are normalized [0,1].
    """

    def __init__(self, path: str = "") -> None:
        model_id = os.environ.get("MODEL_ID", "moondream/moondream3-preview")

        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

        # Load local repo (or remote if MODEL_ID points to hub id)
        # Pass token when accessing gated repos and ensure consistent device placement
        hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
        if torch.cuda.is_available():
            device_map = {"": "cuda"}
            dtype = torch.bfloat16
        else:
            device_map = {"": "cpu"}
            dtype = torch.float32
        load_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": dtype,
            "device_map": device_map,
        }
        if hub_token:
            load_kwargs["token"] = hub_token
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            **load_kwargs,
        )

        # Optional compilation for speed if exposed by remote code
        try:
            compile_fn = getattr(self.model, "compile", None)
            if callable(compile_fn):
                compile_fn()
        except Exception:
            pass

        # Prefer robustness over speed on HF endpoints: disable FlexAttention decoding
        # to avoid BlockMask attribute mismatches in some torch builds.
        try:
            if hasattr(self.model, "model") and hasattr(self.model.model, "use_flex_decoding"):
                self.model.model.use_flex_decoding = False
        except Exception:
            pass

    def __call__(self, data: Dict[str, Any]) -> Any:
        # Accept HF toolkit shapes: { inputs: {...} } or JSON string
        if isinstance(data, dict) and "inputs" in data:
            inputs_val = data.get("inputs")
            if isinstance(inputs_val, dict):
                data = inputs_val
            elif isinstance(inputs_val, (str, bytes, bytearray)):
                try:
                    if isinstance(inputs_val, (bytes, bytearray)):
                        inputs_val = inputs_val.decode("utf-8")
                    parsed = json.loads(inputs_val)
                    if isinstance(parsed, dict):
                        data = parsed
                except Exception:
                    pass

        # New input contract: expect 'system', 'user', 'image' (point task only)
        prioritize_accuracy = bool(data.get("prioritize_accuracy", True))

        system_prompt: Optional[str] = data.get("system")
        text_piece: Optional[str] = data.get("user")
        image_data_url: Optional[str] = data.get("image")
        if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
            return {"error": "image must be a data URL (data:...)"}
        if not text_piece:
            return {"error": "user text must be provided"}

        # Decode for dimensions and pass PIL to model
        try:
            pil = _b64_to_pil(image_data_url)
        except Exception as e:
            return {"error": f"Failed to decode image data URL: {e}"}

        width = getattr(pil, "width", None)
        height = getattr(pil, "height", None)
        if width and height:
            try:
                print(f"[moondream-endpoint] Received image size: {width}x{height}")
            except Exception:
                pass

        # Point-only inference
        try:
            if prioritize_accuracy:
                flipped = pil.transpose(Image.FLIP_LEFT_RIGHT)
                res_orig = self.model.point(pil, text_piece)
                res_flip = self.model.point(flipped, text_piece)
                points = self._tta_points(res_orig.get("points", []), res_flip.get("points", []))
                out: Dict[str, Any] = {"points": points}
            else:
                result = self.model.point(pil, text_piece)
                out = {"points": result.get("points", [])}
        except Exception as e:
            return {"error": f"Model inference failed: {e}"}

        # Print prompt, dimensions, and raw output
        # Log prompts and timings
        def _se(s: Optional[str], n: int = 120):
            if not s:
                return ("", "")
            return (s[:n], s[-n:] if len(s) > n else s)
        sys_start, sys_end = _se(system_prompt)
        usr_start, usr_end = _se(text_piece)
        try:
            print(f"[moondream-endpoint] System prompt (start): {sys_start}")
            print(f"[moondream-endpoint] System prompt (end): {sys_end}")
            print(f"[moondream-endpoint] User prompt (full): {text_piece}")
        except Exception:
            pass
        if width and height:
            try:
                print(f"[moondream-endpoint] Received image size: {width}x{height}")
            except Exception:
                pass
        try:
            print(f"[moondream-endpoint] Raw output: {json.dumps(out)}")
        except Exception:
            pass

        # Ensure points array exists and normalized [0,1]
        if not isinstance(out.get("points"), list) or not out["points"]:
            return {"error": "No points returned"}
        def _to_01(p):
            x = float(p.get("x", 0.0))
            y = float(p.get("y", 0.0))
            if x > 1.0 or y > 1.0:
                return {"x": x / 1000.0, "y": y / 1000.0}
            return {"x": x, "y": y}
        points_01 = [_to_01(p) for p in out["points"]]
        return {"points": points_01, "raw": out}

    @staticmethod
    def _flip_point(p: Dict[str, Any]) -> Dict[str, float]:
        x = float(p.get("x", 0.0))
        y = float(p.get("y", 0.0))
        x = 1.0 - x
        return {"x": max(0.0, min(1.0, x)), "y": max(0.0, min(1.0, y))}

    @classmethod
    def _deduplicate_and_average_points(cls, points: List[Dict[str, Any]], tol: float = 0.03) -> List[Dict[str, float]]:
        clusters: List[Dict[str, float]] = []
        counts: List[int] = []
        for p in points:
            px = float(p.get("x", 0.0))
            py = float(p.get("y", 0.0))
            matched = False
            for i, c in enumerate(clusters):
                dx = px - c["x"]
                dy = py - c["y"]
                if dx * dx + dy * dy <= tol * tol:
                    n = counts[i]
                    c["x"] = (c["x"] * n + px) / (n + 1)
                    c["y"] = (c["y"] * n + py) / (n + 1)
                    counts[i] = n + 1
                    matched = True
                    break
            if not matched:
                clusters.append({"x": px, "y": py})
                counts.append(1)
        return clusters

    @classmethod
    def _tta_points(cls, points_a: List[Dict[str, Any]], points_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]:
        # Convert flipped prediction back to original frame
        unflipped_b = [cls._flip_point(p) for p in points_b_flipped]
        merged = list(points_a) + unflipped_b
        return cls._deduplicate_and_average_points(merged)

    # Box-related utilities removed (endpoint is point-only)