Christen Millerdurai commited on
Commit
1f5020c
·
1 Parent(s): f869f30

changed to mmcv-lite

Browse files
Files changed (1) hide show
  1. egoforce_runtime_patches.py +80 -0
egoforce_runtime_patches.py CHANGED
@@ -12,6 +12,74 @@ def _torchvision_nms():
12
  return nms
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def _nms(
16
  boxes: Any,
17
  scores: Any,
@@ -104,16 +172,28 @@ def apply_runtime_patches() -> None:
104
  if ops_module is None:
105
  ops_module = types.ModuleType("mmcv.ops")
106
  sys.modules["mmcv.ops"] = ops_module
 
107
 
108
  nms_module = sys.modules.get("mmcv.ops.nms")
109
  if nms_module is None:
110
  nms_module = types.ModuleType("mmcv.ops.nms")
111
  sys.modules["mmcv.ops.nms"] = nms_module
112
 
 
 
 
 
 
113
  ops_module.nms = _nms
114
  ops_module.batched_nms = _batched_nms
 
 
 
 
115
  nms_module.nms = _nms
116
  nms_module.batched_nms = _batched_nms
 
 
117
 
118
  if mmcv is not None:
119
  mmcv.ops = ops_module
 
12
  return nms
13
 
14
 
15
+ def _torchvision_roi_align():
16
+ from torchvision.ops import roi_align
17
+
18
+ return roi_align
19
+
20
+
21
+ def _torchvision_roi_align_module():
22
+ from torchvision.ops import RoIAlign
23
+
24
+ return RoIAlign
25
+
26
+
27
+ def _torchvision_roi_pool_module():
28
+ from torchvision.ops import RoIPool
29
+
30
+ return RoIPool
31
+
32
+
33
+ def _bbox_overlaps(
34
+ bboxes1: Any,
35
+ bboxes2: Any,
36
+ mode: str = "iou",
37
+ aligned: bool = False,
38
+ offset: int = 0,
39
+ eps: float = 1e-6,
40
+ ) -> Any:
41
+ import torch
42
+
43
+ if bboxes1.numel() == 0 or bboxes2.numel() == 0:
44
+ if aligned:
45
+ return bboxes1.new_zeros((bboxes1.shape[0],))
46
+ return bboxes1.new_zeros((bboxes1.shape[0], bboxes2.shape[0]))
47
+
48
+ if aligned:
49
+ lt = torch.maximum(bboxes1[:, :2], bboxes2[:, :2])
50
+ rb = torch.minimum(bboxes1[:, 2:], bboxes2[:, 2:])
51
+ wh = (rb - lt + offset).clamp(min=0)
52
+ overlap = wh[:, 0] * wh[:, 1]
53
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset)
54
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset)
55
+ else:
56
+ lt = torch.maximum(bboxes1[:, None, :2], bboxes2[None, :, :2])
57
+ rb = torch.minimum(bboxes1[:, None, 2:], bboxes2[None, :, 2:])
58
+ wh = (rb - lt + offset).clamp(min=0)
59
+ overlap = wh[..., 0] * wh[..., 1]
60
+ area1 = ((bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset))[:, None]
61
+ area2 = ((bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset))[None, :]
62
+
63
+ if mode == "iof":
64
+ union = area1
65
+ elif mode == "giou":
66
+ union = area1 + area2 - overlap
67
+ if aligned:
68
+ enclosed_lt = torch.minimum(bboxes1[:, :2], bboxes2[:, :2])
69
+ enclosed_rb = torch.maximum(bboxes1[:, 2:], bboxes2[:, 2:])
70
+ else:
71
+ enclosed_lt = torch.minimum(bboxes1[:, None, :2], bboxes2[None, :, :2])
72
+ enclosed_rb = torch.maximum(bboxes1[:, None, 2:], bboxes2[None, :, 2:])
73
+ enclosed_wh = (enclosed_rb - enclosed_lt + offset).clamp(min=0)
74
+ enclosed_area = enclosed_wh[..., 0] * enclosed_wh[..., 1]
75
+ iou = overlap / union.clamp(min=eps)
76
+ return iou - (enclosed_area - union) / enclosed_area.clamp(min=eps)
77
+ else:
78
+ union = area1 + area2 - overlap
79
+
80
+ return overlap / union.clamp(min=eps)
81
+
82
+
83
  def _nms(
84
  boxes: Any,
85
  scores: Any,
 
172
  if ops_module is None:
173
  ops_module = types.ModuleType("mmcv.ops")
174
  sys.modules["mmcv.ops"] = ops_module
175
+ ops_module.__path__ = []
176
 
177
  nms_module = sys.modules.get("mmcv.ops.nms")
178
  if nms_module is None:
179
  nms_module = types.ModuleType("mmcv.ops.nms")
180
  sys.modules["mmcv.ops.nms"] = nms_module
181
 
182
+ roi_align_module = sys.modules.get("mmcv.ops.roi_align")
183
+ if roi_align_module is None:
184
+ roi_align_module = types.ModuleType("mmcv.ops.roi_align")
185
+ sys.modules["mmcv.ops.roi_align"] = roi_align_module
186
+
187
  ops_module.nms = _nms
188
  ops_module.batched_nms = _batched_nms
189
+ ops_module.bbox_overlaps = _bbox_overlaps
190
+ ops_module.roi_align = _torchvision_roi_align()
191
+ ops_module.RoIAlign = _torchvision_roi_align_module()
192
+ ops_module.RoIPool = _torchvision_roi_pool_module()
193
  nms_module.nms = _nms
194
  nms_module.batched_nms = _batched_nms
195
+ roi_align_module.roi_align = ops_module.roi_align
196
+ roi_align_module.RoIAlign = ops_module.RoIAlign
197
 
198
  if mmcv is not None:
199
  mmcv.ops = ops_module