AuralSAM2 / avs.code /v1s.code /utils /foreground_iou.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
import numpy
import torch
class AverageMeter:
def __init__(self, *keys):
self.__data = dict()
for k in keys:
self.__data[k] = [0.0, 0]
def add(self, dict):
for k, v in dict.items():
self.__data[k][0] += v
self.__data[k][1] += 1
def get(self, *keys):
if len(keys) == 1:
return self.__data[keys[0]][0] / self.__data[keys[0]][1]
else:
v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
return tuple(v_list)
def get_entire_dict_for_ddp_calculation(self):
return self.__data
def pop(self, key=None):
if key is None:
for k in self.__data.keys():
self.__data[k] = [0.0, 0]
else:
v = self.get(key)
self.__data[key] = [0.0, 0]
return v
class ForegroundIoU(AverageMeter):
def __init__(self):
super(ForegroundIoU, self).__init__('foreground_iou')
def calculate_iou(self, pred, target, eps=1e-7, get_entire_list=False):
r"""
param (both hard mask):
pred: size [N x H x W], type: int
target: size [N x H x W], type: int
output:
iou: size [1] (size_average=True) or [N] (size_average=False)
"""
assert len(pred.shape) == 3 and pred.shape == target.shape, 'shape mismatch.'
assert pred.dtype is torch.long and target.dtype is torch.long, 'type mismatch.'
N = pred.size(0)
num_pixels = pred.size(-1) * pred.size(-2)
no_obj_flag = (target.sum(2).sum(1) == 0)
inter = (pred * target).sum(2).sum(1)
union = torch.max(pred, target).sum(2).sum(1)
inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
inter[no_obj_flag] = inter_no_obj[no_obj_flag]
union[no_obj_flag] = num_pixels
iou = torch.sum(inter / (union+eps)) / N
self.add({'foreground_iou': iou})
return self.get('foreground_iou') if not get_entire_list else self.get_entire_dict_for_ddp_calculation()
def reset(self,):
super(ForegroundIoU, self).__init__('foreground_iou')