| import pickle |
| import sys |
| import os |
|
|
| sys.path.append(os.getcwd()) |
|
|
| import json |
| from glob import glob |
| from data_utils.utils import * |
| import torch.utils.data as data |
| from data_utils.consts import speaker_id |
| from data_utils.lower_body import count_part |
| import random |
| from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d |
|
|
| with open('data_utils/hand_component.json') as file_obj: |
| comp = json.load(file_obj) |
| left_hand_c = np.asarray(comp['left']) |
| right_hand_c = np.asarray(comp['right']) |
|
|
|
|
| def to3d(data): |
| left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :]) |
| right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :]) |
| data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1) |
| return data |
|
|
|
|
| class SmplxDataset(): |
| ''' |
| creat a dataset for every segment and concat. |
| ''' |
|
|
| def __init__(self, |
| data_root, |
| speaker, |
| motion_fn, |
| audio_fn, |
| audio_sr, |
| fps, |
| feat_method='mel_spec', |
| audio_feat_dim=64, |
| audio_feat_win_size=None, |
| |
| train=True, |
| load_all=False, |
| split_trans_zero=False, |
| limbscaling=False, |
| num_frames=25, |
| num_pre_frames=25, |
| num_generate_length=25, |
| context_info=False, |
| convert_to_6d=False, |
| expression=False, |
| config=None, |
| am=None, |
| am_sr=None, |
| whole_video=False |
| ): |
|
|
| self.data_root = data_root |
| self.speaker = speaker |
|
|
| self.feat_method = feat_method |
| self.audio_fn = audio_fn |
| self.audio_sr = audio_sr |
| self.fps = fps |
| self.audio_feat_dim = audio_feat_dim |
| self.audio_feat_win_size = audio_feat_win_size |
| self.context_info = context_info |
| self.convert_to_6d = convert_to_6d |
| self.expression = expression |
|
|
| self.train = train |
| self.load_all = load_all |
| self.split_trans_zero = split_trans_zero |
| self.limbscaling = limbscaling |
| self.num_frames = num_frames |
| self.num_pre_frames = num_pre_frames |
| self.num_generate_length = num_generate_length |
| |
|
|
| self.config = config |
| self.am_sr = am_sr |
| self.whole_video = whole_video |
| load_mode = self.config.dataset_load_mode |
|
|
| if load_mode == 'pickle': |
| raise NotImplementedError |
|
|
| elif load_mode == 'csv': |
| import pickle |
| with open(data_root, 'rb') as f: |
| u = pickle._Unpickler(f) |
| data = u.load() |
| self.data = data[0] |
| if self.load_all: |
| self._load_npz_all() |
|
|
| elif load_mode == 'json': |
| self.annotations = glob(data_root + '/*pkl') |
| if len(self.annotations) == 0: |
| raise FileNotFoundError(data_root + ' are empty') |
| self.annotations = sorted(self.annotations) |
| self.img_name_list = self.annotations |
|
|
| if self.load_all: |
| self._load_them_all(am, am_sr, motion_fn) |
|
|
| def _load_npz_all(self): |
| self.loaded_data = {} |
| self.complete_data = [] |
| data = self.data |
| shape = data['body_pose_axis'].shape[0] |
| self.betas = data['betas'] |
| self.img_name_list = [] |
| for index in range(shape): |
| img_name = f'{index:6d}' |
| self.img_name_list.append(img_name) |
|
|
| jaw_pose = data['jaw_pose'][index] |
| leye_pose = data['leye_pose'][index] |
| reye_pose = data['reye_pose'][index] |
| global_orient = data['global_orient'][index] |
| body_pose = data['body_pose_axis'][index] |
| left_hand_pose = data['left_hand_pose'][index] |
| right_hand_pose = data['right_hand_pose'][index] |
|
|
| full_body = np.concatenate( |
| (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose)) |
| assert full_body.shape[0] == 99 |
| if self.convert_to_6d: |
| full_body = to3d(full_body) |
| full_body = torch.from_numpy(full_body) |
| full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body)) |
| full_body = np.asarray(full_body) |
| if self.expression: |
| expression = data['expression'][index] |
| full_body = np.concatenate((full_body, expression)) |
| |
| else: |
| full_body = to3d(full_body) |
| if self.expression: |
| expression = data['expression'][index] |
| full_body = np.concatenate((full_body, expression)) |
|
|
| self.loaded_data[img_name] = full_body.reshape(-1) |
| self.complete_data.append(full_body.reshape(-1)) |
|
|
| self.complete_data = np.array(self.complete_data) |
|
|
| if self.audio_feat_win_size is not None: |
| self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) |
| |
| else: |
| if self.feat_method == 'mel_spec': |
| self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) |
| elif self.feat_method == 'mfcc': |
| self.audio_feat = get_mfcc(self.audio_fn, |
| smlpx=True, |
| sr=self.audio_sr, |
| n_mfcc=self.audio_feat_dim, |
| win_size=self.audio_feat_win_size |
| ) |
|
|
| def _load_them_all(self, am, am_sr, motion_fn): |
| self.loaded_data = {} |
| self.complete_data = [] |
| f = open(motion_fn, 'rb+') |
| data = pickle.load(f) |
|
|
| self.betas = np.array(data['betas']) |
|
|
| jaw_pose = np.array(data['jaw_pose']) |
| leye_pose = np.array(data['leye_pose']) |
| reye_pose = np.array(data['reye_pose']) |
| global_orient = np.array(data['global_orient']).squeeze() |
| body_pose = np.array(data['body_pose_axis']) |
| left_hand_pose = np.array(data['left_hand_pose']) |
| right_hand_pose = np.array(data['right_hand_pose']) |
|
|
| full_body = np.concatenate( |
| (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1) |
| assert full_body.shape[1] == 99 |
|
|
|
|
| if self.convert_to_6d: |
| full_body = to3d(full_body) |
| full_body = torch.from_numpy(full_body) |
| full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330) |
| full_body = np.asarray(full_body) |
| if self.expression: |
| expression = np.array(data['expression']) |
| full_body = np.concatenate((full_body, expression), axis=1) |
|
|
| else: |
| full_body = to3d(full_body) |
| expression = np.array(data['expression']) |
| full_body = np.concatenate((full_body, expression), axis=1) |
|
|
| self.complete_data = full_body |
| self.complete_data = np.array(self.complete_data) |
|
|
| if self.audio_feat_win_size is not None: |
| self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) |
| else: |
| |
| |
| |
| self.audio_feat = get_mfcc_ta(self.audio_fn, |
| smlpx=True, |
| fps=30, |
| sr=self.audio_sr, |
| n_mfcc=self.audio_feat_dim, |
| win_size=self.audio_feat_win_size, |
| type=self.feat_method, |
| am=am, |
| am_sr=am_sr, |
| encoder_choice=self.config.Model.encoder_choice, |
| ) |
| |
| |
|
|
| def get_dataset(self, normalization=False, normalize_stats=None, split='train'): |
|
|
| class __Worker__(data.Dataset): |
| def __init__(child, index_list, normalization, normalize_stats, split='train') -> None: |
| super().__init__() |
| child.index_list = index_list |
| child.normalization = normalization |
| child.normalize_stats = normalize_stats |
| child.split = split |
|
|
| def __getitem__(child, index): |
| num_generate_length = self.num_generate_length |
| num_pre_frames = self.num_pre_frames |
| seq_len = num_generate_length + num_pre_frames |
| |
|
|
| index = child.index_list[index] |
| index_new = index + random.randrange(0, 5, 3) |
| if index_new + seq_len > self.complete_data.shape[0]: |
| index_new = index |
| index = index_new |
|
|
| if child.split in ['val', 'pre', 'test'] or self.whole_video: |
| index = 0 |
| seq_len = self.complete_data.shape[0] |
| seq_data = [] |
| assert index + seq_len <= self.complete_data.shape[0] |
| |
| seq_data = self.complete_data[index:(index + seq_len), :] |
| seq_data = np.array(seq_data) |
|
|
| ''' |
| audio feature, |
| ''' |
| if not self.context_info: |
| if not self.whole_video: |
| audio_feat = self.audio_feat[index:index + seq_len, ...] |
| if audio_feat.shape[0] < seq_len: |
| audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]], |
| mode='reflect') |
|
|
| assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim |
| else: |
| audio_feat = self.audio_feat |
|
|
| else: |
| if self.audio_feat_win_size is None: |
| audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...] |
| if audio_feat.shape[0] < seq_len + num_pre_frames: |
| audio_feat = np.pad(audio_feat, |
| [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]], |
| mode='constant') |
|
|
| assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[ |
| 1] == self.audio_feat_dim |
|
|
| if child.normalization: |
| data_mean = child.normalize_stats['mean'].reshape(1, -1) |
| data_std = child.normalize_stats['std'].reshape(1, -1) |
| seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std |
| if child.split in['train', 'test']: |
| if self.convert_to_6d: |
| if self.expression: |
| data_sample = { |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
| 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
| |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
| 'speaker': speaker_id[self.speaker], |
| 'betas': self.betas, |
| 'aud_file': self.audio_fn, |
| } |
| else: |
| data_sample = { |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
| 'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
| 'speaker': speaker_id[self.speaker], |
| 'betas': self.betas |
| } |
| else: |
| if self.expression: |
| data_sample = { |
| 'poses': seq_data[:, :165].astype(np.float).transpose(1, 0), |
| 'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0), |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
| |
| 'speaker': speaker_id[self.speaker], |
| 'aud_file': self.audio_fn, |
| 'betas': self.betas |
| } |
| else: |
| data_sample = { |
| 'poses': seq_data.astype(np.float).transpose(1, 0), |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
| 'speaker': speaker_id[self.speaker], |
| 'betas': self.betas |
| } |
| return data_sample |
| else: |
| data_sample = { |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
| 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
| |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
| 'aud_file': self.audio_fn, |
| 'speaker': speaker_id[self.speaker], |
| 'betas': self.betas |
| } |
| return data_sample |
| def __len__(child): |
| return len(child.index_list) |
|
|
| if split == 'train': |
| index_list = list( |
| range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames, |
| 6)) |
| elif split in ['val', 'test']: |
| index_list = list([0]) |
| if self.whole_video: |
| index_list = list([0]) |
| self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split) |
|
|
| def __len__(self): |
| return len(self.img_name_list) |
|
|