Spaces:
Sleeping
Sleeping
| """ | |
| head_replace.py β Replace TripoSG head with DECA-reconstructed head at mesh level. | |
| Requires: trimesh, numpy, scipy, cv2, torch (+ face-alignment via DECA deps) | |
| Optional: pymeshlab (for mesh clean-up) | |
| Usage (standalone): | |
| python head_replace.py --body /tmp/triposg_textured.glb \ | |
| --face /path/to/face.jpg \ | |
| --out /tmp/head_replaced.glb | |
| Returns combined GLB with DECA head geometry + TripoSG body. | |
| """ | |
| import os, sys, argparse, warnings | |
| warnings.filterwarnings('ignore') | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Patch DECA before importing it to avoid pytorch3d dependency | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DECA_ROOT = '/root/DECA' | |
| sys.path.insert(0, DECA_ROOT) | |
| # Stub out the rasterizer so DECA doesn't try to import pytorch3d | |
| import importlib, types | |
| _fake_renderer = types.ModuleType('decalib.utils.renderer') | |
| _fake_renderer.set_rasterizer = lambda t='pytorch3d': None | |
| class _FakeRender: | |
| """No-op renderer β we only need the mesh, not rendered images.""" | |
| def __init__(self, *a, **kw): pass | |
| def to(self, *a, **kw): return self | |
| def __call__(self, *a, **kw): return {'images': None, 'alpha_images': None, | |
| 'normal_images': None, 'grid': None, | |
| 'transformed_normals': None, 'normals': None} | |
| def render_shape(self, *a, **kw): return None, None, None, None | |
| def world2uv(self, *a, **kw): return None | |
| def add_SHlight(self, *a, **kw): return None | |
| _fake_renderer.SRenderY = _FakeRender | |
| sys.modules['decalib.utils.renderer'] = _fake_renderer | |
| # Patch deca.py: make _setup_renderer a no-op when renderer not available | |
| from decalib import deca as _deca_mod | |
| _orig_setup = _deca_mod.DECA._setup_renderer | |
| def _patched_setup(self, model_cfg): | |
| try: | |
| _orig_setup(self, model_cfg) | |
| except Exception as e: | |
| print(f'[head_replace] Renderer disabled ({e})') | |
| self.render = _FakeRender() | |
| # Still load mask / displacement data we need for UV baking | |
| from skimage.io import imread | |
| import torch, torch.nn.functional as F | |
| try: | |
| mask = imread(model_cfg.face_eye_mask_path).astype(np.float32) / 255. | |
| mask = torch.from_numpy(mask[:, :, 0])[None, None, :, :].contiguous() | |
| self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]) | |
| mask2 = imread(model_cfg.face_mask_path).astype(np.float32) / 255. | |
| mask2 = torch.from_numpy(mask2[:, :, 0])[None, None, :, :].contiguous() | |
| self.uv_face_mask = F.interpolate(mask2, [model_cfg.uv_size, model_cfg.uv_size]) | |
| except Exception: | |
| pass | |
| try: | |
| fixed_dis = np.load(model_cfg.fixed_displacement_path) | |
| self.fixed_uv_dis = torch.tensor(fixed_dis).float() | |
| except Exception: | |
| pass | |
| try: | |
| mean_tex_np = imread(model_cfg.mean_tex_path).astype(np.float32) / 255. | |
| mean_tex = torch.from_numpy(mean_tex_np.transpose(2, 0, 1))[None] | |
| self.mean_texture = F.interpolate(mean_tex, [model_cfg.uv_size, model_cfg.uv_size]) | |
| except Exception: | |
| pass | |
| try: | |
| self.dense_template = np.load(model_cfg.dense_template_path, | |
| allow_pickle=True, encoding='latin1').item() | |
| except Exception: | |
| pass | |
| _deca_mod.DECA._setup_renderer = _patched_setup | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FLAME mesh: parse head_template.obj for UV map | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_flame_template(obj_path=os.path.join(DECA_ROOT, 'data', 'head_template.obj')): | |
| """Return (verts, faces, uv_verts, uv_faces) from head_template.obj.""" | |
| verts, uv_verts = [], [] | |
| faces_v, faces_uv = [], [] | |
| for line in open(obj_path): | |
| t = line.split() | |
| if not t: | |
| continue | |
| if t[0] == 'v': | |
| verts.append([float(t[1]), float(t[2]), float(t[3])]) | |
| elif t[0] == 'vt': | |
| uv_verts.append([float(t[1]), float(t[2])]) | |
| elif t[0] == 'f': | |
| vi, uvi = [], [] | |
| for tok in t[1:]: | |
| parts = tok.split('/') | |
| vi.append(int(parts[0]) - 1) | |
| uvi.append(int(parts[1]) - 1 if len(parts) > 1 and parts[1] else 0) | |
| if len(vi) == 3: | |
| faces_v.append(vi) | |
| faces_uv.append(uvi) | |
| return (np.array(verts, dtype=np.float32), | |
| np.array(faces_v, dtype=np.int32), | |
| np.array(uv_verts, dtype=np.float32), | |
| np.array(faces_uv, dtype=np.int32)) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UV texture baking (software rasteriser, no pytorch3d needed) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _bake_uv_texture(verts3d, faces_v, uv_verts, faces_uv, cam, face_img_bgr, tex_size=256): | |
| """ | |
| Project face_img_bgr onto the FLAME UV map using orthographic camera. | |
| verts3d : (N,3) FLAME vertices in world space | |
| cam : (3,) = [scale, tx, ty] orthographic camera | |
| Returns : (tex_size, tex_size, 3) uint8 texture (BGR) | |
| """ | |
| H, W = face_img_bgr.shape[:2] | |
| scale, tx, ty = float(cam[0]), float(cam[1]), float(cam[2]) | |
| # Orthographic project: DECA formula = (vert_2D + [tx,ty]) * scale, then flip y | |
| proj = np.zeros((len(verts3d), 2), dtype=np.float32) | |
| proj[:, 0] = (verts3d[:, 0] + tx) * scale | |
| proj[:, 1] = -((verts3d[:, 1] + ty) * scale) # y-flip matches DECA convention | |
| # Map to pixel coords: image spans proj β [-1,1] β pixel [0, WH] | |
| img_pts = (proj + 1.0) * 0.5 * np.array([W, H], dtype=np.float32) # (N, 2) | |
| # UV pixel coords | |
| uv_px = uv_verts * tex_size # (K, 2) | |
| # Output buffers | |
| tex_acc = np.zeros((tex_size, tex_size, 3), dtype=np.float64) | |
| tex_cnt = np.zeros((tex_size, tex_size), dtype=np.float64) | |
| z_buf = np.full((tex_size, tex_size), -1e9, dtype=np.float64) | |
| # Vectorised rasteriser in UV space: | |
| # For each face, scatter samples from img_pts into uv_px coords. | |
| # Use scipy.interpolate.griddata as a fast splat. | |
| from scipy.interpolate import griddata | |
| # Front-facing mask (z > threshold) β only bake visible faces | |
| z_face = verts3d[faces_v, 2].mean(axis=1) # (M,) mean z per face | |
| front_mask = z_face >= -0.02 # keep front and side faces | |
| # For each face corner, record (uv_px, img_pts) sample | |
| corners_uv = uv_px[faces_uv[front_mask]] # (K, 3, 2) | |
| corners_img = img_pts[faces_v[front_mask]] # (K, 3, 2) | |
| # Flatten to (K*3, 2) | |
| src_uv = corners_uv.reshape(-1, 2) # UV pixel destination | |
| src_img = corners_img.reshape(-1, 2) # image pixel source | |
| # Remove out-of-bounds image samples | |
| valid = ((src_img[:, 0] >= 0) & (src_img[:, 0] < W) & | |
| (src_img[:, 1] >= 0) & (src_img[:, 1] < H)) | |
| src_uv = src_uv[valid] | |
| src_img = src_img[valid] | |
| # Sample face image at src_img positions | |
| ix = np.clip(src_img[:, 0].astype(int), 0, W - 1) | |
| iy = np.clip(src_img[:, 1].astype(int), 0, H - 1) | |
| colours = face_img_bgr[iy, ix].astype(np.float32) # (P, 3) | |
| # Clip UV destinations to texture bounds | |
| uv_dest = np.clip(src_uv, 0, tex_size - 1 - 1e-6).astype(np.float32) | |
| # Build query grid for griddata interpolation | |
| grid_u, grid_v = np.meshgrid(np.arange(tex_size), np.arange(tex_size)) | |
| grid_pts = np.column_stack([grid_u.ravel(), grid_v.ravel()]) | |
| # Interpolate each colour channel | |
| tex_baked = np.zeros((tex_size * tex_size, 3), dtype=np.float32) | |
| for ch in range(3): | |
| ch_vals = griddata(uv_dest, colours[:, ch], grid_pts, | |
| method='linear', fill_value=np.nan) | |
| tex_baked[:, ch] = ch_vals | |
| tex_baked = tex_baked.reshape(tex_size, tex_size, 3) | |
| face_baked_mask = ~np.isnan(tex_baked[:, :, 0]) | |
| # Base texture: mean_texture (skin tone fallback for unsampled regions) | |
| mean_tex_path = os.path.join(DECA_ROOT, 'data', 'mean_texture.jpg') | |
| if os.path.exists(mean_tex_path): | |
| mt = cv2.resize(cv2.imread(mean_tex_path), (tex_size, tex_size)).astype(np.float32) | |
| else: | |
| mt = np.full((tex_size, tex_size, 3), 180.0, dtype=np.float32) | |
| # Blend: baked face over mean texture | |
| result = mt.copy() | |
| result[face_baked_mask] = np.clip(tex_baked[face_baked_mask], 0, 255) | |
| return result.astype(np.uint8) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DECA inference | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_deca(face_img_path, device='cuda'): | |
| """ | |
| Run DECA on face_img_path. | |
| Returns (verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_img_bgr) | |
| """ | |
| import torch | |
| from decalib.deca import DECA | |
| from decalib.utils import config as cfg_module | |
| from decalib.datasets import datasets | |
| cfg = cfg_module.get_cfg_defaults() | |
| cfg.model.use_tex = False | |
| print('[DECA] Loading model...') | |
| deca = DECA(config=cfg, device=device) | |
| deca.eval() | |
| print('[DECA] Preprocessing image...') | |
| testdata = datasets.TestData(face_img_path) | |
| img_tensor = testdata[0]['image'].to(device)[None, ...] | |
| print('[DECA] Encoding...') | |
| with torch.no_grad(): | |
| codedict = deca.encode(img_tensor, use_detail=False) | |
| verts, _, _ = deca.flame( | |
| shape_params=codedict['shape'], | |
| expression_params=codedict['exp'], | |
| pose_params=codedict['pose'] | |
| ) | |
| verts_np = verts[0].cpu().numpy() # (5023, 3) | |
| cam_np = codedict['cam'][0].cpu().numpy() # (3,) | |
| print(f'[DECA] Mesh: {verts_np.shape}, cam={cam_np}') | |
| # Load FLAME UV map | |
| _, faces_v, uv_verts, faces_uv = _load_flame_template() | |
| # Get face image for texture baking (use the cropped/aligned 224x224) | |
| img_np = (img_tensor[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) | |
| img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| print('[DECA] Baking UV texture...') | |
| tex_bgr = _bake_uv_texture(verts_np, faces_v, uv_verts, faces_uv, cam_np, img_bgr, tex_size=256) | |
| return verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Mesh helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _find_neck_height(mesh): | |
| """ | |
| Find the best neck cut height in a body mesh. | |
| Strategy: in the top 40% of the mesh, find the local minimum of | |
| cross-sectional area (the neck is narrower than the head). | |
| Returns the y-value of the cut plane. | |
| """ | |
| verts = mesh.vertices | |
| y_min, y_max = verts[:, 1].min(), verts[:, 1].max() | |
| y_range = y_max - y_min | |
| # Scan [80%, 87%] to find the neck-base narrowing below the face. | |
| # The range [83%, 91%] was picking the crown taper instead of the neck. | |
| y_start = y_min + y_range * 0.80 | |
| y_end = y_min + y_range * 0.87 | |
| steps = 20 | |
| ys = np.linspace(y_start, y_end, steps) | |
| band = y_range * 0.015 | |
| r10_vals = [] | |
| for y in ys: | |
| pts = verts[(verts[:, 1] >= y - band) & (verts[:, 1] <= y + band)] | |
| if len(pts) < 6: | |
| r10_vals.append(1.0); continue | |
| xz = pts[:, [0, 2]] | |
| cx, cz = xz.mean(0) | |
| radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2) | |
| r10_vals.append(float(np.percentile(radii, 10))) | |
| from scipy.ndimage import uniform_filter1d | |
| r10 = uniform_filter1d(np.array(r10_vals), size=3) | |
| neck_idx = int(np.argmin(r10[2:-2])) + 2 | |
| neck_y = float(ys[neck_idx]) | |
| frac = (neck_y - y_min) / y_range | |
| print(f'[neck] Cut height: {neck_y:.4f} (y_range {y_min:.3f}β{y_max:.3f}, frac={frac:.2f})') | |
| return neck_y | |
| def _weld_mesh(mesh): | |
| """ | |
| Merge duplicate vertices (UV-split mesh β geometric mesh). | |
| Returns a new trimesh with welded vertices. | |
| """ | |
| import trimesh | |
| from scipy.spatial import cKDTree | |
| verts = mesh.vertices | |
| tree = cKDTree(verts) | |
| # Build mapping: each vertex β canonical representative | |
| N = len(verts) | |
| mapping = np.arange(N, dtype=np.int64) | |
| pairs = tree.query_pairs(r=1e-5) | |
| for a, b in pairs: | |
| root_a = int(mapping[a]) | |
| root_b = int(mapping[b]) | |
| while mapping[root_a] != root_a: | |
| root_a = int(mapping[root_a]) | |
| while mapping[root_b] != root_b: | |
| root_b = int(mapping[root_b]) | |
| if root_a != root_b: | |
| mapping[root_b] = root_a | |
| # Flatten chains | |
| for i in range(N): | |
| root = int(mapping[i]) | |
| while mapping[root] != root: | |
| root = int(mapping[root]) | |
| mapping[i] = root | |
| # Compact the mapping | |
| unique_ids = np.unique(mapping) | |
| compact = np.full(N, -1, dtype=np.int64) | |
| compact[unique_ids] = np.arange(len(unique_ids)) | |
| new_faces = compact[mapping[mesh.faces]] | |
| new_verts = verts[unique_ids] | |
| return trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False) | |
| def _cut_mesh_below(mesh, y_cut): | |
| """Keep only faces where all vertices are at or below y_cut. Preserves UV/texture.""" | |
| import trimesh | |
| from trimesh.visual.texture import TextureVisuals | |
| v_mask = mesh.vertices[:, 1] <= y_cut | |
| f_keep = np.all(v_mask[mesh.faces], axis=1) | |
| faces_kept = mesh.faces[f_keep] | |
| used_verts = np.unique(faces_kept) | |
| old_to_new = np.full(len(mesh.vertices), -1, dtype=np.int64) | |
| old_to_new[used_verts] = np.arange(len(used_verts)) | |
| new_faces = old_to_new[faces_kept] | |
| new_verts = mesh.vertices[used_verts] | |
| new_mesh = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False) | |
| # Preserve UV + texture if present | |
| if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: | |
| new_mesh.visual = TextureVisuals( | |
| uv=mesh.visual.uv[used_verts], | |
| material=mesh.visual.material) | |
| return new_mesh | |
| def _extract_neck_ring_geometric(mesh, neck_y, n_pts=64, band_frac=0.02): | |
| """ | |
| Extract a neck ring using topological boundary edges near neck_y. | |
| Falls back to angle-sorted vertices if topology is non-manifold. | |
| Works on welded (geometric) meshes. | |
| """ | |
| verts = mesh.vertices | |
| y_range = verts[:, 1].max() - verts[:, 1].min() | |
| band = y_range * band_frac | |
| # --- Try topological boundary near neck_y first --- | |
| edges = np.sort(mesh.edges, axis=1) | |
| u, c2 = np.unique(edges, axis=0, return_counts=True) | |
| be = u[c2 == 1] # boundary edges | |
| # Keep boundary edges where BOTH endpoints are near neck_y | |
| v_near = np.abs(verts[:, 1] - neck_y) <= band * 2 | |
| neck_be = be[v_near[be[:, 0]] & v_near[be[:, 1]]] | |
| if len(neck_be) >= 8: | |
| # Build adjacency and walk loop | |
| adj = {} | |
| for e in neck_be: | |
| adj.setdefault(int(e[0]), []).append(int(e[1])) | |
| adj.setdefault(int(e[1]), []).append(int(e[0])) | |
| # Find the largest connected loop | |
| visited = set() | |
| loops = [] | |
| for start in adj: | |
| if start in visited: continue | |
| loop = [start]; visited.add(start); prev = -1; cur = start | |
| for _ in range(len(neck_be) + 1): | |
| nbrs = [v for v in adj.get(cur, []) if v != prev] | |
| if not nbrs: break | |
| nxt = nbrs[0] | |
| if nxt == start: break | |
| if nxt in visited: break | |
| visited.add(nxt); prev = cur; cur = nxt; loop.append(cur) | |
| loops.append(loop) | |
| if loops: | |
| best = max(loops, key=len) | |
| if len(best) >= 8: | |
| ring_pts = verts[best] | |
| # Snap all ring points to neck_y (smooth the cut plane) | |
| ring_pts = ring_pts.copy() | |
| ring_pts[:, 1] = neck_y | |
| return _resample_loop(ring_pts, n_pts) | |
| # --- Fallback: use inner-cluster (neck column) vertices in the band --- | |
| mask = (verts[:, 1] >= neck_y - band) & (verts[:, 1] <= neck_y + band) | |
| pts = verts[mask] | |
| if len(pts) < 8: | |
| raise ValueError(f'Too few vertices near neck_y={neck_y:.4f}: {len(pts)}') | |
| # Keep only inner-ring vertices (below 35th percentile radius from centroid) | |
| # This excludes the outer face/head surface and keeps only the neck column | |
| xz = pts[:, [0, 2]] | |
| cx, cz = xz.mean(0) | |
| radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2) | |
| thresh = np.percentile(radii, 35) | |
| inner_mask = radii <= thresh | |
| if inner_mask.sum() >= 8: | |
| pts = pts[inner_mask] | |
| # Recompute centroid on inner pts | |
| cx, cz = pts[:, [0, 2]].mean(0) | |
| # Sort by angle in XZ plane | |
| angles = np.arctan2(pts[:, 2] - cz, pts[:, 0] - cx) | |
| pts_sorted = pts[np.argsort(angles)] | |
| pts_sorted = pts_sorted.copy() | |
| pts_sorted[:, 1] = neck_y # snap to cut plane | |
| return _resample_loop(pts_sorted, n_pts) | |
| def _extract_boundary_loop(mesh): | |
| """ | |
| Extract the boundary edge loop (ordered) from a welded mesh. | |
| Returns (N, 3) ordered vertex positions. | |
| """ | |
| # Find boundary edges (edges used by exactly one face) | |
| edges = np.sort(mesh.edges, axis=1) | |
| unique, counts = np.unique(edges, axis=0, return_counts=True) | |
| boundary_edges = unique[counts == 1] | |
| if len(boundary_edges) == 0: | |
| raise ValueError('No boundary edges found β mesh may be closed') | |
| # Build adjacency for boundary edges | |
| adj = {} | |
| for e in boundary_edges: | |
| adj.setdefault(int(e[0]), []).append(int(e[1])) | |
| adj.setdefault(int(e[1]), []).append(int(e[0])) | |
| # Walk the longest connected loop | |
| # Find all loops | |
| visited = set() | |
| loops = [] | |
| for start_v in adj: | |
| if start_v in visited: | |
| continue | |
| loop = [start_v] | |
| visited.add(start_v) | |
| prev = -1 | |
| cur = start_v | |
| for _ in range(len(boundary_edges) + 1): | |
| nbrs = [v for v in adj.get(cur, []) if v != prev] | |
| if not nbrs: | |
| break | |
| nxt = nbrs[0] | |
| if nxt == start_v: | |
| break | |
| if nxt in visited: | |
| break | |
| visited.add(nxt) | |
| prev = cur | |
| cur = nxt | |
| loop.append(cur) | |
| loops.append(loop) | |
| # Use the longest loop | |
| best = max(loops, key=len) | |
| return mesh.vertices[best] | |
| def _resample_loop(loop_pts, N): | |
| """Resample an ordered set of 3D points to exactly N evenly-spaced points.""" | |
| from scipy.interpolate import interp1d | |
| # Arc-length parameterisation | |
| diffs = np.diff(loop_pts, axis=0, prepend=loop_pts[-1:]) | |
| seg_lens = np.linalg.norm(diffs, axis=1) | |
| t = np.cumsum(seg_lens) | |
| t = np.insert(t, 0, 0) | |
| t /= t[-1] | |
| # Close the loop | |
| t[-1] = 1.0 | |
| loop_closed = np.vstack([loop_pts, loop_pts[0]]) | |
| interp = interp1d(t, loop_closed, axis=0) | |
| t_new = np.linspace(0, 1, N, endpoint=False) | |
| return interp(t_new) | |
| def _bridge_loops(loop_a, loop_b): | |
| """ | |
| Create a triangle strip bridging two ordered loops of equal length N. | |
| loop_a, loop_b: (N, 3) vertex positions | |
| Returns (verts, faces) β just the bridge strip as a trimesh-ready array. | |
| """ | |
| N = len(loop_a) | |
| verts = np.vstack([loop_a, loop_b]) # (2N, 3) β a:0..N-1, b:N..2N-1 | |
| faces = [] | |
| for i in range(N): | |
| j = (i + 1) % N | |
| ai, aj = i, j | |
| bi, bj = i + N, j + N | |
| faces.append([ai, aj, bi]) | |
| faces.append([aj, bj, bi]) | |
| return verts, np.array(faces, dtype=np.int32) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DECA head β trimesh | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def deca_to_trimesh(verts_np, faces_v, uv_verts, faces_uv, tex_bgr): | |
| """ | |
| Assemble a trimesh.Trimesh from DECA outputs with UV texture. | |
| Uses per-vertex UV (averaged over face corners sharing each vertex). | |
| """ | |
| import trimesh | |
| from trimesh.visual.texture import TextureVisuals | |
| from trimesh.visual.material import PBRMaterial | |
| # Average face-corner UVs per vertex | |
| N = len(verts_np) | |
| uv_sum = np.zeros((N, 2), dtype=np.float64) | |
| uv_cnt = np.zeros(N, dtype=np.int32) | |
| for fi in range(len(faces_v)): | |
| for ci in range(3): | |
| vi = faces_v[fi, ci] | |
| uvi = faces_uv[fi, ci] | |
| uv_sum[vi] += uv_verts[uvi] | |
| uv_cnt[vi] += 1 | |
| uv_cnt = np.maximum(uv_cnt, 1) | |
| uv_per_vert = (uv_sum / uv_cnt[:, None]).astype(np.float32) | |
| mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_v, process=False) | |
| tex_rgb = cv2.cvtColor(tex_bgr, cv2.COLOR_BGR2RGB) | |
| tex_pil = Image.fromarray(tex_rgb) | |
| try: | |
| mat = PBRMaterial(baseColorTexture=tex_pil) | |
| mesh.visual = TextureVisuals(uv=uv_per_vert, material=mat) | |
| print(f'[deca_to_trimesh] UV attached: {uv_per_vert.shape}, tex={tex_rgb.shape}') | |
| except Exception as e: | |
| print(f'[deca_to_trimesh] UV attach failed ({e}) β using vertex colours') | |
| mesh.visual.vertex_colors = np.tile([200, 175, 155, 255], (len(verts_np), 1)) | |
| return mesh | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main head-replacement function | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def replace_head(body_glb: str, face_img_path: str, out_glb: str, | |
| device: str = 'cuda', bridge_n: int = 64): | |
| """ | |
| Main entry point. | |
| body_glb : path to TripoSG textured GLB | |
| face_img_path : path to reference face image | |
| out_glb : output path for combined GLB | |
| bridge_n : number of vertices in the stitching ring | |
| """ | |
| import trimesh | |
| import torch | |
| # ββ 1. Load body GLB ββββββββββββββββββββββββββββββββββββββββββ | |
| print('[replace_head] Loading body GLB...') | |
| scene = trimesh.load(body_glb) | |
| if isinstance(scene, trimesh.Scene): | |
| body_mesh = trimesh.util.concatenate( | |
| [g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)] | |
| ) | |
| else: | |
| body_mesh = scene | |
| print(f' Body: {len(body_mesh.vertices)} verts, {len(body_mesh.faces)} faces') | |
| # ββ 1b. Weld body mesh (UV-split β geometric) βββββββββββββββββ | |
| print('[replace_head] Welding mesh vertices...') | |
| body_welded = _weld_mesh(body_mesh) | |
| print(f' Welded: {len(body_welded.vertices)} verts (was {len(body_mesh.vertices)})') | |
| # ββ 2. Find neck cut height βββββββββββββββββββββββββββββββββββ | |
| neck_y = _find_neck_height(body_welded) | |
| # ββ 3. Cut body at neck βββββββββββββββββββββββββββββββββββββββ | |
| print('[replace_head] Cutting body at neck...') | |
| # Work on welded mesh for topology; keep original mesh for geometry export | |
| body_lower_welded = _cut_mesh_below(body_welded, neck_y) | |
| body_lower = _cut_mesh_below(body_mesh, neck_y) # keeps original UV/texture | |
| print(f' Body lower: {len(body_lower.vertices)} verts') | |
| # Extract neck ring geometrically (robust for non-manifold UV-split meshes) | |
| body_neck_loop = _extract_neck_ring_geometric(body_welded, neck_y, n_pts=bridge_n) | |
| print(f' Body neck ring: {len(body_neck_loop)} pts (geometric)') | |
| # ββ 4. Run DECA βββββββββββββββββββββββββββββββββββββββββββββββ | |
| print('[replace_head] Running DECA...') | |
| verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr = run_deca(face_img_path, device=device) | |
| # ββ 5. Align DECA head to body coordinate system βββββββββββββ | |
| # TripoSG body is roughly in [-1,1] world space (y-up) | |
| # DECA/FLAME space: head centered around origin, scale β 1.5-2.5 units for full head | |
| # We need to: | |
| # a) Scale the FLAME head to match body scale | |
| # b) Position the FLAME head so its neck base aligns with body neck ring | |
| # Get the bottom of the FLAME head (neck area) | |
| # FLAME template: bottom vertices are the neck boundary ring | |
| flame_mesh_tmp = __import__('trimesh').Trimesh(vertices=verts_np, faces=faces_v, process=False) | |
| try: | |
| flame_neck_loop = _extract_boundary_loop(flame_mesh_tmp) | |
| print(f' FLAME neck ring (topology): {len(flame_neck_loop)} verts') | |
| except Exception as e: | |
| print(f' FLAME boundary loop failed ({e}), using geometric extraction') | |
| # Geometric fallback: bottom 5% of head vertices | |
| flame_neck_y = verts_np[:, 1].min() + (verts_np[:, 1].max() - verts_np[:, 1].min()) * 0.08 | |
| flame_neck_loop = _extract_neck_ring_geometric(flame_mesh_tmp, flame_neck_y, n_pts=bridge_n) | |
| print(f' FLAME neck ring (geometric): {len(flame_neck_loop)} pts') | |
| # ββ 5b. Compute head position using NECK RING centroid βββββββββββββββ | |
| # Directly align FLAME neck ring center β body neck ring center in all 3 axes. | |
| # This is robust regardless of body pose or tilt. | |
| body_neck_center = body_neck_loop.mean(axis=0) | |
| # Estimate head height from WELDED mesh crown (more reliable than UV-split mesh) | |
| welded_y_max = float(body_welded.vertices[:, 1].max()) | |
| body_head_height = welded_y_max - neck_y | |
| flame_neck_center_unscaled = flame_neck_loop.mean(axis=0) | |
| flame_y_min = verts_np[:, 1].min() | |
| flame_y_max = verts_np[:, 1].max() | |
| flame_head_height = flame_y_max - flame_y_min | |
| print(f' Body neck center: {body_neck_center.round(4)}') | |
| print(f' Body head space: {body_head_height:.4f} (neck_y={neck_y:.4f}, crown_y={welded_y_max:.4f})') | |
| print(f' FLAME head height (unscaled): {flame_head_height:.4f}') | |
| print(f' FLAME neck center (unscaled): {flame_neck_center_unscaled.round(4)}') | |
| # Scale FLAME head to match body head height | |
| if flame_head_height > 1e-5: | |
| head_scale = body_head_height / flame_head_height | |
| else: | |
| head_scale = 1.0 | |
| print(f' Head scale: {head_scale:.4f}') | |
| # Translate: FLAME neck ring center β body neck ring center in XZ, | |
| # FLAME mesh bottom (flame_y_min) β neck_y in Y. | |
| # This ensures the head fills the full space from neck_y to body crown. | |
| translate = np.array([ | |
| body_neck_center[0] - flame_neck_center_unscaled[0] * head_scale, | |
| neck_y - flame_y_min * head_scale, | |
| body_neck_center[2] - flame_neck_center_unscaled[2] * head_scale, | |
| ]) | |
| print(f' Translate: {translate.round(4)}') | |
| verts_aligned = verts_np * head_scale + translate | |
| print(f' FLAME aligned y={verts_aligned[:,1].min():.4f}β{verts_aligned[:,1].max():.4f}' | |
| f' x={verts_aligned[:,0].min():.4f}β{verts_aligned[:,0].max():.4f}' | |
| f' z={verts_aligned[:,2].min():.4f}β{verts_aligned[:,2].max():.4f}') | |
| # Extract FLAME neck loop after alignment (at the cut plane y=neck_y) | |
| flame_verts_aligned = verts_aligned | |
| flame_mesh_aligned = __import__('trimesh').Trimesh( | |
| vertices=flame_verts_aligned, faces=faces_v, process=False) | |
| try: | |
| flame_neck_loop_aligned = _extract_boundary_loop(flame_mesh_aligned) | |
| print(f' FLAME neck ring (topology): {len(flame_neck_loop_aligned)} verts') | |
| except Exception: | |
| flame_neck_y_aligned = flame_verts_aligned[:, 1].min() + ( | |
| flame_verts_aligned[:, 1].max() - flame_verts_aligned[:, 1].min()) * 0.05 | |
| flame_neck_loop_aligned = _extract_neck_ring_geometric( | |
| flame_mesh_aligned, flame_neck_y_aligned, n_pts=bridge_n) | |
| print(f' FLAME neck ring (geometric): {len(flame_neck_loop_aligned)} pts') | |
| flame_neck_r = np.linalg.norm(flame_neck_loop_aligned - flame_neck_loop_aligned.mean(0), axis=1).mean() | |
| body_neck_r = np.linalg.norm(body_neck_loop - body_neck_loop.mean(0), axis=1).mean() | |
| print(f' Body neck radius: {body_neck_r:.4f} FLAME neck radius (scaled): {flame_neck_r:.4f}') | |
| # ββ 6. Resample both neck loops to bridge_n points ββββββββββββ | |
| body_loop_r = _resample_loop(body_neck_loop, bridge_n) | |
| flame_loop_r = _resample_loop(flame_neck_loop_aligned, bridge_n) | |
| # Ensure loops are oriented consistently (both CW or both CCW) | |
| # Compute signed area to check orientation | |
| def _loop_orientation(loop): | |
| c = loop.mean(0) | |
| t = loop - c | |
| cross = np.cross(t[:-1], t[1:]) | |
| return float(np.sum(cross[:, 1])) # y-component | |
| o_body = _loop_orientation(body_loop_r) | |
| o_flame = _loop_orientation(flame_loop_r) | |
| if (o_body > 0) != (o_flame > 0): | |
| flame_loop_r = flame_loop_r[::-1] | |
| # ββ 7. Align loop starting points (minimise bridge twist) βββββ | |
| # Match starting vertex: find flame loop point closest to body loop start | |
| dists = np.linalg.norm(flame_loop_r - body_loop_r[0], axis=1) | |
| best_offset = int(np.argmin(dists)) | |
| flame_loop_r = np.roll(flame_loop_r, -best_offset, axis=0) | |
| # ββ 8. Build bridge strip βββββββββββββββββββββββββββββββββββββ | |
| bridge_verts, bridge_faces = _bridge_loops(body_loop_r, flame_loop_r) | |
| bridge_mesh = __import__('trimesh').Trimesh(vertices=bridge_verts, faces=bridge_faces, process=False) | |
| # ββ 9. Combine: body_lower + bridge + FLAME head ββββββββββββββ | |
| # Build FLAME head mesh with texture | |
| head_mesh = deca_to_trimesh(flame_verts_aligned, faces_v, uv_verts, faces_uv, tex_bgr) | |
| # Combine all parts | |
| combined = __import__('trimesh').util.concatenate([body_lower, bridge_mesh, head_mesh]) | |
| combined = __import__('trimesh').Trimesh( | |
| vertices=combined.vertices, | |
| faces=combined.faces, | |
| process=False | |
| ) | |
| # Try to copy body texture to combined if available | |
| try: | |
| if hasattr(body_lower.visual, 'material'): | |
| pass # Keep per-mesh materials β export as scene | |
| except Exception: | |
| pass | |
| # ββ 10. Export ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f'[replace_head] Exporting combined mesh: {len(combined.vertices)} verts...') | |
| os.makedirs(os.path.dirname(out_glb) or '.', exist_ok=True) | |
| # Export as GLB scene with separate submeshes (preserves textures) | |
| try: | |
| import trimesh | |
| scene_out = trimesh.Scene() | |
| scene_out.add_geometry(body_lower, geom_name='body') | |
| scene_out.add_geometry(bridge_mesh, geom_name='bridge') | |
| scene_out.add_geometry(head_mesh, geom_name='head') | |
| scene_out.export(out_glb) | |
| print(f'[replace_head] Saved scene GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)') | |
| except Exception as e: | |
| print(f'[replace_head] Scene export failed ({e}), trying single mesh...') | |
| combined.export(out_glb) | |
| print(f'[replace_head] Saved GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)') | |
| return out_glb | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CLI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == '__main__': | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument('--body', required=True, help='TripoSG body GLB path') | |
| ap.add_argument('--face', required=True, help='Reference face image path') | |
| ap.add_argument('--out', required=True, help='Output GLB path') | |
| ap.add_argument('--bridge', type=int, default=64, help='Bridge ring vertex count') | |
| ap.add_argument('--cpu', action='store_true', help='Use CPU instead of CUDA') | |
| args = ap.parse_args() | |
| device = 'cpu' if args.cpu else ('cuda' if __import__('torch').cuda.is_available() else 'cpu') | |
| replace_head(args.body, args.face, args.out, device=device, bridge_n=args.bridge) | |