Spaces:
Running on Zero
Running on Zero
Rawal Khirodkar commited on
Commit ·
5dd5fbb
1
Parent(s): 824c1d9
Pointmap: trimesh→.glb (MoGe-2 pattern), Model3D back, cap input height to 1024
Browse files- app.py +65 -76
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
-
Image → per-pixel 3D pointmap (camera frame, metric units).
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import sys
|
|
@@ -13,14 +14,15 @@ import os
|
|
| 13 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
|
| 15 |
import tempfile
|
|
|
|
| 16 |
|
| 17 |
import cv2
|
| 18 |
import gradio as gr
|
| 19 |
import numpy as np
|
| 20 |
-
import open3d as o3d
|
| 21 |
import spaces
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
|
|
|
| 24 |
from PIL import Image
|
| 25 |
from torchvision import transforms
|
| 26 |
|
|
@@ -57,12 +59,13 @@ POINTMAP_MODELS = {
|
|
| 57 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
|
| 58 |
},
|
| 59 |
}
|
| 60 |
-
DEFAULT_SIZE = "0.4B" # iteration mode
|
| 61 |
|
| 62 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 63 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 64 |
|
| 65 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 66 |
|
| 67 |
_fg_transform = transforms.Compose([
|
| 68 |
transforms.Resize((1024, 768)),
|
|
@@ -103,7 +106,15 @@ print("[startup] ready.")
|
|
| 103 |
|
| 104 |
|
| 105 |
# -----------------------------------------------------------------------------
|
| 106 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 109 |
data = model.pipeline(dict(img=image_bgr))
|
|
@@ -134,116 +145,87 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
|
|
| 134 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 135 |
|
| 136 |
|
| 137 |
-
def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 138 |
-
"""Inverse-depth turbo colormap (matches sapiens2 vis_pointmap.py).
|
| 139 |
-
Background pixels are left at 0 — caller should overlay them."""
|
| 140 |
-
valid = np.isfinite(depth) & (depth > 1e-3) & mask
|
| 141 |
-
rgb = np.zeros((*depth.shape, 3), dtype=np.uint8)
|
| 142 |
-
if not valid.any():
|
| 143 |
-
return rgb
|
| 144 |
-
inv = np.zeros_like(depth, dtype=np.float32)
|
| 145 |
-
inv[valid] = 1.0 / depth[valid]
|
| 146 |
-
p1, p99 = np.percentile(inv[valid], [1, 99])
|
| 147 |
-
lo, hi = float(p1), float(p99)
|
| 148 |
-
if hi <= lo:
|
| 149 |
-
hi = lo + 1e-3
|
| 150 |
-
norm = ((inv - lo) / (hi - lo)).clip(0, 1)
|
| 151 |
-
grey = (norm * 255.0).astype(np.uint8)
|
| 152 |
-
color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1] # cv2 is BGR → RGB
|
| 153 |
-
rgb[valid] = color[valid]
|
| 154 |
-
return rgb
|
| 155 |
-
|
| 156 |
-
|
| 157 |
# -----------------------------------------------------------------------------
|
| 158 |
-
# Point cloud export (
|
| 159 |
|
| 160 |
def _camera_marker(radius: float = 0.04, n_points: int = 800,
|
| 161 |
-
color=(
|
| 162 |
-
"""Tiny slate-blue Fibonacci sphere at the world origin."""
|
| 163 |
i = np.arange(n_points)
|
| 164 |
phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
|
| 165 |
theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
|
| 166 |
-
|
| 167 |
radius * np.sin(phi) * np.cos(theta),
|
| 168 |
radius * np.sin(phi) * np.sin(theta),
|
| 169 |
radius * np.cos(phi),
|
| 170 |
-
], axis=1)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
pc.colors = o3d.utility.Vector3dVector(np.tile(color, (n_points, 1)).astype(np.float64))
|
| 174 |
-
return pc
|
| 175 |
|
| 176 |
|
| 177 |
-
def
|
| 178 |
mask_hw: np.ndarray, max_points: int = 200_000) -> str:
|
| 179 |
-
"""`image_pil_native` MUST already be sized to `pointmap_hwc.shape[:2]` so
|
| 180 |
-
point colors line up. Output .ply: foreground points + camera marker."""
|
| 181 |
h, w = pointmap_hwc.shape[:2]
|
| 182 |
image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
|
| 183 |
|
| 184 |
-
pts = pointmap_hwc.reshape(-1, 3)
|
| 185 |
-
|
| 186 |
|
| 187 |
z = pts[:, 2]
|
| 188 |
finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
|
| 189 |
-
pts,
|
| 190 |
|
| 191 |
if len(pts) > max_points:
|
| 192 |
idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
|
| 193 |
-
pts,
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
return out_path
|
| 203 |
|
| 204 |
|
| 205 |
# -----------------------------------------------------------------------------
|
| 206 |
# Gradio handler
|
| 207 |
|
| 208 |
-
import time as _t
|
| 209 |
-
|
| 210 |
@spaces.GPU(duration=120)
|
| 211 |
def predict(image: Image.Image, size: str):
|
| 212 |
if image is None:
|
| 213 |
return None, None
|
| 214 |
|
| 215 |
t0 = _t.perf_counter()
|
| 216 |
-
image_pil = image.convert("RGB")
|
| 217 |
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 218 |
-
print(f"[time] convert+
|
| 219 |
|
| 220 |
t = _t.perf_counter()
|
| 221 |
model = _get_pointmap_model(size)
|
| 222 |
-
print(f"[time] _get_pointmap_model {(_t.perf_counter()-t)*1000:.0f} ms")
|
| 223 |
-
|
| 224 |
-
t = _t.perf_counter()
|
| 225 |
pointmap = _estimate_pointmap(image_bgr, model)
|
| 226 |
h_n, w_n = pointmap.shape[:2]
|
| 227 |
-
print(f"[time]
|
| 228 |
|
| 229 |
t = _t.perf_counter()
|
| 230 |
mask = _foreground_mask(image_pil, h_n, w_n)
|
| 231 |
-
print(f"[time]
|
| 232 |
|
| 233 |
t = _t.perf_counter()
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
depth_rgb[~mask] = 200
|
| 237 |
-
w0, h0 = image_pil.size
|
| 238 |
-
depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS)
|
| 239 |
-
print(f"[time] depth heatmap+resize {(_t.perf_counter()-t)*1000:.0f} ms (target {w0}x{h0})")
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
print(f"[time] _make_ply {(_t.perf_counter()-t)*1000:.0f} ms")
|
| 244 |
-
|
| 245 |
-
print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms")
|
| 246 |
-
return depth_pil, ply_path
|
| 247 |
|
| 248 |
|
| 249 |
# -----------------------------------------------------------------------------
|
|
@@ -311,7 +293,14 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 311 |
|
| 312 |
with gr.Row(equal_height=True):
|
| 313 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
with gr.Row():
|
| 317 |
size = gr.Radio(
|
|
@@ -325,9 +314,9 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
|
|
| 325 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 326 |
|
| 327 |
with gr.Accordion("Raw Pointmap", open=False):
|
| 328 |
-
|
| 329 |
|
| 330 |
-
run.click(predict, inputs=[inp, size], outputs=[
|
| 331 |
|
| 332 |
|
| 333 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""Sapiens2 pointmap Gradio Space.
|
| 2 |
|
| 3 |
+
Image → per-pixel 3D pointmap (camera frame, metric units). Right pane is an
|
| 4 |
+
interactive 3D point-cloud viewer rendering a `.glb` exported via trimesh
|
| 5 |
+
(MoGe-2's approach — much faster than Open3D's `.ply` for Three.js viewers).
|
| 6 |
|
| 7 |
+
All work happens at the model's NATIVE resolution. We additionally cap the
|
| 8 |
+
input image to height=1024 before processing so 4K uploads don't blow up
|
| 9 |
+
downstream sizes.
|
| 10 |
"""
|
| 11 |
|
| 12 |
import sys
|
|
|
|
| 14 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
|
| 16 |
import tempfile
|
| 17 |
+
import time as _t
|
| 18 |
|
| 19 |
import cv2
|
| 20 |
import gradio as gr
|
| 21 |
import numpy as np
|
|
|
|
| 22 |
import spaces
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
+
import trimesh
|
| 26 |
from PIL import Image
|
| 27 |
from torchvision import transforms
|
| 28 |
|
|
|
|
| 59 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
|
| 60 |
},
|
| 61 |
}
|
| 62 |
+
DEFAULT_SIZE = "0.4B" # iteration mode
|
| 63 |
|
| 64 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 65 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 66 |
|
| 67 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
|
| 69 |
|
| 70 |
_fg_transform = transforms.Compose([
|
| 71 |
transforms.Resize((1024, 768)),
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
# -----------------------------------------------------------------------------
|
| 109 |
+
# Helpers
|
| 110 |
+
|
| 111 |
+
def _cap_height(image_pil: Image.Image, max_h: int = MAX_HEIGHT) -> Image.Image:
|
| 112 |
+
w, h = image_pil.size
|
| 113 |
+
if h <= max_h:
|
| 114 |
+
return image_pil
|
| 115 |
+
new_w = int(round(w * max_h / h))
|
| 116 |
+
return image_pil.resize((new_w, max_h), Image.LANCZOS)
|
| 117 |
+
|
| 118 |
|
| 119 |
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 120 |
data = model.pipeline(dict(img=image_bgr))
|
|
|
|
| 145 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 146 |
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
# -----------------------------------------------------------------------------
|
| 149 |
+
# Point cloud export — trimesh → .glb (much faster than Open3D .ply for Three.js)
|
| 150 |
|
| 151 |
def _camera_marker(radius: float = 0.04, n_points: int = 800,
|
| 152 |
+
color=(51, 140, 245)):
|
| 153 |
+
"""Tiny slate-blue Fibonacci sphere at the world origin. Returns (verts, cols)."""
|
| 154 |
i = np.arange(n_points)
|
| 155 |
phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
|
| 156 |
theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
|
| 157 |
+
verts = np.stack([
|
| 158 |
radius * np.sin(phi) * np.cos(theta),
|
| 159 |
radius * np.sin(phi) * np.sin(theta),
|
| 160 |
radius * np.cos(phi),
|
| 161 |
+
], axis=1).astype(np.float32)
|
| 162 |
+
cols = np.tile(np.array(color + (255,), dtype=np.uint8), (n_points, 1))
|
| 163 |
+
return verts, cols
|
|
|
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
+
def _make_glb(image_pil_native: Image.Image, pointmap_hwc: np.ndarray,
|
| 167 |
mask_hw: np.ndarray, max_points: int = 200_000) -> str:
|
|
|
|
|
|
|
| 168 |
h, w = pointmap_hwc.shape[:2]
|
| 169 |
image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
|
| 170 |
|
| 171 |
+
pts = pointmap_hwc.reshape(-1, 3).astype(np.float32)
|
| 172 |
+
cols_rgb = image_rgb.reshape(-1, 3).astype(np.uint8)
|
| 173 |
|
| 174 |
z = pts[:, 2]
|
| 175 |
finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
|
| 176 |
+
pts, cols_rgb = pts[finite], cols_rgb[finite]
|
| 177 |
|
| 178 |
if len(pts) > max_points:
|
| 179 |
idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
|
| 180 |
+
pts, cols_rgb = pts[idx], cols_rgb[idx]
|
| 181 |
+
|
| 182 |
+
cam_verts, cam_cols = _camera_marker()
|
| 183 |
+
verts = np.concatenate([pts, cam_verts], axis=0)
|
| 184 |
+
cols_rgba = np.concatenate(
|
| 185 |
+
[np.concatenate([cols_rgb, np.full((len(cols_rgb), 1), 255, dtype=np.uint8)], axis=1),
|
| 186 |
+
cam_cols], axis=0,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Three.js viewers (and gr.Model3D) typically use Y-up. Sapiens2 pointmaps
|
| 190 |
+
# come in camera frame with Y down, Z forward — flip Y so the viewer's
|
| 191 |
+
# default orientation matches photographic intuition.
|
| 192 |
+
verts = verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)
|
| 193 |
+
|
| 194 |
+
pc = trimesh.PointCloud(vertices=verts, colors=cols_rgba)
|
| 195 |
+
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".glb").name
|
| 196 |
+
pc.export(out_path)
|
| 197 |
return out_path
|
| 198 |
|
| 199 |
|
| 200 |
# -----------------------------------------------------------------------------
|
| 201 |
# Gradio handler
|
| 202 |
|
|
|
|
|
|
|
| 203 |
@spaces.GPU(duration=120)
|
| 204 |
def predict(image: Image.Image, size: str):
|
| 205 |
if image is None:
|
| 206 |
return None, None
|
| 207 |
|
| 208 |
t0 = _t.perf_counter()
|
| 209 |
+
image_pil = _cap_height(image.convert("RGB")) # cap to 1024px height
|
| 210 |
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 211 |
+
print(f"[time] convert+cap {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size})")
|
| 212 |
|
| 213 |
t = _t.perf_counter()
|
| 214 |
model = _get_pointmap_model(size)
|
|
|
|
|
|
|
|
|
|
| 215 |
pointmap = _estimate_pointmap(image_bgr, model)
|
| 216 |
h_n, w_n = pointmap.shape[:2]
|
| 217 |
+
print(f"[time] pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})")
|
| 218 |
|
| 219 |
t = _t.perf_counter()
|
| 220 |
mask = _foreground_mask(image_pil, h_n, w_n)
|
| 221 |
+
print(f"[time] fg mask {(_t.perf_counter()-t)*1000:.0f} ms")
|
| 222 |
|
| 223 |
t = _t.perf_counter()
|
| 224 |
+
glb_path = _make_glb(image_pil, pointmap, mask)
|
| 225 |
+
print(f"[time] glb export {(_t.perf_counter()-t)*1000:.0f} ms")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
+
print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms")
|
| 228 |
+
return glb_path, glb_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
# -----------------------------------------------------------------------------
|
|
|
|
| 293 |
|
| 294 |
with gr.Row(equal_height=True):
|
| 295 |
inp = gr.Image(label="Input", type="pil", height=640)
|
| 296 |
+
out_glb = gr.Model3D(
|
| 297 |
+
label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
|
| 298 |
+
height=640,
|
| 299 |
+
clear_color=[0.07, 0.09, 0.13, 1.0],
|
| 300 |
+
display_mode="point_cloud",
|
| 301 |
+
zoom_speed=0.7,
|
| 302 |
+
pan_speed=0.5,
|
| 303 |
+
)
|
| 304 |
|
| 305 |
with gr.Row():
|
| 306 |
size = gr.Radio(
|
|
|
|
| 314 |
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 315 |
|
| 316 |
with gr.Accordion("Raw Pointmap", open=False):
|
| 317 |
+
out_glb_file = gr.File(label="Point cloud (.glb — open in Blender/MeshLab/web viewers)")
|
| 318 |
|
| 319 |
+
run.click(predict, inputs=[inp, size], outputs=[out_glb, out_glb_file])
|
| 320 |
|
| 321 |
|
| 322 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -19,4 +19,4 @@ prettytable
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
| 22 |
-
|
|
|
|
| 19 |
termcolor
|
| 20 |
accelerate
|
| 21 |
rich
|
| 22 |
+
trimesh
|