Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from collections import defaultdict, OrderedDict | |
| from typing import Dict, List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from prettytable import PrettyTable | |
| from sapiens.engine.evaluators import BaseEvaluator | |
| from sapiens.registry import MODELS | |
| from ..datasets.seg.seg_dome_dataset import DOME_CLASSES_29 | |
| class SegEvaluator(BaseEvaluator): | |
| def __init__( | |
| self, | |
| class_names="dome29", | |
| ignore_index: int = 255, | |
| iou_metrics: List[str] = ["mIoU"], | |
| nan_to_num: Optional[int] = None, | |
| beta: int = 1, | |
| ): | |
| super().__init__() | |
| self.ignore_index = ignore_index | |
| self.class_names = ( | |
| self.extract_class(DOME_CLASSES_29) if class_names == "dome29" else None | |
| ) | |
| self.metrics = iou_metrics | |
| self.nan_to_num = nan_to_num | |
| self.beta = beta | |
| def extract_class(self, class_names): | |
| return [class_info["name"] for _, class_info in class_names.items()] | |
| def process(self, pred_logits, data_samples: dict, accelerator=None): | |
| assert accelerator is not None, "evaluation process expects an accelerator" | |
| num_classes = pred_logits.shape[1] | |
| ai_list, au_list, apl_list, al_list = [], [], [], [] | |
| for i in range(len(pred_logits)): | |
| pred_logit = pred_logits[i] # C x H x W | |
| gt_label = data_samples[i]["gt_seg"].squeeze() # H x W | |
| if pred_logit.shape[2:] != gt_label.shape: | |
| pred_logit = F.interpolate( | |
| input=pred_logit.unsqueeze(0), | |
| size=gt_label.shape, | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=False, | |
| ).squeeze(0) | |
| pred_label = pred_logit.argmax(dim=0) # H x W | |
| a_i, a_u, a_pl, a_l = self.intersect_and_union( | |
| pred_label, gt_label, num_classes, self.ignore_index | |
| ) | |
| ai_list.append(a_i) | |
| au_list.append(a_u) | |
| apl_list.append(a_pl) | |
| al_list.append(a_l) | |
| # Local per-batch tensors: (B_local, C) | |
| ai = torch.stack(ai_list, dim=0) | |
| au = torch.stack(au_list, dim=0) | |
| apl = torch.stack(apl_list, dim=0) | |
| al = torch.stack(al_list, dim=0) | |
| # Pack as (B_local, 4, C) so gather concatenates along the batch dim. | |
| pack = torch.stack([ai, au, apl, al], dim=1) # (B_local, 4, C) | |
| gpack = accelerator.gather_for_metrics(pack) # (B_global_this_step, 4, C) | |
| batch_tot = gpack.sum(dim=0) # (4, C) global for this step | |
| ai_g, au_g, apl_g, al_g = batch_tot[0], batch_tot[1], batch_tot[2], batch_tot[3] | |
| # Only rank-0 appends real totals for this batch | |
| if accelerator.is_main_process: | |
| self.results.append((ai_g, au_g, apl_g, al_g)) | |
| return | |
| def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: | |
| assert accelerator is not None, "evaluation aggregation expects an accelerator" | |
| if not accelerator.is_main_process: | |
| self.reset() | |
| return {} | |
| if not self.results: | |
| if logger is not None: | |
| logger.info("No results to evaluate.") | |
| return {} | |
| per_field = list(zip(*self.results)) # [(ai_b), (au_b), (apl_b), (al_b)] | |
| totals = [torch.stack(x, dim=0).sum(dim=0) for x in per_field] | |
| ( | |
| total_area_intersect, | |
| total_area_union, | |
| total_area_pred_label, | |
| total_area_label, | |
| ) = totals # tensors already reduced across ranks | |
| ret_metrics = self.total_area_to_metrics( | |
| total_area_intersect, | |
| total_area_union, | |
| total_area_pred_label, | |
| total_area_label, | |
| self.metrics, | |
| self.nan_to_num, | |
| self.beta, | |
| ) | |
| ret_metrics_summary = OrderedDict( | |
| { | |
| ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) | |
| for ret_metric, ret_metric_value in ret_metrics.items() | |
| } | |
| ) | |
| metrics = dict() | |
| for key, val in ret_metrics_summary.items(): | |
| if key == "aAcc": | |
| metrics[key] = val | |
| else: | |
| metrics["m" + key] = val | |
| # each class table | |
| ret_metrics.pop("aAcc", None) | |
| ret_metrics_class = OrderedDict( | |
| { | |
| ret_metric: np.round(ret_metric_value * 100, 2) | |
| for ret_metric, ret_metric_value in ret_metrics.items() | |
| } | |
| ) | |
| if self.class_names is not None: | |
| ret_metrics_class.update({"Class": self.class_names}) | |
| ret_metrics_class.move_to_end("Class", last=False) | |
| class_table_data = PrettyTable() | |
| for key, val in ret_metrics_class.items(): | |
| class_table_data.add_column(key, val) | |
| logger.info("\n" + class_table_data.get_string()) | |
| self.reset() | |
| return metrics | |
| def intersect_and_union( | |
| self, | |
| pred_label: torch.tensor, | |
| label: torch.tensor, | |
| num_classes: int, | |
| ignore_index: int, | |
| ): | |
| mask = label != ignore_index | |
| pred_label = pred_label[mask] | |
| label = label[mask] | |
| intersect = pred_label[pred_label == label] | |
| area_intersect = torch.histc( | |
| intersect.float(), bins=(num_classes), min=0, max=num_classes - 1 | |
| ) | |
| area_pred_label = torch.histc( | |
| pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1 | |
| ) | |
| area_label = torch.histc( | |
| label.float(), bins=(num_classes), min=0, max=num_classes - 1 | |
| ) | |
| area_union = area_pred_label + area_label - area_intersect | |
| return area_intersect, area_union, area_pred_label, area_label | |
| def total_area_to_metrics( | |
| self, | |
| total_area_intersect: np.ndarray, | |
| total_area_union: np.ndarray, | |
| total_area_pred_label: np.ndarray, | |
| total_area_label: np.ndarray, | |
| metrics: List[str] = ["mIoU"], | |
| nan_to_num: Optional[int] = None, | |
| beta: int = 1, | |
| ): | |
| def f_score(precision, recall, beta=1): | |
| score = ( | |
| (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) | |
| ) | |
| return score | |
| if isinstance(metrics, str): | |
| metrics = [metrics] | |
| allowed_metrics = ["mIoU", "mDice", "mFscore"] | |
| if not set(metrics).issubset(set(allowed_metrics)): | |
| raise KeyError(f"metrics {metrics} is not supported") | |
| all_acc = total_area_intersect.sum() / total_area_label.sum() | |
| ret_metrics = OrderedDict({"aAcc": all_acc}) | |
| for metric in metrics: | |
| if metric == "mIoU": | |
| iou = total_area_intersect / total_area_union | |
| acc = total_area_intersect / total_area_label | |
| ret_metrics["IoU"] = iou | |
| ret_metrics["Acc"] = acc | |
| elif metric == "mDice": | |
| dice = ( | |
| 2 | |
| * total_area_intersect | |
| / (total_area_pred_label + total_area_label) | |
| ) | |
| acc = total_area_intersect / total_area_label | |
| ret_metrics["Dice"] = dice | |
| ret_metrics["Acc"] = acc | |
| elif metric == "mFscore": | |
| precision = total_area_intersect / total_area_pred_label | |
| recall = total_area_intersect / total_area_label | |
| f_value = torch.tensor( | |
| [f_score(x[0], x[1], beta) for x in zip(precision, recall)] | |
| ) | |
| ret_metrics["Fscore"] = f_value | |
| ret_metrics["Precision"] = precision | |
| ret_metrics["Recall"] = recall | |
| ret_metrics = { | |
| metric: ( | |
| value.detach().cpu().numpy() | |
| if isinstance(value, torch.Tensor) | |
| else value | |
| ) | |
| for metric, value in ret_metrics.items() | |
| } | |
| if nan_to_num is not None: | |
| ret_metrics = OrderedDict( | |
| { | |
| metric: np.nan_to_num(metric_value, nan=nan_to_num) | |
| for metric, metric_value in ret_metrics.items() | |
| } | |
| ) | |
| return ret_metrics | |