import numpy import torch class AverageMeter: def __init__(self, *keys): self.__data = dict() for k in keys: self.__data[k] = [0.0, 0] def add(self, dict): for k, v in dict.items(): self.__data[k][0] += v self.__data[k][1] += 1 def get(self, *keys): if len(keys) == 1: return self.__data[keys[0]][0] / self.__data[keys[0]][1] else: v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] return tuple(v_list) def get_entire_dict_for_ddp_calculation(self): return self.__data def pop(self, key=None): if key is None: for k in self.__data.keys(): self.__data[k] = [0.0, 0] else: v = self.get(key) self.__data[key] = [0.0, 0] return v class ForegroundIoU(AverageMeter): def __init__(self): super(ForegroundIoU, self).__init__('foreground_iou') def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False): r""" param (both hard mask): pred: size [N x H x W], type: int target: size [N x H x W], type: int output: iou: size [1] (size_average=True) or [N] (size_average=False) """ assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.' assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.' N = pred.size(0) num_pixels = pred.size(-1) * pred.size(-2) no_obj_flag = (target.sum(2).sum(1) == 0) inter = (pred * target).sum(2).sum(1) union = torch.max(pred, target).sum(2).sum(1) inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) inter[no_obj_flag] = inter_no_obj[no_obj_flag] union[no_obj_flag] = num_pixels iou = torch.sum(inter / (union+eps)) / N self.add({'foreground_iou': iou}) return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation() def reset(self,): super(ForegroundIoU, self).__init__('foreground_iou')