Spaces:
Running on Zero
Running on Zero
solrz commited on
Commit ·
82e2bb0
1
Parent(s): b9c8e7d
Add video pointmap sequence endpoint
Browse files
app.py
CHANGED
|
@@ -13,8 +13,10 @@ import sys
|
|
| 13 |
import os
|
| 14 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
|
|
|
|
| 16 |
import tempfile
|
| 17 |
import time as _t
|
|
|
|
| 18 |
|
| 19 |
import cv2
|
| 20 |
import gradio as gr
|
|
@@ -59,13 +61,17 @@ POINTMAP_MODELS = {
|
|
| 59 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
|
| 60 |
},
|
| 61 |
}
|
| 62 |
-
DEFAULT_SIZE = "
|
| 63 |
|
| 64 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 65 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 66 |
|
| 67 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
_fg_transform = transforms.Compose([
|
| 71 |
transforms.Resize((1024, 768)),
|
|
@@ -99,11 +105,7 @@ def _get_fg_model():
|
|
| 99 |
return _fg_model
|
| 100 |
|
| 101 |
|
| 102 |
-
print("[startup]
|
| 103 |
-
for _size in POINTMAP_MODELS:
|
| 104 |
-
_get_pointmap_model(_size)
|
| 105 |
-
_get_fg_model()
|
| 106 |
-
print("[startup] ready.")
|
| 107 |
|
| 108 |
|
| 109 |
# -----------------------------------------------------------------------------
|
|
@@ -255,6 +257,129 @@ def _make_glb(image_pil_texture: Image.Image, pointmap_hwc: np.ndarray,
|
|
| 255 |
return out_path
|
| 256 |
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# -----------------------------------------------------------------------------
|
| 259 |
# Gradio handler
|
| 260 |
|
|
@@ -295,6 +420,27 @@ def predict(image: Image.Image, size: str):
|
|
| 295 |
return depth_pil, glb_path
|
| 296 |
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# -----------------------------------------------------------------------------
|
| 299 |
# UI
|
| 300 |
|
|
@@ -358,33 +504,68 @@ HEADER_HTML = """
|
|
| 358 |
with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
| 359 |
gr.HTML(HEADER_HTML)
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
)
|
| 383 |
-
run = gr.Button("Run", variant="primary")
|
| 384 |
-
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
|
| 385 |
-
out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3)
|
| 386 |
-
|
| 387 |
-
run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb])
|
| 388 |
|
| 389 |
|
| 390 |
if __name__ == "__main__":
|
|
|
|
| 13 |
import os
|
| 14 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
|
| 16 |
+
import json
|
| 17 |
import tempfile
|
| 18 |
import time as _t
|
| 19 |
+
import zipfile
|
| 20 |
|
| 21 |
import cv2
|
| 22 |
import gradio as gr
|
|
|
|
| 61 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
|
| 62 |
},
|
| 63 |
}
|
| 64 |
+
DEFAULT_SIZE = "0.4B"
|
| 65 |
|
| 66 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 67 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 68 |
|
| 69 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
|
| 71 |
+
VIDEO_MAX_HEIGHT = 512
|
| 72 |
+
VIDEO_DEFAULT_FRAMES = 36
|
| 73 |
+
VIDEO_DEFAULT_STRIDE = 5
|
| 74 |
+
VIDEO_MAX_POINTS = 45_000
|
| 75 |
|
| 76 |
_fg_transform = transforms.Compose([
|
| 77 |
transforms.Resize((1024, 768)),
|
|
|
|
| 105 |
return _fg_model
|
| 106 |
|
| 107 |
|
| 108 |
+
print("[startup] ready; models will load lazily on first request.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
# -----------------------------------------------------------------------------
|
|
|
|
| 257 |
return out_path
|
| 258 |
|
| 259 |
|
| 260 |
+
def _sample_video_frames(video_path: str, max_frames: int) -> tuple[list[Image.Image], float, int, int]:
|
| 261 |
+
cap = cv2.VideoCapture(video_path)
|
| 262 |
+
if not cap.isOpened():
|
| 263 |
+
raise gr.Error("Could not open the uploaded video.")
|
| 264 |
+
|
| 265 |
+
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0) or 24.0
|
| 266 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 267 |
+
frame_limit = int(max(1, min(max_frames, 120)))
|
| 268 |
+
if total_frames > 0:
|
| 269 |
+
indices = np.linspace(0, max(total_frames - 1, 0), min(frame_limit, total_frames), dtype=np.int32)
|
| 270 |
+
else:
|
| 271 |
+
indices = np.arange(frame_limit, dtype=np.int32)
|
| 272 |
+
|
| 273 |
+
frames: list[Image.Image] = []
|
| 274 |
+
source_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
|
| 275 |
+
source_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
|
| 276 |
+
for index in indices:
|
| 277 |
+
if total_frames > 0:
|
| 278 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, int(index))
|
| 279 |
+
ok, frame_bgr = cap.read()
|
| 280 |
+
if not ok:
|
| 281 |
+
if total_frames <= 0:
|
| 282 |
+
break
|
| 283 |
+
continue
|
| 284 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 285 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 286 |
+
if total_frames <= 0 and len(frames) >= frame_limit:
|
| 287 |
+
break
|
| 288 |
+
|
| 289 |
+
cap.release()
|
| 290 |
+
if not frames:
|
| 291 |
+
raise gr.Error("No readable frames were found in the uploaded video.")
|
| 292 |
+
return frames, fps, source_w, source_h
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _select_point_coords(mask: np.ndarray, stride: int, max_points: int) -> tuple[np.ndarray, np.ndarray]:
|
| 296 |
+
stride = int(max(1, min(stride, 16)))
|
| 297 |
+
grid = np.zeros_like(mask, dtype=bool)
|
| 298 |
+
grid[::stride, ::stride] = True
|
| 299 |
+
selected = mask & grid
|
| 300 |
+
yy, xx = np.where(selected)
|
| 301 |
+
if len(yy) < 128:
|
| 302 |
+
yy, xx = np.where(grid)
|
| 303 |
+
if len(yy) > max_points:
|
| 304 |
+
keep = np.linspace(0, len(yy) - 1, max_points, dtype=np.int64)
|
| 305 |
+
yy = yy[keep]
|
| 306 |
+
xx = xx[keep]
|
| 307 |
+
return yy.astype(np.int32), xx.astype(np.int32)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _point_sequence_zip(
|
| 311 |
+
frames: list[Image.Image],
|
| 312 |
+
size: str,
|
| 313 |
+
max_frames: int,
|
| 314 |
+
point_stride: int,
|
| 315 |
+
fps: float,
|
| 316 |
+
source_w: int,
|
| 317 |
+
source_h: int,
|
| 318 |
+
) -> str:
|
| 319 |
+
model = _get_pointmap_model(size)
|
| 320 |
+
sampled = frames[: int(max(1, min(max_frames, len(frames))))]
|
| 321 |
+
positions_frames: list[np.ndarray] = []
|
| 322 |
+
colors_frames: list[np.ndarray] = []
|
| 323 |
+
sample_y: np.ndarray | None = None
|
| 324 |
+
sample_x: np.ndarray | None = None
|
| 325 |
+
native_w = 0
|
| 326 |
+
native_h = 0
|
| 327 |
+
|
| 328 |
+
for frame_index, frame in enumerate(sampled):
|
| 329 |
+
t = _t.perf_counter()
|
| 330 |
+
image_pil = _cap_height(frame.convert("RGB"), VIDEO_MAX_HEIGHT)
|
| 331 |
+
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 332 |
+
pointmap = _estimate_pointmap(image_bgr, model)
|
| 333 |
+
native_h, native_w = pointmap.shape[:2]
|
| 334 |
+
|
| 335 |
+
if sample_y is None or sample_x is None:
|
| 336 |
+
mask = _foreground_mask(image_pil, native_h, native_w)
|
| 337 |
+
sample_y, sample_x = _select_point_coords(mask, point_stride, VIDEO_MAX_POINTS)
|
| 338 |
+
print(f"[video] selected {len(sample_y)} points at native {native_w}x{native_h}")
|
| 339 |
+
|
| 340 |
+
image_native = np.array(image_pil.resize((native_w, native_h), Image.LANCZOS))
|
| 341 |
+
points = pointmap[sample_y, sample_x].astype(np.float32)
|
| 342 |
+
finite = np.isfinite(points).all(axis=1) & (points[:, 2] > 0.05) & (points[:, 2] < 25.0)
|
| 343 |
+
if finite.any():
|
| 344 |
+
centroid = points[finite].mean(axis=0).astype(np.float32)
|
| 345 |
+
else:
|
| 346 |
+
centroid = np.zeros(3, dtype=np.float32)
|
| 347 |
+
points = (points - centroid) * np.array([1.0, -1.0, -1.0], dtype=np.float32)
|
| 348 |
+
points[~finite] = 0
|
| 349 |
+
colors = image_native[sample_y, sample_x, :3].astype(np.uint8)
|
| 350 |
+
colors[~finite] = 0
|
| 351 |
+
|
| 352 |
+
positions_frames.append(points)
|
| 353 |
+
colors_frames.append(colors)
|
| 354 |
+
print(f"[video] frame {frame_index + 1}/{len(sampled)} {(_t.perf_counter() - t) * 1000:.0f} ms")
|
| 355 |
+
|
| 356 |
+
if not positions_frames:
|
| 357 |
+
raise gr.Error("Pointmap inference did not produce any frames.")
|
| 358 |
+
|
| 359 |
+
positions = np.stack(positions_frames, axis=0).astype(np.float32)
|
| 360 |
+
colors = np.stack(colors_frames, axis=0).astype(np.uint8)
|
| 361 |
+
metadata = {
|
| 362 |
+
"format": "fpbox-sapiens-pointmap-sequence-v1",
|
| 363 |
+
"model": f"sapiens2-pointmap-{size}",
|
| 364 |
+
"frameCount": int(positions.shape[0]),
|
| 365 |
+
"fps": float(min(fps, max(1, positions.shape[0]))),
|
| 366 |
+
"pointCount": int(positions.shape[1]),
|
| 367 |
+
"width": int(native_w),
|
| 368 |
+
"height": int(native_h),
|
| 369 |
+
"sourceWidth": int(source_w),
|
| 370 |
+
"sourceHeight": int(source_h),
|
| 371 |
+
"coordinateSystem": "x, -y, -z, centered per frame",
|
| 372 |
+
"dtype": {"positions": "float32", "colors": "uint8"},
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
|
| 376 |
+
with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
| 377 |
+
zf.writestr("metadata.json", json.dumps(metadata, indent=2))
|
| 378 |
+
zf.writestr("positions_f32.bin", positions.tobytes(order="C"))
|
| 379 |
+
zf.writestr("colors_u8.bin", colors.tobytes(order="C"))
|
| 380 |
+
return out_path
|
| 381 |
+
|
| 382 |
+
|
| 383 |
# -----------------------------------------------------------------------------
|
| 384 |
# Gradio handler
|
| 385 |
|
|
|
|
| 420 |
return depth_pil, glb_path
|
| 421 |
|
| 422 |
|
| 423 |
+
@spaces.GPU(duration=300)
|
| 424 |
+
def predict_video(video_path: str, size: str, max_frames: int, point_stride: int):
|
| 425 |
+
if video_path is None:
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
t0 = _t.perf_counter()
|
| 429 |
+
frames, fps, source_w, source_h = _sample_video_frames(video_path, max_frames)
|
| 430 |
+
print(f"[video] sampled {len(frames)} frames from {source_w}x{source_h} video at {fps:.2f} fps")
|
| 431 |
+
zip_path = _point_sequence_zip(
|
| 432 |
+
frames=frames,
|
| 433 |
+
size=size,
|
| 434 |
+
max_frames=max_frames,
|
| 435 |
+
point_stride=point_stride,
|
| 436 |
+
fps=fps,
|
| 437 |
+
source_w=source_w,
|
| 438 |
+
source_h=source_h,
|
| 439 |
+
)
|
| 440 |
+
print(f"[video] TOTAL {(_t.perf_counter() - t0) * 1000:.0f} ms")
|
| 441 |
+
return zip_path
|
| 442 |
+
|
| 443 |
+
|
| 444 |
# -----------------------------------------------------------------------------
|
| 445 |
# UI
|
| 446 |
|
|
|
|
| 504 |
with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
| 505 |
gr.HTML(HEADER_HTML)
|
| 506 |
|
| 507 |
+
with gr.Tabs():
|
| 508 |
+
with gr.Tab("Image"):
|
| 509 |
+
# Row 1: input ↔ 3D mesh, equal height
|
| 510 |
+
with gr.Row(equal_height=True):
|
| 511 |
+
inp = gr.Image(label="Input", type="pil", height=640, scale=2)
|
| 512 |
+
out_glb = gr.Model3D(
|
| 513 |
+
label="Pointmap",
|
| 514 |
+
height=640,
|
| 515 |
+
clear_color=[0.97, 0.97, 0.97, 1.0], # cinematic studio white
|
| 516 |
+
camera_position=(35, 70, 1.6), # closer, since scene is centered on the human
|
| 517 |
+
zoom_speed=0.7,
|
| 518 |
+
pan_speed=0.5,
|
| 519 |
+
scale=3,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Row 2: controls (with examples below them) on the left | depth heatmap on the right.
|
| 523 |
+
with gr.Row():
|
| 524 |
+
with gr.Column(scale=2, min_width=320):
|
| 525 |
+
size = gr.Radio(
|
| 526 |
+
choices=list(POINTMAP_MODELS.keys()),
|
| 527 |
+
value=DEFAULT_SIZE,
|
| 528 |
+
label="Model",
|
| 529 |
+
container=False,
|
| 530 |
+
)
|
| 531 |
+
run = gr.Button("Run", variant="primary")
|
| 532 |
+
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
|
| 533 |
+
out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3)
|
| 534 |
+
|
| 535 |
+
run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb])
|
| 536 |
+
|
| 537 |
+
with gr.Tab("Video"):
|
| 538 |
+
with gr.Row():
|
| 539 |
+
video_inp = gr.Video(label="Input Video", height=420)
|
| 540 |
+
sequence_zip = gr.File(label="Pointmap Sequence (.zip)")
|
| 541 |
+
with gr.Row():
|
| 542 |
+
video_size = gr.Radio(
|
| 543 |
+
choices=list(POINTMAP_MODELS.keys()),
|
| 544 |
+
value=DEFAULT_SIZE,
|
| 545 |
+
label="Model",
|
| 546 |
+
container=False,
|
| 547 |
+
)
|
| 548 |
+
video_frames = gr.Slider(
|
| 549 |
+
minimum=1,
|
| 550 |
+
maximum=120,
|
| 551 |
+
step=1,
|
| 552 |
+
value=VIDEO_DEFAULT_FRAMES,
|
| 553 |
+
label="Sampled Frames",
|
| 554 |
+
)
|
| 555 |
+
video_stride = gr.Slider(
|
| 556 |
+
minimum=1,
|
| 557 |
+
maximum=16,
|
| 558 |
+
step=1,
|
| 559 |
+
value=VIDEO_DEFAULT_STRIDE,
|
| 560 |
+
label="Point Stride",
|
| 561 |
+
)
|
| 562 |
+
run_video = gr.Button("Run Video Pointmap", variant="primary")
|
| 563 |
+
run_video.click(
|
| 564 |
+
predict_video,
|
| 565 |
+
inputs=[video_inp, video_size, video_frames, video_stride],
|
| 566 |
+
outputs=[sequence_zip],
|
| 567 |
+
api_name="video_pointmap",
|
| 568 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
|
| 571 |
if __name__ == "__main__":
|