Daankular commited on
Commit
4a05450
·
1 Parent(s): 60f5fd7

Replace runtime TripoSG string-patching with pre-patched files in patches/triposg/

Browse files

All three TripoSG scripts that previously required runtime str.replace() hacks are
now committed as fixed files under patches/triposg/ (mirroring upstream layout).
load_triposg() now simply shutil.copy2s them over the cloned repo after checkout.

Patches applied in the files:
- scripts/inference_triposg.py: remove pymeshlab import + helper fns, use trimesh.simplify_quadric_decimation
- scripts/image_process.py: empty contours guard, rmbg_net=None guard, all-zero alpha fallback
- triposg/inference_utils.py: diso optional try/except, hierarchical wrapper, queries dtype cast

app.py CHANGED
@@ -365,6 +365,7 @@ def load_triposg():
365
  return _triposg_pipe, _rmbg_net
366
 
367
  print("[load_triposg] Loading TripoSG pipeline...")
 
368
  from huggingface_hub import snapshot_download
369
 
370
  # TripoSG source has no setup.py — clone GitHub repo and add to sys.path
@@ -377,160 +378,19 @@ def load_triposg():
377
  str(triposg_src)],
378
  check=True
379
  )
380
- if str(triposg_src) not in sys.path:
381
- sys.path.insert(0, str(triposg_src))
382
 
383
- # Patch image_process.py: guard rmbg_net=None in load_image.
384
- # TripoSG calls rmbg(rgb_image_resized) unconditionally when alpha is None,
385
- # with no check for rmbg_net being None. Fallback: all-white alpha (full foreground).
386
- _ip_path = triposg_src / "scripts" / "image_process.py"
387
- if _ip_path.exists():
388
- _ip_text = _ip_path.read_text()
389
- if "rmbg_net_none_guard_v2" not in _ip_text:
390
- _ip_text = _ip_text.replace(
391
- " # seg from rmbg\n alpha_gpu_rmbg = rmbg(rgb_image_resized)",
392
- " # seg from rmbg\n"
393
- " if rmbg_net is None: # rmbg_net_none_guard_v2\n"
394
- " alpha_gpu_rmbg = torch.ones(\n"
395
- " 1, 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
396
- " device=rgb_image_resized.device)\n"
397
- " else:\n"
398
- " alpha_gpu_rmbg = rmbg(rgb_image_resized)",
399
- )
400
- _ip_path.write_text(_ip_text)
401
- print("[load_triposg] Patched image_process.py: rmbg_net None guard")
402
-
403
- # Patch find_bounding_box: guard against empty contours (blank alpha mask).
404
- # When RMBG produces an all-black mask, findContours returns [] and max() raises.
405
- # Fallback: return the full image bounding box so pipeline can continue.
406
- # NOTE: parameter is gray_image, not alpha.
407
- _ip_text2 = _ip_path.read_text()
408
- if "empty_contours_guard" not in _ip_text2:
409
- _ip_text2 = _ip_text2.replace(
410
- " max_contour = max(contours, key=cv2.contourArea)",
411
- " if not contours: # empty_contours_guard\n"
412
- " h, w = gray_image.shape[:2]\n"
413
- " return 0, 0, w, h\n"
414
- " max_contour = max(contours, key=cv2.contourArea)",
415
- )
416
- _ip_path.write_text(_ip_text2)
417
- print("[load_triposg] Patched image_process.py: empty contours guard")
418
-
419
- # Patch all-zero alpha guard: instead of raising ValueError("input image too small"),
420
- # fall back to full-foreground alpha so the pipeline can continue with the whole image.
421
- # Happens when RMBG produces a blank mask (e.g. remove_small_objects wipes everything).
422
- _ip_text3 = _ip_path.read_text()
423
- if "all_zero_alpha_guard" not in _ip_text3:
424
- _ip_text3 = _ip_text3.replace(
425
- ' if np.all(alpha==0):\n raise ValueError(f"input image too small")',
426
- " if np.all(alpha==0): # all_zero_alpha_guard\n"
427
- " h_full, w_full = alpha.shape[:2]\n"
428
- " alpha = np.full((h_full, w_full), 255, dtype=np.uint8)\n"
429
- " alpha_gpu = torch.ones(1, h_full, w_full, dtype=torch.float32,\n"
430
- " device=rgb_image_gpu.device)\n"
431
- " x, y, w, h = 0, 0, w_full, h_full",
432
- )
433
- _ip_path.write_text(_ip_text3)
434
- print("[load_triposg] Patched image_process.py: all-zero alpha fallback")
435
-
436
- # Safety net: patch inference_utils.py to make diso import optional.
437
- # Even if diso compiled with submodules, guard against any residual link errors.
438
- _iu_path = triposg_src / "triposg" / "inference_utils.py"
439
- if _iu_path.exists():
440
- _iu_text = _iu_path.read_text()
441
- if "queries.to(dtype=batch_latents.dtype)" not in _iu_text:
442
- _iu_text = _iu_text.replace(
443
- "from diso import DiffDMC",
444
- "try:\n from diso import DiffDMC\n"
445
- "except Exception as _diso_err:\n"
446
- " print(f'[TripoSG] diso unavailable ({_diso_err}), using flash fallback')\n"
447
- " DiffDMC = None",
448
- )
449
- if ("def hierarchical_extract_geometry(" in _iu_text
450
- and "flash_extract_geometry" in _iu_text):
451
- _iu_text = _iu_text.replace(
452
- "def hierarchical_extract_geometry(",
453
- "def _hierarchical_extract_geometry_impl(",
454
- )
455
- _iu_text += (
456
- "\n\n"
457
- "def hierarchical_extract_geometry(*args, **kwargs):\n"
458
- " if DiffDMC is None:\n"
459
- " return flash_extract_geometry(*args, **kwargs)\n"
460
- " return _hierarchical_extract_geometry_impl(*args, **kwargs)\n"
461
- )
462
- # Also cast queries to match batch_latents dtype before vae.decode.
463
- # TripoSGPipeline loads as float16 but flash_extract_geometry creates
464
- # query grids as float32, causing a dtype mismatch in F.linear.
465
- _iu_text = _iu_text.replace(
466
- "logits = vae.decode(batch_latents, queries).sample",
467
- "logits = vae.decode(batch_latents, queries.to(dtype=batch_latents.dtype)).sample",
468
- )
469
- _iu_path.write_text(_iu_text)
470
- print("[load_triposg] Patched inference_utils.py: diso optional + queries dtype cast")
471
-
472
- # Patch inference_triposg.py: replace pymeshlab (no py3.13 wheels) with trimesh.
473
- # pymeshlab is only used for simplify_mesh() — QEM decimation + vertex merging.
474
- # trimesh.simplify_quadric_decimation() is a direct equivalent (needs fast-simplification).
475
- _it_path = triposg_src / "scripts" / "inference_triposg.py"
476
- if _it_path.exists():
477
- _it_text = _it_path.read_text()
478
- if "pymeshlab_replaced_v1" not in _it_text:
479
- # Step 1: always strip the top-level pymeshlab import
480
- _it_text = _it_text.replace("import pymeshlab\n", "")
481
-
482
- # Step 2: try to replace the pymeshlab helper functions with trimesh equivalent
483
- _new_simplify = (
484
- "# pymeshlab_replaced_v1: replaced with trimesh (no py3.13 wheels for pymeshlab)\n"
485
- "def simplify_mesh(mesh: trimesh.Trimesh, n_faces):\n"
486
- " if mesh.faces.shape[0] > n_faces:\n"
487
- " mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=True)\n"
488
- " mesh = mesh.simplify_quadric_decimation(n_faces)\n"
489
- " return mesh\n"
490
- )
491
- # Try the exact upstream string first
492
- _old_exact = (
493
- "def mesh_to_pymesh(vertices, faces):\n"
494
- " mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)\n"
495
- " ms = pymeshlab.MeshSet()\n"
496
- " ms.add_mesh(mesh)\n"
497
- " return ms\n"
498
- "\n"
499
- "\n"
500
- "def pymesh_to_trimesh(mesh):\n"
501
- " verts = mesh.vertex_matrix()\n"
502
- " faces = mesh.face_matrix()\n"
503
- " return trimesh.Trimesh(vertices=verts, faces=faces)\n"
504
- "\n"
505
- "\n"
506
- "def simplify_mesh(mesh: trimesh.Trimesh, n_faces):\n"
507
- " if mesh.faces.shape[0] > n_faces:\n"
508
- " ms = mesh_to_pymesh(mesh.vertices, mesh.faces)\n"
509
- " ms.meshing_merge_close_vertices()\n"
510
- " ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces)\n"
511
- " return pymesh_to_trimesh(ms.current_mesh())\n"
512
- " else:\n"
513
- " return mesh\n"
514
- )
515
- if _old_exact in _it_text:
516
- _it_text = _it_text.replace(_old_exact, _new_simplify)
517
- print("[load_triposg] Patched inference_triposg.py: pymeshlab → trimesh (exact match)")
518
- else:
519
- # Fallback: regex — handles whitespace/indent variants
520
- import re as _re
521
- _it_text = _re.sub(
522
- r"def mesh_to_pymesh\(.*?\n.*?\n.*?\n.*?\n.*?\n\n\ndef pymesh_to_trimesh\(.*?\n.*?\n.*?\n.*?\n\n\ndef simplify_mesh\([^)]*\):[^\n]*\n(?: [^\n]*\n)*",
523
- _new_simplify,
524
- _it_text,
525
- flags=_re.DOTALL,
526
- )
527
- print("[load_triposg] Patched inference_triposg.py: pymeshlab → trimesh (regex fallback)")
528
 
529
- # Step 3: always write — import removal alone fixes the crash even if
530
- # function replacement didn't match (simplify_mesh just won't exist,
531
- # which is a NameError only if called, not at import time)
532
- _it_path.write_text(_it_text)
533
- print("[load_triposg] inference_triposg.py written.")
534
 
535
  weights_path = snapshot_download("VAST-AI/TripoSG")
536
 
 
365
  return _triposg_pipe, _rmbg_net
366
 
367
  print("[load_triposg] Loading TripoSG pipeline...")
368
+ import shutil as _shutil
369
  from huggingface_hub import snapshot_download
370
 
371
  # TripoSG source has no setup.py — clone GitHub repo and add to sys.path
 
378
  str(triposg_src)],
379
  check=True
380
  )
 
 
381
 
382
+ # Overwrite upstream scripts with pre-patched versions committed to this repo.
383
+ # Patches live in patches/triposg/ and mirror the upstream directory layout.
384
+ _patches_dir = HERE / "patches" / "triposg"
385
+ for _pf in _patches_dir.rglob("*"):
386
+ if _pf.is_file():
387
+ _dest = triposg_src / _pf.relative_to(_patches_dir)
388
+ _dest.parent.mkdir(parents=True, exist_ok=True)
389
+ _shutil.copy2(str(_pf), str(_dest))
390
+ print("[load_triposg] Applied pre-patched scripts from patches/triposg/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ if str(triposg_src) not in sys.path:
393
+ sys.path.insert(0, str(triposg_src))
 
 
 
394
 
395
  weights_path = snapshot_download("VAST-AI/TripoSG")
396
 
patches/triposg/scripts/image_process.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ from skimage.morphology import remove_small_objects
4
+ from skimage.measure import label
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ from torchvision import transforms
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+
13
+ def find_bounding_box(gray_image):
14
+ _, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
15
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
16
+ if not contours:
17
+ h, w = gray_image.shape[:2]
18
+ return 0, 0, w, h
19
+ max_contour = max(contours, key=cv2.contourArea)
20
+ x, y, w, h = cv2.boundingRect(max_contour)
21
+ return x, y, w, h
22
+
23
+ def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1):
24
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
25
+ if img is None:
26
+ return f"invalid image path {img_path}"
27
+
28
+ def is_valid_alpha(alpha, min_ratio = 0.01):
29
+ bins = 20
30
+ if isinstance(alpha, np.ndarray):
31
+ hist = cv2.calcHist([alpha], [0], None, [bins], [0, 256])
32
+ else:
33
+ hist = torch.histc(alpha, bins=bins, min=0, max=1)
34
+ min_hist_val = alpha.shape[0] * alpha.shape[1] * min_ratio
35
+ return hist[0] >= min_hist_val and hist[-1] >= min_hist_val
36
+
37
+ def rmbg(image: torch.Tensor) -> torch.Tensor:
38
+ image = TF.normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0]).unsqueeze(0)
39
+ result=rmbg_net(image)
40
+ return result[0][0]
41
+
42
+ if len(img.shape) == 2:
43
+ num_channels = 1
44
+ else:
45
+ num_channels = img.shape[2]
46
+
47
+ # check if too large
48
+ height, width = img.shape[:2]
49
+ if height > width:
50
+ scale = 2000 / height
51
+ else:
52
+ scale = 2000 / width
53
+ if scale < 1:
54
+ new_size = (int(width * scale), int(height * scale))
55
+ img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
56
+
57
+ if img.dtype != 'uint8':
58
+ img = (img * (255. / np.iinfo(img.dtype).max)).astype(np.uint8)
59
+
60
+ rgb_image = None
61
+ alpha = None
62
+
63
+ if num_channels == 1:
64
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
65
+ elif num_channels == 3:
66
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67
+ elif num_channels == 4:
68
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
69
+
70
+ b, g, r, alpha = cv2.split(img)
71
+ if not is_valid_alpha(alpha):
72
+ alpha = None
73
+ else:
74
+ alpha_gpu = torch.from_numpy(alpha).unsqueeze(0).cuda().float() / 255.
75
+ else:
76
+ return f"invalid image: channels {num_channels}"
77
+
78
+ rgb_image_gpu = torch.from_numpy(rgb_image).cuda().float().permute(2, 0, 1) / 255.
79
+ if alpha is None:
80
+ resize_transform = transforms.Resize((384, 384), antialias=True)
81
+ rgb_image_resized = resize_transform(rgb_image_gpu)
82
+ normalize_image = rgb_image_resized * 2 - 1
83
+
84
+ mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda()
85
+ resize_transform = transforms.Resize((1024, 1024), antialias=True)
86
+ rgb_image_resized = resize_transform(rgb_image_gpu)
87
+ max_value = rgb_image_resized.flatten().max()
88
+ if max_value < 1e-3:
89
+ return "invalid image: pure black image"
90
+ normalize_image = rgb_image_resized / max_value - mean_color
91
+ normalize_image = normalize_image.unsqueeze(0)
92
+ resize_transform = transforms.Resize((rgb_image_gpu.shape[1], rgb_image_gpu.shape[2]), antialias=True)
93
+
94
+ # seg from rmbg
95
+ if rmbg_net is None:
96
+ alpha_gpu_rmbg = torch.ones(1, 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2], device=rgb_image_resized.device)
97
+ else:
98
+ alpha_gpu_rmbg = rmbg(rgb_image_resized)
99
+ alpha_gpu_rmbg = alpha_gpu_rmbg.squeeze(0)
100
+ alpha_gpu_rmbg = resize_transform(alpha_gpu_rmbg)
101
+ ma, mi = alpha_gpu_rmbg.max(), alpha_gpu_rmbg.min()
102
+ alpha_gpu_rmbg = (alpha_gpu_rmbg - mi) / (ma - mi)
103
+
104
+ alpha_gpu = alpha_gpu_rmbg
105
+
106
+ alpha_gpu_tmp = alpha_gpu * 255
107
+ alpha = alpha_gpu_tmp.to(torch.uint8).squeeze().cpu().numpy()
108
+
109
+ _, alpha = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
110
+ labeled_alpha = label(alpha)
111
+ cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
112
+ cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
113
+ alpha = cleaned_alpha * 255
114
+ alpha_gpu = torch.from_numpy(cleaned_alpha).cuda().float().unsqueeze(0)
115
+ x, y, w, h = find_bounding_box(alpha)
116
+
117
+ # If alpha is provided, the bounds of all foreground are used
118
+ else:
119
+ rows, cols = np.where(alpha > 0)
120
+ if rows.size > 0 and cols.size > 0:
121
+ x_min = np.min(cols)
122
+ y_min = np.min(rows)
123
+ x_max = np.max(cols)
124
+ y_max = np.max(rows)
125
+
126
+ width = x_max - x_min + 1
127
+ height = y_max - y_min + 1
128
+ x, y, w, h = x_min, y_min, width, height
129
+
130
+ if np.all(alpha==0):
131
+ # Blank alpha: treat entire image as foreground instead of raising
132
+ h_full, w_full = alpha.shape[:2]
133
+ alpha = np.full((h_full, w_full), 255, dtype=np.uint8)
134
+ alpha_gpu = torch.ones(1, h_full, w_full, dtype=torch.float32, device=rgb_image_gpu.device)
135
+ x, y, w, h = 0, 0, w_full, h_full
136
+
137
+ bg_gray = bg_color[0]
138
+ bg_color = torch.from_numpy(bg_color).float().cuda().repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
139
+ rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
140
+ padding_size = [0] * 6
141
+ if w > h:
142
+ padding_size[0] = int(w * padding_ratio)
143
+ padding_size[2] = int(padding_size[0] + (w - h) / 2)
144
+ else:
145
+ padding_size[2] = int(h * padding_ratio)
146
+ padding_size[0] = int(padding_size[2] + (h - w) / 2)
147
+ padding_size[1] = padding_size[0]
148
+ padding_size[3] = padding_size[2]
149
+ padded_tensor = F.pad(rgb_image_gpu[:, y:(y+h), x:(x+w)], pad=tuple(padding_size), mode='constant', value=bg_gray)
150
+
151
+ return padded_tensor
152
+
153
+ def prepare_image(image_path, bg_color, rmbg_net=None):
154
+ if os.path.isfile(image_path):
155
+ img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net)
156
+ img_np = img_tensor.permute(1,2,0).cpu().numpy()
157
+ img_pil = Image.fromarray((img_np*255).astype(np.uint8))
158
+
159
+ return img_pil
patches/triposg/scripts/inference_triposg.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from glob import glob
5
+ from typing import Any, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import trimesh
10
+ from huggingface_hub import snapshot_download
11
+ from PIL import Image
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
16
+ from image_process import prepare_image
17
+ from briarmbg import BriaRMBG
18
+
19
+
20
+ @torch.no_grad()
21
+ def run_triposg(
22
+ pipe: Any,
23
+ image_input: Union[str, Image.Image],
24
+ rmbg_net: Any,
25
+ seed: int,
26
+ num_inference_steps: int = 50,
27
+ guidance_scale: float = 7.0,
28
+ faces: int = -1,
29
+ ) -> trimesh.Scene:
30
+
31
+ img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
32
+
33
+ outputs = pipe(
34
+ image=img_pil,
35
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
36
+ num_inference_steps=num_inference_steps,
37
+ guidance_scale=guidance_scale,
38
+ ).samples[0]
39
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
40
+
41
+ if faces > 0:
42
+ mesh = simplify_mesh(mesh, faces)
43
+
44
+ return mesh
45
+
46
+
47
+ def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
48
+ if mesh.faces.shape[0] > n_faces:
49
+ mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=True)
50
+ mesh = mesh.simplify_quadric_decimation(n_faces)
51
+ return mesh
52
+
53
+
54
+ if __name__ == "__main__":
55
+ device = "cuda"
56
+ dtype = torch.float16
57
+
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--image-input", type=str, required=True)
60
+ parser.add_argument("--output-path", type=str, default="./output.glb")
61
+ parser.add_argument("--seed", type=int, default=42)
62
+ parser.add_argument("--num-inference-steps", type=int, default=50)
63
+ parser.add_argument("--guidance-scale", type=float, default=7.0)
64
+ parser.add_argument("--faces", type=int, default=-1)
65
+ args = parser.parse_args()
66
+
67
+ # download pretrained weights
68
+ triposg_weights_dir = "pretrained_weights/TripoSG"
69
+ rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
70
+ snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
71
+ snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
72
+
73
+ # init rmbg model for background removal
74
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
75
+ rmbg_net.eval()
76
+
77
+ # init tripoSG pipeline
78
+ pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype)
79
+
80
+ # run inference
81
+ run_triposg(
82
+ pipe,
83
+ image_input=args.image_input,
84
+ rmbg_net=rmbg_net,
85
+ seed=args.seed,
86
+ num_inference_steps=args.num_inference_steps,
87
+ guidance_scale=args.guidance_scale,
88
+ faces=args.faces,
89
+ ).export(args.output_path)
90
+ print(f"Mesh saved to {args.output_path}")
patches/triposg/triposg/inference_utils.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import scipy.ndimage
5
+ from skimage import measure
6
+ from einops import repeat
7
+ try:
8
+ from diso import DiffDMC
9
+ except Exception as _diso_err:
10
+ print(f'[TripoSG] diso unavailable ({_diso_err}), using flash fallback')
11
+ DiffDMC = None
12
+ import torch.nn.functional as F
13
+
14
+ from triposg.utils.typing import *
15
+
16
+ def generate_dense_grid_points_gpu(bbox_min: torch.Tensor,
17
+ bbox_max: torch.Tensor,
18
+ octree_depth: int,
19
+ indexing: str = "ij"):
20
+ length = bbox_max - bbox_min
21
+ num_cells = 2 ** octree_depth
22
+ device = bbox_min.device
23
+
24
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(num_cells), dtype=torch.float16, device=device)
25
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(num_cells), dtype=torch.float16, device=device)
26
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(num_cells), dtype=torch.float16, device=device)
27
+
28
+ xs, ys, zs = torch.meshgrid(x, y, z, indexing=indexing)
29
+ xyz = torch.stack((xs, ys, zs), dim=-1)
30
+ xyz = xyz.view(-1, 3)
31
+ grid_size = [int(num_cells), int(num_cells), int(num_cells)]
32
+
33
+ return xyz, grid_size, length
34
+
35
+ def find_mesh_grid_coordinates_fast_gpu(occupancy_grid, n_limits=-1):
36
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
37
+ occupied = core_grid > 0
38
+
39
+ neighbors_unoccupied = (
40
+ (occupancy_grid[:-2, :-2, :-2] < 0)
41
+ | (occupancy_grid[:-2, :-2, 1:-1] < 0)
42
+ | (occupancy_grid[:-2, :-2, 2:] < 0) # x-1, y-1, z-1/0/1
43
+ | (occupancy_grid[:-2, 1:-1, :-2] < 0)
44
+ | (occupancy_grid[:-2, 1:-1, 1:-1] < 0)
45
+ | (occupancy_grid[:-2, 1:-1, 2:] < 0) # x-1, y0, z-1/0/1
46
+ | (occupancy_grid[:-2, 2:, :-2] < 0)
47
+ | (occupancy_grid[:-2, 2:, 1:-1] < 0)
48
+ | (occupancy_grid[:-2, 2:, 2:] < 0) # x-1, y+1, z-1/0/1
49
+ | (occupancy_grid[1:-1, :-2, :-2] < 0)
50
+ | (occupancy_grid[1:-1, :-2, 1:-1] < 0)
51
+ | (occupancy_grid[1:-1, :-2, 2:] < 0) # x0, y-1, z-1/0/1
52
+ | (occupancy_grid[1:-1, 1:-1, :-2] < 0)
53
+ | (occupancy_grid[1:-1, 1:-1, 2:] < 0) # x0, y0, z-1/1
54
+ | (occupancy_grid[1:-1, 2:, :-2] < 0)
55
+ | (occupancy_grid[1:-1, 2:, 1:-1] < 0)
56
+ | (occupancy_grid[1:-1, 2:, 2:] < 0) # x0, y+1, z-1/0/1
57
+ | (occupancy_grid[2:, :-2, :-2] < 0)
58
+ | (occupancy_grid[2:, :-2, 1:-1] < 0)
59
+ | (occupancy_grid[2:, :-2, 2:] < 0) # x+1, y-1, z-1/0/1
60
+ | (occupancy_grid[2:, 1:-1, :-2] < 0)
61
+ | (occupancy_grid[2:, 1:-1, 1:-1] < 0)
62
+ | (occupancy_grid[2:, 1:-1, 2:] < 0) # x+1, y0, z-1/0/1
63
+ | (occupancy_grid[2:, 2:, :-2] < 0)
64
+ | (occupancy_grid[2:, 2:, 1:-1] < 0)
65
+ | (occupancy_grid[2:, 2:, 2:] < 0) # x+1, y+1, z-1/0/1
66
+ )
67
+ core_mesh_coords = torch.nonzero(occupied & neighbors_unoccupied, as_tuple=False) + 1
68
+
69
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
70
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
71
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
72
+ core_mesh_coords = core_mesh_coords[ind]
73
+
74
+ return core_mesh_coords
75
+
76
+ def find_candidates_band(occupancy_grid: torch.Tensor, band_threshold: float, n_limits: int = -1) -> torch.Tensor:
77
+ """
78
+ Returns the coordinates of all voxels in the occupancy_grid where |value| < band_threshold.
79
+
80
+ Args:
81
+ occupancy_grid (torch.Tensor): A 3D tensor of SDF values.
82
+ band_threshold (float): The threshold below which |SDF| must be to include the voxel.
83
+ n_limits (int): Maximum number of points to return (-1 for no limit)
84
+
85
+ Returns:
86
+ torch.Tensor: A 2D tensor of coordinates (N x 3) where each row is [x, y, z].
87
+ """
88
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
89
+ # logits to sdf
90
+ core_grid = torch.sigmoid(core_grid) * 2 - 1
91
+ # Create a boolean mask for all cells in the band
92
+ in_band = torch.abs(core_grid) < band_threshold
93
+
94
+ # Get coordinates of all voxels in the band
95
+ core_mesh_coords = torch.nonzero(in_band, as_tuple=False) + 1
96
+
97
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
98
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
99
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
100
+ core_mesh_coords = core_mesh_coords[ind]
101
+
102
+ return core_mesh_coords
103
+
104
+ def expand_edge_region_fast(edge_coords, grid_size):
105
+ expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=torch.float16, requires_grad=False)
106
+ expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
107
+ if grid_size < 512:
108
+ kernel_size = 5
109
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=2).squeeze()
110
+ else:
111
+ kernel_size = 3
112
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=1).squeeze()
113
+ expanded_coords_low_res = torch.nonzero(pooled_tensor, as_tuple=False).to(torch.int16)
114
+
115
+ expanded_coords_high_res = torch.stack([
116
+ torch.cat((expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1)),
117
+ torch.cat((expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2+1, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2 + 1)),
118
+ torch.cat((expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1))
119
+ ], dim=1)
120
+
121
+ return expanded_coords_high_res
122
+
123
+ def zoom_block(block, scale_factor, order=3):
124
+ block = block.astype(np.float32)
125
+ return scipy.ndimage.zoom(block, scale_factor, order=order)
126
+
127
+ def parallel_zoom(occupancy_grid, scale_factor):
128
+ result = torch.nn.functional.interpolate(occupancy_grid.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor)
129
+ return result.squeeze(0).squeeze(0)
130
+
131
+
132
+ @torch.no_grad()
133
+ def _hierarchical_extract_geometry_impl(geometric_func: Callable,
134
+ device: torch.device,
135
+ bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
136
+ dense_octree_depth: int = 8,
137
+ hierarchical_octree_depth: int = 9,
138
+ ):
139
+ """
140
+
141
+ Args:
142
+ geometric_func:
143
+ device:
144
+ bounds:
145
+ dense_octree_depth:
146
+ hierarchical_octree_depth:
147
+ Returns:
148
+
149
+ """
150
+ if isinstance(bounds, float):
151
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
152
+
153
+ bbox_min = torch.tensor(bounds[0:3]).to(device)
154
+ bbox_max = torch.tensor(bounds[3:6]).to(device)
155
+ bbox_size = bbox_max - bbox_min
156
+
157
+ xyz_samples, grid_size, length = generate_dense_grid_points_gpu(
158
+ bbox_min=bbox_min,
159
+ bbox_max=bbox_max,
160
+ octree_depth=dense_octree_depth,
161
+ indexing="ij"
162
+ )
163
+
164
+ print(f'step 1 query num: {xyz_samples.shape[0]}')
165
+ grid_logits = geometric_func(xyz_samples.unsqueeze(0)).to(torch.float16).view(grid_size[0], grid_size[1], grid_size[2])
166
+ # print(f'step 1 grid_logits shape: {grid_logits.shape}')
167
+ for i in range(hierarchical_octree_depth - dense_octree_depth):
168
+ curr_octree_depth = dense_octree_depth + i + 1
169
+ # upsample
170
+ grid_size = 2**curr_octree_depth
171
+ normalize_offset = grid_size / 2
172
+ high_res_occupancy = parallel_zoom(grid_logits, 2)
173
+
174
+ band_threshold = 1.0
175
+ edge_coords = find_candidates_band(grid_logits, band_threshold)
176
+ expanded_coords = expand_edge_region_fast(edge_coords, grid_size=int(grid_size/2)).to(torch.float16)
177
+ print(f'step {i+2} query num: {len(expanded_coords)}')
178
+ expanded_coords_norm = (expanded_coords - normalize_offset) * (abs(bounds[0]) / normalize_offset)
179
+
180
+ all_logits = None
181
+
182
+ all_logits = geometric_func(expanded_coords_norm.unsqueeze(0)).to(torch.float16)
183
+ all_logits = torch.cat([expanded_coords_norm, all_logits[0]], dim=1)
184
+ # print("all logits shape = ", all_logits.shape)
185
+
186
+ indices = all_logits[..., :3]
187
+ indices = indices * (normalize_offset / abs(bounds[0])) + normalize_offset
188
+ indices = indices.type(torch.IntTensor)
189
+ values = all_logits[:, 3]
190
+ # breakpoint()
191
+ high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
192
+ grid_logits = high_res_occupancy
193
+ torch.cuda.empty_cache()
194
+ mesh_v_f = []
195
+ try:
196
+ print("final grids shape = ", grid_logits.shape)
197
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits.float().cpu().numpy(), 0, method="lewiner")
198
+ vertices = vertices / (2**hierarchical_octree_depth) * bbox_size.cpu().numpy() + bbox_min.cpu().numpy()
199
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
200
+ except Exception as e:
201
+ print(e)
202
+ torch.cuda.empty_cache()
203
+ mesh_v_f = (None, None)
204
+
205
+ return [mesh_v_f]
206
+
207
+
208
+ def hierarchical_extract_geometry(*args, **kwargs):
209
+ """Wrapper: uses DiffDMC-based flash path when diso is available, else marching cubes fallback."""
210
+ if DiffDMC is None:
211
+ # flash_extract_geometry needs latents + vae — forward positional args unchanged
212
+ return flash_extract_geometry(*args, **kwargs)
213
+ return _hierarchical_extract_geometry_impl(*args, **kwargs)
214
+
215
+
216
+ def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
217
+ """
218
+ Args:
219
+ input_tensor: shape [D, D, D], torch.float16
220
+ alpha: isosurface offset
221
+ Returns:
222
+ mask: shape [D, D, D], torch.int32
223
+ """
224
+ device = input_tensor.device
225
+ D = input_tensor.shape[0]
226
+ signed_val = 0.0
227
+
228
+ # add isosurface offset and exclude invalid value
229
+ val = input_tensor + alpha
230
+ valid_mask = val > -9000
231
+
232
+ # obtain neighbors
233
+ def get_neighbor(t, shift, axis):
234
+ if shift == 0:
235
+ return t.clone()
236
+
237
+ pad_dims = [0, 0, 0, 0, 0, 0] # [x_front,x_back,y_front,y_back,z_front,z_back]
238
+
239
+ if axis == 0: # x axis
240
+ pad_idx = 0 if shift > 0 else 1
241
+ pad_dims[pad_idx] = abs(shift)
242
+ elif axis == 1: # y axis
243
+ pad_idx = 2 if shift > 0 else 3
244
+ pad_dims[pad_idx] = abs(shift)
245
+ elif axis == 2: # z axis
246
+ pad_idx = 4 if shift > 0 else 5
247
+ pad_dims[pad_idx] = abs(shift)
248
+
249
+ # Apply padding with replication at boundaries
250
+ padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
251
+
252
+ # Create dynamic slicing indices
253
+ slice_dims = [slice(None)] * 3
254
+ if axis == 0: # x axis
255
+ if shift > 0:
256
+ slice_dims[0] = slice(shift, None)
257
+ else:
258
+ slice_dims[0] = slice(None, shift)
259
+ elif axis == 1: # y axis
260
+ if shift > 0:
261
+ slice_dims[1] = slice(shift, None)
262
+ else:
263
+ slice_dims[1] = slice(None, shift)
264
+ elif axis == 2: # z axis
265
+ if shift > 0:
266
+ slice_dims[2] = slice(shift, None)
267
+ else:
268
+ slice_dims[2] = slice(None, shift)
269
+
270
+ # Apply slicing and restore dimensions
271
+ padded = padded.squeeze(0).squeeze(0)
272
+ sliced = padded[slice_dims]
273
+ return sliced
274
+
275
+ # Get neighbors in all directions
276
+ left = get_neighbor(val, 1, axis=0) # x axis
277
+ right = get_neighbor(val, -1, axis=0)
278
+ back = get_neighbor(val, 1, axis=1) # y axis
279
+ front = get_neighbor(val, -1, axis=1)
280
+ down = get_neighbor(val, 1, axis=2) # z axis
281
+ up = get_neighbor(val, -1, axis=2)
282
+
283
+ # Handle invalid boundary values
284
+ def safe_where(neighbor):
285
+ return torch.where(neighbor > -9000, neighbor, val)
286
+
287
+ left = safe_where(left)
288
+ right = safe_where(right)
289
+ back = safe_where(back)
290
+ front = safe_where(front)
291
+ down = safe_where(down)
292
+ up = safe_where(up)
293
+
294
+ # Calculate sign consistency
295
+ sign = torch.sign(val.to(torch.float32))
296
+ neighbors_sign = torch.stack([
297
+ torch.sign(left.to(torch.float32)),
298
+ torch.sign(right.to(torch.float32)),
299
+ torch.sign(back.to(torch.float32)),
300
+ torch.sign(front.to(torch.float32)),
301
+ torch.sign(down.to(torch.float32)),
302
+ torch.sign(up.to(torch.float32))
303
+ ], dim=0)
304
+
305
+ # Check if all signs are consistent
306
+ same_sign = torch.all(neighbors_sign == sign, dim=0)
307
+
308
+ # Generate final mask
309
+ mask = (~same_sign).to(torch.int32)
310
+ return mask * valid_mask.to(torch.int32)
311
+
312
+
313
+ def generate_dense_grid_points_2(
314
+ bbox_min: np.ndarray,
315
+ bbox_max: np.ndarray,
316
+ octree_resolution: int,
317
+ indexing: str = "ij",
318
+ ):
319
+ length = bbox_max - bbox_min
320
+ num_cells = octree_resolution
321
+
322
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
323
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
324
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
325
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
326
+ xyz = np.stack((xs, ys, zs), axis=-1)
327
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
328
+
329
+ return xyz, grid_size, length
330
+
331
+ @torch.no_grad()
332
+ def flash_extract_geometry(
333
+ latents: torch.FloatTensor,
334
+ vae: Callable,
335
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
336
+ num_chunks: int = 10000,
337
+ mc_level: float = 0.0,
338
+ octree_depth: int = 9,
339
+ min_resolution: int = 63,
340
+ mini_grid_num: int = 4,
341
+ **kwargs,
342
+ ):
343
+ geo_decoder = vae.decoder
344
+ device = latents.device
345
+ dtype = latents.dtype
346
+ # resolution to depth
347
+ octree_resolution = 2 ** octree_depth
348
+ resolutions = []
349
+ if octree_resolution < min_resolution:
350
+ resolutions.append(octree_resolution)
351
+ while octree_resolution >= min_resolution:
352
+ resolutions.append(octree_resolution)
353
+ octree_resolution = octree_resolution // 2
354
+ resolutions.reverse()
355
+ resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
356
+ for i, resolution in enumerate(resolutions[1:]):
357
+ resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
358
+
359
+
360
+ # 1. generate query points
361
+ if isinstance(bounds, float):
362
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
363
+ bbox_min = np.array(bounds[0:3])
364
+ bbox_max = np.array(bounds[3:6])
365
+ bbox_size = bbox_max - bbox_min
366
+
367
+ xyz_samples, grid_size, length = generate_dense_grid_points_2(
368
+ bbox_min=bbox_min,
369
+ bbox_max=bbox_max,
370
+ octree_resolution=resolutions[0],
371
+ indexing="ij"
372
+ )
373
+
374
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
375
+ dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
376
+
377
+ grid_size = np.array(grid_size)
378
+
379
+ # 2. latents to 3d volume
380
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
381
+ batch_size = latents.shape[0]
382
+ mini_grid_size = xyz_samples.shape[0] // mini_grid_num
383
+ xyz_samples = xyz_samples.view(
384
+ mini_grid_num, mini_grid_size,
385
+ mini_grid_num, mini_grid_size,
386
+ mini_grid_num, mini_grid_size, 3
387
+ ).permute(
388
+ 0, 2, 4, 1, 3, 5, 6
389
+ ).reshape(
390
+ -1, mini_grid_size * mini_grid_size * mini_grid_size, 3
391
+ )
392
+ batch_logits = []
393
+ num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
394
+ for start in range(0, xyz_samples.shape[0], num_batchs):
395
+ queries = xyz_samples[start: start + num_batchs, :]
396
+ batch = queries.shape[0]
397
+ batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
398
+ # geo_decoder.set_topk(True)
399
+ geo_decoder.set_topk(False)
400
+ logits = vae.decode(batch_latents, queries.to(dtype=batch_latents.dtype)).sample
401
+ batch_logits.append(logits)
402
+ grid_logits = torch.cat(batch_logits, dim=0).reshape(
403
+ mini_grid_num, mini_grid_num, mini_grid_num,
404
+ mini_grid_size, mini_grid_size,
405
+ mini_grid_size
406
+ ).permute(0, 3, 1, 4, 2, 5).contiguous().view(
407
+ (batch_size, grid_size[0], grid_size[1], grid_size[2])
408
+ )
409
+
410
+ for octree_depth_now in resolutions[1:]:
411
+ grid_size = np.array([octree_depth_now + 1] * 3)
412
+ resolution = bbox_size / octree_depth_now
413
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
414
+ next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
415
+ curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
416
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
417
+
418
+ if octree_depth_now == resolutions[-1]:
419
+ expand_num = 0
420
+ else:
421
+ expand_num = 1
422
+ for i in range(expand_num):
423
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
424
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
425
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
426
+
427
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
428
+ for i in range(2 - expand_num):
429
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
430
+ nidx = torch.where(next_index > 0)
431
+
432
+ next_points = torch.stack(nidx, dim=1)
433
+ next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
434
+ torch.tensor(bbox_min, dtype=torch.float32, device=device))
435
+
436
+ query_grid_num = 6
437
+ min_val = next_points.min(axis=0).values
438
+ max_val = next_points.max(axis=0).values
439
+ vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
440
+ index = torch.floor(vol_queries_index).long()
441
+ index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
442
+ index = index.sort()
443
+ next_points = next_points[index.indices].unsqueeze(0).contiguous()
444
+ unique_values = torch.unique(index.values, return_counts=True)
445
+ grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
446
+ input_grid = [[], []]
447
+ logits_grid_list = []
448
+ start_num = 0
449
+ sum_num = 0
450
+ for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
451
+ if sum_num + count < num_chunks or sum_num == 0:
452
+ sum_num += count
453
+ input_grid[0].append(grid_index)
454
+ input_grid[1].append(count)
455
+ else:
456
+ # geo_decoder.set_topk(input_grid)
457
+ geo_decoder.set_topk(False)
458
+ logits_grid = vae.decode(latents, next_points[:, start_num:start_num + sum_num].to(dtype=latents.dtype)).sample
459
+ start_num = start_num + sum_num
460
+ logits_grid_list.append(logits_grid)
461
+ input_grid = [[grid_index], [count]]
462
+ sum_num = count
463
+ if sum_num > 0:
464
+ # geo_decoder.set_topk(input_grid)
465
+ geo_decoder.set_topk(False)
466
+ logits_grid = vae.decode(latents, next_points[:, start_num:start_num + sum_num].to(dtype=latents.dtype)).sample
467
+ logits_grid_list.append(logits_grid)
468
+ logits_grid = torch.cat(logits_grid_list, dim=1)
469
+ grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
470
+ next_logits[nidx] = grid_logits
471
+ grid_logits = next_logits.unsqueeze(0)
472
+
473
+ grid_logits[grid_logits == -10000.] = float('nan')
474
+ torch.cuda.empty_cache()
475
+ mesh_v_f = []
476
+ grid_logits = grid_logits[0]
477
+ try:
478
+ print("final grids shape = ", grid_logits.shape)
479
+ dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
480
+ sdf = -grid_logits / octree_resolution
481
+ sdf = sdf.to(torch.float32).contiguous()
482
+ vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
483
+ vertices = vertices.detach().cpu().numpy()
484
+ faces = faces.detach().cpu().numpy()[:, ::-1]
485
+ vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
486
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
487
+ except Exception as e:
488
+ print(e)
489
+ torch.cuda.empty_cache()
490
+ mesh_v_f = (None, None)
491
+
492
+ return [mesh_v_f]