Image2Model / Retarget /io /gltf_io.py
Daankular's picture
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
8f1bcd9
"""
io/gltf_io.py
Load a glTF/GLB skeleton (e.g. UniRig output) into an Armature.
Write retargeted animation back into a glTF/GLB file.
Requires: pip install pygltflib
"""
from __future__ import annotations
import base64
import json
import struct
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import pygltflib
except ImportError:
raise ImportError("pip install pygltflib")
from ..skeleton import Armature, PoseBone
from ..math3d import (
quat_identity, quat_normalize, matrix4_to_quat, matrix4_to_trs,
trs_to_matrix4, vec3,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _node_local_trs(node: "pygltflib.Node"):
"""Extract TRS from a glTF node. Returns (t[3], r_wxyz[4], s[3])."""
t = np.array(node.translation or [0.0, 0.0, 0.0])
r_xyzw = np.array(node.rotation or [0.0, 0.0, 0.0, 1.0])
s = np.array(node.scale or [1.0, 1.0, 1.0])
# Convert glTF (x,y,z,w) → our (w,x,y,z)
r_wxyz = np.array([r_xyzw[3], r_xyzw[0], r_xyzw[1], r_xyzw[2]])
return t, r_wxyz, s
def _node_local_matrix(node: "pygltflib.Node") -> np.ndarray:
if node.matrix:
# glTF stores column-major; convert to row-major
m = np.array(node.matrix, dtype=float).reshape(4, 4).T
return m
t, r, s = _node_local_trs(node)
return trs_to_matrix4(t, r, s)
def _read_accessor(gltf: "pygltflib.GLTF2", accessor_idx: int) -> np.ndarray:
"""Read a glTF accessor into a numpy array."""
acc = gltf.accessors[accessor_idx]
bv = gltf.bufferViews[acc.bufferView]
buf = gltf.buffers[bv.buffer]
# Inline base64 data URI
if buf.uri and buf.uri.startswith("data:"):
_, b64 = buf.uri.split(",", 1)
raw = base64.b64decode(b64)
elif buf.uri:
base_dir = Path(gltf._path).parent if hasattr(gltf, "_path") and gltf._path else Path(".")
raw = (base_dir / buf.uri).read_bytes()
else:
# Binary GLB — data stored in gltf.binary_blob
raw = bytes(gltf.binary_blob())
start = bv.byteOffset + (acc.byteOffset or 0)
count = acc.count
type_to_components = {
"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4,
"MAT2": 4, "MAT3": 9, "MAT4": 16,
}
component_type_to_fmt = {
5120: "b", 5121: "B", 5122: "h", 5123: "H",
5125: "I", 5126: "f",
}
n_comp = type_to_components[acc.type]
fmt = component_type_to_fmt[acc.componentType]
item_size = struct.calcsize(fmt) * n_comp
stride = bv.byteStride or item_size
items = []
for i in range(count):
offset = start + i * stride
vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
items.append(vals)
return np.array(items, dtype=float).squeeze()
# ---------------------------------------------------------------------------
# Load skeleton from glTF
# ---------------------------------------------------------------------------
def load_gltf(filepath: str, skin_index: int = 0) -> Armature:
"""
Load the first (or specified) skin from a glTF/GLB file into an Armature.
The armature world_matrix is set to identity (typical for UniRig output).
"""
gltf = pygltflib.GLTF2().load(filepath)
gltf._path = filepath
if not gltf.skins:
raise ValueError(f"No skins found in '{filepath}'")
skin = gltf.skins[skin_index]
# Read inverse bind matrices
n_joints = len(skin.joints)
ibm_array: Optional[np.ndarray] = None
if skin.inverseBindMatrices is not None:
raw = _read_accessor(gltf, skin.inverseBindMatrices)
ibm_array = raw.reshape(n_joints, 4, 4)
# Compute bind-pose world matrices: world_bind = inv(ibm)
joint_world_bind: Dict[int, np.ndarray] = {}
for i, j_idx in enumerate(skin.joints):
if ibm_array is not None:
ibm = ibm_array[i].T # glTF column-major → numpy row-major
joint_world_bind[j_idx] = np.linalg.inv(ibm)
else:
# Fallback: compute from FK over node local matrices
joint_world_bind[j_idx] = np.eye(4)
# Build parent map for nodes
parent_of: Dict[int, Optional[int]] = {}
for ni, node in enumerate(gltf.nodes):
for child_idx in (node.children or []):
parent_of[child_idx] = ni
arm = Armature(skin.name or f"Skin_{skin_index}")
# Process joints in order (parent always before child in glTF spec)
joint_set = set(skin.joints)
processed: Dict[int, str] = {}
for i, j_idx in enumerate(skin.joints):
node = gltf.nodes[j_idx]
bone_name = node.name or f"joint_{i}"
# Find parent joint node
parent_node_idx = parent_of.get(j_idx)
parent_bone_name: Optional[str] = None
while parent_node_idx is not None:
if parent_node_idx in joint_set:
parent_bone_name = processed.get(parent_node_idx)
break
parent_node_idx = parent_of.get(parent_node_idx)
# rest_matrix_local in parent space
if parent_bone_name and parent_bone_name in [b for b in processed.values()]:
parent_world = joint_world_bind.get(
next(k for k, v in processed.items() if v == parent_bone_name),
np.eye(4)
)
rest_local = np.linalg.inv(parent_world) @ joint_world_bind[j_idx]
else:
rest_local = joint_world_bind[j_idx]
bone = PoseBone(bone_name, rest_local)
arm.add_bone(bone, parent_bone_name)
processed[j_idx] = bone_name
arm.update_fk()
return arm
# ---------------------------------------------------------------------------
# Write animation to glTF
# ---------------------------------------------------------------------------
def write_gltf_animation(
source_filepath: str,
dest_armature: Armature,
keyframes: List[Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]],
output_filepath: str,
fps: float = 30.0,
skin_index: int = 0,
) -> None:
"""
Embed animation keyframes into a copy of source_filepath (the UniRig GLB).
keyframes: list of dicts, one per frame.
Each dict maps bone_name → (pose_location, pose_rotation_quat, pose_scale)
These are LOCAL values (relative to rest pose local matrix).
The function adds one glTF Animation with channels for each bone that has data.
"""
gltf = pygltflib.GLTF2().load(source_filepath)
gltf._path = source_filepath
if not gltf.skins:
raise ValueError("No skins in source file")
skin = gltf.skins[skin_index]
# Build node_name → node_index map for skin joints
joint_name_to_node: Dict[str, int] = {}
for j_idx in skin.joints:
node = gltf.nodes[j_idx]
name = node.name or f"joint_{j_idx}"
joint_name_to_node[name] = j_idx
n_frames = len(keyframes)
times = np.array([i / fps for i in range(n_frames)], dtype=np.float32)
# Gather binary data
binary_chunks: List[bytes] = []
accessors: List[dict] = []
buffer_views: List[dict] = []
def _add_data(data: np.ndarray, acc_type: str) -> int:
"""Append numpy array to binary, return accessor index."""
raw = data.astype(np.float32).tobytes()
bv_offset = sum(len(c) for c in binary_chunks)
binary_chunks.append(raw)
bv_idx = len(gltf.bufferViews)
gltf.bufferViews.append(pygltflib.BufferView(
buffer=0,
byteOffset=bv_offset,
byteLength=len(raw),
))
acc_idx = len(gltf.accessors)
gltf.accessors.append(pygltflib.Accessor(
bufferView=bv_idx,
componentType=pygltflib.FLOAT,
count=len(data),
type=acc_type,
max=data.max(axis=0).tolist() if data.ndim > 1 else [float(data.max())],
min=data.min(axis=0).tolist() if data.ndim > 1 else [float(data.min())],
))
return acc_idx
time_acc_idx = _add_data(times, "SCALAR")
channels: List[pygltflib.AnimationChannel] = []
samplers: List[pygltflib.AnimationSampler] = []
bone_names = set()
for frame in keyframes:
bone_names |= frame.keys()
for bone_name in sorted(bone_names):
if bone_name not in joint_name_to_node:
continue
node_idx = joint_name_to_node[bone_name]
node = gltf.nodes[node_idx]
# Collect TRS arrays across frames
rot_data = np.zeros((n_frames, 4), dtype=np.float32) # (x,y,z,w)
trans_data = np.zeros((n_frames, 3), dtype=np.float32)
scale_data = np.ones((n_frames, 3), dtype=np.float32)
rest_t, rest_r, rest_s = _node_local_trs(node)
for fi, frame in enumerate(keyframes):
if bone_name in frame:
pose_loc, pose_rot, pose_scale = frame[bone_name]
else:
pose_loc = vec3()
pose_rot = quat_identity()
pose_scale = np.ones(3)
# Final local = rest + delta (simple addition for translation, multiply for rotation)
from ..math3d import quat_mul, trs_to_matrix4
final_t = rest_t + pose_loc
final_r = quat_mul(rest_r, pose_rot) # (w,x,y,z)
final_s = rest_s * pose_scale
# Convert rotation to glTF (x,y,z,w)
w, x, y, z = final_r
rot_data[fi] = [x, y, z, w]
trans_data[fi] = final_t
scale_data[fi] = final_s
s_idx = len(samplers)
rot_acc = _add_data(rot_data, "VEC4")
samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=rot_acc, interpolation="LINEAR"))
channels.append(pygltflib.AnimationChannel(
sampler=s_idx,
target=pygltflib.AnimationChannelTarget(node=node_idx, path="rotation"),
))
s_idx = len(samplers)
trans_acc = _add_data(trans_data, "VEC3")
samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=trans_acc, interpolation="LINEAR"))
channels.append(pygltflib.AnimationChannel(
sampler=s_idx,
target=pygltflib.AnimationChannelTarget(node=node_idx, path="translation"),
))
if not channels:
print("[gltf_io] Warning: no channels written — check bone name mapping.")
return
gltf.animations.append(pygltflib.Animation(
name="RetargetedAnimation",
samplers=samplers,
channels=channels,
))
# Patch buffer 0 size with our new data
new_blob = b"".join(binary_chunks)
existing_blob = bytes(gltf.binary_blob()) if gltf.binary_blob() else b""
full_blob = existing_blob + new_blob
# Update buffer 0 byteOffset of new views
for bv in gltf.bufferViews[-len(binary_chunks):]:
bv.byteOffset += len(existing_blob)
gltf.set_binary_blob(full_blob)
gltf.buffers[0].byteLength = len(full_blob)
gltf.save(output_filepath)
print(f"[gltf_io] Saved animated GLB -> {output_filepath}")