Christen Millerdurai commited on
Commit ·
b59067d
1
Parent(s): 761864b
bug fix
Browse files- app.py +134 -0
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -92,6 +92,7 @@ def ensure_egoforce_repo() -> Path:
|
|
| 92 |
demo_entrypoint = EGOFORCE_ROOT / "demo" / "run_app.py"
|
| 93 |
if demo_entrypoint.exists():
|
| 94 |
patch_upstream_gradio_for_zerogpu(demo_entrypoint)
|
|
|
|
| 95 |
return EGOFORCE_ROOT
|
| 96 |
|
| 97 |
if EGOFORCE_ROOT.exists() and any(EGOFORCE_ROOT.iterdir()):
|
|
@@ -114,6 +115,7 @@ def ensure_egoforce_repo() -> Path:
|
|
| 114 |
raise RuntimeError(f"EgoForce demo entrypoint not found at {demo_entrypoint}")
|
| 115 |
|
| 116 |
patch_upstream_gradio_for_zerogpu(demo_entrypoint)
|
|
|
|
| 117 |
return EGOFORCE_ROOT
|
| 118 |
|
| 119 |
|
|
@@ -149,9 +151,141 @@ def patch_upstream_gradio_for_zerogpu(demo_entrypoint: Path) -> None:
|
|
| 149 |
1,
|
| 150 |
)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
demo_entrypoint.write_text(source, encoding="utf-8")
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
def package_available(module_name: str) -> bool:
|
| 156 |
return importlib.util.find_spec(module_name) is not None
|
| 157 |
|
|
|
|
| 92 |
demo_entrypoint = EGOFORCE_ROOT / "demo" / "run_app.py"
|
| 93 |
if demo_entrypoint.exists():
|
| 94 |
patch_upstream_gradio_for_zerogpu(demo_entrypoint)
|
| 95 |
+
patch_upstream_tensorrt_fallback(EGOFORCE_ROOT)
|
| 96 |
return EGOFORCE_ROOT
|
| 97 |
|
| 98 |
if EGOFORCE_ROOT.exists() and any(EGOFORCE_ROOT.iterdir()):
|
|
|
|
| 115 |
raise RuntimeError(f"EgoForce demo entrypoint not found at {demo_entrypoint}")
|
| 116 |
|
| 117 |
patch_upstream_gradio_for_zerogpu(demo_entrypoint)
|
| 118 |
+
patch_upstream_tensorrt_fallback(EGOFORCE_ROOT)
|
| 119 |
return EGOFORCE_ROOT
|
| 120 |
|
| 121 |
|
|
|
|
| 151 |
1,
|
| 152 |
)
|
| 153 |
|
| 154 |
+
if "def load_gradio_hero_css():\n" not in source:
|
| 155 |
+
marker = "GRADIO_HERO_CSS_PATH = ASSETS_CSS_DIR / \"gradio_hero.css\"\n"
|
| 156 |
+
if marker not in source:
|
| 157 |
+
raise RuntimeError(f"Could not locate CSS path constant in {demo_entrypoint}")
|
| 158 |
+
source = source.replace(
|
| 159 |
+
marker,
|
| 160 |
+
(
|
| 161 |
+
marker +
|
| 162 |
+
"\n"
|
| 163 |
+
"@lru_cache(maxsize=1)\n"
|
| 164 |
+
"def load_gradio_hero_css():\n"
|
| 165 |
+
" if not GRADIO_HERO_CSS_PATH.exists():\n"
|
| 166 |
+
" return None\n"
|
| 167 |
+
" return GRADIO_HERO_CSS_PATH.read_text(encoding=\"utf-8\")\n"
|
| 168 |
+
),
|
| 169 |
+
1,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
source = source.replace(" css=load_gradio_hero_css(),\n ) as app:\n", " ) as app:\n")
|
| 173 |
+
|
| 174 |
+
launch_css_marker = " server_port=args.server_port,\n"
|
| 175 |
+
launch_css_line = " css=load_gradio_hero_css(),\n"
|
| 176 |
+
if launch_css_line not in source:
|
| 177 |
+
if launch_css_marker not in source:
|
| 178 |
+
raise RuntimeError(f"Could not locate Gradio launch arguments in {demo_entrypoint}")
|
| 179 |
+
source = source.replace(launch_css_marker, launch_css_marker + launch_css_line, 1)
|
| 180 |
+
|
| 181 |
demo_entrypoint.write_text(source, encoding="utf-8")
|
| 182 |
|
| 183 |
|
| 184 |
+
def patch_upstream_tensorrt_fallback(repo_root: Path) -> None:
|
| 185 |
+
inference_path = repo_root / "demo" / "inference.py"
|
| 186 |
+
demo_utils_path = repo_root / "demo" / "demo_utils.py"
|
| 187 |
+
|
| 188 |
+
inference_source = inference_path.read_text(encoding="utf-8")
|
| 189 |
+
if "TORCH_TENSORRT_IMPORT_ERROR = None\n" not in inference_source:
|
| 190 |
+
import_marker = "import torch\nimport torch_tensorrt\n\n"
|
| 191 |
+
if import_marker not in inference_source:
|
| 192 |
+
raise RuntimeError(f"Could not locate torch_tensorrt import in {inference_path}")
|
| 193 |
+
inference_source = inference_source.replace(
|
| 194 |
+
import_marker,
|
| 195 |
+
(
|
| 196 |
+
"import torch\n"
|
| 197 |
+
"\n"
|
| 198 |
+
"try:\n"
|
| 199 |
+
" import torch_tensorrt\n"
|
| 200 |
+
" TORCH_TENSORRT_IMPORT_ERROR = None\n"
|
| 201 |
+
"except Exception as exc:\n"
|
| 202 |
+
" torch_tensorrt = None\n"
|
| 203 |
+
" TORCH_TENSORRT_IMPORT_ERROR = exc\n"
|
| 204 |
+
" print(f\"Torch-TensorRT unavailable: {exc}. Falling back to PyTorch inference.\", flush=True)\n"
|
| 205 |
+
"\n"
|
| 206 |
+
),
|
| 207 |
+
1,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
runtime_marker = (
|
| 211 |
+
"torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n"
|
| 212 |
+
"torch_tensorrt.runtime.set_cudagraphs_mode(True)\n"
|
| 213 |
+
)
|
| 214 |
+
if runtime_marker in inference_source:
|
| 215 |
+
inference_source = inference_source.replace(
|
| 216 |
+
runtime_marker,
|
| 217 |
+
(
|
| 218 |
+
"if torch_tensorrt is not None:\n"
|
| 219 |
+
" torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n"
|
| 220 |
+
" torch_tensorrt.runtime.set_cudagraphs_mode(True)\n"
|
| 221 |
+
),
|
| 222 |
+
1,
|
| 223 |
+
)
|
| 224 |
+
inference_path.write_text(inference_source, encoding="utf-8")
|
| 225 |
+
|
| 226 |
+
demo_utils_source = demo_utils_path.read_text(encoding="utf-8")
|
| 227 |
+
if "Torch-TensorRT backend unavailable" not in demo_utils_source:
|
| 228 |
+
old_compile_function = """def compile_to_tensorrt(model, device):
|
| 229 |
+
x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2])
|
| 230 |
+
x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device)
|
| 231 |
+
|
| 232 |
+
with torch.inference_mode():
|
| 233 |
+
model = model.to(device).half()
|
| 234 |
+
x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half()
|
| 235 |
+
model = torch.jit.trace(model, (x1, x2, x3, x4), strict=False)
|
| 236 |
+
|
| 237 |
+
backend_kwargs = {
|
| 238 |
+
"enabled_precisions": {torch.half},
|
| 239 |
+
"min_block_size": 2,
|
| 240 |
+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
|
| 241 |
+
"optimization_level": 5,
|
| 242 |
+
"use_python_runtime": False,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
model = torch.compile(model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,)
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
model(x1, x2, x3, x4) # compiled on first run
|
| 248 |
+
|
| 249 |
+
return model
|
| 250 |
+
"""
|
| 251 |
+
new_compile_function = """def compile_to_tensorrt(model, device):
|
| 252 |
+
try:
|
| 253 |
+
import torch_tensorrt # noqa: F401
|
| 254 |
+
except Exception as exc:
|
| 255 |
+
print(f"Torch-TensorRT backend unavailable: {exc}. Using PyTorch model.", flush=True)
|
| 256 |
+
return model.to(device).half()
|
| 257 |
+
|
| 258 |
+
x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2])
|
| 259 |
+
x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device)
|
| 260 |
+
|
| 261 |
+
with torch.inference_mode():
|
| 262 |
+
fallback_model = model.to(device).half()
|
| 263 |
+
x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half()
|
| 264 |
+
traced_model = torch.jit.trace(fallback_model, (x1, x2, x3, x4), strict=False)
|
| 265 |
+
|
| 266 |
+
backend_kwargs = {
|
| 267 |
+
"enabled_precisions": {torch.half},
|
| 268 |
+
"min_block_size": 2,
|
| 269 |
+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
|
| 270 |
+
"optimization_level": 5,
|
| 271 |
+
"use_python_runtime": False,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
compiled_model = torch.compile(traced_model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,)
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
compiled_model(x1, x2, x3, x4) # compiled on first run
|
| 278 |
+
return compiled_model
|
| 279 |
+
except Exception as exc:
|
| 280 |
+
print(f"Torch-TensorRT compile failed: {exc}. Using PyTorch model.", flush=True)
|
| 281 |
+
return fallback_model
|
| 282 |
+
"""
|
| 283 |
+
if old_compile_function not in demo_utils_source:
|
| 284 |
+
raise RuntimeError(f"Could not locate compile_to_tensorrt in {demo_utils_path}")
|
| 285 |
+
demo_utils_source = demo_utils_source.replace(old_compile_function, new_compile_function, 1)
|
| 286 |
+
demo_utils_path.write_text(demo_utils_source, encoding="utf-8")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
def package_available(module_name: str) -> bool:
|
| 290 |
return importlib.util.find_spec(module_name) is not None
|
| 291 |
|
requirements.txt
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
--extra-index-url https://miropsota.github.io/torch_packages_builder
|
|
|
|
| 2 |
|
| 3 |
numpy==1.26.4
|
| 4 |
torch==2.8.0
|
| 5 |
torchvision==0.23.0
|
|
|
|
| 6 |
pytorch3d==0.7.9+pt2.8.0cu128
|
| 7 |
opencv-python==4.11.0.86
|
| 8 |
pillow==11.3.0
|
|
|
|
| 1 |
--extra-index-url https://miropsota.github.io/torch_packages_builder
|
| 2 |
+
--find-links https://download.pytorch.org/whl/torch-tensorrt
|
| 3 |
|
| 4 |
numpy==1.26.4
|
| 5 |
torch==2.8.0
|
| 6 |
torchvision==0.23.0
|
| 7 |
+
torch_tensorrt==2.8.0+cu128
|
| 8 |
pytorch3d==0.7.9+pt2.8.0cu128
|
| 9 |
opencv-python==4.11.0.86
|
| 10 |
pillow==11.3.0
|