| |
| import torch |
| import utils.rotation_conversions as geometry |
|
|
|
|
| from models.smpl import SMPL, JOINTSTYPE_ROOT |
| |
| JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] |
|
|
|
|
| class Rotation2xyz: |
| def __init__(self, device, dataset='amass'): |
| self.device = device |
| self.dataset = dataset |
| self.smpl_model = SMPL().eval().to(device) |
|
|
| def __call__(self, x, mask, pose_rep, translation, glob, |
| jointstype, vertstrans, betas=None, beta=0, |
| glob_rot=None, get_rotations_back=False, **kwargs): |
| if pose_rep == "xyz": |
| return x |
|
|
| if mask is None: |
| mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) |
|
|
| if not glob and glob_rot is None: |
| raise TypeError("You must specify global rotation if glob is False") |
|
|
| if jointstype not in JOINTSTYPES: |
| raise NotImplementedError("This jointstype is not implemented.") |
|
|
| if translation: |
| x_translations = x[:, -1, :3] |
| x_rotations = x[:, :-1] |
| else: |
| x_rotations = x |
|
|
| x_rotations = x_rotations.permute(0, 3, 1, 2) |
| nsamples, time, njoints, feats = x_rotations.shape |
|
|
| |
| if pose_rep == "rotvec": |
| rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) |
| elif pose_rep == "rotmat": |
| rotations = x_rotations[mask].view(-1, njoints, 3, 3) |
| elif pose_rep == "rotquat": |
| rotations = geometry.quaternion_to_matrix(x_rotations[mask]) |
| elif pose_rep == "rot6d": |
| rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) |
| else: |
| raise NotImplementedError("No geometry for this one.") |
|
|
| if not glob: |
| global_orient = torch.tensor(glob_rot, device=x.device) |
| global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) |
| global_orient = global_orient.repeat(len(rotations), 1, 1, 1) |
| else: |
| global_orient = rotations[:, 0] |
| rotations = rotations[:, 1:] |
|
|
| if betas is None: |
| betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], |
| dtype=rotations.dtype, device=rotations.device) |
| betas[:, 1] = beta |
| |
| out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) |
|
|
| |
| joints = out[jointstype] |
|
|
| x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) |
| x_xyz[~mask] = 0 |
| x_xyz[mask] = joints |
|
|
| x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() |
|
|
| |
| if jointstype != "vertices": |
| rootindex = JOINTSTYPE_ROOT[jointstype] |
| x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] |
|
|
| if translation and vertstrans: |
| |
| x_translations = x_translations - x_translations[:, :, [0]] |
|
|
| |
| x_xyz = x_xyz + x_translations[:, None, :, :] |
|
|
| if get_rotations_back: |
| return x_xyz, rotations, global_orient |
| else: |
| return x_xyz |
|
|