| """ |
| SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints |
| Model: facebook/sam3 |
| |
| Using the official sam3 package from Meta (pip install sam3) |
| NOT the transformers integration. |
| |
| For ProofPath video assessment - text-prompted segmentation to find UI elements. |
| Supports text prompts like "Save button", "dropdown menu", "text input field". |
| |
| KEY CAPABILITIES: |
| - Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) |
| - Promptable Concept Segmentation (PCS): 270K unique concepts |
| - Video tracking: Consistent object IDs across frames |
| - Presence token: Discriminates similar elements ("player in white" vs "player in red") |
| |
| REQUIREMENTS: |
| 1. Set HF_TOKEN environment variable (model is gated) |
| 2. Accept license at https://huggingface.co/facebook/sam3 |
| """ |
|
|
| from typing import Dict, List, Any, Optional, Union |
| import torch |
| import numpy as np |
| import base64 |
| import io |
| import os |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize SAM 3 model for text-prompted segmentation. |
| Uses the official sam3 package from Meta. |
| |
| Args: |
| path: Path to the model directory (ignored - we load from HF hub) |
| """ |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| from sam3.model_builder import build_sam3_image_model |
| from sam3.model.sam3_image_processor import Sam3Processor |
| |
| |
| |
| self.model = build_sam3_image_model() |
| self.processor = Sam3Processor(self.model) |
| |
| |
| self._video_predictor = None |
| |
| def _get_video_predictor(self): |
| """Lazy load video predictor only when needed.""" |
| if self._video_predictor is None: |
| from sam3.model_builder import build_sam3_video_predictor |
| self._video_predictor = build_sam3_video_predictor() |
| return self._video_predictor |
| |
| def _load_image(self, image_data: Any): |
| """Load image from various formats.""" |
| from PIL import Image |
| import requests |
| |
| if isinstance(image_data, Image.Image): |
| return image_data.convert('RGB') |
| elif isinstance(image_data, str): |
| if image_data.startswith(('http://', 'https://')): |
| response = requests.get(image_data, stream=True) |
| return Image.open(response.raw).convert('RGB') |
| elif image_data.startswith('data:'): |
| header, encoded = image_data.split(',', 1) |
| image_bytes = base64.b64decode(encoded) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| else: |
| |
| image_bytes = base64.b64decode(image_data) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| elif isinstance(image_data, bytes): |
| return Image.open(io.BytesIO(image_data)).convert('RGB') |
| else: |
| raise ValueError(f"Unsupported image input type: {type(image_data)}") |
| |
| def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> tuple: |
| """Load video frames from various formats.""" |
| import cv2 |
| from PIL import Image |
| import tempfile |
| |
| |
| if isinstance(video_data, str): |
| if video_data.startswith(('http://', 'https://')): |
| import requests |
| response = requests.get(video_data, stream=True) |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| video_path = f.name |
| elif video_data.startswith('data:'): |
| header, encoded = video_data.split(',', 1) |
| video_bytes = base64.b64decode(encoded) |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_bytes) |
| video_path = f.name |
| else: |
| video_bytes = base64.b64decode(video_data) |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_bytes) |
| video_path = f.name |
| elif isinstance(video_data, bytes): |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_data) |
| video_path = f.name |
| else: |
| raise ValueError(f"Unsupported video input type: {type(video_data)}") |
| |
| try: |
| cap = cv2.VideoCapture(video_path) |
| video_fps = cap.get(cv2.CAP_PROP_FPS) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| duration = total_frames / video_fps if video_fps > 0 else 0 |
| |
| |
| target_frames = min(max_frames, int(duration * fps), total_frames) |
| if target_frames <= 0: |
| target_frames = min(max_frames, total_frames) |
| |
| frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) |
| |
| frames = [] |
| for idx in frame_indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
| if ret: |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_image = Image.fromarray(frame_rgb) |
| frames.append(pil_image) |
| |
| cap.release() |
| |
| metadata = { |
| "duration": duration, |
| "total_frames": total_frames, |
| "sampled_frames": len(frames), |
| "video_fps": video_fps |
| } |
| |
| return video_path, metadata |
| |
| except Exception as e: |
| if os.path.exists(video_path): |
| os.unlink(video_path) |
| raise e |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process image or video with SAM 3 for text-prompted segmentation. |
| |
| INPUT FORMATS: |
| |
| 1. Single image with text prompt (find all instances): |
| { |
| "inputs": <image_url_or_base64>, |
| "parameters": { |
| "prompt": "Save button", |
| "return_masks": true |
| } |
| } |
| |
| 2. Single image with multiple text prompts: |
| { |
| "inputs": <image_url_or_base64>, |
| "parameters": { |
| "prompts": ["button", "text field", "dropdown"] |
| } |
| } |
| |
| 3. Video with text prompt (track all instances): |
| { |
| "inputs": <video_url_or_base64>, |
| "parameters": { |
| "mode": "video", |
| "prompt": "Submit button", |
| "max_frames": 100 |
| } |
| } |
| |
| 4. ProofPath UI element detection: |
| { |
| "inputs": <screenshot_base64>, |
| "parameters": { |
| "mode": "ui_elements", |
| "elements": ["Save button", "Cancel button", "text input"] |
| } |
| } |
| |
| OUTPUT FORMAT: |
| { |
| "results": [ |
| { |
| "prompt": "Save button", |
| "instances": [ |
| { |
| "box": [x1, y1, x2, y2], |
| "score": 0.95, |
| "mask": "<base64_png>" // if return_masks=true |
| } |
| ] |
| } |
| ], |
| "image_size": {"width": 1920, "height": 1080} |
| } |
| """ |
| inputs = data.get("inputs") |
| params = data.get("parameters", {}) |
| |
| if inputs is None: |
| raise ValueError("No inputs provided") |
| |
| mode = params.get("mode", "image") |
| |
| if mode == "video": |
| return self._process_video(inputs, params) |
| elif mode == "ui_elements": |
| return self._process_ui_elements(inputs, params) |
| else: |
| return self._process_single_image(inputs, params) |
| |
| def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
| """Process a single image with text prompts using official sam3 API.""" |
| image = self._load_image(image_data) |
| |
| return_masks = params.get("return_masks", True) |
| |
| |
| prompt = params.get("prompt") |
| prompts = params.get("prompts", [prompt] if prompt else []) |
| |
| if not prompts: |
| raise ValueError("No text prompt(s) provided") |
| |
| |
| inference_state = self.processor.set_image(image) |
| |
| results = [] |
| |
| for text_prompt in prompts: |
| |
| output = self.processor.set_text_prompt( |
| state=inference_state, |
| prompt=text_prompt |
| ) |
| |
| masks = output.get("masks", []) |
| boxes = output.get("boxes", []) |
| scores = output.get("scores", []) |
| |
| instances = [] |
| |
| |
| if hasattr(boxes, 'tolist'): |
| boxes = boxes.tolist() |
| if hasattr(scores, 'tolist'): |
| scores = scores.tolist() |
| |
| for i in range(len(boxes)): |
| instance = { |
| "box": boxes[i] if i < len(boxes) else None, |
| "score": float(scores[i]) if i < len(scores) else 0.0 |
| } |
| |
| if return_masks and masks is not None and i < len(masks): |
| |
| mask = masks[i] |
| if hasattr(mask, 'cpu'): |
| mask = mask.cpu().numpy() |
| mask_uint8 = (mask * 255).astype(np.uint8) |
| from PIL import Image as PILImage |
| mask_img = PILImage.fromarray(mask_uint8) |
| buffer = io.BytesIO() |
| mask_img.save(buffer, format='PNG') |
| instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| |
| instances.append(instance) |
| |
| results.append({ |
| "prompt": text_prompt, |
| "instances": instances, |
| "count": len(instances) |
| }) |
| |
| return { |
| "results": results, |
| "image_size": {"width": image.width, "height": image.height} |
| } |
| |
| def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
| """ |
| ProofPath-specific mode: Detect multiple UI element types in a screenshot. |
| Returns structured data for each element type with bounding boxes. |
| """ |
| image = self._load_image(image_data) |
| |
| elements = params.get("elements", []) |
| if not elements: |
| |
| elements = ["button", "text input", "dropdown", "checkbox", "link"] |
| |
| |
| inference_state = self.processor.set_image(image) |
| |
| all_detections = {} |
| |
| for element_type in elements: |
| output = self.processor.set_text_prompt( |
| state=inference_state, |
| prompt=element_type |
| ) |
| |
| boxes = output.get("boxes", []) |
| scores = output.get("scores", []) |
| |
| if hasattr(boxes, 'tolist'): |
| boxes = boxes.tolist() |
| if hasattr(scores, 'tolist'): |
| scores = scores.tolist() |
| |
| detections = [] |
| for i in range(len(boxes)): |
| box = boxes[i] |
| detections.append({ |
| "box": box, |
| "score": float(scores[i]) if i < len(scores) else 0.0, |
| "center": [ |
| (box[0] + box[2]) / 2, |
| (box[1] + box[3]) / 2 |
| ] if len(box) >= 4 else None |
| }) |
| |
| all_detections[element_type] = { |
| "count": len(detections), |
| "instances": detections |
| } |
| |
| return { |
| "ui_elements": all_detections, |
| "image_size": {"width": image.width, "height": image.height}, |
| "total_elements": sum(d["count"] for d in all_detections.values()) |
| } |
| |
| def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: |
| """ |
| Process video with SAM3 Video for text-prompted tracking. |
| Uses the official sam3 video predictor API. |
| """ |
| video_predictor = self._get_video_predictor() |
| |
| prompt = params.get("prompt") |
| if not prompt: |
| raise ValueError("Text prompt required for video mode") |
| |
| max_frames = params.get("max_frames", 100) |
| |
| |
| video_path, video_metadata = self._load_video_frames(video_data, max_frames) |
| |
| try: |
| |
| response = video_predictor.handle_request( |
| request=dict( |
| type="start_session", |
| resource_path=video_path, |
| ) |
| ) |
| session_id = response.get("session_id") |
| |
| |
| response = video_predictor.handle_request( |
| request=dict( |
| type="add_prompt", |
| session_id=session_id, |
| frame_index=0, |
| text=prompt, |
| ) |
| ) |
| |
| output = response.get("outputs", {}) |
| |
| |
| object_ids = output.get("object_ids", []) |
| if hasattr(object_ids, 'tolist'): |
| object_ids = object_ids.tolist() |
| |
| |
| propagate_response = video_predictor.handle_request( |
| request=dict( |
| type="propagate", |
| session_id=session_id, |
| ) |
| ) |
| |
| |
| per_frame_results = propagate_response.get("per_frame_outputs", {}) |
| |
| |
| tracks = [] |
| for obj_id in object_ids: |
| track = { |
| "object_id": int(obj_id) if hasattr(obj_id, 'item') else obj_id, |
| "frames": [] |
| } |
| tracks.append(track) |
| |
| return { |
| "prompt": prompt, |
| "video_metadata": video_metadata, |
| "objects_tracked": len(object_ids), |
| "tracks": tracks, |
| "session_id": session_id |
| } |
| |
| finally: |
| |
| if os.path.exists(video_path): |
| os.unlink(video_path) |
|
|
|
|
| |
| if __name__ == "__main__": |
| handler = EndpointHandler() |
| |
| |
| test_data = { |
| "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", |
| "parameters": { |
| "prompt": "ear", |
| "return_masks": False |
| } |
| } |
| |
| result = handler(test_data) |
| print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") |
| for inst in result['results'][0]['instances']: |
| print(f" Box: {inst['box']}, Score: {inst['score']:.3f}") |