Rawal Khirodkar commited on
Commit
b66298c
Β·
1 Parent(s): 2c70f2e

Pointmap: depth heatmap on the right with solid grey bg; .ply still in accordion

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -134,6 +134,26 @@ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np
134
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # -----------------------------------------------------------------------------
138
  # Point cloud export (camera marker + cloud, native-res grid)
139
 
@@ -194,13 +214,24 @@ def predict(image: Image.Image, size: str):
194
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
195
 
196
  model = _get_pointmap_model(size)
197
- pointmap = _estimate_pointmap(image_bgr, model) # (H_n, W_n, 3) β€” at most 1024 in either dim
198
  h_n, w_n = pointmap.shape[:2]
199
 
200
  mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
201
- ply_path = _make_ply(image_pil, pointmap, mask) # native-res .ply
202
 
203
- return ply_path, ply_path
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  # -----------------------------------------------------------------------------
@@ -268,14 +299,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
268
 
269
  with gr.Row(equal_height=True):
270
  inp = gr.Image(label="Input", type="pil", height=640)
271
- out_ply = gr.Model3D(
272
- label="Point cloud β€” drag to rotate, scroll to zoom, shift+drag to pan",
273
- height=640,
274
- clear_color=[0.07, 0.09, 0.13, 1.0],
275
- display_mode="point_cloud",
276
- zoom_speed=0.7,
277
- pan_speed=0.5,
278
- )
279
 
280
  with gr.Row():
281
  size = gr.Radio(
@@ -291,7 +315,7 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS
291
  with gr.Accordion("Raw Pointmap", open=False):
292
  out_ply_file = gr.File(label="Point cloud (.ply β€” open in MeshLab/CloudCompare/Blender)")
293
 
294
- run.click(predict, inputs=[inp, size], outputs=[out_ply, out_ply_file])
295
 
296
 
297
  if __name__ == "__main__":
 
134
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
135
 
136
 
137
+ def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray:
138
+ """Inverse-depth turbo colormap (matches sapiens2 vis_pointmap.py).
139
+ Background pixels are left at 0 β€” caller should overlay them."""
140
+ valid = np.isfinite(depth) & (depth > 1e-3) & mask
141
+ rgb = np.zeros((*depth.shape, 3), dtype=np.uint8)
142
+ if not valid.any():
143
+ return rgb
144
+ inv = np.zeros_like(depth, dtype=np.float32)
145
+ inv[valid] = 1.0 / depth[valid]
146
+ p1, p99 = np.percentile(inv[valid], [1, 99])
147
+ lo, hi = float(p1), float(p99)
148
+ if hi <= lo:
149
+ hi = lo + 1e-3
150
+ norm = ((inv - lo) / (hi - lo)).clip(0, 1)
151
+ grey = (norm * 255.0).astype(np.uint8)
152
+ color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1] # cv2 is BGR β†’ RGB
153
+ rgb[valid] = color[valid]
154
+ return rgb
155
+
156
+
157
  # -----------------------------------------------------------------------------
158
  # Point cloud export (camera marker + cloud, native-res grid)
159
 
 
214
  image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
215
 
216
  model = _get_pointmap_model(size)
217
+ pointmap = _estimate_pointmap(image_bgr, model) # (H_n, W_n, 3) at most 1024 in either dim
218
  h_n, w_n = pointmap.shape[:2]
219
 
220
  mask = _foreground_mask(image_pil, h_n, w_n) # native-res mask, fast
 
221
 
222
+ # Depth heatmap (right pane). Solid mid-grey background with the foreground
223
+ # turbo-coloured by inverse depth. Mirrors sapiens2 vis_pointmap.py colormap.
224
+ depth = pointmap[:, :, 2]
225
+ depth_rgb = _depth_to_rgb(depth, mask)
226
+ BG_GREY = 200
227
+ depth_rgb[~mask] = BG_GREY
228
+ w0, h0 = image_pil.size
229
+ depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS)
230
+
231
+ # PLY (download in accordion). Native-res, ≀200K points.
232
+ ply_path = _make_ply(image_pil, pointmap, mask)
233
+
234
+ return depth_pil, ply_path
235
 
236
 
237
  # -----------------------------------------------------------------------------
 
299
 
300
  with gr.Row(equal_height=True):
301
  inp = gr.Image(label="Input", type="pil", height=640)
302
+ out_img = gr.Image(label="Depth (Z)", type="pil", height=640)
 
 
 
 
 
 
 
303
 
304
  with gr.Row():
305
  size = gr.Radio(
 
315
  with gr.Accordion("Raw Pointmap", open=False):
316
  out_ply_file = gr.File(label="Point cloud (.ply β€” open in MeshLab/CloudCompare/Blender)")
317
 
318
+ run.click(predict, inputs=[inp, size], outputs=[out_img, out_ply_file])
319
 
320
 
321
  if __name__ == "__main__":