solrz commited on
Commit
da43251
·
1 Parent(s): 82e2bb0

Emit meshable video pointmap grids

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -70,8 +70,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
  MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
71
  VIDEO_MAX_HEIGHT = 512
72
  VIDEO_DEFAULT_FRAMES = 36
73
- VIDEO_DEFAULT_STRIDE = 5
74
- VIDEO_MAX_POINTS = 45_000
75
 
76
  _fg_transform = transforms.Compose([
77
  transforms.Resize((1024, 768)),
@@ -292,19 +291,19 @@ def _sample_video_frames(video_path: str, max_frames: int) -> tuple[list[Image.I
292
  return frames, fps, source_w, source_h
293
 
294
 
295
- def _select_point_coords(mask: np.ndarray, stride: int, max_points: int) -> tuple[np.ndarray, np.ndarray]:
296
  stride = int(max(1, min(stride, 16)))
297
- grid = np.zeros_like(mask, dtype=bool)
298
- grid[::stride, ::stride] = True
299
- selected = mask & grid
300
- yy, xx = np.where(selected)
301
- if len(yy) < 128:
302
- yy, xx = np.where(grid)
303
- if len(yy) > max_points:
304
- keep = np.linspace(0, len(yy) - 1, max_points, dtype=np.int64)
305
- yy = yy[keep]
306
- xx = xx[keep]
307
- return yy.astype(np.int32), xx.astype(np.int32)
308
 
309
 
310
  def _point_sequence_zip(
@@ -322,6 +321,9 @@ def _point_sequence_zip(
322
  colors_frames: list[np.ndarray] = []
323
  sample_y: np.ndarray | None = None
324
  sample_x: np.ndarray | None = None
 
 
 
325
  native_w = 0
326
  native_h = 0
327
 
@@ -332,14 +334,15 @@ def _point_sequence_zip(
332
  pointmap = _estimate_pointmap(image_bgr, model)
333
  native_h, native_w = pointmap.shape[:2]
334
 
335
- if sample_y is None or sample_x is None:
336
  mask = _foreground_mask(image_pil, native_h, native_w)
337
- sample_y, sample_x = _select_point_coords(mask, point_stride, VIDEO_MAX_POINTS)
338
- print(f"[video] selected {len(sample_y)} points at native {native_w}x{native_h}")
339
 
340
  image_native = np.array(image_pil.resize((native_w, native_h), Image.LANCZOS))
341
  points = pointmap[sample_y, sample_x].astype(np.float32)
342
  finite = np.isfinite(points).all(axis=1) & (points[:, 2] > 0.05) & (points[:, 2] < 25.0)
 
343
  if finite.any():
344
  centroid = points[finite].mean(axis=0).astype(np.float32)
345
  else:
@@ -359,17 +362,19 @@ def _point_sequence_zip(
359
  positions = np.stack(positions_frames, axis=0).astype(np.float32)
360
  colors = np.stack(colors_frames, axis=0).astype(np.uint8)
361
  metadata = {
362
- "format": "fpbox-sapiens-pointmap-sequence-v1",
363
  "model": f"sapiens2-pointmap-{size}",
364
  "frameCount": int(positions.shape[0]),
365
  "fps": float(min(fps, max(1, positions.shape[0]))),
366
  "pointCount": int(positions.shape[1]),
 
 
367
  "width": int(native_w),
368
  "height": int(native_h),
369
  "sourceWidth": int(source_w),
370
  "sourceHeight": int(source_h),
371
- "coordinateSystem": "x, -y, -z, centered per frame",
372
- "dtype": {"positions": "float32", "colors": "uint8"},
373
  }
374
 
375
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
@@ -377,6 +382,7 @@ def _point_sequence_zip(
377
  zf.writestr("metadata.json", json.dumps(metadata, indent=2))
378
  zf.writestr("positions_f32.bin", positions.tobytes(order="C"))
379
  zf.writestr("colors_u8.bin", colors.tobytes(order="C"))
 
380
  return out_path
381
 
382
 
 
70
  MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
71
  VIDEO_MAX_HEIGHT = 512
72
  VIDEO_DEFAULT_FRAMES = 36
73
+ VIDEO_DEFAULT_STRIDE = 3
 
74
 
75
  _fg_transform = transforms.Compose([
76
  transforms.Resize((1024, 768)),
 
291
  return frames, fps, source_w, source_h
292
 
293
 
294
+ def _sample_grid(mask: np.ndarray, stride: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, int, int]:
295
  stride = int(max(1, min(stride, 16)))
296
+ ys = np.arange(0, mask.shape[0], stride, dtype=np.int32)
297
+ xs = np.arange(0, mask.shape[1], stride, dtype=np.int32)
298
+ grid_y, grid_x = np.meshgrid(ys, xs, indexing="ij")
299
+ valid = mask[grid_y, grid_x].reshape(-1)
300
+ return (
301
+ grid_y.reshape(-1).astype(np.int32),
302
+ grid_x.reshape(-1).astype(np.int32),
303
+ valid.astype(np.uint8),
304
+ int(len(ys)),
305
+ int(len(xs)),
306
+ )
307
 
308
 
309
  def _point_sequence_zip(
 
321
  colors_frames: list[np.ndarray] = []
322
  sample_y: np.ndarray | None = None
323
  sample_x: np.ndarray | None = None
324
+ valid_mask: np.ndarray | None = None
325
+ grid_rows = 0
326
+ grid_cols = 0
327
  native_w = 0
328
  native_h = 0
329
 
 
334
  pointmap = _estimate_pointmap(image_bgr, model)
335
  native_h, native_w = pointmap.shape[:2]
336
 
337
+ if sample_y is None or sample_x is None or valid_mask is None:
338
  mask = _foreground_mask(image_pil, native_h, native_w)
339
+ sample_y, sample_x, valid_mask, grid_rows, grid_cols = _sample_grid(mask, point_stride)
340
+ print(f"[video] sampled {grid_cols}x{grid_rows} grid ({int(valid_mask.sum())} valid foreground points) at native {native_w}x{native_h}")
341
 
342
  image_native = np.array(image_pil.resize((native_w, native_h), Image.LANCZOS))
343
  points = pointmap[sample_y, sample_x].astype(np.float32)
344
  finite = np.isfinite(points).all(axis=1) & (points[:, 2] > 0.05) & (points[:, 2] < 25.0)
345
+ finite &= valid_mask.astype(bool)
346
  if finite.any():
347
  centroid = points[finite].mean(axis=0).astype(np.float32)
348
  else:
 
362
  positions = np.stack(positions_frames, axis=0).astype(np.float32)
363
  colors = np.stack(colors_frames, axis=0).astype(np.uint8)
364
  metadata = {
365
+ "format": "fpbox-sapiens-pointmap-sequence-v2",
366
  "model": f"sapiens2-pointmap-{size}",
367
  "frameCount": int(positions.shape[0]),
368
  "fps": float(min(fps, max(1, positions.shape[0]))),
369
  "pointCount": int(positions.shape[1]),
370
+ "gridRows": int(grid_rows),
371
+ "gridCols": int(grid_cols),
372
  "width": int(native_w),
373
  "height": int(native_h),
374
  "sourceWidth": int(source_w),
375
  "sourceHeight": int(source_h),
376
+ "coordinateSystem": "x, -y, -z, centered per frame, regular sampled image grid",
377
+ "dtype": {"positions": "float32", "colors": "uint8", "valid": "uint8"},
378
  }
379
 
380
  out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
 
382
  zf.writestr("metadata.json", json.dumps(metadata, indent=2))
383
  zf.writestr("positions_f32.bin", positions.tobytes(order="C"))
384
  zf.writestr("colors_u8.bin", colors.tobytes(order="C"))
385
+ zf.writestr("valid_u8.bin", valid_mask.astype(np.uint8).tobytes(order="C"))
386
  return out_path
387
 
388