Christen Millerdurai commited on
Commit
f0a5ba2
·
1 Parent(s): 6d8feaa
Files changed (1) hide show
  1. egoforce_runtime_patches.py +51 -0
egoforce_runtime_patches.py CHANGED
@@ -1,8 +1,11 @@
1
  from __future__ import annotations
2
 
 
3
  import importlib
4
  import sys
 
5
  import types
 
6
  from typing import Any
7
 
8
 
@@ -302,7 +305,55 @@ def _ensure_module(module_name: str) -> types.ModuleType:
302
  return module
303
 
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def apply_runtime_patches() -> None:
 
 
306
  try:
307
  mmcv = importlib.import_module("mmcv")
308
  except ImportError:
 
1
  from __future__ import annotations
2
 
3
+ import functools
4
  import importlib
5
  import sys
6
+ import tempfile
7
  import types
8
+ from pathlib import Path
9
  from typing import Any
10
 
11
 
 
305
  return module
306
 
307
 
308
+ _GRADIO_CSS_PATCH_PATH: Path | None = None
309
+
310
+
311
+ def _normalize_gradio_css_paths(css_paths: Any) -> list[str]:
312
+ if css_paths is None:
313
+ return []
314
+ if isinstance(css_paths, (str, Path)):
315
+ return [str(css_paths)]
316
+ return [str(path) for path in css_paths]
317
+
318
+
319
+ def _persist_egoforce_gradio_css(css: str) -> str:
320
+ global _GRADIO_CSS_PATCH_PATH
321
+
322
+ if _GRADIO_CSS_PATCH_PATH is None:
323
+ _GRADIO_CSS_PATCH_PATH = Path(tempfile.gettempdir()) / "egoforce-gradio-launch.css"
324
+ _GRADIO_CSS_PATCH_PATH.write_text(css, encoding="utf-8")
325
+ return str(_GRADIO_CSS_PATCH_PATH)
326
+
327
+
328
+ def _patch_gradio_launch() -> None:
329
+ try:
330
+ import gradio as gr
331
+ except ImportError:
332
+ return
333
+
334
+ launch_method = getattr(gr.Blocks, "launch", None)
335
+ if launch_method is None or getattr(launch_method, "__egoforce_runtime_patch__", False):
336
+ return
337
+
338
+ @functools.wraps(launch_method)
339
+ def patched_launch(self: Any, *args: Any, **kwargs: Any) -> Any:
340
+ css = kwargs.get("css")
341
+ if isinstance(css, str) and css.strip() and (".egoforce-hero" in css or "#sample-video-carousel" in css):
342
+ css_path_entries = _normalize_gradio_css_paths(kwargs.get("css_paths"))
343
+ patched_css_path = _persist_egoforce_gradio_css(css)
344
+ if patched_css_path not in css_path_entries:
345
+ css_path_entries.append(patched_css_path)
346
+ kwargs["css_paths"] = css_path_entries
347
+ kwargs["css"] = None
348
+ return launch_method(self, *args, **kwargs)
349
+
350
+ setattr(patched_launch, "__egoforce_runtime_patch__", True)
351
+ gr.Blocks.launch = patched_launch
352
+
353
+
354
  def apply_runtime_patches() -> None:
355
+ _patch_gradio_launch()
356
+
357
  try:
358
  mmcv = importlib.import_module("mmcv")
359
  except ImportError: