| import logging |
| import pickle |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| __all__ = ['MMDataLoader'] |
| logger = logging.getLogger('MMSA') |
|
|
| class MMDataset(Dataset): |
| def __init__(self, args, mode='train'): |
| self.mode = mode |
| self.args = args |
| DATASET_MAP = { |
| 'mosi': self.__init_mosi, |
| 'mosei': self.__init_mosei, |
| } |
| DATASET_MAP[args['dataset_name']]() |
|
|
| def __init_mosi(self): |
| with open(self.args['featurePath'], 'rb') as f: |
| data = pickle.load(f) |
| if 'use_bert' in self.args and self.args['use_bert']: |
| self.text = data[self.mode]['text_bert'].astype(np.float32) |
| else: |
| self.text = data[self.mode]['text'].astype(np.float32) |
| self.vision = data[self.mode]['vision'].astype(np.float32) |
| self.audio = data[self.mode]['audio'].astype(np.float32) |
| self.raw_text = data[self.mode]['raw_text'] |
| self.ids = data[self.mode]['id'] |
|
|
|
|
| if self.args['feature_T'] != "": |
| with open(self.args['feature_T'], 'rb') as f: |
| data_T = pickle.load(f) |
| if 'use_bert' in self.args and self.args['use_bert']: |
| self.text = data_T[self.mode]['text_bert'].astype(np.float32) |
| self.args['feature_dims'][0] = 768 |
| else: |
| self.text = data_T[self.mode]['text'].astype(np.float32) |
| self.args['feature_dims'][0] = self.text.shape[2] |
| if self.args['feature_A'] != "": |
| with open(self.args['feature_A'], 'rb') as f: |
| data_A = pickle.load(f) |
| self.audio = data_A[self.mode]['audio'].astype(np.float32) |
| self.args['feature_dims'][1] = self.audio.shape[2] |
| if self.args['feature_V'] != "": |
| with open(self.args['feature_V'], 'rb') as f: |
| data_V = pickle.load(f) |
| self.vision = data_V[self.mode]['vision'].astype(np.float32) |
| self.args['feature_dims'][2] = self.vision.shape[2] |
|
|
| self.labels = { |
| 'M': np.array(data[self.mode]['regression_labels']).astype(np.float32) |
| } |
|
|
| logger.info(f"{self.mode} samples: {self.labels['M'].shape}") |
|
|
|
|
| if not self.args['need_data_aligned']: |
| if self.args['feature_A'] != "": |
| self.audio_lengths = list(data_A[self.mode]['audio_lengths']) |
| else: |
| self.audio_lengths = data[self.mode]['audio_lengths'] |
| if self.args['feature_V'] != "": |
| self.vision_lengths = list(data_V[self.mode]['vision_lengths']) |
| else: |
| self.vision_lengths = data[self.mode]['vision_lengths'] |
| self.audio[self.audio == -np.inf] = 0 |
|
|
| if 'need_normalized' in self.args and self.args['need_normalized']: |
| self.__normalize() |
| |
| def __init_mosei(self): |
| return self.__init_mosi() |
|
|
| def __init_sims(self): |
| return self.__init_mosi() |
|
|
| def __truncate(self): |
| def do_truncate(modal_features, length): |
| if length == modal_features.shape[1]: |
| return modal_features |
| truncated_feature = [] |
| padding = np.array([0 for i in range(modal_features.shape[2])]) |
| for instance in modal_features: |
| for index in range(modal_features.shape[1]): |
| if((instance[index] == padding).all()): |
| if(index + length >= modal_features.shape[1]): |
| truncated_feature.append(instance[index:index+20]) |
| break |
| else: |
| truncated_feature.append(instance[index:index+20]) |
| break |
| truncated_feature = np.array(truncated_feature) |
| return truncated_feature |
| |
| text_length, audio_length, video_length = self.args['seq_lens'] |
| self.vision = do_truncate(self.vision, video_length) |
| self.text = do_truncate(self.text, text_length) |
| self.audio = do_truncate(self.audio, audio_length) |
|
|
| def __normalize(self): |
|
|
| self.vision = np.mean(self.vision, axis=1, keepdims=True) |
| self.audio = np.mean(self.audio, axis=1, keepdims=True) |
|
|
| self.vision[self.vision != self.vision] = 0 |
| self.audio[self.audio != self.audio] = 0 |
| |
| def __len__(self): |
| return len(self.labels['M']) |
|
|
| def get_seq_len(self): |
| if 'use_bert' in self.args and self.args['use_bert']: |
| return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1]) |
| else: |
| return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1]) |
|
|
| def get_feature_dim(self): |
| return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] |
|
|
| def __getitem__(self, index): |
| sample = { |
| 'raw_text': self.raw_text[index], |
| 'text': torch.Tensor(self.text[index]), |
| 'audio': torch.Tensor(self.audio[index]), |
| 'vision': torch.Tensor(self.vision[index]), |
| 'index': index, |
| 'id': self.ids[index], |
| 'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()} |
| } |
| if not self.args['need_data_aligned']: |
| sample['audio_lengths'] = self.audio_lengths[index] |
| sample['vision_lengths'] = self.vision_lengths[index] |
| return sample |
|
|
| def MMDataLoader(args, num_workers): |
|
|
| datasets = { |
| 'train': MMDataset(args, mode='train'), |
| 'valid': MMDataset(args, mode='valid'), |
| 'test': MMDataset(args, mode='test') |
| } |
|
|
| if 'seq_lens' in args: |
| args['seq_lens'] = datasets['train'].get_seq_len() |
|
|
| dataLoader = { |
| ds: DataLoader(datasets[ds], |
| batch_size=args['batch_size'], |
| num_workers=num_workers, |
| shuffle=True) |
| for ds in datasets.keys() |
| } |
| |
| return dataLoader |
|
|