| import os |
| import sys |
|
|
| import torch |
| import torchvision |
| from fvcore.nn import FlopCountAnalysis |
| from torch import nn |
|
|
| sys.path.append("vision/references/segmentation") |
| from transforms import Compose |
| from coco_utils import ConvertCocoPolysToMask |
| from coco_utils import FilterAndRemapCocoCategories |
| from coco_utils import _coco_remove_images_without_annotations |
| from utils import ConfusionMatrix |
|
|
|
|
| class NanSafeConfusionMatrix(ConfusionMatrix): |
| """Confusion matrix with replacement nans to zeros.""" |
|
|
| def __init__(self, num_classes): |
| super().__init__(num_classes=num_classes) |
|
|
| def compute(self): |
| """Compute metrics based on confusion matrix.""" |
| confusion_matrix = self.mat.float() |
| acc_global = torch.nan_to_num(torch.diag(confusion_matrix).sum() / confusion_matrix.sum()) |
| acc = torch.nan_to_num(torch.diag(confusion_matrix) / confusion_matrix.sum(1)) |
| intersection_over_unions = torch.nan_to_num( |
| torch.diag(confusion_matrix) |
| / (confusion_matrix.sum(1) + confusion_matrix.sum(0) - torch.diag(confusion_matrix)) |
| ) |
| return acc_global, acc, intersection_over_unions |
|
|
|
|
| def flops_calculation_function(model: nn.Module, input_sample: torch.Tensor) -> float: |
| """Calculate number of flops in millions.""" |
| counter = FlopCountAnalysis( |
| model=model.eval(), |
| inputs=input_sample, |
| ) |
| counter.unsupported_ops_warnings(False) |
| counter.uncalled_modules_warnings(False) |
|
|
| flops = counter.total() / input_sample.shape[0] |
|
|
| return flops / 1e6 |
|
|
|
|
| def get_coco(root, image_set, transforms, use_v2=False, use_orig=False): |
| """Get COCO dataset with VOC or COCO classes.""" |
| paths = { |
| "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
| "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
| |
| } |
| if use_orig: |
| classes_list = list(range(81)) |
| else: |
| classes_list = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
|
|
| img_folder, ann_file = paths[image_set] |
| img_folder = os.path.join(root, img_folder) |
| ann_file = os.path.join(root, ann_file) |
|
|
| |
| |
| |
| |
| if use_v2: |
| import v2_extras |
| from torchvision.datasets import wrap_dataset_for_transforms_v2 |
|
|
| transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) |
| dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
| dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) |
| else: |
| transforms = Compose( |
| [FilterAndRemapCocoCategories(classes_list, remap=True), ConvertCocoPolysToMask(), transforms] |
| ) |
| dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
|
|
| if image_set == "train": |
| dataset = _coco_remove_images_without_annotations(dataset, classes_list) |
|
|
| return dataset |
|
|