| ''' |
| dutils.py |
| A utility library for customized data loading functions |
| ''' |
| import os |
| import gzip |
| import numpy as np |
| import pandas as pd |
|
|
| import os |
| import cv2 |
| from typing import List, Union, Dict, Sequence |
| import numpy as np |
| import numpy.random as nprand |
| import datetime |
| import pandas as pd |
| import h5py |
| import torch |
| import torch.nn.functional as F |
| from torch.nn.functional import avg_pool2d |
| import random |
| from torchvision import transforms as T |
| from torchvision import datasets |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
|
|
| SEVIR_ROOT_DIR = "data/SEVIR" |
| METEO_FILE_DIR = "data/meteonet" |
|
|
| def resize(seq, size): |
| |
| seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) |
| seq = seq.clamp(0,1) |
| return seq.unsqueeze(2) |
|
|
| |
| |
| |
| def pixel_to_dBZ_nonlinear(img): |
| ''' |
| [0, 255] OR [0, 1] pixel => [0, 80] dBZ |
| ''' |
| if img.mean() > 1.0: |
| img = img / 255.0 |
| ashift = 31.0 |
| afact = 4.0 |
| atan_dBZ_min = -1.482 |
| atan_dBZ_max = 1.412 |
| tan_pix = np.tan(img * (atan_dBZ_max - atan_dBZ_min) + atan_dBZ_min) |
| return tan_pix * afact + ashift |
|
|
| def dbZ_to_pixel_nonlinear(dbZ): |
| ''' |
| [0, 80] dBZ => [0, 255] OR [0, 1] pixel |
| ''' |
| ashift = 31.0 |
| afact = 4.0 |
| atan_dBZ_min = -1.482 |
| atan_dBZ_max = 1.412 |
| dbZ_adjusted = (dbZ - ashift) / afact |
| return (np.arctan(dbZ_adjusted) - atan_dBZ_min) / (atan_dBZ_max - atan_dBZ_min) |
|
|
| def dbZ_to_pixel(dbZ): |
| ''' |
| [0, 80] dbZ => [0, 1] pixel |
| ''' |
| return np.floor((dbZ + 10) * 255 / 70 + 0.5) / 255.0 |
|
|
| def pixel_to_dBZ(pixel): |
| ''' |
| [0, 255] (or [0, 1]) pixel => [0, 80] dBZ |
| ''' |
| if pixel.mean() > 1.0: |
| pixel = pixel / 255.0 |
| return (70 * pixel) - 10 |
|
|
| def nonlinear_to_linear(im): |
| return dbZ_to_pixel(pixel_to_dBZ_nonlinear(im)) |
|
|
| def nonlinear_to_linear_batched(seq, datetime): |
| seq_linear = np.zeros_like(seq) |
| for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
| if dt_b[0].year >= 2016: |
| seq_linear[i] = nonlinear_to_linear(seq_b) |
| else: |
| seq_linear[i] = seq_b |
| seq_linear = np.clip(seq_linear, 0.0, 1.0) |
| return seq_linear |
|
|
| def linear_to_nonlinear(im): |
| return dbZ_to_pixel_nonlinear(pixel_to_dBZ(im)) |
|
|
| def linear_to_nonlinear_batched(seq, datetime): |
| seq_nonlinear = np.zeros_like(seq) |
| for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
| if dt_b[0].year < 2016: |
| seq_nonlinear[i] = linear_to_nonlinear(seq_b) |
| else: |
| seq_nonlinear[i] = seq_b |
| seq_nonlinear = np.clip(seq_nonlinear, 0.0, 1.0) |
| return seq_nonlinear |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| SEVIR_DATA_TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght'] |
| SEVIR_RAW_DTYPES = {'vis': np.int16, |
| 'ir069': np.int16, |
| 'ir107': np.int16, |
| 'vil': np.uint8, |
| 'lght': np.int16} |
| LIGHTING_FRAME_TIMES = np.arange(- 120.0, 125.0, 5) * 60 |
| SEVIR_DATA_SHAPE = {'lght': (48, 48), } |
| PREPROCESS_SCALE_SEVIR = {'vis': 1, |
| 'ir069': 1 / 1174.68, |
| 'ir107': 1 / 2562.43, |
| 'vil': 1 / 47.54, |
| 'lght': 1 / 0.60517} |
| PREPROCESS_OFFSET_SEVIR = {'vis': 0, |
| 'ir069': 3683.58, |
| 'ir107': 1552.80, |
| 'vil': - 33.44, |
| 'lght': - 0.02990} |
| PREPROCESS_SCALE_01 = {'vis': 1, |
| 'ir069': 1, |
| 'ir107': 1, |
| 'vil': 1 / 255, |
| 'lght': 1} |
| PREPROCESS_OFFSET_01 = {'vis': 0, |
| 'ir069': 0, |
| 'ir107': 0, |
| 'vil': 0, |
| 'lght': 0} |
|
|
| |
| SEVIR_CATALOG = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv") |
| SEVIR_DATA_DIR = os.path.join(SEVIR_ROOT_DIR, "data") |
| SEVIR_RAW_SEQ_LEN = 49 |
|
|
| SEVIR_TRAIN_VAL_SPLIT_DATE = datetime.datetime(2019, 1, 1) |
| SEVIR_TRAIN_TEST_SPLIT_DATE = datetime.datetime(2019, 6, 1) |
|
|
| def change_layout_np(data, |
| in_layout='NHWT', out_layout='NHWT', |
| ret_contiguous=False): |
| |
| if in_layout == 'NHWT': |
| pass |
| elif in_layout == 'NTHW': |
| data = np.transpose(data, |
| axes=(0, 2, 3, 1)) |
| elif in_layout == 'NWHT': |
| data = np.transpose(data, |
| axes=(0, 2, 1, 3)) |
| elif in_layout == 'NTCHW': |
| data = data[:, :, 0, :, :] |
| data = np.transpose(data, |
| axes=(0, 2, 3, 1)) |
| elif in_layout == 'NTHWC': |
| data = data[:, :, :, :, 0] |
| data = np.transpose(data, |
| axes=(0, 2, 3, 1)) |
| elif in_layout == 'NTWHC': |
| data = data[:, :, :, :, 0] |
| data = np.transpose(data, |
| axes=(0, 3, 2, 1)) |
| elif in_layout == 'TNHW': |
| data = np.transpose(data, |
| axes=(1, 2, 3, 0)) |
| elif in_layout == 'TNCHW': |
| data = data[:, :, 0, :, :] |
| data = np.transpose(data, |
| axes=(1, 2, 3, 0)) |
| else: |
| raise NotImplementedError |
|
|
| if out_layout == 'NHWT': |
| pass |
| elif out_layout == 'NTHW': |
| data = np.transpose(data, |
| axes=(0, 3, 1, 2)) |
| elif out_layout == 'NWHT': |
| data = np.transpose(data, |
| axes=(0, 2, 1, 3)) |
| elif out_layout == 'NTCHW': |
| data = np.transpose(data, |
| axes=(0, 3, 1, 2)) |
| data = np.expand_dims(data, axis=2) |
| elif out_layout == 'NTHWC': |
| data = np.transpose(data, |
| axes=(0, 3, 1, 2)) |
| data = np.expand_dims(data, axis=-1) |
| elif out_layout == 'NTWHC': |
| data = np.transpose(data, |
| axes=(0, 3, 2, 1)) |
| data = np.expand_dims(data, axis=-1) |
| elif out_layout == 'TNHW': |
| data = np.transpose(data, |
| axes=(3, 0, 1, 2)) |
| elif out_layout == 'TNCHW': |
| data = np.transpose(data, |
| axes=(3, 0, 1, 2)) |
| data = np.expand_dims(data, axis=2) |
| else: |
| raise NotImplementedError |
| if ret_contiguous: |
| data = data.ascontiguousarray() |
| return data |
|
|
| def change_layout_torch(data, |
| in_layout='NHWT', out_layout='NHWT', |
| ret_contiguous=False): |
| |
| if in_layout == 'NHWT': |
| pass |
| elif in_layout == 'NTHW': |
| data = data.permute(0, 2, 3, 1) |
| elif in_layout == 'NTCHW': |
| data = data[:, :, 0, :, :] |
| data = data.permute(0, 2, 3, 1) |
| elif in_layout == 'NTHWC': |
| data = data[:, :, :, :, 0] |
| data = data.permute(0, 2, 3, 1) |
| elif in_layout == 'TNHW': |
| data = data.permute(1, 2, 3, 0) |
| elif in_layout == 'TNCHW': |
| data = data[:, :, 0, :, :] |
| data = data.permute(1, 2, 3, 0) |
| else: |
| raise NotImplementedError |
|
|
| if out_layout == 'NHWT': |
| pass |
| elif out_layout == 'NTHW': |
| data = data.permute(0, 3, 1, 2) |
| elif out_layout == 'NTCHW': |
| data = data.permute(0, 3, 1, 2) |
| data = torch.unsqueeze(data, dim=2) |
| elif out_layout == 'NTHWC': |
| data = data.permute(0, 3, 1, 2) |
| data = torch.unsqueeze(data, dim=-1) |
| elif out_layout == 'TNHW': |
| data = data.permute(3, 0, 1, 2) |
| elif out_layout == 'TNCHW': |
| data = data.permute(3, 0, 1, 2) |
| data = torch.unsqueeze(data, dim=2) |
| else: |
| raise NotImplementedError |
| if ret_contiguous: |
| data = data.contiguous() |
| return data |
|
|
| class SEVIRDataLoader: |
| r""" |
| DataLoader that loads SEVIR sequences, and spilts each event |
| into segments according to specified sequence length. |
| |
| Event Frames: |
| [-----------------------raw_seq_len----------------------] |
| [-----seq_len-----] |
| <--stride-->[-----seq_len-----] |
| <--stride-->[-----seq_len-----] |
| ... |
| """ |
| def __init__(self, |
| data_types: Sequence[str] = None, |
| seq_len: int = 49, |
| raw_seq_len: int = 49, |
| sample_mode: str = 'sequent', |
| stride: int = 12, |
| batch_size: int = 1, |
| layout: str = 'NHWT', |
| num_shard: int = 1, |
| rank: int = 0, |
| split_mode: str = "uneven", |
| sevir_catalog: Union[str, pd.DataFrame] = None, |
| sevir_data_dir: str = None, |
| start_date: datetime.datetime = None, |
| end_date: datetime.datetime = None, |
| datetime_filter=None, |
| catalog_filter='default', |
| shuffle: bool = False, |
| shuffle_seed: int = 1, |
| output_type=np.float32, |
| preprocess: bool = True, |
| rescale_method: str = '01', |
| downsample_dict: Dict[str, Sequence[int]] = None, |
| verbose: bool = False): |
| r""" |
| Parameters |
| ---------- |
| data_types |
| A subset of SEVIR_DATA_TYPES. |
| seq_len |
| The length of the data sequences. Should be smaller than the max length raw_seq_len. |
| raw_seq_len |
| The length of the raw data sequences. |
| sample_mode |
| 'random' or 'sequent' |
| stride |
| Useful when sample_mode == 'sequent' |
| stride must not be smaller than out_len to prevent data leakage in testing. |
| batch_size |
| Number of sequences in one batch. |
| layout |
| str: consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
| The layout of sampled data. Raw data layout is 'NHWT'. |
| valid layout: 'NHWT', 'NTHW', 'NTCHW', 'TNHW', 'TNCHW'. |
| num_shard |
| Split the whole dataset into num_shard parts for distributed training. |
| rank |
| Rank of the current process within num_shard. |
| split_mode: str |
| if 'ceil', all `num_shard` dataloaders have the same length = ceil(total_len / num_shard). |
| Different dataloaders may have some duplicated data batches, if the total size of datasets is not divided by num_shard. |
| if 'floor', all `num_shard` dataloaders have the same length = floor(total_len / num_shard). |
| The last several data batches may be wasted, if the total size of datasets is not divided by num_shard. |
| if 'uneven', the last datasets has larger length when the total length is not divided by num_shard. |
| The uneven split leads to synchronization error in dist.all_reduce() or dist.barrier(). |
| See related issue: https://github.com/pytorch/pytorch/issues/33148 |
| Notice: this also affects the behavior of `self.use_up`. |
| sevir_catalog |
| Name of SEVIR catalog CSV file. |
| sevir_data_dir |
| Directory path to SEVIR data. |
| start_date |
| Start time of SEVIR samples to generate. |
| end_date |
| End time of SEVIR samples to generate. |
| datetime_filter |
| function |
| Mask function applied to time_utc column of catalog (return true to keep the row). |
| Pass function of the form lambda t : COND(t) |
| Example: lambda t: np.logical_and(t.dt.hour>=13,t.dt.hour<=21) # Generate only day-time events |
| catalog_filter |
| function or None or 'default' |
| Mask function applied to entire catalog dataframe (return true to keep row). |
| Pass function of the form lambda catalog: COND(catalog) |
| Example: lambda c: [s[0]=='S' for s in c.id] # Generate only the 'S' events |
| shuffle |
| bool, If True, data samples are shuffled before each epoch. |
| shuffle_seed |
| int, Seed to use for shuffling. |
| output_type |
| np.dtype, dtype of generated tensors |
| preprocess |
| bool, If True, self.preprocess_data_dict(data_dict) is called before each sample generated |
| downsample_dict: |
| dict, downsample_dict.keys() == data_types. downsample_dict[key] is a Sequence of (t_factor, h_factor, w_factor), |
| representing the downsampling factors of all dimensions. |
| verbose |
| bool, verbose when opening raw data files |
| """ |
| super(SEVIRDataLoader, self).__init__() |
| if sevir_catalog is None: |
| sevir_catalog = SEVIR_CATALOG |
| if sevir_data_dir is None: |
| sevir_data_dir = SEVIR_DATA_DIR |
| if data_types is None: |
| data_types = SEVIR_DATA_TYPES |
| else: |
| assert set(data_types).issubset(SEVIR_DATA_TYPES) |
|
|
| |
| self._dtypes = SEVIR_RAW_DTYPES |
| self.lght_frame_times = LIGHTING_FRAME_TIMES |
| self.data_shape = SEVIR_DATA_SHAPE |
|
|
| self.raw_seq_len = raw_seq_len |
| assert seq_len <= self.raw_seq_len, f'seq_len must not be larger than raw_seq_len = {raw_seq_len}, got {seq_len}.' |
| self.seq_len = seq_len |
| assert sample_mode in ['random', 'sequent'], f'Invalid sample_mode = {sample_mode}, must be \'random\' or \'sequent\'.' |
| self.sample_mode = sample_mode |
| self.stride = stride |
| self.batch_size = batch_size |
| valid_layout = ('NHWT', 'NTHW', 'NTCHW', 'NTHWC', 'TNHW', 'TNCHW') |
| if layout not in valid_layout: |
| raise ValueError(f'Invalid layout = {layout}! Must be one of {valid_layout}.') |
| self.layout = layout |
| self.num_shard = num_shard |
| self.rank = rank |
| valid_split_mode = ('ceil', 'floor', 'uneven') |
| if split_mode not in valid_split_mode: |
| raise ValueError(f'Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}.') |
| self.split_mode = split_mode |
| self._samples = None |
| self._hdf_files = {} |
| self.data_types = data_types |
| if isinstance(sevir_catalog, str): |
| self.catalog = pd.read_csv(sevir_catalog, parse_dates=['time_utc'], low_memory=False) |
| else: |
| self.catalog = sevir_catalog |
| self.sevir_data_dir = sevir_data_dir |
| self.datetime_filter = datetime_filter |
| self.catalog_filter = catalog_filter |
| self.start_date = start_date |
| self.end_date = end_date |
| self.shuffle = shuffle |
| self.shuffle_seed = int(shuffle_seed) |
| self.output_type = output_type |
| self.preprocess = preprocess |
| self.downsample_dict = downsample_dict |
| self.rescale_method = rescale_method |
| self.verbose = verbose |
|
|
| if self.start_date is not None: |
| self.catalog = self.catalog[self.catalog.time_utc > self.start_date] |
| if self.end_date is not None: |
| self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] |
| if self.datetime_filter: |
| self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] |
|
|
| if self.catalog_filter is not None: |
| if self.catalog_filter == 'default': |
| self.catalog_filter = lambda c: c.pct_missing == 0 |
| self.catalog = self.catalog[self.catalog_filter(self.catalog)] |
|
|
| self._compute_samples() |
| self._open_files(verbose=self.verbose) |
| self.reset() |
|
|
| def _compute_samples(self): |
| """ |
| Computes the list of samples in catalog to be used. This sets self._samples |
| """ |
| |
| imgt = self.data_types |
| imgts = set(imgt) |
| filtcat = self.catalog[ np.logical_or.reduce([self.catalog.img_type==i for i in imgt]) ] |
| |
| filtcat = filtcat.groupby('id').filter(lambda x: imgts.issubset(set(x['img_type']))) |
| |
| |
| filtcat = filtcat.groupby('id').filter(lambda x: x.shape[0]==len(imgt)) |
| self._samples = filtcat.groupby('id').apply(lambda df: self._df_to_series(df,imgt) ) |
| if self.shuffle: |
| self.shuffle_samples() |
|
|
| def shuffle_samples(self): |
| self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) |
|
|
| def _df_to_series(self, df, imgt): |
| d = {} |
| df = df.set_index('img_type') |
| for i in imgt: |
| s = df.loc[i] |
| idx = s.file_index if i != 'lght' else s.id |
| d.update({f'{i}_filename': [s.file_name], |
| f'{i}_index': [idx]}) |
|
|
| return pd.DataFrame(d) |
|
|
| def _open_files(self, verbose=True): |
| """ |
| Opens HDF files |
| """ |
| imgt = self.data_types |
| hdf_filenames = [] |
| for t in imgt: |
| hdf_filenames += list(np.unique( self._samples[f'{t}_filename'].values )) |
| self._hdf_files = {} |
| for f in hdf_filenames: |
| if verbose: |
| print('Opening HDF5 file for reading', f) |
| self._hdf_files[f] = h5py.File(self.sevir_data_dir + '/' + f, 'r') |
|
|
| def close(self): |
| """ |
| Closes all open file handles |
| """ |
| for f in self._hdf_files: |
| self._hdf_files[f].close() |
| self._hdf_files = {} |
|
|
| @property |
| def num_seq_per_event(self): |
| return 1 + (self.raw_seq_len - self.seq_len) // self.stride |
|
|
| @property |
| def total_num_seq(self): |
| """ |
| The total number of sequences within each shard. |
| Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. |
| """ |
| return int(self.num_seq_per_event * self.num_event) |
|
|
| @property |
| def total_num_event(self): |
| """ |
| The total number of events in the whole dataset, before split into different shards. |
| """ |
| return int(self._samples.shape[0]) |
|
|
| @property |
| def start_event_idx(self): |
| """ |
| The event idx used in certain rank should satisfy event_idx >= start_event_idx |
| """ |
| return self.total_num_event // self.num_shard * self.rank |
|
|
| @property |
| def end_event_idx(self): |
| """ |
| The event idx used in certain rank should satisfy event_idx < end_event_idx |
| |
| """ |
| if self.split_mode == 'ceil': |
| _last_start_event_idx = self.total_num_event // self.num_shard * (self.num_shard - 1) |
| _num_event = self.total_num_event - _last_start_event_idx |
| return self.start_event_idx + _num_event |
| elif self.split_mode == 'floor': |
| return self.total_num_event // self.num_shard * (self.rank + 1) |
| else: |
| if self.rank == self.num_shard - 1: |
| return self.total_num_event |
| else: |
| return self.total_num_event // self.num_shard * (self.rank + 1) |
|
|
| @property |
| def num_event(self): |
| """ |
| The number of events split into each rank |
| """ |
| return self.end_event_idx - self.start_event_idx |
|
|
| def _read_data(self, row, data): |
| """ |
| Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_len). |
| |
| Parameters |
| ---------- |
| row |
| A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. |
| data |
| Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_len). |
| |
| Returns |
| ------- |
| data |
| Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_len). |
| """ |
| imgtyps = np.unique([x.split('_')[0] for x in list(row.keys())]) |
| for t in imgtyps: |
| fname = row[f'{t}_filename'] |
| idx = row[f'{t}_index'] |
| t_slice = slice(0, None) |
| |
| if t == 'lght': |
| lght_data = self._hdf_files[fname][idx][:] |
| data_i = self._lght_to_grid(lght_data, t_slice) |
| else: |
| data_i = self._hdf_files[fname][t][idx:idx + 1, :, :, t_slice] |
| data[t] = np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i |
|
|
| return data |
|
|
| def _lght_to_grid(self, data, t_slice=slice(0, None)): |
| """ |
| Converts Nx5 lightning data matrix into a 2D grid of pixel counts |
| """ |
| |
| out_size = (*self.data_shape['lght'], len(self.lght_frame_times)) if t_slice.stop is None else (*self.data_shape['lght'], 1) |
| if data.shape[0] == 0: |
| return np.zeros((1,) + out_size, dtype=np.float32) |
|
|
| |
| x, y = data[:, 3], data[:, 4] |
| m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) |
| data = data[m, :] |
| if data.shape[0] == 0: |
| return np.zeros((1,) + out_size, dtype=np.float32) |
|
|
| |
| t = data[:, 0] |
| if t_slice.stop is not None: |
| if t_slice.stop > 0: |
| if t_slice.stop < len(self.lght_frame_times): |
| tm = np.logical_and(t >= self.lght_frame_times[t_slice.stop - 1], |
| t < self.lght_frame_times[t_slice.stop]) |
| else: |
| tm = t >= self.lght_frame_times[-1] |
| else: |
| tm = np.logical_and(t >= self.lght_frame_times[0], t < self.lght_frame_times[1]) |
| |
|
|
| data = data[tm, :] |
| z = np.zeros(data.shape[0], dtype=np.int64) |
| else: |
| z = np.digitize(t, self.lght_frame_times) - 1 |
| z[z == -1] = 0 |
|
|
| x = data[:, 3].astype(np.int64) |
| y = data[:, 4].astype(np.int64) |
|
|
| k = np.ravel_multi_index(np.array([y, x, z]), out_size) |
| n = np.bincount(k, minlength=np.prod(out_size)) |
| return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :] |
|
|
| def _old_save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
| """ |
| This method does not save .h5 dataset correctly. There are some batches missed due to unknown error. |
| E.g., the first converted .h5 file `SEVIR_VIL_RANDOMEVENTS_2017_0501_0831.h5` only has batch_dim = 1414, |
| while it should be 1440 in the original .h5 file. |
| """ |
| import os |
| from skimage.measure import block_reduce |
| assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
| os.makedirs(save_dir) |
| sample_counter = 0 |
| for index, row in self._samples.iterrows(): |
| if verbose: |
| print(f"Downsampling {sample_counter}-th data item.", end='\r') |
| for data_type in self.data_types: |
| fname = row[f'{data_type}_filename'] |
| idx = row[f'{data_type}_index'] |
| t_slice = slice(0, None) |
| if data_type == 'lght': |
| lght_data = self._hdf_files[fname][idx][:] |
| data_i = self._lght_to_grid(lght_data, t_slice) |
| else: |
| data_i = self._hdf_files[fname][data_type][idx:idx + 1, :, :, t_slice] |
| |
| t_slice = [slice(None, None), ] * 4 |
| t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
| data_i = data_i[tuple(t_slice)] |
| |
| data_i = block_reduce(data_i, |
| block_size=(1, *downsample_dict[data_type][1:], 1), |
| func=np.max) |
| |
| new_file_path = os.path.join(save_dir, fname) |
| if not os.path.exists(new_file_path): |
| if not os.path.exists(os.path.dirname(new_file_path)): |
| os.makedirs(os.path.dirname(new_file_path)) |
| |
| with h5py.File(new_file_path, 'w') as hf: |
| hf.create_dataset( |
| data_type, data=data_i, |
| maxshape=(None, *data_i.shape[1:])) |
| else: |
| |
| with h5py.File(new_file_path, 'a') as hf: |
| hf[data_type].resize((hf[data_type].shape[0] + data_i.shape[0]), axis=0) |
| hf[data_type][-data_i.shape[0]:] = data_i |
|
|
| sample_counter += 1 |
|
|
| def save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
| """ |
| Parameters |
| ---------- |
| save_dir |
| downsample_dict: Dict[Sequence[int]] |
| Notice that this is different from `self.downsample_dict`, which is used during runtime. |
| """ |
| import os |
| from skimage.measure import block_reduce |
| from ...utils.utils import path_splitall |
| assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
| os.makedirs(save_dir) |
| for fname, hdf_file in self._hdf_files.items(): |
| if verbose: |
| print(f"Downsampling data in {fname}.") |
| data_type = path_splitall(fname)[0] |
| if data_type == 'lght': |
| |
| raise NotImplementedError |
| |
| |
| |
| else: |
| data_i = self._hdf_files[fname][data_type] |
| |
| t_slice = [slice(None, None), ] * 4 |
| t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
| data_i = data_i[tuple(t_slice)] |
| |
| data_i = block_reduce(data_i, |
| block_size=(1, *downsample_dict[data_type][1:], 1), |
| func=np.max) |
| |
| new_file_path = os.path.join(save_dir, fname) |
| if not os.path.exists(os.path.dirname(new_file_path)): |
| os.makedirs(os.path.dirname(new_file_path)) |
| |
| with h5py.File(new_file_path, 'w') as hf: |
| hf.create_dataset( |
| data_type, data=data_i, |
| maxshape=(None, *data_i.shape[1:])) |
|
|
| @property |
| def sample_count(self): |
| """ |
| Record how many times self.__next__() is called. |
| """ |
| return self._sample_count |
|
|
| def inc_sample_count(self): |
| self._sample_count += 1 |
|
|
| @property |
| def curr_event_idx(self): |
| return self._curr_event_idx |
|
|
| @property |
| def curr_seq_idx(self): |
| """ |
| Used only when self.sample_mode == 'sequent' |
| """ |
| return self._curr_seq_idx |
|
|
| def set_curr_event_idx(self, val): |
| self._curr_event_idx = val |
|
|
| def set_curr_seq_idx(self, val): |
| """ |
| Used only when self.sample_mode == 'sequent' |
| """ |
| self._curr_seq_idx = val |
|
|
| def reset(self, shuffle: bool = None): |
| self.set_curr_event_idx(val=self.start_event_idx) |
| self.set_curr_seq_idx(0) |
| self._sample_count = 0 |
| if shuffle is None: |
| shuffle = self.shuffle |
| if shuffle: |
| self.shuffle_samples() |
|
|
| def __len__(self): |
| """ |
| Used only when self.sample_mode == 'sequent' |
| """ |
| return self.total_num_seq // self.batch_size |
|
|
| @property |
| def use_up(self): |
| """ |
| Check if dataset is used up in 'sequent' mode. |
| """ |
| if self.sample_mode == 'random': |
| return False |
| else: |
| |
| curr_event_remain_seq = self.num_seq_per_event - self.curr_seq_idx |
| all_remain_seq = curr_event_remain_seq + ( |
| self.end_event_idx - self.curr_event_idx - 1) * self.num_seq_per_event |
| if self.split_mode == "floor": |
| |
| return all_remain_seq < self.batch_size |
| else: |
| return all_remain_seq <= 0 |
|
|
| def _load_event_batch(self, event_idx, event_batch_size): |
| """ |
| Loads a selected batch of events (not batch of sequences) into memory. |
| |
| Parameters |
| ---------- |
| idx |
| event_batch_size |
| event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size] |
| Returns |
| ------- |
| event_batch |
| list of event batches. |
| event_batch[i] is the event batch of the i-th data type. |
| Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_len) |
| """ |
| event_idx_slice_end = event_idx + event_batch_size |
| pad_size = 0 |
| if event_idx_slice_end > self.end_event_idx: |
| pad_size = event_idx_slice_end - self.end_event_idx |
| event_idx_slice_end = self.end_event_idx |
| pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] |
| data = {} |
| for index, row in pd_batch.iterrows(): |
| data = self._read_data(row, data) |
| if pad_size > 0: |
| event_batch = [] |
| for t in self.data_types: |
| pad_shape = [pad_size, ] + list(data[t].shape[1:]) |
| data_pad = np.concatenate((data[t].astype(self.output_type), |
| np.zeros(pad_shape, dtype=self.output_type)), |
| axis=0) |
| event_batch.append(data_pad) |
| else: |
| event_batch = [data[t].astype(self.output_type) for t in self.data_types] |
| return event_batch |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if self.sample_mode == 'random': |
| self.inc_sample_count() |
| ret_dict = self._random_sample() |
| else: |
| if self.use_up: |
| raise StopIteration |
| else: |
| self.inc_sample_count() |
| ret_dict = self._sequent_sample() |
| ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
| data_types=self.data_types) |
| if self.preprocess: |
| ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
| data_types=self.data_types, |
| layout=self.layout, |
| rescale=self.rescale_method) |
| if self.downsample_dict is not None: |
| ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
| data_types=self.data_types, |
| factors_dict=self.downsample_dict, |
| layout=self.layout) |
| return ret_dict |
|
|
| def __getitem__(self, index): |
| data_dict = self._idx_sample(index=index) |
| return data_dict |
|
|
| @staticmethod |
| def preprocess_data_dict(data_dict, data_types=None, layout='NHWT', rescale='01'): |
| """ |
| Parameters |
| ---------- |
| data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
| data_types: Sequence[str] |
| The data types that we want to rescale. This mainly excludes "mask" from preprocessing. |
| layout: str |
| consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
| rescale: str |
| 'sevir': use the offsets and scale factors in original implementation. |
| '01': scale all values to range 0 to 1, currently only supports 'vil' |
| Returns |
| ------- |
| data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
| preprocessed data |
| """ |
| if rescale == 'sevir': |
| scale_dict = PREPROCESS_SCALE_SEVIR |
| offset_dict = PREPROCESS_OFFSET_SEVIR |
| elif rescale == '01': |
| scale_dict = PREPROCESS_SCALE_01 |
| offset_dict = PREPROCESS_OFFSET_01 |
| else: |
| raise ValueError(f'Invalid rescale option: {rescale}.') |
| if data_types is None: |
| data_types = data_dict.keys() |
| for key, data in data_dict.items(): |
| if key in data_types: |
| if isinstance(data, np.ndarray): |
| data = scale_dict[key] * ( |
| data.astype(np.float32) + |
| offset_dict[key]) |
| data = change_layout_np(data=data, |
| in_layout='NHWT', |
| out_layout=layout) |
| elif isinstance(data, torch.Tensor): |
| data = scale_dict[key] * ( |
| data.float() + |
| offset_dict[key]) |
| data = change_layout_torch(data=data, |
| in_layout='NHWT', |
| out_layout=layout) |
| data_dict[key] = data |
| return data_dict |
|
|
| @staticmethod |
| def process_data_dict_back(data_dict, data_types=None, rescale='01'): |
| """ |
| Parameters |
| ---------- |
| data_dict |
| each data_dict[key] is a torch.Tensor. |
| rescale |
| str: |
| 'sevir': data are scaled using the offsets and scale factors in original implementation. |
| '01': data are all scaled to range 0 to 1, currently only supports 'vil' |
| Returns |
| ------- |
| data_dict |
| each data_dict[key] is the data processed back in torch.Tensor. |
| """ |
| if rescale == 'sevir': |
| scale_dict = PREPROCESS_SCALE_SEVIR |
| offset_dict = PREPROCESS_OFFSET_SEVIR |
| elif rescale == '01': |
| scale_dict = PREPROCESS_SCALE_01 |
| offset_dict = PREPROCESS_OFFSET_01 |
| else: |
| raise ValueError(f'Invalid rescale option: {rescale}.') |
| if data_types is None: |
| data_types = data_dict.keys() |
| for key in data_types: |
| data = data_dict[key] |
| data = data.float() / scale_dict[key] - offset_dict[key] |
| data_dict[key] = data |
| return data_dict |
|
|
| @staticmethod |
| def data_dict_to_tensor(data_dict, data_types=None): |
| """ |
| Convert each element in data_dict to torch.Tensor (copy without grad). |
| """ |
| ret_dict = {} |
| if data_types is None: |
| data_types = data_dict.keys() |
| for key, data in data_dict.items(): |
| if key in data_types: |
| if isinstance(data, torch.Tensor): |
| ret_dict[key] = data.detach().clone() |
| elif isinstance(data, np.ndarray): |
| ret_dict[key] = torch.from_numpy(data) |
| else: |
| raise ValueError(f"Invalid data type: {type(data)}. Should be torch.Tensor or np.ndarray") |
| else: |
| ret_dict[key] = data |
| return ret_dict |
|
|
| @staticmethod |
| def downsample_data_dict(data_dict, data_types=None, factors_dict=None, layout='NHWT'): |
| """ |
| Parameters |
| ---------- |
| data_dict: Dict[str, Union[np.array, torch.Tensor]] |
| factors_dict: Optional[Dict[str, Sequence[int]]] |
| each element `factors` is a Sequence of int, representing (t_factor, h_factor, w_factor) |
| |
| Returns |
| ------- |
| downsampled_data_dict: Dict[str, torch.Tensor] |
| Modify on a deep copy of data_dict instead of directly modifying the original data_dict |
| """ |
| if factors_dict is None: |
| factors_dict = {} |
| if data_types is None: |
| data_types = data_dict.keys() |
| downsampled_data_dict = SEVIRDataLoader.data_dict_to_tensor( |
| data_dict=data_dict, |
| data_types=data_types) |
| for key, data in data_dict.items(): |
| factors = factors_dict.get(key, None) |
| if factors is not None: |
| downsampled_data_dict[key] = change_layout_torch( |
| data=downsampled_data_dict[key], |
| in_layout=layout, |
| out_layout='NTHW') |
| |
| t_slice = [slice(None, None), ] * 4 |
| t_slice[1] = slice(None, None, factors[0]) |
| downsampled_data_dict[key] = downsampled_data_dict[key][tuple(t_slice)] |
| |
| downsampled_data_dict[key] = avg_pool2d( |
| input=downsampled_data_dict[key], |
| kernel_size=(factors[1], factors[2])) |
|
|
| downsampled_data_dict[key] = change_layout_torch( |
| data=downsampled_data_dict[key], |
| in_layout='NTHW', |
| out_layout=layout) |
|
|
| return downsampled_data_dict |
|
|
| def _random_sample(self): |
| """ |
| Returns |
| ------- |
| ret_dict |
| dict. ret_dict.keys() == self.data_types. |
| If self.preprocess == False: |
| ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| """ |
| num_sampled = 0 |
| event_idx_list = nprand.randint(low=self.start_event_idx, |
| high=self.end_event_idx, |
| size=self.batch_size) |
| seq_idx_list = nprand.randint(low=0, |
| high=self.num_seq_per_event, |
| size=self.batch_size) |
| seq_slice_list = [slice(seq_idx * self.stride, |
| seq_idx * self.stride + self.seq_len) |
| for seq_idx in seq_idx_list] |
| ret_dict = {} |
| while num_sampled < self.batch_size: |
| event = self._load_event_batch(event_idx=event_idx_list[num_sampled], |
| event_batch_size=1) |
| for imgt_idx, imgt in enumerate(self.data_types): |
| sampled_seq = event[imgt_idx][[0, ], :, :, seq_slice_list[num_sampled]] |
| if imgt in ret_dict: |
| ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| axis=0) |
| else: |
| ret_dict.update({imgt: sampled_seq}) |
| return ret_dict |
|
|
| def _sequent_sample(self): |
| """ |
| Returns |
| ------- |
| ret_dict: Dict |
| `ret_dict.keys()` contains `self.data_types`. |
| `ret_dict["mask"]` is a list of bool, indicating if the data entry is real or padded. |
| If self.preprocess == False: |
| ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| """ |
| assert not self.use_up, 'Data loader used up! Reset it to reuse.' |
| event_idx = self.curr_event_idx |
| seq_idx = self.curr_seq_idx |
| num_sampled = 0 |
| sampled_idx_list = [] |
| while num_sampled < self.batch_size: |
| sampled_idx_list.append({'event_idx': event_idx, |
| 'seq_idx': seq_idx}) |
| seq_idx += 1 |
| if seq_idx >= self.num_seq_per_event: |
| event_idx += 1 |
| seq_idx = 0 |
| num_sampled += 1 |
|
|
| start_event_idx = sampled_idx_list[0]['event_idx'] |
| event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
|
|
| event_batch = self._load_event_batch(event_idx=start_event_idx, |
| event_batch_size=event_batch_size) |
| ret_dict = {"mask": []} |
| all_no_pad_flag = True |
| for sampled_idx in sampled_idx_list: |
| batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
| seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
| sampled_idx['seq_idx'] * self.stride + self.seq_len) |
| for imgt_idx, imgt in enumerate(self.data_types): |
| sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
| if imgt in ret_dict: |
| ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| axis=0) |
| else: |
| ret_dict.update({imgt: sampled_seq}) |
| |
| no_pad_flag = sampled_idx['event_idx'] < self.end_event_idx |
| if not no_pad_flag: |
| all_no_pad_flag = False |
| ret_dict["mask"].append(no_pad_flag) |
| if all_no_pad_flag: |
| |
| ret_dict["mask"] = None |
| |
| self.set_curr_event_idx(event_idx) |
| self.set_curr_seq_idx(seq_idx) |
| return ret_dict |
|
|
| def _idx_sample(self, index): |
| """ |
| Parameters |
| ---------- |
| index |
| The index of the batch to sample. |
| Returns |
| ------- |
| ret_dict |
| dict. ret_dict.keys() == self.data_types. |
| If self.preprocess == False: |
| ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| """ |
| event_idx = (index * self.batch_size) // self.num_seq_per_event |
| seq_idx = (index * self.batch_size) % self.num_seq_per_event |
| num_sampled = 0 |
| sampled_idx_list = [] |
| while num_sampled < self.batch_size: |
| sampled_idx_list.append({'event_idx': event_idx, |
| 'seq_idx': seq_idx}) |
| seq_idx += 1 |
| if seq_idx >= self.num_seq_per_event: |
| event_idx += 1 |
| seq_idx = 0 |
| num_sampled += 1 |
|
|
| start_event_idx = sampled_idx_list[0]['event_idx'] |
| event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
|
|
| event_batch = self._load_event_batch(event_idx=start_event_idx, |
| event_batch_size=event_batch_size) |
| ret_dict = {} |
| for sampled_idx in sampled_idx_list: |
| batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
| seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
| sampled_idx['seq_idx'] * self.stride + self.seq_len) |
| for imgt_idx, imgt in enumerate(self.data_types): |
| sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
| if imgt in ret_dict: |
| ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| axis=0) |
| else: |
| ret_dict.update({imgt: sampled_seq}) |
|
|
| ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
| data_types=self.data_types) |
| if self.preprocess: |
| ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
| data_types=self.data_types, |
| layout=self.layout, |
| rescale=self.rescale_method) |
|
|
| if self.downsample_dict is not None: |
| ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
| data_types=self.data_types, |
| factors_dict=self.downsample_dict, |
| layout=self.layout) |
| return ret_dict |
|
|
|
|
| class SEVIRDataIterator(): |
| ''' |
| A wrapper s.t. it implements the function sample(). |
| Every arguments in this class will be redirected to the inner SEVIRDataLoader object. |
| If you expect a pythonic iterator, use SEVIRDataLoader instead. |
| ''' |
| def __init__(self, **kwargs): |
| self.loader = SEVIRDataLoader(**kwargs) |
| self.sample_mode = kwargs['sample_mode'] if 'sample_mode' in kwargs else 'random' |
| |
| def reset(self): |
| self.loader.reset() |
| |
| def sample(self, batch_size=None): |
| ''' |
| The input param batch_size here is not used |
| ''' |
| out = next(self.loader, None) |
| if out is None and self.sample_mode == 'random': |
| self.loader.reset() |
| out = next(self.loader, None) |
| return out |
| |
| def __len__(self): |
| """ |
| Used only when self.sample_mode == 'sequent' |
| """ |
| return len(self.loader) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| class Meteo(Dataset): |
| def __init__(self, data_path, img_size, type='train', trans=None, in_len=-1): |
| super().__init__() |
| |
| self.pixel_scale = 70.0 |
| |
| self.data_path = data_path |
| self.img_size = img_size |
| self.in_len = in_len |
|
|
| assert type in ['train', 'test', 'val'] |
| self.type = type if type!='val' else 'test' |
| with h5py.File(data_path,'r') as f: |
| self.all_len = int(f[f'{self.type}_len'][()]) |
| if trans is not None: |
| self.transform = trans |
| else: |
| self.transform = T.Compose([ |
| T.Resize((img_size, img_size)), |
| |
| |
| |
| |
|
|
| ]) |
| |
| def __len__(self): |
| return self.all_len |
|
|
| def sample(self): |
| index = np.random.randint(0, self.all_len) |
| return self.__getitem__(index) |
| |
| |
| def __getitem__(self, index): |
|
|
| with h5py.File(self.data_path,'r') as f: |
| imgs = f[self.type][str(index)][()] |
|
|
| frames = torch.from_numpy(imgs).float().squeeze() |
| frames = frames / self.pixel_scale |
| frames = self.transform(frames).unsqueeze(1) |
| |
| |
| return frames[:self.in_len], frames[self.in_len:] |
| |
|
|
| def load_meteonet(batch_size, val_batch_size, in_len, train=False, num_workers=0, img_size=128): |
| meteo_filepath = os.path.join(METEO_FILE_DIR, "meteo.h5") |
| if train: |
| train_set = Meteo(meteo_filepath, img_size, 'train', in_len=in_len) |
| valid_set = Meteo(meteo_filepath, img_size, 'val', in_len=in_len) |
| dataloader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) |
| dataloader_valid = torch.utils.data.DataLoader(valid_set, batch_size=val_batch_size, shuffle=False, drop_last=True, num_workers=num_workers) |
| return dataloader_train, dataloader_valid |
| else: |
| test_set = Meteo(meteo_filepath, img_size, 'test', in_len=in_len) |
| dataloader_test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) |
| return None, dataloader_test |