import torch import numpy as np def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) return hist def label_accuracy_score(label_trues, label_preds, n_class, bg_thre=200): """Returns accuracy score evaluation result. - overall accuracy - mean accuracy - mean IU - fwavacc """ hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() return acc, acc_cls, mean_iu, fwavacc def label_confusion_matrix(label_trues, label_preds, n_class, bg_thre=200): # eps=1e-20 hist=np.zeros((n_class,n_class),dtype=float) """ (8,256,256), (256,256) """ for lt,lp in zip(label_trues, label_preds): # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) hist += _fast_hist(lt[lt 0] * iu[freq > 0]).sum() return acc, acc_cls, mean_iu, fwavacc, iu def cal_seg_iou_loss(gt,pred,trsh=0.5): t=np.array(pred>trsh) p=np.array(gt>0.) intersection = np.logical_and(t, p) union = np.logical_or(t, p) iou = (np.sum(intersection > 0 , axis=(2,3)) + 1e-10 )/ (np.sum(union > 0, axis=(2,3)) + 1e-10) return iou def cal_seg_iou(gt,pred,trsh=0.5): #(gt.shape) [1 428 640] #(pred.shape) [428 640] t=np.array(pred>trsh) p=np.array(gt>0.) intersection = np.logical_and(t, p) union = np.logical_or(t, p) iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10) prec=dict() thresholds = np.arange(0.5, 1, 0.05) for thresh in thresholds: prec[thresh]= float(iou > thresh) return iou,prec def cal_seg_iou2(gt,pred,trsh=0.5): #(gt.shape) [1 428 640] #(pred.shape) [428 640] t=np.array(pred>trsh) p=np.array(gt>0.) intersection = np.logical_and(t, p) union = np.logical_or(t, p) iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10) prec=dict() thresholds = np.arange(0.5, 1, 0.05) for thresh in thresholds: prec[thresh]= float(iou > thresh) return iou, prec, np.sum(intersection > 0), np.sum(union > 0)