Image2Model / pipeline /head_replace.py
Daankular's picture
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
8f1bcd9
"""
head_replace.py β€” Replace TripoSG head with DECA-reconstructed head at mesh level.
Requires: trimesh, numpy, scipy, cv2, torch (+ face-alignment via DECA deps)
Optional: pymeshlab (for mesh clean-up)
Usage (standalone):
python head_replace.py --body /tmp/triposg_textured.glb \
--face /path/to/face.jpg \
--out /tmp/head_replaced.glb
Returns combined GLB with DECA head geometry + TripoSG body.
"""
import os, sys, argparse, warnings
warnings.filterwarnings('ignore')
import numpy as np
import cv2
from PIL import Image
# ──────────────────────────────────────────────────────────────────
# Patch DECA before importing it to avoid pytorch3d dependency
# ──────────────────────────────────────────────────────────────────
DECA_ROOT = '/root/DECA'
sys.path.insert(0, DECA_ROOT)
# Stub out the rasterizer so DECA doesn't try to import pytorch3d
import importlib, types
_fake_renderer = types.ModuleType('decalib.utils.renderer')
_fake_renderer.set_rasterizer = lambda t='pytorch3d': None
class _FakeRender:
"""No-op renderer β€” we only need the mesh, not rendered images."""
def __init__(self, *a, **kw): pass
def to(self, *a, **kw): return self
def __call__(self, *a, **kw): return {'images': None, 'alpha_images': None,
'normal_images': None, 'grid': None,
'transformed_normals': None, 'normals': None}
def render_shape(self, *a, **kw): return None, None, None, None
def world2uv(self, *a, **kw): return None
def add_SHlight(self, *a, **kw): return None
_fake_renderer.SRenderY = _FakeRender
sys.modules['decalib.utils.renderer'] = _fake_renderer
# Patch deca.py: make _setup_renderer a no-op when renderer not available
from decalib import deca as _deca_mod
_orig_setup = _deca_mod.DECA._setup_renderer
def _patched_setup(self, model_cfg):
try:
_orig_setup(self, model_cfg)
except Exception as e:
print(f'[head_replace] Renderer disabled ({e})')
self.render = _FakeRender()
# Still load mask / displacement data we need for UV baking
from skimage.io import imread
import torch, torch.nn.functional as F
try:
mask = imread(model_cfg.face_eye_mask_path).astype(np.float32) / 255.
mask = torch.from_numpy(mask[:, :, 0])[None, None, :, :].contiguous()
self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size])
mask2 = imread(model_cfg.face_mask_path).astype(np.float32) / 255.
mask2 = torch.from_numpy(mask2[:, :, 0])[None, None, :, :].contiguous()
self.uv_face_mask = F.interpolate(mask2, [model_cfg.uv_size, model_cfg.uv_size])
except Exception:
pass
try:
fixed_dis = np.load(model_cfg.fixed_displacement_path)
self.fixed_uv_dis = torch.tensor(fixed_dis).float()
except Exception:
pass
try:
mean_tex_np = imread(model_cfg.mean_tex_path).astype(np.float32) / 255.
mean_tex = torch.from_numpy(mean_tex_np.transpose(2, 0, 1))[None]
self.mean_texture = F.interpolate(mean_tex, [model_cfg.uv_size, model_cfg.uv_size])
except Exception:
pass
try:
self.dense_template = np.load(model_cfg.dense_template_path,
allow_pickle=True, encoding='latin1').item()
except Exception:
pass
_deca_mod.DECA._setup_renderer = _patched_setup
# ──────────────────────────────────────────────────────────────────
# FLAME mesh: parse head_template.obj for UV map
# ──────────────────────────────────────────────────────────────────
def _load_flame_template(obj_path=os.path.join(DECA_ROOT, 'data', 'head_template.obj')):
"""Return (verts, faces, uv_verts, uv_faces) from head_template.obj."""
verts, uv_verts = [], []
faces_v, faces_uv = [], []
for line in open(obj_path):
t = line.split()
if not t:
continue
if t[0] == 'v':
verts.append([float(t[1]), float(t[2]), float(t[3])])
elif t[0] == 'vt':
uv_verts.append([float(t[1]), float(t[2])])
elif t[0] == 'f':
vi, uvi = [], []
for tok in t[1:]:
parts = tok.split('/')
vi.append(int(parts[0]) - 1)
uvi.append(int(parts[1]) - 1 if len(parts) > 1 and parts[1] else 0)
if len(vi) == 3:
faces_v.append(vi)
faces_uv.append(uvi)
return (np.array(verts, dtype=np.float32),
np.array(faces_v, dtype=np.int32),
np.array(uv_verts, dtype=np.float32),
np.array(faces_uv, dtype=np.int32))
# ──────────────────────────────────────────────────────────────────
# UV texture baking (software rasteriser, no pytorch3d needed)
# ──────────────────────────────────────────────────────────────────
def _bake_uv_texture(verts3d, faces_v, uv_verts, faces_uv, cam, face_img_bgr, tex_size=256):
"""
Project face_img_bgr onto the FLAME UV map using orthographic camera.
verts3d : (N,3) FLAME vertices in world space
cam : (3,) = [scale, tx, ty] orthographic camera
Returns : (tex_size, tex_size, 3) uint8 texture (BGR)
"""
H, W = face_img_bgr.shape[:2]
scale, tx, ty = float(cam[0]), float(cam[1]), float(cam[2])
# Orthographic project: DECA formula = (vert_2D + [tx,ty]) * scale, then flip y
proj = np.zeros((len(verts3d), 2), dtype=np.float32)
proj[:, 0] = (verts3d[:, 0] + tx) * scale
proj[:, 1] = -((verts3d[:, 1] + ty) * scale) # y-flip matches DECA convention
# Map to pixel coords: image spans proj ∈ [-1,1] β†’ pixel [0, WH]
img_pts = (proj + 1.0) * 0.5 * np.array([W, H], dtype=np.float32) # (N, 2)
# UV pixel coords
uv_px = uv_verts * tex_size # (K, 2)
# Output buffers
tex_acc = np.zeros((tex_size, tex_size, 3), dtype=np.float64)
tex_cnt = np.zeros((tex_size, tex_size), dtype=np.float64)
z_buf = np.full((tex_size, tex_size), -1e9, dtype=np.float64)
# Vectorised rasteriser in UV space:
# For each face, scatter samples from img_pts into uv_px coords.
# Use scipy.interpolate.griddata as a fast splat.
from scipy.interpolate import griddata
# Front-facing mask (z > threshold) β€” only bake visible faces
z_face = verts3d[faces_v, 2].mean(axis=1) # (M,) mean z per face
front_mask = z_face >= -0.02 # keep front and side faces
# For each face corner, record (uv_px, img_pts) sample
corners_uv = uv_px[faces_uv[front_mask]] # (K, 3, 2)
corners_img = img_pts[faces_v[front_mask]] # (K, 3, 2)
# Flatten to (K*3, 2)
src_uv = corners_uv.reshape(-1, 2) # UV pixel destination
src_img = corners_img.reshape(-1, 2) # image pixel source
# Remove out-of-bounds image samples
valid = ((src_img[:, 0] >= 0) & (src_img[:, 0] < W) &
(src_img[:, 1] >= 0) & (src_img[:, 1] < H))
src_uv = src_uv[valid]
src_img = src_img[valid]
# Sample face image at src_img positions
ix = np.clip(src_img[:, 0].astype(int), 0, W - 1)
iy = np.clip(src_img[:, 1].astype(int), 0, H - 1)
colours = face_img_bgr[iy, ix].astype(np.float32) # (P, 3)
# Clip UV destinations to texture bounds
uv_dest = np.clip(src_uv, 0, tex_size - 1 - 1e-6).astype(np.float32)
# Build query grid for griddata interpolation
grid_u, grid_v = np.meshgrid(np.arange(tex_size), np.arange(tex_size))
grid_pts = np.column_stack([grid_u.ravel(), grid_v.ravel()])
# Interpolate each colour channel
tex_baked = np.zeros((tex_size * tex_size, 3), dtype=np.float32)
for ch in range(3):
ch_vals = griddata(uv_dest, colours[:, ch], grid_pts,
method='linear', fill_value=np.nan)
tex_baked[:, ch] = ch_vals
tex_baked = tex_baked.reshape(tex_size, tex_size, 3)
face_baked_mask = ~np.isnan(tex_baked[:, :, 0])
# Base texture: mean_texture (skin tone fallback for unsampled regions)
mean_tex_path = os.path.join(DECA_ROOT, 'data', 'mean_texture.jpg')
if os.path.exists(mean_tex_path):
mt = cv2.resize(cv2.imread(mean_tex_path), (tex_size, tex_size)).astype(np.float32)
else:
mt = np.full((tex_size, tex_size, 3), 180.0, dtype=np.float32)
# Blend: baked face over mean texture
result = mt.copy()
result[face_baked_mask] = np.clip(tex_baked[face_baked_mask], 0, 255)
return result.astype(np.uint8)
# ──────────────────────────────────────────────────────────────────
# DECA inference
# ──────────────────────────────────────────────────────────────────
def run_deca(face_img_path, device='cuda'):
"""
Run DECA on face_img_path.
Returns (verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_img_bgr)
"""
import torch
from decalib.deca import DECA
from decalib.utils import config as cfg_module
from decalib.datasets import datasets
cfg = cfg_module.get_cfg_defaults()
cfg.model.use_tex = False
print('[DECA] Loading model...')
deca = DECA(config=cfg, device=device)
deca.eval()
print('[DECA] Preprocessing image...')
testdata = datasets.TestData(face_img_path)
img_tensor = testdata[0]['image'].to(device)[None, ...]
print('[DECA] Encoding...')
with torch.no_grad():
codedict = deca.encode(img_tensor, use_detail=False)
verts, _, _ = deca.flame(
shape_params=codedict['shape'],
expression_params=codedict['exp'],
pose_params=codedict['pose']
)
verts_np = verts[0].cpu().numpy() # (5023, 3)
cam_np = codedict['cam'][0].cpu().numpy() # (3,)
print(f'[DECA] Mesh: {verts_np.shape}, cam={cam_np}')
# Load FLAME UV map
_, faces_v, uv_verts, faces_uv = _load_flame_template()
# Get face image for texture baking (use the cropped/aligned 224x224)
img_np = (img_tensor[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
print('[DECA] Baking UV texture...')
tex_bgr = _bake_uv_texture(verts_np, faces_v, uv_verts, faces_uv, cam_np, img_bgr, tex_size=256)
return verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr
# ──────────────────────────────────────────────────────────────────
# Mesh helpers
# ──────────────────────────────────────────────────────────────────
def _find_neck_height(mesh):
"""
Find the best neck cut height in a body mesh.
Strategy: in the top 40% of the mesh, find the local minimum of
cross-sectional area (the neck is narrower than the head).
Returns the y-value of the cut plane.
"""
verts = mesh.vertices
y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
y_range = y_max - y_min
# Scan [80%, 87%] to find the neck-base narrowing below the face.
# The range [83%, 91%] was picking the crown taper instead of the neck.
y_start = y_min + y_range * 0.80
y_end = y_min + y_range * 0.87
steps = 20
ys = np.linspace(y_start, y_end, steps)
band = y_range * 0.015
r10_vals = []
for y in ys:
pts = verts[(verts[:, 1] >= y - band) & (verts[:, 1] <= y + band)]
if len(pts) < 6:
r10_vals.append(1.0); continue
xz = pts[:, [0, 2]]
cx, cz = xz.mean(0)
radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
r10_vals.append(float(np.percentile(radii, 10)))
from scipy.ndimage import uniform_filter1d
r10 = uniform_filter1d(np.array(r10_vals), size=3)
neck_idx = int(np.argmin(r10[2:-2])) + 2
neck_y = float(ys[neck_idx])
frac = (neck_y - y_min) / y_range
print(f'[neck] Cut height: {neck_y:.4f} (y_range {y_min:.3f}–{y_max:.3f}, frac={frac:.2f})')
return neck_y
def _weld_mesh(mesh):
"""
Merge duplicate vertices (UV-split mesh β†’ geometric mesh).
Returns a new trimesh with welded vertices.
"""
import trimesh
from scipy.spatial import cKDTree
verts = mesh.vertices
tree = cKDTree(verts)
# Build mapping: each vertex β†’ canonical representative
N = len(verts)
mapping = np.arange(N, dtype=np.int64)
pairs = tree.query_pairs(r=1e-5)
for a, b in pairs:
root_a = int(mapping[a])
root_b = int(mapping[b])
while mapping[root_a] != root_a:
root_a = int(mapping[root_a])
while mapping[root_b] != root_b:
root_b = int(mapping[root_b])
if root_a != root_b:
mapping[root_b] = root_a
# Flatten chains
for i in range(N):
root = int(mapping[i])
while mapping[root] != root:
root = int(mapping[root])
mapping[i] = root
# Compact the mapping
unique_ids = np.unique(mapping)
compact = np.full(N, -1, dtype=np.int64)
compact[unique_ids] = np.arange(len(unique_ids))
new_faces = compact[mapping[mesh.faces]]
new_verts = verts[unique_ids]
return trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
def _cut_mesh_below(mesh, y_cut):
"""Keep only faces where all vertices are at or below y_cut. Preserves UV/texture."""
import trimesh
from trimesh.visual.texture import TextureVisuals
v_mask = mesh.vertices[:, 1] <= y_cut
f_keep = np.all(v_mask[mesh.faces], axis=1)
faces_kept = mesh.faces[f_keep]
used_verts = np.unique(faces_kept)
old_to_new = np.full(len(mesh.vertices), -1, dtype=np.int64)
old_to_new[used_verts] = np.arange(len(used_verts))
new_faces = old_to_new[faces_kept]
new_verts = mesh.vertices[used_verts]
new_mesh = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
# Preserve UV + texture if present
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
new_mesh.visual = TextureVisuals(
uv=mesh.visual.uv[used_verts],
material=mesh.visual.material)
return new_mesh
def _extract_neck_ring_geometric(mesh, neck_y, n_pts=64, band_frac=0.02):
"""
Extract a neck ring using topological boundary edges near neck_y.
Falls back to angle-sorted vertices if topology is non-manifold.
Works on welded (geometric) meshes.
"""
verts = mesh.vertices
y_range = verts[:, 1].max() - verts[:, 1].min()
band = y_range * band_frac
# --- Try topological boundary near neck_y first ---
edges = np.sort(mesh.edges, axis=1)
u, c2 = np.unique(edges, axis=0, return_counts=True)
be = u[c2 == 1] # boundary edges
# Keep boundary edges where BOTH endpoints are near neck_y
v_near = np.abs(verts[:, 1] - neck_y) <= band * 2
neck_be = be[v_near[be[:, 0]] & v_near[be[:, 1]]]
if len(neck_be) >= 8:
# Build adjacency and walk loop
adj = {}
for e in neck_be:
adj.setdefault(int(e[0]), []).append(int(e[1]))
adj.setdefault(int(e[1]), []).append(int(e[0]))
# Find the largest connected loop
visited = set()
loops = []
for start in adj:
if start in visited: continue
loop = [start]; visited.add(start); prev = -1; cur = start
for _ in range(len(neck_be) + 1):
nbrs = [v for v in adj.get(cur, []) if v != prev]
if not nbrs: break
nxt = nbrs[0]
if nxt == start: break
if nxt in visited: break
visited.add(nxt); prev = cur; cur = nxt; loop.append(cur)
loops.append(loop)
if loops:
best = max(loops, key=len)
if len(best) >= 8:
ring_pts = verts[best]
# Snap all ring points to neck_y (smooth the cut plane)
ring_pts = ring_pts.copy()
ring_pts[:, 1] = neck_y
return _resample_loop(ring_pts, n_pts)
# --- Fallback: use inner-cluster (neck column) vertices in the band ---
mask = (verts[:, 1] >= neck_y - band) & (verts[:, 1] <= neck_y + band)
pts = verts[mask]
if len(pts) < 8:
raise ValueError(f'Too few vertices near neck_y={neck_y:.4f}: {len(pts)}')
# Keep only inner-ring vertices (below 35th percentile radius from centroid)
# This excludes the outer face/head surface and keeps only the neck column
xz = pts[:, [0, 2]]
cx, cz = xz.mean(0)
radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
thresh = np.percentile(radii, 35)
inner_mask = radii <= thresh
if inner_mask.sum() >= 8:
pts = pts[inner_mask]
# Recompute centroid on inner pts
cx, cz = pts[:, [0, 2]].mean(0)
# Sort by angle in XZ plane
angles = np.arctan2(pts[:, 2] - cz, pts[:, 0] - cx)
pts_sorted = pts[np.argsort(angles)]
pts_sorted = pts_sorted.copy()
pts_sorted[:, 1] = neck_y # snap to cut plane
return _resample_loop(pts_sorted, n_pts)
def _extract_boundary_loop(mesh):
"""
Extract the boundary edge loop (ordered) from a welded mesh.
Returns (N, 3) ordered vertex positions.
"""
# Find boundary edges (edges used by exactly one face)
edges = np.sort(mesh.edges, axis=1)
unique, counts = np.unique(edges, axis=0, return_counts=True)
boundary_edges = unique[counts == 1]
if len(boundary_edges) == 0:
raise ValueError('No boundary edges found β€” mesh may be closed')
# Build adjacency for boundary edges
adj = {}
for e in boundary_edges:
adj.setdefault(int(e[0]), []).append(int(e[1]))
adj.setdefault(int(e[1]), []).append(int(e[0]))
# Walk the longest connected loop
# Find all loops
visited = set()
loops = []
for start_v in adj:
if start_v in visited:
continue
loop = [start_v]
visited.add(start_v)
prev = -1
cur = start_v
for _ in range(len(boundary_edges) + 1):
nbrs = [v for v in adj.get(cur, []) if v != prev]
if not nbrs:
break
nxt = nbrs[0]
if nxt == start_v:
break
if nxt in visited:
break
visited.add(nxt)
prev = cur
cur = nxt
loop.append(cur)
loops.append(loop)
# Use the longest loop
best = max(loops, key=len)
return mesh.vertices[best]
def _resample_loop(loop_pts, N):
"""Resample an ordered set of 3D points to exactly N evenly-spaced points."""
from scipy.interpolate import interp1d
# Arc-length parameterisation
diffs = np.diff(loop_pts, axis=0, prepend=loop_pts[-1:])
seg_lens = np.linalg.norm(diffs, axis=1)
t = np.cumsum(seg_lens)
t = np.insert(t, 0, 0)
t /= t[-1]
# Close the loop
t[-1] = 1.0
loop_closed = np.vstack([loop_pts, loop_pts[0]])
interp = interp1d(t, loop_closed, axis=0)
t_new = np.linspace(0, 1, N, endpoint=False)
return interp(t_new)
def _bridge_loops(loop_a, loop_b):
"""
Create a triangle strip bridging two ordered loops of equal length N.
loop_a, loop_b: (N, 3) vertex positions
Returns (verts, faces) β€” just the bridge strip as a trimesh-ready array.
"""
N = len(loop_a)
verts = np.vstack([loop_a, loop_b]) # (2N, 3) β€” a:0..N-1, b:N..2N-1
faces = []
for i in range(N):
j = (i + 1) % N
ai, aj = i, j
bi, bj = i + N, j + N
faces.append([ai, aj, bi])
faces.append([aj, bj, bi])
return verts, np.array(faces, dtype=np.int32)
# ──────────────────────────────────────────────────────────────────
# DECA head β†’ trimesh
# ──────────────────────────────────────────────────────────────────
def deca_to_trimesh(verts_np, faces_v, uv_verts, faces_uv, tex_bgr):
"""
Assemble a trimesh.Trimesh from DECA outputs with UV texture.
Uses per-vertex UV (averaged over face corners sharing each vertex).
"""
import trimesh
from trimesh.visual.texture import TextureVisuals
from trimesh.visual.material import PBRMaterial
# Average face-corner UVs per vertex
N = len(verts_np)
uv_sum = np.zeros((N, 2), dtype=np.float64)
uv_cnt = np.zeros(N, dtype=np.int32)
for fi in range(len(faces_v)):
for ci in range(3):
vi = faces_v[fi, ci]
uvi = faces_uv[fi, ci]
uv_sum[vi] += uv_verts[uvi]
uv_cnt[vi] += 1
uv_cnt = np.maximum(uv_cnt, 1)
uv_per_vert = (uv_sum / uv_cnt[:, None]).astype(np.float32)
mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_v, process=False)
tex_rgb = cv2.cvtColor(tex_bgr, cv2.COLOR_BGR2RGB)
tex_pil = Image.fromarray(tex_rgb)
try:
mat = PBRMaterial(baseColorTexture=tex_pil)
mesh.visual = TextureVisuals(uv=uv_per_vert, material=mat)
print(f'[deca_to_trimesh] UV attached: {uv_per_vert.shape}, tex={tex_rgb.shape}')
except Exception as e:
print(f'[deca_to_trimesh] UV attach failed ({e}) β€” using vertex colours')
mesh.visual.vertex_colors = np.tile([200, 175, 155, 255], (len(verts_np), 1))
return mesh
# ──────────────────────────────────────────────────────────────────
# Main head-replacement function
# ──────────────────────────────────────────────────────────────────
def replace_head(body_glb: str, face_img_path: str, out_glb: str,
device: str = 'cuda', bridge_n: int = 64):
"""
Main entry point.
body_glb : path to TripoSG textured GLB
face_img_path : path to reference face image
out_glb : output path for combined GLB
bridge_n : number of vertices in the stitching ring
"""
import trimesh
import torch
# ── 1. Load body GLB ──────────────────────────────────────────
print('[replace_head] Loading body GLB...')
scene = trimesh.load(body_glb)
if isinstance(scene, trimesh.Scene):
body_mesh = trimesh.util.concatenate(
[g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)]
)
else:
body_mesh = scene
print(f' Body: {len(body_mesh.vertices)} verts, {len(body_mesh.faces)} faces')
# ── 1b. Weld body mesh (UV-split β†’ geometric) ─────────────────
print('[replace_head] Welding mesh vertices...')
body_welded = _weld_mesh(body_mesh)
print(f' Welded: {len(body_welded.vertices)} verts (was {len(body_mesh.vertices)})')
# ── 2. Find neck cut height ───────────────────────────────────
neck_y = _find_neck_height(body_welded)
# ── 3. Cut body at neck ───────────────────────────────────────
print('[replace_head] Cutting body at neck...')
# Work on welded mesh for topology; keep original mesh for geometry export
body_lower_welded = _cut_mesh_below(body_welded, neck_y)
body_lower = _cut_mesh_below(body_mesh, neck_y) # keeps original UV/texture
print(f' Body lower: {len(body_lower.vertices)} verts')
# Extract neck ring geometrically (robust for non-manifold UV-split meshes)
body_neck_loop = _extract_neck_ring_geometric(body_welded, neck_y, n_pts=bridge_n)
print(f' Body neck ring: {len(body_neck_loop)} pts (geometric)')
# ── 4. Run DECA ───────────────────────────────────────────────
print('[replace_head] Running DECA...')
verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr = run_deca(face_img_path, device=device)
# ── 5. Align DECA head to body coordinate system ─────────────
# TripoSG body is roughly in [-1,1] world space (y-up)
# DECA/FLAME space: head centered around origin, scale β‰ˆ 1.5-2.5 units for full head
# We need to:
# a) Scale the FLAME head to match body scale
# b) Position the FLAME head so its neck base aligns with body neck ring
# Get the bottom of the FLAME head (neck area)
# FLAME template: bottom vertices are the neck boundary ring
flame_mesh_tmp = __import__('trimesh').Trimesh(vertices=verts_np, faces=faces_v, process=False)
try:
flame_neck_loop = _extract_boundary_loop(flame_mesh_tmp)
print(f' FLAME neck ring (topology): {len(flame_neck_loop)} verts')
except Exception as e:
print(f' FLAME boundary loop failed ({e}), using geometric extraction')
# Geometric fallback: bottom 5% of head vertices
flame_neck_y = verts_np[:, 1].min() + (verts_np[:, 1].max() - verts_np[:, 1].min()) * 0.08
flame_neck_loop = _extract_neck_ring_geometric(flame_mesh_tmp, flame_neck_y, n_pts=bridge_n)
print(f' FLAME neck ring (geometric): {len(flame_neck_loop)} pts')
# ── 5b. Compute head position using NECK RING centroid ───────────────
# Directly align FLAME neck ring center β†’ body neck ring center in all 3 axes.
# This is robust regardless of body pose or tilt.
body_neck_center = body_neck_loop.mean(axis=0)
# Estimate head height from WELDED mesh crown (more reliable than UV-split mesh)
welded_y_max = float(body_welded.vertices[:, 1].max())
body_head_height = welded_y_max - neck_y
flame_neck_center_unscaled = flame_neck_loop.mean(axis=0)
flame_y_min = verts_np[:, 1].min()
flame_y_max = verts_np[:, 1].max()
flame_head_height = flame_y_max - flame_y_min
print(f' Body neck center: {body_neck_center.round(4)}')
print(f' Body head space: {body_head_height:.4f} (neck_y={neck_y:.4f}, crown_y={welded_y_max:.4f})')
print(f' FLAME head height (unscaled): {flame_head_height:.4f}')
print(f' FLAME neck center (unscaled): {flame_neck_center_unscaled.round(4)}')
# Scale FLAME head to match body head height
if flame_head_height > 1e-5:
head_scale = body_head_height / flame_head_height
else:
head_scale = 1.0
print(f' Head scale: {head_scale:.4f}')
# Translate: FLAME neck ring center β†’ body neck ring center in XZ,
# FLAME mesh bottom (flame_y_min) β†’ neck_y in Y.
# This ensures the head fills the full space from neck_y to body crown.
translate = np.array([
body_neck_center[0] - flame_neck_center_unscaled[0] * head_scale,
neck_y - flame_y_min * head_scale,
body_neck_center[2] - flame_neck_center_unscaled[2] * head_scale,
])
print(f' Translate: {translate.round(4)}')
verts_aligned = verts_np * head_scale + translate
print(f' FLAME aligned y={verts_aligned[:,1].min():.4f}β†’{verts_aligned[:,1].max():.4f}'
f' x={verts_aligned[:,0].min():.4f}β†’{verts_aligned[:,0].max():.4f}'
f' z={verts_aligned[:,2].min():.4f}β†’{verts_aligned[:,2].max():.4f}')
# Extract FLAME neck loop after alignment (at the cut plane y=neck_y)
flame_verts_aligned = verts_aligned
flame_mesh_aligned = __import__('trimesh').Trimesh(
vertices=flame_verts_aligned, faces=faces_v, process=False)
try:
flame_neck_loop_aligned = _extract_boundary_loop(flame_mesh_aligned)
print(f' FLAME neck ring (topology): {len(flame_neck_loop_aligned)} verts')
except Exception:
flame_neck_y_aligned = flame_verts_aligned[:, 1].min() + (
flame_verts_aligned[:, 1].max() - flame_verts_aligned[:, 1].min()) * 0.05
flame_neck_loop_aligned = _extract_neck_ring_geometric(
flame_mesh_aligned, flame_neck_y_aligned, n_pts=bridge_n)
print(f' FLAME neck ring (geometric): {len(flame_neck_loop_aligned)} pts')
flame_neck_r = np.linalg.norm(flame_neck_loop_aligned - flame_neck_loop_aligned.mean(0), axis=1).mean()
body_neck_r = np.linalg.norm(body_neck_loop - body_neck_loop.mean(0), axis=1).mean()
print(f' Body neck radius: {body_neck_r:.4f} FLAME neck radius (scaled): {flame_neck_r:.4f}')
# ── 6. Resample both neck loops to bridge_n points ────────────
body_loop_r = _resample_loop(body_neck_loop, bridge_n)
flame_loop_r = _resample_loop(flame_neck_loop_aligned, bridge_n)
# Ensure loops are oriented consistently (both CW or both CCW)
# Compute signed area to check orientation
def _loop_orientation(loop):
c = loop.mean(0)
t = loop - c
cross = np.cross(t[:-1], t[1:])
return float(np.sum(cross[:, 1])) # y-component
o_body = _loop_orientation(body_loop_r)
o_flame = _loop_orientation(flame_loop_r)
if (o_body > 0) != (o_flame > 0):
flame_loop_r = flame_loop_r[::-1]
# ── 7. Align loop starting points (minimise bridge twist) ─────
# Match starting vertex: find flame loop point closest to body loop start
dists = np.linalg.norm(flame_loop_r - body_loop_r[0], axis=1)
best_offset = int(np.argmin(dists))
flame_loop_r = np.roll(flame_loop_r, -best_offset, axis=0)
# ── 8. Build bridge strip ─────────────────────────────────────
bridge_verts, bridge_faces = _bridge_loops(body_loop_r, flame_loop_r)
bridge_mesh = __import__('trimesh').Trimesh(vertices=bridge_verts, faces=bridge_faces, process=False)
# ── 9. Combine: body_lower + bridge + FLAME head ──────────────
# Build FLAME head mesh with texture
head_mesh = deca_to_trimesh(flame_verts_aligned, faces_v, uv_verts, faces_uv, tex_bgr)
# Combine all parts
combined = __import__('trimesh').util.concatenate([body_lower, bridge_mesh, head_mesh])
combined = __import__('trimesh').Trimesh(
vertices=combined.vertices,
faces=combined.faces,
process=False
)
# Try to copy body texture to combined if available
try:
if hasattr(body_lower.visual, 'material'):
pass # Keep per-mesh materials β€” export as scene
except Exception:
pass
# ── 10. Export ────────────────────────────────────────────────
print(f'[replace_head] Exporting combined mesh: {len(combined.vertices)} verts...')
os.makedirs(os.path.dirname(out_glb) or '.', exist_ok=True)
# Export as GLB scene with separate submeshes (preserves textures)
try:
import trimesh
scene_out = trimesh.Scene()
scene_out.add_geometry(body_lower, geom_name='body')
scene_out.add_geometry(bridge_mesh, geom_name='bridge')
scene_out.add_geometry(head_mesh, geom_name='head')
scene_out.export(out_glb)
print(f'[replace_head] Saved scene GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
except Exception as e:
print(f'[replace_head] Scene export failed ({e}), trying single mesh...')
combined.export(out_glb)
print(f'[replace_head] Saved GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
return out_glb
# ──────────────────────────────────────────────────────────────────
# CLI
# ──────────────────────────────────────────────────────────────────
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('--body', required=True, help='TripoSG body GLB path')
ap.add_argument('--face', required=True, help='Reference face image path')
ap.add_argument('--out', required=True, help='Output GLB path')
ap.add_argument('--bridge', type=int, default=64, help='Bridge ring vertex count')
ap.add_argument('--cpu', action='store_true', help='Use CPU instead of CUDA')
args = ap.parse_args()
device = 'cpu' if args.cpu else ('cuda' if __import__('torch').cuda.is_available() else 'cpu')
replace_head(args.body, args.face, args.out, device=device, bridge_n=args.bridge)