| """ |
| Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints |
| Model: allenai/Molmo2-8B |
| |
| For ProofPath video assessment - video pointing, tracking, and grounded analysis. |
| Unique capability: Returns pixel-level coordinates for objects in videos. |
| """ |
|
|
| from typing import Dict, List, Any, Optional, Tuple, Union |
| import torch |
| import numpy as np |
| import base64 |
| import io |
| import tempfile |
| import os |
| import re |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize Molmo 2 model for video pointing and tracking. |
| |
| Args: |
| path: Path to the model directory (ignored - we always load from HF hub) |
| """ |
| |
| model_id = "allenai/Molmo2-8B" |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
| |
| self.processor = AutoProcessor.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| ) |
| |
| self.model = AutoModelForImageTextToText.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| ) |
| |
| if not torch.cuda.is_available(): |
| self.model = self.model.to(self.device) |
| |
| self.model.eval() |
| |
| |
| self.max_frames = 128 |
| self.default_fps = 2.0 |
| |
| |
| self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>") |
| self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)") |
| self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})") |
| |
| def _parse_video_points(self, text: str, image_w: int, image_h: int) -> List[Dict]: |
| """ |
| Extract pointing coordinates from Molmo output. |
| |
| Molmo outputs coordinates in format: |
| <points coords="8.5 0 183 216; 8.5 1 245 198"/> |
| Where: timestamp instance_id x y (coords scaled by 1000) |
| """ |
| all_points = [] |
| |
| for coord_match in self.COORD_REGEX.finditer(text): |
| for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): |
| timestamp = float(frame_match.group(1)) |
| |
| for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): |
| instance_id = int(point_match.group(1)) |
| |
| x = float(point_match.group(2)) / 1000 * image_w |
| y = float(point_match.group(3)) / 1000 * image_h |
| |
| if 0 <= x <= image_w and 0 <= y <= image_h: |
| all_points.append({ |
| "timestamp": timestamp, |
| "instance_id": instance_id, |
| "x": x, |
| "y": y |
| }) |
| |
| return all_points |
| |
| def _load_image(self, image_data: Any): |
| """Load a single image from various formats.""" |
| from PIL import Image |
| import requests |
| |
| if isinstance(image_data, Image.Image): |
| return image_data |
| elif isinstance(image_data, str): |
| if image_data.startswith(('http://', 'https://')): |
| response = requests.get(image_data, stream=True) |
| return Image.open(response.raw).convert('RGB') |
| elif image_data.startswith('data:'): |
| header, encoded = image_data.split(',', 1) |
| image_bytes = base64.b64decode(encoded) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| else: |
| image_bytes = base64.b64decode(image_data) |
| return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| elif isinstance(image_data, bytes): |
| return Image.open(io.BytesIO(image_data)).convert('RGB') |
| else: |
| raise ValueError(f"Unsupported image input type: {type(image_data)}") |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process video or images with Molmo 2. |
| |
| Expected input formats: |
| |
| 1. Image analysis with pointing: |
| { |
| "inputs": <image_url_or_base64>, |
| "parameters": { |
| "prompt": "Point to the Excel cell B2.", |
| "max_new_tokens": 1024 |
| } |
| } |
| |
| 2. Video analysis: |
| { |
| "inputs": <video_url>, |
| "parameters": { |
| "prompt": "What happens in this video?", |
| "max_new_tokens": 2048 |
| } |
| } |
| |
| 3. Multi-image comparison: |
| { |
| "inputs": [<image1>, <image2>], |
| "parameters": { |
| "prompt": "Compare these screenshots." |
| } |
| } |
| |
| Returns: |
| { |
| "generated_text": "...", |
| "points": [{"timestamp": 0, "x": 123, "y": 456, ...}], # If pointing detected |
| "image_size": {...} |
| } |
| """ |
| inputs = data.get("inputs") |
| if inputs is None: |
| inputs = data.get("video") or data.get("image") or data.get("images") |
| if inputs is None: |
| raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") |
| |
| params = data.get("parameters", {}) |
| prompt = params.get("prompt", "Describe this image.") |
| max_new_tokens = params.get("max_new_tokens", 1024) |
| |
| try: |
| if isinstance(inputs, list): |
| return self._process_multi_image(inputs, prompt, max_new_tokens) |
| elif self._is_video(inputs, params): |
| return self._process_video(inputs, prompt, params, max_new_tokens) |
| else: |
| return self._process_image(inputs, prompt, max_new_tokens) |
| |
| except Exception as e: |
| import traceback |
| return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()} |
| |
| def _is_video(self, inputs: Any, params: Dict) -> bool: |
| """Determine if input is video.""" |
| if params.get("input_type") == "video": |
| return True |
| if params.get("input_type") == "image": |
| return False |
| |
| if isinstance(inputs, str): |
| lower = inputs.lower() |
| video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] |
| return any(ext in lower for ext in video_exts) |
| |
| return False |
| |
| def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]: |
| """Process a single image.""" |
| from PIL import Image |
| |
| image = self._load_image(image_data) |
| |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
| |
| |
| inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ) |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| |
| |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| |
| |
| generated_tokens = output[0, inputs['input_ids'].size(1):] |
| generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| result = { |
| "generated_text": generated_text, |
| "image_size": {"width": image.width, "height": image.height} |
| } |
| |
| |
| points = self._parse_video_points(generated_text, image.width, image.height) |
| if points: |
| result["points"] = points |
| result["num_points"] = len(points) |
| |
| return result |
| |
| def _process_video( |
| self, |
| video_data: Any, |
| prompt: str, |
| params: Dict, |
| max_new_tokens: int |
| ) -> Dict[str, Any]: |
| """Process video using molmo_utils.""" |
| from molmo_utils import process_vision_info |
| |
| |
| if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')): |
| video_source = video_data |
| temp_path = None |
| else: |
| |
| if isinstance(video_data, str): |
| video_bytes = base64.b64decode(video_data) |
| else: |
| video_bytes = video_data |
| |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_bytes) |
| video_source = f.name |
| temp_path = f.name |
| |
| try: |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": prompt}, |
| {"type": "video", "video": video_source}, |
| ], |
| } |
| ] |
| |
| |
| _, videos, video_kwargs = process_vision_info(messages) |
| videos, video_metadatas = zip(*videos) |
| videos, video_metadatas = list(videos), list(video_metadatas) |
| |
| |
| text = self.processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| inputs = self.processor( |
| videos=videos, |
| video_metadata=video_metadatas, |
| text=text, |
| padding=True, |
| return_tensors="pt", |
| **video_kwargs, |
| ) |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| |
| |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| |
| |
| generated_tokens = output[0, inputs['input_ids'].size(1):] |
| generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| |
| video_w = video_metadatas[0].get("width", 1920) |
| video_h = video_metadatas[0].get("height", 1080) |
| |
| result = { |
| "generated_text": generated_text, |
| "video_metadata": { |
| "width": video_w, |
| "height": video_h, |
| } |
| } |
| |
| |
| points = self._parse_video_points(generated_text, video_w, video_h) |
| if points: |
| result["points"] = points |
| result["num_points"] = len(points) |
| |
| return result |
| |
| finally: |
| |
| if temp_path and os.path.exists(temp_path): |
| os.unlink(temp_path) |
| |
| def _process_multi_image( |
| self, |
| images_data: List, |
| prompt: str, |
| max_new_tokens: int |
| ) -> Dict[str, Any]: |
| """Process multiple images.""" |
| from PIL import Image |
| |
| images = [self._load_image(img) for img in images_data] |
| |
| |
| content = [] |
| for image in images: |
| content.append({"type": "image", "image": image}) |
| content.append({"type": "text", "text": prompt}) |
| |
| messages = [{"role": "user", "content": content}] |
| |
| |
| inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ) |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| |
| |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| |
| |
| generated_tokens = output[0, inputs['input_ids'].size(1):] |
| generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| result = { |
| "generated_text": generated_text, |
| "num_images": len(images), |
| "image_sizes": [{"width": img.width, "height": img.height} for img in images] |
| } |
| |
| |
| if images: |
| points = self._parse_video_points(generated_text, images[0].width, images[0].height) |
| if points: |
| result["points"] = points |
| result["num_points"] = len(points) |
| |
| return result |
|
|