Rawal Khirodkar commited on
Commit
2070091
Β·
1 Parent(s): 9884195

Normal: copy seg aesthetic; fix output (unpad + drop bogus channel swap); 0.4B-only preload

Browse files
Files changed (1) hide show
  1. app.py +111 -63
app.py CHANGED
@@ -3,8 +3,7 @@
3
  Image β†’ per-pixel surface normals. Visualized by RGB-encoding the unit-length
4
  (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
5
 
6
- Optionally applies a v1 foreground/background mask so only person pixels are
7
- shown (background reads as a flat colour).
8
  """
9
 
10
  import sys
@@ -23,7 +22,7 @@ from PIL import Image
23
  from torchvision import transforms
24
 
25
  from huggingface_hub import hf_hub_download
26
- from sapiens.dense.models import NormalEstimator, init_model # NormalEstimator triggers registry
27
  _ = NormalEstimator
28
 
29
 
@@ -55,9 +54,9 @@ NORMAL_MODELS = {
55
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
56
  },
57
  }
58
- DEFAULT_SIZE = "1B"
59
 
60
- # v1 binary fg/bg TorchScript model β€” uses a different normalization (PIL β†’ tensor β†’ ImageNet).
61
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
62
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
63
  BG_OPTIONS = ["fg-bg", "no-bg-removal"]
@@ -65,7 +64,6 @@ DEFAULT_BG = "fg-bg"
65
 
66
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
 
68
- # Pre-process for v1 fg-bg model (matches v1 sapiens-normal Space recipe).
69
  _fg_transform = transforms.Compose([
70
  transforms.Resize((1024, 768)),
71
  transforms.ToTensor(),
@@ -94,51 +92,56 @@ def _get_fg_model():
94
  global _fg_model
95
  if _fg_model is None:
96
  ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
97
- model = torch.jit.load(ckpt).eval().to(DEVICE)
98
- _fg_model = model
99
  return _fg_model
100
 
101
 
102
- print("[startup] pre-loading all normal sizes + fg/bg model ...")
103
- for _size in NORMAL_MODELS:
104
- _get_normal_model(_size)
 
105
  _get_fg_model()
 
 
106
  print("[startup] ready.")
107
 
108
 
109
  # -----------------------------------------------------------------------------
110
- # Inference
111
 
112
  def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
113
  h0, w0 = image_bgr.shape[:2]
114
- data = model.pipeline(dict(img=image_bgr))
115
- data = model.data_preprocessor(data)
116
- inputs = data["inputs"]
117
- if inputs.ndim == 3:
118
- inputs = inputs.unsqueeze(0)
119
 
120
  with torch.no_grad():
121
- normals = model(inputs) # (1, 3, H, W)
 
122
 
123
- normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(1e-6)
124
- normals = F.interpolate(normals, size=(h0, w0), mode="bilinear", align_corners=False)
125
- normals = normals[0].cpu().float().numpy()
126
- return normals.transpose(1, 2, 0) # (H, W, 3)
 
 
 
 
127
 
128
 
129
  def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
130
- """Returns a (H, W) bool mask using the v1 binary fg/bg torchscript model."""
131
  fg = _get_fg_model()
132
  inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
133
  with torch.no_grad():
134
- out = fg(inputs) # (1, K, H, W) logits
135
  out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
136
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
137
 
138
 
139
  def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
140
- rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
141
- return rgb[:, :, ::-1]
 
142
 
143
 
144
  # -----------------------------------------------------------------------------
@@ -155,21 +158,17 @@ def predict(image: Image.Image, size: str, bg_mode: str):
155
  h0, w0 = image_rgb.shape[:2]
156
 
157
  model = _get_normal_model(size)
158
- normals = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
159
 
160
- raw = normals.copy()
161
  if bg_mode == "fg-bg":
162
  mask = _foreground_mask(image_pil, h0, w0)
163
  raw[~mask] = np.nan
164
- # For viz, show background as middle-grey rather than a saturated colour.
165
- rgb = _normal_to_rgb(normals)
166
- rgb[~mask] = 128
167
- else:
168
- rgb = _normal_to_rgb(normals)
169
 
170
- with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as f:
171
- np.save(f.name, raw.astype(np.float32))
172
- npy_path = f.name
173
 
174
  return Image.fromarray(rgb), npy_path
175
 
@@ -183,34 +182,83 @@ EXAMPLES = sorted(
183
  if n.lower().endswith((".jpg", ".jpeg", ".png"))
184
  )
185
 
186
- with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Default()) as demo:
187
- gr.Markdown(
188
- "# Sapiens2: Surface Normal Estimation\n"
189
- "### ICLR 2026\n"
190
- "Per-pixel surface-normal estimation. Output is RGB-encoded (x, y, z β†’ R, G, B).\n\n"
191
- "[Code](https://github.com/facebookresearch/sapiens2) Β· "
192
- "[Models](https://huggingface.co/facebook/sapiens2) Β· "
193
- "[Paper](https://openreview.net/pdf?id=IVAlYCqdvW)"
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Row():
196
- with gr.Column():
197
- inp = gr.Image(label="Input", type="pil")
198
- with gr.Row():
199
- size = gr.Radio(
200
- choices=list(NORMAL_MODELS.keys()),
201
- value=DEFAULT_SIZE,
202
- label="Model size",
203
- )
204
- bg = gr.Radio(
205
- choices=BG_OPTIONS,
206
- value=DEFAULT_BG,
207
- label="Background",
208
- )
209
- run = gr.Button("Run", variant="primary")
210
- gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
211
- with gr.Column():
212
- out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil")
213
- out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg)")
214
 
215
  run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
216
 
 
3
  Image β†’ per-pixel surface normals. Visualized by RGB-encoding the unit-length
4
  (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
5
 
6
+ Optionally applies a v1 binary fg/bg mask so background pixels are blacked out.
 
7
  """
8
 
9
  import sys
 
22
  from torchvision import transforms
23
 
24
  from huggingface_hub import hf_hub_download
25
+ from sapiens.dense.models import NormalEstimator, init_model # registers NormalEstimator
26
  _ = NormalEstimator
27
 
28
 
 
54
  "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
55
  },
56
  }
57
+ DEFAULT_SIZE = "0.4B" # iteration mode β€” only this is preloaded; others lazy-load on click
58
 
59
+ # v1 binary fg/bg TorchScript model.
60
  FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
61
  FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
62
  BG_OPTIONS = ["fg-bg", "no-bg-removal"]
 
64
 
65
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
66
 
 
67
  _fg_transform = transforms.Compose([
68
  transforms.Resize((1024, 768)),
69
  transforms.ToTensor(),
 
92
  global _fg_model
93
  if _fg_model is None:
94
  ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
95
+ _fg_model = torch.jit.load(ckpt).eval().to(DEVICE)
 
96
  return _fg_model
97
 
98
 
99
+ # Iteration mode: only preload the default (0.4B) for fast Space boot.
100
+ # Re-enable full preload by uncommenting the loop below.
101
+ print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
102
+ _get_normal_model(DEFAULT_SIZE)
103
  _get_fg_model()
104
+ # for _size in NORMAL_MODELS:
105
+ # _get_normal_model(_size)
106
  print("[startup] ready.")
107
 
108
 
109
  # -----------------------------------------------------------------------------
110
+ # Inference (mirrors sapiens/dense/tools/vis/vis_normal.py)
111
 
112
  def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
113
  h0, w0 = image_bgr.shape[:2]
114
+ data = model.pipeline(dict(img=image_bgr)) # resize + pad
115
+ data = model.data_preprocessor(data) # normalize + batch
116
+ inputs, data_samples = data["inputs"], data["data_samples"]
 
 
117
 
118
  with torch.no_grad():
119
+ normal = model(inputs) # (1, 3, padded_H, padded_W)
120
+ normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
121
 
122
+ pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
123
+ normal = normal[
124
+ :, :,
125
+ pad_top : inputs.shape[2] - pad_bottom,
126
+ pad_left : inputs.shape[3] - pad_right,
127
+ ]
128
+ normal = F.interpolate(normal, size=(h0, w0), mode="bilinear", align_corners=False)
129
+ return normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3) in [-1, 1]
130
 
131
 
132
  def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
 
133
  fg = _get_fg_model()
134
  inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
135
  with torch.no_grad():
136
+ out = fg(inputs) # (1, K, H, W) logits
137
  out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
138
  return (out.argmax(dim=1)[0] > 0).cpu().numpy()
139
 
140
 
141
  def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
142
+ """(H, W, 3) in [-1, 1] β†’ (H, W, 3) uint8 RGB. NO channel swap (the swap in
143
+ vis_normal.py is purely for cv2.imwrite's BGR convention)."""
144
+ return (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
145
 
146
 
147
  # -----------------------------------------------------------------------------
 
158
  h0, w0 = image_rgb.shape[:2]
159
 
160
  model = _get_normal_model(size)
161
+ normal = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
162
 
163
+ raw = normal.copy()
164
  if bg_mode == "fg-bg":
165
  mask = _foreground_mask(image_pil, h0, w0)
166
  raw[~mask] = np.nan
167
+ normal[~mask] = -1.0 # β†’ RGB(0,0,0) after vis
168
+ rgb = _normal_to_rgb(normal)
 
 
 
169
 
170
+ npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name
171
+ np.save(npy_path, raw.astype(np.float32))
 
172
 
173
  return Image.fromarray(rgb), npy_path
174
 
 
182
  if n.lower().endswith((".jpg", ".jpeg", ".png"))
183
  )
184
 
185
+ CUSTOM_CSS = """
186
+ :root, body, .gradio-container, button, input, select, textarea,
187
+ .gradio-container *:not(code):not(pre) {
188
+ font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
189
+ -webkit-font-smoothing: antialiased;
190
+ -moz-osx-font-smoothing: grayscale;
191
+ }
192
+
193
+ #title { text-align: center; font-size: 44px; font-weight: 700;
194
+ letter-spacing: -0.01em; margin: 28px 0 4px;
195
+ background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
196
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;
197
+ background-clip: text; }
198
+ #subtitle { text-align: center; font-size: 12px; color: #64748b;
199
+ letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
200
+ font-weight: 500; }
201
+ #badges { display: flex; justify-content: center; flex-wrap: wrap;
202
+ gap: 8px; margin: 0 0 32px; }
203
+ .pill { display: inline-flex; align-items: center; gap: 6px;
204
+ padding: 7px 14px; border-radius: 999px;
205
+ background: #f1f5f9; color: #0f172a !important;
206
+ font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
207
+ text-decoration: none !important; border: 1px solid #e2e8f0;
208
+ transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
209
+ .pill:hover { background: #0f172a; color: #f8fafc !important;
210
+ border-color: #0f172a; transform: translateY(-1px); }
211
+ .pill svg { width: 14px; height: 14px; }
212
+ """
213
+
214
+ HEADER_HTML = """
215
+ <div id="title">Sapiens2: Normal</div>
216
+ <div id="subtitle">ICLR 2026</div>
217
+ <div id="badges">
218
+ <a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
219
+ <svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
220
+ Code
221
+ </a>
222
+ <a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
223
+ πŸ€— Models
224
+ </a>
225
+ <a class="pill" href="https://openreview.net/pdf?id=IVAlYCqdvW" target="_blank" rel="noopener">
226
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
227
+ Paper
228
+ </a>
229
+ <a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
230
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
231
+ Project
232
+ </a>
233
+ </div>
234
+ """
235
+
236
+ with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
237
+ gr.HTML(HEADER_HTML)
238
+
239
+ with gr.Row(equal_height=True):
240
+ inp = gr.Image(label="Input", type="pil", height=640)
241
+ out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil", height=640)
242
+
243
  with gr.Row():
244
+ size = gr.Radio(
245
+ choices=list(NORMAL_MODELS.keys()),
246
+ value=DEFAULT_SIZE,
247
+ label="Model",
248
+ scale=2,
249
+ )
250
+ bg = gr.Radio(
251
+ choices=BG_OPTIONS,
252
+ value=DEFAULT_BG,
253
+ label="Background",
254
+ scale=2,
255
+ )
256
+ run = gr.Button("Run", variant="primary", size="lg", scale=1)
257
+
258
+ gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
259
+
260
+ with gr.Accordion("Original Res + Raw Normals", open=False):
261
+ out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg removed)")
262
 
263
  run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
264