Christen Millerdurai commited on
Commit
f869f30
·
1 Parent(s): 08639bd

changed to mmcv-lite

Browse files
Files changed (3) hide show
  1. app.py +10 -2
  2. egoforce_runtime_patches.py +119 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -120,6 +120,15 @@ def ensure_egoforce_repo() -> Path:
120
  def patch_upstream_gradio_for_zerogpu(demo_entrypoint: Path) -> None:
121
  source = demo_entrypoint.read_text(encoding="utf-8")
122
 
 
 
 
 
 
 
 
 
 
123
  if "import spaces\n" not in source:
124
  if "import torch\n" not in source:
125
  raise RuntimeError(f"Could not insert ZeroGPU import in {demo_entrypoint}")
@@ -164,12 +173,11 @@ def pip_install(requirement: str, *extra_args: str) -> None:
164
  def ensure_runtime_python_packages(repo_root: Path) -> None:
165
  datapipes_path = repo_root / "thirdparty" / "datapipes"
166
  install_plan = [
167
- ("mmcv", "mmcv==2.1.0", ("--no-build-isolation", "--no-deps")),
168
  ("anycalib", "git+https://github.com/javrtg/AnyCalib.git", ("--no-build-isolation",)),
169
  ("chumpy", "git+https://github.com/mattloper/chumpy.git", ("--no-build-isolation",)),
170
  ("pytorch3d", "git+https://github.com/facebookresearch/pytorch3d.git", ("--no-build-isolation",)),
171
  ("datapipes", str(datapipes_path), ()),
172
- ("mmdet", str(repo_root / "thirdparty" / "mmdetection"), ("--no-build-isolation",)),
173
  ]
174
 
175
  for module_name, requirement, extra_args in install_plan:
 
120
  def patch_upstream_gradio_for_zerogpu(demo_entrypoint: Path) -> None:
121
  source = demo_entrypoint.read_text(encoding="utf-8")
122
 
123
+ if "from egoforce_runtime_patches import apply_runtime_patches\n" not in source:
124
+ if "import torch\n" not in source:
125
+ raise RuntimeError(f"Could not insert runtime patches in {demo_entrypoint}")
126
+ source = source.replace(
127
+ "import torch\n",
128
+ "import torch\nfrom egoforce_runtime_patches import apply_runtime_patches\napply_runtime_patches()\n",
129
+ 1,
130
+ )
131
+
132
  if "import spaces\n" not in source:
133
  if "import torch\n" not in source:
134
  raise RuntimeError(f"Could not insert ZeroGPU import in {demo_entrypoint}")
 
173
  def ensure_runtime_python_packages(repo_root: Path) -> None:
174
  datapipes_path = repo_root / "thirdparty" / "datapipes"
175
  install_plan = [
 
176
  ("anycalib", "git+https://github.com/javrtg/AnyCalib.git", ("--no-build-isolation",)),
177
  ("chumpy", "git+https://github.com/mattloper/chumpy.git", ("--no-build-isolation",)),
178
  ("pytorch3d", "git+https://github.com/facebookresearch/pytorch3d.git", ("--no-build-isolation",)),
179
  ("datapipes", str(datapipes_path), ()),
180
+ ("mmdet", str(repo_root / "thirdparty" / "mmdetection"), ("--no-build-isolation", "--no-deps")),
181
  ]
182
 
183
  for module_name, requirement, extra_args in install_plan:
egoforce_runtime_patches.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import sys
5
+ import types
6
+ from typing import Any
7
+
8
+
9
+ def _torchvision_nms():
10
+ from torchvision.ops import nms
11
+
12
+ return nms
13
+
14
+
15
+ def _nms(
16
+ boxes: Any,
17
+ scores: Any,
18
+ iou_threshold: float,
19
+ offset: int = 0,
20
+ score_threshold: float = 0,
21
+ max_num: int = -1,
22
+ ) -> tuple[Any, Any]:
23
+ import torch
24
+
25
+ if boxes.numel() == 0 or scores.numel() == 0:
26
+ keep = torch.empty((0,), dtype=torch.long, device=scores.device)
27
+ dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1)
28
+ return dets, keep
29
+
30
+ if score_threshold > 0:
31
+ valid = scores > score_threshold
32
+ original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1)
33
+ filtered_boxes = boxes[valid]
34
+ filtered_scores = scores[valid]
35
+ else:
36
+ original_indices = torch.arange(scores.numel(), device=scores.device)
37
+ filtered_boxes = boxes
38
+ filtered_scores = scores
39
+
40
+ keep_local = _torchvision_nms()(filtered_boxes, filtered_scores, float(iou_threshold))
41
+ if max_num > 0:
42
+ keep_local = keep_local[:max_num]
43
+
44
+ keep = original_indices[keep_local]
45
+ dets = torch.cat((filtered_boxes[keep_local], filtered_scores[keep_local, None]), dim=1)
46
+ return dets, keep
47
+
48
+
49
+ def _batched_nms(
50
+ boxes: Any,
51
+ scores: Any,
52
+ idxs: Any,
53
+ nms_cfg: dict[str, Any] | None,
54
+ class_agnostic: bool = False,
55
+ ) -> tuple[Any, Any]:
56
+ import torch
57
+
58
+ if boxes.numel() == 0 or scores.numel() == 0:
59
+ keep = torch.empty((0,), dtype=torch.long, device=scores.device)
60
+ dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1)
61
+ return dets, keep
62
+
63
+ if nms_cfg is None:
64
+ order = scores.argsort(descending=True)
65
+ return torch.cat((boxes[order], scores[order, None]), dim=1), order
66
+
67
+ nms_cfg = dict(nms_cfg)
68
+ iou_threshold = nms_cfg.pop("iou_threshold", nms_cfg.pop("iou_thr", 0.5))
69
+ score_threshold = nms_cfg.pop("score_threshold", 0)
70
+ max_num = nms_cfg.pop("max_num", -1)
71
+
72
+ if class_agnostic:
73
+ boxes_for_nms = boxes
74
+ else:
75
+ max_coordinate = boxes.max()
76
+ offsets = idxs.to(boxes) * (max_coordinate + boxes.new_tensor(1))
77
+ boxes_for_nms = boxes + offsets[:, None]
78
+
79
+ if score_threshold > 0:
80
+ valid = scores > score_threshold
81
+ original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1)
82
+ boxes_for_nms = boxes_for_nms[valid]
83
+ scores_for_nms = scores[valid]
84
+ else:
85
+ original_indices = torch.arange(scores.numel(), device=scores.device)
86
+ scores_for_nms = scores
87
+
88
+ keep_local = _torchvision_nms()(boxes_for_nms, scores_for_nms, float(iou_threshold))
89
+ if max_num > 0:
90
+ keep_local = keep_local[:max_num]
91
+
92
+ keep = original_indices[keep_local]
93
+ dets = torch.cat((boxes[keep], scores[keep, None]), dim=1)
94
+ return dets, keep
95
+
96
+
97
+ def apply_runtime_patches() -> None:
98
+ try:
99
+ mmcv = importlib.import_module("mmcv")
100
+ except ImportError:
101
+ mmcv = None
102
+
103
+ ops_module = sys.modules.get("mmcv.ops")
104
+ if ops_module is None:
105
+ ops_module = types.ModuleType("mmcv.ops")
106
+ sys.modules["mmcv.ops"] = ops_module
107
+
108
+ nms_module = sys.modules.get("mmcv.ops.nms")
109
+ if nms_module is None:
110
+ nms_module = types.ModuleType("mmcv.ops.nms")
111
+ sys.modules["mmcv.ops.nms"] = nms_module
112
+
113
+ ops_module.nms = _nms
114
+ ops_module.batched_nms = _batched_nms
115
+ nms_module.nms = _nms
116
+ nms_module.batched_nms = _batched_nms
117
+
118
+ if mmcv is not None:
119
+ mmcv.ops = ops_module
requirements.txt CHANGED
@@ -22,6 +22,7 @@ pycocotools==2.0.10
22
  trimesh==4.11.3
23
  sortedcontainers==2.4.0
24
  openmim==0.3.9
 
25
  mmengine==0.10.7
26
  yapf==0.43.0
27
  lmdb==2.0.0
 
22
  trimesh==4.11.3
23
  sortedcontainers==2.4.0
24
  openmim==0.3.9
25
+ mmcv-lite==2.1.0
26
  mmengine==0.10.7
27
  yapf==0.43.0
28
  lmdb==2.0.0