Rawal Khirodkar commited on
Commit
a0fd52f
·
1 Parent(s): 2f57cf8

Add fg-bg dropdown using v1 binary segmentation TorchScript model

Browse files
Files changed (1) hide show
  1. app.py +61 -14
app.py CHANGED
@@ -3,6 +3,9 @@
3
  Image → per-pixel 3D pointmap (camera frame, metric units). The result is
4
  exported as a .ply point cloud and rendered with Gradio's Model3D component
5
  for interactive 3D viewing.
 
 
 
6
  """
7
 
8
  import sys
@@ -19,6 +22,7 @@ import spaces
19
  import torch
20
  import torch.nn.functional as F
21
  from PIL import Image
 
22
 
23
  from huggingface_hub import hf_hub_download
24
  from sapiens.dense.models import PointmapEstimator, init_model # registers in registry
@@ -55,13 +59,26 @@ POINTMAP_MODELS = {
55
  }
56
  DEFAULT_SIZE = "1B"
57
 
 
 
 
 
 
58
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
59
 
 
 
 
 
 
 
 
60
 
61
  # -----------------------------------------------------------------------------
62
  # Model cache
63
 
64
  _pointmap_model_cache: dict = {}
 
65
 
66
 
67
  def _get_pointmap_model(size: str):
@@ -73,9 +90,18 @@ def _get_pointmap_model(size: str):
73
  return _pointmap_model_cache[size]
74
 
75
 
76
- print("[startup] pre-loading all pointmap sizes ...")
 
 
 
 
 
 
 
 
77
  for _size in POINTMAP_MODELS:
78
  _get_pointmap_model(_size)
 
79
  print("[startup] ready.")
80
 
81
 
@@ -105,17 +131,28 @@ def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
105
  return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3)
106
 
107
 
 
 
 
 
 
 
 
 
 
108
  # -----------------------------------------------------------------------------
109
  # Point cloud export
110
 
111
- def _make_ply(image_rgb: np.ndarray, pointmap_hwc: np.ndarray, max_points: int = 200_000) -> str:
112
- """Subsample, filter to a reasonable depth range, and write a .ply file."""
 
113
  pts = pointmap_hwc.reshape(-1, 3)
114
  cols = (image_rgb.reshape(-1, 3).astype(np.float32) / 255.0)
115
 
116
- # Drop points with non-finite or extreme depth
117
  z = pts[:, 2]
118
  finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0)
 
 
119
  pts, cols = pts[finite], cols[finite]
120
 
121
  if len(pts) > max_points:
@@ -135,16 +172,20 @@ def _make_ply(image_rgb: np.ndarray, pointmap_hwc: np.ndarray, max_points: int =
135
  # Gradio handler
136
 
137
  @spaces.GPU(duration=180)
138
- def predict(image: Image.Image, size: str):
139
  if image is None:
140
  return None, None
141
 
142
- image_rgb = np.array(image.convert("RGB"))
 
143
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
 
144
 
145
  model = _get_pointmap_model(size)
146
- pointmap = _estimate_pointmap(image_bgr, model) # (H, W, 3) metric, camera frame
147
- ply_path = _make_ply(image_rgb, pointmap)
 
 
148
 
149
  npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name
150
  np.save(npy_path, pointmap.astype(np.float32))
@@ -173,18 +214,24 @@ with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Default()) as demo:
173
  with gr.Row():
174
  with gr.Column():
175
  inp = gr.Image(label="Input", type="pil")
176
- size = gr.Radio(
177
- choices=list(POINTMAP_MODELS.keys()),
178
- value=DEFAULT_SIZE,
179
- label="Model size",
180
- )
 
 
 
 
 
 
181
  run = gr.Button("Run", variant="primary")
182
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
183
  with gr.Column():
184
  out_ply = gr.Model3D(label="Point cloud (drag to rotate)", clear_color=[0.05, 0.05, 0.05, 1.0])
185
  out_npy = gr.File(label="Raw pointmap (.npy float32 [H, W, 3] in meters)")
186
 
187
- run.click(predict, inputs=[inp, size], outputs=[out_ply, out_npy])
188
 
189
 
190
  if __name__ == "__main__":
 
3
  Image → per-pixel 3D pointmap (camera frame, metric units). The result is
4
  exported as a .ply point cloud and rendered with Gradio's Model3D component
5
  for interactive 3D viewing.
6
+
7
+ Optionally applies a v1 foreground/background mask so only person points end
8
+ up in the cloud (background is dropped entirely).
9
  """
10
 
11
  import sys
 
22
  import torch
23
  import torch.nn.functional as F
24
  from PIL import Image
25
+ from torchvision import transforms
26
 
27
  from huggingface_hub import hf_hub_download
28
  from sapiens.dense.models import PointmapEstimator, init_model # registers in registry
 
59
  }
60
  DEFAULT_SIZE = "1B"
61
 
62
+ FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
63
+ FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
64
+ BG_OPTIONS = ["fg-bg", "no-bg-removal"]
65
+ DEFAULT_BG = "fg-bg"
66
+
67
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
 
69
+ _fg_transform = transforms.Compose([
70
+ transforms.Resize((1024, 768)),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
73
+ std=[58.5 / 255, 57.0 / 255, 57.5 / 255]),
74
+ ])
75
+
76
 
77
  # -----------------------------------------------------------------------------
78
  # Model cache
79
 
80
  _pointmap_model_cache: dict = {}
81
+ _fg_model = None
82
 
83
 
84
  def _get_pointmap_model(size: str):
 
90
  return _pointmap_model_cache[size]
91
 
92
 
93
+ def _get_fg_model():
94
+ global _fg_model
95
+ if _fg_model is None:
96
+ ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
97
+ _fg_model = torch.jit.load(ckpt).eval().to(DEVICE)
98
+ return _fg_model
99
+
100
+
101
+ print("[startup] pre-loading all pointmap sizes + fg/bg model ...")
102
  for _size in POINTMAP_MODELS:
103
  _get_pointmap_model(_size)
104
+ _get_fg_model()
105
  print("[startup] ready.")
106
 
107
 
 
131
  return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3)
132
 
133
 
134
+ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
135
+ fg = _get_fg_model()
136
+ inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
137
+ with torch.no_grad():
138
+ out = fg(inputs)
139
+ out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
140
+ return (out.argmax(dim=1)[0] > 0).cpu().numpy()
141
+
142
+
143
  # -----------------------------------------------------------------------------
144
  # Point cloud export
145
 
146
+ def _make_ply(image_rgb: np.ndarray, pointmap_hwc: np.ndarray, mask_hw: np.ndarray | None = None,
147
+ max_points: int = 200_000) -> str:
148
+ """Subsample, optionally mask to foreground, and write a .ply file."""
149
  pts = pointmap_hwc.reshape(-1, 3)
150
  cols = (image_rgb.reshape(-1, 3).astype(np.float32) / 255.0)
151
 
 
152
  z = pts[:, 2]
153
  finite = np.isfinite(pts).all(axis=1) & (z > 0.05) & (z < 25.0)
154
+ if mask_hw is not None:
155
+ finite &= mask_hw.reshape(-1)
156
  pts, cols = pts[finite], cols[finite]
157
 
158
  if len(pts) > max_points:
 
172
  # Gradio handler
173
 
174
  @spaces.GPU(duration=180)
175
+ def predict(image: Image.Image, size: str, bg_mode: str):
176
  if image is None:
177
  return None, None
178
 
179
+ image_pil = image.convert("RGB")
180
+ image_rgb = np.array(image_pil)
181
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
182
+ h0, w0 = image_rgb.shape[:2]
183
 
184
  model = _get_pointmap_model(size)
185
+ pointmap = _estimate_pointmap(image_bgr, model)
186
+
187
+ mask = _foreground_mask(image_pil, h0, w0) if bg_mode == "fg-bg" else None
188
+ ply_path = _make_ply(image_rgb, pointmap, mask)
189
 
190
  npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name
191
  np.save(npy_path, pointmap.astype(np.float32))
 
214
  with gr.Row():
215
  with gr.Column():
216
  inp = gr.Image(label="Input", type="pil")
217
+ with gr.Row():
218
+ size = gr.Radio(
219
+ choices=list(POINTMAP_MODELS.keys()),
220
+ value=DEFAULT_SIZE,
221
+ label="Model size",
222
+ )
223
+ bg = gr.Radio(
224
+ choices=BG_OPTIONS,
225
+ value=DEFAULT_BG,
226
+ label="Background",
227
+ )
228
  run = gr.Button("Run", variant="primary")
229
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
230
  with gr.Column():
231
  out_ply = gr.Model3D(label="Point cloud (drag to rotate)", clear_color=[0.05, 0.05, 0.05, 1.0])
232
  out_npy = gr.File(label="Raw pointmap (.npy float32 [H, W, 3] in meters)")
233
 
234
+ run.click(predict, inputs=[inp, size, bg], outputs=[out_ply, out_npy])
235
 
236
 
237
  if __name__ == "__main__":