Spaces:
Running on Zero
Running on Zero
Rawal Khirodkar commited on
Commit Β·
2070091
1
Parent(s): 9884195
Normal: copy seg aesthetic; fix output (unpad + drop bogus channel swap); 0.4B-only preload
Browse files
app.py
CHANGED
|
@@ -3,8 +3,7 @@
|
|
| 3 |
Image β per-pixel surface normals. Visualized by RGB-encoding the unit-length
|
| 4 |
(x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
|
| 5 |
|
| 6 |
-
Optionally applies a v1
|
| 7 |
-
shown (background reads as a flat colour).
|
| 8 |
"""
|
| 9 |
|
| 10 |
import sys
|
|
@@ -23,7 +22,7 @@ from PIL import Image
|
|
| 23 |
from torchvision import transforms
|
| 24 |
|
| 25 |
from huggingface_hub import hf_hub_download
|
| 26 |
-
from sapiens.dense.models import NormalEstimator, init_model #
|
| 27 |
_ = NormalEstimator
|
| 28 |
|
| 29 |
|
|
@@ -55,9 +54,9 @@ NORMAL_MODELS = {
|
|
| 55 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
|
| 56 |
},
|
| 57 |
}
|
| 58 |
-
DEFAULT_SIZE = "
|
| 59 |
|
| 60 |
-
# v1 binary fg/bg TorchScript model
|
| 61 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 62 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 63 |
BG_OPTIONS = ["fg-bg", "no-bg-removal"]
|
|
@@ -65,7 +64,6 @@ DEFAULT_BG = "fg-bg"
|
|
| 65 |
|
| 66 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 67 |
|
| 68 |
-
# Pre-process for v1 fg-bg model (matches v1 sapiens-normal Space recipe).
|
| 69 |
_fg_transform = transforms.Compose([
|
| 70 |
transforms.Resize((1024, 768)),
|
| 71 |
transforms.ToTensor(),
|
|
@@ -94,51 +92,56 @@ def _get_fg_model():
|
|
| 94 |
global _fg_model
|
| 95 |
if _fg_model is None:
|
| 96 |
ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
|
| 97 |
-
|
| 98 |
-
_fg_model = model
|
| 99 |
return _fg_model
|
| 100 |
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
| 105 |
_get_fg_model()
|
|
|
|
|
|
|
| 106 |
print("[startup] ready.")
|
| 107 |
|
| 108 |
|
| 109 |
# -----------------------------------------------------------------------------
|
| 110 |
-
# Inference
|
| 111 |
|
| 112 |
def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 113 |
h0, w0 = image_bgr.shape[:2]
|
| 114 |
-
data = model.pipeline(dict(img=image_bgr))
|
| 115 |
-
data = model.data_preprocessor(data)
|
| 116 |
-
inputs = data["inputs"]
|
| 117 |
-
if inputs.ndim == 3:
|
| 118 |
-
inputs = inputs.unsqueeze(0)
|
| 119 |
|
| 120 |
with torch.no_grad():
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
|
| 130 |
-
"""Returns a (H, W) bool mask using the v1 binary fg/bg torchscript model."""
|
| 131 |
fg = _get_fg_model()
|
| 132 |
inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 133 |
with torch.no_grad():
|
| 134 |
-
out = fg(inputs)
|
| 135 |
out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
| 136 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 137 |
|
| 138 |
|
| 139 |
def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
|
| 140 |
-
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
# -----------------------------------------------------------------------------
|
|
@@ -155,21 +158,17 @@ def predict(image: Image.Image, size: str, bg_mode: str):
|
|
| 155 |
h0, w0 = image_rgb.shape[:2]
|
| 156 |
|
| 157 |
model = _get_normal_model(size)
|
| 158 |
-
|
| 159 |
|
| 160 |
-
raw =
|
| 161 |
if bg_mode == "fg-bg":
|
| 162 |
mask = _foreground_mask(image_pil, h0, w0)
|
| 163 |
raw[~mask] = np.nan
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
rgb[~mask] = 128
|
| 167 |
-
else:
|
| 168 |
-
rgb = _normal_to_rgb(normals)
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
npy_path = f.name
|
| 173 |
|
| 174 |
return Image.fromarray(rgb), npy_path
|
| 175 |
|
|
@@ -183,34 +182,83 @@ EXAMPLES = sorted(
|
|
| 183 |
if n.lower().endswith((".jpg", ".jpeg", ".png"))
|
| 184 |
)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
with gr.Row():
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
|
| 216 |
|
|
|
|
| 3 |
Image β per-pixel surface normals. Visualized by RGB-encoding the unit-length
|
| 4 |
(x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
|
| 5 |
|
| 6 |
+
Optionally applies a v1 binary fg/bg mask so background pixels are blacked out.
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import sys
|
|
|
|
| 22 |
from torchvision import transforms
|
| 23 |
|
| 24 |
from huggingface_hub import hf_hub_download
|
| 25 |
+
from sapiens.dense.models import NormalEstimator, init_model # registers NormalEstimator
|
| 26 |
_ = NormalEstimator
|
| 27 |
|
| 28 |
|
|
|
|
| 54 |
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
|
| 55 |
},
|
| 56 |
}
|
| 57 |
+
DEFAULT_SIZE = "0.4B" # iteration mode β only this is preloaded; others lazy-load on click
|
| 58 |
|
| 59 |
+
# v1 binary fg/bg TorchScript model.
|
| 60 |
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
|
| 61 |
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
|
| 62 |
BG_OPTIONS = ["fg-bg", "no-bg-removal"]
|
|
|
|
| 64 |
|
| 65 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
|
|
|
|
| 67 |
_fg_transform = transforms.Compose([
|
| 68 |
transforms.Resize((1024, 768)),
|
| 69 |
transforms.ToTensor(),
|
|
|
|
| 92 |
global _fg_model
|
| 93 |
if _fg_model is None:
|
| 94 |
ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
|
| 95 |
+
_fg_model = torch.jit.load(ckpt).eval().to(DEVICE)
|
|
|
|
| 96 |
return _fg_model
|
| 97 |
|
| 98 |
|
| 99 |
+
# Iteration mode: only preload the default (0.4B) for fast Space boot.
|
| 100 |
+
# Re-enable full preload by uncommenting the loop below.
|
| 101 |
+
print("[startup] pre-loading 0.4B (iteration mode) + fg/bg ...")
|
| 102 |
+
_get_normal_model(DEFAULT_SIZE)
|
| 103 |
_get_fg_model()
|
| 104 |
+
# for _size in NORMAL_MODELS:
|
| 105 |
+
# _get_normal_model(_size)
|
| 106 |
print("[startup] ready.")
|
| 107 |
|
| 108 |
|
| 109 |
# -----------------------------------------------------------------------------
|
| 110 |
+
# Inference (mirrors sapiens/dense/tools/vis/vis_normal.py)
|
| 111 |
|
| 112 |
def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
|
| 113 |
h0, w0 = image_bgr.shape[:2]
|
| 114 |
+
data = model.pipeline(dict(img=image_bgr)) # resize + pad
|
| 115 |
+
data = model.data_preprocessor(data) # normalize + batch
|
| 116 |
+
inputs, data_samples = data["inputs"], data["data_samples"]
|
|
|
|
|
|
|
| 117 |
|
| 118 |
with torch.no_grad():
|
| 119 |
+
normal = model(inputs) # (1, 3, padded_H, padded_W)
|
| 120 |
+
normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
|
| 121 |
|
| 122 |
+
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
|
| 123 |
+
normal = normal[
|
| 124 |
+
:, :,
|
| 125 |
+
pad_top : inputs.shape[2] - pad_bottom,
|
| 126 |
+
pad_left : inputs.shape[3] - pad_right,
|
| 127 |
+
]
|
| 128 |
+
normal = F.interpolate(normal, size=(h0, w0), mode="bilinear", align_corners=False)
|
| 129 |
+
return normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3) in [-1, 1]
|
| 130 |
|
| 131 |
|
| 132 |
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
|
|
|
|
| 133 |
fg = _get_fg_model()
|
| 134 |
inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 135 |
with torch.no_grad():
|
| 136 |
+
out = fg(inputs) # (1, K, H, W) logits
|
| 137 |
out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
| 138 |
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
|
| 139 |
|
| 140 |
|
| 141 |
def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
|
| 142 |
+
"""(H, W, 3) in [-1, 1] β (H, W, 3) uint8 RGB. NO channel swap (the swap in
|
| 143 |
+
vis_normal.py is purely for cv2.imwrite's BGR convention)."""
|
| 144 |
+
return (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
|
| 145 |
|
| 146 |
|
| 147 |
# -----------------------------------------------------------------------------
|
|
|
|
| 158 |
h0, w0 = image_rgb.shape[:2]
|
| 159 |
|
| 160 |
model = _get_normal_model(size)
|
| 161 |
+
normal = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1]
|
| 162 |
|
| 163 |
+
raw = normal.copy()
|
| 164 |
if bg_mode == "fg-bg":
|
| 165 |
mask = _foreground_mask(image_pil, h0, w0)
|
| 166 |
raw[~mask] = np.nan
|
| 167 |
+
normal[~mask] = -1.0 # β RGB(0,0,0) after vis
|
| 168 |
+
rgb = _normal_to_rgb(normal)
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name
|
| 171 |
+
np.save(npy_path, raw.astype(np.float32))
|
|
|
|
| 172 |
|
| 173 |
return Image.fromarray(rgb), npy_path
|
| 174 |
|
|
|
|
| 182 |
if n.lower().endswith((".jpg", ".jpeg", ".png"))
|
| 183 |
)
|
| 184 |
|
| 185 |
+
CUSTOM_CSS = """
|
| 186 |
+
:root, body, .gradio-container, button, input, select, textarea,
|
| 187 |
+
.gradio-container *:not(code):not(pre) {
|
| 188 |
+
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
|
| 189 |
+
-webkit-font-smoothing: antialiased;
|
| 190 |
+
-moz-osx-font-smoothing: grayscale;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
#title { text-align: center; font-size: 44px; font-weight: 700;
|
| 194 |
+
letter-spacing: -0.01em; margin: 28px 0 4px;
|
| 195 |
+
background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
|
| 196 |
+
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
|
| 197 |
+
background-clip: text; }
|
| 198 |
+
#subtitle { text-align: center; font-size: 12px; color: #64748b;
|
| 199 |
+
letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
|
| 200 |
+
font-weight: 500; }
|
| 201 |
+
#badges { display: flex; justify-content: center; flex-wrap: wrap;
|
| 202 |
+
gap: 8px; margin: 0 0 32px; }
|
| 203 |
+
.pill { display: inline-flex; align-items: center; gap: 6px;
|
| 204 |
+
padding: 7px 14px; border-radius: 999px;
|
| 205 |
+
background: #f1f5f9; color: #0f172a !important;
|
| 206 |
+
font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
|
| 207 |
+
text-decoration: none !important; border: 1px solid #e2e8f0;
|
| 208 |
+
transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
|
| 209 |
+
.pill:hover { background: #0f172a; color: #f8fafc !important;
|
| 210 |
+
border-color: #0f172a; transform: translateY(-1px); }
|
| 211 |
+
.pill svg { width: 14px; height: 14px; }
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
HEADER_HTML = """
|
| 215 |
+
<div id="title">Sapiens2: Normal</div>
|
| 216 |
+
<div id="subtitle">ICLR 2026</div>
|
| 217 |
+
<div id="badges">
|
| 218 |
+
<a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
|
| 219 |
+
<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
|
| 220 |
+
Code
|
| 221 |
+
</a>
|
| 222 |
+
<a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
|
| 223 |
+
π€ Models
|
| 224 |
+
</a>
|
| 225 |
+
<a class="pill" href="https://openreview.net/pdf?id=IVAlYCqdvW" target="_blank" rel="noopener">
|
| 226 |
+
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
|
| 227 |
+
Paper
|
| 228 |
+
</a>
|
| 229 |
+
<a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
|
| 230 |
+
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
|
| 231 |
+
Project
|
| 232 |
+
</a>
|
| 233 |
+
</div>
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
| 237 |
+
gr.HTML(HEADER_HTML)
|
| 238 |
+
|
| 239 |
+
with gr.Row(equal_height=True):
|
| 240 |
+
inp = gr.Image(label="Input", type="pil", height=640)
|
| 241 |
+
out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil", height=640)
|
| 242 |
+
|
| 243 |
with gr.Row():
|
| 244 |
+
size = gr.Radio(
|
| 245 |
+
choices=list(NORMAL_MODELS.keys()),
|
| 246 |
+
value=DEFAULT_SIZE,
|
| 247 |
+
label="Model",
|
| 248 |
+
scale=2,
|
| 249 |
+
)
|
| 250 |
+
bg = gr.Radio(
|
| 251 |
+
choices=BG_OPTIONS,
|
| 252 |
+
value=DEFAULT_BG,
|
| 253 |
+
label="Background",
|
| 254 |
+
scale=2,
|
| 255 |
+
)
|
| 256 |
+
run = gr.Button("Run", variant="primary", size="lg", scale=1)
|
| 257 |
+
|
| 258 |
+
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
|
| 259 |
+
|
| 260 |
+
with gr.Accordion("Original Res + Raw Normals", open=False):
|
| 261 |
+
out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg removed)")
|
| 262 |
|
| 263 |
run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
|
| 264 |
|