| import torch |
| import torch.distributed as dist |
|
|
| import os |
| import os.path as osp |
| import numpy as np |
| import cv2 |
| import copy |
| import json |
|
|
| from ..log_service import print_log |
|
|
| def singleton(class_): |
| instances = {} |
| def getinstance(*args, **kwargs): |
| if class_ not in instances: |
| instances[class_] = class_(*args, **kwargs) |
| return instances[class_] |
| return getinstance |
|
|
| @singleton |
| class get_evaluator(object): |
| def __init__(self): |
| self.evaluator = {} |
|
|
| def register(self, evaf, name): |
| self.evaluator[name] = evaf |
|
|
| def __call__(self, pipeline_cfg=None): |
| if pipeline_cfg is None: |
| from . import eva_null |
| return self.evaluator['null']() |
|
|
| if not isinstance(pipeline_cfg, list): |
| t = pipeline_cfg.type |
| if t == 'miou': |
| from . import eva_miou |
| if t == 'psnr': |
| from . import eva_psnr |
| if t == 'ssim': |
| from . import eva_ssim |
| if t == 'lpips': |
| from . import eva_lpips |
| if t == 'fid': |
| from . import eva_fid |
| return self.evaluator[t](**pipeline_cfg.args) |
|
|
| evaluator = [] |
| for ci in pipeline_cfg: |
| t = ci.type |
| if t == 'miou': |
| from . import eva_miou |
| if t == 'psnr': |
| from . import eva_psnr |
| if t == 'ssim': |
| from . import eva_ssim |
| if t == 'lpips': |
| from . import eva_lpips |
| if t == 'fid': |
| from . import eva_fid |
| evaluator.append( |
| self.evaluator[t](**ci.args)) |
| if len(evaluator) == 0: |
| return None |
| else: |
| return compose(evaluator) |
|
|
| def register(name): |
| def wrapper(class_): |
| get_evaluator().register(class_, name) |
| return class_ |
| return wrapper |
|
|
| class base_evaluator(object): |
| def __init__(self, |
| **args): |
| ''' |
| Args: |
| sample_n, int, |
| the total number of sample. used in |
| distributed sync |
| ''' |
| if not dist.is_available(): |
| raise ValueError |
| self.world_size = dist.get_world_size() |
| self.rank = dist.get_rank() |
| self.sample_n = None |
| self.final = {} |
|
|
| def sync(self, data): |
| """ |
| Args: |
| data: any, |
| the data needs to be broadcasted |
| """ |
| if data is None: |
| return None |
|
|
| if isinstance(data, tuple): |
| data = list(data) |
|
|
| if isinstance(data, list): |
| data_list = [] |
| for datai in data: |
| data_list.append(self.sync(datai)) |
| data = [[*i] for i in zip(*data_list)] |
| return data |
|
|
| data = [ |
| self.sync_(data, ranki) |
| for ranki in range(self.world_size) |
| ] |
| return data |
|
|
| def sync_(self, data, rank): |
|
|
| t = type(data) |
| is_broadcast = rank == self.rank |
|
|
| if t is np.ndarray: |
| dtrans = data |
| dt = data.dtype |
| if dt in [ |
| int, |
| np.bool, |
| np.uint8, |
| np.int8, |
| np.int16, |
| np.int32, |
| np.int64,]: |
| dtt = torch.int64 |
| elif dt in [ |
| float, |
| np.float16, |
| np.float32, |
| np.float64,]: |
| dtt = torch.float64 |
|
|
| elif t is str: |
| dtrans = np.array( |
| [ord(c) for c in data], |
| dtype = np.int64 |
| ) |
| dt = np.int64 |
| dtt = torch.int64 |
| else: |
| raise ValueError |
|
|
| if is_broadcast: |
| n = len(dtrans.shape) |
| n = torch.tensor(n).long() |
|
|
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
|
|
| n = list(dtrans.shape) |
| n = torch.tensor(n).long() |
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
|
|
| n = torch.tensor(dtrans, dtype=dtt) |
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
| return data |
|
|
| n = torch.tensor(0).long() |
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
| n = n.item() |
|
|
| n = torch.zeros(n).long() |
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
| n = list(n.to('cpu').numpy()) |
|
|
| n = torch.zeros(n, dtype=dtt) |
| n = n.to(self.rank) |
| dist.broadcast(n, src=rank) |
| n = n.to('cpu').numpy().astype(dt) |
|
|
| if t is np.ndarray: |
| return n |
| elif t is str: |
| n = ''.join([chr(c) for c in n]) |
| return n |
|
|
| def zipzap_arrange(self, data): |
| ''' |
| Order the data so it range like this: |
| input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...] |
| ''' |
| if isinstance(data[0], list): |
| data_new = [] |
| maxlen = max([len(i) for i in data]) |
| totlen = sum([len(i) for i in data]) |
| cnt = 0 |
| for idx in range(maxlen): |
| for datai in data: |
| data_new += [datai[idx]] |
| cnt += 1 |
| if cnt >= totlen: |
| break |
| return data_new |
|
|
| elif isinstance(data[0], np.ndarray): |
| maxlen = max([i.shape[0] for i in data]) |
| totlen = sum([i.shape[0] for i in data]) |
| datai_shape = data[0].shape[1:] |
| data = [ |
| np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0) |
| if datai.shape[0] < maxlen else datai |
| for datai in data |
| ] |
| data = np.stack(data, axis=1).reshape(-1, *datai_shape) |
| data = data[:totlen] |
| return data |
|
|
| else: |
| raise NotImplementedError |
|
|
| def add_batch(self, **args): |
| raise NotImplementedError |
|
|
| def set_sample_n(self, sample_n): |
| self.sample_n = sample_n |
|
|
| def compute(self): |
| raise NotImplementedError |
|
|
| |
| |
| def isbetter(self, old, new): |
| return new>old |
|
|
| def one_line_summary(self): |
| print_log('Evaluator display') |
|
|
| def save(self, path): |
| if not osp.exists(path): |
| os.makedirs(path) |
| ofile = osp.join(path, 'result.json') |
| with open(ofile, 'w') as f: |
| json.dump(self.final, f, indent=4) |
|
|
| def clear_data(self): |
| raise NotImplementedError |
|
|
| class compose(object): |
| def __init__(self, pipeline): |
| self.pipeline = pipeline |
| self.sample_n = None |
| self.final = {} |
|
|
| def add_batch(self, *args, **kwargs): |
| for pi in self.pipeline: |
| pi.add_batch(*args, **kwargs) |
|
|
| def set_sample_n(self, sample_n): |
| self.sample_n = sample_n |
| for pi in self.pipeline: |
| pi.set_sample_n(sample_n) |
|
|
| def compute(self): |
| rv = {} |
| for pi in self.pipeline: |
| rv[pi.symbol] = pi.compute() |
| self.final[pi.symbol] = pi.final |
| return rv |
|
|
| def isbetter(self, old, new): |
| check = 0 |
| for pi in self.pipeline: |
| if pi.isbetter(old, new): |
| check+=1 |
| if check/len(self.pipeline)>0.5: |
| return True |
| else: |
| return False |
|
|
| def one_line_summary(self): |
| for pi in self.pipeline: |
| pi.one_line_summary() |
|
|
| def save(self, path): |
| if not osp.exists(path): |
| os.makedirs(path) |
| ofile = osp.join(path, 'result.json') |
| with open(ofile, 'w') as f: |
| json.dump(self.final, f, indent=4) |
|
|
| def clear_data(self): |
| for pi in self.pipeline: |
| pi.clear_data() |
|
|