from __future__ import annotations import functools import importlib import sys import tempfile import types from pathlib import Path from typing import Any def _torchvision_nms(): from torchvision.ops import nms return nms def _torchvision_roi_align(): from torchvision.ops import roi_align return roi_align def _torchvision_roi_align_module(): from torchvision.ops import RoIAlign return RoIAlign def _torchvision_roi_pool_module(): from torchvision.ops import RoIPool return RoIPool def _multiscale_deformable_attention_class(): import torch.nn as nn class MultiScaleDeformableAttention(nn.Module): """Import-only fallback for mmdet registries when running with mmcv-lite.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() def init_weights(self) -> None: return None def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( "MultiScaleDeformableAttention requires full mmcv with compiled ops. " "The EgoForce demo uses RTMDet and should not execute this layer." ) return MultiScaleDeformableAttention def _unsupported_module_class(name: str): import torch.nn as nn class UnsupportedMMCVOp(nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() def init_weights(self) -> None: return None def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.") UnsupportedMMCVOp.__name__ = name return UnsupportedMMCVOp def _unsupported_function(name: str): def unsupported(*args: Any, **kwargs: Any) -> Any: raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.") unsupported.__name__ = name return unsupported def _bbox_overlaps( bboxes1: Any, bboxes2: Any, mode: str = "iou", aligned: bool = False, offset: int = 0, eps: float = 1e-6, ) -> Any: import torch if bboxes1.numel() == 0 or bboxes2.numel() == 0: if aligned: return bboxes1.new_zeros((bboxes1.shape[0],)) return bboxes1.new_zeros((bboxes1.shape[0], bboxes2.shape[0])) if aligned: lt = torch.maximum(bboxes1[:, :2], bboxes2[:, :2]) rb = torch.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) wh = (rb - lt + offset).clamp(min=0) overlap = wh[:, 0] * wh[:, 1] area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset) area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset) else: lt = torch.maximum(bboxes1[:, None, :2], bboxes2[None, :, :2]) rb = torch.minimum(bboxes1[:, None, 2:], bboxes2[None, :, 2:]) wh = (rb - lt + offset).clamp(min=0) overlap = wh[..., 0] * wh[..., 1] area1 = ((bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset))[:, None] area2 = ((bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset))[None, :] if mode == "iof": union = area1 elif mode == "giou": union = area1 + area2 - overlap if aligned: enclosed_lt = torch.minimum(bboxes1[:, :2], bboxes2[:, :2]) enclosed_rb = torch.maximum(bboxes1[:, 2:], bboxes2[:, 2:]) else: enclosed_lt = torch.minimum(bboxes1[:, None, :2], bboxes2[None, :, :2]) enclosed_rb = torch.maximum(bboxes1[:, None, 2:], bboxes2[None, :, 2:]) enclosed_wh = (enclosed_rb - enclosed_lt + offset).clamp(min=0) enclosed_area = enclosed_wh[..., 0] * enclosed_wh[..., 1] iou = overlap / union.clamp(min=eps) return iou - (enclosed_area - union) / enclosed_area.clamp(min=eps) else: union = area1 + area2 - overlap return overlap / union.clamp(min=eps) def _nms( boxes: Any, scores: Any, iou_threshold: float, offset: int = 0, score_threshold: float = 0, max_num: int = -1, ) -> tuple[Any, Any]: import torch if boxes.numel() == 0 or scores.numel() == 0: keep = torch.empty((0,), dtype=torch.long, device=scores.device) dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1) return dets, keep if score_threshold > 0: valid = scores > score_threshold original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1) filtered_boxes = boxes[valid] filtered_scores = scores[valid] else: original_indices = torch.arange(scores.numel(), device=scores.device) filtered_boxes = boxes filtered_scores = scores keep_local = _torchvision_nms()(filtered_boxes, filtered_scores, float(iou_threshold)) if max_num > 0: keep_local = keep_local[:max_num] keep = original_indices[keep_local] dets = torch.cat((filtered_boxes[keep_local], filtered_scores[keep_local, None]), dim=1) return dets, keep def _batched_nms( boxes: Any, scores: Any, idxs: Any, nms_cfg: dict[str, Any] | None, class_agnostic: bool = False, ) -> tuple[Any, Any]: import torch if boxes.numel() == 0 or scores.numel() == 0: keep = torch.empty((0,), dtype=torch.long, device=scores.device) dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1) return dets, keep if nms_cfg is None: order = scores.argsort(descending=True) return torch.cat((boxes[order], scores[order, None]), dim=1), order nms_cfg = dict(nms_cfg) iou_threshold = nms_cfg.pop("iou_threshold", nms_cfg.pop("iou_thr", 0.5)) score_threshold = nms_cfg.pop("score_threshold", 0) max_num = nms_cfg.pop("max_num", -1) if class_agnostic: boxes_for_nms = boxes else: max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + boxes.new_tensor(1)) boxes_for_nms = boxes + offsets[:, None] if score_threshold > 0: valid = scores > score_threshold original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1) boxes_for_nms = boxes_for_nms[valid] scores_for_nms = scores[valid] else: original_indices = torch.arange(scores.numel(), device=scores.device) scores_for_nms = scores keep_local = _torchvision_nms()(boxes_for_nms, scores_for_nms, float(iou_threshold)) if max_num > 0: keep_local = keep_local[:max_num] keep = original_indices[keep_local] dets = torch.cat((boxes[keep], scores[keep, None]), dim=1) return dets, keep def _nms_match(dets: Any, iou_threshold: float) -> list[Any]: """Pure PyTorch fallback for mmcv.ops.nms_match import paths.""" import torch if dets.numel() == 0: return [] boxes = dets[:, :4] scores = dets[:, 4] order = scores.argsort(descending=True) groups = [] while order.numel() > 0: current = order[0] if order.numel() == 1: groups.append(current.reshape(1)) break rest = order[1:] lt = torch.maximum(boxes[current, :2], boxes[rest, :2]) rb = torch.minimum(boxes[current, 2:], boxes[rest, 2:]) wh = (rb - lt).clamp(min=0) inter = wh[:, 0] * wh[:, 1] current_area = (boxes[current, 2] - boxes[current, 0]).clamp(min=0) * ( boxes[current, 3] - boxes[current, 1] ).clamp(min=0) rest_area = (boxes[rest, 2] - boxes[rest, 0]).clamp(min=0) * ( boxes[rest, 3] - boxes[rest, 1] ).clamp(min=0) iou = inter / (current_area + rest_area - inter).clamp(min=1e-6) matched = rest[iou > float(iou_threshold)] groups.append(torch.cat((current.reshape(1), matched))) order = rest[iou <= float(iou_threshold)] return groups def _point_sample(input: Any, points: Any, align_corners: bool = False, **kwargs: Any) -> Any: import torch.nn.functional as F add_dim = False if points.dim() == 3: add_dim = True points = points.unsqueeze(2) output = F.grid_sample(input, points.mul(2).sub(1), align_corners=align_corners, **kwargs) if add_dim: output = output.squeeze(3) return output def _rel_roi_point_to_rel_img_point(rois: Any, rel_roi_points: Any, img_shape: Any) -> Any: x1, y1, x2, y2 = rois[:, 1], rois[:, 2], rois[:, 3], rois[:, 4] roi_w = (x2 - x1).clamp(min=1) roi_h = (y2 - y1).clamp(min=1) img_h, img_w = img_shape[:2] rel_img_points = rel_roi_points.clone() rel_img_points[..., 0] = (x1[:, None] + rel_roi_points[..., 0] * roi_w[:, None]) / float(img_w) rel_img_points[..., 1] = (y1[:, None] + rel_roi_points[..., 1] * roi_h[:, None]) / float(img_h) return rel_img_points def _sigmoid_focal_loss( pred: Any, target: Any, gamma: float = 2.0, alpha: float = 0.25, weight: Any = None, reduction: str = "mean", ) -> Any: import torch.nn.functional as F pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight if weight is not None: loss = loss * weight if reduction == "sum": return loss.sum() if reduction == "mean": return loss.mean() return loss def _ensure_module(module_name: str) -> types.ModuleType: module = sys.modules.get(module_name) if module is None: module = types.ModuleType(module_name) sys.modules[module_name] = module return module _GRADIO_CSS_PATCH_PATH: Path | None = None def _normalize_gradio_css_paths(css_paths: Any) -> list[str]: if css_paths is None: return [] if isinstance(css_paths, (str, Path)): return [str(css_paths)] return [str(path) for path in css_paths] def _persist_egoforce_gradio_css(css: str) -> str: global _GRADIO_CSS_PATCH_PATH if _GRADIO_CSS_PATCH_PATH is None: _GRADIO_CSS_PATCH_PATH = Path(tempfile.gettempdir()) / "egoforce-gradio-launch.css" _GRADIO_CSS_PATCH_PATH.write_text(css, encoding="utf-8") return str(_GRADIO_CSS_PATCH_PATH) def _patch_gradio_launch() -> None: try: import gradio as gr except ImportError: return launch_method = getattr(gr.Blocks, "launch", None) if launch_method is None or getattr(launch_method, "__egoforce_runtime_patch__", False): return @functools.wraps(launch_method) def patched_launch(self: Any, *args: Any, **kwargs: Any) -> Any: css = kwargs.get("css") if isinstance(css, str) and css.strip() and (".egoforce-hero" in css or "#sample-video-carousel" in css): css_path_entries = _normalize_gradio_css_paths(kwargs.get("css_paths")) patched_css_path = _persist_egoforce_gradio_css(css) if patched_css_path not in css_path_entries: css_path_entries.append(patched_css_path) kwargs["css_paths"] = css_path_entries kwargs["css"] = None return launch_method(self, *args, **kwargs) setattr(patched_launch, "__egoforce_runtime_patch__", True) gr.Blocks.launch = patched_launch def apply_runtime_patches() -> None: _patch_gradio_launch() try: mmcv = importlib.import_module("mmcv") except ImportError: mmcv = None ops_module = sys.modules.get("mmcv.ops") if ops_module is None: ops_module = types.ModuleType("mmcv.ops") sys.modules["mmcv.ops"] = ops_module ops_module.__path__ = [] nms_module = _ensure_module("mmcv.ops.nms") roi_align_module = _ensure_module("mmcv.ops.roi_align") deform_conv_module = _ensure_module("mmcv.ops.deform_conv") modulated_deform_conv_module = _ensure_module("mmcv.ops.modulated_deform_conv") carafe_module = _ensure_module("mmcv.ops.carafe") merge_cells_module = _ensure_module("mmcv.ops.merge_cells") multi_scale_deform_attn_module = _ensure_module("mmcv.ops.multi_scale_deform_attn") deform_conv2d = _unsupported_function("deform_conv2d") DeformConv2d = _unsupported_module_class("DeformConv2d") ModulatedDeformConv2d = _unsupported_module_class("ModulatedDeformConv2d") MaskedConv2d = _unsupported_module_class("MaskedConv2d") CornerPool = _unsupported_module_class("CornerPool") CARAFEPack = _unsupported_module_class("CARAFEPack") GlobalPoolingCell = _unsupported_module_class("GlobalPoolingCell") SumCell = _unsupported_module_class("SumCell") ConcatCell = _unsupported_module_class("ConcatCell") MultiScaleDeformableAttention = _multiscale_deformable_attention_class() ops_module.nms = _nms ops_module.batched_nms = _batched_nms ops_module.nms_match = _nms_match ops_module.point_sample = _point_sample ops_module.rel_roi_point_to_rel_img_point = _rel_roi_point_to_rel_img_point ops_module.sigmoid_focal_loss = _sigmoid_focal_loss ops_module.bbox_overlaps = _bbox_overlaps ops_module.roi_align = _torchvision_roi_align() ops_module.RoIAlign = _torchvision_roi_align_module() ops_module.RoIPool = _torchvision_roi_pool_module() ops_module.deform_conv2d = deform_conv2d ops_module.DeformConv2d = DeformConv2d ops_module.ModulatedDeformConv2d = ModulatedDeformConv2d ops_module.MaskedConv2d = MaskedConv2d ops_module.CornerPool = CornerPool ops_module.CARAFEPack = CARAFEPack ops_module.GlobalPoolingCell = GlobalPoolingCell ops_module.SumCell = SumCell ops_module.ConcatCell = ConcatCell ops_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention nms_module.nms = _nms nms_module.batched_nms = _batched_nms roi_align_module.roi_align = ops_module.roi_align roi_align_module.RoIAlign = ops_module.RoIAlign deform_conv_module.deform_conv2d = deform_conv2d deform_conv_module.DeformConv2d = DeformConv2d modulated_deform_conv_module.ModulatedDeformConv2d = ModulatedDeformConv2d carafe_module.CARAFEPack = CARAFEPack merge_cells_module.GlobalPoolingCell = GlobalPoolingCell merge_cells_module.SumCell = SumCell merge_cells_module.ConcatCell = ConcatCell multi_scale_deform_attn_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention try: transformer_module = importlib.import_module("mmcv.cnn.bricks.transformer") if not hasattr(transformer_module, "MultiScaleDeformableAttention"): transformer_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention except ImportError: pass if mmcv is not None: mmcv.ops = ops_module