| import os, sys, logging, argparse |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| import utilspp as utpp |
| from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss |
| from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20 |
| from data.loader import GET_TestLoader |
| from data.dutils import resize |
|
|
| class MetricListEvaluator(): |
| ''' |
| To evaluate a list of metrics. Supported metrics: |
| - CSI, HSS (Eg. `csi-84, hss-84`) |
| - CSI-pooled (Eg. `csi_4-84`) |
| - MAE |
| - MSE |
| - SSIM |
| - PSNR |
| ''' |
| def __init__(self, metric_list): |
| self.metric_holder = {} |
| self.batch_count = 0 |
| for metric_name in metric_list: |
| threshold = '' |
| radius = '' |
| if '-' in metric_name: |
| metric, threshold = metric_name.split('-') |
| if '_' in metric: |
| metric, radius = metric.split('_') |
| radius = int(radius) |
| |
| threshold = float(threshold) / 255 if threshold.isdigit() else threshold |
| self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius) |
|
|
| def init_metric(self, metric_name, **kwarg): |
| ''' |
| return a tuple of three items in order: |
| - the function to call during eval |
| - the value(s) to keep track of |
| - a dict of any additional item to pass into the function |
| ''' |
| if metric_name.split('-')[0] in ['csi', 'hss']: |
| |
| return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] |
| elif '_' in metric_name.split('-')[0]: |
| return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}] |
| else: |
| |
| return [eval(metric_name), 0, {}] |
|
|
| def eval(self, y_pred, y): |
| self.batch_count += 1 |
| for _, metric in self.metric_holder.items(): |
| temp = metric[0](y_pred, y, **metric[-1]) |
| if temp is list: |
| temp = np.array(temp) |
| elif type(temp) == torch.Tensor: |
| temp = temp.detach().cpu().numpy() |
| metric[1] += temp |
| |
| def get_results(self): |
| output_holder = {} |
| for key, metric in self.metric_holder.items(): |
| val = metric[1] |
| |
| if metric[0] is utpp.tfpn: |
| metric_name, threshold = key.split('-') |
| val = eval(metric_name)(*list(metric[1])) |
| elif metric[0] is utpp.tfpn_pool: |
| metric_name, info = key.split('_') |
| val = eval(metric_name)(*list(metric[1])) |
| else: |
| val /= self.batch_count |
| output_holder[key] = val |
| return output_holder |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set') |
| parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length') |
| |
| parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}') |
| parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number') |
| |
| parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader') |
| parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size') |
| |
| parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /') |
| |
| parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss') |
| args = parser.parse_args() |
|
|
| |
| path_list = args.e_file.split("/") |
| logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log') |
| logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s') |
| logging.info(f'Steps: {args.step}') |
|
|
| dataset_config = globals()[args.dataset] |
| dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta'] |
| loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size) |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| metric_list = dataset_meta['metrics'] |
| if args.metrics is not None: |
| metric_list = args.metrics.lower().split('/') |
| logging.info(f'Overwriting metrics list with: {metric_list}') |
| evaluator = MetricListEvaluator(metric_list) |
|
|
| for e in range(args.ens_no): |
| prediction = np.load(args.e_file.format(str(e))) |
| prediction = torch.tensor(prediction, device=device) |
| |
| step = 1 |
| if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']: |
| loader.reset() |
| else: |
| pass |
| |
| while args.step < 0 or step <= args.step: |
| if dataset_meta['dataset'] == 'SEVIR': |
| data = loader.sample(batch_size=args.batch_size) |
| if data is None: |
| break |
| y = data['vil'][:, -args.out_len:] |
| elif dataset_meta['dataset'] == 'HKO-7': |
| setattr(args, 'seq_len', dataset_meta['seq_len']) |
| try: |
| data = loader.sample(batch_size=args.batch_size) |
| except Exception as e: |
| logging.error(e) |
| break |
| x_seq, x_mask, dt_clip, _ = data |
| x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args) |
| elif dataset_meta['dataset'].startswith('meteo'): |
| try: |
| x, y = next(loader) |
| except Exception as e: |
| logging.error(e) |
| break |
|
|
| with torch.no_grad(): |
| y = y.to(device) |
| y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size] |
|
|
| if y.shape[-1] != y_pred.shape[-1]: |
| y = resize(y, y_pred.shape[-1]) |
|
|
| y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) |
| evaluator.eval(y_pred, y) |
| |
| if step == 1 or step % args.print_every == 0: |
| logging.info(f'E_ID:{e} -> {step} Steps evaluated') |
| step += 1 |
| |
| final_results = evaluator.get_results() |
| for k, v in final_results.items(): |
| logging.info(f'{k}: {v}') |