Rawal Khirodkar commited on
Commit
5dd5fbb
·
1 Parent(s): 824c1d9

Pointmap: trimesh→.glb (MoGe-2 pattern), Model3D back, cap input height to 1024

Browse files
Files changed (2) hide show
  1. app.py +65 -76
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,11 +1,12 @@
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
- Image → per-pixel 3D pointmap (camera frame, metric units). Visualized as a
4
- .ply point cloud rendered with Gradio's Model3D component for interactive 3D
5
- viewing. Foreground mask is mandatory.
6
 
7
- Everything runs at the model's NATIVE resolution (max 1024×768 grid → at most
8
- ~786K points before subsampling to 200K). No huge interpolations.
 
9
  """
10
 
11
  import sys
@@ -13,14 +14,15 @@ import os
13
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
 
15
  import tempfile
 
16
 
17
  import cv2
18
  import gradio as gr
19
  import numpy as np
20
- import open3d as o3d
21
  import spaces
22
  import torch
23
  import torch.nn.functional as F
 
24
  from PIL import Image
25
  from torchvision import transforms
26
 
@@ -57,12 +59,13 @@ POINTMAP_MODELS = {
57
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
58
  },
59
  }
60
- DEFAULT_SIZE = "0.4B" # iteration mode — only this is preloaded; others lazy-load on click
61
 
62
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
63
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
64
 
65
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
66
 
67
  _fg_transform = transforms.Compose([
68
  transforms.Resize((1024, 768)),
@@ -103,7 +106,15 @@ print("[startup] ready.")
103
 
104
 
105
  # -----------------------------------------------------------------------------
106
- # Inference (always at native resolution)
 
 
 
 
 
 
 
 
107
 
108
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
109
  data = model.pipeline(dict(img=image_bgr))
@@ -134,116 +145,87 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
134
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
135
 
136
 
137
- def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray:
138
- """Inverse-depth turbo colormap (matches sapiens2 vis_pointmap.py).
139
- Background pixels are left at 0 — caller should overlay them."""
140
- valid = np.isfinite(depth) & (depth > 1e-3) & mask
141
- rgb = np.zeros((*depth.shape, 3), dtype=np.uint8)
142
- if not valid.any():
143
- return rgb
144
- inv = np.zeros_like(depth, dtype=np.float32)
145
- inv[valid] = 1.0 / depth[valid]
146
- p1, p99 = np.percentile(inv[valid], [1, 99])
147
- lo, hi = float(p1), float(p99)
148
- if hi <= lo:
149
- hi = lo + 1e-3
150
- norm = ((inv - lo) / (hi - lo)).clip(0, 1)
151
- grey = (norm * 255.0).astype(np.uint8)
152
- color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1] # cv2 is BGR → RGB
153
- rgb[valid] = color[valid]
154
- return rgb
155
-
156
-
157
  # -----------------------------------------------------------------------------
158
- # Point cloud export (camera marker + cloud, native-res grid)
159
 
160
  def _camera_marker(radius: float = 0.04, n_points: int = 800,
161
- color=(0.20, 0.55, 0.96)) -> o3d.geometry.PointCloud:
162
- """Tiny slate-blue Fibonacci sphere at the world origin."""
163
  i = np.arange(n_points)
164
  phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
165
  theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
166
- pts = np.stack([
167
  radius * np.sin(phi) * np.cos(theta),
168
  radius * np.sin(phi) * np.sin(theta),
169
  radius * np.cos(phi),
170
- ], axis=1)
171
- pc = o3d.geometry.PointCloud()
172
- pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
173
- pc.colors = o3d.utility.Vector3dVector(np.tile(color, (n_points, 1)).astype(np.float64))
174
- return pc
175
 
176
 
177
- def _make_ply(image_pil_native: Image.Image, pointmap_hwc: np.ndarray,
178
  mask_hw: np.ndarray, max_points: int = 200_000) -> str:
179
- """`image_pil_native` MUST already be sized to `pointmap_hwc.shape[:2]` so
180
- point colors line up. Output .ply: foreground points + camera marker."""
181
  h, w = pointmap_hwc.shape[:2]
182
  image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
183
 
184
- pts = pointmap_hwc.reshape(-1, 3)
185
- cols = image_rgb.reshape(-1, 3).astype(np.float32) / 255.0
186
 
187
  z = pts[:, 2]
188
  finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
189
- pts, cols = pts[finite], cols[finite]
190
 
191
  if len(pts) > max_points:
192
  idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
193
- pts, cols = pts[idx], cols[idx]
194
-
195
- pc = o3d.geometry.PointCloud()
196
- pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
197
- pc.colors = o3d.utility.Vector3dVector(cols.astype(np.float64))
198
- pc += _camera_marker()
199
-
200
- out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".ply").name
201
- o3d.io.write_point_cloud(out_path, pc, write_ascii=False)
 
 
 
 
 
 
 
 
202
  return out_path
203
 
204
 
205
  # -----------------------------------------------------------------------------
206
  # Gradio handler
207
 
208
- import time as _t
209
-
210
  @spaces.GPU(duration=120)
211
  def predict(image: Image.Image, size: str):
212
  if image is None:
213
  return None, None
214
 
215
  t0 = _t.perf_counter()
216
- image_pil = image.convert("RGB")
217
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
218
- print(f"[time] convert+bgr {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size})")
219
 
220
  t = _t.perf_counter()
221
  model = _get_pointmap_model(size)
222
- print(f"[time] _get_pointmap_model {(_t.perf_counter()-t)*1000:.0f} ms")
223
-
224
- t = _t.perf_counter()
225
  pointmap = _estimate_pointmap(image_bgr, model)
226
  h_n, w_n = pointmap.shape[:2]
227
- print(f"[time] _estimate_pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})")
228
 
229
  t = _t.perf_counter()
230
  mask = _foreground_mask(image_pil, h_n, w_n)
231
- print(f"[time] _foreground_mask {(_t.perf_counter()-t)*1000:.0f} ms")
232
 
233
  t = _t.perf_counter()
234
- depth = pointmap[:, :, 2]
235
- depth_rgb = _depth_to_rgb(depth, mask)
236
- depth_rgb[~mask] = 200
237
- w0, h0 = image_pil.size
238
- depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS)
239
- print(f"[time] depth heatmap+resize {(_t.perf_counter()-t)*1000:.0f} ms (target {w0}x{h0})")
240
 
241
- t = _t.perf_counter()
242
- ply_path = _make_ply(image_pil, pointmap, mask)
243
- print(f"[time] _make_ply {(_t.perf_counter()-t)*1000:.0f} ms")
244
-
245
- print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms")
246
- return depth_pil, ply_path
247
 
248
 
249
  # -----------------------------------------------------------------------------
@@ -311,7 +293,14 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
311
 
312
  with gr.Row(equal_height=True):
313
  inp = gr.Image(label="Input", type="pil", height=640)
314
- out_img = gr.Image(label="Depth (Z)", type="pil", height=640)
 
 
 
 
 
 
 
315
 
316
  with gr.Row():
317
  size = gr.Radio(
@@ -325,9 +314,9 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
325
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
326
 
327
  with gr.Accordion("Raw Pointmap", open=False):
328
- out_ply_file = gr.File(label="Point cloud (.ply — open in MeshLab/CloudCompare/Blender)")
329
 
330
- run.click(predict, inputs=[inp, size], outputs=[out_img, out_ply_file])
331
 
332
 
333
  if __name__ == "__main__":
 
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
+ Image → per-pixel 3D pointmap (camera frame, metric units). Right pane is an
4
+ interactive 3D point-cloud viewer rendering a `.glb` exported via trimesh
5
+ (MoGe-2's approach much faster than Open3D's `.ply` for Three.js viewers).
6
 
7
+ All work happens at the model's NATIVE resolution. We additionally cap the
8
+ input image to height=1024 before processing so 4K uploads don't blow up
9
+ downstream sizes.
10
  """
11
 
12
  import sys
 
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
 
16
  import tempfile
17
+ import time as _t
18
 
19
  import cv2
20
  import gradio as gr
21
  import numpy as np
 
22
  import spaces
23
  import torch
24
  import torch.nn.functional as F
25
+ import trimesh
26
  from PIL import Image
27
  from torchvision import transforms
28
 
 
59
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
60
  },
61
  }
62
+ DEFAULT_SIZE = "0.4B" # iteration mode
63
 
64
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
65
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
66
 
67
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
+ MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast
69
 
70
  _fg_transform = transforms.Compose([
71
  transforms.Resize((1024, 768)),
 
106
 
107
 
108
  # -----------------------------------------------------------------------------
109
+ # Helpers
110
+
111
+ def _cap_height(image_pil: Image.Image, max_h: int = MAX_HEIGHT) -> Image.Image:
112
+ w, h = image_pil.size
113
+ if h <= max_h:
114
+ return image_pil
115
+ new_w = int(round(w * max_h / h))
116
+ return image_pil.resize((new_w, max_h), Image.LANCZOS)
117
+
118
 
119
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
120
  data = model.pipeline(dict(img=image_bgr))
 
145
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # -----------------------------------------------------------------------------
149
+ # Point cloud export — trimesh → .glb (much faster than Open3D .ply for Three.js)
150
 
151
  def _camera_marker(radius: float = 0.04, n_points: int = 800,
152
+ color=(51, 140, 245)):
153
+ """Tiny slate-blue Fibonacci sphere at the world origin. Returns (verts, cols)."""
154
  i = np.arange(n_points)
155
  phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
156
  theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
157
+ verts = np.stack([
158
  radius * np.sin(phi) * np.cos(theta),
159
  radius * np.sin(phi) * np.sin(theta),
160
  radius * np.cos(phi),
161
+ ], axis=1).astype(np.float32)
162
+ cols = np.tile(np.array(color + (255,), dtype=np.uint8), (n_points, 1))
163
+ return verts, cols
 
 
164
 
165
 
166
+ def _make_glb(image_pil_native: Image.Image, pointmap_hwc: np.ndarray,
167
  mask_hw: np.ndarray, max_points: int = 200_000) -> str:
 
 
168
  h, w = pointmap_hwc.shape[:2]
169
  image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
170
 
171
+ pts = pointmap_hwc.reshape(-1, 3).astype(np.float32)
172
+ cols_rgb = image_rgb.reshape(-1, 3).astype(np.uint8)
173
 
174
  z = pts[:, 2]
175
  finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
176
+ pts, cols_rgb = pts[finite], cols_rgb[finite]
177
 
178
  if len(pts) > max_points:
179
  idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
180
+ pts, cols_rgb = pts[idx], cols_rgb[idx]
181
+
182
+ cam_verts, cam_cols = _camera_marker()
183
+ verts = np.concatenate([pts, cam_verts], axis=0)
184
+ cols_rgba = np.concatenate(
185
+ [np.concatenate([cols_rgb, np.full((len(cols_rgb), 1), 255, dtype=np.uint8)], axis=1),
186
+ cam_cols], axis=0,
187
+ )
188
+
189
+ # Three.js viewers (and gr.Model3D) typically use Y-up. Sapiens2 pointmaps
190
+ # come in camera frame with Y down, Z forward — flip Y so the viewer's
191
+ # default orientation matches photographic intuition.
192
+ verts = verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)
193
+
194
+ pc = trimesh.PointCloud(vertices=verts, colors=cols_rgba)
195
+ out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".glb").name
196
+ pc.export(out_path)
197
  return out_path
198
 
199
 
200
  # -----------------------------------------------------------------------------
201
  # Gradio handler
202
 
 
 
203
  @spaces.GPU(duration=120)
204
  def predict(image: Image.Image, size: str):
205
  if image is None:
206
  return None, None
207
 
208
  t0 = _t.perf_counter()
209
+ image_pil = _cap_height(image.convert("RGB")) # cap to 1024px height
210
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
211
+ print(f"[time] convert+cap {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size})")
212
 
213
  t = _t.perf_counter()
214
  model = _get_pointmap_model(size)
 
 
 
215
  pointmap = _estimate_pointmap(image_bgr, model)
216
  h_n, w_n = pointmap.shape[:2]
217
+ print(f"[time] pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})")
218
 
219
  t = _t.perf_counter()
220
  mask = _foreground_mask(image_pil, h_n, w_n)
221
+ print(f"[time] fg mask {(_t.perf_counter()-t)*1000:.0f} ms")
222
 
223
  t = _t.perf_counter()
224
+ glb_path = _make_glb(image_pil, pointmap, mask)
225
+ print(f"[time] glb export {(_t.perf_counter()-t)*1000:.0f} ms")
 
 
 
 
226
 
227
+ print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms")
228
+ return glb_path, glb_path
 
 
 
 
229
 
230
 
231
  # -----------------------------------------------------------------------------
 
293
 
294
  with gr.Row(equal_height=True):
295
  inp = gr.Image(label="Input", type="pil", height=640)
296
+ out_glb = gr.Model3D(
297
+ label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
298
+ height=640,
299
+ clear_color=[0.07, 0.09, 0.13, 1.0],
300
+ display_mode="point_cloud",
301
+ zoom_speed=0.7,
302
+ pan_speed=0.5,
303
+ )
304
 
305
  with gr.Row():
306
  size = gr.Radio(
 
314
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
315
 
316
  with gr.Accordion("Raw Pointmap", open=False):
317
+ out_glb_file = gr.File(label="Point cloud (.glb — open in Blender/MeshLab/web viewers)")
318
 
319
+ run.click(predict, inputs=[inp, size], outputs=[out_glb, out_glb_file])
320
 
321
 
322
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -19,4 +19,4 @@ prettytable
19
  termcolor
20
  accelerate
21
  rich
22
- open3d
 
19
  termcolor
20
  accelerate
21
  rich
22
+ trimesh