| import os |
| import sys |
|
|
| from torch.optim.lr_scheduler import StepLR |
|
|
| sys.path.append(os.getcwd()) |
|
|
| from nets.layers import * |
| from nets.base import TrainWrapperBaseClass |
| from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G |
| from nets.spg.vqvae_1d import VQVAE as s2g_body |
| from nets.utils import parse_audio, denormalize |
| from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta |
| import numpy as np |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from sklearn.preprocessing import normalize |
|
|
| from data_utils.lower_body import c_index, c_index_3d, c_index_6d |
|
|
|
|
| class TrainWrapper(TrainWrapperBaseClass): |
| ''' |
| a wrapper receving a batch from data_utils and calculate loss |
| ''' |
|
|
| def __init__(self, args, config): |
| self.args = args |
| self.config = config |
| self.device = torch.device(self.args.gpu) |
| self.global_step = 0 |
|
|
| self.convert_to_6d = self.config.Data.pose.convert_to_6d |
| self.expression = self.config.Data.pose.expression |
| self.epoch = 0 |
| self.init_params() |
| self.num_classes = 4 |
| self.composition = self.config.Model.composition |
| if self.composition: |
| self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, |
| num_residual_layers=2, num_residual_hiddens=512).to(self.device) |
| self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024, |
| num_residual_layers=2, num_residual_hiddens=512).to(self.device) |
| else: |
| self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, |
| num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device) |
|
|
| self.discriminator = None |
|
|
| if self.convert_to_6d: |
| self.c_index = c_index_6d |
| else: |
| self.c_index = c_index_3d |
|
|
| super().__init__(args, config) |
|
|
| def init_optimizer(self): |
| print('using Adam') |
| if self.composition: |
| self.g_body_optimizer = optim.Adam( |
| self.g_body.parameters(), |
| lr=self.config.Train.learning_rate.generator_learning_rate, |
| betas=[0.9, 0.999] |
| ) |
| self.g_hand_optimizer = optim.Adam( |
| self.g_hand.parameters(), |
| lr=self.config.Train.learning_rate.generator_learning_rate, |
| betas=[0.9, 0.999] |
| ) |
| else: |
| self.g_optimizer = optim.Adam( |
| self.g.parameters(), |
| lr=self.config.Train.learning_rate.generator_learning_rate, |
| betas=[0.9, 0.999] |
| ) |
|
|
| def state_dict(self): |
| if self.composition: |
| model_state = { |
| 'g_body': self.g_body.state_dict(), |
| 'g_body_optim': self.g_body_optimizer.state_dict(), |
| 'g_hand': self.g_hand.state_dict(), |
| 'g_hand_optim': self.g_hand_optimizer.state_dict(), |
| 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, |
| 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None |
| } |
| else: |
| model_state = { |
| 'g': self.g.state_dict(), |
| 'g_optim': self.g_optimizer.state_dict(), |
| 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, |
| 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None |
| } |
| return model_state |
|
|
| def init_params(self): |
| if self.config.Data.pose.convert_to_6d: |
| scale = 2 |
| else: |
| scale = 1 |
|
|
| global_orient = round(0 * scale) |
| leye_pose = reye_pose = round(0 * scale) |
| jaw_pose = round(0 * scale) |
| body_pose = round((63 - 24) * scale) |
| left_hand_pose = right_hand_pose = round(45 * scale) |
| if self.expression: |
| expression = 100 |
| else: |
| expression = 0 |
|
|
| b_j = 0 |
| jaw_dim = jaw_pose |
| b_e = b_j + jaw_dim |
| eye_dim = leye_pose + reye_pose |
| b_b = b_e + eye_dim |
| body_dim = global_orient + body_pose |
| b_h = b_b + body_dim |
| hand_dim = left_hand_pose + right_hand_pose |
| b_f = b_h + hand_dim |
| face_dim = expression |
|
|
| self.dim_list = [b_j, b_e, b_b, b_h, b_f] |
| self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim |
| self.pose = int(self.full_dim / round(3 * scale)) |
| self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] |
|
|
| def __call__(self, bat): |
| |
| self.global_step += 1 |
|
|
| total_loss = None |
| loss_dict = {} |
|
|
| aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) |
|
|
| |
| |
|
|
| poses = poses[:, self.c_index, :] |
| gt_poses = poses.permute(0, 2, 1) |
| b_poses = gt_poses[..., :self.each_dim[1]] |
| h_poses = gt_poses[..., self.each_dim[1]:] |
|
|
| if self.composition: |
| loss = 0 |
| loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss) |
| loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss) |
| else: |
| loss = 0 |
| loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss) |
|
|
| return total_loss, loss_dict |
|
|
| def vq_train(self, gt, name, model, dict, total_loss, pre=None): |
| e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre) |
| loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre) |
| |
|
|
| if name == 'b': |
| optimizer_name = 'g_body_optimizer' |
| elif name == 'h': |
| optimizer_name = 'g_hand_optimizer' |
| elif name == 'g': |
| optimizer_name = 'g_optimizer' |
| else: |
| raise ValueError("model's name must be b or h") |
| optimizer = getattr(self, optimizer_name) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| for key in list(loss_dict.keys()): |
| dict[name + key] = loss_dict.get(key, 0).item() |
| return dict, total_loss |
|
|
| def get_loss(self, |
| pred_poses, |
| gt_poses, |
| e_q_loss, |
| pre=None |
| ): |
| loss_dict = {} |
|
|
|
|
| rec_loss = torch.mean(torch.abs(pred_poses - gt_poses)) |
| v_pr = pred_poses[:, 1:] - pred_poses[:, :-1] |
| v_gt = gt_poses[:, 1:] - gt_poses[:, :-1] |
| velocity_loss = torch.mean(torch.abs(v_pr - v_gt)) |
|
|
| if pre is None: |
| f0_vel = 0 |
| else: |
| v0_pr = pred_poses[:, 0] - pre[:, -1] |
| v0_gt = gt_poses[:, 0] - pre[:, -1] |
| f0_vel = torch.mean(torch.abs(v0_pr - v0_gt)) |
|
|
| gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel |
|
|
| loss_dict['rec_loss'] = rec_loss |
| loss_dict['velocity_loss'] = velocity_loss |
| |
| if pre is not None: |
| loss_dict['f0_vel'] = f0_vel |
|
|
| return gen_loss, loss_dict |
|
|
| def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False, |
| id=None, fps=15, sr=22000, smooth=False, **kwargs): |
| ''' |
| initial_pose: (B, C, T), normalized |
| (aud_fn, txgfile) -> generated motion (B, T, C) |
| ''' |
| output = [] |
|
|
| assert self.args.infer, "train mode" |
| if self.composition: |
| self.g_body.eval() |
| self.g_hand.eval() |
| else: |
| self.g.eval() |
|
|
| if self.config.Data.pose.normalization: |
| assert norm_stats is not None |
| data_mean = norm_stats[0] |
| data_std = norm_stats[1] |
|
|
| |
| if initial_pose is not None: |
| gt = initial_pose[:, :, :].to(self.device).to(torch.float32) |
| pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32) |
| poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32) |
| B = pre_poses.shape[0] |
| else: |
| gt = None |
| pre_poses = None |
| B = 1 |
|
|
| if type(aud_fn) == torch.Tensor: |
| aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device) |
| num_poses_to_generate = aud_feat.shape[-1] |
| else: |
| aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0) |
| aud_feat = aud_feat[:, :] |
| num_poses_to_generate = aud_feat.shape[-1] |
| aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) |
| aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) |
|
|
| |
| if id is None: |
| id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device) |
|
|
| with torch.no_grad(): |
| aud_feat = aud_feat.permute(0, 2, 1) |
| gt_poses = gt[:, self.c_index].permute(0, 2, 1) |
| if self.composition: |
| if continuity: |
| pred_poses_body = [] |
| pred_poses_hand = [] |
| pre_b = None |
| pre_h = None |
| for i in range(5): |
| _, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b) |
| pre_b = pred_body[..., -1:].transpose(1,2) |
| pred_poses_body.append(pred_body) |
| _, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h) |
| pre_h = pred_hand[..., -1:].transpose(1,2) |
| pred_poses_hand.append(pred_hand) |
|
|
| pred_poses_body = torch.cat(pred_poses_body, dim=2) |
| pred_poses_hand = torch.cat(pred_poses_hand, dim=2) |
| else: |
| _, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id) |
| _, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id) |
| pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1) |
| else: |
| _, pred_poses = self.g(gt_poses=gt_poses, id=id) |
| pred_poses = pred_poses.transpose(1, 2).cpu().numpy() |
| output = pred_poses |
|
|
| if self.config.Data.pose.normalization: |
| output = denormalize(output, data_mean, data_std) |
|
|
| if smooth: |
| lamda = 0.8 |
| smooth_f = 10 |
| frame = 149 |
| for i in range(smooth_f): |
| f = frame + i |
| l = lamda * (i + 1) / smooth_f |
| output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f] |
|
|
| output = np.concatenate(output, axis=1) |
|
|
| return output |
|
|
| def load_state_dict(self, state_dict): |
| if self.composition: |
| self.g_body.load_state_dict(state_dict['g_body']) |
| self.g_hand.load_state_dict(state_dict['g_hand']) |
| else: |
| self.g.load_state_dict(state_dict['g']) |
|
|