Image2Model / pipeline /face_transplant.py
Daankular's picture
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
8f1bcd9
"""
face_transplant.py
==================
Replace the face/head region of a rigged UniRig GLB with a higher-detail
PSHuman mesh, while preserving the skeleton, rig, and skinning weights.
Algorithm
---------
1. Parse rigged GLB → vertices, faces, UVs, JOINTS_0, WEIGHTS_0, bone list
2. Identify head vertices → any vert whose dominant bone is in HEAD_BONES
3. Load PSHuman mesh (OBJ or GLB, no rig)
4. Align PSHuman head to UniRig head bounding box (scale + translate)
5. Transfer skinning weights to PSHuman verts via K-nearest-neighbour from
UniRig head verts (scipy KDTree, weighted average)
6. Retract UniRig face verts slightly inward so PSHuman sits on top cleanly
7. Rebuild the GLB with two mesh primitives:
- Primitive 0 : UniRig body (face verts retracted)
- Primitive 1 : PSHuman face (new, with transferred weights)
8. Write output GLB
Usage
-----
python -m pipeline.face_transplant \\
--body rigged_body.glb \\
--face pshuman_output.obj \\
--output rigged_body_with_pshuman_face.glb
Optionally supply --head-bones as comma-separated bone-name substrings
(default: head,Head,skull). Any bone whose name contains one of these
substrings is treated as a head bone.
Requires: pygltflib numpy scipy trimesh (pip install each)
"""
from __future__ import annotations
import argparse
import base64
import struct
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from scipy.spatial import KDTree
import trimesh
# ---------------------------------------------------------------------------
# GLB low-level helpers (subset of Retarget/io/gltf_io.py re-used here)
# ---------------------------------------------------------------------------
try:
import pygltflib
except ImportError:
raise ImportError("pip install pygltflib")
def _read_accessor_raw(gltf: pygltflib.GLTF2, accessor_idx: int) -> np.ndarray:
acc = gltf.accessors[accessor_idx]
bv = gltf.bufferViews[acc.bufferView]
buf = gltf.buffers[bv.buffer]
if buf.uri and buf.uri.startswith("data:"):
_, b64 = buf.uri.split(",", 1)
raw = base64.b64decode(b64)
elif buf.uri:
base_dir = Path(gltf._path).parent if getattr(gltf, "_path", None) else Path(".")
raw = (base_dir / buf.uri).read_bytes()
else:
raw = bytes(gltf.binary_blob())
type_nc = {"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4, "MAT4": 16}
fmt_map = {5120: "b", 5121: "B", 5122: "h", 5123: "H", 5125: "I", 5126: "f"}
n_comp = type_nc[acc.type]
fmt = fmt_map[acc.componentType]
item_sz = struct.calcsize(fmt) * n_comp
stride = bv.byteStride or item_sz
start = bv.byteOffset + (acc.byteOffset or 0)
items = []
for i in range(acc.count):
offset = start + i * stride
vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
items.append(vals)
arr = np.array(items)
if arr.ndim == 2 and arr.shape[1] == 1:
arr = arr[:, 0]
return arr
def _accessor_dtype(gltf: pygltflib.GLTF2, accessor_idx: int):
fmt_map = {5120: np.int8, 5121: np.uint8, 5122: np.int16, 5123: np.uint16,
5125: np.uint32, 5126: np.float32}
return fmt_map[gltf.accessors[accessor_idx].componentType]
# ---------------------------------------------------------------------------
# Mesh extraction
# ---------------------------------------------------------------------------
class GLBMesh:
"""
All data from the first skin's first mesh primitive in a GLB.
"""
def __init__(self, path: str):
self.path = path
gltf = pygltflib.GLTF2().load(path)
gltf._path = path
self.gltf = gltf
if not gltf.skins:
raise ValueError("No skin found in GLB — is this a rigged file?")
self.skin = gltf.skins[0]
self.joint_names: List[str] = [gltf.nodes[j].name or f"joint_{k}"
for k, j in enumerate(self.skin.joints)]
# Find a mesh node that uses this skin
self.mesh_prim, self.mesh_node_idx = self._find_skinned_prim()
attrs = self.mesh_prim.attributes
self.verts = _read_accessor_raw(gltf, attrs.POSITION).astype(np.float32)
self.normals = (_read_accessor_raw(gltf, attrs.NORMAL).astype(np.float32)
if attrs.NORMAL is not None else None)
self.uvs = (_read_accessor_raw(gltf, attrs.TEXCOORD_0).astype(np.float32)
if attrs.TEXCOORD_0 is not None else None)
self.faces = _read_accessor_raw(gltf, self.mesh_prim.indices).astype(np.int32).reshape(-1, 3)
# Skinning — may be JOINTS_0 / WEIGHTS_0 (uint8/uint16 + float)
self.joints4 = None
self.weights4 = None
if attrs.JOINTS_0 is not None:
self.joints4 = _read_accessor_raw(gltf, attrs.JOINTS_0).astype(np.int32)
self.weights4 = _read_accessor_raw(gltf, attrs.WEIGHTS_0).astype(np.float32)
# Material index (carry over to output)
self.material_idx = self.mesh_prim.material
def _find_skinned_prim(self):
skin_node_indices = set(self.skin.joints)
# find mesh node that references this skin
for ni, node in enumerate(self.gltf.nodes):
if node.skin == 0 and node.mesh is not None:
mesh = self.gltf.meshes[node.mesh]
return mesh.primitives[0], ni
# fallback: first mesh node
for ni, node in enumerate(self.gltf.nodes):
if node.mesh is not None:
mesh = self.gltf.meshes[node.mesh]
return mesh.primitives[0], ni
raise ValueError("No mesh primitive found")
def head_bone_indices(self, substrings=("head", "Head", "skull", "Skull", "neck", "Neck")) -> List[int]:
"""Return joint indices (into self.joint_names) matching any substring.
Falls back to positional heuristic (highest-Y dominant bone) when no
bone names match (e.g. generic bone_0/bone_1 naming from UniRig)."""
result = []
for i, name in enumerate(self.joint_names):
if any(s in name for s in substrings):
result.append(i)
if not result and self.joints4 is not None and self.weights4 is not None:
# Positional fallback: pick bone whose dominant vertices have highest avg Y.
n_bones = len(self.joint_names)
bone_y_sum = np.zeros(n_bones)
bone_y_cnt = np.zeros(n_bones, dtype=np.int32)
for vi in range(len(self.verts)):
dom = int(self.joints4[vi, np.argmax(self.weights4[vi])])
bone_y_sum[dom] += self.verts[vi, 1]
bone_y_cnt[dom] += 1
with np.errstate(invalid='ignore'):
bone_y_avg = np.where(bone_y_cnt > 0, bone_y_sum / bone_y_cnt, -np.inf)
top = int(np.argmax(bone_y_avg))
print(f"[face_transplant] No named head bones; positional fallback: "
f"bone {top} ({self.joint_names[top]}, avg_y={bone_y_avg[top]:.3f})")
result = [top]
return result
# ---------------------------------------------------------------------------
# Face-region identification
# ---------------------------------------------------------------------------
def find_face_verts(glb_mesh: GLBMesh, head_joint_indices: List[int],
weight_threshold: float = 0.35) -> np.ndarray:
"""
Return boolean mask of face/head vertices:
any vert whose total weight on head joints exceeds weight_threshold.
"""
if glb_mesh.joints4 is None:
raise ValueError("Mesh has no skinning weights — cannot identify face region")
n = len(glb_mesh.verts)
mask = np.zeros(n, dtype=bool)
head_set = set(head_joint_indices)
for vi in range(n):
total_head_w = 0.0
for c in range(4):
j = glb_mesh.joints4[vi, c]
w = glb_mesh.weights4[vi, c]
if j in head_set:
total_head_w += w
if total_head_w >= weight_threshold:
mask[vi] = True
return mask
# ---------------------------------------------------------------------------
# PSHuman mesh loading + alignment
# ---------------------------------------------------------------------------
def _crop_to_head(mesh: trimesh.Trimesh, head_fraction: float = 0.22) -> trimesh.Trimesh:
"""
Keep only the top head_fraction of the PSHuman body mesh by Y coordinate.
PSHuman produces a full-body mesh; we only want the head/face portion.
"""
y = mesh.vertices[:, 1]
threshold = y.max() - (y.max() - y.min()) * head_fraction
vert_keep = y >= threshold
face_keep = vert_keep[mesh.faces].all(axis=1)
kept_faces = mesh.faces[face_keep]
used = np.unique(kept_faces)
remap = np.full(len(mesh.vertices), -1, dtype=np.int32)
remap[used] = np.arange(len(used))
new_verts = mesh.vertices[used].astype(np.float32)
new_faces = remap[kept_faces]
result = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
result.visual = trimesh.visual.TextureVisuals(uv=np.array(mesh.visual.uv)[used])
print(f"[face_transplant] PSHuman head crop ({head_fraction*100:.0f}%): "
f"{len(mesh.vertices)}{len(new_verts)} verts (Y ≥ {threshold:.3f})")
return result
def load_and_align_pshuman(pshuman_path: str, target_verts: np.ndarray) -> trimesh.Trimesh:
"""
Load PSHuman mesh (OBJ/GLB/PLY), crop to head region, then scale+translate
to fit the bounding box of target_verts (UniRig head verts).
"""
mesh: trimesh.Trimesh = trimesh.load(pshuman_path, force="mesh", process=False)
print(f"[face_transplant] PSHuman mesh: {len(mesh.vertices)} verts, {len(mesh.faces)} faces")
# PSHuman is full-body — crop to just the head before aligning
mesh = _crop_to_head(mesh)
# Target bbox from UniRig head region
tgt_min = target_verts.min(axis=0)
tgt_max = target_verts.max(axis=0)
tgt_ctr = (tgt_min + tgt_max) * 0.5
tgt_ext = (tgt_max - tgt_min)
src_min = mesh.vertices.min(axis=0).astype(np.float32)
src_max = mesh.vertices.max(axis=0).astype(np.float32)
src_ctr = (src_min + src_max) * 0.5
src_ext = (src_max - src_min)
# Uniform scale: match the largest axis of the target
dominant = np.argmax(tgt_ext)
scale = float(tgt_ext[dominant]) / float(src_ext[dominant] + 1e-9)
verts = mesh.vertices.astype(np.float32).copy()
verts = (verts - src_ctr) * scale + tgt_ctr
mesh.vertices = verts
print(f"[face_transplant] PSHuman aligned: scale={scale:.4f}, center={tgt_ctr}")
return mesh
# ---------------------------------------------------------------------------
# Weight transfer via KDTree
# ---------------------------------------------------------------------------
def transfer_weights(
donor_verts: np.ndarray, # (M, 3) UniRig face verts
donor_joints: np.ndarray, # (M, 4) uint16
donor_weights: np.ndarray, # (M, 4) float32
recipient_verts: np.ndarray, # (N, 3) PSHuman face verts
k: int = 5,
) -> Tuple[np.ndarray, np.ndarray]:
"""
K-nearest-neighbour weight transfer.
Returns (joints4, weights4) for recipient_verts.
"""
tree = KDTree(donor_verts)
dists, idxs = tree.query(recipient_verts, k=k) # (N, k)
N = len(recipient_verts)
n_joints_total = int(donor_joints.max()) + 1
# Build dense per-recipient weight vector
dense = np.zeros((N, n_joints_total), dtype=np.float64)
for ki in range(k):
w_dist = 1.0 / (dists[:, ki] + 1e-8) # inverse-distance
for vi in range(N):
di = idxs[vi, ki]
for c in range(4):
j = donor_joints[di, c]
w = donor_weights[di, c]
dense[vi, j] += w * w_dist[vi]
# Re-normalise rows
row_sum = dense.sum(axis=1, keepdims=True) + 1e-12
dense /= row_sum
# Pack back into 4-bone format (top-4 by weight)
out_joints = np.zeros((N, 4), dtype=np.uint16)
out_weights = np.zeros((N, 4), dtype=np.float32)
for vi in range(N):
top4 = np.argsort(dense[vi])[-4:][::-1]
total = dense[vi, top4].sum() + 1e-12
for c, j in enumerate(top4):
out_joints[vi, c] = j
out_weights[vi, c] = dense[vi, j] / total
return out_joints, out_weights
# ---------------------------------------------------------------------------
# GLB rebuild
# ---------------------------------------------------------------------------
def _pack_buffer_view(data_bytes: bytes, target: list, byte_offset: int,
byte_stride: Optional[int] = None) -> Tuple[int, int]:
"""
Append data_bytes to target buffer, return (buffer_view_index, new_offset).
"""
bv = pygltflib.BufferView(
buffer=0,
byteOffset=byte_offset,
byteLength=len(data_bytes),
)
if byte_stride:
bv.byteStride = byte_stride
return bv, byte_offset + len(data_bytes)
def _make_accessor(component_type: int, type_str: str, count: int,
bv_idx: int, min_vals=None, max_vals=None) -> pygltflib.Accessor:
acc = pygltflib.Accessor(
bufferView=bv_idx,
byteOffset=0,
componentType=component_type,
count=count,
type=type_str,
)
if min_vals is not None:
acc.min = [float(v) for v in min_vals]
if max_vals is not None:
acc.max = [float(v) for v in max_vals]
return acc
FLOAT32 = pygltflib.FLOAT # 5126
UINT16 = pygltflib.UNSIGNED_SHORT # 5123
UINT32 = pygltflib.UNSIGNED_INT # 5125
UBYTE = pygltflib.UNSIGNED_BYTE # 5121
def transplant_face(
body_glb_path: str,
pshuman_mesh_path: str,
output_path: str,
head_bone_substrings: Tuple[str, ...] = ("head", "Head", "skull", "Skull"),
weight_threshold: float = 0.35,
retract_amount: float = 0.004, # metres — how far to push face verts inward
knn: int = 5,
):
"""
Main entry point.
Parameters
----------
body_glb_path : rigged UniRig GLB
pshuman_mesh_path : PSHuman output mesh (OBJ / GLB / PLY)
output_path : result GLB path
head_bone_substrings : bone name fragments that identify head joints
weight_threshold : head-weight sum above which a vertex is "face"
retract_amount : metres to push face verts inward to avoid z-fight
knn : neighbours for weight transfer
"""
print(f"[face_transplant] Loading rigged GLB: {body_glb_path}")
glb = GLBMesh(body_glb_path)
print(f" Verts: {len(glb.verts)} Faces: {len(glb.faces)}")
print(f" Bones ({len(glb.joint_names)}): {', '.join(glb.joint_names[:8])} ...")
# 1. Identify head joints
head_ji = glb.head_bone_indices(substrings=head_bone_substrings)
if not head_ji:
raise RuntimeError(
f"No head bones found with substrings {head_bone_substrings}.\n"
f"Available bones: {glb.joint_names}"
)
print(f" Head joints ({len(head_ji)}): {[glb.joint_names[i] for i in head_ji]}")
# 2. Find face/head vertices
face_mask = find_face_verts(glb, head_ji, weight_threshold=weight_threshold)
print(f" Face verts: {face_mask.sum()} / {len(glb.verts)}")
min_face_verts = max(3, min(10, len(glb.verts) // 4))
if face_mask.sum() < min_face_verts:
raise RuntimeError(
f"Only {face_mask.sum()} face vertices found (need >= {min_face_verts}) — "
f"try lowering --weight-threshold (current: {weight_threshold})"
)
# 3. Load + align PSHuman mesh
face_verts_ur = glb.verts[face_mask]
ps_mesh = load_and_align_pshuman(pshuman_mesh_path, face_verts_ur)
ps_verts = np.array(ps_mesh.vertices, dtype=np.float32)
ps_faces = np.array(ps_mesh.faces, dtype=np.int32)
ps_uvs = None
if hasattr(ps_mesh.visual, "uv") and ps_mesh.visual.uv is not None:
ps_uvs = np.array(ps_mesh.visual.uv, dtype=np.float32)
# 4. Transfer weights: donor = UniRig face verts, recipient = PSHuman verts
print("[face_transplant] Transferring skinning weights via KNN ...")
ps_joints, ps_weights = transfer_weights(
donor_verts = glb.verts[face_mask].astype(np.float64),
donor_joints = glb.joints4[face_mask],
donor_weights = glb.weights4[face_mask],
recipient_verts = ps_verts.astype(np.float64),
k = knn,
)
print(f" Done. Head joint coverage in PSHuman: "
f"{(np.isin(ps_joints[:, 0], head_ji)).mean() * 100:.1f}% primary bone is head")
# 5. Retract UniRig face verts inward (push along −normal)
body_verts = glb.verts.copy()
if glb.normals is not None:
body_verts[face_mask] -= glb.normals[face_mask] * retract_amount
else:
# push toward centroid
centroid = body_verts[face_mask].mean(axis=0)
dirs = centroid - body_verts[face_mask]
norms = np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-9
body_verts[face_mask] += (dirs / norms) * retract_amount
# 6. Rebuild GLB
print("[face_transplant] Rebuilding GLB ...")
_write_transplanted_glb(
source_gltf = glb,
body_verts = body_verts,
ps_verts = ps_verts,
ps_faces = ps_faces,
ps_uvs = ps_uvs,
ps_joints = ps_joints,
ps_weights = ps_weights,
output_path = output_path,
)
print(f"[face_transplant] Saved -> {output_path}")
# ---------------------------------------------------------------------------
# GLB writer
# ---------------------------------------------------------------------------
def _write_transplanted_glb(
source_gltf: GLBMesh,
body_verts: np.ndarray,
ps_verts: np.ndarray,
ps_faces: np.ndarray,
ps_uvs: Optional[np.ndarray],
ps_joints: np.ndarray,
ps_weights: np.ndarray,
output_path: str,
):
"""
Copy the source GLB structure, replace mesh primitive 0 vertex data,
and append a new primitive for the PSHuman face.
"""
import copy
gltf = pygltflib.GLTF2().load(source_gltf.path)
gltf._path = source_gltf.path
# ------------------------------------------------------------------
# Preserve embedded images as data URIs BEFORE we wipe buffer views.
# The binary blob rebuild below only contains geometry; any image data
# referenced via bufferView would otherwise be lost.
# ------------------------------------------------------------------
try:
blob = bytes(gltf.binary_blob())
except Exception:
blob = b""
for img in gltf.images:
if img.bufferView is not None and img.uri is None and blob:
bv = gltf.bufferViews[img.bufferView]
img_bytes = blob[bv.byteOffset: bv.byteOffset + bv.byteLength]
mime = img.mimeType or "image/png"
img.uri = "data:{};base64,{}".format(mime, base64.b64encode(img_bytes).decode())
img.bufferView = None
# ------------------------------------------------------------------
# We will rebuild the entire binary buffer from scratch.
# Collect all data chunks; track buffer views + accessors.
# ------------------------------------------------------------------
chunks: List[bytes] = []
bviews: List[pygltflib.BufferView] = []
accors: List[pygltflib.Accessor] = []
byte_offset = 0
def add_chunk(data: bytes, component_type: int, type_str: str, count: int,
min_v=None, max_v=None, stride: int = None) -> int:
"""Append data, create buffer view + accessor, return accessor index."""
nonlocal byte_offset
bv = pygltflib.BufferView(buffer=0, byteOffset=byte_offset, byteLength=len(data))
if stride:
bv.byteStride = stride
bviews.append(bv)
bv_idx = len(bviews) - 1
acc = pygltflib.Accessor(
bufferView=bv_idx,
byteOffset=0,
componentType=component_type,
count=count,
type=type_str,
)
if min_v is not None:
acc.min = [float(x) for x in np.atleast_1d(min_v)]
if max_v is not None:
acc.max = [float(x) for x in np.atleast_1d(max_v)]
accors.append(acc)
acc_idx = len(accors) - 1
chunks.append(data)
byte_offset += len(data)
return acc_idx
# ------------------------------------------------------------------
# Primitive 0 — UniRig body (retracted face verts)
# ------------------------------------------------------------------
body_v = body_verts.astype(np.float32)
body_i = source_gltf.faces.astype(np.uint32).flatten()
body_n = (source_gltf.normals.astype(np.float32)
if source_gltf.normals is not None else None)
body_uv = (source_gltf.uvs.astype(np.float32)
if source_gltf.uvs is not None else None)
body_j = source_gltf.joints4.astype(np.uint16)
body_w = source_gltf.weights4.astype(np.float32)
# indices
bi_idx = add_chunk(body_i.tobytes(), UINT32, "SCALAR", len(body_i),
min_v=[int(body_i.min())], max_v=[int(body_i.max())])
# positions
bv_idx = add_chunk(body_v.tobytes(), FLOAT32, "VEC3", len(body_v),
min_v=body_v.min(axis=0), max_v=body_v.max(axis=0))
body_attrs = pygltflib.Attributes(POSITION=bv_idx)
if body_n is not None:
body_attrs.NORMAL = add_chunk(body_n.tobytes(), FLOAT32, "VEC3", len(body_n))
if body_uv is not None:
body_attrs.TEXCOORD_0 = add_chunk(body_uv.tobytes(), FLOAT32, "VEC2", len(body_uv))
if body_j is not None:
body_attrs.JOINTS_0 = add_chunk(body_j.tobytes(), UINT16, "VEC4", len(body_j))
body_attrs.WEIGHTS_0 = add_chunk(body_w.tobytes(), FLOAT32, "VEC4", len(body_w))
prim0 = pygltflib.Primitive(
attributes=body_attrs,
indices=bi_idx,
material=source_gltf.material_idx,
mode=4, # TRIANGLES
)
# ------------------------------------------------------------------
# Primitive 1 — PSHuman face
# ------------------------------------------------------------------
ps_v = ps_verts.astype(np.float32)
ps_i = ps_faces.astype(np.uint32).flatten()
ps_j4 = ps_joints.astype(np.uint16)
ps_w4 = ps_weights.astype(np.float32)
# PSHuman material — reuse body material for now (same texture look)
# If PSHuman has its own texture, you'd add a new material here.
face_mat_idx = source_gltf.material_idx
fi_idx = add_chunk(ps_i.tobytes(), UINT32, "SCALAR", len(ps_i),
min_v=[int(ps_i.min())], max_v=[int(ps_i.max())])
fv_idx = add_chunk(ps_v.tobytes(), FLOAT32, "VEC3", len(ps_v),
min_v=ps_v.min(axis=0), max_v=ps_v.max(axis=0))
face_attrs = pygltflib.Attributes(POSITION=fv_idx)
if ps_uvs is not None:
face_attrs.TEXCOORD_0 = add_chunk(ps_uvs.tobytes(), FLOAT32, "VEC2", len(ps_uvs))
face_attrs.JOINTS_0 = add_chunk(ps_j4.tobytes(), UINT16, "VEC4", len(ps_j4))
face_attrs.WEIGHTS_0 = add_chunk(ps_w4.tobytes(), FLOAT32, "VEC4", len(ps_w4))
prim1 = pygltflib.Primitive(
attributes=face_attrs,
indices=fi_idx,
material=face_mat_idx,
mode=4,
)
# ------------------------------------------------------------------
# Patch gltf structure
# ------------------------------------------------------------------
# Find or create the mesh that uses our skin
mesh_node = gltf.nodes[source_gltf.mesh_node_idx]
old_mesh_idx = mesh_node.mesh
new_mesh = pygltflib.Mesh(
name="body_with_pshuman_face",
primitives=[prim0, prim1],
)
# Replace or append
if old_mesh_idx is not None and old_mesh_idx < len(gltf.meshes):
gltf.meshes[old_mesh_idx] = new_mesh
target_mesh_idx = old_mesh_idx
else:
gltf.meshes.append(new_mesh)
target_mesh_idx = len(gltf.meshes) - 1
mesh_node.mesh = target_mesh_idx
# Replace buffer views and accessors
gltf.bufferViews = bviews
gltf.accessors = accors
# Rewrite buffer
combined = b"".join(chunks)
# Pad to 4-byte alignment
if len(combined) % 4:
combined += b"\x00" * (4 - len(combined) % 4)
gltf.buffers = [pygltflib.Buffer(byteLength=len(combined))]
gltf.set_binary_blob(combined)
# Drop stale animation (it referenced old accessor indices)
# The user can re-add animation later if needed.
gltf.animations = []
gltf.save(output_path)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Transplant PSHuman face into UniRig GLB")
parser.add_argument("--body", required=True, help="Rigged UniRig GLB")
parser.add_argument("--face", required=True, help="PSHuman mesh (OBJ/GLB/PLY)")
parser.add_argument("--output", required=True, help="Output GLB path")
parser.add_argument("--head-bones", default="head,Head,skull,Skull",
help="Comma-separated bone name substrings for head detection")
parser.add_argument("--weight-threshold", type=float, default=0.35,
help="Minimum head-bone weight sum to classify a vert as face")
parser.add_argument("--retract", type=float, default=0.004,
help="Metres to retract UniRig face verts inward (default 0.004)")
parser.add_argument("--knn", type=int, default=5,
help="K nearest neighbours for weight transfer")
args = parser.parse_args()
subs = tuple(s.strip() for s in args.head_bones.split(","))
transplant_face(
body_glb_path = args.body,
pshuman_mesh_path = args.face,
output_path = args.output,
head_bone_substrings = subs,
weight_threshold = args.weight_threshold,
retract_amount = args.retract,
knn = args.knn,
)
if __name__ == "__main__":
main()