Spaces:
Runtime error
Runtime error
Rawal Khirodkar commited on
Commit ·
2482c8d
1
Parent(s): 2593450
Pointmap: pivot to depth-z heatmap (turbo); drop Model3D + Open3D + .ply pipeline
Browse files- app.py +45 -86
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
-
Image → per-pixel 3D pointmap (camera frame, metric units).
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
foreground points end up in the cloud.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import sys
|
|
@@ -15,7 +14,6 @@ import tempfile
|
|
| 15 |
import cv2
|
| 16 |
import gradio as gr
|
| 17 |
import numpy as np
|
| 18 |
-
import open3d as o3d
|
| 19 |
import spaces
|
| 20 |
import torch
|
| 21 |
import torch.nn.functional as F
|
|
@@ -95,20 +93,16 @@ def _get_fg_model():
|
|
| 95 |
|
| 96 |
|
| 97 |
# Iteration mode: only preload the default (0.4B) for fast Space boot.
|
| 98 |
-
# Re-enable full preload by uncommenting the loop below.
|
| 99 |
print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
|
| 100 |
_get_pointmap_model(DEFAULT_SIZE)
|
| 101 |
_get_fg_model()
|
| 102 |
-
# for _size in POINTMAP_MODELS:
|
| 103 |
-
# _get_pointmap_model(_size)
|
| 104 |
print("[startup] ready.")
|
| 105 |
|
| 106 |
|
| 107 |
# -----------------------------------------------------------------------------
|
| 108 |
-
# Inference
|
| 109 |
|
| 110 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 111 |
-
h0, w0 = image_bgr.shape[:2]
|
| 112 |
data = model.pipeline(dict(img=image_bgr))
|
| 113 |
data = model.data_preprocessor(data)
|
| 114 |
inputs, data_samples = data["inputs"], data["data_samples"]
|
|
@@ -119,15 +113,13 @@ def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
|
| 119 |
pointmap, scale = model(inputs)
|
| 120 |
pointmap = pointmap / scale # → metric units
|
| 121 |
|
| 122 |
-
|
| 123 |
-
pad_left, pad_right, pad_top, pad_bottom = pad
|
| 124 |
pointmap = pointmap[
|
| 125 |
:, :,
|
| 126 |
pad_top : inputs.shape[2] - pad_bottom,
|
| 127 |
pad_left : inputs.shape[3] - pad_right,
|
| 128 |
]
|
| 129 |
-
|
| 130 |
-
return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3)
|
| 131 |
|
| 132 |
|
| 133 |
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
|
|
@@ -139,81 +131,58 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
|
|
| 139 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 140 |
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def _camera_marker(radius: float = 0.04, n_points: int = 800,
|
| 146 |
-
color=(0.20, 0.55, 0.96)) -> o3d.geometry.PointCloud:
|
| 147 |
-
"""Small uniformly-blue sphere at the world origin marking the camera.
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
"""
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0)
|
| 174 |
-
if mask_hw is not None:
|
| 175 |
-
finite &= mask_hw.reshape(-1)
|
| 176 |
-
pts, cols = pts[finite], cols[finite]
|
| 177 |
-
|
| 178 |
-
if len(pts) > max_points:
|
| 179 |
-
idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
|
| 180 |
-
pts, cols = pts[idx], cols[idx]
|
| 181 |
-
|
| 182 |
-
pc = o3d.geometry.PointCloud()
|
| 183 |
-
pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
|
| 184 |
-
pc.colors = o3d.utility.Vector3dVector(cols.astype(np.float64))
|
| 185 |
-
|
| 186 |
-
# Add the camera marker (blue ball at origin) so users see where the
|
| 187 |
-
# observer is in the reconstructed 3D scene.
|
| 188 |
-
pc += _camera_marker()
|
| 189 |
-
|
| 190 |
-
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".ply").name
|
| 191 |
-
o3d.io.write_point_cloud(out_path, pc, write_ascii=False)
|
| 192 |
-
return out_path
|
| 193 |
|
| 194 |
|
| 195 |
# -----------------------------------------------------------------------------
|
| 196 |
# Gradio handler
|
| 197 |
|
| 198 |
-
@spaces.GPU(duration=
|
| 199 |
def predict(image: Image.Image, size: str):
|
| 200 |
if image is None:
|
| 201 |
-
return None
|
| 202 |
|
| 203 |
image_pil = image.convert("RGB")
|
| 204 |
-
|
| 205 |
-
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
| 206 |
-
h0, w0 = image_rgb.shape[:2]
|
| 207 |
|
| 208 |
model = _get_pointmap_model(size)
|
| 209 |
-
pointmap = _estimate_pointmap(image_bgr, model)
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
ply_path = _make_ply(image_rgb, pointmap, mask)
|
| 215 |
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
# -----------------------------------------------------------------------------
|
|
@@ -281,14 +250,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 281 |
|
| 282 |
with gr.Row(equal_height=True):
|
| 283 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 284 |
-
|
| 285 |
-
label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
|
| 286 |
-
height=640,
|
| 287 |
-
clear_color=[0.07, 0.09, 0.13, 1.0], # subtle slate-900 backdrop
|
| 288 |
-
display_mode="point_cloud",
|
| 289 |
-
zoom_speed=0.7,
|
| 290 |
-
pan_speed=0.5,
|
| 291 |
-
)
|
| 292 |
|
| 293 |
with gr.Row():
|
| 294 |
size = gr.Radio(
|
|
@@ -301,10 +263,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 301 |
|
| 302 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 303 |
|
| 304 |
-
|
| 305 |
-
out_ply_file = gr.File(label="Point cloud (.ply — open in MeshLab/CloudCompare/Blender)")
|
| 306 |
-
|
| 307 |
-
run.click(predict, inputs=[inp, size], outputs=[out_ply, out_ply_file])
|
| 308 |
|
| 309 |
|
| 310 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
+
Image → per-pixel 3D pointmap (camera frame, metric units). For now we just
|
| 4 |
+
visualize the depth (z) channel as a colored heatmap, matching the look of the
|
| 5 |
+
normal demo. The 3D point-cloud viewer can be re-enabled later.
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import sys
|
|
|
|
| 14 |
import cv2
|
| 15 |
import gradio as gr
|
| 16 |
import numpy as np
|
|
|
|
| 17 |
import spaces
|
| 18 |
import torch
|
| 19 |
import torch.nn.functional as F
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
# Iteration mode: only preload the default (0.4B) for fast Space boot.
|
|
|
|
| 96 |
print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
|
| 97 |
_get_pointmap_model(DEFAULT_SIZE)
|
| 98 |
_get_fg_model()
|
|
|
|
|
|
|
| 99 |
print("[startup] ready.")
|
| 100 |
|
| 101 |
|
| 102 |
# -----------------------------------------------------------------------------
|
| 103 |
+
# Inference (operates at the model's native resolution — no big upsamples)
|
| 104 |
|
| 105 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
|
|
|
| 106 |
data = model.pipeline(dict(img=image_bgr))
|
| 107 |
data = model.data_preprocessor(data)
|
| 108 |
inputs, data_samples = data["inputs"], data["data_samples"]
|
|
|
|
| 113 |
pointmap, scale = model(inputs)
|
| 114 |
pointmap = pointmap / scale # → metric units
|
| 115 |
|
| 116 |
+
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
|
|
|
|
| 117 |
pointmap = pointmap[
|
| 118 |
:, :,
|
| 119 |
pad_top : inputs.shape[2] - pad_bottom,
|
| 120 |
pad_left : inputs.shape[3] - pad_right,
|
| 121 |
]
|
| 122 |
+
return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3)
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
|
|
|
|
| 131 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 132 |
|
| 133 |
|
| 134 |
+
def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
|
| 135 |
+
"""Depth (H, W) → RGB (H, W, 3) uint8 via inverse-depth turbo colormap.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
Inverse-depth (1/z) gives more contrast on near surfaces (where humans tend
|
| 138 |
+
to be), which matches what most SfM/depth viewers show.
|
| 139 |
"""
|
| 140 |
+
valid = np.isfinite(depth) & (depth > 1e-3)
|
| 141 |
+
if mask is not None:
|
| 142 |
+
valid &= mask
|
| 143 |
+
if not valid.any():
|
| 144 |
+
return np.zeros((*depth.shape, 3), dtype=np.uint8)
|
| 145 |
+
|
| 146 |
+
inv = np.zeros_like(depth, dtype=np.float32)
|
| 147 |
+
inv[valid] = 1.0 / depth[valid]
|
| 148 |
+
p1, p99 = np.percentile(inv[valid], [1, 99])
|
| 149 |
+
lo, hi = float(p1), float(p99)
|
| 150 |
+
if hi <= lo:
|
| 151 |
+
hi = lo + 1e-3
|
| 152 |
+
norm = np.zeros_like(inv, dtype=np.float32)
|
| 153 |
+
norm[valid] = ((inv[valid] - lo) / (hi - lo)).clip(0, 1)
|
| 154 |
+
grey = (norm * 255.0).astype(np.uint8)
|
| 155 |
+
|
| 156 |
+
# cv2.applyColorMap returns BGR — flip to RGB for Gradio.
|
| 157 |
+
rgb = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1].copy()
|
| 158 |
+
if mask is not None:
|
| 159 |
+
rgb[~mask] = 0 # background → black
|
| 160 |
+
return rgb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
# -----------------------------------------------------------------------------
|
| 164 |
# Gradio handler
|
| 165 |
|
| 166 |
+
@spaces.GPU(duration=120)
|
| 167 |
def predict(image: Image.Image, size: str):
|
| 168 |
if image is None:
|
| 169 |
+
return None
|
| 170 |
|
| 171 |
image_pil = image.convert("RGB")
|
| 172 |
+
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
| 173 |
|
| 174 |
model = _get_pointmap_model(size)
|
| 175 |
+
pointmap = _estimate_pointmap(image_bgr, model) # (H_native, W_native, 3)
|
| 176 |
+
h_n, w_n = pointmap.shape[:2]
|
| 177 |
|
| 178 |
+
mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
|
| 179 |
+
depth = pointmap[:, :, 2] # z channel
|
| 180 |
+
rgb_native = _depth_to_rgb(depth, mask) # (H_native, W_native, 3) uint8
|
|
|
|
| 181 |
|
| 182 |
+
# Lanczos upsample the RGB heatmap to the original image size — sharp.
|
| 183 |
+
w0, h0 = image_pil.size
|
| 184 |
+
rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS)
|
| 185 |
+
return rgb_pil
|
| 186 |
|
| 187 |
|
| 188 |
# -----------------------------------------------------------------------------
|
|
|
|
| 250 |
|
| 251 |
with gr.Row(equal_height=True):
|
| 252 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 253 |
+
out_img = gr.Image(label="Depth (turbo)", type="pil", height=640)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
with gr.Row():
|
| 256 |
size = gr.Radio(
|
|
|
|
| 263 |
|
| 264 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 265 |
|
| 266 |
+
run.click(predict, inputs=[inp, size], outputs=[out_img])
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -19,4 +19,3 @@ prettytable
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
| 22 |
-
open3d
|
|
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
|
|