Rawal Khirodkar commited on
Commit
2482c8d
·
1 Parent(s): 2593450

Pointmap: pivot to depth-z heatmap (turbo); drop Model3D + Open3D + .ply pipeline

Browse files
Files changed (2) hide show
  1. app.py +45 -86
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,9 +1,8 @@
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
- Image → per-pixel 3D pointmap (camera frame, metric units). The result is
4
- exported as a .ply point cloud and rendered with Gradio's Model3D component
5
- for interactive 3D viewing. Optionally applies a v1 binary fg/bg mask so only
6
- foreground points end up in the cloud.
7
  """
8
 
9
  import sys
@@ -15,7 +14,6 @@ import tempfile
15
  import cv2
16
  import gradio as gr
17
  import numpy as np
18
- import open3d as o3d
19
  import spaces
20
  import torch
21
  import torch.nn.functional as F
@@ -95,20 +93,16 @@ def _get_fg_model():
95
 
96
 
97
  # Iteration mode: only preload the default (0.4B) for fast Space boot.
98
- # Re-enable full preload by uncommenting the loop below.
99
  print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
100
  _get_pointmap_model(DEFAULT_SIZE)
101
  _get_fg_model()
102
- # for _size in POINTMAP_MODELS:
103
- # _get_pointmap_model(_size)
104
  print("[startup] ready.")
105
 
106
 
107
  # -----------------------------------------------------------------------------
108
- # Inference
109
 
110
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
111
- h0, w0 = image_bgr.shape[:2]
112
  data = model.pipeline(dict(img=image_bgr))
113
  data = model.data_preprocessor(data)
114
  inputs, data_samples = data["inputs"], data["data_samples"]
@@ -119,15 +113,13 @@ def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
119
  pointmap, scale = model(inputs)
120
  pointmap = pointmap / scale # → metric units
121
 
122
- pad = data_samples["meta"]["padding_size"]
123
- pad_left, pad_right, pad_top, pad_bottom = pad
124
  pointmap = pointmap[
125
  :, :,
126
  pad_top : inputs.shape[2] - pad_bottom,
127
  pad_left : inputs.shape[3] - pad_right,
128
  ]
129
- pointmap = F.interpolate(pointmap, size=(h0, w0), mode="bilinear", align_corners=False)
130
- return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3)
131
 
132
 
133
  def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
@@ -139,81 +131,58 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
139
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
140
 
141
 
142
- # -----------------------------------------------------------------------------
143
- # Point cloud export
144
-
145
- def _camera_marker(radius: float = 0.04, n_points: int = 800,
146
- color=(0.20, 0.55, 0.96)) -> o3d.geometry.PointCloud:
147
- """Small uniformly-blue sphere at the world origin marking the camera.
148
 
149
- Manual Fibonacci-sphere sampling instant, vs Open3D's poisson-disk which
150
- can take seconds per call.
151
  """
152
- rng = np.random.default_rng(0)
153
- i = np.arange(n_points)
154
- phi = np.arccos(1 - 2 * (i + 0.5) / n_points) # latitude
155
- theta = np.pi * (1 + 5 ** 0.5) * (i + 0.5) # golden-angle longitude
156
- pts = np.stack([
157
- radius * np.sin(phi) * np.cos(theta),
158
- radius * np.sin(phi) * np.sin(theta),
159
- radius * np.cos(phi),
160
- ], axis=1)
161
- pc = o3d.geometry.PointCloud()
162
- pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
163
- pc.colors = o3d.utility.Vector3dVector(np.tile(color, (n_points, 1)).astype(np.float64))
164
- return pc
165
-
166
-
167
- def _make_ply(image_rgb: np.ndarray, pointmap_hwc: np.ndarray, mask_hw: np.ndarray | None = None,
168
- max_points: int = 200_000) -> str:
169
- pts = pointmap_hwc.reshape(-1, 3)
170
- cols = (image_rgb.reshape(-1, 3).astype(np.float32) / 255.0)
171
-
172
- z = pts[:, 2]
173
- finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0)
174
- if mask_hw is not None:
175
- finite &= mask_hw.reshape(-1)
176
- pts, cols = pts[finite], cols[finite]
177
-
178
- if len(pts) > max_points:
179
- idx = np.random.default_rng(0).choice(len(pts), size=max_points, replace=False)
180
- pts, cols = pts[idx], cols[idx]
181
-
182
- pc = o3d.geometry.PointCloud()
183
- pc.points = o3d.utility.Vector3dVector(pts.astype(np.float64))
184
- pc.colors = o3d.utility.Vector3dVector(cols.astype(np.float64))
185
-
186
- # Add the camera marker (blue ball at origin) so users see where the
187
- # observer is in the reconstructed 3D scene.
188
- pc += _camera_marker()
189
-
190
- out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".ply").name
191
- o3d.io.write_point_cloud(out_path, pc, write_ascii=False)
192
- return out_path
193
 
194
 
195
  # -----------------------------------------------------------------------------
196
  # Gradio handler
197
 
198
- @spaces.GPU(duration=180)
199
  def predict(image: Image.Image, size: str):
200
  if image is None:
201
- return None, None
202
 
203
  image_pil = image.convert("RGB")
204
- image_rgb = np.array(image_pil)
205
- image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
206
- h0, w0 = image_rgb.shape[:2]
207
 
208
  model = _get_pointmap_model(size)
209
- pointmap = _estimate_pointmap(image_bgr, model)
 
210
 
211
- # Foreground masking is mandatory keeps the cloud clean and the camera
212
- # marker meaningful (background depth is unreliable).
213
- mask = _foreground_mask(image_pil, h0, w0)
214
- ply_path = _make_ply(image_rgb, pointmap, mask)
215
 
216
- return ply_path, ply_path
 
 
 
217
 
218
 
219
  # -----------------------------------------------------------------------------
@@ -281,14 +250,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
281
 
282
  with gr.Row(equal_height=True):
283
  inp = gr.Image(label="Input", type="pil", height=640)
284
- out_ply = gr.Model3D(
285
- label="Point cloud — drag to rotate, scroll to zoom, shift+drag to pan",
286
- height=640,
287
- clear_color=[0.07, 0.09, 0.13, 1.0], # subtle slate-900 backdrop
288
- display_mode="point_cloud",
289
- zoom_speed=0.7,
290
- pan_speed=0.5,
291
- )
292
 
293
  with gr.Row():
294
  size = gr.Radio(
@@ -301,10 +263,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
301
 
302
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
303
 
304
- with gr.Accordion("Raw Pointmap", open=False):
305
- out_ply_file = gr.File(label="Point cloud (.ply — open in MeshLab/CloudCompare/Blender)")
306
-
307
- run.click(predict, inputs=[inp, size], outputs=[out_ply, out_ply_file])
308
 
309
 
310
  if __name__ == "__main__":
 
1
  """Sapiens2 pointmap Gradio Space.
2
 
3
+ Image → per-pixel 3D pointmap (camera frame, metric units). For now we just
4
+ visualize the depth (z) channel as a colored heatmap, matching the look of the
5
+ normal demo. The 3D point-cloud viewer can be re-enabled later.
 
6
  """
7
 
8
  import sys
 
14
  import cv2
15
  import gradio as gr
16
  import numpy as np
 
17
  import spaces
18
  import torch
19
  import torch.nn.functional as F
 
93
 
94
 
95
  # Iteration mode: only preload the default (0.4B) for fast Space boot.
 
96
  print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
97
  _get_pointmap_model(DEFAULT_SIZE)
98
  _get_fg_model()
 
 
99
  print("[startup] ready.")
100
 
101
 
102
  # -----------------------------------------------------------------------------
103
+ # Inference (operates at the model's native resolution — no big upsamples)
104
 
105
  def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
 
106
  data = model.pipeline(dict(img=image_bgr))
107
  data = model.data_preprocessor(data)
108
  inputs, data_samples = data["inputs"], data["data_samples"]
 
113
  pointmap, scale = model(inputs)
114
  pointmap = pointmap / scale # → metric units
115
 
116
+ pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
 
117
  pointmap = pointmap[
118
  :, :,
119
  pad_top : inputs.shape[2] - pad_bottom,
120
  pad_left : inputs.shape[3] - pad_right,
121
  ]
122
+ return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3)
 
123
 
124
 
125
  def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
 
131
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
132
 
133
 
134
+ def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
135
+ """Depth (H, W) → RGB (H, W, 3) uint8 via inverse-depth turbo colormap.
 
 
 
 
136
 
137
+ Inverse-depth (1/z) gives more contrast on near surfaces (where humans tend
138
+ to be), which matches what most SfM/depth viewers show.
139
  """
140
+ valid = np.isfinite(depth) & (depth > 1e-3)
141
+ if mask is not None:
142
+ valid &= mask
143
+ if not valid.any():
144
+ return np.zeros((*depth.shape, 3), dtype=np.uint8)
145
+
146
+ inv = np.zeros_like(depth, dtype=np.float32)
147
+ inv[valid] = 1.0 / depth[valid]
148
+ p1, p99 = np.percentile(inv[valid], [1, 99])
149
+ lo, hi = float(p1), float(p99)
150
+ if hi <= lo:
151
+ hi = lo + 1e-3
152
+ norm = np.zeros_like(inv, dtype=np.float32)
153
+ norm[valid] = ((inv[valid] - lo) / (hi - lo)).clip(0, 1)
154
+ grey = (norm * 255.0).astype(np.uint8)
155
+
156
+ # cv2.applyColorMap returns BGR flip to RGB for Gradio.
157
+ rgb = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1].copy()
158
+ if mask is not None:
159
+ rgb[~mask] = 0 # background → black
160
+ return rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  # -----------------------------------------------------------------------------
164
  # Gradio handler
165
 
166
+ @spaces.GPU(duration=120)
167
  def predict(image: Image.Image, size: str):
168
  if image is None:
169
+ return None
170
 
171
  image_pil = image.convert("RGB")
172
+ image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
 
 
173
 
174
  model = _get_pointmap_model(size)
175
+ pointmap = _estimate_pointmap(image_bgr, model) # (H_native, W_native, 3)
176
+ h_n, w_n = pointmap.shape[:2]
177
 
178
+ mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
179
+ depth = pointmap[:, :, 2] # z channel
180
+ rgb_native = _depth_to_rgb(depth, mask) # (H_native, W_native, 3) uint8
 
181
 
182
+ # Lanczos upsample the RGB heatmap to the original image size — sharp.
183
+ w0, h0 = image_pil.size
184
+ rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS)
185
+ return rgb_pil
186
 
187
 
188
  # -----------------------------------------------------------------------------
 
250
 
251
  with gr.Row(equal_height=True):
252
  inp = gr.Image(label="Input", type="pil", height=640)
253
+ out_img = gr.Image(label="Depth (turbo)", type="pil", height=640)
 
 
 
 
 
 
 
254
 
255
  with gr.Row():
256
  size = gr.Radio(
 
263
 
264
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
265
 
266
+ run.click(predict, inputs=[inp, size], outputs=[out_img])
 
 
 
267
 
268
 
269
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -19,4 +19,3 @@ prettytable
19
  termcolor
20
  accelerate
21
  rich
22
- open3d
 
19
  termcolor
20
  accelerate
21
  rich