| |
| from typing import Sequence, Tuple |
|
|
| import torch |
| from mmcv.ops import batched_nms |
| from mmengine.structures import InstanceData |
|
|
| from mmdet.structures import DetDataSample, SampleList |
|
|
|
|
| def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): |
| """Shift rotated bboxes with offset. |
| |
| Args: |
| bboxes (Tensor): The rotated bboxes need to be translated. |
| With shape (n, 5), which means (x, y, w, h, a). |
| offset (Sequence[int]): The translation offsets with shape of (2, ). |
| Returns: |
| Tensor: Shifted rotated bboxes. |
| """ |
| offset_tensor = bboxes.new_tensor(offset) |
| shifted_bboxes = bboxes.clone() |
| shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor |
| return shifted_bboxes |
|
|
|
|
| def shift_predictions(det_data_samples: SampleList, |
| offsets: Sequence[Tuple[int, int]], |
| src_image_shape: Tuple[int, int]) -> SampleList: |
| """Shift predictions to the original image. |
| |
| Args: |
| det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. |
| offsets (Sequence[Tuple[int, int]]): Positions of the left top points |
| of patches. |
| src_image_shape (Tuple[int, int]): A (height, width) tuple of the large |
| image's width and height. |
| Returns: |
| (List[:obj:`DetDataSample`]): shifted results. |
| """ |
| try: |
| from sahi.slicing import shift_bboxes, shift_masks |
| except ImportError: |
| raise ImportError('Please run "pip install -U sahi" ' |
| 'to install sahi first for large image inference.') |
|
|
| assert len(det_data_samples) == len( |
| offsets), 'The `results` should has the ' 'same length with `offsets`.' |
| shifted_predictions = [] |
| for det_data_sample, offset in zip(det_data_samples, offsets): |
| pred_inst = det_data_sample.pred_instances.clone() |
|
|
| |
| if pred_inst.bboxes.size(-1) == 4: |
| |
| shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) |
| elif pred_inst.bboxes.size(-1) == 5: |
| |
| shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) |
| else: |
| raise NotImplementedError |
|
|
| |
| pred_inst.bboxes = shifted_bboxes |
| if 'masks' in det_data_sample: |
| pred_inst.masks = shift_masks(pred_inst.masks, offset, |
| src_image_shape) |
|
|
| shifted_predictions.append(pred_inst.clone()) |
|
|
| shifted_predictions = InstanceData.cat(shifted_predictions) |
|
|
| return shifted_predictions |
|
|
|
|
| def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, |
| int]], |
| src_image_shape: Tuple[int, int], |
| nms_cfg: dict) -> DetDataSample: |
| """Merge patch results by nms. |
| |
| Args: |
| results (List[:obj:`DetDataSample`]): A list of patch results. |
| offsets (Sequence[Tuple[int, int]]): Positions of the left top points |
| of patches. |
| src_image_shape (Tuple[int, int]): A (height, width) tuple of the large |
| image's width and height. |
| nms_cfg (dict): it should specify nms type and other parameters |
| like `iou_threshold`. |
| Returns: |
| :obj:`DetDataSample`: merged results. |
| """ |
| shifted_instances = shift_predictions(results, offsets, src_image_shape) |
|
|
| _, keeps = batched_nms( |
| boxes=shifted_instances.bboxes, |
| scores=shifted_instances.scores, |
| idxs=shifted_instances.labels, |
| nms_cfg=nms_cfg) |
| merged_instances = shifted_instances[keeps] |
|
|
| merged_result = results[0].clone() |
| merged_result.pred_instances = merged_instances |
| return merged_result |
|
|