| import torch |
| import numpy |
|
|
|
|
| class BinaryMIoU(object): |
| def __init__(self, ignore_index): |
| self.num_classes = 2 |
| self.ignore_index = ignore_index |
| self.inter, self.union = 0, 0 |
| self.correct, self.label = 0, 0 |
| self.iou = numpy.array([0 for _ in range(self.num_classes)]) |
| self.acc = 0.0 |
|
|
| def get_metric_results(self, curr_correct_, curr_label_, curr_inter_, curr_union_): |
| |
| self.correct = self.correct + curr_correct_ |
| self.label = self.label + curr_label_ |
| self.inter = self.inter + curr_inter_ |
| self.union = self.union + curr_union_ |
| self.acc = 1.0 * self.correct / (numpy.spacing(1) + self.label) |
| self.iou = 1.0 * self.inter / (numpy.spacing(1) + self.union) |
| return numpy.round(self.iou, 4), numpy.round(self.acc, 4) |
| |
| |
| |
| |
| |
| |
|
|
| @staticmethod |
| def get_current_image_results(curr_correct_, curr_label_, curr_inter_, curr_union_): |
| curr_acc = 1.0 * curr_correct_ / (numpy.spacing(1) + curr_label_) |
| curr_iou = 1.0 * curr_inter_ / (numpy.spacing(1) + curr_union_) |
| return curr_iou, curr_acc |
|
|
| def __call__(self, x, y): |
| curr_correct, curr_label, curr_inter, curr_union = self.calculate_current_sample(x, y) |
| return (self.get_metric_results(curr_correct, curr_label, curr_inter, curr_union), |
| self.get_current_image_results(curr_correct, curr_label, curr_inter, curr_union)) |
|
|
| def calculate_current_sample(self, output, target): |
| |
| |
| target[target == self.ignore_index] = -1 |
| correct, labeled = self.batch_pix_accuracy(output.data, target) |
| inter, union = self.batch_intersection_union(output.data, target, self.num_classes) |
| return [numpy.round(correct, 5), numpy.round(labeled, 5), numpy.round(inter, 5), numpy.round(union, 5)] |
|
|
| @ staticmethod |
| def batch_pix_accuracy(predict, target): |
| |
|
|
| predict = predict.int() + 1 |
| target = target.int() + 1 |
|
|
| pixel_labeled = (target > 0).sum() |
| pixel_correct = ((predict == target) * (target > 0)).sum() |
| assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" |
| return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() |
|
|
| @ staticmethod |
| def batch_intersection_union(predict, target, num_class): |
| |
| predict = predict + 1 |
| target = target + 1 |
|
|
| predict = predict * (target > 0).long() |
| intersection = predict * (predict == target).long() |
|
|
| area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) |
| area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) |
| area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) |
| area_union = area_pred + area_lab - area_inter |
| assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" |
| return area_inter.cpu().numpy(), area_union.cpu().numpy() |
|
|
|
|