| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, List, Union |
|
|
| import numpy as np |
| import torch |
| from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector |
| from monai.inferers.inferer import Inferer |
| from torch import Tensor |
|
|
|
|
| class RetinaNetInferer(Inferer): |
| """ |
| RetinaNet Inferer takes RetinaNet as input |
| |
| Args: |
| detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP |
| map into boxes and classification scores. |
| force_sliding_window: whether to force using a SlidingWindowInferer to do the inference. |
| If False, will check the input spatial size to decide whether to simply |
| forward the network or using SlidingWindowInferer. |
| If True, will force using SlidingWindowInferer to do the inference. |
| args: other optional args to be passed to detector. |
| kwargs: other optional keyword args to be passed to detector. |
| """ |
|
|
| def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None: |
| Inferer.__init__(self) |
| self.detector = detector |
| self.sliding_window_size = None |
| self.force_sliding_window = force_sliding_window |
| if self.detector.inferer is not None: |
| if hasattr(self.detector.inferer, "roi_size"): |
| self.sliding_window_size = np.prod(self.detector.inferer.roi_size) |
|
|
| def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any): |
| """Unified callable function API of Inferers. |
| Args: |
| inputs: model input data for inference. |
| network: target detection network to execute inference. |
| supports callable that fullfilles requirements of network in |
| monai.apps.detection.networks.retinanet_detector.RetinaNetDetector`` |
| args: optional args to be passed to ``network``. |
| kwargs: optional keyword args to be passed to ``network``. |
| """ |
| self.detector.network = network |
| self.detector.training = self.detector.network.training |
|
|
| |
| |
| use_inferer = ( |
| self.force_sliding_window |
| or self.sliding_window_size is not None |
| and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]) |
| ) |
|
|
| return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs) |
|
|