somatosmpl / tests /test_pose_inversion.py
zirobtc's picture
Upload folder using huggingface_hub
bd95c9c verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for PoseInversion.
Tests both ``fit()`` (analytical Kabsch) and ``fit(autograd_iters=...)``
(FK-based gradient optimization) against ground-truth posed vertices
from example_animation.npy.
Pose conventions
~~~~~~~~~~~~~~~~
- example_animation.npy stores local rotations *relative to T-pose*
(joint orient not applied). demo_soma_vis.py applies a t-pose
correction before passing to ``soma.pose(absolute_pose=False)``.
- Both ``fit()`` and ``fit(autograd_iters=...)`` return *absolute*
local rotations (joint orient already baked in), suitable for
``soma.pose(absolute_pose=True)`` or direct LBS via
``BatchedSkinning.pose(absolute_pose=True)``.
Requires CUDA and assets/.
"""
from pathlib import Path
import numpy as np
import pytest
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
ASSETS_DIR = REPO_ROOT / "assets"
MOTION_FILE = ASSETS_DIR / "example_animation.npy"
# 94-joint skeleton to 77-joint mapping (from demo_soma_vis.py)
# fmt: off
_NVSKEL93_NAMES = [
"Hips", "Spine1", "Spine2", "Chest", "Neck1", "Neck2", "Head", "HeadEnd", "Jaw",
"LeftEye", "RightEye", "LeftShoulder", "LeftArm", "LeftForeArm", "LeftHand",
"LeftHandThumb1", "LeftHandThumb2", "LeftHandThumb3", "LeftHandThumbEnd",
"LeftHandIndex1", "LeftHandIndex2", "LeftHandIndex3", "LeftHandIndex4", "LeftHandIndexEnd",
"LeftHandMiddle1", "LeftHandMiddle2", "LeftHandMiddle3", "LeftHandMiddle4", "LeftHandMiddleEnd",
"LeftHandRing1", "LeftHandRing2", "LeftHandRing3", "LeftHandRing4", "LeftHandRingEnd",
"LeftHandPinky1", "LeftHandPinky2", "LeftHandPinky3", "LeftHandPinky4", "LeftHandPinkyEnd",
"LeftForeArmTwist1", "LeftForeArmTwist2", "LeftArmTwist1", "LeftArmTwist2",
"RightShoulder", "RightArm", "RightForeArm", "RightHand",
"RightHandThumb1", "RightHandThumb2", "RightHandThumb3", "RightHandThumbEnd",
"RightHandIndex1", "RightHandIndex2", "RightHandIndex3", "RightHandIndex4", "RightHandIndexEnd",
"RightHandMiddle1", "RightHandMiddle2", "RightHandMiddle3", "RightHandMiddle4", "RightHandMiddleEnd",
"RightHandRing1", "RightHandRing2", "RightHandRing3", "RightHandRing4", "RightHandRingEnd",
"RightHandPinky1", "RightHandPinky2", "RightHandPinky3", "RightHandPinky4", "RightHandPinkyEnd",
"RightForeArmTwist1", "RightForeArmTwist2", "RightArmTwist1", "RightArmTwist2",
"LeftLeg", "LeftShin", "LeftFoot", "LeftToeBase", "LeftToeEnd",
"LeftShinTwist1", "LeftShinTwist2", "LeftLegTwist1", "LeftLegTwist2",
"RightLeg", "RightShin", "RightFoot", "RightToeBase", "RightToeEnd",
"RightShinTwist1", "RightShinTwist2", "RightLegTwist1", "RightLegTwist2",
]
_NVSKEL77_NAMES = [
"Hips", "Spine1", "Spine2", "Chest", "Neck1", "Neck2", "Head", "HeadEnd", "Jaw",
"LeftEye", "RightEye",
"LeftShoulder", "LeftArm", "LeftForeArm", "LeftHand",
"LeftHandThumb1", "LeftHandThumb2", "LeftHandThumb3", "LeftHandThumbEnd",
"LeftHandIndex1", "LeftHandIndex2", "LeftHandIndex3", "LeftHandIndex4", "LeftHandIndexEnd",
"LeftHandMiddle1", "LeftHandMiddle2", "LeftHandMiddle3", "LeftHandMiddle4", "LeftHandMiddleEnd",
"LeftHandRing1", "LeftHandRing2", "LeftHandRing3", "LeftHandRing4", "LeftHandRingEnd",
"LeftHandPinky1", "LeftHandPinky2", "LeftHandPinky3", "LeftHandPinky4", "LeftHandPinkyEnd",
"RightShoulder", "RightArm", "RightForeArm", "RightHand",
"RightHandThumb1", "RightHandThumb2", "RightHandThumb3", "RightHandThumbEnd",
"RightHandIndex1", "RightHandIndex2", "RightHandIndex3", "RightHandIndex4", "RightHandIndexEnd",
"RightHandMiddle1", "RightHandMiddle2", "RightHandMiddle3", "RightHandMiddle4", "RightHandMiddleEnd",
"RightHandRing1", "RightHandRing2", "RightHandRing3", "RightHandRing4", "RightHandRingEnd",
"RightHandPinky1", "RightHandPinky2", "RightHandPinky3", "RightHandPinky4", "RightHandPinkyEnd",
"LeftLeg", "LeftShin", "LeftFoot", "LeftToeBase", "LeftToeEnd",
"RightLeg", "RightShin", "RightFoot", "RightToeBase", "RightToeEnd",
]
# fmt: on
_93TO77_IDX = [_NVSKEL93_NAMES.index(n) for n in _NVSKEL77_NAMES]
requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def _load_motion(soma, frames):
"""Load example_animation.npy frames, return ground-truth posed vertices.
Follows the same pipeline as tools/demo_soma_vis.py:
1. Remap 94-joint → 78-joint (root + 77)
2. Apply t-pose correction
3. Forward pass through soma.pose()
Returns (posed_vertices, root_translation).
"""
from soma.geometry.rig_utils import joint_local_to_world, joint_world_to_local
device = soma.device
motion_full = torch.from_numpy(np.load(MOTION_FILE)).float().to(device)
rot_local = motion_full[..., :3, :3]
root_trans = motion_full[:, 1, :3, 3]
# Remap 94 → 78 joints (root + 77)
if rot_local.shape[1] == 94:
subset_idx = [0] + [i + 1 for i in _93TO77_IDX]
rot_local = rot_local[:, subset_idx]
# T-pose correction: animation data is in a different skeleton
# convention; rotate world transforms to match SOMA's joint orient.
correction = soma.t_pose_world[:, :3, :3].transpose(-2, -1)
rot_world = joint_local_to_world(rot_local, soma.joint_parent_ids)
rot_world = rot_world @ correction
rot_local = joint_world_to_local(rot_world, soma.joint_parent_ids)
# Build pose: global_orient (Hips=joint 1) + body (joints 2:)
global_orient = rot_local[:, 1]
body_pose = rot_local[:, 2:]
pose = torch.cat([global_orient.unsqueeze(1), body_pose], dim=1)
transl = root_trans
# Select frames
pose = pose[frames]
transl = transl[frames]
# Forward pass — these rotations are relative to T-pose
with torch.no_grad():
out = soma.pose(pose, transl=transl, pose2rot=False, absolute_pose=False)
return out["vertices"], transl
@pytest.fixture(scope="module")
def soma_and_inv():
"""Create SOMALayer + PoseInversion, prepare mean-shape identity."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
if not ASSETS_DIR.is_dir():
pytest.fail(f"Assets directory not found: {ASSETS_DIR}")
if not MOTION_FILE.is_file():
pytest.fail(f"Motion file not found: {MOTION_FILE}")
from soma.pose_inversion import PoseInversion
from soma.soma import SOMALayer
device = "cuda"
soma = SOMALayer(
data_root=str(ASSETS_DIR),
identity_model_type="soma",
device=device,
mode="warp",
low_lod=True,
)
# Prepare mean shape
n_id = soma.identity_model.num_identity_coeffs
identity_coeffs = torch.zeros(1, n_id, device=device)
soma.prepare_identity(identity_coeffs)
inv = PoseInversion(soma, low_lod=True)
inv.prepare_identity(identity_coeffs)
return soma, inv
@requires_cuda
class TestInvert:
"""Tests for PoseInversion.fit() (analytical Kabsch)."""
def test_single_frame_roundtrip(self, soma_and_inv):
"""Single frame: fit recovers pose with low error."""
soma, inv = soma_and_inv
verts, _ = _load_motion(soma, frames=[0])
result = inv.fit(verts, body_iters=10, finger_iters=2, full_iters=1)
J = result["rotations"].shape[1] # 78 (root + 77 joints)
assert result["rotations"].shape == (1, J, 3, 3)
assert result["root_translation"].shape == (1, 3)
assert result["per_vertex_error"].shape[0] == 1
mean_err = result["per_vertex_error"].mean().item()
max_err = result["per_vertex_error"].max().item()
assert mean_err < 0.01, f"Mean vertex error too high: {mean_err:.6f} m"
assert max_err < 0.05, f"Max vertex error too high: {max_err:.6f} m"
def test_batch_roundtrip(self, soma_and_inv):
"""Multiple diverse frames: consistent low error across batch."""
soma, inv = soma_and_inv
verts, _ = _load_motion(soma, frames=[0, 100, 300, 600])
result = inv.fit(verts, body_iters=10, finger_iters=2, full_iters=1)
J = result["rotations"].shape[1]
assert result["rotations"].shape == (4, J, 3, 3)
assert result["per_vertex_error"].shape[0] == 4
mean_err = result["per_vertex_error"].mean().item()
assert mean_err < 0.01, f"Mean vertex error too high: {mean_err:.6f} m"
def test_roundtrip_forward_pass(self, soma_and_inv):
"""Verify inverted rotations reproduce vertices via soma.pose().
fit returns absolute local rotations for 78 joints
(root + 77). Strip the root (index 0) and pass to
soma.pose(absolute_pose=True) to reconstruct.
"""
soma, inv = soma_and_inv
verts_gt, _ = _load_motion(soma, frames=[50, 200])
result = inv.fit(verts_gt, body_iters=10, finger_iters=2, full_iters=1)
# Strip root joint (index 0) — soma.pose() expects 77 joints
rotations_no_root = result["rotations"][:, 1:]
# fit uses raw LBS without correctives, so disable
# correctives in the forward pass for a fair comparison.
with torch.no_grad():
out = soma.pose(
rotations_no_root,
transl=result["root_translation"],
pose2rot=False,
absolute_pose=True,
apply_correctives=False,
)
verts_recon = out["vertices"]
err = torch.norm(verts_recon - verts_gt, dim=-1)
mean_err = err.mean().item()
# Slightly higher threshold than internal per_vertex_error because
# soma.pose() uses full skinning weights while fit
# uses sparse top-K weights internally.
assert mean_err < 0.02, f"Forward-pass roundtrip error too high: {mean_err:.6f} m"
def test_batch_size_chunking(self, soma_and_inv):
"""batch_size parameter produces comparable results to all-at-once."""
soma, inv = soma_and_inv
verts, _ = _load_motion(soma, frames=[0, 50, 100, 150])
result_all = inv.fit(verts, body_iters=5, finger_iters=2)
result_chunked = inv.fit(verts, body_iters=5, finger_iters=2, batch_size=2)
assert result_chunked["rotations"].shape == result_all["rotations"].shape
# Analytical is deterministic, so results should be very close
err_all = result_all["per_vertex_error"].mean().item()
err_chunked = result_chunked["per_vertex_error"].mean().item()
assert abs(err_all - err_chunked) < 0.005, (
f"Chunked vs all-at-once error mismatch: {err_all:.6f} vs {err_chunked:.6f}"
)
def test_identity_pose_near_zero_error(self, soma_and_inv):
"""Rest pose (identity rotations) should fit with near-zero error."""
soma, inv = soma_and_inv
device = soma.device
J = 77
rot_mats = torch.eye(3, device=device).expand(1, J, 3, 3).clone()
transl = torch.zeros(1, 3, device=device)
with torch.no_grad():
out = soma.pose(rot_mats, transl=transl, pose2rot=False)
verts = out["vertices"]
result = inv.fit(verts, body_iters=5, finger_iters=2, full_iters=1)
mean_err = result["per_vertex_error"].mean().item()
assert mean_err < 0.02, f"Identity pose error too high: {mean_err:.6f} m"
@requires_cuda
class TestInvertAutogradFK:
"""Tests for PoseInversion.fit(autograd_iters=...)."""
def test_single_frame_roundtrip(self, soma_and_inv):
"""Single frame: fit(autograd_iters) recovers pose with low error."""
soma, inv = soma_and_inv
verts, _ = _load_motion(soma, frames=[0])
result = inv.fit(verts, body_iters=0, full_iters=0, autograd_iters=20, autograd_lr=5e-3)
J = result["rotations"].shape[1]
assert result["rotations"].shape == (1, J, 3, 3)
assert result["root_translation"].shape == (1, 3)
assert result["per_vertex_error"].shape[0] == 1
mean_err = result["per_vertex_error"].mean().item()
assert mean_err < 0.01, f"Mean vertex error too high: {mean_err:.6f} m"
def test_batch_roundtrip(self, soma_and_inv):
"""Multiple diverse frames: consistent low error across batch."""
soma, inv = soma_and_inv
verts, _ = _load_motion(soma, frames=[0, 100, 300, 600])
result = inv.fit(verts, body_iters=0, full_iters=0, autograd_iters=20, autograd_lr=5e-3)
assert result["rotations"].shape[0] == 4
mean_err = result["per_vertex_error"].mean().item()
assert mean_err < 0.01, f"Mean vertex error too high: {mean_err:.6f} m"