somatosmpl / soma /geometry /skeleton_transfer.py
zirobtc's picture
Upload folder using huggingface_hub
bd95c9c verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from ._utils import one_hot_1d, require_torch_tensors
from .interpolate import RadialBasisFunction
from .rig_utils import get_joint_children_ids, joint_world_to_local
from .transforms import SE3_from_Rt, align_vectors
try:
from .align_vectors_warp import (
align_vectors_warp,
parallel_rodrigues_kabsch_warp,
rodrigues_rotation_warp,
)
except ImportError:
align_vectors_warp = None
rodrigues_rotation_warp = None
parallel_rodrigues_kabsch_warp = None
class SkeletonTransfer(torch.nn.Module):
def __init__(
self,
joint_parent_ids,
bind_world_transforms,
bind_shape,
skinning_weights,
rbf_kernel="linear",
vertex_ids_to_exclude=None,
freeze_rotations=None,
skip_endjoints=True,
use_sparse_rbf_matrix=True,
use_warp_for_rotations=True,
rotation_method="kabsch",
skip_inverse_lbs=False,
):
"""Initialize a SkeletonTransfer instance for fitting a skeleton to new shapes.
Args:
joint_parent_ids: (J,) int array of joint parent indices
bind_world_transforms: (J, 4, 4) array of joint bind poses in world space
bind_shape: (V, 3) array of vertex positions in bind pose
skinning_weights: (V, J) array of skinning weights
rbf_kernel: type of RBF kernel to use for joint position regression
vertex_ids_to_exclude: (V,) int array of vertex ids to exclude from the joint position regressors
freeze_rotations: list of joint ids to freeze to bind pose (for Warp mode)
skip_endjoints: bool, whether to skip rotation fitting for end joints (for Warp mode)
use_sparse_rbf_matrix: bool, whether to use a sparse RBF matrix for joint position regression
use_warp_for_rotations: bool, whether to use Warp-based rotation fitting (requires Warp)
rotation_method: str, rotation extraction method ('kabsch' or 'newton-schulz')
skip_inverse_lbs: bool, whether to skip Inverse LBS (skinned vertex fitting) and use identity R_init
"""
super().__init__()
if freeze_rotations is None:
freeze_rotations = []
require_torch_tensors(
bind_world_transforms, bind_shape, skinning_weights, name="SkeletonTransfer inputs"
)
assert len(joint_parent_ids) == len(bind_world_transforms) == skinning_weights.shape[1], (
"joint_names, joint_parent_ids, bind_world_transforms, and skinning_weights.shape[1] must have the same length"
)
self.num_joints = len(joint_parent_ids)
if isinstance(joint_parent_ids, torch.Tensor):
self.joint_parent_ids = joint_parent_ids.detach().cpu().tolist()
else:
self.joint_parent_ids = joint_parent_ids
self.joint_children_ids = get_joint_children_ids(self.joint_parent_ids)
bind_world_transforms = bind_world_transforms.detach()
bind_shape = bind_shape.detach()
skinning_weights = skinning_weights.detach()
self.register_buffer("bind_world_transforms", bind_world_transforms, persistent=False)
self.register_buffer(
"bind_local_transforms",
joint_world_to_local(bind_world_transforms, joint_parent_ids),
persistent=False,
)
self.register_buffer("bind_shape", bind_shape, persistent=False)
self.register_buffer("skinning_weights", skinning_weights, persistent=False)
self.register_buffer("regressor_mask", None, persistent=False)
self.sparse_rbf_matrix = None
self.joint_pos_regressors = None
self.rbf_kernel = rbf_kernel
if isinstance(vertex_ids_to_exclude, torch.Tensor):
self.vertex_ids_to_exclude = vertex_ids_to_exclude.detach().cpu().tolist()
else:
self.vertex_ids_to_exclude = vertex_ids_to_exclude
self.freeze_rotations = set(freeze_rotations) if freeze_rotations else set()
self.skip_endjoints = skip_endjoints
self.use_sparse_rbf_matrix = use_sparse_rbf_matrix
self.rotation_method = rotation_method
self.skip_inverse_lbs = skip_inverse_lbs
self._precompute_regressors()
# Warp-specific precomputed data (initialized lazily)
self.register_buffer("_warp_stage1_offsets", None, persistent=False)
self.register_buffer("_warp_stage1_counts", None, persistent=False)
self.register_buffer("_warp_stage1_skinned_vids_flat", None, persistent=False)
self.register_buffer("_warp_stage1_joint_indices", None, persistent=False)
self.register_buffer("_warp_stage1_repeat_indices", None, persistent=False)
self.register_buffer("_warp_stage1_skinned_orig_centered", None, persistent=False)
self._warp_stage1_joint_to_batch_idx = None
self.register_buffer("_warp_stage2_offsets", None, persistent=False)
self.register_buffer("_warp_stage2_counts", None, persistent=False)
self.register_buffer("_warp_stage2_child_flat", None, persistent=False)
self.register_buffer("_warp_stage2_joint_indices", None, persistent=False)
self._warp_stage2_joint_to_batch_idx = None
self.register_buffer("_warp_stage2_repeat_indices", None, persistent=False)
self.register_buffer("_warp_stage2_R_repeat_indices", None, persistent=False)
self.register_buffer("_warp_frozen_joints", None, persistent=False)
self.register_buffer("_warp_unskinned_end_joints", None, persistent=False)
self.register_buffer("_warp_unskinned_end_joint_parents", None, persistent=False)
self.register_buffer("_warp_stage2_n1_joint_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n1_child_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n1_to_stage1_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n1_children_orig_centered", None, persistent=False)
self.register_buffer("_warp_stage2_n2_offsets", None, persistent=False)
self.register_buffer("_warp_stage2_n2_counts", None, persistent=False)
self.register_buffer("_warp_stage2_n2_child_flat", None, persistent=False)
self.register_buffer("_warp_stage2_n2_joint_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n2_to_stage1_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n2_children_orig_centered", None, persistent=False)
self.register_buffer("_warp_stage2_n2_repeat_indices", None, persistent=False)
self.register_buffer("_warp_stage2_n2_R_repeat_indices", None, persistent=False)
self.register_buffer("_warp_stage2_joint_indices", None, persistent=False)
self.register_buffer("_warp_frozen_parents", None, persistent=False)
self.use_warp_for_rotations = use_warp_for_rotations
if use_warp_for_rotations:
if align_vectors_warp is None or rodrigues_rotation_warp is None:
raise ImportError("Warp-based rotation fitting requires Warp to be installed.")
self._precompute_warp_data()
def update_bind(self, bind_world_transforms, bind_shape):
"""Update bind-pose data without rebuilding structural caches.
This is much faster than constructing a new SkeletonTransfer and is
suitable when only the identity (shape) changes but the topology,
skinning weights, and skeleton structure remain the same.
"""
self.bind_world_transforms = bind_world_transforms.detach()
self.bind_local_transforms = joint_world_to_local(
bind_world_transforms, self.joint_parent_ids
)
self.bind_shape = bind_shape.detach()
@property
def device(self):
return self.bind_world_transforms.device
@property
def dtype(self):
return self.bind_world_transforms.dtype
def _apply(self, fn):
super()._apply(fn)
self._precompute_regressors()
if self.use_warp_for_rotations:
self._precompute_warp_data()
return self
def _precompute_regressors(self):
regressor_mask = self.skinning_weights > 0.0
regressor_mask &= self.skinning_weights[:, self.joint_parent_ids] > 0.0
zero_weight_ids = torch.where(regressor_mask.sum(dim=0) == 0.0)[0]
joint_parent_ids = torch.as_tensor(
self.joint_parent_ids, dtype=torch.long, device=self.device
)
joint_parent_ids_cur = joint_parent_ids.clone()
regressor_mask[:, zero_weight_ids] = self.skinning_weights[:, zero_weight_ids] > 0.0
while len(zero_weight_ids) > 1:
regressor_mask[:, zero_weight_ids] |= (
self.skinning_weights[:, joint_parent_ids_cur][:, zero_weight_ids] > 0.0
)
zero_weight_ids = torch.where(regressor_mask.sum(dim=0) == 0.0)[0]
joint_parent_ids_cur_update = joint_parent_ids[joint_parent_ids_cur]
if torch.equal(joint_parent_ids_cur_update, joint_parent_ids_cur):
break
joint_parent_ids_cur = joint_parent_ids_cur_update
if torch.equal(zero_weight_ids, torch.tensor([0, 1], device=self.device)):
print("Aggregating children of hips")
child_ids = get_joint_children_ids(joint_parent_ids)[1]
regressor_mask[:, 1] = regressor_mask[:, child_ids].any(axis=1)
if self.vertex_ids_to_exclude is not None:
regressor_mask[self.vertex_ids_to_exclude] = False
self.regressor_mask = regressor_mask
if self.dtype == torch.float16:
bind_shape_rbf = self.bind_shape.to(torch.float32)
else:
bind_shape_rbf = self.bind_shape
self.joint_pos_regressors = [
RadialBasisFunction(
bind_shape_rbf[regressor_mask[:, i]],
kernel=self.rbf_kernel,
include_polynomial=True,
)
if i != 0
else None
for i in range(self.num_joints)
]
if self.use_sparse_rbf_matrix:
all_weights = []
all_col_indices = []
crow_indices = [0]
for i, rbf in enumerate(self.joint_pos_regressors):
if rbf is None:
crow_indices.append(crow_indices[-1])
continue
joint_query_position = self.bind_world_transforms[i, :3, 3]
if rbf.dtype != self.dtype:
joint_query_position = joint_query_position.to(rbf.dtype)
w = rbf.get_basis_weights(joint_query_position)
all_weights.append(w)
all_col_indices.append(torch.where(regressor_mask[:, i])[0])
crow_indices.append(crow_indices[-1] + len(w))
flat_values = torch.cat(all_weights)
flat_indices = torch.cat(all_col_indices)
crow_indices_tensor = torch.tensor(crow_indices, device=self.device, dtype=torch.int64)
self.sparse_rbf_matrix = torch.sparse_csr_tensor(
crow_indices=crow_indices_tensor,
col_indices=flat_indices,
values=flat_values.to(self.dtype),
size=(len(self.bind_world_transforms), len(self.bind_shape)),
device=self.device,
dtype=self.dtype,
)
else:
self.sparse_rbf_matrix = None
def _precompute_warp_data(self):
"""Precompute offsets, counts, and indices for Warp-based rotation fitting."""
frozen = self.freeze_rotations
unskinned_end_joints = []
unskinned_end_joint_parents = []
# ===== Stage 1: Skinned vertex alignment =====
stage1_skinned_vids_flat = []
stage1_counts_list = []
stage1_joint_indices_list = []
stage1_joint_to_batch_idx = {}
batch_idx = 0
for i in range(1, self.num_joints):
if i in frozen:
continue
children = self.joint_children_ids[i]
is_end_joint = len(children) == 0
skinned_vids = torch.where(self.skinning_weights[:, i] > 0.01)[0]
num_skinned = len(skinned_vids)
if is_end_joint:
if self.skip_endjoints:
unskinned_end_joints.append(i)
unskinned_end_joint_parents.append(self.joint_parent_ids[i])
continue
elif num_skinned < 1:
unskinned_end_joints.append(i)
unskinned_end_joint_parents.append(self.joint_parent_ids[i])
continue
stage1_joint_indices_list.append(i)
stage1_joint_to_batch_idx[i] = batch_idx
stage1_skinned_vids_flat.extend(skinned_vids.tolist())
stage1_counts_list.append(num_skinned)
batch_idx += 1
if stage1_counts_list:
counts = torch.tensor(stage1_counts_list, dtype=torch.int32, device=self.device)
offsets = torch.zeros_like(counts)
if counts.numel() > 1:
offsets[1:] = torch.cumsum(counts[:-1], dim=0)
self._warp_stage1_offsets = offsets
self._warp_stage1_counts = counts
self._warp_stage1_skinned_vids_flat = torch.tensor(
stage1_skinned_vids_flat, dtype=torch.long, device=self.device
)
self._warp_stage1_joint_indices = torch.tensor(
stage1_joint_indices_list, dtype=torch.long, device=self.device
)
self._warp_stage1_joint_to_batch_idx = stage1_joint_to_batch_idx
skinned_orig = self.bind_shape[self._warp_stage1_skinned_vids_flat]
joint_positions_for_verts = self.bind_world_transforms[
self._warp_stage1_joint_indices, :3, 3
]
joint_positions_expanded = torch.repeat_interleave(
joint_positions_for_verts, self._warp_stage1_counts, dim=0
)
self._warp_stage1_skinned_orig_centered = skinned_orig - joint_positions_expanded
repeat_indices = torch.repeat_interleave(
torch.arange(len(stage1_joint_indices_list), device=self.device),
self._warp_stage1_counts,
)
self._warp_stage1_repeat_indices = repeat_indices
else:
self._warp_stage1_offsets = torch.tensor([], dtype=torch.int32, device=self.device)
self._warp_stage1_counts = torch.tensor([], dtype=torch.int32, device=self.device)
self._warp_stage1_skinned_vids_flat = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage1_joint_indices = torch.tensor([], dtype=torch.long, device=self.device)
self._warp_stage1_joint_to_batch_idx = {}
self._warp_stage1_skinned_orig_centered = torch.tensor(
[], dtype=self.dtype, device=self.device
)
self._warp_stage1_repeat_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
# ===== Stage 2: Child joint alignment =====
stage2_n1_child_list = []
stage2_n1_joint_indices_list = []
stage2_n2_child_flat = []
stage2_n2_counts_list = []
stage2_n2_joint_indices_list = []
stage2_n2_joint_to_batch_idx = {}
batch_idx_n2 = 0
for i in range(1, self.num_joints):
children = self.joint_children_ids[i]
if not children:
continue
if i in frozen:
continue
if len(children) == 1:
stage2_n1_joint_indices_list.append(i)
stage2_n1_child_list.append(children[0])
else:
stage2_n2_joint_indices_list.append(i)
stage2_n2_joint_to_batch_idx[i] = batch_idx_n2
stage2_n2_child_flat.extend(children)
stage2_n2_counts_list.append(len(children))
batch_idx_n2 += 1
if stage2_n1_joint_indices_list:
self._warp_stage2_n1_joint_indices = torch.tensor(
stage2_n1_joint_indices_list, dtype=torch.long, device=self.device
)
self._warp_stage2_n1_child_indices = torch.tensor(
stage2_n1_child_list, dtype=torch.long, device=self.device
)
stage1_batch_indices_n1 = [
stage1_joint_to_batch_idx[j] for j in stage2_n1_joint_indices_list
]
self._warp_stage2_n1_to_stage1_indices = torch.tensor(
stage1_batch_indices_n1, dtype=torch.long, device=self.device
)
bind_world = self.bind_world_transforms
pos_children_n1 = bind_world[self._warp_stage2_n1_child_indices, :3, 3]
parent_positions_n1 = bind_world[self._warp_stage2_n1_joint_indices, :3, 3]
self._warp_stage2_n1_children_orig_centered = pos_children_n1 - parent_positions_n1
else:
self._warp_stage2_n1_joint_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n1_child_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n1_to_stage1_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n1_children_orig_centered = torch.tensor(
[], dtype=self.dtype, device=self.device
)
if stage2_n2_counts_list:
counts = torch.tensor(stage2_n2_counts_list, dtype=torch.int32, device=self.device)
offsets = torch.zeros_like(counts)
if counts.numel() > 1:
offsets[1:] = torch.cumsum(counts[:-1], dim=0)
self._warp_stage2_n2_offsets = offsets
self._warp_stage2_n2_counts = counts
self._warp_stage2_n2_child_flat = torch.tensor(
stage2_n2_child_flat, dtype=torch.long, device=self.device
)
self._warp_stage2_n2_joint_indices = torch.tensor(
stage2_n2_joint_indices_list, dtype=torch.long, device=self.device
)
stage1_batch_indices_n2 = [
stage1_joint_to_batch_idx[j] for j in stage2_n2_joint_indices_list
]
self._warp_stage2_n2_to_stage1_indices = torch.tensor(
stage1_batch_indices_n2, dtype=torch.long, device=self.device
)
bind_world = self.bind_world_transforms
pos_children_n2 = bind_world[self._warp_stage2_n2_child_flat, :3, 3]
parent_positions_n2 = bind_world[self._warp_stage2_n2_joint_indices, :3, 3]
parent_positions_n2_expanded = torch.repeat_interleave(
parent_positions_n2, self._warp_stage2_n2_counts, dim=0
)
self._warp_stage2_n2_children_orig_centered = (
pos_children_n2 - parent_positions_n2_expanded
)
repeat_indices = torch.repeat_interleave(
torch.arange(len(stage2_n2_joint_indices_list), device=self.device),
self._warp_stage2_n2_counts,
)
self._warp_stage2_n2_repeat_indices = repeat_indices
self._warp_stage2_n2_R_repeat_indices = repeat_indices
else:
self._warp_stage2_n2_offsets = torch.tensor([], dtype=torch.int32, device=self.device)
self._warp_stage2_n2_counts = torch.tensor([], dtype=torch.int32, device=self.device)
self._warp_stage2_n2_child_flat = torch.tensor([], dtype=torch.long, device=self.device)
self._warp_stage2_n2_joint_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n2_to_stage1_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n2_children_orig_centered = torch.tensor(
[], dtype=self.dtype, device=self.device
)
self._warp_stage2_n2_repeat_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
self._warp_stage2_n2_R_repeat_indices = torch.tensor(
[], dtype=torch.long, device=self.device
)
all_stage2_joints = stage2_n1_joint_indices_list + stage2_n2_joint_indices_list
if all_stage2_joints:
self._warp_stage2_joint_indices = torch.tensor(
all_stage2_joints, dtype=torch.long, device=self.device
)
else:
self._warp_stage2_joint_indices = torch.tensor([], dtype=torch.long, device=self.device)
self._warp_frozen_joints = torch.tensor(list(frozen), dtype=torch.long, device=self.device)
if unskinned_end_joints:
self._warp_unskinned_end_joints = torch.tensor(
unskinned_end_joints, dtype=torch.long, device=self.device
)
self._warp_unskinned_end_joint_parents = torch.tensor(
unskinned_end_joint_parents, dtype=torch.long, device=self.device
)
else:
self._warp_unskinned_end_joints = torch.tensor([], dtype=torch.long, device=self.device)
self._warp_unskinned_end_joint_parents = torch.tensor(
[], dtype=torch.long, device=self.device
)
frozen_parents = [self.joint_parent_ids[i] for i in frozen if i > 0]
self._warp_frozen_parents = torch.tensor(
frozen_parents, dtype=torch.long, device=self.device
)
def fit(self, target_shapes):
"""Fit the skeleton to new shapes by adjusting joint positions and orientations.
Args:
target_shapes: (B, V, 3) or (V, 3) array of new vertex positions
Returns:
target_bind_world_transforms: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space
"""
new_joint_positions = self.fit_joint_positions(target_shapes)
if self.use_warp_for_rotations:
world_bind_pose = self.fit_rotations_warp(
new_joint_positions,
target_shapes,
)
else:
world_bind_pose = self.fit_joint_rotations(
new_joint_positions,
target_shapes,
)
return world_bind_pose
def fit_joint_positions(self, target_shapes):
"""Fit the skeleton to new shapes by adjusting joint positions.
Args:
target_shapes: (B, V, 3) or (V, 3) array of new vertex positions
Returns:
new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions
"""
dtype, device = self.dtype, self.device
J = self.num_joints
added_batch = False
if target_shapes.ndim == 2:
target_shapes = target_shapes[None, :, :]
added_batch = True
B = target_shapes.shape[0]
if self.sparse_rbf_matrix is not None:
target_shapes_flat = target_shapes.permute(1, 0, 2).reshape(target_shapes.shape[1], -1)
new_joint_positions = torch.mm(self.sparse_rbf_matrix, target_shapes_flat)
new_joint_positions = new_joint_positions.reshape(J, B, 3).permute(1, 0, 2)
else:
cols = []
root_pos = self.bind_world_transforms[0, :3, 3].to(dtype=dtype, device=device)
root_pos = root_pos.view(1, 1, 3).expand(B, 1, 3)
cols.append(root_pos)
for i in range(1, J):
target_vertex_positions = target_shapes[:, self.regressor_mask[:, i]]
joint_query_position = self.bind_world_transforms[i : i + 1, :3, 3]
pred = self.joint_pos_regressors[i].interpolate(
target_vertex_positions, joint_query_position
)
if pred.ndim == 2:
pred = pred[:, None, :]
cols.append(pred)
new_joint_positions = torch.cat(cols, dim=1)
return new_joint_positions[0] if added_batch else new_joint_positions
def fit_joint_rotations(self, new_joint_positions, target_shapes):
"""Fit the skeleton to new positions by adjusting joint orientations.
Args:
new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions
target_shapes: (B, V, 3) or (V, 3) array of new vertex positions
Returns:
world_bind_pose: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space
"""
dtype, device = self.dtype, self.device
J = self.num_joints
freeze_rotations = self.freeze_rotations
skip_endjoints = self.skip_endjoints
added_batch = False
if new_joint_positions.ndim == 2:
new_joint_positions = new_joint_positions[None, :, :]
added_batch = True
if new_joint_positions.shape[-2:] != (J, 3):
raise ValueError(
f"Expected new_joint_positions to have shape (...,{J},3); got {new_joint_positions.shape}"
)
if target_shapes.ndim == 2:
target_shapes = target_shapes[None, :, :]
t = new_joint_positions
B = t.shape[0]
bind_world = self.bind_world_transforms[None, ...].expand(B, J, 4, 4)
bind_local = self.bind_local_transforms[None, ...].expand(B, J, 4, 4)
R0 = self.bind_world_transforms[..., :3, :3].clone()
R = R0[None, ...].expand(B, J, 3, 3)
for i in range(1, J):
jmask = one_hot_1d(J, i, dtype=dtype, device=device)[None, :, None, None]
children = self.joint_children_ids[i]
if not children and skip_endjoints:
p = self.joint_parent_ids[i]
R_parent = R[:, p : p + 1, :, :]
R = R * (1 - jmask) + R_parent * jmask
continue
if i in freeze_rotations:
p = self.joint_parent_ids[i]
R_parent = R[:, p : p + 1, :, :]
R_i_new = R_parent @ bind_local[:, i : i + 1, :3, :3]
R = R * (1 - jmask) + R_i_new * jmask
continue
if self.skip_inverse_lbs:
R_init = (
torch.eye(3, dtype=dtype, device=device)
.unsqueeze(0)
.unsqueeze(0)
.expand(B, 1, 3, 3)
)
else:
skinned_vids = torch.where(self.skinning_weights[:, i] > 0.01)[0]
skinned_orig = (
self.bind_shape[skinned_vids] - self.bind_world_transforms[i, :3, 3]
)[None, :, :]
skinned_new = target_shapes[:, skinned_vids, :] - t[:, i : i + 1]
R_init = align_vectors(skinned_new, skinned_orig)
if len(children) > 0:
pos_children_orig = (
bind_world[:, :, :3, 3][:, children] - bind_world[:, i : i + 1, :3, 3]
)
R_init_squeezed = R_init.squeeze(1)
pos_children_orig = (R_init_squeezed @ pos_children_orig.swapaxes(-2, -1)).swapaxes(
-2, -1
)
pos_children_new = t[:, children, :] - t[:, i : i + 1, :]
align_rot = align_vectors(pos_children_new, pos_children_orig)
R_i_new = align_rot @ R_init_squeezed @ R[:, i, :, :]
else:
R_i_new = R_init.squeeze(1) @ R[:, i, :, :]
R = R * (1 - jmask) + R_i_new[:, None, :, :] * jmask
world_bind_pose = SE3_from_Rt(R, t)
return world_bind_pose[0] if added_batch else world_bind_pose
def fit_rotations_warp(self, new_joint_positions, target_shapes):
"""Warp-accelerated version of fit_joint_rotations using GPU-parallel alignment.
Args:
new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions
target_shapes: (B, V, 3) or (V, 3) array of new vertex positions
Returns:
world_bind_pose: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space
"""
if self._warp_stage1_offsets is None:
self._precompute_warp_data()
dtype, device = self.dtype, self.device
J = self.num_joints
added_batch = False
if new_joint_positions.ndim == 2:
new_joint_positions = new_joint_positions[None, :, :]
added_batch = True
if new_joint_positions.shape[-2:] != (J, 3):
raise ValueError(
f"Expected new_joint_positions to have shape (...,{J},3); got {new_joint_positions.shape}"
)
if target_shapes.ndim == 2:
target_shapes = target_shapes[None, :, :]
t = new_joint_positions
B = t.shape[0]
bind_local = self.bind_local_transforms[None, ...].expand(B, J, 4, 4)
R0 = self.bind_world_transforms[..., :3, :3].clone()
R = R0[None, ...].expand(B, J, 3, 3)
# ===== Inverse LBS: Skinned vertices (single Warp call) =====
R_init_all = None
if len(self._warp_stage1_joint_indices) > 0:
if self.skip_inverse_lbs:
num_joints_stage1 = len(self._warp_stage1_joint_indices)
R_init_all = (
torch.eye(3, dtype=dtype, device=device)
.unsqueeze(0)
.unsqueeze(0)
.expand(B, num_joints_stage1, 3, 3)
)
else:
skinned_orig = self._warp_stage1_skinned_orig_centered
skinned_new = target_shapes[:, self._warp_stage1_skinned_vids_flat, :]
new_joint_positions_for_verts = t[:, self._warp_stage1_joint_indices, :]
new_joint_positions_expanded = new_joint_positions_for_verts[
:, self._warp_stage1_repeat_indices, :
]
skinned_new = skinned_new - new_joint_positions_expanded
skinned_orig_batched = skinned_orig.unsqueeze(0).expand(B, -1, -1).reshape(-1, 3)
skinned_new_flat = skinned_new.reshape(-1, 3)
num_joints_stage1 = len(self._warp_stage1_joint_indices)
offsets_batched = (
self._warp_stage1_offsets.unsqueeze(0)
+ torch.arange(B, device=device, dtype=torch.int32).unsqueeze(1)
* skinned_orig.shape[0]
)
offsets_batched = offsets_batched.flatten()
counts_batched = self._warp_stage1_counts.unsqueeze(0).expand(B, -1).flatten()
R_init_all = align_vectors_warp(
skinned_new_flat,
skinned_orig_batched,
offsets_batched,
counts_batched,
method=self.rotation_method,
)
R_init_all = R_init_all.reshape(B, num_joints_stage1, 3, 3)
# ===== Stage 2: Child alignment =====
align_rot_n1 = None
align_rot_n2 = None
if R_init_all is not None and parallel_rodrigues_kabsch_warp is not None:
num_n1 = len(self._warp_stage2_n1_joint_indices)
if num_n1 > 0:
pos_children_orig_n1 = self._warp_stage2_n1_children_orig_centered
pos_children_new_n1 = (
t[:, self._warp_stage2_n1_child_indices, :]
- t[:, self._warp_stage2_n1_joint_indices, :]
)
pos_children_orig_n1_batched = pos_children_orig_n1.unsqueeze(0).expand(B, -1, -1)
R_init_for_n1 = R_init_all[:, self._warp_stage2_n1_to_stage1_indices, :, :]
pos_children_orig_n1_rotated = torch.bmm(
R_init_for_n1.reshape(-1, 3, 3), pos_children_orig_n1_batched.reshape(-1, 3, 1)
).reshape(B, num_n1, 3)
src_vecs_n1 = pos_children_orig_n1_rotated.reshape(-1, 3)
tgt_vecs_n1 = pos_children_new_n1.reshape(-1, 3)
else:
src_vecs_n1 = torch.empty((0, 3), dtype=dtype, device=device)
tgt_vecs_n1 = torch.empty((0, 3), dtype=dtype, device=device)
num_n2 = len(self._warp_stage2_n2_joint_indices)
if num_n2 > 0:
pos_children_orig_n2 = self._warp_stage2_n2_children_orig_centered
pos_children_new_n2 = t[:, self._warp_stage2_n2_child_flat, :]
parent_positions_new_n2 = t[:, self._warp_stage2_n2_joint_indices, :]
parent_positions_new_n2_expanded = parent_positions_new_n2[
:, self._warp_stage2_n2_repeat_indices, :
]
pos_children_new_n2 = pos_children_new_n2 - parent_positions_new_n2_expanded
pos_children_orig_n2_batched = pos_children_orig_n2.unsqueeze(0).expand(B, -1, -1)
R_init_for_n2 = R_init_all[:, self._warp_stage2_n2_to_stage1_indices, :, :]
R_init_n2_expanded = R_init_for_n2[:, self._warp_stage2_n2_R_repeat_indices, :, :]
pos_children_orig_n2_rotated = torch.bmm(
R_init_n2_expanded.reshape(-1, 3, 3),
pos_children_orig_n2_batched.reshape(-1, 3, 1),
).reshape(B, -1, 3)
pos_children_orig_n2_flat = pos_children_orig_n2_rotated.reshape(-1, 3)
pos_children_new_n2_flat = pos_children_new_n2.reshape(-1, 3)
offsets_batched_n2 = self._warp_stage2_n2_offsets.unsqueeze(0) + torch.arange(
B, device=device, dtype=torch.int32
).unsqueeze(1) * len(self._warp_stage2_n2_child_flat)
offsets_batched_n2 = offsets_batched_n2.flatten()
counts_batched_n2 = self._warp_stage2_n2_counts.unsqueeze(0).expand(B, -1).flatten()
else:
pos_children_orig_n2_flat = torch.empty((0, 3), dtype=dtype, device=device)
pos_children_new_n2_flat = torch.empty((0, 3), dtype=dtype, device=device)
offsets_batched_n2 = torch.empty((0,), dtype=torch.int32, device=device)
counts_batched_n2 = torch.empty((0,), dtype=torch.int32, device=device)
align_rot_n1_flat, align_rot_n2_flat = parallel_rodrigues_kabsch_warp(
tgt_vecs_n1,
src_vecs_n1,
pos_children_new_n2_flat,
pos_children_orig_n2_flat,
offsets_batched_n2,
counts_batched_n2,
method=self.rotation_method,
)
align_rot_n1 = align_rot_n1_flat.reshape(B, num_n1, 3, 3)
align_rot_n2 = align_rot_n2_flat.reshape(B, num_n2, 3, 3)
# ===== Combine rotations (no loops, fully vectorized) =====
R_new = R.clone()
if R_init_all is not None:
stage1_joints = self._warp_stage1_joint_indices
R_new[:, stage1_joints, :, :] = torch.bmm(
R_init_all.reshape(B * len(stage1_joints), 3, 3),
R[:, stage1_joints, :, :].reshape(B * len(stage1_joints), 3, 3),
).reshape(B, len(stage1_joints), 3, 3)
if align_rot_n1 is not None:
stage2_n1_joints = self._warp_stage2_n1_joint_indices
R_init_for_n1 = R_init_all[:, self._warp_stage2_n1_to_stage1_indices, :, :]
temp = torch.bmm(
align_rot_n1.reshape(B * len(stage2_n1_joints), 3, 3),
R_init_for_n1.reshape(B * len(stage2_n1_joints), 3, 3),
).reshape(B, len(stage2_n1_joints), 3, 3)
R_new[:, stage2_n1_joints, :, :] = torch.bmm(
temp.reshape(B * len(stage2_n1_joints), 3, 3),
R[:, stage2_n1_joints, :, :].reshape(B * len(stage2_n1_joints), 3, 3),
).reshape(B, len(stage2_n1_joints), 3, 3)
if align_rot_n2 is not None:
stage2_n2_joints = self._warp_stage2_n2_joint_indices
R_init_for_n2 = R_init_all[:, self._warp_stage2_n2_to_stage1_indices, :, :]
temp = torch.bmm(
align_rot_n2.reshape(B * len(stage2_n2_joints), 3, 3),
R_init_for_n2.reshape(B * len(stage2_n2_joints), 3, 3),
).reshape(B, len(stage2_n2_joints), 3, 3)
R_new[:, stage2_n2_joints, :, :] = torch.bmm(
temp.reshape(B * len(stage2_n2_joints), 3, 3),
R[:, stage2_n2_joints, :, :].reshape(B * len(stage2_n2_joints), 3, 3),
).reshape(B, len(stage2_n2_joints), 3, 3)
if len(self._warp_unskinned_end_joints) > 0:
R_new[:, self._warp_unskinned_end_joints, :, :] = R_new[
:, self._warp_unskinned_end_joint_parents, :, :
]
if len(self._warp_frozen_joints) > 0:
R_parents = R_new[:, self._warp_frozen_parents, :, :]
R_bind_local = bind_local[0, self._warp_frozen_joints, :3, :3]
num_frozen = len(self._warp_frozen_joints)
R_frozen = torch.bmm(
R_parents.reshape(B * num_frozen, 3, 3),
R_bind_local.unsqueeze(0).expand(B, -1, -1, -1).reshape(B * num_frozen, 3, 3),
).reshape(B, num_frozen, 3, 3)
R_new[:, self._warp_frozen_joints, :, :] = R_frozen
world_bind_pose = SE3_from_Rt(R_new, t)
return world_bind_pose[0] if added_batch else world_bind_pose