Christen Millerdurai commited on
Commit
b59067d
·
1 Parent(s): 761864b
Files changed (2) hide show
  1. app.py +134 -0
  2. requirements.txt +2 -0
app.py CHANGED
@@ -92,6 +92,7 @@ def ensure_egoforce_repo() -> Path:
92
  demo_entrypoint = EGOFORCE_ROOT / "demo" / "run_app.py"
93
  if demo_entrypoint.exists():
94
  patch_upstream_gradio_for_zerogpu(demo_entrypoint)
 
95
  return EGOFORCE_ROOT
96
 
97
  if EGOFORCE_ROOT.exists() and any(EGOFORCE_ROOT.iterdir()):
@@ -114,6 +115,7 @@ def ensure_egoforce_repo() -> Path:
114
  raise RuntimeError(f"EgoForce demo entrypoint not found at {demo_entrypoint}")
115
 
116
  patch_upstream_gradio_for_zerogpu(demo_entrypoint)
 
117
  return EGOFORCE_ROOT
118
 
119
 
@@ -149,9 +151,141 @@ def patch_upstream_gradio_for_zerogpu(demo_entrypoint: Path) -> None:
149
  1,
150
  )
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  demo_entrypoint.write_text(source, encoding="utf-8")
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def package_available(module_name: str) -> bool:
156
  return importlib.util.find_spec(module_name) is not None
157
 
 
92
  demo_entrypoint = EGOFORCE_ROOT / "demo" / "run_app.py"
93
  if demo_entrypoint.exists():
94
  patch_upstream_gradio_for_zerogpu(demo_entrypoint)
95
+ patch_upstream_tensorrt_fallback(EGOFORCE_ROOT)
96
  return EGOFORCE_ROOT
97
 
98
  if EGOFORCE_ROOT.exists() and any(EGOFORCE_ROOT.iterdir()):
 
115
  raise RuntimeError(f"EgoForce demo entrypoint not found at {demo_entrypoint}")
116
 
117
  patch_upstream_gradio_for_zerogpu(demo_entrypoint)
118
+ patch_upstream_tensorrt_fallback(EGOFORCE_ROOT)
119
  return EGOFORCE_ROOT
120
 
121
 
 
151
  1,
152
  )
153
 
154
+ if "def load_gradio_hero_css():\n" not in source:
155
+ marker = "GRADIO_HERO_CSS_PATH = ASSETS_CSS_DIR / \"gradio_hero.css\"\n"
156
+ if marker not in source:
157
+ raise RuntimeError(f"Could not locate CSS path constant in {demo_entrypoint}")
158
+ source = source.replace(
159
+ marker,
160
+ (
161
+ marker +
162
+ "\n"
163
+ "@lru_cache(maxsize=1)\n"
164
+ "def load_gradio_hero_css():\n"
165
+ " if not GRADIO_HERO_CSS_PATH.exists():\n"
166
+ " return None\n"
167
+ " return GRADIO_HERO_CSS_PATH.read_text(encoding=\"utf-8\")\n"
168
+ ),
169
+ 1,
170
+ )
171
+
172
+ source = source.replace(" css=load_gradio_hero_css(),\n ) as app:\n", " ) as app:\n")
173
+
174
+ launch_css_marker = " server_port=args.server_port,\n"
175
+ launch_css_line = " css=load_gradio_hero_css(),\n"
176
+ if launch_css_line not in source:
177
+ if launch_css_marker not in source:
178
+ raise RuntimeError(f"Could not locate Gradio launch arguments in {demo_entrypoint}")
179
+ source = source.replace(launch_css_marker, launch_css_marker + launch_css_line, 1)
180
+
181
  demo_entrypoint.write_text(source, encoding="utf-8")
182
 
183
 
184
+ def patch_upstream_tensorrt_fallback(repo_root: Path) -> None:
185
+ inference_path = repo_root / "demo" / "inference.py"
186
+ demo_utils_path = repo_root / "demo" / "demo_utils.py"
187
+
188
+ inference_source = inference_path.read_text(encoding="utf-8")
189
+ if "TORCH_TENSORRT_IMPORT_ERROR = None\n" not in inference_source:
190
+ import_marker = "import torch\nimport torch_tensorrt\n\n"
191
+ if import_marker not in inference_source:
192
+ raise RuntimeError(f"Could not locate torch_tensorrt import in {inference_path}")
193
+ inference_source = inference_source.replace(
194
+ import_marker,
195
+ (
196
+ "import torch\n"
197
+ "\n"
198
+ "try:\n"
199
+ " import torch_tensorrt\n"
200
+ " TORCH_TENSORRT_IMPORT_ERROR = None\n"
201
+ "except Exception as exc:\n"
202
+ " torch_tensorrt = None\n"
203
+ " TORCH_TENSORRT_IMPORT_ERROR = exc\n"
204
+ " print(f\"Torch-TensorRT unavailable: {exc}. Falling back to PyTorch inference.\", flush=True)\n"
205
+ "\n"
206
+ ),
207
+ 1,
208
+ )
209
+
210
+ runtime_marker = (
211
+ "torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n"
212
+ "torch_tensorrt.runtime.set_cudagraphs_mode(True)\n"
213
+ )
214
+ if runtime_marker in inference_source:
215
+ inference_source = inference_source.replace(
216
+ runtime_marker,
217
+ (
218
+ "if torch_tensorrt is not None:\n"
219
+ " torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n"
220
+ " torch_tensorrt.runtime.set_cudagraphs_mode(True)\n"
221
+ ),
222
+ 1,
223
+ )
224
+ inference_path.write_text(inference_source, encoding="utf-8")
225
+
226
+ demo_utils_source = demo_utils_path.read_text(encoding="utf-8")
227
+ if "Torch-TensorRT backend unavailable" not in demo_utils_source:
228
+ old_compile_function = """def compile_to_tensorrt(model, device):
229
+ x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2])
230
+ x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device)
231
+
232
+ with torch.inference_mode():
233
+ model = model.to(device).half()
234
+ x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half()
235
+ model = torch.jit.trace(model, (x1, x2, x3, x4), strict=False)
236
+
237
+ backend_kwargs = {
238
+ "enabled_precisions": {torch.half},
239
+ "min_block_size": 2,
240
+ "torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
241
+ "optimization_level": 5,
242
+ "use_python_runtime": False,
243
+ }
244
+
245
+ model = torch.compile(model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,)
246
+ with torch.no_grad():
247
+ model(x1, x2, x3, x4) # compiled on first run
248
+
249
+ return model
250
+ """
251
+ new_compile_function = """def compile_to_tensorrt(model, device):
252
+ try:
253
+ import torch_tensorrt # noqa: F401
254
+ except Exception as exc:
255
+ print(f"Torch-TensorRT backend unavailable: {exc}. Using PyTorch model.", flush=True)
256
+ return model.to(device).half()
257
+
258
+ x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2])
259
+ x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device)
260
+
261
+ with torch.inference_mode():
262
+ fallback_model = model.to(device).half()
263
+ x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half()
264
+ traced_model = torch.jit.trace(fallback_model, (x1, x2, x3, x4), strict=False)
265
+
266
+ backend_kwargs = {
267
+ "enabled_precisions": {torch.half},
268
+ "min_block_size": 2,
269
+ "torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
270
+ "optimization_level": 5,
271
+ "use_python_runtime": False,
272
+ }
273
+
274
+ try:
275
+ compiled_model = torch.compile(traced_model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,)
276
+ with torch.no_grad():
277
+ compiled_model(x1, x2, x3, x4) # compiled on first run
278
+ return compiled_model
279
+ except Exception as exc:
280
+ print(f"Torch-TensorRT compile failed: {exc}. Using PyTorch model.", flush=True)
281
+ return fallback_model
282
+ """
283
+ if old_compile_function not in demo_utils_source:
284
+ raise RuntimeError(f"Could not locate compile_to_tensorrt in {demo_utils_path}")
285
+ demo_utils_source = demo_utils_source.replace(old_compile_function, new_compile_function, 1)
286
+ demo_utils_path.write_text(demo_utils_source, encoding="utf-8")
287
+
288
+
289
  def package_available(module_name: str) -> bool:
290
  return importlib.util.find_spec(module_name) is not None
291
 
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  --extra-index-url https://miropsota.github.io/torch_packages_builder
 
2
 
3
  numpy==1.26.4
4
  torch==2.8.0
5
  torchvision==0.23.0
 
6
  pytorch3d==0.7.9+pt2.8.0cu128
7
  opencv-python==4.11.0.86
8
  pillow==11.3.0
 
1
  --extra-index-url https://miropsota.github.io/torch_packages_builder
2
+ --find-links https://download.pytorch.org/whl/torch-tensorrt
3
 
4
  numpy==1.26.4
5
  torch==2.8.0
6
  torchvision==0.23.0
7
+ torch_tensorrt==2.8.0+cu128
8
  pytorch3d==0.7.9+pt2.8.0cu128
9
  opencv-python==4.11.0.86
10
  pillow==11.3.0