| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from lib.utils.utils_smpl import SMPL |
| from lib.utils.utils_mesh import rotation_matrix_to_angle_axis, rot6d_to_rotmat |
|
|
| class SMPLRegressor(nn.Module): |
| def __init__(self, args, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.): |
| super(SMPLRegressor, self).__init__() |
| param_pose_dim = 24 * 6 |
| self.dropout = nn.Dropout(p=dropout_ratio) |
| self.fc1 = nn.Linear(num_joints*dim_rep, hidden_dim) |
| self.pool2 = nn.AdaptiveAvgPool2d((None, 1)) |
| self.fc2 = nn.Linear(num_joints*dim_rep, hidden_dim) |
| self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.1) |
| self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.1) |
| self.relu1 = nn.ReLU(inplace=True) |
| self.relu2 = nn.ReLU(inplace=True) |
| self.head_pose = nn.Linear(hidden_dim, param_pose_dim) |
| self.head_shape = nn.Linear(hidden_dim, 10) |
| nn.init.xavier_uniform_(self.head_pose.weight, gain=0.01) |
| nn.init.xavier_uniform_(self.head_shape.weight, gain=0.01) |
| self.smpl = SMPL( |
| args.data_root, |
| batch_size=64, |
| create_transl=False, |
| ) |
| mean_params = np.load(self.smpl.smpl_mean_params) |
| init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) |
| init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) |
| self.register_buffer('init_pose', init_pose) |
| self.register_buffer('init_shape', init_shape) |
| self.J_regressor = self.smpl.J_regressor_h36m |
|
|
| def forward(self, feat, init_pose=None, init_shape=None): |
| N, T, J, C = feat.shape |
| NT = N * T |
| feat = feat.reshape(N, T, -1) |
|
|
| feat_pose = feat.reshape(NT, -1) |
|
|
| feat_pose = self.dropout(feat_pose) |
| feat_pose = self.fc1(feat_pose) |
| feat_pose = self.bn1(feat_pose) |
| feat_pose = self.relu1(feat_pose) |
|
|
| feat_shape = feat.permute(0,2,1) |
| feat_shape = self.pool2(feat_shape).reshape(N, -1) |
|
|
| feat_shape = self.dropout(feat_shape) |
| feat_shape = self.fc2(feat_shape) |
| feat_shape = self.bn2(feat_shape) |
| feat_shape = self.relu2(feat_shape) |
|
|
| pred_pose = self.init_pose.expand(NT, -1) |
| pred_shape = self.init_shape.expand(N, -1) |
|
|
| pred_pose = self.head_pose(feat_pose) + pred_pose |
| pred_shape = self.head_shape(feat_shape) + pred_shape |
| pred_shape = pred_shape.expand(T, N, -1).permute(1, 0, 2).reshape(NT, -1) |
| pred_rotmat = rot6d_to_rotmat(pred_pose).view(-1, 24, 3, 3) |
| pred_output = self.smpl( |
| betas=pred_shape, |
| body_pose=pred_rotmat[:, 1:], |
| global_orient=pred_rotmat[:, 0].unsqueeze(1), |
| pose2rot=False |
| ) |
| pred_vertices = pred_output.vertices*1000.0 |
| assert self.J_regressor is not None |
| J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device) |
| pred_joints = torch.matmul(J_regressor_batch, pred_vertices) |
| pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72) |
| output = [{ |
| 'theta' : torch.cat([pose, pred_shape], dim=1), |
| 'verts' : pred_vertices, |
| 'kp_3d' : pred_joints, |
| }] |
| return output |
|
|
| class MeshRegressor(nn.Module): |
| def __init__(self, args, backbone, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.5): |
| super(MeshRegressor, self).__init__() |
| self.backbone = backbone |
| self.feat_J = num_joints |
| self.head = SMPLRegressor(args, dim_rep, num_joints, hidden_dim, dropout_ratio) |
| |
| def forward(self, x, init_pose=None, init_shape=None, n_iter=3): |
| ''' |
| Input: (N x T x 17 x 3) |
| ''' |
| N, T, J, C = x.shape |
| feat = self.backbone.get_representation(x) |
| feat = feat.reshape([N, T, self.feat_J, -1]) |
| smpl_output = self.head(feat) |
| for s in smpl_output: |
| s['theta'] = s['theta'].reshape(N, T, -1) |
| s['verts'] = s['verts'].reshape(N, T, -1, 3) |
| s['kp_3d'] = s['kp_3d'].reshape(N, T, -1, 3) |
| return smpl_output |