sapiens2-normal / sapiens /dense /src /evaluators /seg_evaluator.py
Rawal Khirodkar
Initial sapiens2-normal Space (HF download at startup, all 4 sizes)
ba23d94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import defaultdict, OrderedDict
from typing import Dict, List, Optional, Sequence
import numpy as np
import torch
import torch.nn.functional as F
from prettytable import PrettyTable
from sapiens.engine.evaluators import BaseEvaluator
from sapiens.registry import MODELS
from ..datasets.seg.seg_dome_dataset import DOME_CLASSES_29
@MODELS.register_module()
class SegEvaluator(BaseEvaluator):
def __init__(
self,
class_names="dome29",
ignore_index: int = 255,
iou_metrics: List[str] = ["mIoU"],
nan_to_num: Optional[int] = None,
beta: int = 1,
):
super().__init__()
self.ignore_index = ignore_index
self.class_names = (
self.extract_class(DOME_CLASSES_29) if class_names == "dome29" else None
)
self.metrics = iou_metrics
self.nan_to_num = nan_to_num
self.beta = beta
def extract_class(self, class_names):
return [class_info["name"] for _, class_info in class_names.items()]
@torch.no_grad()
def process(self, pred_logits, data_samples: dict, accelerator=None):
assert accelerator is not None, "evaluation process expects an accelerator"
num_classes = pred_logits.shape[1]
ai_list, au_list, apl_list, al_list = [], [], [], []
for i in range(len(pred_logits)):
pred_logit = pred_logits[i] # C x H x W
gt_label = data_samples[i]["gt_seg"].squeeze() # H x W
if pred_logit.shape[2:] != gt_label.shape:
pred_logit = F.interpolate(
input=pred_logit.unsqueeze(0),
size=gt_label.shape,
mode="bilinear",
align_corners=False,
antialias=False,
).squeeze(0)
pred_label = pred_logit.argmax(dim=0) # H x W
a_i, a_u, a_pl, a_l = self.intersect_and_union(
pred_label, gt_label, num_classes, self.ignore_index
)
ai_list.append(a_i)
au_list.append(a_u)
apl_list.append(a_pl)
al_list.append(a_l)
# Local per-batch tensors: (B_local, C)
ai = torch.stack(ai_list, dim=0)
au = torch.stack(au_list, dim=0)
apl = torch.stack(apl_list, dim=0)
al = torch.stack(al_list, dim=0)
# Pack as (B_local, 4, C) so gather concatenates along the batch dim.
pack = torch.stack([ai, au, apl, al], dim=1) # (B_local, 4, C)
gpack = accelerator.gather_for_metrics(pack) # (B_global_this_step, 4, C)
batch_tot = gpack.sum(dim=0) # (4, C) global for this step
ai_g, au_g, apl_g, al_g = batch_tot[0], batch_tot[1], batch_tot[2], batch_tot[3]
# Only rank-0 appends real totals for this batch
if accelerator.is_main_process:
self.results.append((ai_g, au_g, apl_g, al_g))
return
def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]:
assert accelerator is not None, "evaluation aggregation expects an accelerator"
if not accelerator.is_main_process:
self.reset()
return {}
if not self.results:
if logger is not None:
logger.info("No results to evaluate.")
return {}
per_field = list(zip(*self.results)) # [(ai_b), (au_b), (apl_b), (al_b)]
totals = [torch.stack(x, dim=0).sum(dim=0) for x in per_field]
(
total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
) = totals # tensors already reduced across ranks
ret_metrics = self.total_area_to_metrics(
total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
self.metrics,
self.nan_to_num,
self.beta,
)
ret_metrics_summary = OrderedDict(
{
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
}
)
metrics = dict()
for key, val in ret_metrics_summary.items():
if key == "aAcc":
metrics[key] = val
else:
metrics["m" + key] = val
# each class table
ret_metrics.pop("aAcc", None)
ret_metrics_class = OrderedDict(
{
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
}
)
if self.class_names is not None:
ret_metrics_class.update({"Class": self.class_names})
ret_metrics_class.move_to_end("Class", last=False)
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
logger.info("\n" + class_table_data.get_string())
self.reset()
return metrics
def intersect_and_union(
self,
pred_label: torch.tensor,
label: torch.tensor,
num_classes: int,
ignore_index: int,
):
mask = label != ignore_index
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1
)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1
)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes - 1
)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
def total_area_to_metrics(
self,
total_area_intersect: np.ndarray,
total_area_union: np.ndarray,
total_area_pred_label: np.ndarray,
total_area_label: np.ndarray,
metrics: List[str] = ["mIoU"],
nan_to_num: Optional[int] = None,
beta: int = 1,
):
def f_score(precision, recall, beta=1):
score = (
(1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
)
return score
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ["mIoU", "mDice", "mFscore"]
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError(f"metrics {metrics} is not supported")
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({"aAcc": all_acc})
for metric in metrics:
if metric == "mIoU":
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics["IoU"] = iou
ret_metrics["Acc"] = acc
elif metric == "mDice":
dice = (
2
* total_area_intersect
/ (total_area_pred_label + total_area_label)
)
acc = total_area_intersect / total_area_label
ret_metrics["Dice"] = dice
ret_metrics["Acc"] = acc
elif metric == "mFscore":
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)]
)
ret_metrics["Fscore"] = f_value
ret_metrics["Precision"] = precision
ret_metrics["Recall"] = recall
ret_metrics = {
metric: (
value.detach().cpu().numpy()
if isinstance(value, torch.Tensor)
else value
)
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict(
{
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
}
)
return ret_metrics