| import sys |
| import os |
| sys.path.append(os.getcwd()) |
| import os |
| from tqdm import tqdm |
| from data_utils.utils import * |
| import torch.utils.data as data |
| from data_utils.mesh_dataset import SmplxDataset |
| from transformers import Wav2Vec2Processor |
|
|
|
|
| class MultiVidData(): |
| def __init__(self, |
| data_root, |
| speakers, |
| split='train', |
| limbscaling=False, |
| normalization=False, |
| norm_method='new', |
| split_trans_zero=False, |
| num_frames=25, |
| num_pre_frames=25, |
| num_generate_length=None, |
| aud_feat_win_size=None, |
| aud_feat_dim=64, |
| feat_method='mel_spec', |
| context_info=False, |
| smplx=False, |
| audio_sr=16000, |
| convert_to_6d=False, |
| expression=False, |
| config=None |
| ): |
| self.data_root = data_root |
| self.speakers = speakers |
| self.split = split |
| if split == 'pre': |
| self.split = 'train' |
| self.norm_method=norm_method |
| self.normalization = normalization |
| self.limbscaling = limbscaling |
| self.convert_to_6d = convert_to_6d |
| self.num_frames=num_frames |
| self.num_pre_frames=num_pre_frames |
| if num_generate_length is None: |
| self.num_generate_length = num_frames |
| else: |
| self.num_generate_length = num_generate_length |
| self.split_trans_zero=split_trans_zero |
|
|
| dataset = SmplxDataset |
| |
| if self.split_trans_zero: |
| self.trans_dataset_list = [] |
| self.zero_dataset_list = [] |
| else: |
| self.all_dataset_list = [] |
| self.dataset={} |
| self.complete_data=[] |
| self.config=config |
| load_mode=self.config.dataset_load_mode |
| |
| |
| if load_mode=='pickle': |
| import pickle |
| import subprocess |
| |
| |
| |
| |
| |
| |
| f = open(self.split+config.Data.pklname, 'rb+') |
| self.dataset=pickle.load(f) |
| f.close() |
| for key in self.dataset: |
| self.complete_data.append(self.dataset[key].complete_data) |
| |
| |
| |
| elif load_mode=='csv': |
|
|
| |
| try: |
| sys.path.append(self.config.config_root_path) |
| from config import config_path |
| from csv_parser import csv_parse |
| |
| except ImportError as e: |
| print(f'err: {e}') |
| raise ImportError('config root path error...') |
|
|
|
|
| for speaker_name in self.speakers: |
| |
| df_intervals=None |
| df_intervals=df_intervals[df_intervals['speaker']==speaker_name] |
| df_intervals = df_intervals[df_intervals['dataset'] == self.split] |
|
|
| print(f'speaker {speaker_name} train interval length: {len(df_intervals)}') |
| for iter_index, (_, interval) in tqdm( |
| (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}' |
| ): |
| |
| ( |
| interval_index, |
| interval_speaker, |
| interval_video_fn, |
| interval_id, |
| |
| start_time, |
| end_time, |
| duration_time, |
| start_time_10, |
| over_flow_flag, |
| short_dur_flag, |
| |
| big_video_dir, |
| small_video_dir_name, |
| speaker_video_path, |
| |
| voca_basename, |
| json_basename, |
| wav_basename, |
| voca_top_clip_path, |
| voca_json_clip_path, |
| voca_wav_clip_path, |
| |
| audio_output_fn, |
| image_output_path, |
| pifpaf_output_path, |
| mp_output_path, |
| op_output_path, |
| deca_output_path, |
| pixie_output_path, |
| cam_output_path, |
| ours_output_path, |
| merge_output_path, |
| multi_output_path, |
| gt_output_path, |
| ours_images_path, |
| pkl_fil_path, |
| )=csv_parse(interval) |
| |
| if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn): |
| continue |
|
|
| key=f'{interval_video_fn}/{small_video_dir_name}' |
| self.dataset[key] = dataset( |
| data_root=pkl_fil_path, |
| speaker=speaker_name, |
| audio_fn=audio_output_fn, |
| audio_sr=audio_sr, |
| fps=num_frames, |
| feat_method=feat_method, |
| audio_feat_dim=aud_feat_dim, |
| train=(self.split == 'train'), |
| load_all=True, |
| split_trans_zero=self.split_trans_zero, |
| limbscaling=self.limbscaling, |
| num_frames=self.num_frames, |
| num_pre_frames=self.num_pre_frames, |
| num_generate_length=self.num_generate_length, |
| audio_feat_win_size=aud_feat_win_size, |
| context_info=context_info, |
| convert_to_6d=convert_to_6d, |
| expression=expression, |
| config=self.config |
| ) |
| self.complete_data.append(self.dataset[key].complete_data) |
| |
| |
| |
| elif load_mode=='json': |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme") |
| am_sr = 16000 |
| |
| |
| for speaker_name in self.speakers: |
| speaker_root = os.path.join(self.data_root, speaker_name) |
|
|
| videos=[v for v in os.listdir(speaker_root) ] |
| print(videos) |
|
|
| haode = huaide = 0 |
|
|
| for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): |
| source_vid=vid |
| |
| vid_pth = os.path.join(speaker_root, source_vid, self.split) |
| if smplx == 'pose': |
| seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))] |
| else: |
| try: |
| seqs = [s for s in os.listdir(vid_pth)] |
| except: |
| continue |
|
|
| for s in seqs: |
| seq_root=os.path.join(vid_pth, s) |
| key = seq_root |
| audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s)) |
| motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s)) |
| if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname): |
| huaide = huaide + 1 |
| continue |
|
|
| self.dataset[key]=dataset( |
| data_root=seq_root, |
| speaker=speaker_name, |
| motion_fn=motion_fname, |
| audio_fn=audio_fname, |
| audio_sr=audio_sr, |
| fps=num_frames, |
| feat_method=feat_method, |
| audio_feat_dim=aud_feat_dim, |
| train=(self.split=='train'), |
| load_all=True, |
| split_trans_zero=self.split_trans_zero, |
| limbscaling=self.limbscaling, |
| num_frames=self.num_frames, |
| num_pre_frames=self.num_pre_frames, |
| num_generate_length=self.num_generate_length, |
| audio_feat_win_size=aud_feat_win_size, |
| context_info=context_info, |
| convert_to_6d=convert_to_6d, |
| expression=expression, |
| config=self.config, |
| am=am, |
| am_sr=am_sr, |
| whole_video=config.Data.whole_video |
| ) |
| self.complete_data.append(self.dataset[key].complete_data) |
| haode = haode + 1 |
| print("huaide:{}, haode:{}".format(huaide, haode)) |
| import pickle |
|
|
| f = open(self.split+config.Data.pklname, 'wb') |
| pickle.dump(self.dataset, f) |
| f.close() |
| |
|
|
| self.complete_data=np.concatenate(self.complete_data, axis=0) |
|
|
| |
| self.normalize_stats = {} |
|
|
| self.data_mean = None |
| self.data_std = None |
| |
| def get_dataset(self): |
| self.normalize_stats['mean'] = self.data_mean |
| self.normalize_stats['std'] = self.data_std |
|
|
| for key in list(self.dataset.keys()): |
| if self.dataset[key].complete_data.shape[0] < self.num_generate_length: |
| continue |
| self.dataset[key].num_generate_length = self.num_generate_length |
| self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split) |
| self.all_dataset_list.append(self.dataset[key].all_dataset) |
| |
| if self.split_trans_zero: |
| self.trans_dataset = data.ConcatDataset(self.trans_dataset_list) |
| self.zero_dataset = data.ConcatDataset(self.zero_dataset_list) |
| else: |
| self.all_dataset = data.ConcatDataset(self.all_dataset_list) |
|
|
|
|
|
|
|
|