Rawal Khirodkar commited on
Commit
2c70f2e
·
1 Parent(s): 380dd37

Pointmap: bring back .ply + Model3D, but native-res only (max 1024×768 grid → 200K pts)

Browse files
Files changed (2) hide show
  1. app.py +71 -43
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,8 +1,11 @@
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
- Image → per-pixel 3D pointmap (camera frame, metric units). For now we just
4
- visualize the depth (z) channel as a colored heatmap, matching the look of the
5
- normal demo. The 3D point-cloud viewer can be re-enabled later.
 
 
 
6
  """
7
 
8
  import sys
@@ -14,6 +17,7 @@ import tempfile
14
  import cv2
15
  import gradio as gr
16
  import numpy as np
 
17
  import spaces
18
  import torch
19
  import torch.nn.functional as F
@@ -92,7 +96,6 @@ def _get_fg_model():
92
  return _fg_model
93
 
94
 
95
- # Iteration mode: only preload the default (0.4B) for fast Space boot.
96
  print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
97
  _get_pointmap_model(DEFAULT_SIZE)
98
  _get_fg_model()
@@ -100,7 +103,7 @@ print("[startup] ready.")
100
 
101
 
102
  # -----------------------------------------------------------------------------
103
- # Inference (operates at the model's native resolution — no big upsamples)
104
 
105
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
106
  data = model.pipeline(dict(img=image_bgr))
@@ -111,7 +114,7 @@ def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
111
 
112
  with torch.no_grad():
113
  pointmap, scale = model(inputs)
114
- pointmap = pointmap / scale # → metric units
115
 
116
  pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
117
  pointmap = pointmap[
@@ -131,33 +134,52 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
131
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
132
 
133
 
134
- def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
135
- """Depth (H, W) RGB (H, W, 3) uint8 via inverse-depth turbo colormap.
136
-
137
- Inverse-depth (1/z) gives more contrast on near surfaces (where humans tend
138
- to be), which matches what most SfM/depth viewers show.
139
- """
140
- valid = np.isfinite(depth) & (depth > 1e-3)
141
- if mask is not None:
142
- valid &= mask
143
- if not valid.any():
144
- return np.zeros((*depth.shape, 3), dtype=np.uint8)
145
-
146
- inv = np.zeros_like(depth, dtype=np.float32)
147
- inv[valid] = 1.0 / depth[valid]
148
- p1, p99 = np.percentile(inv[valid], [1, 99])
149
- lo, hi = float(p1), float(p99)
150
- if hi <= lo:
151
- hi = lo + 1e-3
152
- norm = np.zeros_like(inv, dtype=np.float32)
153
- norm[valid] = ((inv[valid] - lo) / (hi - lo)).clip(0, 1)
154
- grey = (norm * 255.0).astype(np.uint8)
155
-
156
- # cv2.applyColorMap returns BGR flip to RGB for Gradio.
157
- rgb = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1].copy()
158
- if mask is not None:
159
- rgb[~mask] = 0 # background → black
160
- return rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  # -----------------------------------------------------------------------------
@@ -166,23 +188,19 @@ def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray | None = None) -> np.ndarr
166
  @spaces.GPU(duration=120)
167
  def predict(image: Image.Image, size: str):
168
  if image is None:
169
- return None
170
 
171
  image_pil = image.convert("RGB")
172
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
173
 
174
  model = _get_pointmap_model(size)
175
- pointmap = _estimate_pointmap(image_bgr, model) # (H_native, W_native, 3)
176
  h_n, w_n = pointmap.shape[:2]
177
 
178
  mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
179
- depth = pointmap[:, :, 2] # z channel
180
- rgb_native = _depth_to_rgb(depth, mask) # (H_native, W_native, 3) uint8
181
 
182
- # Lanczos upsample the RGB heatmap to the original image size — sharp.
183
- w0, h0 = image_pil.size
184
- rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS)
185
- return rgb_pil
186
 
187
 
188
  # -----------------------------------------------------------------------------
@@ -250,7 +268,14 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
250
 
251
  with gr.Row(equal_height=True):
252
  inp = gr.Image(label="Input", type="pil", height=640)
253
- out_img = gr.Image(label="Depth (Z)", type="pil", height=640)
 
 
 
 
 
 
 
254
 
255
  with gr.Row():
256
  size = gr.Radio(
@@ -263,7 +288,10 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
263
 
264
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
265
 
266
- run.click(predict, inputs=[inp, size], outputs=[out_img])
 
 
 
267
 
268
 
269
  if __name__ == "__main__":
 
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
+ Image → per-pixel 3D pointmap (camera frame, metric units). Visualized as a
4
+ .ply point cloud rendered with Gradio's Model3D component for interactive 3D
5
+ viewing. Foreground mask is mandatory.
6
+
7
+ Everything runs at the model's NATIVE resolution (max 1024×768 grid → at most
8
+ ~786K points before subsampling to 200K). No huge interpolations.
9
  """
10
 
11
  import sys
 
17
  import cv2
18
  import gradio as gr
19
  import numpy as np
20
+ import open3d as o3d
21
  import spaces
22
  import torch
23
  import torch.nn.functional as F
 
96
  return _fg_model
97
 
98
 
 
99
  print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
100
  _get_pointmap_model(DEFAULT_SIZE)
101
  _get_fg_model()
 
103
 
104
 
105
  # -----------------------------------------------------------------------------
106
+ # Inference (always at native resolution)
107
 
108
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
109
  data = model.pipeline(dict(img=image_bgr))
 
114
 
115
  with torch.no_grad():
116
  pointmap, scale = model(inputs)
117
+ pointmap = pointmap / scale # → metric
118
 
119
  pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
120
  pointmap = pointmap[
 
134
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
135
 
136
 
137
+ # -----------------------------------------------------------------------------
138
+ # Point cloud export (camera marker + cloud, native-res grid)
139
+
140
+ def _camera_marker(radius: float = 0.04, n_points: int = 800,
141
+ color=(0.20, 0.55, 0.96)) -> o3d.geometry.PointCloud:
142
+ """Tiny slate-blue Fibonacci sphere at the world origin."""
143
+ i = np.arange(n_points)
144
+ phi = np.arccos(1 - 2 * (i + 0.5) / n_points)
145
+ theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5)
146
+ pts = np.stack([
147
+ radius * np.sin(phi) * np.cos(theta),
148
+ radius * np.sin(phi) * np.sin(theta),
149
+ radius * np.cos(phi),
150
+ ], axis=1)
151
+ pc = o3d.geometry.PointCloud()
152
+ pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
153
+ pc.colors = o3d.utility.Vector3dVector(np.tile(color, (n_points, 1)).astype(np.float64))
154
+ return pc
155
+
156
+
157
+ def _make_ply(image_pil_native: Image.Image, pointmap_hwc: np.ndarray,
158
+ mask_hw: np.ndarray, max_points: int = 200_000) -> str:
159
+ """`image_pil_native` MUST already be sized to `pointmap_hwc.shape[:2]` so
160
+ point colors line up. Output .ply: foreground points + camera marker."""
161
+ h, w = pointmap_hwc.shape[:2]
162
+ image_rgb = np.asarray(image_pil_native.resize((w, h), Image.LANCZOS))
163
+
164
+ pts = pointmap_hwc.reshape(-1, 3)
165
+ cols = image_rgb.reshape(-1, 3).astype(np.float32) / 255.0
166
+
167
+ z = pts[:, 2]
168
+ finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0) & mask_hw.reshape(-1)
169
+ pts, cols = pts[finite], cols[finite]
170
+
171
+ if len(pts) > max_points:
172
+ idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
173
+ pts, cols = pts[idx], cols[idx]
174
+
175
+ pc = o3d.geometry.PointCloud()
176
+ pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
177
+ pc.colors = o3d.utility.Vector3dVector(cols.astype(np.float64))
178
+ pc += _camera_marker()
179
+
180
+ out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".ply").name
181
+ o3d.io.write_point_cloud(out_path, pc, write_ascii=False)
182
+ return out_path
183
 
184
 
185
  # -----------------------------------------------------------------------------
 
188
  @spaces.GPU(duration=120)
189
  def predict(image: Image.Image, size: str):
190
  if image is None:
191
+ return None, None
192
 
193
  image_pil = image.convert("RGB")
194
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
195
 
196
  model = _get_pointmap_model(size)
197
+ pointmap = _estimate_pointmap(image_bgr, model) # (H_n, W_n, 3) — at most 1024 in either dim
198
  h_n, w_n = pointmap.shape[:2]
199
 
200
  mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
201
+ ply_path = _make_ply(image_pil, pointmap, mask) # native-res .ply
 
202
 
203
+ return ply_path, ply_path
 
 
 
204
 
205
 
206
  # -----------------------------------------------------------------------------
 
268
 
269
  with gr.Row(equal_height=True):
270
  inp = gr.Image(label="Input", type="pil", height=640)
271
+ out_ply = gr.Model3D(
272
+ label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
273
+ height=640,
274
+ clear_color=[0.07, 0.09, 0.13, 1.0],
275
+ display_mode="point_cloud",
276
+ zoom_speed=0.7,
277
+ pan_speed=0.5,
278
+ )
279
 
280
  with gr.Row():
281
  size = gr.Radio(
 
288
 
289
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
290
 
291
+ with gr.Accordion("Raw Pointmap", open=False):
292
+ out_ply_file = gr.File(label="Point cloud (.ply — open in MeshLab/CloudCompare/Blender)")
293
+
294
+ run.click(predict, inputs=[inp, size], outputs=[out_ply, out_ply_file])
295
 
296
 
297
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -19,3 +19,4 @@ prettytable
19
  termcolor
20
  accelerate
21
  rich
 
 
19
  termcolor
20
  accelerate
21
  rich
22
+ open3d