Spaces:
Running on Zero
Running on Zero
Rawal Khirodkar commited on
Commit ·
2c70f2e
1
Parent(s): 380dd37
Pointmap: bring back .ply + Model3D, but native-res only (max 1024×768 grid → 200K pts)
Browse files- app.py +71 -43
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
-
Image → per-pixel 3D pointmap (camera frame, metric units).
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import sys
|
|
@@ -14,6 +17,7 @@ import tempfile
|
|
| 14 |
import cv2
|
| 15 |
import gradio as gr
|
| 16 |
import numpy as np
|
|
|
|
| 17 |
import spaces
|
| 18 |
import torch
|
| 19 |
import torch.nn.functional as F
|
|
@@ -92,7 +96,6 @@ def _get_fg_model():
|
|
| 92 |
return _fg_model
|
| 93 |
|
| 94 |
|
| 95 |
-
# Iteration mode: only preload the default (0.4B) for fast Space boot.
|
| 96 |
print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
|
| 97 |
_get_pointmap_model(DEFAULT_SIZE)
|
| 98 |
_get_fg_model()
|
|
@@ -100,7 +103,7 @@ print("[startup] ready.")
|
|
| 100 |
|
| 101 |
|
| 102 |
# -----------------------------------------------------------------------------
|
| 103 |
-
# Inference (
|
| 104 |
|
| 105 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 106 |
data = model.pipeline(dict(img=image_bgr))
|
|
@@ -111,7 +114,7 @@ def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
|
| 111 |
|
| 112 |
with torch.no_grad():
|
| 113 |
pointmap, scale = model(inputs)
|
| 114 |
-
pointmap = pointmap / scale # → metric
|
| 115 |
|
| 116 |
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
|
| 117 |
pointmap = pointmap[
|
|
@@ -131,33 +134,52 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
|
|
| 131 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 132 |
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
"""
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
# -----------------------------------------------------------------------------
|
|
@@ -166,23 +188,19 @@ def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray | None = None) -> np.ndarr
|
|
| 166 |
@spaces.GPU(duration=120)
|
| 167 |
def predict(image: Image.Image, size: str):
|
| 168 |
if image is None:
|
| 169 |
-
return None
|
| 170 |
|
| 171 |
image_pil = image.convert("RGB")
|
| 172 |
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 173 |
|
| 174 |
model = _get_pointmap_model(size)
|
| 175 |
-
pointmap = _estimate_pointmap(image_bgr, model)
|
| 176 |
h_n, w_n = pointmap.shape[:2]
|
| 177 |
|
| 178 |
mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
|
| 179 |
-
|
| 180 |
-
rgb_native = _depth_to_rgb(depth, mask) # (H_native, W_native, 3) uint8
|
| 181 |
|
| 182 |
-
|
| 183 |
-
w0, h0 = image_pil.size
|
| 184 |
-
rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS)
|
| 185 |
-
return rgb_pil
|
| 186 |
|
| 187 |
|
| 188 |
# -----------------------------------------------------------------------------
|
|
@@ -250,7 +268,14 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 250 |
|
| 251 |
with gr.Row(equal_height=True):
|
| 252 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
with gr.Row():
|
| 256 |
size = gr.Radio(
|
|
@@ -263,7 +288,10 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 263 |
|
| 264 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 265 |
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
+
Image → per-pixel 3D pointmap (camera frame, metric units). Visualized as a
|
| 4 |
+
.ply point cloud rendered with Gradio's Model3D component for interactive 3D
|
| 5 |
+
viewing. Foreground mask is mandatory.
|
| 6 |
+
|
| 7 |
+
Everything runs at the model's NATIVE resolution (max 1024×768 grid → at most
|
| 8 |
+
~786K points before subsampling to 200K). No huge interpolations.
|
| 9 |
"""
|
| 10 |
|
| 11 |
import sys
|
|
|
|
| 17 |
import cv2
|
| 18 |
import gradio as gr
|
| 19 |
import numpy as np
|
| 20 |
+
import open3d as o3d
|
| 21 |
import spaces
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
|
|
|
| 96 |
return _fg_model
|
| 97 |
|
| 98 |
|
|
|
|
| 99 |
print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
|
| 100 |
_get_pointmap_model(DEFAULT_SIZE)
|
| 101 |
_get_fg_model()
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
# -----------------------------------------------------------------------------
|
| 106 |
+
# Inference (always at native resolution)
|
| 107 |
|
| 108 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 109 |
data = model.pipeline(dict(img=image_bgr))
|
|
|
|
| 114 |
|
| 115 |
with torch.no_grad():
|
| 116 |
pointmap, scale = model(inputs)
|
| 117 |
+
pointmap = pointmap / scale # → metric
|
| 118 |
|
| 119 |
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
|
| 120 |
pointmap = pointmap[
|
|
|
|
| 134 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 135 |
|
| 136 |
|
| 137 |
+
# -----------------------------------------------------------------------------
|
| 138 |
+
# Point cloud export (camera marker + cloud, native-res grid)
|
| 139 |
+
|
| 140 |
+
def _camera_marker(radius: float = 0.04, n_points: int = 800,
|
| 141 |
+
color=(0.20, 0.55, 0.96)) -> o3d.geometry.PointCloud:
|
| 142 |
+
"""Tiny slate-blue Fibonacci sphere at the world origin."""
|
| 143 |
+
i = np.arange(n_points)
|
| 144 |
+
phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
|
| 145 |
+
theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
|
| 146 |
+
pts = np.stack([
|
| 147 |
+
radius * np.sin(phi) * np.cos(theta),
|
| 148 |
+
radius * np.sin(phi) * np.sin(theta),
|
| 149 |
+
radius * np.cos(phi),
|
| 150 |
+
], axis=1)
|
| 151 |
+
pc = o3d.geometry.PointCloud()
|
| 152 |
+
pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
|
| 153 |
+
pc.colors = o3d.utility.Vector3dVector(np.tile(color, (n_points, 1)).astype(np.float64))
|
| 154 |
+
return pc
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _make_ply(image_pil_native: Image.Image, pointmap_hwc: np.ndarray,
|
| 158 |
+
mask_hw: np.ndarray, max_points: int = 200_000) -> str:
|
| 159 |
+
"""`image_pil_native` MUST already be sized to `pointmap_hwc.shape[:2]` so
|
| 160 |
+
point colors line up. Output .ply: foreground points + camera marker."""
|
| 161 |
+
h, w = pointmap_hwc.shape[:2]
|
| 162 |
+
image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
|
| 163 |
+
|
| 164 |
+
pts = pointmap_hwc.reshape(-1, 3)
|
| 165 |
+
cols = image_rgb.reshape(-1, 3).astype(np.float32) / 255.0
|
| 166 |
+
|
| 167 |
+
z = pts[:, 2]
|
| 168 |
+
finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
|
| 169 |
+
pts, cols = pts[finite], cols[finite]
|
| 170 |
+
|
| 171 |
+
if len(pts) > max_points:
|
| 172 |
+
idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
|
| 173 |
+
pts, cols = pts[idx], cols[idx]
|
| 174 |
+
|
| 175 |
+
pc = o3d.geometry.PointCloud()
|
| 176 |
+
pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
|
| 177 |
+
pc.colors = o3d.utility.Vector3dVector(cols.astype(np.float64))
|
| 178 |
+
pc += _camera_marker()
|
| 179 |
+
|
| 180 |
+
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".ply").name
|
| 181 |
+
o3d.io.write_point_cloud(out_path, pc, write_ascii=False)
|
| 182 |
+
return out_path
|
| 183 |
|
| 184 |
|
| 185 |
# -----------------------------------------------------------------------------
|
|
|
|
| 188 |
@spaces.GPU(duration=120)
|
| 189 |
def predict(image: Image.Image, size: str):
|
| 190 |
if image is None:
|
| 191 |
+
return None, None
|
| 192 |
|
| 193 |
image_pil = image.convert("RGB")
|
| 194 |
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 195 |
|
| 196 |
model = _get_pointmap_model(size)
|
| 197 |
+
pointmap = _estimate_pointmap(image_bgr, model) # (H_n, W_n, 3) — at most 1024 in either dim
|
| 198 |
h_n, w_n = pointmap.shape[:2]
|
| 199 |
|
| 200 |
mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
|
| 201 |
+
ply_path = _make_ply(image_pil, pointmap, mask) # native-res .ply
|
|
|
|
| 202 |
|
| 203 |
+
return ply_path, ply_path
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
# -----------------------------------------------------------------------------
|
|
|
|
| 268 |
|
| 269 |
with gr.Row(equal_height=True):
|
| 270 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 271 |
+
out_ply = gr.Model3D(
|
| 272 |
+
label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
|
| 273 |
+
height=640,
|
| 274 |
+
clear_color=[0.07, 0.09, 0.13, 1.0],
|
| 275 |
+
display_mode="point_cloud",
|
| 276 |
+
zoom_speed=0.7,
|
| 277 |
+
pan_speed=0.5,
|
| 278 |
+
)
|
| 279 |
|
| 280 |
with gr.Row():
|
| 281 |
size = gr.Radio(
|
|
|
|
| 288 |
|
| 289 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 290 |
|
| 291 |
+
with gr.Accordion("Raw Pointmap", open=False):
|
| 292 |
+
out_ply_file = gr.File(label="Point cloud (.ply — open in MeshLab/CloudCompare/Blender)")
|
| 293 |
+
|
| 294 |
+
run.click(predict, inputs=[inp, size], outputs=[out_ply, out_ply_file])
|
| 295 |
|
| 296 |
|
| 297 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -19,3 +19,4 @@ prettytable
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
|
|
|
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
| 22 |
+
open3d
|