Rawal Khirodkar commited on
Commit
d84d54c
·
1 Parent(s): 977839e

Add fg-bg dropdown using v1 binary segmentation TorchScript model

Browse files
Files changed (1) hide show
  1. app.py +70 -15
app.py CHANGED
@@ -2,6 +2,9 @@
2
 
3
  Image → per-pixel surface normals. Visualized by RGB-encoding the unit-length
4
  (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
 
 
 
5
  """
6
 
7
  import sys
@@ -17,6 +20,7 @@ import spaces
17
  import torch
18
  import torch.nn.functional as F
19
  from PIL import Image
 
20
 
21
  from huggingface_hub import hf_hub_download
22
  from sapiens.dense.models import NormalEstimator, init_model # NormalEstimator triggers registry
@@ -53,13 +57,28 @@ NORMAL_MODELS = {
53
  }
54
  DEFAULT_SIZE = "1B"
55
 
 
 
 
 
 
 
56
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
 
 
 
 
 
 
 
 
 
58
 
59
  # -----------------------------------------------------------------------------
60
  # Model cache
61
 
62
  _normal_model_cache: dict = {}
 
63
 
64
 
65
  def _get_normal_model(size: str):
@@ -71,9 +90,19 @@ def _get_normal_model(size: str):
71
  return _normal_model_cache[size]
72
 
73
 
74
- print("[startup] pre-loading all normal sizes ...")
 
 
 
 
 
 
 
 
 
75
  for _size in NORMAL_MODELS:
76
  _get_normal_model(_size)
 
77
  print("[startup] ready.")
78
 
79
 
@@ -91,35 +120,55 @@ def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
91
  with torch.no_grad():
92
  normals = model(inputs) # (1, 3, H, W)
93
 
94
- # Unit-length normalization, interpolate to original size, cast to numpy
95
  normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(1e-6)
96
  normals = F.interpolate(normals, size=(h0, w0), mode="bilinear", align_corners=False)
97
- normals = normals[0].cpu().float().numpy() # (3, H, W) in [-1, 1]
98
  return normals.transpose(1, 2, 0) # (H, W, 3)
99
 
100
 
 
 
 
 
 
 
 
 
 
 
101
  def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
102
  rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
103
- return rgb[:, :, ::-1] # match training viz channel order
104
 
105
 
106
  # -----------------------------------------------------------------------------
107
  # Gradio handler
108
 
109
  @spaces.GPU(duration=120)
110
- def predict(image: Image.Image, size: str):
111
  if image is None:
112
  return None, None
113
 
114
- image_rgb = np.array(image.convert("RGB"))
 
115
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
 
116
 
117
  model = _get_normal_model(size)
118
  normals = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
119
- rgb = _normal_to_rgb(normals)
 
 
 
 
 
 
 
 
 
120
 
121
  with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as f:
122
- np.save(f.name, normals.astype(np.float32))
123
  npy_path = f.name
124
 
125
  return Image.fromarray(rgb), npy_path
@@ -146,18 +195,24 @@ with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Default()) as demo:
146
  with gr.Row():
147
  with gr.Column():
148
  inp = gr.Image(label="Input", type="pil")
149
- size = gr.Radio(
150
- choices=list(NORMAL_MODELS.keys()),
151
- value=DEFAULT_SIZE,
152
- label="Model size",
153
- )
 
 
 
 
 
 
154
  run = gr.Button("Run", variant="primary")
155
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
156
  with gr.Column():
157
  out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil")
158
- out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1])")
159
 
160
- run.click(predict, inputs=[inp, size], outputs=[out_img, out_npy])
161
 
162
 
163
  if __name__ == "__main__":
 
2
 
3
  Image → per-pixel surface normals. Visualized by RGB-encoding the unit-length
4
  (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
5
+
6
+ Optionally applies a v1 foreground/background mask so only person pixels are
7
+ shown (background reads as a flat colour).
8
  """
9
 
10
  import sys
 
20
  import torch
21
  import torch.nn.functional as F
22
  from PIL import Image
23
+ from torchvision import transforms
24
 
25
  from huggingface_hub import hf_hub_download
26
  from sapiens.dense.models import NormalEstimator, init_model # NormalEstimator triggers registry
 
57
  }
58
  DEFAULT_SIZE = "1B"
59
 
60
+ # v1 binary fg/bg TorchScript model — uses a different normalization (PIL → tensor → ImageNet).
61
+ FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
62
+ FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
63
+ BG_OPTIONS = ["fg-bg", "no-bg-removal"]
64
+ DEFAULT_BG = "fg-bg"
65
+
66
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
 
68
+ # Pre-process for v1 fg-bg model (matches v1 sapiens-normal Space recipe).
69
+ _fg_transform = transforms.Compose([
70
+ transforms.Resize((1024, 768)),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
73
+ std=[58.5 / 255, 57.0 / 255, 57.5 / 255]),
74
+ ])
75
+
76
 
77
  # -----------------------------------------------------------------------------
78
  # Model cache
79
 
80
  _normal_model_cache: dict = {}
81
+ _fg_model = None
82
 
83
 
84
  def _get_normal_model(size: str):
 
90
  return _normal_model_cache[size]
91
 
92
 
93
+ def _get_fg_model():
94
+ global _fg_model
95
+ if _fg_model is None:
96
+ ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
97
+ model = torch.jit.load(ckpt).eval().to(DEVICE)
98
+ _fg_model = model
99
+ return _fg_model
100
+
101
+
102
+ print("[startup] pre-loading all normal sizes + fg/bg model ...")
103
  for _size in NORMAL_MODELS:
104
  _get_normal_model(_size)
105
+ _get_fg_model()
106
  print("[startup] ready.")
107
 
108
 
 
120
  with torch.no_grad():
121
  normals = model(inputs) # (1, 3, H, W)
122
 
 
123
  normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(1e-6)
124
  normals = F.interpolate(normals, size=(h0, w0), mode="bilinear", align_corners=False)
125
+ normals = normals[0].cpu().float().numpy()
126
  return normals.transpose(1, 2, 0) # (H, W, 3)
127
 
128
 
129
+ def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
130
+ """Returns a (H, W) bool mask using the v1 binary fg/bg torchscript model."""
131
+ fg = _get_fg_model()
132
+ inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
133
+ with torch.no_grad():
134
+ out = fg(inputs) # (1, K, H, W) logits
135
+ out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
136
+ return (out.argmax(dim=1)[0] > 0).cpu().numpy()
137
+
138
+
139
  def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
140
  rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
141
+ return rgb[:, :, ::-1]
142
 
143
 
144
  # -----------------------------------------------------------------------------
145
  # Gradio handler
146
 
147
  @spaces.GPU(duration=120)
148
+ def predict(image: Image.Image, size: str, bg_mode: str):
149
  if image is None:
150
  return None, None
151
 
152
+ image_pil = image.convert("RGB")
153
+ image_rgb = np.array(image_pil)
154
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
155
+ h0, w0 = image_rgb.shape[:2]
156
 
157
  model = _get_normal_model(size)
158
  normals = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
159
+
160
+ raw = normals.copy()
161
+ if bg_mode == "fg-bg":
162
+ mask = _foreground_mask(image_pil, h0, w0)
163
+ raw[~mask] = np.nan
164
+ # For viz, show background as middle-grey rather than a saturated colour.
165
+ rgb = _normal_to_rgb(normals)
166
+ rgb[~mask] = 128
167
+ else:
168
+ rgb = _normal_to_rgb(normals)
169
 
170
  with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as f:
171
+ np.save(f.name, raw.astype(np.float32))
172
  npy_path = f.name
173
 
174
  return Image.fromarray(rgb), npy_path
 
195
  with gr.Row():
196
  with gr.Column():
197
  inp = gr.Image(label="Input", type="pil")
198
+ with gr.Row():
199
+ size = gr.Radio(
200
+ choices=list(NORMAL_MODELS.keys()),
201
+ value=DEFAULT_SIZE,
202
+ label="Model size",
203
+ )
204
+ bg = gr.Radio(
205
+ choices=BG_OPTIONS,
206
+ value=DEFAULT_BG,
207
+ label="Background",
208
+ )
209
  run = gr.Button("Run", variant="primary")
210
  gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
211
  with gr.Column():
212
  out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil")
213
+ out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg)")
214
 
215
+ run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
216
 
217
 
218
  if __name__ == "__main__":