| |
| |
|
|
| import torch |
|
|
| from ._utils import one_hot_1d, require_torch_tensors |
| from .interpolate import RadialBasisFunction |
| from .rig_utils import get_joint_children_ids, joint_world_to_local |
| from .transforms import SE3_from_Rt, align_vectors |
|
|
| try: |
| from .align_vectors_warp import ( |
| align_vectors_warp, |
| parallel_rodrigues_kabsch_warp, |
| rodrigues_rotation_warp, |
| ) |
| except ImportError: |
| align_vectors_warp = None |
| rodrigues_rotation_warp = None |
| parallel_rodrigues_kabsch_warp = None |
|
|
|
|
| class SkeletonTransfer(torch.nn.Module): |
| def __init__( |
| self, |
| joint_parent_ids, |
| bind_world_transforms, |
| bind_shape, |
| skinning_weights, |
| rbf_kernel="linear", |
| vertex_ids_to_exclude=None, |
| freeze_rotations=None, |
| skip_endjoints=True, |
| use_sparse_rbf_matrix=True, |
| use_warp_for_rotations=True, |
| rotation_method="kabsch", |
| skip_inverse_lbs=False, |
| ): |
| """Initialize a SkeletonTransfer instance for fitting a skeleton to new shapes. |
| Args: |
| joint_parent_ids: (J,) int array of joint parent indices |
| bind_world_transforms: (J, 4, 4) array of joint bind poses in world space |
| bind_shape: (V, 3) array of vertex positions in bind pose |
| skinning_weights: (V, J) array of skinning weights |
| rbf_kernel: type of RBF kernel to use for joint position regression |
| vertex_ids_to_exclude: (V,) int array of vertex ids to exclude from the joint position regressors |
| freeze_rotations: list of joint ids to freeze to bind pose (for Warp mode) |
| skip_endjoints: bool, whether to skip rotation fitting for end joints (for Warp mode) |
| use_sparse_rbf_matrix: bool, whether to use a sparse RBF matrix for joint position regression |
| use_warp_for_rotations: bool, whether to use Warp-based rotation fitting (requires Warp) |
| rotation_method: str, rotation extraction method ('kabsch' or 'newton-schulz') |
| skip_inverse_lbs: bool, whether to skip Inverse LBS (skinned vertex fitting) and use identity R_init |
| """ |
| super().__init__() |
| if freeze_rotations is None: |
| freeze_rotations = [] |
| require_torch_tensors( |
| bind_world_transforms, bind_shape, skinning_weights, name="SkeletonTransfer inputs" |
| ) |
|
|
| assert len(joint_parent_ids) == len(bind_world_transforms) == skinning_weights.shape[1], ( |
| "joint_names, joint_parent_ids, bind_world_transforms, and skinning_weights.shape[1] must have the same length" |
| ) |
| self.num_joints = len(joint_parent_ids) |
|
|
| if isinstance(joint_parent_ids, torch.Tensor): |
| self.joint_parent_ids = joint_parent_ids.detach().cpu().tolist() |
| else: |
| self.joint_parent_ids = joint_parent_ids |
| self.joint_children_ids = get_joint_children_ids(self.joint_parent_ids) |
| bind_world_transforms = bind_world_transforms.detach() |
| bind_shape = bind_shape.detach() |
| skinning_weights = skinning_weights.detach() |
| self.register_buffer("bind_world_transforms", bind_world_transforms, persistent=False) |
| self.register_buffer( |
| "bind_local_transforms", |
| joint_world_to_local(bind_world_transforms, joint_parent_ids), |
| persistent=False, |
| ) |
| self.register_buffer("bind_shape", bind_shape, persistent=False) |
| self.register_buffer("skinning_weights", skinning_weights, persistent=False) |
| self.register_buffer("regressor_mask", None, persistent=False) |
| self.sparse_rbf_matrix = None |
| self.joint_pos_regressors = None |
| self.rbf_kernel = rbf_kernel |
| if isinstance(vertex_ids_to_exclude, torch.Tensor): |
| self.vertex_ids_to_exclude = vertex_ids_to_exclude.detach().cpu().tolist() |
| else: |
| self.vertex_ids_to_exclude = vertex_ids_to_exclude |
| self.freeze_rotations = set(freeze_rotations) if freeze_rotations else set() |
| self.skip_endjoints = skip_endjoints |
| self.use_sparse_rbf_matrix = use_sparse_rbf_matrix |
| self.rotation_method = rotation_method |
| self.skip_inverse_lbs = skip_inverse_lbs |
| self._precompute_regressors() |
|
|
| |
| self.register_buffer("_warp_stage1_offsets", None, persistent=False) |
| self.register_buffer("_warp_stage1_counts", None, persistent=False) |
| self.register_buffer("_warp_stage1_skinned_vids_flat", None, persistent=False) |
| self.register_buffer("_warp_stage1_joint_indices", None, persistent=False) |
| self.register_buffer("_warp_stage1_repeat_indices", None, persistent=False) |
| self.register_buffer("_warp_stage1_skinned_orig_centered", None, persistent=False) |
| self._warp_stage1_joint_to_batch_idx = None |
|
|
| self.register_buffer("_warp_stage2_offsets", None, persistent=False) |
| self.register_buffer("_warp_stage2_counts", None, persistent=False) |
| self.register_buffer("_warp_stage2_child_flat", None, persistent=False) |
| self.register_buffer("_warp_stage2_joint_indices", None, persistent=False) |
| self._warp_stage2_joint_to_batch_idx = None |
| self.register_buffer("_warp_stage2_repeat_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_R_repeat_indices", None, persistent=False) |
| self.register_buffer("_warp_frozen_joints", None, persistent=False) |
| self.register_buffer("_warp_unskinned_end_joints", None, persistent=False) |
| self.register_buffer("_warp_unskinned_end_joint_parents", None, persistent=False) |
| self.register_buffer("_warp_stage2_n1_joint_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n1_child_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n1_to_stage1_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n1_children_orig_centered", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_offsets", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_counts", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_child_flat", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_joint_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_to_stage1_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_children_orig_centered", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_repeat_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_n2_R_repeat_indices", None, persistent=False) |
| self.register_buffer("_warp_stage2_joint_indices", None, persistent=False) |
| self.register_buffer("_warp_frozen_parents", None, persistent=False) |
| self.use_warp_for_rotations = use_warp_for_rotations |
| if use_warp_for_rotations: |
| if align_vectors_warp is None or rodrigues_rotation_warp is None: |
| raise ImportError("Warp-based rotation fitting requires Warp to be installed.") |
| self._precompute_warp_data() |
|
|
| def update_bind(self, bind_world_transforms, bind_shape): |
| """Update bind-pose data without rebuilding structural caches. |
| |
| This is much faster than constructing a new SkeletonTransfer and is |
| suitable when only the identity (shape) changes but the topology, |
| skinning weights, and skeleton structure remain the same. |
| """ |
| self.bind_world_transforms = bind_world_transforms.detach() |
| self.bind_local_transforms = joint_world_to_local( |
| bind_world_transforms, self.joint_parent_ids |
| ) |
| self.bind_shape = bind_shape.detach() |
|
|
| @property |
| def device(self): |
| return self.bind_world_transforms.device |
|
|
| @property |
| def dtype(self): |
| return self.bind_world_transforms.dtype |
|
|
| def _apply(self, fn): |
| super()._apply(fn) |
| self._precompute_regressors() |
| if self.use_warp_for_rotations: |
| self._precompute_warp_data() |
| return self |
|
|
| def _precompute_regressors(self): |
| regressor_mask = self.skinning_weights > 0.0 |
| regressor_mask &= self.skinning_weights[:, self.joint_parent_ids] > 0.0 |
| zero_weight_ids = torch.where(regressor_mask.sum(dim=0) == 0.0)[0] |
|
|
| joint_parent_ids = torch.as_tensor( |
| self.joint_parent_ids, dtype=torch.long, device=self.device |
| ) |
| joint_parent_ids_cur = joint_parent_ids.clone() |
| regressor_mask[:, zero_weight_ids] = self.skinning_weights[:, zero_weight_ids] > 0.0 |
| while len(zero_weight_ids) > 1: |
| regressor_mask[:, zero_weight_ids] |= ( |
| self.skinning_weights[:, joint_parent_ids_cur][:, zero_weight_ids] > 0.0 |
| ) |
| zero_weight_ids = torch.where(regressor_mask.sum(dim=0) == 0.0)[0] |
| joint_parent_ids_cur_update = joint_parent_ids[joint_parent_ids_cur] |
| if torch.equal(joint_parent_ids_cur_update, joint_parent_ids_cur): |
| break |
| joint_parent_ids_cur = joint_parent_ids_cur_update |
|
|
| if torch.equal(zero_weight_ids, torch.tensor([0, 1], device=self.device)): |
| print("Aggregating children of hips") |
| child_ids = get_joint_children_ids(joint_parent_ids)[1] |
| regressor_mask[:, 1] = regressor_mask[:, child_ids].any(axis=1) |
|
|
| if self.vertex_ids_to_exclude is not None: |
| regressor_mask[self.vertex_ids_to_exclude] = False |
| self.regressor_mask = regressor_mask |
|
|
| if self.dtype == torch.float16: |
| bind_shape_rbf = self.bind_shape.to(torch.float32) |
| else: |
| bind_shape_rbf = self.bind_shape |
|
|
| self.joint_pos_regressors = [ |
| RadialBasisFunction( |
| bind_shape_rbf[regressor_mask[:, i]], |
| kernel=self.rbf_kernel, |
| include_polynomial=True, |
| ) |
| if i != 0 |
| else None |
| for i in range(self.num_joints) |
| ] |
| if self.use_sparse_rbf_matrix: |
| all_weights = [] |
| all_col_indices = [] |
| crow_indices = [0] |
|
|
| for i, rbf in enumerate(self.joint_pos_regressors): |
| if rbf is None: |
| crow_indices.append(crow_indices[-1]) |
| continue |
|
|
| joint_query_position = self.bind_world_transforms[i, :3, 3] |
| if rbf.dtype != self.dtype: |
| joint_query_position = joint_query_position.to(rbf.dtype) |
| w = rbf.get_basis_weights(joint_query_position) |
|
|
| all_weights.append(w) |
| all_col_indices.append(torch.where(regressor_mask[:, i])[0]) |
| crow_indices.append(crow_indices[-1] + len(w)) |
|
|
| flat_values = torch.cat(all_weights) |
| flat_indices = torch.cat(all_col_indices) |
| crow_indices_tensor = torch.tensor(crow_indices, device=self.device, dtype=torch.int64) |
|
|
| self.sparse_rbf_matrix = torch.sparse_csr_tensor( |
| crow_indices=crow_indices_tensor, |
| col_indices=flat_indices, |
| values=flat_values.to(self.dtype), |
| size=(len(self.bind_world_transforms), len(self.bind_shape)), |
| device=self.device, |
| dtype=self.dtype, |
| ) |
| else: |
| self.sparse_rbf_matrix = None |
|
|
| def _precompute_warp_data(self): |
| """Precompute offsets, counts, and indices for Warp-based rotation fitting.""" |
| frozen = self.freeze_rotations |
|
|
| unskinned_end_joints = [] |
| unskinned_end_joint_parents = [] |
|
|
| |
| stage1_skinned_vids_flat = [] |
| stage1_counts_list = [] |
| stage1_joint_indices_list = [] |
| stage1_joint_to_batch_idx = {} |
|
|
| batch_idx = 0 |
| for i in range(1, self.num_joints): |
| if i in frozen: |
| continue |
|
|
| children = self.joint_children_ids[i] |
| is_end_joint = len(children) == 0 |
|
|
| skinned_vids = torch.where(self.skinning_weights[:, i] > 0.01)[0] |
| num_skinned = len(skinned_vids) |
|
|
| if is_end_joint: |
| if self.skip_endjoints: |
| unskinned_end_joints.append(i) |
| unskinned_end_joint_parents.append(self.joint_parent_ids[i]) |
| continue |
| elif num_skinned < 1: |
| unskinned_end_joints.append(i) |
| unskinned_end_joint_parents.append(self.joint_parent_ids[i]) |
| continue |
|
|
| stage1_joint_indices_list.append(i) |
| stage1_joint_to_batch_idx[i] = batch_idx |
| stage1_skinned_vids_flat.extend(skinned_vids.tolist()) |
| stage1_counts_list.append(num_skinned) |
| batch_idx += 1 |
|
|
| if stage1_counts_list: |
| counts = torch.tensor(stage1_counts_list, dtype=torch.int32, device=self.device) |
| offsets = torch.zeros_like(counts) |
| if counts.numel() > 1: |
| offsets[1:] = torch.cumsum(counts[:-1], dim=0) |
|
|
| self._warp_stage1_offsets = offsets |
| self._warp_stage1_counts = counts |
| self._warp_stage1_skinned_vids_flat = torch.tensor( |
| stage1_skinned_vids_flat, dtype=torch.long, device=self.device |
| ) |
| self._warp_stage1_joint_indices = torch.tensor( |
| stage1_joint_indices_list, dtype=torch.long, device=self.device |
| ) |
| self._warp_stage1_joint_to_batch_idx = stage1_joint_to_batch_idx |
|
|
| skinned_orig = self.bind_shape[self._warp_stage1_skinned_vids_flat] |
| joint_positions_for_verts = self.bind_world_transforms[ |
| self._warp_stage1_joint_indices, :3, 3 |
| ] |
| joint_positions_expanded = torch.repeat_interleave( |
| joint_positions_for_verts, self._warp_stage1_counts, dim=0 |
| ) |
| self._warp_stage1_skinned_orig_centered = skinned_orig - joint_positions_expanded |
|
|
| repeat_indices = torch.repeat_interleave( |
| torch.arange(len(stage1_joint_indices_list), device=self.device), |
| self._warp_stage1_counts, |
| ) |
| self._warp_stage1_repeat_indices = repeat_indices |
| else: |
| self._warp_stage1_offsets = torch.tensor([], dtype=torch.int32, device=self.device) |
| self._warp_stage1_counts = torch.tensor([], dtype=torch.int32, device=self.device) |
| self._warp_stage1_skinned_vids_flat = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage1_joint_indices = torch.tensor([], dtype=torch.long, device=self.device) |
| self._warp_stage1_joint_to_batch_idx = {} |
| self._warp_stage1_skinned_orig_centered = torch.tensor( |
| [], dtype=self.dtype, device=self.device |
| ) |
| self._warp_stage1_repeat_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
|
|
| |
| stage2_n1_child_list = [] |
| stage2_n1_joint_indices_list = [] |
|
|
| stage2_n2_child_flat = [] |
| stage2_n2_counts_list = [] |
| stage2_n2_joint_indices_list = [] |
| stage2_n2_joint_to_batch_idx = {} |
|
|
| batch_idx_n2 = 0 |
| for i in range(1, self.num_joints): |
| children = self.joint_children_ids[i] |
|
|
| if not children: |
| continue |
| if i in frozen: |
| continue |
|
|
| if len(children) == 1: |
| stage2_n1_joint_indices_list.append(i) |
| stage2_n1_child_list.append(children[0]) |
| else: |
| stage2_n2_joint_indices_list.append(i) |
| stage2_n2_joint_to_batch_idx[i] = batch_idx_n2 |
| stage2_n2_child_flat.extend(children) |
| stage2_n2_counts_list.append(len(children)) |
| batch_idx_n2 += 1 |
|
|
| if stage2_n1_joint_indices_list: |
| self._warp_stage2_n1_joint_indices = torch.tensor( |
| stage2_n1_joint_indices_list, dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n1_child_indices = torch.tensor( |
| stage2_n1_child_list, dtype=torch.long, device=self.device |
| ) |
|
|
| stage1_batch_indices_n1 = [ |
| stage1_joint_to_batch_idx[j] for j in stage2_n1_joint_indices_list |
| ] |
| self._warp_stage2_n1_to_stage1_indices = torch.tensor( |
| stage1_batch_indices_n1, dtype=torch.long, device=self.device |
| ) |
|
|
| bind_world = self.bind_world_transforms |
| pos_children_n1 = bind_world[self._warp_stage2_n1_child_indices, :3, 3] |
| parent_positions_n1 = bind_world[self._warp_stage2_n1_joint_indices, :3, 3] |
| self._warp_stage2_n1_children_orig_centered = pos_children_n1 - parent_positions_n1 |
| else: |
| self._warp_stage2_n1_joint_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n1_child_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n1_to_stage1_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n1_children_orig_centered = torch.tensor( |
| [], dtype=self.dtype, device=self.device |
| ) |
|
|
| if stage2_n2_counts_list: |
| counts = torch.tensor(stage2_n2_counts_list, dtype=torch.int32, device=self.device) |
| offsets = torch.zeros_like(counts) |
| if counts.numel() > 1: |
| offsets[1:] = torch.cumsum(counts[:-1], dim=0) |
|
|
| self._warp_stage2_n2_offsets = offsets |
| self._warp_stage2_n2_counts = counts |
| self._warp_stage2_n2_child_flat = torch.tensor( |
| stage2_n2_child_flat, dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n2_joint_indices = torch.tensor( |
| stage2_n2_joint_indices_list, dtype=torch.long, device=self.device |
| ) |
|
|
| stage1_batch_indices_n2 = [ |
| stage1_joint_to_batch_idx[j] for j in stage2_n2_joint_indices_list |
| ] |
| self._warp_stage2_n2_to_stage1_indices = torch.tensor( |
| stage1_batch_indices_n2, dtype=torch.long, device=self.device |
| ) |
|
|
| bind_world = self.bind_world_transforms |
| pos_children_n2 = bind_world[self._warp_stage2_n2_child_flat, :3, 3] |
| parent_positions_n2 = bind_world[self._warp_stage2_n2_joint_indices, :3, 3] |
| parent_positions_n2_expanded = torch.repeat_interleave( |
| parent_positions_n2, self._warp_stage2_n2_counts, dim=0 |
| ) |
| self._warp_stage2_n2_children_orig_centered = ( |
| pos_children_n2 - parent_positions_n2_expanded |
| ) |
|
|
| repeat_indices = torch.repeat_interleave( |
| torch.arange(len(stage2_n2_joint_indices_list), device=self.device), |
| self._warp_stage2_n2_counts, |
| ) |
| self._warp_stage2_n2_repeat_indices = repeat_indices |
| self._warp_stage2_n2_R_repeat_indices = repeat_indices |
| else: |
| self._warp_stage2_n2_offsets = torch.tensor([], dtype=torch.int32, device=self.device) |
| self._warp_stage2_n2_counts = torch.tensor([], dtype=torch.int32, device=self.device) |
| self._warp_stage2_n2_child_flat = torch.tensor([], dtype=torch.long, device=self.device) |
| self._warp_stage2_n2_joint_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n2_to_stage1_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n2_children_orig_centered = torch.tensor( |
| [], dtype=self.dtype, device=self.device |
| ) |
| self._warp_stage2_n2_repeat_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
| self._warp_stage2_n2_R_repeat_indices = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
|
|
| all_stage2_joints = stage2_n1_joint_indices_list + stage2_n2_joint_indices_list |
| if all_stage2_joints: |
| self._warp_stage2_joint_indices = torch.tensor( |
| all_stage2_joints, dtype=torch.long, device=self.device |
| ) |
| else: |
| self._warp_stage2_joint_indices = torch.tensor([], dtype=torch.long, device=self.device) |
|
|
| self._warp_frozen_joints = torch.tensor(list(frozen), dtype=torch.long, device=self.device) |
|
|
| if unskinned_end_joints: |
| self._warp_unskinned_end_joints = torch.tensor( |
| unskinned_end_joints, dtype=torch.long, device=self.device |
| ) |
| self._warp_unskinned_end_joint_parents = torch.tensor( |
| unskinned_end_joint_parents, dtype=torch.long, device=self.device |
| ) |
| else: |
| self._warp_unskinned_end_joints = torch.tensor([], dtype=torch.long, device=self.device) |
| self._warp_unskinned_end_joint_parents = torch.tensor( |
| [], dtype=torch.long, device=self.device |
| ) |
|
|
| frozen_parents = [self.joint_parent_ids[i] for i in frozen if i > 0] |
| self._warp_frozen_parents = torch.tensor( |
| frozen_parents, dtype=torch.long, device=self.device |
| ) |
|
|
| def fit(self, target_shapes): |
| """Fit the skeleton to new shapes by adjusting joint positions and orientations. |
| Args: |
| target_shapes: (B, V, 3) or (V, 3) array of new vertex positions |
| Returns: |
| target_bind_world_transforms: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space |
| """ |
| new_joint_positions = self.fit_joint_positions(target_shapes) |
|
|
| if self.use_warp_for_rotations: |
| world_bind_pose = self.fit_rotations_warp( |
| new_joint_positions, |
| target_shapes, |
| ) |
| else: |
| world_bind_pose = self.fit_joint_rotations( |
| new_joint_positions, |
| target_shapes, |
| ) |
| return world_bind_pose |
|
|
| def fit_joint_positions(self, target_shapes): |
| """Fit the skeleton to new shapes by adjusting joint positions. |
| Args: |
| target_shapes: (B, V, 3) or (V, 3) array of new vertex positions |
| Returns: |
| new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions |
| """ |
| dtype, device = self.dtype, self.device |
| J = self.num_joints |
|
|
| added_batch = False |
| if target_shapes.ndim == 2: |
| target_shapes = target_shapes[None, :, :] |
| added_batch = True |
| B = target_shapes.shape[0] |
|
|
| if self.sparse_rbf_matrix is not None: |
| target_shapes_flat = target_shapes.permute(1, 0, 2).reshape(target_shapes.shape[1], -1) |
| new_joint_positions = torch.mm(self.sparse_rbf_matrix, target_shapes_flat) |
| new_joint_positions = new_joint_positions.reshape(J, B, 3).permute(1, 0, 2) |
| else: |
| cols = [] |
| root_pos = self.bind_world_transforms[0, :3, 3].to(dtype=dtype, device=device) |
| root_pos = root_pos.view(1, 1, 3).expand(B, 1, 3) |
| cols.append(root_pos) |
| for i in range(1, J): |
| target_vertex_positions = target_shapes[:, self.regressor_mask[:, i]] |
| joint_query_position = self.bind_world_transforms[i : i + 1, :3, 3] |
| pred = self.joint_pos_regressors[i].interpolate( |
| target_vertex_positions, joint_query_position |
| ) |
| if pred.ndim == 2: |
| pred = pred[:, None, :] |
| cols.append(pred) |
| new_joint_positions = torch.cat(cols, dim=1) |
| return new_joint_positions[0] if added_batch else new_joint_positions |
|
|
| def fit_joint_rotations(self, new_joint_positions, target_shapes): |
| """Fit the skeleton to new positions by adjusting joint orientations. |
| Args: |
| new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions |
| target_shapes: (B, V, 3) or (V, 3) array of new vertex positions |
| Returns: |
| world_bind_pose: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space |
| """ |
| dtype, device = self.dtype, self.device |
| J = self.num_joints |
| freeze_rotations = self.freeze_rotations |
| skip_endjoints = self.skip_endjoints |
|
|
| added_batch = False |
| if new_joint_positions.ndim == 2: |
| new_joint_positions = new_joint_positions[None, :, :] |
| added_batch = True |
| if new_joint_positions.shape[-2:] != (J, 3): |
| raise ValueError( |
| f"Expected new_joint_positions to have shape (...,{J},3); got {new_joint_positions.shape}" |
| ) |
| if target_shapes.ndim == 2: |
| target_shapes = target_shapes[None, :, :] |
|
|
| t = new_joint_positions |
| B = t.shape[0] |
|
|
| bind_world = self.bind_world_transforms[None, ...].expand(B, J, 4, 4) |
| bind_local = self.bind_local_transforms[None, ...].expand(B, J, 4, 4) |
|
|
| R0 = self.bind_world_transforms[..., :3, :3].clone() |
| R = R0[None, ...].expand(B, J, 3, 3) |
|
|
| for i in range(1, J): |
| jmask = one_hot_1d(J, i, dtype=dtype, device=device)[None, :, None, None] |
| children = self.joint_children_ids[i] |
|
|
| if not children and skip_endjoints: |
| p = self.joint_parent_ids[i] |
| R_parent = R[:, p : p + 1, :, :] |
| R = R * (1 - jmask) + R_parent * jmask |
| continue |
|
|
| if i in freeze_rotations: |
| p = self.joint_parent_ids[i] |
| R_parent = R[:, p : p + 1, :, :] |
| R_i_new = R_parent @ bind_local[:, i : i + 1, :3, :3] |
| R = R * (1 - jmask) + R_i_new * jmask |
| continue |
|
|
| if self.skip_inverse_lbs: |
| R_init = ( |
| torch.eye(3, dtype=dtype, device=device) |
| .unsqueeze(0) |
| .unsqueeze(0) |
| .expand(B, 1, 3, 3) |
| ) |
| else: |
| skinned_vids = torch.where(self.skinning_weights[:, i] > 0.01)[0] |
| skinned_orig = ( |
| self.bind_shape[skinned_vids] - self.bind_world_transforms[i, :3, 3] |
| )[None, :, :] |
| skinned_new = target_shapes[:, skinned_vids, :] - t[:, i : i + 1] |
| R_init = align_vectors(skinned_new, skinned_orig) |
|
|
| if len(children) > 0: |
| pos_children_orig = ( |
| bind_world[:, :, :3, 3][:, children] - bind_world[:, i : i + 1, :3, 3] |
| ) |
| R_init_squeezed = R_init.squeeze(1) |
| pos_children_orig = (R_init_squeezed @ pos_children_orig.swapaxes(-2, -1)).swapaxes( |
| -2, -1 |
| ) |
| pos_children_new = t[:, children, :] - t[:, i : i + 1, :] |
|
|
| align_rot = align_vectors(pos_children_new, pos_children_orig) |
|
|
| R_i_new = align_rot @ R_init_squeezed @ R[:, i, :, :] |
| else: |
| R_i_new = R_init.squeeze(1) @ R[:, i, :, :] |
|
|
| R = R * (1 - jmask) + R_i_new[:, None, :, :] * jmask |
|
|
| world_bind_pose = SE3_from_Rt(R, t) |
| return world_bind_pose[0] if added_batch else world_bind_pose |
|
|
| def fit_rotations_warp(self, new_joint_positions, target_shapes): |
| """Warp-accelerated version of fit_joint_rotations using GPU-parallel alignment. |
| |
| Args: |
| new_joint_positions: (B, J, 3) or (J, 3) array of new joint positions |
| target_shapes: (B, V, 3) or (V, 3) array of new vertex positions |
| |
| Returns: |
| world_bind_pose: (B, J, 4, 4) or (J, 4, 4) array of new bind poses in world space |
| """ |
| if self._warp_stage1_offsets is None: |
| self._precompute_warp_data() |
|
|
| dtype, device = self.dtype, self.device |
| J = self.num_joints |
|
|
| added_batch = False |
| if new_joint_positions.ndim == 2: |
| new_joint_positions = new_joint_positions[None, :, :] |
| added_batch = True |
| if new_joint_positions.shape[-2:] != (J, 3): |
| raise ValueError( |
| f"Expected new_joint_positions to have shape (...,{J},3); got {new_joint_positions.shape}" |
| ) |
| if target_shapes.ndim == 2: |
| target_shapes = target_shapes[None, :, :] |
|
|
| t = new_joint_positions |
| B = t.shape[0] |
|
|
| bind_local = self.bind_local_transforms[None, ...].expand(B, J, 4, 4) |
|
|
| R0 = self.bind_world_transforms[..., :3, :3].clone() |
| R = R0[None, ...].expand(B, J, 3, 3) |
|
|
| |
| R_init_all = None |
| if len(self._warp_stage1_joint_indices) > 0: |
| if self.skip_inverse_lbs: |
| num_joints_stage1 = len(self._warp_stage1_joint_indices) |
| R_init_all = ( |
| torch.eye(3, dtype=dtype, device=device) |
| .unsqueeze(0) |
| .unsqueeze(0) |
| .expand(B, num_joints_stage1, 3, 3) |
| ) |
| else: |
| skinned_orig = self._warp_stage1_skinned_orig_centered |
|
|
| skinned_new = target_shapes[:, self._warp_stage1_skinned_vids_flat, :] |
| new_joint_positions_for_verts = t[:, self._warp_stage1_joint_indices, :] |
| new_joint_positions_expanded = new_joint_positions_for_verts[ |
| :, self._warp_stage1_repeat_indices, : |
| ] |
| skinned_new = skinned_new - new_joint_positions_expanded |
|
|
| skinned_orig_batched = skinned_orig.unsqueeze(0).expand(B, -1, -1).reshape(-1, 3) |
| skinned_new_flat = skinned_new.reshape(-1, 3) |
|
|
| num_joints_stage1 = len(self._warp_stage1_joint_indices) |
| offsets_batched = ( |
| self._warp_stage1_offsets.unsqueeze(0) |
| + torch.arange(B, device=device, dtype=torch.int32).unsqueeze(1) |
| * skinned_orig.shape[0] |
| ) |
| offsets_batched = offsets_batched.flatten() |
| counts_batched = self._warp_stage1_counts.unsqueeze(0).expand(B, -1).flatten() |
|
|
| R_init_all = align_vectors_warp( |
| skinned_new_flat, |
| skinned_orig_batched, |
| offsets_batched, |
| counts_batched, |
| method=self.rotation_method, |
| ) |
| R_init_all = R_init_all.reshape(B, num_joints_stage1, 3, 3) |
|
|
| |
| align_rot_n1 = None |
| align_rot_n2 = None |
|
|
| if R_init_all is not None and parallel_rodrigues_kabsch_warp is not None: |
| num_n1 = len(self._warp_stage2_n1_joint_indices) |
| if num_n1 > 0: |
| pos_children_orig_n1 = self._warp_stage2_n1_children_orig_centered |
| pos_children_new_n1 = ( |
| t[:, self._warp_stage2_n1_child_indices, :] |
| - t[:, self._warp_stage2_n1_joint_indices, :] |
| ) |
|
|
| pos_children_orig_n1_batched = pos_children_orig_n1.unsqueeze(0).expand(B, -1, -1) |
| R_init_for_n1 = R_init_all[:, self._warp_stage2_n1_to_stage1_indices, :, :] |
| pos_children_orig_n1_rotated = torch.bmm( |
| R_init_for_n1.reshape(-1, 3, 3), pos_children_orig_n1_batched.reshape(-1, 3, 1) |
| ).reshape(B, num_n1, 3) |
|
|
| src_vecs_n1 = pos_children_orig_n1_rotated.reshape(-1, 3) |
| tgt_vecs_n1 = pos_children_new_n1.reshape(-1, 3) |
| else: |
| src_vecs_n1 = torch.empty((0, 3), dtype=dtype, device=device) |
| tgt_vecs_n1 = torch.empty((0, 3), dtype=dtype, device=device) |
|
|
| num_n2 = len(self._warp_stage2_n2_joint_indices) |
| if num_n2 > 0: |
| pos_children_orig_n2 = self._warp_stage2_n2_children_orig_centered |
| pos_children_new_n2 = t[:, self._warp_stage2_n2_child_flat, :] |
| parent_positions_new_n2 = t[:, self._warp_stage2_n2_joint_indices, :] |
| parent_positions_new_n2_expanded = parent_positions_new_n2[ |
| :, self._warp_stage2_n2_repeat_indices, : |
| ] |
| pos_children_new_n2 = pos_children_new_n2 - parent_positions_new_n2_expanded |
|
|
| pos_children_orig_n2_batched = pos_children_orig_n2.unsqueeze(0).expand(B, -1, -1) |
| R_init_for_n2 = R_init_all[:, self._warp_stage2_n2_to_stage1_indices, :, :] |
| R_init_n2_expanded = R_init_for_n2[:, self._warp_stage2_n2_R_repeat_indices, :, :] |
| pos_children_orig_n2_rotated = torch.bmm( |
| R_init_n2_expanded.reshape(-1, 3, 3), |
| pos_children_orig_n2_batched.reshape(-1, 3, 1), |
| ).reshape(B, -1, 3) |
|
|
| pos_children_orig_n2_flat = pos_children_orig_n2_rotated.reshape(-1, 3) |
| pos_children_new_n2_flat = pos_children_new_n2.reshape(-1, 3) |
|
|
| offsets_batched_n2 = self._warp_stage2_n2_offsets.unsqueeze(0) + torch.arange( |
| B, device=device, dtype=torch.int32 |
| ).unsqueeze(1) * len(self._warp_stage2_n2_child_flat) |
| offsets_batched_n2 = offsets_batched_n2.flatten() |
| counts_batched_n2 = self._warp_stage2_n2_counts.unsqueeze(0).expand(B, -1).flatten() |
| else: |
| pos_children_orig_n2_flat = torch.empty((0, 3), dtype=dtype, device=device) |
| pos_children_new_n2_flat = torch.empty((0, 3), dtype=dtype, device=device) |
| offsets_batched_n2 = torch.empty((0,), dtype=torch.int32, device=device) |
| counts_batched_n2 = torch.empty((0,), dtype=torch.int32, device=device) |
|
|
| align_rot_n1_flat, align_rot_n2_flat = parallel_rodrigues_kabsch_warp( |
| tgt_vecs_n1, |
| src_vecs_n1, |
| pos_children_new_n2_flat, |
| pos_children_orig_n2_flat, |
| offsets_batched_n2, |
| counts_batched_n2, |
| method=self.rotation_method, |
| ) |
|
|
| align_rot_n1 = align_rot_n1_flat.reshape(B, num_n1, 3, 3) |
| align_rot_n2 = align_rot_n2_flat.reshape(B, num_n2, 3, 3) |
|
|
| |
| R_new = R.clone() |
|
|
| if R_init_all is not None: |
| stage1_joints = self._warp_stage1_joint_indices |
| R_new[:, stage1_joints, :, :] = torch.bmm( |
| R_init_all.reshape(B * len(stage1_joints), 3, 3), |
| R[:, stage1_joints, :, :].reshape(B * len(stage1_joints), 3, 3), |
| ).reshape(B, len(stage1_joints), 3, 3) |
|
|
| if align_rot_n1 is not None: |
| stage2_n1_joints = self._warp_stage2_n1_joint_indices |
| R_init_for_n1 = R_init_all[:, self._warp_stage2_n1_to_stage1_indices, :, :] |
|
|
| temp = torch.bmm( |
| align_rot_n1.reshape(B * len(stage2_n1_joints), 3, 3), |
| R_init_for_n1.reshape(B * len(stage2_n1_joints), 3, 3), |
| ).reshape(B, len(stage2_n1_joints), 3, 3) |
|
|
| R_new[:, stage2_n1_joints, :, :] = torch.bmm( |
| temp.reshape(B * len(stage2_n1_joints), 3, 3), |
| R[:, stage2_n1_joints, :, :].reshape(B * len(stage2_n1_joints), 3, 3), |
| ).reshape(B, len(stage2_n1_joints), 3, 3) |
|
|
| if align_rot_n2 is not None: |
| stage2_n2_joints = self._warp_stage2_n2_joint_indices |
| R_init_for_n2 = R_init_all[:, self._warp_stage2_n2_to_stage1_indices, :, :] |
|
|
| temp = torch.bmm( |
| align_rot_n2.reshape(B * len(stage2_n2_joints), 3, 3), |
| R_init_for_n2.reshape(B * len(stage2_n2_joints), 3, 3), |
| ).reshape(B, len(stage2_n2_joints), 3, 3) |
|
|
| R_new[:, stage2_n2_joints, :, :] = torch.bmm( |
| temp.reshape(B * len(stage2_n2_joints), 3, 3), |
| R[:, stage2_n2_joints, :, :].reshape(B * len(stage2_n2_joints), 3, 3), |
| ).reshape(B, len(stage2_n2_joints), 3, 3) |
|
|
| if len(self._warp_unskinned_end_joints) > 0: |
| R_new[:, self._warp_unskinned_end_joints, :, :] = R_new[ |
| :, self._warp_unskinned_end_joint_parents, :, : |
| ] |
|
|
| if len(self._warp_frozen_joints) > 0: |
| R_parents = R_new[:, self._warp_frozen_parents, :, :] |
| R_bind_local = bind_local[0, self._warp_frozen_joints, :3, :3] |
|
|
| num_frozen = len(self._warp_frozen_joints) |
| R_frozen = torch.bmm( |
| R_parents.reshape(B * num_frozen, 3, 3), |
| R_bind_local.unsqueeze(0).expand(B, -1, -1, -1).reshape(B * num_frozen, 3, 3), |
| ).reshape(B, num_frozen, 3, 3) |
|
|
| R_new[:, self._warp_frozen_joints, :, :] = R_frozen |
|
|
| world_bind_pose = SE3_from_Rt(R_new, t) |
| return world_bind_pose[0] if added_batch else world_bind_pose |
|
|