jamepark3922 commited on
Commit
d0f93f4
·
0 Parent(s):

Initial Molmo-Point HF Spaces app

Browse files
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Molmo-Point Demo
3
+ emoji: 👆
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.3.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Molmo-Point - Image & Video Pointing & Tracking
12
+ ---
app.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ import os
4
+ import tempfile
5
+ from collections import defaultdict
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import PIL
10
+ import torch
11
+ from PIL import Image, ImageDraw, ImageFile
12
+ from transformers import AutoModelForImageTextToText, AutoProcessor
13
+
14
+ import gradio as gr
15
+ import spaces
16
+ from molmo_utils import process_vision_info
17
+
18
+ from typing import Iterable
19
+ from gradio.themes import Soft
20
+ from gradio.themes.utils import colors, fonts, sizes
21
+
22
+ Image.MAX_IMAGE_PIXELS = None
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+
25
+ # ── Constants ──────────────────────────────────────────────────────────────────
26
+
27
+ MODEL_ID = "allenai/MolmoPoint-8B"
28
+ MAX_IMAGE_SIZE = 512
29
+ MAX_VIDEO_HEIGHT = 512
30
+ POINT_SIZE = 0.01
31
+ KEYFRAME_HOLD_FRAMES = 3
32
+ SHOW_TRAILS = True
33
+ MAX_NEW_TOKENS = 2048
34
+ MAX_FPS = 10
35
+
36
+ COLORS = [
37
+ "rgb(255, 100, 180)",
38
+ "rgb(100, 180, 255)",
39
+ "rgb(180, 255, 100)",
40
+ "rgb(255, 180, 100)",
41
+ "rgb(100, 255, 180)",
42
+ "rgb(180, 100, 255)",
43
+ "rgb(255, 255, 100)",
44
+ "rgb(100, 255, 255)",
45
+ "rgb(255, 120, 120)",
46
+ "rgb(120, 255, 255)",
47
+ "rgb(255, 255, 120)",
48
+ "rgb(255, 120, 255)",
49
+ ]
50
+
51
+ # ── Model loading ──────────────────────────────────────────────────────────────
52
+
53
+ print(f"Loading {MODEL_ID}...")
54
+ processor = AutoProcessor.from_pretrained(
55
+ MODEL_ID,
56
+ trust_remote_code=True,
57
+ padding_side="left",
58
+ )
59
+
60
+ model = AutoModelForImageTextToText.from_pretrained(
61
+ MODEL_ID,
62
+ trust_remote_code=True,
63
+ dtype="bfloat16",
64
+ device_map="auto",
65
+ )
66
+ print("Model loaded successfully.")
67
+
68
+ # ── Helper functions ───────────────────────────────────────────────────────────
69
+
70
+
71
+ def _parse_rgb(color_str):
72
+ """Parse 'rgb(r, g, b)' to (r, g, b) tuple."""
73
+ nums = color_str.replace("rgb(", "").replace(")", "").split(",")
74
+ return tuple(int(n.strip()) for n in nums)
75
+
76
+
77
+ COLORS_BGR = [(_parse_rgb(c)[2], _parse_rgb(c)[1], _parse_rgb(c)[0]) for c in COLORS]
78
+
79
+
80
+ def is_tracking_output(generated_text: str) -> bool:
81
+ """Detect tracking from model output by checking for <tracks tag."""
82
+ return generated_text.strip().startswith("<tracks")
83
+
84
+
85
+ def cast_float_bf16(t: torch.Tensor):
86
+ if torch.is_floating_point(t):
87
+ t = t.to(torch.bfloat16)
88
+ return t
89
+
90
+
91
+ def draw_points(image, points):
92
+ if isinstance(image, np.ndarray):
93
+ annotation = PIL.Image.fromarray(image)
94
+ else:
95
+ annotation = image.copy()
96
+ draw = ImageDraw.Draw(annotation)
97
+ w, h = annotation.size
98
+ size = max(5, int(max(w, h) * POINT_SIZE))
99
+ for i, (x, y) in enumerate(points):
100
+ color = COLORS[0]
101
+ draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None)
102
+ return annotation
103
+
104
+
105
+ def draw_points_colored(image, points_with_ids):
106
+ """Draw points with per-instance-ID colors for tracking visualization."""
107
+ if isinstance(image, np.ndarray):
108
+ annotation = PIL.Image.fromarray(image)
109
+ else:
110
+ annotation = image.copy()
111
+ draw = ImageDraw.Draw(annotation)
112
+ w, h = annotation.size
113
+ size = max(5, int(max(w, h) * POINT_SIZE))
114
+ for object_id, x, y in points_with_ids:
115
+ color = COLORS[(object_id - 1) % len(COLORS)]
116
+ draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None)
117
+ return annotation
118
+
119
+
120
+ def format_points_list(points, is_video=False):
121
+ """Format extracted points as a flat Python list string."""
122
+ if not points:
123
+ return "[]"
124
+ rows = []
125
+ if is_video:
126
+ for object_id, ts, x, y in points:
127
+ rows.append(f"[{int(object_id)}, {float(ts):.2f}, {float(x):.1f}, {float(y):.1f}]")
128
+ else:
129
+ for object_id, ix, x, y in points:
130
+ rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]")
131
+ return "[" + ", ".join(rows) + "]"
132
+
133
+
134
+ def _interpolate_keyframes(keyframes, total_frames):
135
+ """Linearly interpolate positions between keyframes.
136
+
137
+ keyframes: sorted list of (frame_idx, x, y)
138
+ Returns dict {frame_idx: (x, y)} for every frame from first to last keyframe.
139
+ """
140
+ if not keyframes:
141
+ return {}
142
+ positions = {}
143
+ for i in range(len(keyframes)):
144
+ f_idx, x, y = keyframes[i]
145
+ positions[f_idx] = (x, y)
146
+ if i + 1 < len(keyframes):
147
+ nf, nx, ny = keyframes[i + 1]
148
+ span = nf - f_idx
149
+ if span > 1:
150
+ for t in range(1, span):
151
+ alpha = t / span
152
+ positions[f_idx + t] = (x + alpha * (nx - x), y + alpha * (ny - y))
153
+ return positions
154
+
155
+
156
+ def create_annotated_video(video_path, points, metadata, tracking):
157
+ """Draw points on the original video with interpolation and fading trails.
158
+
159
+ Points format: [(object_id, timestamp, x, y), ...]
160
+ Coordinates are in the processed frame space (metadata["video_size"]).
161
+ """
162
+ cap = cv2.VideoCapture(video_path)
163
+ fps = cap.get(cv2.CAP_PROP_FPS)
164
+ vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
165
+ vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
166
+
167
+ proc_w, proc_h = metadata["video_size"]
168
+ scale_x = vid_w / proc_w
169
+ scale_y = vid_h / proc_h
170
+
171
+ # Build per-object keyframes: {obj_id: [(frame_idx, x, y), ...]}
172
+ obj_keyframes = defaultdict(list)
173
+ for object_id, ts, x, y in points:
174
+ f_idx = int(round(float(ts) * fps))
175
+ sx, sy = float(x) * scale_x, float(y) * scale_y
176
+ obj_keyframes[int(object_id)].append((f_idx, sx, sy))
177
+
178
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
179
+ obj_positions = {}
180
+ obj_keyframe_set = {}
181
+ for obj_id, kfs in obj_keyframes.items():
182
+ kfs.sort(key=lambda k: k[0])
183
+ obj_positions[obj_id] = _interpolate_keyframes(kfs, total_frames)
184
+ raw_kf = set(f_idx for f_idx, _, _ in kfs)
185
+ obj_keyframe_set[obj_id] = set(
186
+ f for kf in raw_kf for f in range(kf - KEYFRAME_HOLD_FRAMES, kf + KEYFRAME_HOLD_FRAMES + 1)
187
+ )
188
+
189
+ out_path = tempfile.mktemp(suffix=".mp4")
190
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
191
+ out = cv2.VideoWriter(out_path, fourcc, fps, (vid_w, vid_h))
192
+
193
+ radius = max(5, int(max(vid_w, vid_h) * POINT_SIZE))
194
+ trail_length = int(fps * 2)
195
+ obj_history = defaultdict(list)
196
+
197
+ current_frame = 0
198
+ while cap.isOpened():
199
+ ret, frame = cap.read()
200
+ if not ret:
201
+ break
202
+
203
+ for obj_id, positions in obj_positions.items():
204
+ if current_frame in positions:
205
+ px, py = positions[current_frame]
206
+ obj_history[obj_id].append((px, py))
207
+ if len(obj_history[obj_id]) > trail_length:
208
+ obj_history[obj_id] = obj_history[obj_id][-trail_length:]
209
+
210
+ if tracking:
211
+ color = COLORS_BGR[(obj_id - 1) % len(COLORS_BGR)]
212
+ else:
213
+ color = COLORS_BGR[0]
214
+
215
+ # Draw fading trail
216
+ trail = obj_history[obj_id]
217
+ n_trail = len(trail)
218
+ if SHOW_TRAILS and n_trail >= 2:
219
+ for i in range(n_trail - 1):
220
+ alpha = (i + 1) / n_trail
221
+ trail_color = tuple(int(c * alpha) for c in color)
222
+ thickness = max(1, int(radius * 0.6 * alpha))
223
+ pt1 = (int(trail[i][0]), int(trail[i][1]))
224
+ pt2 = (int(trail[i + 1][0]), int(trail[i + 1][1]))
225
+ cv2.line(frame, pt1, pt2, trail_color, thickness)
226
+
227
+ # Solid on keyframes, outline-only on interpolated frames
228
+ if current_frame in obj_keyframe_set[obj_id]:
229
+ cv2.circle(frame, (int(px), int(py)), radius, color, -1)
230
+ cv2.circle(frame, (int(px), int(py)), radius + 2, (255, 255, 255), 2)
231
+ else:
232
+ cv2.circle(frame, (int(px), int(py)), radius, color, 2)
233
+
234
+ out.write(frame)
235
+ current_frame += 1
236
+
237
+ cap.release()
238
+ out.release()
239
+ return out_path
240
+
241
+
242
+ # ── Inference functions ────────────────────────────────────────────────────────
243
+
244
+
245
+ @spaces.GPU
246
+ def process_images(user_text, input_images, max_tokens):
247
+ if not input_images:
248
+ return "Please upload at least one image.", [], "[]"
249
+
250
+ pil_images = []
251
+ for img_path in input_images:
252
+ if isinstance(img_path, tuple):
253
+ img_path = img_path[0]
254
+ pil_images.append(Image.open(img_path).convert("RGB"))
255
+
256
+ # Build messages
257
+ content = [dict(type="text", text=user_text)]
258
+ for img in pil_images:
259
+ content.append(dict(type="image", image=img))
260
+ messages = [{"role": "user", "content": content}]
261
+
262
+ # Process inputs
263
+ images, _, _ = process_vision_info(messages)
264
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
265
+ print(f"Prompt: {text}")
266
+
267
+ inputs = processor(
268
+ images=images,
269
+ text=text,
270
+ padding=True,
271
+ return_tensors="pt",
272
+ return_pointing_metadata=True,
273
+ )
274
+ metadata = inputs.pop("metadata")
275
+ inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()}
276
+
277
+ # Generate
278
+ with torch.inference_mode():
279
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
280
+ output = model.generate(
281
+ **inputs,
282
+ logits_processor=model.build_logit_processor_from_inputs(inputs),
283
+ max_new_tokens=int(max_tokens),
284
+ temperature=0
285
+ )
286
+
287
+ generated_tokens = output[0, inputs["input_ids"].size(1):]
288
+ generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
289
+
290
+ # Extract points
291
+ points = model.extract_image_points(
292
+ generated_text,
293
+ metadata["token_pooling"],
294
+ metadata["subpatch_mapping"],
295
+ metadata["image_sizes"],
296
+ )
297
+
298
+ points_table = format_points_list(points, is_video=False)
299
+
300
+ print(f"Output text: {generated_text}")
301
+ print("Extracted points:", points_table)
302
+
303
+ if points:
304
+ group_by_index = defaultdict(list)
305
+ for object_id, ix, x, y in points:
306
+ group_by_index[ix].append((x, y))
307
+ annotated = []
308
+ for ix, pts in group_by_index.items():
309
+ annotated.append(draw_points(images[ix], pts))
310
+ return generated_text, annotated, points_table
311
+
312
+ return generated_text, pil_images, points_table
313
+
314
+
315
+ @spaces.GPU
316
+ def process_video(user_text, video_path, frame_sample_mode, max_frames, max_fps, max_tokens):
317
+ if not video_path:
318
+ return "Please upload a video.", None, [], "[]"
319
+
320
+ # Build messages
321
+ video_kwargs_msg = {
322
+ "num_frames": int(max_frames),
323
+ "frame_sample_mode": frame_sample_mode,
324
+ }
325
+ if max_fps is not None and max_fps > 0:
326
+ video_kwargs_msg["max_fps"] = int(max_fps)
327
+
328
+ messages = [
329
+ {
330
+ "role": "user",
331
+ "content": [
332
+ dict(type="text", text=user_text),
333
+ dict(type="video", video=video_path, **video_kwargs_msg),
334
+ ],
335
+ }
336
+ ]
337
+
338
+ # Process vision info
339
+ _, videos, video_kwargs = process_vision_info(messages)
340
+ videos, video_metadatas = zip(*videos)
341
+ videos, video_metadatas = list(videos), list(video_metadatas)
342
+
343
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
344
+ print(f"Prompt: {text}")
345
+
346
+ inputs = processor(
347
+ videos=videos,
348
+ video_metadata=video_metadatas,
349
+ text=text,
350
+ padding=True,
351
+ return_tensors="pt",
352
+ return_pointing_metadata=True,
353
+ **video_kwargs,
354
+ )
355
+ metadata = inputs.pop("metadata")
356
+ inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()}
357
+
358
+ # Generate
359
+ with torch.inference_mode():
360
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
361
+ output = model.generate(
362
+ **inputs,
363
+ logits_processor=model.build_logit_processor_from_inputs(inputs),
364
+ max_new_tokens=int(max_tokens),
365
+ )
366
+
367
+ generated_tokens = output[0, inputs["input_ids"].size(1):]
368
+ generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
369
+
370
+ # Extract points
371
+ points = model.extract_video_points(
372
+ generated_text,
373
+ metadata["token_pooling"],
374
+ metadata["subpatch_mapping"],
375
+ metadata["timestamps"],
376
+ metadata["video_size"],
377
+ )
378
+
379
+ tracking = is_tracking_output(generated_text)
380
+ annotated_video = None
381
+ annotated_frames = []
382
+ points_table = format_points_list(points, is_video=True)
383
+
384
+ print(f"Output text: {generated_text}")
385
+ print("Extracted points:", points_table)
386
+
387
+ if points:
388
+ print(f"Extracted {len(points)} points. Tracking={tracking}")
389
+
390
+ # Build annotated frames on sampled video frames
391
+ if tracking:
392
+ group_by_time = defaultdict(list)
393
+ for object_id, ts, x, y in points:
394
+ group_by_time[ts].append((object_id, x, y))
395
+ group_by_frame = defaultdict(list)
396
+ for ts, pts_with_ids in group_by_time.items():
397
+ ix = int(np.argmin(np.abs(metadata["timestamps"] - ts)))
398
+ group_by_frame[ix] += pts_with_ids
399
+ for ix, pts_with_ids in sorted(group_by_frame.items()):
400
+ frame_img = draw_points_colored(videos[0][ix], pts_with_ids)
401
+ ts = metadata["timestamps"][ix]
402
+ annotated_frames.append((frame_img, f"t={ts:.2f}s"))
403
+ else:
404
+ group_by_time = defaultdict(list)
405
+ for object_id, ts, x, y in points:
406
+ group_by_time[ts].append((x, y))
407
+ group_by_frame = defaultdict(list)
408
+ for ts, pts in group_by_time.items():
409
+ ix = int(np.argmin(np.abs(metadata["timestamps"] - ts)))
410
+ group_by_frame[ix] += pts
411
+ for ix, pts in sorted(group_by_frame.items()):
412
+ frame_img = draw_points(videos[0][ix], pts)
413
+ ts = metadata["timestamps"][ix]
414
+ annotated_frames.append((frame_img, f"t={ts:.2f}s"))
415
+
416
+ # Annotated video with interpolation + trails
417
+ annotated_video = create_annotated_video(video_path, points, metadata, tracking)
418
+
419
+ return generated_text, annotated_video, annotated_frames, points_table
420
+
421
+
422
+ # ── Gradio UI ────────────────────────────────────────────────────────────────���─
423
+
424
+ # Read processor defaults for video settings
425
+ _default_frame_sample_mode = processor.video_processor.frame_sample_mode
426
+ _default_max_frames = processor.video_processor.num_frames
427
+
428
+ css = """
429
+ #col-container {
430
+ margin: 0 auto;
431
+ max-width: 960px;
432
+ }
433
+ #main-title h1 {font-size: 2.3em !important;}
434
+ #input_image image {
435
+ object-fit: contain !important;
436
+ }
437
+ #input_video video {
438
+ object-fit: contain !important;
439
+ }
440
+ .gallery-item img {
441
+ border: none !important;
442
+ outline: none !important;
443
+ }
444
+ """
445
+
446
+ with gr.Blocks() as demo:
447
+ gr.Markdown("# **Molmo-Point Demo**", elem_id="main-title")
448
+ gr.Markdown(
449
+ "Image & video pointing and tracking using the "
450
+ "[MolmoPoint-8B](https://huggingface.co/allenai/MolmoPoint-8B) pointing model."
451
+ )
452
+
453
+ with gr.Row():
454
+ # ── LEFT COLUMN: Inputs ──
455
+ with gr.Column():
456
+ with gr.Tabs() as input_tabs:
457
+ with gr.TabItem("Video Pointing & Tracking", id="video_tab") as video_tab:
458
+ video = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT)
459
+ with gr.TabItem("Image(s) Pointing", id="image_tab") as image_tab:
460
+ images_input = gr.Gallery(
461
+ label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE,
462
+ )
463
+
464
+ input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text")
465
+
466
+ with gr.Row(visible=True) as video_params_row:
467
+ frame_sample_mode = gr.Dropdown(choices=[_default_frame_sample_mode, "fps"], value=_default_frame_sample_mode, label="frame_sample_mode")
468
+ max_frames = gr.Number(value=_default_max_frames, label="max_frames")
469
+ max_fps = gr.Number(value=MAX_FPS, label="max_fps")
470
+ max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS)
471
+
472
+ with gr.Row():
473
+ submit_button = gr.Button("Submit", variant="primary", scale=3)
474
+ clear_all_button = gr.ClearButton(
475
+ components=[video, images_input, input_text], value="Clear All", scale=1,
476
+ )
477
+
478
+ # ── RIGHT COLUMN: Outputs ──
479
+ with gr.Column():
480
+ with gr.Tabs():
481
+ with gr.TabItem("Output Text"):
482
+ output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10)
483
+ with gr.TabItem("Extracted Points"):
484
+ output_points = gr.Textbox(
485
+ label="Extracted Points ([[id, time/index, x, y]])", lines=15,
486
+ )
487
+
488
+ with gr.Tabs(visible=True) as video_output_tabs:
489
+ with gr.TabItem("Annotated Video"):
490
+ output_video = gr.Video(label="Annotated Video", height=MAX_VIDEO_HEIGHT)
491
+ with gr.TabItem("Annotated Frames"):
492
+ gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*")
493
+ output_annotations = gr.Gallery(label="Annotated Frames (Video)", height=MAX_IMAGE_SIZE)
494
+
495
+ with gr.Group(visible=False) as image_output_group:
496
+ gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*")
497
+ output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE)
498
+
499
+ # ── Examples ──
500
+ with gr.Group(visible=True) as video_examples_group:
501
+ gr.Markdown("### Video Examples")
502
+ gr.Examples(
503
+ examples=[
504
+ ["example-videos/penguins.mp4", "Track all the penguins."],
505
+ ["example-videos/arena_basketball.mp4", "Track the players in yellow uniform in 1 fps."],
506
+ ],
507
+ inputs=[video, input_text],
508
+ label="Video Pointing & Tracking Examples",
509
+ )
510
+
511
+ with gr.Group(visible=False) as image_examples_group:
512
+ gr.Markdown("### Image Examples")
513
+ gr.Examples(
514
+ examples=[
515
+ [["example-images/boat1.jpeg", "example-images/boat2.jpeg"], "Point to the boats."],
516
+ [["example-images/messy1.jpg", "example-images/messy2.jpg", "example-images/messy3.jpg", "example-images/messy4.jpg"], "Point to the scissors."],
517
+ ],
518
+ inputs=[images_input, input_text],
519
+ label="Image Pointing Examples",
520
+ )
521
+
522
+ # ── Tab switching: toggle visibility + track active tab ──
523
+ active_tab = gr.State("video")
524
+
525
+ def _select_video_tab():
526
+ return (
527
+ "video",
528
+ gr.update(visible=True), # video_examples_group
529
+ gr.update(visible=False), # image_examples_group
530
+ gr.update(visible=True), # video_params_row
531
+ gr.update(visible=True), # video_output_tabs
532
+ gr.update(visible=False), # image_output_group
533
+ )
534
+
535
+ def _select_image_tab():
536
+ return (
537
+ "image",
538
+ gr.update(visible=False), # video_examples_group
539
+ gr.update(visible=True), # image_examples_group
540
+ gr.update(visible=False), # video_params_row
541
+ gr.update(visible=False), # video_output_tabs
542
+ gr.update(visible=True), # image_output_group
543
+ )
544
+
545
+ tab_outputs = [active_tab, video_examples_group, image_examples_group, video_params_row, video_output_tabs, image_output_group]
546
+ video_tab.select(fn=_select_video_tab, outputs=tab_outputs)
547
+ image_tab.select(fn=_select_image_tab, outputs=tab_outputs)
548
+
549
+ def _show_fps_tip(generated_text, current_max_fps):
550
+ """Show a toast notification if max_fps doesn't match the detected task type."""
551
+ tracking = "<tracks" in generated_text
552
+ pointing = "<point" in generated_text
553
+ if pointing and int(current_max_fps) != 2:
554
+ gr.Info("Tip: For best video pointing results, set max_fps=2.")
555
+ elif tracking and int(current_max_fps) != 10:
556
+ gr.Info("Tip: For best tracking results, set max_fps=10.")
557
+
558
+ def dispatch_submit(tab, user_text, video_path, input_images,
559
+ fsm, mf, mfps, max_tok):
560
+ if tab == "image":
561
+ text_out, img_gallery, pts = process_images(user_text, input_images, max_tok)
562
+ return text_out, pts, None, [], img_gallery
563
+ else:
564
+ text_out, ann_video, ann_frames, pts = process_video(
565
+ user_text, video_path, fsm, mf, mfps, max_tok,
566
+ )
567
+ _show_fps_tip(text_out, mfps)
568
+ return text_out, pts, ann_video, ann_frames, []
569
+
570
+ submit_button.click(
571
+ fn=dispatch_submit,
572
+ inputs=[active_tab, input_text, video, images_input,
573
+ frame_sample_mode, max_frames, max_fps, max_tok_slider],
574
+ outputs=[output_text, output_points, output_video, output_annotations, output_annotations_img],
575
+ )
576
+
577
+ if __name__ == "__main__":
578
+ demo.launch(css=css, ssr_mode=False, show_error=True, share=True)
example-images/boat1.jpeg ADDED

Git LFS Details

  • SHA256: 1652fe0880f870989ac390c704ece359acef36ecde5b423fcc32f9181e7c374f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
example-images/boat2.jpeg ADDED

Git LFS Details

  • SHA256: 2a974235aed23bef201d15f32de63c25098c020d61e3625663fb515af8acbe3c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.12 MB
example-images/messy1.jpg ADDED

Git LFS Details

  • SHA256: 0810f7ed899a9a90e49923241cf577f3712eac9e8e5d52360d54a9a0d11b079b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
example-images/messy2.jpg ADDED

Git LFS Details

  • SHA256: 12b2e240b935d23644e311b7249410afb8025a842d38d06ebee014195d6da6a9
  • Pointer size: 131 Bytes
  • Size of remote file: 255 kB
example-images/messy3.jpg ADDED

Git LFS Details

  • SHA256: 7ef063cf698a947efd4224f4d95b36e23c7605a27b6a21bbd0d496cbd95afbfc
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
example-images/messy4.jpg ADDED

Git LFS Details

  • SHA256: b3776aadcc39769a233bf185797b4984aa0fffdcef087fec32e4e1fb75712dec
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
example-videos/arena_basketball.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a965ceced2053d1e456b2ce4e4a3fc87a64e4520af7743e91885a2ae11dc237
3
+ size 12297652
example-videos/backflip.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10ac0f73fc374bd6ebb63f3d8d145bb11ef1c713b71f433e31c98b1b0f536018
3
+ size 11171759
example-videos/penguins.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:856bfacc3de618a5154fc6dd0240ad375a8f76faa486070a996007f17f9d3624
3
+ size 1689459
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@v4.57.1
2
+ git+https://github.com/huggingface/accelerate.git
3
+ torch==2.8.0
4
+ torchvision
5
+ pillow
6
+ einops
7
+ decord2
8
+ molmo_utils
9
+ opencv-python
10
+ numpy
11
+ gradio
12
+ spaces
13
+ kernels
14
+ hf_xet