Spaces:
Running on Zero
Running on Zero
Replace runtime TripoSG string-patching with pre-patched files in patches/triposg/
Browse filesAll three TripoSG scripts that previously required runtime str.replace() hacks are
now committed as fixed files under patches/triposg/ (mirroring upstream layout).
load_triposg() now simply shutil.copy2s them over the cloned repo after checkout.
Patches applied in the files:
- scripts/inference_triposg.py: remove pymeshlab import + helper fns, use trimesh.simplify_quadric_decimation
- scripts/image_process.py: empty contours guard, rmbg_net=None guard, all-zero alpha fallback
- triposg/inference_utils.py: diso optional try/except, hierarchical wrapper, queries dtype cast
- app.py +12 -152
- patches/triposg/scripts/image_process.py +159 -0
- patches/triposg/scripts/inference_triposg.py +90 -0
- patches/triposg/triposg/inference_utils.py +492 -0
app.py
CHANGED
|
@@ -365,6 +365,7 @@ def load_triposg():
|
|
| 365 |
return _triposg_pipe, _rmbg_net
|
| 366 |
|
| 367 |
print("[load_triposg] Loading TripoSG pipeline...")
|
|
|
|
| 368 |
from huggingface_hub import snapshot_download
|
| 369 |
|
| 370 |
# TripoSG source has no setup.py — clone GitHub repo and add to sys.path
|
|
@@ -377,160 +378,19 @@ def load_triposg():
|
|
| 377 |
str(triposg_src)],
|
| 378 |
check=True
|
| 379 |
)
|
| 380 |
-
if str(triposg_src) not in sys.path:
|
| 381 |
-
sys.path.insert(0, str(triposg_src))
|
| 382 |
|
| 383 |
-
#
|
| 384 |
-
#
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
" # seg from rmbg\n"
|
| 393 |
-
" if rmbg_net is None: # rmbg_net_none_guard_v2\n"
|
| 394 |
-
" alpha_gpu_rmbg = torch.ones(\n"
|
| 395 |
-
" 1, 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
|
| 396 |
-
" device=rgb_image_resized.device)\n"
|
| 397 |
-
" else:\n"
|
| 398 |
-
" alpha_gpu_rmbg = rmbg(rgb_image_resized)",
|
| 399 |
-
)
|
| 400 |
-
_ip_path.write_text(_ip_text)
|
| 401 |
-
print("[load_triposg] Patched image_process.py: rmbg_net None guard")
|
| 402 |
-
|
| 403 |
-
# Patch find_bounding_box: guard against empty contours (blank alpha mask).
|
| 404 |
-
# When RMBG produces an all-black mask, findContours returns [] and max() raises.
|
| 405 |
-
# Fallback: return the full image bounding box so pipeline can continue.
|
| 406 |
-
# NOTE: parameter is gray_image, not alpha.
|
| 407 |
-
_ip_text2 = _ip_path.read_text()
|
| 408 |
-
if "empty_contours_guard" not in _ip_text2:
|
| 409 |
-
_ip_text2 = _ip_text2.replace(
|
| 410 |
-
" max_contour = max(contours, key=cv2.contourArea)",
|
| 411 |
-
" if not contours: # empty_contours_guard\n"
|
| 412 |
-
" h, w = gray_image.shape[:2]\n"
|
| 413 |
-
" return 0, 0, w, h\n"
|
| 414 |
-
" max_contour = max(contours, key=cv2.contourArea)",
|
| 415 |
-
)
|
| 416 |
-
_ip_path.write_text(_ip_text2)
|
| 417 |
-
print("[load_triposg] Patched image_process.py: empty contours guard")
|
| 418 |
-
|
| 419 |
-
# Patch all-zero alpha guard: instead of raising ValueError("input image too small"),
|
| 420 |
-
# fall back to full-foreground alpha so the pipeline can continue with the whole image.
|
| 421 |
-
# Happens when RMBG produces a blank mask (e.g. remove_small_objects wipes everything).
|
| 422 |
-
_ip_text3 = _ip_path.read_text()
|
| 423 |
-
if "all_zero_alpha_guard" not in _ip_text3:
|
| 424 |
-
_ip_text3 = _ip_text3.replace(
|
| 425 |
-
' if np.all(alpha==0):\n raise ValueError(f"input image too small")',
|
| 426 |
-
" if np.all(alpha==0): # all_zero_alpha_guard\n"
|
| 427 |
-
" h_full, w_full = alpha.shape[:2]\n"
|
| 428 |
-
" alpha = np.full((h_full, w_full), 255, dtype=np.uint8)\n"
|
| 429 |
-
" alpha_gpu = torch.ones(1, h_full, w_full, dtype=torch.float32,\n"
|
| 430 |
-
" device=rgb_image_gpu.device)\n"
|
| 431 |
-
" x, y, w, h = 0, 0, w_full, h_full",
|
| 432 |
-
)
|
| 433 |
-
_ip_path.write_text(_ip_text3)
|
| 434 |
-
print("[load_triposg] Patched image_process.py: all-zero alpha fallback")
|
| 435 |
-
|
| 436 |
-
# Safety net: patch inference_utils.py to make diso import optional.
|
| 437 |
-
# Even if diso compiled with submodules, guard against any residual link errors.
|
| 438 |
-
_iu_path = triposg_src / "triposg" / "inference_utils.py"
|
| 439 |
-
if _iu_path.exists():
|
| 440 |
-
_iu_text = _iu_path.read_text()
|
| 441 |
-
if "queries.to(dtype=batch_latents.dtype)" not in _iu_text:
|
| 442 |
-
_iu_text = _iu_text.replace(
|
| 443 |
-
"from diso import DiffDMC",
|
| 444 |
-
"try:\n from diso import DiffDMC\n"
|
| 445 |
-
"except Exception as _diso_err:\n"
|
| 446 |
-
" print(f'[TripoSG] diso unavailable ({_diso_err}), using flash fallback')\n"
|
| 447 |
-
" DiffDMC = None",
|
| 448 |
-
)
|
| 449 |
-
if ("def hierarchical_extract_geometry(" in _iu_text
|
| 450 |
-
and "flash_extract_geometry" in _iu_text):
|
| 451 |
-
_iu_text = _iu_text.replace(
|
| 452 |
-
"def hierarchical_extract_geometry(",
|
| 453 |
-
"def _hierarchical_extract_geometry_impl(",
|
| 454 |
-
)
|
| 455 |
-
_iu_text += (
|
| 456 |
-
"\n\n"
|
| 457 |
-
"def hierarchical_extract_geometry(*args, **kwargs):\n"
|
| 458 |
-
" if DiffDMC is None:\n"
|
| 459 |
-
" return flash_extract_geometry(*args, **kwargs)\n"
|
| 460 |
-
" return _hierarchical_extract_geometry_impl(*args, **kwargs)\n"
|
| 461 |
-
)
|
| 462 |
-
# Also cast queries to match batch_latents dtype before vae.decode.
|
| 463 |
-
# TripoSGPipeline loads as float16 but flash_extract_geometry creates
|
| 464 |
-
# query grids as float32, causing a dtype mismatch in F.linear.
|
| 465 |
-
_iu_text = _iu_text.replace(
|
| 466 |
-
"logits = vae.decode(batch_latents, queries).sample",
|
| 467 |
-
"logits = vae.decode(batch_latents, queries.to(dtype=batch_latents.dtype)).sample",
|
| 468 |
-
)
|
| 469 |
-
_iu_path.write_text(_iu_text)
|
| 470 |
-
print("[load_triposg] Patched inference_utils.py: diso optional + queries dtype cast")
|
| 471 |
-
|
| 472 |
-
# Patch inference_triposg.py: replace pymeshlab (no py3.13 wheels) with trimesh.
|
| 473 |
-
# pymeshlab is only used for simplify_mesh() — QEM decimation + vertex merging.
|
| 474 |
-
# trimesh.simplify_quadric_decimation() is a direct equivalent (needs fast-simplification).
|
| 475 |
-
_it_path = triposg_src / "scripts" / "inference_triposg.py"
|
| 476 |
-
if _it_path.exists():
|
| 477 |
-
_it_text = _it_path.read_text()
|
| 478 |
-
if "pymeshlab_replaced_v1" not in _it_text:
|
| 479 |
-
# Step 1: always strip the top-level pymeshlab import
|
| 480 |
-
_it_text = _it_text.replace("import pymeshlab\n", "")
|
| 481 |
-
|
| 482 |
-
# Step 2: try to replace the pymeshlab helper functions with trimesh equivalent
|
| 483 |
-
_new_simplify = (
|
| 484 |
-
"# pymeshlab_replaced_v1: replaced with trimesh (no py3.13 wheels for pymeshlab)\n"
|
| 485 |
-
"def simplify_mesh(mesh: trimesh.Trimesh, n_faces):\n"
|
| 486 |
-
" if mesh.faces.shape[0] > n_faces:\n"
|
| 487 |
-
" mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=True)\n"
|
| 488 |
-
" mesh = mesh.simplify_quadric_decimation(n_faces)\n"
|
| 489 |
-
" return mesh\n"
|
| 490 |
-
)
|
| 491 |
-
# Try the exact upstream string first
|
| 492 |
-
_old_exact = (
|
| 493 |
-
"def mesh_to_pymesh(vertices, faces):\n"
|
| 494 |
-
" mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)\n"
|
| 495 |
-
" ms = pymeshlab.MeshSet()\n"
|
| 496 |
-
" ms.add_mesh(mesh)\n"
|
| 497 |
-
" return ms\n"
|
| 498 |
-
"\n"
|
| 499 |
-
"\n"
|
| 500 |
-
"def pymesh_to_trimesh(mesh):\n"
|
| 501 |
-
" verts = mesh.vertex_matrix()\n"
|
| 502 |
-
" faces = mesh.face_matrix()\n"
|
| 503 |
-
" return trimesh.Trimesh(vertices=verts, faces=faces)\n"
|
| 504 |
-
"\n"
|
| 505 |
-
"\n"
|
| 506 |
-
"def simplify_mesh(mesh: trimesh.Trimesh, n_faces):\n"
|
| 507 |
-
" if mesh.faces.shape[0] > n_faces:\n"
|
| 508 |
-
" ms = mesh_to_pymesh(mesh.vertices, mesh.faces)\n"
|
| 509 |
-
" ms.meshing_merge_close_vertices()\n"
|
| 510 |
-
" ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces)\n"
|
| 511 |
-
" return pymesh_to_trimesh(ms.current_mesh())\n"
|
| 512 |
-
" else:\n"
|
| 513 |
-
" return mesh\n"
|
| 514 |
-
)
|
| 515 |
-
if _old_exact in _it_text:
|
| 516 |
-
_it_text = _it_text.replace(_old_exact, _new_simplify)
|
| 517 |
-
print("[load_triposg] Patched inference_triposg.py: pymeshlab → trimesh (exact match)")
|
| 518 |
-
else:
|
| 519 |
-
# Fallback: regex — handles whitespace/indent variants
|
| 520 |
-
import re as _re
|
| 521 |
-
_it_text = _re.sub(
|
| 522 |
-
r"def mesh_to_pymesh\(.*?\n.*?\n.*?\n.*?\n.*?\n\n\ndef pymesh_to_trimesh\(.*?\n.*?\n.*?\n.*?\n\n\ndef simplify_mesh\([^)]*\):[^\n]*\n(?: [^\n]*\n)*",
|
| 523 |
-
_new_simplify,
|
| 524 |
-
_it_text,
|
| 525 |
-
flags=_re.DOTALL,
|
| 526 |
-
)
|
| 527 |
-
print("[load_triposg] Patched inference_triposg.py: pymeshlab → trimesh (regex fallback)")
|
| 528 |
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
# which is a NameError only if called, not at import time)
|
| 532 |
-
_it_path.write_text(_it_text)
|
| 533 |
-
print("[load_triposg] inference_triposg.py written.")
|
| 534 |
|
| 535 |
weights_path = snapshot_download("VAST-AI/TripoSG")
|
| 536 |
|
|
|
|
| 365 |
return _triposg_pipe, _rmbg_net
|
| 366 |
|
| 367 |
print("[load_triposg] Loading TripoSG pipeline...")
|
| 368 |
+
import shutil as _shutil
|
| 369 |
from huggingface_hub import snapshot_download
|
| 370 |
|
| 371 |
# TripoSG source has no setup.py — clone GitHub repo and add to sys.path
|
|
|
|
| 378 |
str(triposg_src)],
|
| 379 |
check=True
|
| 380 |
)
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
# Overwrite upstream scripts with pre-patched versions committed to this repo.
|
| 383 |
+
# Patches live in patches/triposg/ and mirror the upstream directory layout.
|
| 384 |
+
_patches_dir = HERE / "patches" / "triposg"
|
| 385 |
+
for _pf in _patches_dir.rglob("*"):
|
| 386 |
+
if _pf.is_file():
|
| 387 |
+
_dest = triposg_src / _pf.relative_to(_patches_dir)
|
| 388 |
+
_dest.parent.mkdir(parents=True, exist_ok=True)
|
| 389 |
+
_shutil.copy2(str(_pf), str(_dest))
|
| 390 |
+
print("[load_triposg] Applied pre-patched scripts from patches/triposg/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
+
if str(triposg_src) not in sys.path:
|
| 393 |
+
sys.path.insert(0, str(triposg_src))
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
weights_path = snapshot_download("VAST-AI/TripoSG")
|
| 396 |
|
patches/triposg/scripts/image_process.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
from skimage.morphology import remove_small_objects
|
| 4 |
+
from skimage.measure import label
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import cv2
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
|
| 13 |
+
def find_bounding_box(gray_image):
|
| 14 |
+
_, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
|
| 15 |
+
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 16 |
+
if not contours:
|
| 17 |
+
h, w = gray_image.shape[:2]
|
| 18 |
+
return 0, 0, w, h
|
| 19 |
+
max_contour = max(contours, key=cv2.contourArea)
|
| 20 |
+
x, y, w, h = cv2.boundingRect(max_contour)
|
| 21 |
+
return x, y, w, h
|
| 22 |
+
|
| 23 |
+
def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1):
|
| 24 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
| 25 |
+
if img is None:
|
| 26 |
+
return f"invalid image path {img_path}"
|
| 27 |
+
|
| 28 |
+
def is_valid_alpha(alpha, min_ratio = 0.01):
|
| 29 |
+
bins = 20
|
| 30 |
+
if isinstance(alpha, np.ndarray):
|
| 31 |
+
hist = cv2.calcHist([alpha], [0], None, [bins], [0, 256])
|
| 32 |
+
else:
|
| 33 |
+
hist = torch.histc(alpha, bins=bins, min=0, max=1)
|
| 34 |
+
min_hist_val = alpha.shape[0] * alpha.shape[1] * min_ratio
|
| 35 |
+
return hist[0] >= min_hist_val and hist[-1] >= min_hist_val
|
| 36 |
+
|
| 37 |
+
def rmbg(image: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
image = TF.normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0]).unsqueeze(0)
|
| 39 |
+
result=rmbg_net(image)
|
| 40 |
+
return result[0][0]
|
| 41 |
+
|
| 42 |
+
if len(img.shape) == 2:
|
| 43 |
+
num_channels = 1
|
| 44 |
+
else:
|
| 45 |
+
num_channels = img.shape[2]
|
| 46 |
+
|
| 47 |
+
# check if too large
|
| 48 |
+
height, width = img.shape[:2]
|
| 49 |
+
if height > width:
|
| 50 |
+
scale = 2000 / height
|
| 51 |
+
else:
|
| 52 |
+
scale = 2000 / width
|
| 53 |
+
if scale < 1:
|
| 54 |
+
new_size = (int(width * scale), int(height * scale))
|
| 55 |
+
img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
|
| 56 |
+
|
| 57 |
+
if img.dtype != 'uint8':
|
| 58 |
+
img = (img * (255. / np.iinfo(img.dtype).max)).astype(np.uint8)
|
| 59 |
+
|
| 60 |
+
rgb_image = None
|
| 61 |
+
alpha = None
|
| 62 |
+
|
| 63 |
+
if num_channels == 1:
|
| 64 |
+
rgb_image = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 65 |
+
elif num_channels == 3:
|
| 66 |
+
rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 67 |
+
elif num_channels == 4:
|
| 68 |
+
rgb_image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
| 69 |
+
|
| 70 |
+
b, g, r, alpha = cv2.split(img)
|
| 71 |
+
if not is_valid_alpha(alpha):
|
| 72 |
+
alpha = None
|
| 73 |
+
else:
|
| 74 |
+
alpha_gpu = torch.from_numpy(alpha).unsqueeze(0).cuda().float() / 255.
|
| 75 |
+
else:
|
| 76 |
+
return f"invalid image: channels {num_channels}"
|
| 77 |
+
|
| 78 |
+
rgb_image_gpu = torch.from_numpy(rgb_image).cuda().float().permute(2, 0, 1) / 255.
|
| 79 |
+
if alpha is None:
|
| 80 |
+
resize_transform = transforms.Resize((384, 384), antialias=True)
|
| 81 |
+
rgb_image_resized = resize_transform(rgb_image_gpu)
|
| 82 |
+
normalize_image = rgb_image_resized * 2 - 1
|
| 83 |
+
|
| 84 |
+
mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda()
|
| 85 |
+
resize_transform = transforms.Resize((1024, 1024), antialias=True)
|
| 86 |
+
rgb_image_resized = resize_transform(rgb_image_gpu)
|
| 87 |
+
max_value = rgb_image_resized.flatten().max()
|
| 88 |
+
if max_value < 1e-3:
|
| 89 |
+
return "invalid image: pure black image"
|
| 90 |
+
normalize_image = rgb_image_resized / max_value - mean_color
|
| 91 |
+
normalize_image = normalize_image.unsqueeze(0)
|
| 92 |
+
resize_transform = transforms.Resize((rgb_image_gpu.shape[1], rgb_image_gpu.shape[2]), antialias=True)
|
| 93 |
+
|
| 94 |
+
# seg from rmbg
|
| 95 |
+
if rmbg_net is None:
|
| 96 |
+
alpha_gpu_rmbg = torch.ones(1, 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2], device=rgb_image_resized.device)
|
| 97 |
+
else:
|
| 98 |
+
alpha_gpu_rmbg = rmbg(rgb_image_resized)
|
| 99 |
+
alpha_gpu_rmbg = alpha_gpu_rmbg.squeeze(0)
|
| 100 |
+
alpha_gpu_rmbg = resize_transform(alpha_gpu_rmbg)
|
| 101 |
+
ma, mi = alpha_gpu_rmbg.max(), alpha_gpu_rmbg.min()
|
| 102 |
+
alpha_gpu_rmbg = (alpha_gpu_rmbg - mi) / (ma - mi)
|
| 103 |
+
|
| 104 |
+
alpha_gpu = alpha_gpu_rmbg
|
| 105 |
+
|
| 106 |
+
alpha_gpu_tmp = alpha_gpu * 255
|
| 107 |
+
alpha = alpha_gpu_tmp.to(torch.uint8).squeeze().cpu().numpy()
|
| 108 |
+
|
| 109 |
+
_, alpha = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
| 110 |
+
labeled_alpha = label(alpha)
|
| 111 |
+
cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
|
| 112 |
+
cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
|
| 113 |
+
alpha = cleaned_alpha * 255
|
| 114 |
+
alpha_gpu = torch.from_numpy(cleaned_alpha).cuda().float().unsqueeze(0)
|
| 115 |
+
x, y, w, h = find_bounding_box(alpha)
|
| 116 |
+
|
| 117 |
+
# If alpha is provided, the bounds of all foreground are used
|
| 118 |
+
else:
|
| 119 |
+
rows, cols = np.where(alpha > 0)
|
| 120 |
+
if rows.size > 0 and cols.size > 0:
|
| 121 |
+
x_min = np.min(cols)
|
| 122 |
+
y_min = np.min(rows)
|
| 123 |
+
x_max = np.max(cols)
|
| 124 |
+
y_max = np.max(rows)
|
| 125 |
+
|
| 126 |
+
width = x_max - x_min + 1
|
| 127 |
+
height = y_max - y_min + 1
|
| 128 |
+
x, y, w, h = x_min, y_min, width, height
|
| 129 |
+
|
| 130 |
+
if np.all(alpha==0):
|
| 131 |
+
# Blank alpha: treat entire image as foreground instead of raising
|
| 132 |
+
h_full, w_full = alpha.shape[:2]
|
| 133 |
+
alpha = np.full((h_full, w_full), 255, dtype=np.uint8)
|
| 134 |
+
alpha_gpu = torch.ones(1, h_full, w_full, dtype=torch.float32, device=rgb_image_gpu.device)
|
| 135 |
+
x, y, w, h = 0, 0, w_full, h_full
|
| 136 |
+
|
| 137 |
+
bg_gray = bg_color[0]
|
| 138 |
+
bg_color = torch.from_numpy(bg_color).float().cuda().repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
|
| 139 |
+
rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
|
| 140 |
+
padding_size = [0] * 6
|
| 141 |
+
if w > h:
|
| 142 |
+
padding_size[0] = int(w * padding_ratio)
|
| 143 |
+
padding_size[2] = int(padding_size[0] + (w - h) / 2)
|
| 144 |
+
else:
|
| 145 |
+
padding_size[2] = int(h * padding_ratio)
|
| 146 |
+
padding_size[0] = int(padding_size[2] + (h - w) / 2)
|
| 147 |
+
padding_size[1] = padding_size[0]
|
| 148 |
+
padding_size[3] = padding_size[2]
|
| 149 |
+
padded_tensor = F.pad(rgb_image_gpu[:, y:(y+h), x:(x+w)], pad=tuple(padding_size), mode='constant', value=bg_gray)
|
| 150 |
+
|
| 151 |
+
return padded_tensor
|
| 152 |
+
|
| 153 |
+
def prepare_image(image_path, bg_color, rmbg_net=None):
|
| 154 |
+
if os.path.isfile(image_path):
|
| 155 |
+
img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net)
|
| 156 |
+
img_np = img_tensor.permute(1,2,0).cpu().numpy()
|
| 157 |
+
img_pil = Image.fromarray((img_np*255).astype(np.uint8))
|
| 158 |
+
|
| 159 |
+
return img_pil
|
patches/triposg/scripts/inference_triposg.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from glob import glob
|
| 5 |
+
from typing import Any, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import trimesh
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
|
| 16 |
+
from image_process import prepare_image
|
| 17 |
+
from briarmbg import BriaRMBG
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def run_triposg(
|
| 22 |
+
pipe: Any,
|
| 23 |
+
image_input: Union[str, Image.Image],
|
| 24 |
+
rmbg_net: Any,
|
| 25 |
+
seed: int,
|
| 26 |
+
num_inference_steps: int = 50,
|
| 27 |
+
guidance_scale: float = 7.0,
|
| 28 |
+
faces: int = -1,
|
| 29 |
+
) -> trimesh.Scene:
|
| 30 |
+
|
| 31 |
+
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
|
| 32 |
+
|
| 33 |
+
outputs = pipe(
|
| 34 |
+
image=img_pil,
|
| 35 |
+
generator=torch.Generator(device=pipe.device).manual_seed(seed),
|
| 36 |
+
num_inference_steps=num_inference_steps,
|
| 37 |
+
guidance_scale=guidance_scale,
|
| 38 |
+
).samples[0]
|
| 39 |
+
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
|
| 40 |
+
|
| 41 |
+
if faces > 0:
|
| 42 |
+
mesh = simplify_mesh(mesh, faces)
|
| 43 |
+
|
| 44 |
+
return mesh
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
|
| 48 |
+
if mesh.faces.shape[0] > n_faces:
|
| 49 |
+
mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=True)
|
| 50 |
+
mesh = mesh.simplify_quadric_decimation(n_faces)
|
| 51 |
+
return mesh
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
device = "cuda"
|
| 56 |
+
dtype = torch.float16
|
| 57 |
+
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
parser.add_argument("--image-input", type=str, required=True)
|
| 60 |
+
parser.add_argument("--output-path", type=str, default="./output.glb")
|
| 61 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 62 |
+
parser.add_argument("--num-inference-steps", type=int, default=50)
|
| 63 |
+
parser.add_argument("--guidance-scale", type=float, default=7.0)
|
| 64 |
+
parser.add_argument("--faces", type=int, default=-1)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
# download pretrained weights
|
| 68 |
+
triposg_weights_dir = "pretrained_weights/TripoSG"
|
| 69 |
+
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
|
| 70 |
+
snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
|
| 71 |
+
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
|
| 72 |
+
|
| 73 |
+
# init rmbg model for background removal
|
| 74 |
+
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
|
| 75 |
+
rmbg_net.eval()
|
| 76 |
+
|
| 77 |
+
# init tripoSG pipeline
|
| 78 |
+
pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype)
|
| 79 |
+
|
| 80 |
+
# run inference
|
| 81 |
+
run_triposg(
|
| 82 |
+
pipe,
|
| 83 |
+
image_input=args.image_input,
|
| 84 |
+
rmbg_net=rmbg_net,
|
| 85 |
+
seed=args.seed,
|
| 86 |
+
num_inference_steps=args.num_inference_steps,
|
| 87 |
+
guidance_scale=args.guidance_scale,
|
| 88 |
+
faces=args.faces,
|
| 89 |
+
).export(args.output_path)
|
| 90 |
+
print(f"Mesh saved to {args.output_path}")
|
patches/triposg/triposg/inference_utils.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import scipy.ndimage
|
| 5 |
+
from skimage import measure
|
| 6 |
+
from einops import repeat
|
| 7 |
+
try:
|
| 8 |
+
from diso import DiffDMC
|
| 9 |
+
except Exception as _diso_err:
|
| 10 |
+
print(f'[TripoSG] diso unavailable ({_diso_err}), using flash fallback')
|
| 11 |
+
DiffDMC = None
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from triposg.utils.typing import *
|
| 15 |
+
|
| 16 |
+
def generate_dense_grid_points_gpu(bbox_min: torch.Tensor,
|
| 17 |
+
bbox_max: torch.Tensor,
|
| 18 |
+
octree_depth: int,
|
| 19 |
+
indexing: str = "ij"):
|
| 20 |
+
length = bbox_max - bbox_min
|
| 21 |
+
num_cells = 2 ** octree_depth
|
| 22 |
+
device = bbox_min.device
|
| 23 |
+
|
| 24 |
+
x = torch.linspace(bbox_min[0], bbox_max[0], int(num_cells), dtype=torch.float16, device=device)
|
| 25 |
+
y = torch.linspace(bbox_min[1], bbox_max[1], int(num_cells), dtype=torch.float16, device=device)
|
| 26 |
+
z = torch.linspace(bbox_min[2], bbox_max[2], int(num_cells), dtype=torch.float16, device=device)
|
| 27 |
+
|
| 28 |
+
xs, ys, zs = torch.meshgrid(x, y, z, indexing=indexing)
|
| 29 |
+
xyz = torch.stack((xs, ys, zs), dim=-1)
|
| 30 |
+
xyz = xyz.view(-1, 3)
|
| 31 |
+
grid_size = [int(num_cells), int(num_cells), int(num_cells)]
|
| 32 |
+
|
| 33 |
+
return xyz, grid_size, length
|
| 34 |
+
|
| 35 |
+
def find_mesh_grid_coordinates_fast_gpu(occupancy_grid, n_limits=-1):
|
| 36 |
+
core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
|
| 37 |
+
occupied = core_grid > 0
|
| 38 |
+
|
| 39 |
+
neighbors_unoccupied = (
|
| 40 |
+
(occupancy_grid[:-2, :-2, :-2] < 0)
|
| 41 |
+
| (occupancy_grid[:-2, :-2, 1:-1] < 0)
|
| 42 |
+
| (occupancy_grid[:-2, :-2, 2:] < 0) # x-1, y-1, z-1/0/1
|
| 43 |
+
| (occupancy_grid[:-2, 1:-1, :-2] < 0)
|
| 44 |
+
| (occupancy_grid[:-2, 1:-1, 1:-1] < 0)
|
| 45 |
+
| (occupancy_grid[:-2, 1:-1, 2:] < 0) # x-1, y0, z-1/0/1
|
| 46 |
+
| (occupancy_grid[:-2, 2:, :-2] < 0)
|
| 47 |
+
| (occupancy_grid[:-2, 2:, 1:-1] < 0)
|
| 48 |
+
| (occupancy_grid[:-2, 2:, 2:] < 0) # x-1, y+1, z-1/0/1
|
| 49 |
+
| (occupancy_grid[1:-1, :-2, :-2] < 0)
|
| 50 |
+
| (occupancy_grid[1:-1, :-2, 1:-1] < 0)
|
| 51 |
+
| (occupancy_grid[1:-1, :-2, 2:] < 0) # x0, y-1, z-1/0/1
|
| 52 |
+
| (occupancy_grid[1:-1, 1:-1, :-2] < 0)
|
| 53 |
+
| (occupancy_grid[1:-1, 1:-1, 2:] < 0) # x0, y0, z-1/1
|
| 54 |
+
| (occupancy_grid[1:-1, 2:, :-2] < 0)
|
| 55 |
+
| (occupancy_grid[1:-1, 2:, 1:-1] < 0)
|
| 56 |
+
| (occupancy_grid[1:-1, 2:, 2:] < 0) # x0, y+1, z-1/0/1
|
| 57 |
+
| (occupancy_grid[2:, :-2, :-2] < 0)
|
| 58 |
+
| (occupancy_grid[2:, :-2, 1:-1] < 0)
|
| 59 |
+
| (occupancy_grid[2:, :-2, 2:] < 0) # x+1, y-1, z-1/0/1
|
| 60 |
+
| (occupancy_grid[2:, 1:-1, :-2] < 0)
|
| 61 |
+
| (occupancy_grid[2:, 1:-1, 1:-1] < 0)
|
| 62 |
+
| (occupancy_grid[2:, 1:-1, 2:] < 0) # x+1, y0, z-1/0/1
|
| 63 |
+
| (occupancy_grid[2:, 2:, :-2] < 0)
|
| 64 |
+
| (occupancy_grid[2:, 2:, 1:-1] < 0)
|
| 65 |
+
| (occupancy_grid[2:, 2:, 2:] < 0) # x+1, y+1, z-1/0/1
|
| 66 |
+
)
|
| 67 |
+
core_mesh_coords = torch.nonzero(occupied & neighbors_unoccupied, as_tuple=False) + 1
|
| 68 |
+
|
| 69 |
+
if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
|
| 70 |
+
print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
|
| 71 |
+
ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
|
| 72 |
+
core_mesh_coords = core_mesh_coords[ind]
|
| 73 |
+
|
| 74 |
+
return core_mesh_coords
|
| 75 |
+
|
| 76 |
+
def find_candidates_band(occupancy_grid: torch.Tensor, band_threshold: float, n_limits: int = -1) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Returns the coordinates of all voxels in the occupancy_grid where |value| < band_threshold.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
occupancy_grid (torch.Tensor): A 3D tensor of SDF values.
|
| 82 |
+
band_threshold (float): The threshold below which |SDF| must be to include the voxel.
|
| 83 |
+
n_limits (int): Maximum number of points to return (-1 for no limit)
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: A 2D tensor of coordinates (N x 3) where each row is [x, y, z].
|
| 87 |
+
"""
|
| 88 |
+
core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
|
| 89 |
+
# logits to sdf
|
| 90 |
+
core_grid = torch.sigmoid(core_grid) * 2 - 1
|
| 91 |
+
# Create a boolean mask for all cells in the band
|
| 92 |
+
in_band = torch.abs(core_grid) < band_threshold
|
| 93 |
+
|
| 94 |
+
# Get coordinates of all voxels in the band
|
| 95 |
+
core_mesh_coords = torch.nonzero(in_band, as_tuple=False) + 1
|
| 96 |
+
|
| 97 |
+
if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
|
| 98 |
+
print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
|
| 99 |
+
ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
|
| 100 |
+
core_mesh_coords = core_mesh_coords[ind]
|
| 101 |
+
|
| 102 |
+
return core_mesh_coords
|
| 103 |
+
|
| 104 |
+
def expand_edge_region_fast(edge_coords, grid_size):
|
| 105 |
+
expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=torch.float16, requires_grad=False)
|
| 106 |
+
expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
|
| 107 |
+
if grid_size < 512:
|
| 108 |
+
kernel_size = 5
|
| 109 |
+
pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=2).squeeze()
|
| 110 |
+
else:
|
| 111 |
+
kernel_size = 3
|
| 112 |
+
pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=1).squeeze()
|
| 113 |
+
expanded_coords_low_res = torch.nonzero(pooled_tensor, as_tuple=False).to(torch.int16)
|
| 114 |
+
|
| 115 |
+
expanded_coords_high_res = torch.stack([
|
| 116 |
+
torch.cat((expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1)),
|
| 117 |
+
torch.cat((expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2+1, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2 + 1)),
|
| 118 |
+
torch.cat((expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1))
|
| 119 |
+
], dim=1)
|
| 120 |
+
|
| 121 |
+
return expanded_coords_high_res
|
| 122 |
+
|
| 123 |
+
def zoom_block(block, scale_factor, order=3):
|
| 124 |
+
block = block.astype(np.float32)
|
| 125 |
+
return scipy.ndimage.zoom(block, scale_factor, order=order)
|
| 126 |
+
|
| 127 |
+
def parallel_zoom(occupancy_grid, scale_factor):
|
| 128 |
+
result = torch.nn.functional.interpolate(occupancy_grid.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor)
|
| 129 |
+
return result.squeeze(0).squeeze(0)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def _hierarchical_extract_geometry_impl(geometric_func: Callable,
|
| 134 |
+
device: torch.device,
|
| 135 |
+
bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
|
| 136 |
+
dense_octree_depth: int = 8,
|
| 137 |
+
hierarchical_octree_depth: int = 9,
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
geometric_func:
|
| 143 |
+
device:
|
| 144 |
+
bounds:
|
| 145 |
+
dense_octree_depth:
|
| 146 |
+
hierarchical_octree_depth:
|
| 147 |
+
Returns:
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
if isinstance(bounds, float):
|
| 151 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 152 |
+
|
| 153 |
+
bbox_min = torch.tensor(bounds[0:3]).to(device)
|
| 154 |
+
bbox_max = torch.tensor(bounds[3:6]).to(device)
|
| 155 |
+
bbox_size = bbox_max - bbox_min
|
| 156 |
+
|
| 157 |
+
xyz_samples, grid_size, length = generate_dense_grid_points_gpu(
|
| 158 |
+
bbox_min=bbox_min,
|
| 159 |
+
bbox_max=bbox_max,
|
| 160 |
+
octree_depth=dense_octree_depth,
|
| 161 |
+
indexing="ij"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
print(f'step 1 query num: {xyz_samples.shape[0]}')
|
| 165 |
+
grid_logits = geometric_func(xyz_samples.unsqueeze(0)).to(torch.float16).view(grid_size[0], grid_size[1], grid_size[2])
|
| 166 |
+
# print(f'step 1 grid_logits shape: {grid_logits.shape}')
|
| 167 |
+
for i in range(hierarchical_octree_depth - dense_octree_depth):
|
| 168 |
+
curr_octree_depth = dense_octree_depth + i + 1
|
| 169 |
+
# upsample
|
| 170 |
+
grid_size = 2**curr_octree_depth
|
| 171 |
+
normalize_offset = grid_size / 2
|
| 172 |
+
high_res_occupancy = parallel_zoom(grid_logits, 2)
|
| 173 |
+
|
| 174 |
+
band_threshold = 1.0
|
| 175 |
+
edge_coords = find_candidates_band(grid_logits, band_threshold)
|
| 176 |
+
expanded_coords = expand_edge_region_fast(edge_coords, grid_size=int(grid_size/2)).to(torch.float16)
|
| 177 |
+
print(f'step {i+2} query num: {len(expanded_coords)}')
|
| 178 |
+
expanded_coords_norm = (expanded_coords - normalize_offset) * (abs(bounds[0]) / normalize_offset)
|
| 179 |
+
|
| 180 |
+
all_logits = None
|
| 181 |
+
|
| 182 |
+
all_logits = geometric_func(expanded_coords_norm.unsqueeze(0)).to(torch.float16)
|
| 183 |
+
all_logits = torch.cat([expanded_coords_norm, all_logits[0]], dim=1)
|
| 184 |
+
# print("all logits shape = ", all_logits.shape)
|
| 185 |
+
|
| 186 |
+
indices = all_logits[..., :3]
|
| 187 |
+
indices = indices * (normalize_offset / abs(bounds[0])) + normalize_offset
|
| 188 |
+
indices = indices.type(torch.IntTensor)
|
| 189 |
+
values = all_logits[:, 3]
|
| 190 |
+
# breakpoint()
|
| 191 |
+
high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
|
| 192 |
+
grid_logits = high_res_occupancy
|
| 193 |
+
torch.cuda.empty_cache()
|
| 194 |
+
mesh_v_f = []
|
| 195 |
+
try:
|
| 196 |
+
print("final grids shape = ", grid_logits.shape)
|
| 197 |
+
vertices, faces, normals, _ = measure.marching_cubes(grid_logits.float().cpu().numpy(), 0, method="lewiner")
|
| 198 |
+
vertices = vertices / (2**hierarchical_octree_depth) * bbox_size.cpu().numpy() + bbox_min.cpu().numpy()
|
| 199 |
+
mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(e)
|
| 202 |
+
torch.cuda.empty_cache()
|
| 203 |
+
mesh_v_f = (None, None)
|
| 204 |
+
|
| 205 |
+
return [mesh_v_f]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def hierarchical_extract_geometry(*args, **kwargs):
|
| 209 |
+
"""Wrapper: uses DiffDMC-based flash path when diso is available, else marching cubes fallback."""
|
| 210 |
+
if DiffDMC is None:
|
| 211 |
+
# flash_extract_geometry needs latents + vae — forward positional args unchanged
|
| 212 |
+
return flash_extract_geometry(*args, **kwargs)
|
| 213 |
+
return _hierarchical_extract_geometry_impl(*args, **kwargs)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
| 217 |
+
"""
|
| 218 |
+
Args:
|
| 219 |
+
input_tensor: shape [D, D, D], torch.float16
|
| 220 |
+
alpha: isosurface offset
|
| 221 |
+
Returns:
|
| 222 |
+
mask: shape [D, D, D], torch.int32
|
| 223 |
+
"""
|
| 224 |
+
device = input_tensor.device
|
| 225 |
+
D = input_tensor.shape[0]
|
| 226 |
+
signed_val = 0.0
|
| 227 |
+
|
| 228 |
+
# add isosurface offset and exclude invalid value
|
| 229 |
+
val = input_tensor + alpha
|
| 230 |
+
valid_mask = val > -9000
|
| 231 |
+
|
| 232 |
+
# obtain neighbors
|
| 233 |
+
def get_neighbor(t, shift, axis):
|
| 234 |
+
if shift == 0:
|
| 235 |
+
return t.clone()
|
| 236 |
+
|
| 237 |
+
pad_dims = [0, 0, 0, 0, 0, 0] # [x_front,x_back,y_front,y_back,z_front,z_back]
|
| 238 |
+
|
| 239 |
+
if axis == 0: # x axis
|
| 240 |
+
pad_idx = 0 if shift > 0 else 1
|
| 241 |
+
pad_dims[pad_idx] = abs(shift)
|
| 242 |
+
elif axis == 1: # y axis
|
| 243 |
+
pad_idx = 2 if shift > 0 else 3
|
| 244 |
+
pad_dims[pad_idx] = abs(shift)
|
| 245 |
+
elif axis == 2: # z axis
|
| 246 |
+
pad_idx = 4 if shift > 0 else 5
|
| 247 |
+
pad_dims[pad_idx] = abs(shift)
|
| 248 |
+
|
| 249 |
+
# Apply padding with replication at boundaries
|
| 250 |
+
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
|
| 251 |
+
|
| 252 |
+
# Create dynamic slicing indices
|
| 253 |
+
slice_dims = [slice(None)] * 3
|
| 254 |
+
if axis == 0: # x axis
|
| 255 |
+
if shift > 0:
|
| 256 |
+
slice_dims[0] = slice(shift, None)
|
| 257 |
+
else:
|
| 258 |
+
slice_dims[0] = slice(None, shift)
|
| 259 |
+
elif axis == 1: # y axis
|
| 260 |
+
if shift > 0:
|
| 261 |
+
slice_dims[1] = slice(shift, None)
|
| 262 |
+
else:
|
| 263 |
+
slice_dims[1] = slice(None, shift)
|
| 264 |
+
elif axis == 2: # z axis
|
| 265 |
+
if shift > 0:
|
| 266 |
+
slice_dims[2] = slice(shift, None)
|
| 267 |
+
else:
|
| 268 |
+
slice_dims[2] = slice(None, shift)
|
| 269 |
+
|
| 270 |
+
# Apply slicing and restore dimensions
|
| 271 |
+
padded = padded.squeeze(0).squeeze(0)
|
| 272 |
+
sliced = padded[slice_dims]
|
| 273 |
+
return sliced
|
| 274 |
+
|
| 275 |
+
# Get neighbors in all directions
|
| 276 |
+
left = get_neighbor(val, 1, axis=0) # x axis
|
| 277 |
+
right = get_neighbor(val, -1, axis=0)
|
| 278 |
+
back = get_neighbor(val, 1, axis=1) # y axis
|
| 279 |
+
front = get_neighbor(val, -1, axis=1)
|
| 280 |
+
down = get_neighbor(val, 1, axis=2) # z axis
|
| 281 |
+
up = get_neighbor(val, -1, axis=2)
|
| 282 |
+
|
| 283 |
+
# Handle invalid boundary values
|
| 284 |
+
def safe_where(neighbor):
|
| 285 |
+
return torch.where(neighbor > -9000, neighbor, val)
|
| 286 |
+
|
| 287 |
+
left = safe_where(left)
|
| 288 |
+
right = safe_where(right)
|
| 289 |
+
back = safe_where(back)
|
| 290 |
+
front = safe_where(front)
|
| 291 |
+
down = safe_where(down)
|
| 292 |
+
up = safe_where(up)
|
| 293 |
+
|
| 294 |
+
# Calculate sign consistency
|
| 295 |
+
sign = torch.sign(val.to(torch.float32))
|
| 296 |
+
neighbors_sign = torch.stack([
|
| 297 |
+
torch.sign(left.to(torch.float32)),
|
| 298 |
+
torch.sign(right.to(torch.float32)),
|
| 299 |
+
torch.sign(back.to(torch.float32)),
|
| 300 |
+
torch.sign(front.to(torch.float32)),
|
| 301 |
+
torch.sign(down.to(torch.float32)),
|
| 302 |
+
torch.sign(up.to(torch.float32))
|
| 303 |
+
], dim=0)
|
| 304 |
+
|
| 305 |
+
# Check if all signs are consistent
|
| 306 |
+
same_sign = torch.all(neighbors_sign == sign, dim=0)
|
| 307 |
+
|
| 308 |
+
# Generate final mask
|
| 309 |
+
mask = (~same_sign).to(torch.int32)
|
| 310 |
+
return mask * valid_mask.to(torch.int32)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def generate_dense_grid_points_2(
|
| 314 |
+
bbox_min: np.ndarray,
|
| 315 |
+
bbox_max: np.ndarray,
|
| 316 |
+
octree_resolution: int,
|
| 317 |
+
indexing: str = "ij",
|
| 318 |
+
):
|
| 319 |
+
length = bbox_max - bbox_min
|
| 320 |
+
num_cells = octree_resolution
|
| 321 |
+
|
| 322 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
| 323 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
| 324 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
| 325 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
| 326 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
| 327 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
| 328 |
+
|
| 329 |
+
return xyz, grid_size, length
|
| 330 |
+
|
| 331 |
+
@torch.no_grad()
|
| 332 |
+
def flash_extract_geometry(
|
| 333 |
+
latents: torch.FloatTensor,
|
| 334 |
+
vae: Callable,
|
| 335 |
+
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 336 |
+
num_chunks: int = 10000,
|
| 337 |
+
mc_level: float = 0.0,
|
| 338 |
+
octree_depth: int = 9,
|
| 339 |
+
min_resolution: int = 63,
|
| 340 |
+
mini_grid_num: int = 4,
|
| 341 |
+
**kwargs,
|
| 342 |
+
):
|
| 343 |
+
geo_decoder = vae.decoder
|
| 344 |
+
device = latents.device
|
| 345 |
+
dtype = latents.dtype
|
| 346 |
+
# resolution to depth
|
| 347 |
+
octree_resolution = 2 ** octree_depth
|
| 348 |
+
resolutions = []
|
| 349 |
+
if octree_resolution < min_resolution:
|
| 350 |
+
resolutions.append(octree_resolution)
|
| 351 |
+
while octree_resolution >= min_resolution:
|
| 352 |
+
resolutions.append(octree_resolution)
|
| 353 |
+
octree_resolution = octree_resolution // 2
|
| 354 |
+
resolutions.reverse()
|
| 355 |
+
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
| 356 |
+
for i, resolution in enumerate(resolutions[1:]):
|
| 357 |
+
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# 1. generate query points
|
| 361 |
+
if isinstance(bounds, float):
|
| 362 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 363 |
+
bbox_min = np.array(bounds[0:3])
|
| 364 |
+
bbox_max = np.array(bounds[3:6])
|
| 365 |
+
bbox_size = bbox_max - bbox_min
|
| 366 |
+
|
| 367 |
+
xyz_samples, grid_size, length = generate_dense_grid_points_2(
|
| 368 |
+
bbox_min=bbox_min,
|
| 369 |
+
bbox_max=bbox_max,
|
| 370 |
+
octree_resolution=resolutions[0],
|
| 371 |
+
indexing="ij"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
| 375 |
+
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
| 376 |
+
|
| 377 |
+
grid_size = np.array(grid_size)
|
| 378 |
+
|
| 379 |
+
# 2. latents to 3d volume
|
| 380 |
+
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
| 381 |
+
batch_size = latents.shape[0]
|
| 382 |
+
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
| 383 |
+
xyz_samples = xyz_samples.view(
|
| 384 |
+
mini_grid_num, mini_grid_size,
|
| 385 |
+
mini_grid_num, mini_grid_size,
|
| 386 |
+
mini_grid_num, mini_grid_size, 3
|
| 387 |
+
).permute(
|
| 388 |
+
0, 2, 4, 1, 3, 5, 6
|
| 389 |
+
).reshape(
|
| 390 |
+
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
| 391 |
+
)
|
| 392 |
+
batch_logits = []
|
| 393 |
+
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
| 394 |
+
for start in range(0, xyz_samples.shape[0], num_batchs):
|
| 395 |
+
queries = xyz_samples[start: start + num_batchs, :]
|
| 396 |
+
batch = queries.shape[0]
|
| 397 |
+
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
| 398 |
+
# geo_decoder.set_topk(True)
|
| 399 |
+
geo_decoder.set_topk(False)
|
| 400 |
+
logits = vae.decode(batch_latents, queries.to(dtype=batch_latents.dtype)).sample
|
| 401 |
+
batch_logits.append(logits)
|
| 402 |
+
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
| 403 |
+
mini_grid_num, mini_grid_num, mini_grid_num,
|
| 404 |
+
mini_grid_size, mini_grid_size,
|
| 405 |
+
mini_grid_size
|
| 406 |
+
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
| 407 |
+
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
for octree_depth_now in resolutions[1:]:
|
| 411 |
+
grid_size = np.array([octree_depth_now + 1] * 3)
|
| 412 |
+
resolution = bbox_size / octree_depth_now
|
| 413 |
+
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
| 414 |
+
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
| 415 |
+
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
| 416 |
+
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
| 417 |
+
|
| 418 |
+
if octree_depth_now == resolutions[-1]:
|
| 419 |
+
expand_num = 0
|
| 420 |
+
else:
|
| 421 |
+
expand_num = 1
|
| 422 |
+
for i in range(expand_num):
|
| 423 |
+
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
| 424 |
+
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
| 425 |
+
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
| 426 |
+
|
| 427 |
+
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
| 428 |
+
for i in range(2 - expand_num):
|
| 429 |
+
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
| 430 |
+
nidx = torch.where(next_index > 0)
|
| 431 |
+
|
| 432 |
+
next_points = torch.stack(nidx, dim=1)
|
| 433 |
+
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
| 434 |
+
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
| 435 |
+
|
| 436 |
+
query_grid_num = 6
|
| 437 |
+
min_val = next_points.min(axis=0).values
|
| 438 |
+
max_val = next_points.max(axis=0).values
|
| 439 |
+
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
| 440 |
+
index = torch.floor(vol_queries_index).long()
|
| 441 |
+
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
| 442 |
+
index = index.sort()
|
| 443 |
+
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
| 444 |
+
unique_values = torch.unique(index.values, return_counts=True)
|
| 445 |
+
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
| 446 |
+
input_grid = [[], []]
|
| 447 |
+
logits_grid_list = []
|
| 448 |
+
start_num = 0
|
| 449 |
+
sum_num = 0
|
| 450 |
+
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
| 451 |
+
if sum_num + count < num_chunks or sum_num == 0:
|
| 452 |
+
sum_num += count
|
| 453 |
+
input_grid[0].append(grid_index)
|
| 454 |
+
input_grid[1].append(count)
|
| 455 |
+
else:
|
| 456 |
+
# geo_decoder.set_topk(input_grid)
|
| 457 |
+
geo_decoder.set_topk(False)
|
| 458 |
+
logits_grid = vae.decode(latents, next_points[:, start_num:start_num + sum_num].to(dtype=latents.dtype)).sample
|
| 459 |
+
start_num = start_num + sum_num
|
| 460 |
+
logits_grid_list.append(logits_grid)
|
| 461 |
+
input_grid = [[grid_index], [count]]
|
| 462 |
+
sum_num = count
|
| 463 |
+
if sum_num > 0:
|
| 464 |
+
# geo_decoder.set_topk(input_grid)
|
| 465 |
+
geo_decoder.set_topk(False)
|
| 466 |
+
logits_grid = vae.decode(latents, next_points[:, start_num:start_num + sum_num].to(dtype=latents.dtype)).sample
|
| 467 |
+
logits_grid_list.append(logits_grid)
|
| 468 |
+
logits_grid = torch.cat(logits_grid_list, dim=1)
|
| 469 |
+
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
| 470 |
+
next_logits[nidx] = grid_logits
|
| 471 |
+
grid_logits = next_logits.unsqueeze(0)
|
| 472 |
+
|
| 473 |
+
grid_logits[grid_logits == -10000.] = float('nan')
|
| 474 |
+
torch.cuda.empty_cache()
|
| 475 |
+
mesh_v_f = []
|
| 476 |
+
grid_logits = grid_logits[0]
|
| 477 |
+
try:
|
| 478 |
+
print("final grids shape = ", grid_logits.shape)
|
| 479 |
+
dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
|
| 480 |
+
sdf = -grid_logits / octree_resolution
|
| 481 |
+
sdf = sdf.to(torch.float32).contiguous()
|
| 482 |
+
vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
|
| 483 |
+
vertices = vertices.detach().cpu().numpy()
|
| 484 |
+
faces = faces.detach().cpu().numpy()[:, ::-1]
|
| 485 |
+
vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
|
| 486 |
+
mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
|
| 487 |
+
except Exception as e:
|
| 488 |
+
print(e)
|
| 489 |
+
torch.cuda.empty_cache()
|
| 490 |
+
mesh_v_f = (None, None)
|
| 491 |
+
|
| 492 |
+
return [mesh_v_f]
|