import functools import math import os import tempfile from collections import defaultdict import cv2 import numpy as np import PIL import torch from PIL import Image, ImageDraw, ImageFile from transformers import AutoModelForImageTextToText, AutoProcessor import gradio as gr import spaces from molmo_utils import process_vision_info from typing import Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True # ── Constants ────────────────────────────────────────────────────────────────── MODEL_ID = "allenai/MolmoPoint-8B" MAX_IMAGE_SIZE = 512 MAX_VIDEO_HEIGHT = 512 POINT_SIZE = 0.01 KEYFRAME_HOLD_FRAMES = 3 SHOW_TRAILS = True MAX_NEW_TOKENS = 2048 MAX_FPS = 10 COLORS = [ "rgb(255, 100, 180)", "rgb(100, 180, 255)", "rgb(180, 255, 100)", "rgb(255, 180, 100)", "rgb(100, 255, 180)", "rgb(180, 100, 255)", "rgb(255, 255, 100)", "rgb(100, 255, 255)", "rgb(255, 120, 120)", "rgb(120, 255, 255)", "rgb(255, 255, 120)", "rgb(255, 120, 255)", ] # ── Model loading ────────────────────────────────────────────────────────────── print(f"Loading {MODEL_ID}...") processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, padding_side="left", ) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, trust_remote_code=True, dtype="bfloat16", device_map="auto", ) print("Model loaded successfully.") # ── Helper functions ─────────────────────────────────────────────────────────── def _parse_rgb(color_str): """Parse 'rgb(r, g, b)' to (r, g, b) tuple.""" nums = color_str.replace("rgb(", "").replace(")", "").split(",") return tuple(int(n.strip()) for n in nums) COLORS_BGR = [(_parse_rgb(c)[2], _parse_rgb(c)[1], _parse_rgb(c)[0]) for c in COLORS] def is_tracking_output(generated_text: str) -> bool: """Detect tracking from model output by checking for 1 and (max_gap is None or span <= max_gap): for t in range(1, span): alpha = t / span positions[f_idx + t] = (x + alpha * (nx - x), y + alpha * (ny - y)) return positions def create_annotated_video(video_path, points, metadata, tracking): """Draw points on the original video with interpolation and fading trails. Points format: [(object_id, timestamp, x, y), ...] Coordinates are in the processed frame space (metadata["video_size"]). """ cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) proc_w, proc_h = metadata["video_size"] scale_x = vid_w / proc_w scale_y = vid_h / proc_h # Build per-object keyframes: {obj_id: [(frame_idx, x, y), ...]} obj_keyframes = defaultdict(list) for object_id, ts, x, y in points: f_idx = int(round(float(ts) * fps)) sx, sy = float(x) * scale_x, float(y) * scale_y obj_keyframes[int(object_id)].append((f_idx, sx, sy)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) obj_positions = {} obj_keyframe_set = {} max_gap_frames = int(fps) # gaps > 1 second: no interpolation, object disappears for obj_id, kfs in obj_keyframes.items(): kfs.sort(key=lambda k: k[0]) obj_positions[obj_id] = _interpolate_keyframes(kfs, total_frames, max_gap=max_gap_frames) raw_kf = set(f_idx for f_idx, _, _ in kfs) obj_keyframe_set[obj_id] = set( f for kf in raw_kf for f in range(kf - KEYFRAME_HOLD_FRAMES, kf + KEYFRAME_HOLD_FRAMES + 1) ) out_path = tempfile.mktemp(suffix=".mp4") fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(out_path, fourcc, fps, (vid_w, vid_h)) radius = max(5, int(max(vid_w, vid_h) * POINT_SIZE)) trail_length = int(fps * 2) obj_history = defaultdict(list) current_frame = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break for obj_id, positions in obj_positions.items(): if current_frame in positions: px, py = positions[current_frame] # Clear trail if the object was absent in the previous frame (gap reappearance) if current_frame > 0 and (current_frame - 1) not in positions: obj_history[obj_id] = [] obj_history[obj_id].append((px, py)) if len(obj_history[obj_id]) > trail_length: obj_history[obj_id] = obj_history[obj_id][-trail_length:] if tracking: color = COLORS_BGR[(obj_id - 1) % len(COLORS_BGR)] else: color = COLORS_BGR[0] # Draw fading trail trail = obj_history[obj_id] n_trail = len(trail) if SHOW_TRAILS and n_trail >= 2: for i in range(n_trail - 1): alpha = (i + 1) / n_trail trail_color = tuple(int(c * alpha) for c in color) thickness = max(1, int(radius * 0.6 * alpha)) pt1 = (int(trail[i][0]), int(trail[i][1])) pt2 = (int(trail[i + 1][0]), int(trail[i + 1][1])) cv2.line(frame, pt1, pt2, trail_color, thickness) # Solid on keyframes, outline-only on interpolated frames if current_frame in obj_keyframe_set[obj_id]: cv2.circle(frame, (int(px), int(py)), radius, color, -1) cv2.circle(frame, (int(px), int(py)), radius + 2, (255, 255, 255), 2) else: cv2.circle(frame, (int(px), int(py)), radius, color, 2) out.write(frame) current_frame += 1 cap.release() out.release() return out_path # ── Inference functions ──────────────────────────────────────────────────────── @spaces.GPU def process_images(user_text, input_images, max_tokens): if not input_images: return "Please upload at least one image.", [], "[]" pil_images = [] for img_path in input_images: if isinstance(img_path, tuple): img_path = img_path[0] pil_images.append(Image.open(img_path).convert("RGB")) # Build messages content = [dict(type="text", text=user_text)] for img in pil_images: content.append(dict(type="image", image=img)) messages = [{"role": "user", "content": content}] # Process inputs images, _, _ = process_vision_info(messages) text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(f"Prompt: {text}") inputs = processor( images=images, text=text, padding=True, return_tensors="pt", return_pointing_metadata=True, ) metadata = inputs.pop("metadata") inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} # Generate with torch.inference_mode(): with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): output = model.generate( **inputs, logits_processor=model.build_logit_processor_from_inputs(inputs), max_new_tokens=int(max_tokens), temperature=0.0 ) generated_tokens = output[0, inputs["input_ids"].size(1):] generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Extract points points = model.extract_image_points( generated_text, metadata["token_pooling"], metadata["subpatch_mapping"], metadata["image_sizes"], ) points_table = format_points_list(points, is_video=False) print(f"Output text: {generated_text}") print("Extracted points:", points_table) if points: group_by_index = defaultdict(list) for object_id, ix, x, y in points: group_by_index[ix].append((x, y)) annotated = [] for ix, pts in group_by_index.items(): annotated.append(draw_points(images[ix], pts)) return generated_text, annotated, points_table return generated_text, pil_images, points_table @spaces.GPU def process_video(user_text, video_path, frame_sample_mode, max_frames, max_fps, max_tokens): if not video_path: return "Please upload a video.", None, [], "[]" # Build messages video_kwargs_msg = { "num_frames": int(max_frames), "frame_sample_mode": frame_sample_mode, } if max_fps is not None and max_fps > 0: video_kwargs_msg["max_fps"] = int(max_fps) messages = [ { "role": "user", "content": [ dict(type="text", text=user_text), dict(type="video", video=video_path, **video_kwargs_msg), ], } ] # Process vision info _, videos, video_kwargs = process_vision_info(messages) videos, video_metadatas = zip(*videos) videos, video_metadatas = list(videos), list(video_metadatas) text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(f"Prompt: {text}") inputs = processor( videos=videos, video_metadata=video_metadatas, text=text, padding=True, return_tensors="pt", return_pointing_metadata=True, **video_kwargs, ) metadata = inputs.pop("metadata") inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} # Generate with torch.inference_mode(): with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): output = model.generate( **inputs, logits_processor=model.build_logit_processor_from_inputs(inputs), max_new_tokens=int(max_tokens), temperature=0.0 ) generated_tokens = output[0, inputs["input_ids"].size(1):] generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Extract points points = model.extract_video_points( generated_text, metadata["token_pooling"], metadata["subpatch_mapping"], metadata["timestamps"], metadata["video_size"], ) tracking = is_tracking_output(generated_text) annotated_video = None annotated_frames = [] points_table = format_points_list(points, is_video=True) print(f"Output text: {generated_text}") print("Extracted points:", points_table) if points: print(f"Extracted {len(points)} points. Tracking={tracking}") # Build annotated frames on sampled video frames if tracking: group_by_time = defaultdict(list) for object_id, ts, x, y in points: group_by_time[ts].append((object_id, x, y)) group_by_frame = defaultdict(list) for ts, pts_with_ids in group_by_time.items(): ix = int(np.argmin(np.abs(metadata["timestamps"] - ts))) group_by_frame[ix] += pts_with_ids for ix, pts_with_ids in sorted(group_by_frame.items()): frame_img = draw_points_colored(videos[0][ix], pts_with_ids) ts = metadata["timestamps"][ix] annotated_frames.append((frame_img, f"t={ts:.2f}s")) else: group_by_time = defaultdict(list) for object_id, ts, x, y in points: group_by_time[ts].append((x, y)) group_by_frame = defaultdict(list) for ts, pts in group_by_time.items(): ix = int(np.argmin(np.abs(metadata["timestamps"] - ts))) group_by_frame[ix] += pts for ix, pts in sorted(group_by_frame.items()): frame_img = draw_points(videos[0][ix], pts) ts = metadata["timestamps"][ix] annotated_frames.append((frame_img, f"t={ts:.2f}s")) # Annotated video with interpolation + trails annotated_video = create_annotated_video(video_path, points, metadata, tracking) return generated_text, annotated_video, annotated_frames, points_table # ── Gradio UI ────────────────────────────────────────────────────────────────── # Read processor defaults for video settings _default_frame_sample_mode = processor.video_processor.frame_sample_mode _default_max_frames = processor.video_processor.num_frames css = """ #col-container { margin: 0 auto; max-width: 960px; } #main-title h1 {font-size: 2.3em !important;} #input_image image { object-fit: contain !important; } #input_video video { object-fit: contain !important; } .gallery-item img { border: none !important; outline: none !important; } """ with gr.Blocks() as demo: gr.Markdown("# **MolmoPoint-8B Demo**", elem_id="main-title") gr.Markdown( "Image & video pointing and tracking using the " "[MolmoPoint-8B](https://huggingface.co/allenai/MolmoPoint-8B) pointing model." ) with gr.Row(): # ── LEFT COLUMN: Inputs ── with gr.Column(): with gr.Tabs() as input_tabs: with gr.TabItem("Video Tracking", id="video_tracking_tab") as video_tracking_tab: video_tracking = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT) with gr.TabItem("Video Pointing", id="video_pointing_tab") as video_pointing_tab: video_pointing = gr.Video(label="Input Video", elem_id="input_video_pointing", height=MAX_VIDEO_HEIGHT) with gr.TabItem("Image(s) Pointing", id="image_tab") as image_tab: images_input = gr.Gallery( label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE, ) input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text") with gr.Row(visible=True) as video_params_row: frame_sample_mode = gr.Dropdown(choices=[_default_frame_sample_mode, "fps"], value=_default_frame_sample_mode, label="frame_sample_mode") max_frames = gr.Number(value=_default_max_frames, label="max_frames") max_fps = gr.Number(value=MAX_FPS, label="max_fps") max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS) with gr.Row(): submit_button = gr.Button("Submit", variant="primary", scale=3) clear_all_button = gr.ClearButton( components=[video_tracking, video_pointing, images_input, input_text], value="Clear All", scale=1, ) # ── RIGHT COLUMN: Outputs ── with gr.Column(): with gr.Tabs(): with gr.TabItem("Output Text"): output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10) with gr.TabItem("Extracted Points"): output_points = gr.Textbox( label="Extracted Points ([[id, time/index, x, y]])", lines=15, ) output_warning = gr.HTML(visible=False) with gr.Tabs(visible=True) as video_output_tabs: with gr.TabItem("Annotated Video"): output_video = gr.Video(label="Annotated Video", height=MAX_VIDEO_HEIGHT) with gr.TabItem("Annotated Frames"): gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") output_annotations = gr.Gallery(label="Annotated Frames (Video)", height=MAX_IMAGE_SIZE) with gr.Group(visible=False) as image_output_group: gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE) # ── Examples ── with gr.Group(visible=True) as video_tracking_examples_group: gr.Markdown("### Video Tracking Examples") gr.Examples( examples=[ ["example-videos/us_canada_hockey.mp4", "Track the U.S. hockey players."], ["example-videos/penguins.mp4", "Track all the penguins."], ["example-videos/arena_basketball.mp4", "Track the players in yellow uniform in 1 fps."], ], inputs=[video_tracking, input_text], label="Video Tracking Examples", ) with gr.Group(visible=False) as video_pointing_examples_group: gr.Markdown("### Video Pointing Examples") gr.Examples( examples=[ ["example-videos/sports.mp4", "Point to players in white/blue jersey"], ["example-videos/travel.mp4", "Point to standalone spiky towers"], ], inputs=[video_pointing, input_text], label="Video Pointing Examples", ) with gr.Group(visible=False) as image_examples_group: gr.Markdown("### Image Examples") gr.Examples( examples=[ [["example-images/boat1.jpeg", "example-images/boat2.jpeg"], "Point to the boats."], [["example-images/messy1.jpg", "example-images/messy2.jpg", "example-images/messy3.jpg", "example-images/messy4.jpg"], "Point to the scissors."], ], inputs=[images_input, input_text], label="Image Pointing Examples", ) # ── Tab switching: toggle visibility + track active tab ── active_tab = gr.State("video_tracking") def _select_video_tracking_tab(): return ( "video_tracking", gr.update(value=10), # max_fps gr.update(visible=True), # video_tracking_examples_group gr.update(visible=False), # video_pointing_examples_group gr.update(visible=False), # image_examples_group gr.update(visible=True), # video_params_row gr.update(visible=True), # video_output_tabs gr.update(visible=False), # image_output_group ) def _select_video_pointing_tab(): return ( "video_pointing", gr.update(value=2), # max_fps gr.update(visible=False), # video_tracking_examples_group gr.update(visible=True), # video_pointing_examples_group gr.update(visible=False), # image_examples_group gr.update(visible=True), # video_params_row gr.update(visible=True), # video_output_tabs gr.update(visible=False), # image_output_group ) def _select_image_tab(): return ( "image", gr.update(), # max_fps unchanged gr.update(visible=False), # video_tracking_examples_group gr.update(visible=False), # video_pointing_examples_group gr.update(visible=True), # image_examples_group gr.update(visible=False), # video_params_row gr.update(visible=False), # video_output_tabs gr.update(visible=True), # image_output_group ) tab_outputs = [ active_tab, max_fps, video_tracking_examples_group, video_pointing_examples_group, image_examples_group, video_params_row, video_output_tabs, image_output_group, ] video_tracking_tab.select(fn=_select_video_tracking_tab, outputs=tab_outputs) video_pointing_tab.select(fn=_select_video_pointing_tab, outputs=tab_outputs) image_tab.select(fn=_select_image_tab, outputs=tab_outputs) _WARNING_STYLE = ( 'style="background:#fef2f2; border:1px solid #fca5a5; border-radius:6px; ' 'padding:8px 12px; color:#991b1b; font-size:14px;"' ) def _fps_warning(generated_text, current_max_fps): """Return gr.update for the warning HTML block.""" tracking = "⚠️ For best video pointing results, set max_fps=2.' return gr.update(value=html, visible=True) if tracking and int(current_max_fps) != 10: html = f'
⚠️ For best tracking results, set max_fps=10.
' return gr.update(value=html, visible=True) return gr.update(value="", visible=False) def dispatch_submit(tab, user_text, video_tracking_path, video_pointing_path, input_images, fsm, mf, mfps, max_tok): if tab == "image": text_out, img_gallery, pts = process_images(user_text, input_images, max_tok) return text_out, pts, gr.update(value="", visible=False), None, [], img_gallery else: video_path = video_tracking_path if tab == "video_tracking" else video_pointing_path text_out, ann_video, ann_frames, pts = process_video( user_text, video_path, fsm, mf, mfps, max_tok, ) warning = _fps_warning(text_out, mfps) return text_out, pts, warning, ann_video, ann_frames, [] submit_button.click( fn=dispatch_submit, inputs=[active_tab, input_text, video_tracking, video_pointing, images_input, frame_sample_mode, max_frames, max_fps, max_tok_slider], outputs=[output_text, output_points, output_warning, output_video, output_annotations, output_annotations_img], ) if __name__ == "__main__": demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True, share=True)