| from __future__ import annotations |
|
|
| import importlib.util |
| import os |
| import re |
| import runpy |
| import shlex |
| import shutil |
| import site |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| SPACE_ROOT = Path(__file__).resolve().parent |
| EGOFORCE_REPO_URL = os.environ.get("EGOFORCE_REPO_URL", "https://github.com/dfki-av/EgoForce") |
| EGOFORCE_REF = os.environ.get("EGOFORCE_REF", "main") |
| EGOFORCE_ROOT = Path(os.environ.get("EGOFORCE_ROOT", SPACE_ROOT / "EgoForce")).resolve() |
| EGOFORCE_ASSETS_REPO_ID = os.environ.get("EGOFORCE_ASSETS_REPO_ID", "chris10/EgoForce") |
| ZEROGPU_DURATION_SECONDS = os.environ.get("ZEROGPU_DURATION_SECONDS", "210") |
| ZEROGPU_SIZE = os.environ.get("ZEROGPU_SIZE", "large") |
| HERO_CSS_SPACE_PATCH_MARKER = "/* EgoForce Space patch: dual-theme hero styling */" |
| HERO_CSS_SPACE_PATCH = f""" |
| {HERO_CSS_SPACE_PATCH_MARKER} |
| :root, |
| .gradio-container {{ |
| --egoforce-bg-start: #fcf8f1; |
| --egoforce-bg-end: #f3ead8; |
| --egoforce-bg-accent: rgba(201, 95, 73, 0.14); |
| --egoforce-card-bg: linear-gradient(168deg, rgba(255, 251, 245, 0.96), rgba(246, 238, 225, 0.94)); |
| --egoforce-card-border: rgba(173, 128, 93, 0.28); |
| --egoforce-shadow: 0 20px 52px rgba(111, 76, 44, 0.14); |
| --egoforce-text-primary: #1f2937; |
| --egoforce-text-secondary: #374151; |
| --egoforce-text-muted: #4b5563; |
| --egoforce-link-bg: rgba(255, 255, 255, 0.82); |
| --egoforce-link-border: rgba(128, 91, 58, 0.35); |
| --egoforce-link-text: #1f2937; |
| --egoforce-link-hover-bg: rgba(201, 95, 73, 0.15); |
| --egoforce-link-hover-border: rgba(201, 95, 73, 0.48); |
| }} |
| |
| @media (prefers-color-scheme: dark) {{ |
| :root, |
| .gradio-container {{ |
| --egoforce-bg-start: #0b1020; |
| --egoforce-bg-end: #121a2f; |
| --egoforce-bg-accent: rgba(56, 189, 248, 0.18); |
| --egoforce-card-bg: linear-gradient(165deg, rgba(17, 24, 39, 0.94), rgba(30, 41, 59, 0.9)); |
| --egoforce-card-border: rgba(125, 211, 252, 0.26); |
| --egoforce-shadow: 0 24px 56px rgba(2, 8, 23, 0.52); |
| --egoforce-text-primary: #e5e7eb; |
| --egoforce-text-secondary: #d1d5db; |
| --egoforce-text-muted: #cbd5e1; |
| --egoforce-link-bg: rgba(15, 23, 42, 0.62); |
| --egoforce-link-border: rgba(125, 211, 252, 0.4); |
| --egoforce-link-text: #e2e8f0; |
| --egoforce-link-hover-bg: rgba(14, 116, 144, 0.35); |
| --egoforce-link-hover-border: rgba(125, 211, 252, 0.65); |
| }} |
| }} |
| |
| html.dark, |
| body.dark, |
| .dark, |
| html[data-theme="dark"], |
| body[data-theme="dark"] {{ |
| --egoforce-bg-start: #0b1020; |
| --egoforce-bg-end: #121a2f; |
| --egoforce-bg-accent: rgba(56, 189, 248, 0.18); |
| --egoforce-card-bg: linear-gradient(165deg, rgba(17, 24, 39, 0.94), rgba(30, 41, 59, 0.9)); |
| --egoforce-card-border: rgba(125, 211, 252, 0.26); |
| --egoforce-shadow: 0 24px 56px rgba(2, 8, 23, 0.52); |
| --egoforce-text-primary: #e5e7eb; |
| --egoforce-text-secondary: #d1d5db; |
| --egoforce-text-muted: #cbd5e1; |
| --egoforce-link-bg: rgba(15, 23, 42, 0.62); |
| --egoforce-link-border: rgba(125, 211, 252, 0.4); |
| --egoforce-link-text: #e2e8f0; |
| --egoforce-link-hover-bg: rgba(14, 116, 144, 0.35); |
| --egoforce-link-hover-border: rgba(125, 211, 252, 0.65); |
| }} |
| |
| body, |
| .gradio-container {{ |
| background-image: |
| radial-gradient(circle at center top, var(--egoforce-bg-accent), rgba(0, 0, 0, 0) 36%), |
| linear-gradient(180deg, var(--egoforce-bg-start) 0%, var(--egoforce-bg-end) 100%) !important; |
| color: var(--egoforce-text-primary) !important; |
| }} |
| |
| .egoforce-hero-card, |
| .prose .egoforce-hero .egoforce-hero-card {{ |
| background: var(--block-background-fill) !important; |
| background-image: linear-gradient( |
| 168deg, |
| color-mix(in srgb, var(--block-background-fill) 94%, #ffffff 6%), |
| color-mix(in srgb, var(--block-background-fill) 96%, #000000 4%) |
| ) !important; |
| border: 1px solid var(--block-border-color) !important; |
| box-shadow: var(--egoforce-shadow) !important; |
| }} |
| |
| .egoforce-hero-title, |
| .prose .egoforce-hero .egoforce-hero-title {{ |
| font-size: clamp(2.6rem, 6vw, 4.8rem) !important; |
| line-height: 1 !important; |
| color: var(--egoforce-text-primary) !important; |
| }} |
| |
| .egoforce-brand-black, |
| .prose .egoforce-hero .egoforce-brand-black {{ |
| color: var(--body-text-color) !important; |
| font-variant-caps: small-caps; |
| letter-spacing: 0.015em; |
| }} |
| |
| .egoforce-brand-force, |
| .prose .egoforce-hero .egoforce-brand-force {{ |
| color: #b42018 !important; |
| font-variant-caps: small-caps; |
| letter-spacing: 0.015em; |
| }} |
| |
| .egoforce-hero-subtitle, |
| .egoforce-hero-authors, |
| .egoforce-hero-venue, |
| .prose .egoforce-hero .egoforce-hero-subtitle, |
| .prose .egoforce-hero .egoforce-hero-authors, |
| .prose .egoforce-hero .egoforce-hero-venue {{ |
| color: var(--body-text-color) !important; |
| }} |
| |
| .egoforce-hero-affiliations, |
| .egoforce-hero-caption, |
| .prose .egoforce-hero .egoforce-hero-affiliations, |
| .prose .egoforce-hero .egoforce-hero-caption {{ |
| color: var(--body-text-color-subdued) !important; |
| }} |
| |
| .egoforce-hero-icon, |
| .prose .egoforce-hero img.egoforce-hero-icon {{ |
| height: clamp(3.2rem, 7vw, 5.4rem) !important; |
| width: auto !important; |
| max-width: none !important; |
| }} |
| |
| .egoforce-hero-link, |
| .prose .egoforce-hero .egoforce-hero-link {{ |
| background: var(--egoforce-link-bg) !important; |
| border-color: var(--egoforce-link-border) !important; |
| color: var(--egoforce-link-text) !important; |
| }} |
| |
| .egoforce-hero-link:hover, |
| .prose .egoforce-hero .egoforce-hero-link:hover {{ |
| background: var(--egoforce-link-hover-bg) !important; |
| border-color: var(--egoforce-link-hover-border) !important; |
| }} |
| |
| .egoforce-hero-link svg, |
| .prose .egoforce-hero .egoforce-hero-link svg {{ |
| color: var(--egoforce-link-text) !important; |
| }} |
| |
| @media (prefers-color-scheme: dark) {{ |
| body, |
| .gradio-container {{ |
| background-image: |
| radial-gradient(circle at center top, rgba(56, 189, 248, 0.18), rgba(0, 0, 0, 0) 36%), |
| linear-gradient(180deg, #0b1020 0%, #121a2f 100%) !important; |
| color: #e5e7eb !important; |
| }} |
| |
| .egoforce-hero-card, |
| .prose .egoforce-hero .egoforce-hero-card {{ |
| background: linear-gradient(165deg, rgba(17, 24, 39, 0.94), rgba(30, 41, 59, 0.9)) !important; |
| border-color: rgba(125, 211, 252, 0.26) !important; |
| box-shadow: 0 24px 56px rgba(2, 8, 23, 0.52) !important; |
| }} |
| |
| .egoforce-hero-title, |
| .egoforce-hero-subtitle, |
| .egoforce-hero-authors, |
| .egoforce-hero-affiliations, |
| .egoforce-hero-venue, |
| .egoforce-hero-caption, |
| .prose .egoforce-hero .egoforce-hero-title, |
| .prose .egoforce-hero .egoforce-hero-subtitle, |
| .prose .egoforce-hero .egoforce-hero-authors, |
| .prose .egoforce-hero .egoforce-hero-affiliations, |
| .prose .egoforce-hero .egoforce-hero-venue, |
| .prose .egoforce-hero .egoforce-hero-caption {{ |
| color: #e5e7eb !important; |
| }} |
| |
| .egoforce-brand-black, |
| .prose .egoforce-hero .egoforce-brand-black {{ |
| color: #e5e7eb !important; |
| }} |
| |
| .egoforce-brand-force, |
| .prose .egoforce-hero .egoforce-brand-force {{ |
| color: #b42018 !important; |
| }} |
| |
| .egoforce-hero-link, |
| .prose .egoforce-hero .egoforce-hero-link {{ |
| background: rgba(15, 23, 42, 0.62) !important; |
| border-color: rgba(125, 211, 252, 0.4) !important; |
| color: #e2e8f0 !important; |
| }} |
| |
| .egoforce-hero-link:hover, |
| .prose .egoforce-hero .egoforce-hero-link:hover {{ |
| background: rgba(14, 116, 144, 0.35) !important; |
| border-color: rgba(125, 211, 252, 0.65) !important; |
| }} |
| |
| .egoforce-hero-link svg, |
| .prose .egoforce-hero .egoforce-hero-link svg {{ |
| color: #e2e8f0 !important; |
| }} |
| }} |
| |
| html.dark .egoforce-hero-card, |
| body.dark .egoforce-hero-card, |
| html[data-theme="dark"] .egoforce-hero-card, |
| body[data-theme="dark"] .egoforce-hero-card, |
| .dark .egoforce-hero-card, |
| html.dark .prose .egoforce-hero .egoforce-hero-card, |
| body.dark .prose .egoforce-hero .egoforce-hero-card, |
| html[data-theme="dark"] .prose .egoforce-hero .egoforce-hero-card, |
| body[data-theme="dark"] .prose .egoforce-hero .egoforce-hero-card, |
| .dark .prose .egoforce-hero .egoforce-hero-card {{ |
| background: linear-gradient(165deg, rgba(17, 24, 39, 0.94), rgba(30, 41, 59, 0.9)) !important; |
| border-color: rgba(125, 211, 252, 0.26) !important; |
| box-shadow: 0 24px 56px rgba(2, 8, 23, 0.52) !important; |
| }} |
| |
| html.dark .gradio-container, |
| body.dark .gradio-container, |
| html[data-theme="dark"] .gradio-container, |
| body[data-theme="dark"] .gradio-container, |
| .dark .gradio-container {{ |
| background-image: |
| radial-gradient(circle at center top, rgba(56, 189, 248, 0.18), rgba(0, 0, 0, 0) 36%), |
| linear-gradient(180deg, #0b1020 0%, #121a2f 100%) !important; |
| color: #e5e7eb !important; |
| }} |
| |
| html.dark .egoforce-hero-title, |
| html.dark .egoforce-hero-subtitle, |
| html.dark .egoforce-hero-authors, |
| html.dark .egoforce-hero-affiliations, |
| html.dark .egoforce-hero-venue, |
| html.dark .egoforce-hero-caption, |
| body.dark .egoforce-hero-title, |
| body.dark .egoforce-hero-subtitle, |
| body.dark .egoforce-hero-authors, |
| body.dark .egoforce-hero-affiliations, |
| body.dark .egoforce-hero-venue, |
| body.dark .egoforce-hero-caption, |
| html[data-theme="dark"] .egoforce-hero-title, |
| html[data-theme="dark"] .egoforce-hero-subtitle, |
| html[data-theme="dark"] .egoforce-hero-authors, |
| html[data-theme="dark"] .egoforce-hero-affiliations, |
| html[data-theme="dark"] .egoforce-hero-venue, |
| html[data-theme="dark"] .egoforce-hero-caption, |
| body[data-theme="dark"] .egoforce-hero-title, |
| body[data-theme="dark"] .egoforce-hero-subtitle, |
| body[data-theme="dark"] .egoforce-hero-authors, |
| body[data-theme="dark"] .egoforce-hero-affiliations, |
| body[data-theme="dark"] .egoforce-hero-venue, |
| body[data-theme="dark"] .egoforce-hero-caption, |
| .dark .egoforce-hero-title, |
| .dark .egoforce-hero-subtitle, |
| .dark .egoforce-hero-authors, |
| .dark .egoforce-hero-affiliations, |
| .dark .egoforce-hero-venue, |
| .dark .egoforce-hero-caption {{ |
| color: #e5e7eb !important; |
| }} |
| |
| html.dark .egoforce-brand-black, |
| body.dark .egoforce-brand-black, |
| html[data-theme="dark"] .egoforce-brand-black, |
| body[data-theme="dark"] .egoforce-brand-black, |
| .dark .egoforce-brand-black {{ |
| color: #e5e7eb !important; |
| }} |
| |
| html.dark .egoforce-brand-force, |
| body.dark .egoforce-brand-force, |
| html[data-theme="dark"] .egoforce-brand-force, |
| body[data-theme="dark"] .egoforce-brand-force, |
| .dark .egoforce-brand-force {{ |
| color: #b42018 !important; |
| }} |
| |
| html.dark .egoforce-hero-link, |
| body.dark .egoforce-hero-link, |
| html[data-theme="dark"] .egoforce-hero-link, |
| body[data-theme="dark"] .egoforce-hero-link, |
| .dark .egoforce-hero-link {{ |
| background: rgba(15, 23, 42, 0.62) !important; |
| border-color: rgba(125, 211, 252, 0.4) !important; |
| color: #e2e8f0 !important; |
| }} |
| """.strip() |
|
|
|
|
| def run_command(command: list[str], cwd: Path | None = None) -> None: |
| print(f"+ {shlex.join(command)}", flush=True) |
| subprocess.run(command, cwd=str(cwd) if cwd else None, check=True) |
|
|
|
|
| def configure_runtime_environment() -> None: |
| os.environ.setdefault("EGOFORCE_ROOT", str(EGOFORCE_ROOT)) |
| configure_cuda_environment() |
| os.environ.setdefault("FORCE_CUDA", "1") |
| configure_torch_cuda_arch_list() |
| os.environ.setdefault("MAX_JOBS", "1") |
| os.environ.setdefault("CMAKE_BUILD_PARALLEL_LEVEL", "1") |
| os.environ.setdefault("NINJAFLAGS", "-j1") |
| os.environ.setdefault("MAKEFLAGS", "-j1") |
| os.environ.setdefault("PYOPENGL_PLATFORM", "egl") |
| os.environ.setdefault("MPLBACKEND", "Agg") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") |
|
|
|
|
| def configure_torch_cuda_arch_list() -> None: |
| if os.environ.get("TORCH_CUDA_ARCH_LIST"): |
| return |
|
|
| accelerator = os.environ.get("ACCELERATOR", "").lower() |
| if "zero" in accelerator or "h200" in accelerator: |
| arch_list = "9.0" |
| elif "t4" in accelerator: |
| arch_list = "7.5" |
| elif "a100" in accelerator: |
| arch_list = "8.0" |
| elif "a10g" in accelerator: |
| arch_list = "8.6" |
| elif "l4" in accelerator or "l40" in accelerator: |
| arch_list = "8.9" |
| else: |
| arch_list = "9.0" |
|
|
| os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list |
|
|
|
|
| def candidate_site_packages() -> list[Path]: |
| paths = [Path(path) for path in site.getsitepackages()] |
| user_site = site.getusersitepackages() |
| if user_site: |
| paths.append(Path(user_site)) |
| return paths |
|
|
|
|
| def find_python_cuda_home() -> Path | None: |
| for site_packages in candidate_site_packages(): |
| cuda_home = site_packages / "nvidia" / "cuda_nvcc" |
| if (cuda_home / "bin" / "nvcc").exists(): |
| return cuda_home |
| return None |
|
|
|
|
| def configure_cuda_environment() -> None: |
| current_cuda_home = os.environ.get("CUDA_HOME") |
| if current_cuda_home and (Path(current_cuda_home) / "bin" / "nvcc").exists(): |
| cuda_home = Path(current_cuda_home) |
| else: |
| cuda_home = find_python_cuda_home() or Path(current_cuda_home or "/usr/local/cuda") |
| os.environ["CUDA_HOME"] = str(cuda_home) |
|
|
| cuda_bin = cuda_home / "bin" |
| if cuda_bin.exists(): |
| os.environ["PATH"] = f"{cuda_bin}:{os.environ.get('PATH', '')}" |
|
|
|
|
| def ensure_egoforce_repo() -> Path: |
| demo_entrypoint = EGOFORCE_ROOT / "demo" / "run_app.py" |
| if demo_entrypoint.exists(): |
| patch_upstream_gradio_for_zerogpu(demo_entrypoint) |
| patch_upstream_tensorrt_fallback(EGOFORCE_ROOT) |
| patch_upstream_gradio_hero_css(EGOFORCE_ROOT) |
| return EGOFORCE_ROOT |
|
|
| if EGOFORCE_ROOT.exists() and any(EGOFORCE_ROOT.iterdir()): |
| raise RuntimeError( |
| f"{EGOFORCE_ROOT} exists, but demo/run_app.py was not found. " |
| "Delete that directory or set EGOFORCE_ROOT to a clean location." |
| ) |
|
|
| EGOFORCE_ROOT.parent.mkdir(parents=True, exist_ok=True) |
| command = ["git", "clone", "--depth", "1"] |
| if EGOFORCE_REF: |
| command.extend(["--branch", EGOFORCE_REF]) |
| command.extend([EGOFORCE_REPO_URL, str(EGOFORCE_ROOT)]) |
| run_command(command) |
|
|
| if (EGOFORCE_ROOT / ".gitmodules").exists(): |
| run_command(["git", "submodule", "update", "--init", "--recursive"], cwd=EGOFORCE_ROOT) |
|
|
| if not demo_entrypoint.exists(): |
| raise RuntimeError(f"EgoForce demo entrypoint not found at {demo_entrypoint}") |
|
|
| patch_upstream_gradio_for_zerogpu(demo_entrypoint) |
| patch_upstream_tensorrt_fallback(EGOFORCE_ROOT) |
| patch_upstream_gradio_hero_css(EGOFORCE_ROOT) |
| return EGOFORCE_ROOT |
|
|
|
|
| def patch_upstream_gradio_for_zerogpu(demo_entrypoint: Path) -> None: |
| source = demo_entrypoint.read_text(encoding="utf-8") |
|
|
| if "from egoforce_runtime_patches import apply_runtime_patches\n" not in source: |
| if "import torch\n" not in source: |
| raise RuntimeError(f"Could not insert runtime patches in {demo_entrypoint}") |
| source = source.replace( |
| "import torch\n", |
| "import torch\nfrom egoforce_runtime_patches import apply_runtime_patches\napply_runtime_patches()\n", |
| 1, |
| ) |
|
|
| if "import spaces\n" not in source: |
| if "import torch\n" not in source: |
| raise RuntimeError(f"Could not insert ZeroGPU import in {demo_entrypoint}") |
| source = source.replace("import torch\n", "import spaces\nimport torch\n", 1) |
|
|
| if "def process_video(\n" not in source: |
| raise RuntimeError(f"Could not locate process_video in {demo_entrypoint}") |
|
|
| source = re.sub( |
| r"@spaces\.GPU\([^\n]*\)\n(?=def process_video\()", |
| "", |
| source, |
| count=1, |
| ) |
|
|
| desired_decorator = ( |
| "@spaces.GPU(" |
| f"duration={int(ZEROGPU_DURATION_SECONDS)}, " |
| f"size={ZEROGPU_SIZE!r}" |
| ")\n" |
| ) |
| if desired_decorator not in source: |
| source = source.replace("def process_video(\n", f"{desired_decorator}def process_video(\n", 1) |
|
|
| injected_css_loader = ( |
| "\n" |
| "@lru_cache(maxsize=1)\n" |
| "def load_gradio_hero_css():\n" |
| " if not GRADIO_HERO_CSS_PATH.exists():\n" |
| " return None\n" |
| " return GRADIO_HERO_CSS_PATH.read_text(encoding=\"utf-8\")\n" |
| ) |
| if injected_css_loader in source: |
| source = source.replace(injected_css_loader, "", 1) |
|
|
| source = source.replace(" css=load_gradio_hero_css(),\n", "") |
|
|
| demo_entrypoint.write_text(source, encoding="utf-8") |
|
|
|
|
| def patch_upstream_gradio_hero_css(repo_root: Path) -> None: |
| css_path = repo_root / "assets" / "css" / "gradio_hero.css" |
| if not css_path.exists(): |
| return |
|
|
| css_source = css_path.read_text(encoding="utf-8") |
| if HERO_CSS_SPACE_PATCH_MARKER in css_source: |
| base_css = css_source.split(HERO_CSS_SPACE_PATCH_MARKER, 1)[0].rstrip() |
| css_path.write_text(f"{base_css}\n\n{HERO_CSS_SPACE_PATCH}\n", encoding="utf-8") |
| return |
|
|
| css_path.write_text(f"{css_source.rstrip()}\n\n{HERO_CSS_SPACE_PATCH}\n", encoding="utf-8") |
|
|
|
|
| def patch_upstream_tensorrt_fallback(repo_root: Path) -> None: |
| inference_path = repo_root / "demo" / "inference.py" |
| demo_utils_path = repo_root / "demo" / "demo_utils.py" |
|
|
| inference_source = inference_path.read_text(encoding="utf-8") |
| if "TORCH_TENSORRT_IMPORT_ERROR = None\n" not in inference_source: |
| import_marker = "import torch\nimport torch_tensorrt\n\n" |
| if import_marker not in inference_source: |
| raise RuntimeError(f"Could not locate torch_tensorrt import in {inference_path}") |
| inference_source = inference_source.replace( |
| import_marker, |
| ( |
| "import torch\n" |
| "\n" |
| "try:\n" |
| " import torch_tensorrt\n" |
| " TORCH_TENSORRT_IMPORT_ERROR = None\n" |
| "except Exception as exc:\n" |
| " torch_tensorrt = None\n" |
| " TORCH_TENSORRT_IMPORT_ERROR = exc\n" |
| " print(f\"Torch-TensorRT unavailable: {exc}. Falling back to PyTorch inference.\", flush=True)\n" |
| "\n" |
| ), |
| 1, |
| ) |
|
|
| runtime_marker = ( |
| "torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n" |
| "torch_tensorrt.runtime.set_cudagraphs_mode(True)\n" |
| ) |
| if runtime_marker in inference_source: |
| inference_source = inference_source.replace( |
| runtime_marker, |
| ( |
| "if torch_tensorrt is not None:\n" |
| " torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n" |
| " torch_tensorrt.runtime.set_cudagraphs_mode(True)\n" |
| ), |
| 1, |
| ) |
| inference_path.write_text(inference_source, encoding="utf-8") |
|
|
| demo_utils_source = demo_utils_path.read_text(encoding="utf-8") |
| if "Torch-TensorRT backend unavailable" not in demo_utils_source: |
| old_compile_function = """def compile_to_tensorrt(model, device): |
| x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]) |
| x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device) |
| |
| with torch.inference_mode(): |
| model = model.to(device).half() |
| x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half() |
| model = torch.jit.trace(model, (x1, x2, x3, x4), strict=False) |
| |
| backend_kwargs = { |
| "enabled_precisions": {torch.half}, |
| "min_block_size": 2, |
| "torch_executed_ops": {"torch.ops.aten.sub.Tensor"}, |
| "optimization_level": 5, |
| "use_python_runtime": False, |
| } |
| |
| model = torch.compile(model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,) |
| with torch.no_grad(): |
| model(x1, x2, x3, x4) # compiled on first run |
| |
| return model |
| """ |
| new_compile_function = """def compile_to_tensorrt(model, device): |
| try: |
| import torch_tensorrt # noqa: F401 |
| except Exception as exc: |
| print(f"Torch-TensorRT backend unavailable: {exc}. Using PyTorch model.", flush=True) |
| return model.to(device).half() |
| |
| x1, x2, x3, x4 = torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]), torch.rand([2, 1, 3, 224, 224]), torch.rand([2, 1, 3, 6, 2]) |
| x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device) |
| |
| with torch.inference_mode(): |
| fallback_model = model.to(device).half() |
| x1, x2, x3, x4 = x1.half(), x2.half(), x3.half(), x4.half() |
| traced_model = torch.jit.trace(fallback_model, (x1, x2, x3, x4), strict=False) |
| |
| backend_kwargs = { |
| "enabled_precisions": {torch.half}, |
| "min_block_size": 2, |
| "torch_executed_ops": {"torch.ops.aten.sub.Tensor"}, |
| "optimization_level": 5, |
| "use_python_runtime": False, |
| } |
| |
| try: |
| compiled_model = torch.compile(traced_model, backend="torch_tensorrt", options=backend_kwargs, dynamic=False,) |
| with torch.no_grad(): |
| compiled_model(x1, x2, x3, x4) # compiled on first run |
| return compiled_model |
| except Exception as exc: |
| print(f"Torch-TensorRT compile failed: {exc}. Using PyTorch model.", flush=True) |
| return fallback_model |
| """ |
| if old_compile_function not in demo_utils_source: |
| raise RuntimeError(f"Could not locate compile_to_tensorrt in {demo_utils_path}") |
| demo_utils_source = demo_utils_source.replace(old_compile_function, new_compile_function, 1) |
| demo_utils_path.write_text(demo_utils_source, encoding="utf-8") |
|
|
|
|
| def package_available(module_name: str) -> bool: |
| return importlib.util.find_spec(module_name) is not None |
|
|
|
|
| def pip_install(requirement: str, *extra_args: str) -> None: |
| command = [ |
| sys.executable, |
| "-m", |
| "pip", |
| "install", |
| "--no-cache-dir", |
| "--disable-pip-version-check", |
| requirement, |
| *extra_args, |
| ] |
| run_command(command) |
|
|
|
|
| def ensure_runtime_python_packages(repo_root: Path) -> None: |
| datapipes_path = repo_root / "thirdparty" / "datapipes" |
| install_plan = [ |
| ("anycalib", "git+https://github.com/javrtg/AnyCalib.git", ("--no-build-isolation",)), |
| ("chumpy", "git+https://github.com/mattloper/chumpy.git", ("--no-build-isolation",)), |
| ("datapipes", str(datapipes_path), ()), |
| ("mmdet", str(repo_root / "thirdparty" / "mmdetection"), ("--no-build-isolation", "--no-deps")), |
| ] |
|
|
| for module_name, requirement, extra_args in install_plan: |
| if package_available(module_name): |
| continue |
| pip_install(requirement, *extra_args) |
|
|
|
|
| def validate_assets(repo_root: Path) -> None: |
| data_dir = repo_root / "_DATA" |
| required_paths = [ |
| data_dir / "model_weights.pth", |
| data_dir / "epoch_460.pth", |
| data_dir / "detector.torchscript", |
| data_dir / "mano", |
| ] |
| missing = [str(path) for path in required_paths if not path.exists()] |
| if missing: |
| raise RuntimeError(f"Missing required EgoForce assets: {missing}") |
|
|
|
|
| def ensure_egoforce_assets(repo_root: Path) -> None: |
| from huggingface_hub import snapshot_download |
|
|
| try: |
| validate_assets(repo_root) |
| print(f"Using existing EgoForce assets at {repo_root / '_DATA'}", flush=True) |
| return |
| except RuntimeError: |
| pass |
|
|
| target_data_dir = repo_root / "_DATA" |
| cache_root = repo_root / ".hf-download" |
|
|
| if cache_root.exists(): |
| shutil.rmtree(cache_root) |
|
|
| print(f"Downloading EgoForce assets from {EGOFORCE_ASSETS_REPO_ID}", flush=True) |
| snapshot_path = Path( |
| snapshot_download( |
| repo_id=EGOFORCE_ASSETS_REPO_ID, |
| repo_type="model", |
| allow_patterns=["_DATA/**"], |
| local_dir=cache_root, |
| ) |
| ) |
| source_data_dir = snapshot_path / "_DATA" |
| if not source_data_dir.exists(): |
| raise RuntimeError(f"Downloaded snapshot did not contain {source_data_dir}") |
|
|
| shutil.copytree(source_data_dir, target_data_dir, dirs_exist_ok=True) |
| shutil.rmtree(cache_root, ignore_errors=True) |
| validate_assets(repo_root) |
| print(f"Downloaded EgoForce assets to {target_data_dir}", flush=True) |
|
|
|
|
| def launch_upstream_gradio(repo_root: Path) -> None: |
| demo_entrypoint = repo_root / "demo" / "run_app.py" |
| demo_dir = demo_entrypoint.parent |
| server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") |
| server_port = os.environ.get("PORT") or os.environ.get("GRADIO_SERVER_PORT") or "7860" |
|
|
| os.chdir(repo_root) |
| for import_path in (repo_root, demo_dir): |
| import_path_string = str(import_path) |
| if import_path_string in sys.path: |
| sys.path.remove(import_path_string) |
| sys.path.insert(0, import_path_string) |
|
|
| sys.argv = [ |
| str(demo_entrypoint), |
| "--server-name", |
| server_name, |
| "--server-port", |
| str(server_port), |
| ] |
|
|
| if os.environ.get("GRADIO_SHARE", "").lower() in {"1", "true", "yes"}: |
| sys.argv.append("--share") |
|
|
| runpy.run_path(str(demo_entrypoint), run_name="__main__") |
|
|
|
|
| def main() -> None: |
| configure_runtime_environment() |
| repo_root = ensure_egoforce_repo() |
| ensure_runtime_python_packages(repo_root) |
| ensure_egoforce_assets(repo_root) |
| launch_upstream_gradio(repo_root) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|