| import torch.nn.functional as F |
|
|
| from data import dutils |
| from nowcasting.hko_iterator import HKOIterator |
|
|
| def GET_TrainLoader(meta, param, batch_size, in_len, out_len): |
| if meta['dataset'] == 'SEVIR': |
| total_seq_len = in_len + out_len |
| train_config = { |
| 'data_types': ['vil'], |
| 'layout': 'NTCHW', |
| 'seq_len': total_seq_len, |
| 'raw_seq_len': total_seq_len, |
| 'end_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE, |
| 'start_date': None |
| } |
| test_config = { |
| 'data_types': ['vil'], |
| 'layout': 'NTCHW', |
| 'seq_len': total_seq_len, |
| 'raw_seq_len': total_seq_len, |
| 'end_date': None, |
| 'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE |
| } |
| train_loader = dutils.SEVIRDataIterator(**train_config, batch_size=batch_size) |
| test_loader = dutils.SEVIRDataIterator(**test_config, batch_size=8 if batch_size > 8 else batch_size) |
| return train_loader, test_loader |
| elif meta['dataset'].startswith('HKO'): |
| total_seq_len = in_len + out_len |
| pkl_path = param['pd_path'] |
| train_loader = HKOIterator(pd_path=pkl_path.replace('test', 'train'), sample_mode="random", seq_len=total_seq_len, stride=1) |
| test_loader = HKOIterator(pd_path=pkl_path, sample_mode="sequent", seq_len=total_seq_len, stride=in_len) |
| return train_loader, test_loader |
| elif meta['dataset'] == 'meteonet': |
| train_loader, test_loader = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8 if batch_size > 8 else batch_size, train=True, **param) |
| return train_loader, test_loader |
| else: |
| raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}') |
|
|
| def GET_TestLoader(meta, param, batch_size): |
| if meta['dataset'] == 'SEVIR': |
| return dutils.SEVIRDataIterator(**param, batch_size=batch_size) |
| elif meta['dataset'].startswith('HKO'): |
| return HKOIterator(**param) |
| elif meta['dataset'] == 'meteonet': |
| _, test_iter = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8, train=False, **param) |
| return iter(test_iter) |
| else: |
| raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}') |