solrz commited on
Commit
82e2bb0
·
1 Parent(s): b9c8e7d

Add video pointmap sequence endpoint

Browse files
Files changed (1) hide show
  1. app.py +213 -32
app.py CHANGED
@@ -13,8 +13,10 @@ import sys
13
  import os
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
 
 
16
  import tempfile
17
  import time as _t
 
18
 
19
  import cv2
20
  import gradio as gr
@@ -59,13 +61,17 @@ POINTMAP_MODELS = {
59
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
60
  },
61
  }
62
- DEFAULT_SIZE = "1B"
63
 
64
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
65
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
66
 
67
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
  MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
 
 
 
 
69
 
70
  _fg_transform = transforms.Compose([
71
  transforms.Resize((1024, 768)),
@@ -99,11 +105,7 @@ def _get_fg_model():
99
  return _fg_model
100
 
101
 
102
- print("[startup] pre-loading all pointmap sizes + fg/bg ...")
103
- for _size in POINTMAP_MODELS:
104
- _get_pointmap_model(_size)
105
- _get_fg_model()
106
- print("[startup] ready.")
107
 
108
 
109
  # -----------------------------------------------------------------------------
@@ -255,6 +257,129 @@ def _make_glb(image_pil_texture: Image.Image, pointmap_hwc: np.ndarray,
255
  return out_path
256
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # -----------------------------------------------------------------------------
259
  # Gradio handler
260
 
@@ -295,6 +420,27 @@ def predict(image: Image.Image, size: str):
295
  return depth_pil, glb_path
296
 
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  # -----------------------------------------------------------------------------
299
  # UI
300
 
@@ -358,33 +504,68 @@ HEADER_HTML = """
358
  with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
359
  gr.HTML(HEADER_HTML)
360
 
361
- # Row 1: input ↔ 3D mesh, equal height
362
- with gr.Row(equal_height=True):
363
- inp = gr.Image(label="Input", type="pil", height=640, scale=2)
364
- out_glb = gr.Model3D(
365
- label="Pointmap",
366
- height=640,
367
- clear_color=[0.97, 0.97, 0.97, 1.0], # cinematic studio white
368
- camera_position=(35, 70, 1.6), # closer, since scene is centered on the human
369
- zoom_speed=0.7,
370
- pan_speed=0.5,
371
- scale=3,
372
- )
373
-
374
- # Row 2: controls (with examples below them) on the left | depth heatmap on the right.
375
- with gr.Row():
376
- with gr.Column(scale=2, min_width=320):
377
- size = gr.Radio(
378
- choices=list(POINTMAP_MODELS.keys()),
379
- value=DEFAULT_SIZE,
380
- label="Model",
381
- container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
- run = gr.Button("Run", variant="primary")
384
- gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
385
- out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3)
386
-
387
- run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb])
388
 
389
 
390
  if __name__ == "__main__":
 
13
  import os
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
 
16
+ import json
17
  import tempfile
18
  import time as _t
19
+ import zipfile
20
 
21
  import cv2
22
  import gradio as gr
 
61
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
62
  },
63
  }
64
+ DEFAULT_SIZE = "0.4B"
65
 
66
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
67
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
68
 
69
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
  MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
71
+ VIDEO_MAX_HEIGHT = 512
72
+ VIDEO_DEFAULT_FRAMES = 36
73
+ VIDEO_DEFAULT_STRIDE = 5
74
+ VIDEO_MAX_POINTS = 45_000
75
 
76
  _fg_transform = transforms.Compose([
77
  transforms.Resize((1024, 768)),
 
105
  return _fg_model
106
 
107
 
108
+ print("[startup] ready; models will load lazily on first request.")
 
 
 
 
109
 
110
 
111
  # -----------------------------------------------------------------------------
 
257
  return out_path
258
 
259
 
260
+ def _sample_video_frames(video_path: str, max_frames: int) -> tuple[list[Image.Image], float, int, int]:
261
+ cap = cv2.VideoCapture(video_path)
262
+ if not cap.isOpened():
263
+ raise gr.Error("Could not open the uploaded video.")
264
+
265
+ fps = float(cap.get(cv2.CAP_PROP_FPS) or 0) or 24.0
266
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
267
+ frame_limit = int(max(1, min(max_frames, 120)))
268
+ if total_frames > 0:
269
+ indices = np.linspace(0, max(total_frames - 1, 0), min(frame_limit, total_frames), dtype=np.int32)
270
+ else:
271
+ indices = np.arange(frame_limit, dtype=np.int32)
272
+
273
+ frames: list[Image.Image] = []
274
+ source_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
275
+ source_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
276
+ for index in indices:
277
+ if total_frames > 0:
278
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(index))
279
+ ok, frame_bgr = cap.read()
280
+ if not ok:
281
+ if total_frames <= 0:
282
+ break
283
+ continue
284
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
285
+ frames.append(Image.fromarray(frame_rgb))
286
+ if total_frames <= 0 and len(frames) >= frame_limit:
287
+ break
288
+
289
+ cap.release()
290
+ if not frames:
291
+ raise gr.Error("No readable frames were found in the uploaded video.")
292
+ return frames, fps, source_w, source_h
293
+
294
+
295
+ def _select_point_coords(mask: np.ndarray, stride: int, max_points: int) -> tuple[np.ndarray, np.ndarray]:
296
+ stride = int(max(1, min(stride, 16)))
297
+ grid = np.zeros_like(mask, dtype=bool)
298
+ grid[::stride, ::stride] = True
299
+ selected = mask & grid
300
+ yy, xx = np.where(selected)
301
+ if len(yy) < 128:
302
+ yy, xx = np.where(grid)
303
+ if len(yy) > max_points:
304
+ keep = np.linspace(0, len(yy) - 1, max_points, dtype=np.int64)
305
+ yy = yy[keep]
306
+ xx = xx[keep]
307
+ return yy.astype(np.int32), xx.astype(np.int32)
308
+
309
+
310
+ def _point_sequence_zip(
311
+ frames: list[Image.Image],
312
+ size: str,
313
+ max_frames: int,
314
+ point_stride: int,
315
+ fps: float,
316
+ source_w: int,
317
+ source_h: int,
318
+ ) -> str:
319
+ model = _get_pointmap_model(size)
320
+ sampled = frames[: int(max(1, min(max_frames, len(frames))))]
321
+ positions_frames: list[np.ndarray] = []
322
+ colors_frames: list[np.ndarray] = []
323
+ sample_y: np.ndarray | None = None
324
+ sample_x: np.ndarray | None = None
325
+ native_w = 0
326
+ native_h = 0
327
+
328
+ for frame_index, frame in enumerate(sampled):
329
+ t = _t.perf_counter()
330
+ image_pil = _cap_height(frame.convert("RGB"), VIDEO_MAX_HEIGHT)
331
+ image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
332
+ pointmap = _estimate_pointmap(image_bgr, model)
333
+ native_h, native_w = pointmap.shape[:2]
334
+
335
+ if sample_y is None or sample_x is None:
336
+ mask = _foreground_mask(image_pil, native_h, native_w)
337
+ sample_y, sample_x = _select_point_coords(mask, point_stride, VIDEO_MAX_POINTS)
338
+ print(f"[video] selected {len(sample_y)} points at native {native_w}x{native_h}")
339
+
340
+ image_native = np.array(image_pil.resize((native_w, native_h), Image.LANCZOS))
341
+ points = pointmap[sample_y, sample_x].astype(np.float32)
342
+ finite = np.isfinite(points).all(axis=1) & (points[:, 2] > 0.05) & (points[:, 2] < 25.0)
343
+ if finite.any():
344
+ centroid = points[finite].mean(axis=0).astype(np.float32)
345
+ else:
346
+ centroid = np.zeros(3, dtype=np.float32)
347
+ points = (points - centroid) * np.array([1.0, -1.0, -1.0], dtype=np.float32)
348
+ points[~finite] = 0
349
+ colors = image_native[sample_y, sample_x, :3].astype(np.uint8)
350
+ colors[~finite] = 0
351
+
352
+ positions_frames.append(points)
353
+ colors_frames.append(colors)
354
+ print(f"[video] frame {frame_index + 1}/{len(sampled)} {(_t.perf_counter() - t) * 1000:.0f} ms")
355
+
356
+ if not positions_frames:
357
+ raise gr.Error("Pointmap inference did not produce any frames.")
358
+
359
+ positions = np.stack(positions_frames, axis=0).astype(np.float32)
360
+ colors = np.stack(colors_frames, axis=0).astype(np.uint8)
361
+ metadata = {
362
+ "format": "fpbox-sapiens-pointmap-sequence-v1",
363
+ "model": f"sapiens2-pointmap-{size}",
364
+ "frameCount": int(positions.shape[0]),
365
+ "fps": float(min(fps, max(1, positions.shape[0]))),
366
+ "pointCount": int(positions.shape[1]),
367
+ "width": int(native_w),
368
+ "height": int(native_h),
369
+ "sourceWidth": int(source_w),
370
+ "sourceHeight": int(source_h),
371
+ "coordinateSystem": "x, -y, -z, centered per frame",
372
+ "dtype": {"positions": "float32", "colors": "uint8"},
373
+ }
374
+
375
+ out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
376
+ with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
377
+ zf.writestr("metadata.json", json.dumps(metadata, indent=2))
378
+ zf.writestr("positions_f32.bin", positions.tobytes(order="C"))
379
+ zf.writestr("colors_u8.bin", colors.tobytes(order="C"))
380
+ return out_path
381
+
382
+
383
  # -----------------------------------------------------------------------------
384
  # Gradio handler
385
 
 
420
  return depth_pil, glb_path
421
 
422
 
423
+ @spaces.GPU(duration=300)
424
+ def predict_video(video_path: str, size: str, max_frames: int, point_stride: int):
425
+ if video_path is None:
426
+ return None
427
+
428
+ t0 = _t.perf_counter()
429
+ frames, fps, source_w, source_h = _sample_video_frames(video_path, max_frames)
430
+ print(f"[video] sampled {len(frames)} frames from {source_w}x{source_h} video at {fps:.2f} fps")
431
+ zip_path = _point_sequence_zip(
432
+ frames=frames,
433
+ size=size,
434
+ max_frames=max_frames,
435
+ point_stride=point_stride,
436
+ fps=fps,
437
+ source_w=source_w,
438
+ source_h=source_h,
439
+ )
440
+ print(f"[video] TOTAL {(_t.perf_counter() - t0) * 1000:.0f} ms")
441
+ return zip_path
442
+
443
+
444
  # -----------------------------------------------------------------------------
445
  # UI
446
 
 
504
  with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
505
  gr.HTML(HEADER_HTML)
506
 
507
+ with gr.Tabs():
508
+ with gr.Tab("Image"):
509
+ # Row 1: input ↔ 3D mesh, equal height
510
+ with gr.Row(equal_height=True):
511
+ inp = gr.Image(label="Input", type="pil", height=640, scale=2)
512
+ out_glb = gr.Model3D(
513
+ label="Pointmap",
514
+ height=640,
515
+ clear_color=[0.97, 0.97, 0.97, 1.0], # cinematic studio white
516
+ camera_position=(35, 70, 1.6), # closer, since scene is centered on the human
517
+ zoom_speed=0.7,
518
+ pan_speed=0.5,
519
+ scale=3,
520
+ )
521
+
522
+ # Row 2: controls (with examples below them) on the left | depth heatmap on the right.
523
+ with gr.Row():
524
+ with gr.Column(scale=2, min_width=320):
525
+ size = gr.Radio(
526
+ choices=list(POINTMAP_MODELS.keys()),
527
+ value=DEFAULT_SIZE,
528
+ label="Model",
529
+ container=False,
530
+ )
531
+ run = gr.Button("Run", variant="primary")
532
+ gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
533
+ out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3)
534
+
535
+ run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb])
536
+
537
+ with gr.Tab("Video"):
538
+ with gr.Row():
539
+ video_inp = gr.Video(label="Input Video", height=420)
540
+ sequence_zip = gr.File(label="Pointmap Sequence (.zip)")
541
+ with gr.Row():
542
+ video_size = gr.Radio(
543
+ choices=list(POINTMAP_MODELS.keys()),
544
+ value=DEFAULT_SIZE,
545
+ label="Model",
546
+ container=False,
547
+ )
548
+ video_frames = gr.Slider(
549
+ minimum=1,
550
+ maximum=120,
551
+ step=1,
552
+ value=VIDEO_DEFAULT_FRAMES,
553
+ label="Sampled Frames",
554
+ )
555
+ video_stride = gr.Slider(
556
+ minimum=1,
557
+ maximum=16,
558
+ step=1,
559
+ value=VIDEO_DEFAULT_STRIDE,
560
+ label="Point Stride",
561
+ )
562
+ run_video = gr.Button("Run Video Pointmap", variant="primary")
563
+ run_video.click(
564
+ predict_video,
565
+ inputs=[video_inp, video_size, video_frames, video_stride],
566
+ outputs=[sequence_zip],
567
+ api_name="video_pointmap",
568
  )
 
 
 
 
 
569
 
570
 
571
  if __name__ == "__main__":