File size: 13,031 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5a4d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5a4d7
 
 
 
 
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base skeleton class: hierarchy, joint metadata, and helpers for kinematics and motion."""

from pathlib import Path
from typing import Optional

import torch

from kimodo.assets import skeleton_asset_path

from .kinematics import fk
from .transforms import (
    from_standard_tpose,
    global_rots_to_local_rots,
    to_standard_tpose,
)


class SkeletonBase(torch.nn.Module):
    """Base class that stores a skeleton hierarchy and helper metadata.

    Subclasses define the static joint layout (joint names and parent links) and semantic groups
    (feet, hands, hips). This class builds index mappings, parent tensors, and convenience helpers
    used by kinematics, constraints, and motion conversion utilities.
    """

    # these should be defined in the subclass
    name = None
    bone_order_names_with_parents = None
    bone_order_names_no_root = None
    root_idx = None
    foot_joint_names = None
    foot_joint_idx = None
    hip_joint_names = None  # in order [right, left]
    hip_joint_idx = None  # in order [right, left]

    def __init__(
        self,
        folder: Optional[str] = None,
        name: Optional[str] = None,
        load: bool = True,
        **kwargs,  # to catch addition args in configs
    ):
        """Initialize a skeleton instance and optional neutral-pose assets.

        Args:
            folder: Folder containing serialized skeleton assets (for example
                `joints.p` and optional `standard_t_pose_global_offsets_rots.p`).
            name: Optional runtime name used to validate subclass compatibility.
            load: Whether to load tensor assets from `folder`.
            **kwargs: Unused extra config keys kept for config compatibility.
        """
        super().__init__()

        if name is not None:
            # Check that the name is not too far from the actual skeleton class name
            assert self.name in name
            self.name = name

        if folder is None:
            # Take the skeleton asset folder of the repo from the name
            # in case we don't override it
            folder = str(skeleton_asset_path(self.name))

        self.folder = folder

        self.dim = len(self.bone_order_names_with_parents)

        if load and folder is not None:
            pfolder = Path(folder)
            class_default_name = getattr(self.__class__, "name", None)
            print(
                "[kimodo][skeleton][init][entry]"
                f" class={self.__class__.__name__} name={self.name} class_default={class_default_name}"
                f" folder={pfolder} joints_exists={(pfolder / 'joints.p').exists()}"
            )
            if not (pfolder / "joints.p").exists():
                candidate_names = []
                if class_default_name:
                    candidate_names.append(str(class_default_name))
                if self.name:
                    candidate_names.append(str(self.name))
                # Robust fallback for renamed runtime names in model configs
                candidate_names.extend(["somaskel30", "somaskel77", "g1skel34", "smplx22"])
                for candidate in candidate_names:
                    fallback_folder = skeleton_asset_path(candidate)
                    if (fallback_folder / "joints.p").exists():
                        pfolder = fallback_folder
                        self.folder = str(pfolder)
                        print(
                            "[kimodo][skeleton][init][fallback]"
                            f" class={self.__class__.__name__} candidate={candidate} path={pfolder}"
                        )
                        break
            try:
                neutral_joints = torch.load(pfolder / "joints.p").squeeze()
            except Exception as error:
                print(
                    "[kimodo][skeleton][init][error]"
                    f" class={self.__class__.__name__} resolved_folder={pfolder}"
                    f" dir_exists={pfolder.exists()}"
                    f" dir_entries={sorted([p.name for p in pfolder.iterdir()]) if pfolder.exists() else []}"
                    f" error={type(error).__name__}: {error}"
                )
                raise
            self.register_buffer("neutral_joints", neutral_joints, persistent=False)

            if (pfolder / "bvh_joints.p").exists():
                bvh_neutral_joints = torch.load(pfolder / "bvh_joints.p").squeeze()
                self.register_buffer("bvh_neutral_joints", bvh_neutral_joints, persistent=False)

            global_offset_path = pfolder / "standard_t_pose_global_offsets_rots.p"
            if global_offset_path.exists():
                global_rot_offsets = torch.load(global_offset_path).squeeze()
                self.register_buffer("global_rot_offsets", global_rot_offsets, persistent=False)
            # Usefull for g1, where the rest pose is not zero
            baked_rest_path = pfolder / "rest_pose_local_rot.p"
            if baked_rest_path.exists():
                rest_pose_local_rot = torch.load(baked_rest_path).squeeze()
                self.register_buffer("rest_pose_local_rot", rest_pose_local_rot, persistent=False)
            print(
                "[kimodo][skeleton][init][exit]"
                f" class={self.__class__.__name__} resolved_folder={pfolder}"
                f" neutral_shape={tuple(self.neutral_joints.shape)}"
            )

        self.bone_order_names = [x for x, y in self.bone_order_names_with_parents]

        self.bone_parents = dict(self.bone_order_names_with_parents)
        self.bone_index = {x: idx for idx, x in enumerate(self.bone_order_names)}
        self.bone_order_names_index = self.bone_index

        # create the parents tensor on the fly
        joint_parents = torch.tensor(
            [-1 if (y := self.bone_parents[x]) is None else self.bone_index[y] for x in self.bone_order_names]
        )
        self.register_buffer("joint_parents", joint_parents, persistent=False)

        self.nbjoints = len(self.bone_order_names)

        # check lengths
        assert self.nbjoints == len(self.joint_parents)
        if "neutral_joints" in self.__dict__:
            assert self.nbjoints == len(self.neutral_joints)

        root_indices = torch.where(joint_parents == -1)[0]
        assert len(root_indices) == 1  # should be one root only
        self.root_idx = root_indices[0].item()

        if "neutral_joints" in self.__dict__:
            assert (self.neutral_joints[0] == 0).all()

        # remove the root
        self.bone_order_names_no_root = (
            self.bone_order_names[: self.root_idx] + self.bone_order_names[self.root_idx + 1 :]
        )

        self.foot_joint_names = self.left_foot_joint_names + self.right_foot_joint_names
        self.foot_joint_names_index = {x: idx for idx, x in enumerate(self.foot_joint_names)}

        self.left_foot_joint_idx = [
            self.bone_order_names.index(foot_joint) for foot_joint in self.left_foot_joint_names
        ]

        self.right_foot_joint_idx = [
            self.bone_order_names.index(foot_joint) for foot_joint in self.right_foot_joint_names
        ]

        self.foot_joint_idx = self.left_foot_joint_idx + self.right_foot_joint_idx

        self.hip_joint_idx = [self.bone_order_names.index(hip_joint) for hip_joint in self.hip_joint_names]

    def expand_joint_names(self, joint_names):
        """Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to
        constrain position and rotations.

        Args:
            joint_names: list of list of base EE names to constrain

        Returns:
            rot_joint_names: list of list of joint names to constrain rotations
            pos_joint_names: list of list of joint names to constrain positions
        """

        base_ee = ["LeftFoot", "RightFoot", "LeftHand", "RightHand", "Hips"]

        pelvis_name = self.bone_order_names[self.root_idx]

        base_pos_names = [
            self.left_foot_joint_names,
            self.right_foot_joint_names,
            self.left_hand_joint_names,
            self.right_hand_joint_names,
            [pelvis_name],
        ]
        # base of each chain
        base_rot_names = [
            self.left_foot_joint_names[:1],
            self.right_foot_joint_names[:1],
            self.left_hand_joint_names[:1],
            self.right_hand_joint_names[:1],
            [pelvis_name],
        ]
        rot_joint_names = []
        pos_joint_names = []
        # loop through each EE joint group to constrain in the current keyframe
        for jname in joint_names:
            idx = base_ee.index(jname)
            rot_joint_names += base_rot_names[idx]
            pos_joint_names += base_pos_names[idx]
        return rot_joint_names, pos_joint_names

    def expand_joint_names_batched(self, joint_names):
        """Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to
        constrain position and rotations.

        Args:
            joint_names: list of list of base EE names to constrain

        Returns:
            rot_joint_names: list of list of joint names to constrain rotations
            pos_joint_names: list of list of joint names to constrain positions
        """

        base_ee = ["LeftFoot", "RightFoot", "LeftHand", "RightHand", "Hips"]

        pelvis_name = self.bone_order_names[self.root_idx]

        base_pos_names = [
            self.left_foot_joint_names,
            self.right_foot_joint_names,
            self.left_hand_joint_names,
            self.right_hand_joint_names,
            [pelvis_name],
        ]
        # base of each chain
        base_rot_names = [
            self.left_foot_joint_names[:1],
            self.right_foot_joint_names[:1],
            self.left_hand_joint_names[:1],
            self.right_hand_joint_names[:1],
            [pelvis_name],
        ]
        # loop through each keyframe
        rot_joint_names = []
        pos_joint_names = []
        for key_joint_names in joint_names:
            key_rot_names = []
            key_pos_names = []
            # loop through each EE joint group to constrain in the current keyframe
            for jname in key_joint_names:
                idx = base_ee.index(jname)
                key_rot_names += base_rot_names[idx]
                key_pos_names += base_pos_names[idx]
            rot_joint_names.append(key_rot_names)
            pos_joint_names.append(key_pos_names)
        return rot_joint_names, pos_joint_names

    def __repr__(self):
        if self.folder is None:
            return f"{self.__class__.__name__}()"
        return f'{self.__class__.__name__}(folder="{self.folder}")'

    @property
    def device(self):
        """Device where neutral-joint buffers are stored.

        Returns 'cpu' if neutral_joints is not present.
        """
        if getattr(self, "neutral_joints", None) is None:
            return "cpu"
        return self.neutral_joints.device

    def fk(self, local_joint_rots: torch.Tensor, root_positions: torch.Tensor):
        """Run forward kinematics for this skeleton layout.

        Args:
            local_joint_rots: Local joint rotation matrices with shape
                `(..., J, 3, 3)`.
            root_positions: Root translations with shape `(..., 3)`.

        Returns:
            Tuple of `(global_joint_rots, posed_joints, posed_joints_norootpos)`.
        """
        global_joint_rots, posed_joints, posed_joints_norootpos = fk(local_joint_rots, root_positions, self)
        return global_joint_rots, posed_joints, posed_joints_norootpos

    def to_standard_tpose(self, local_rot_mats: torch.Tensor):
        """Convert local rotations into the skeleton's standard T-pose frame."""
        return to_standard_tpose(local_rot_mats, self)

    def from_standard_tpose(self, local_rot_mats: torch.Tensor):
        """Convert local rotations from the skeleton's standard T-pose frame."""
        return from_standard_tpose(local_rot_mats, self)

    def global_rots_to_local_rots(self, global_joint_rots: torch.Tensor):
        """Convert global joint rotations to local rotations for this hierarchy."""
        return global_rots_to_local_rots(global_joint_rots, self)

    def get_skel_slice(self, skeleton: "SkeletonBase"):
        """Build index mapping from another skeleton into this skeleton order.

        Args:
            skeleton: Source skeleton whose joint order is used by input tensors.

        Returns:
            A list of source indices ordered as `self.bone_order_names`.

        Raises:
            ValueError: If at least one required joint is missing from `skeleton`.
        """
        try:
            skel_slice = [skeleton.bone_index[x] for x in self.bone_order_names]
        except KeyError:
            raise ValueError("The current skeleton contain joints that are not in the input")
        return skel_slice