| from __future__ import annotations |
|
|
| import functools |
| import importlib |
| import sys |
| import tempfile |
| import types |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| def _torchvision_nms(): |
| from torchvision.ops import nms |
|
|
| return nms |
|
|
|
|
| def _torchvision_roi_align(): |
| from torchvision.ops import roi_align |
|
|
| return roi_align |
|
|
|
|
| def _torchvision_roi_align_module(): |
| from torchvision.ops import RoIAlign |
|
|
| return RoIAlign |
|
|
|
|
| def _torchvision_roi_pool_module(): |
| from torchvision.ops import RoIPool |
|
|
| return RoIPool |
|
|
|
|
| def _multiscale_deformable_attention_class(): |
| import torch.nn as nn |
|
|
| class MultiScaleDeformableAttention(nn.Module): |
| """Import-only fallback for mmdet registries when running with mmcv-lite.""" |
|
|
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__() |
|
|
| def init_weights(self) -> None: |
| return None |
|
|
| def forward(self, *args: Any, **kwargs: Any) -> Any: |
| raise RuntimeError( |
| "MultiScaleDeformableAttention requires full mmcv with compiled ops. " |
| "The EgoForce demo uses RTMDet and should not execute this layer." |
| ) |
|
|
| return MultiScaleDeformableAttention |
|
|
|
|
| def _unsupported_module_class(name: str): |
| import torch.nn as nn |
|
|
| class UnsupportedMMCVOp(nn.Module): |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__() |
|
|
| def init_weights(self) -> None: |
| return None |
|
|
| def forward(self, *args: Any, **kwargs: Any) -> Any: |
| raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.") |
|
|
| UnsupportedMMCVOp.__name__ = name |
| return UnsupportedMMCVOp |
|
|
|
|
| def _unsupported_function(name: str): |
| def unsupported(*args: Any, **kwargs: Any) -> Any: |
| raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.") |
|
|
| unsupported.__name__ = name |
| return unsupported |
|
|
|
|
| def _bbox_overlaps( |
| bboxes1: Any, |
| bboxes2: Any, |
| mode: str = "iou", |
| aligned: bool = False, |
| offset: int = 0, |
| eps: float = 1e-6, |
| ) -> Any: |
| import torch |
|
|
| if bboxes1.numel() == 0 or bboxes2.numel() == 0: |
| if aligned: |
| return bboxes1.new_zeros((bboxes1.shape[0],)) |
| return bboxes1.new_zeros((bboxes1.shape[0], bboxes2.shape[0])) |
|
|
| if aligned: |
| lt = torch.maximum(bboxes1[:, :2], bboxes2[:, :2]) |
| rb = torch.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) |
| wh = (rb - lt + offset).clamp(min=0) |
| overlap = wh[:, 0] * wh[:, 1] |
| area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset) |
| area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset) |
| else: |
| lt = torch.maximum(bboxes1[:, None, :2], bboxes2[None, :, :2]) |
| rb = torch.minimum(bboxes1[:, None, 2:], bboxes2[None, :, 2:]) |
| wh = (rb - lt + offset).clamp(min=0) |
| overlap = wh[..., 0] * wh[..., 1] |
| area1 = ((bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset))[:, None] |
| area2 = ((bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset))[None, :] |
|
|
| if mode == "iof": |
| union = area1 |
| elif mode == "giou": |
| union = area1 + area2 - overlap |
| if aligned: |
| enclosed_lt = torch.minimum(bboxes1[:, :2], bboxes2[:, :2]) |
| enclosed_rb = torch.maximum(bboxes1[:, 2:], bboxes2[:, 2:]) |
| else: |
| enclosed_lt = torch.minimum(bboxes1[:, None, :2], bboxes2[None, :, :2]) |
| enclosed_rb = torch.maximum(bboxes1[:, None, 2:], bboxes2[None, :, 2:]) |
| enclosed_wh = (enclosed_rb - enclosed_lt + offset).clamp(min=0) |
| enclosed_area = enclosed_wh[..., 0] * enclosed_wh[..., 1] |
| iou = overlap / union.clamp(min=eps) |
| return iou - (enclosed_area - union) / enclosed_area.clamp(min=eps) |
| else: |
| union = area1 + area2 - overlap |
|
|
| return overlap / union.clamp(min=eps) |
|
|
|
|
| def _nms( |
| boxes: Any, |
| scores: Any, |
| iou_threshold: float, |
| offset: int = 0, |
| score_threshold: float = 0, |
| max_num: int = -1, |
| ) -> tuple[Any, Any]: |
| import torch |
|
|
| if boxes.numel() == 0 or scores.numel() == 0: |
| keep = torch.empty((0,), dtype=torch.long, device=scores.device) |
| dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1) |
| return dets, keep |
|
|
| if score_threshold > 0: |
| valid = scores > score_threshold |
| original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1) |
| filtered_boxes = boxes[valid] |
| filtered_scores = scores[valid] |
| else: |
| original_indices = torch.arange(scores.numel(), device=scores.device) |
| filtered_boxes = boxes |
| filtered_scores = scores |
|
|
| keep_local = _torchvision_nms()(filtered_boxes, filtered_scores, float(iou_threshold)) |
| if max_num > 0: |
| keep_local = keep_local[:max_num] |
|
|
| keep = original_indices[keep_local] |
| dets = torch.cat((filtered_boxes[keep_local], filtered_scores[keep_local, None]), dim=1) |
| return dets, keep |
|
|
|
|
| def _batched_nms( |
| boxes: Any, |
| scores: Any, |
| idxs: Any, |
| nms_cfg: dict[str, Any] | None, |
| class_agnostic: bool = False, |
| ) -> tuple[Any, Any]: |
| import torch |
|
|
| if boxes.numel() == 0 or scores.numel() == 0: |
| keep = torch.empty((0,), dtype=torch.long, device=scores.device) |
| dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1) |
| return dets, keep |
|
|
| if nms_cfg is None: |
| order = scores.argsort(descending=True) |
| return torch.cat((boxes[order], scores[order, None]), dim=1), order |
|
|
| nms_cfg = dict(nms_cfg) |
| iou_threshold = nms_cfg.pop("iou_threshold", nms_cfg.pop("iou_thr", 0.5)) |
| score_threshold = nms_cfg.pop("score_threshold", 0) |
| max_num = nms_cfg.pop("max_num", -1) |
|
|
| if class_agnostic: |
| boxes_for_nms = boxes |
| else: |
| max_coordinate = boxes.max() |
| offsets = idxs.to(boxes) * (max_coordinate + boxes.new_tensor(1)) |
| boxes_for_nms = boxes + offsets[:, None] |
|
|
| if score_threshold > 0: |
| valid = scores > score_threshold |
| original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1) |
| boxes_for_nms = boxes_for_nms[valid] |
| scores_for_nms = scores[valid] |
| else: |
| original_indices = torch.arange(scores.numel(), device=scores.device) |
| scores_for_nms = scores |
|
|
| keep_local = _torchvision_nms()(boxes_for_nms, scores_for_nms, float(iou_threshold)) |
| if max_num > 0: |
| keep_local = keep_local[:max_num] |
|
|
| keep = original_indices[keep_local] |
| dets = torch.cat((boxes[keep], scores[keep, None]), dim=1) |
| return dets, keep |
|
|
|
|
| def _nms_match(dets: Any, iou_threshold: float) -> list[Any]: |
| """Pure PyTorch fallback for mmcv.ops.nms_match import paths.""" |
| import torch |
|
|
| if dets.numel() == 0: |
| return [] |
|
|
| boxes = dets[:, :4] |
| scores = dets[:, 4] |
| order = scores.argsort(descending=True) |
| groups = [] |
|
|
| while order.numel() > 0: |
| current = order[0] |
| if order.numel() == 1: |
| groups.append(current.reshape(1)) |
| break |
|
|
| rest = order[1:] |
| lt = torch.maximum(boxes[current, :2], boxes[rest, :2]) |
| rb = torch.minimum(boxes[current, 2:], boxes[rest, 2:]) |
| wh = (rb - lt).clamp(min=0) |
| inter = wh[:, 0] * wh[:, 1] |
| current_area = (boxes[current, 2] - boxes[current, 0]).clamp(min=0) * ( |
| boxes[current, 3] - boxes[current, 1] |
| ).clamp(min=0) |
| rest_area = (boxes[rest, 2] - boxes[rest, 0]).clamp(min=0) * ( |
| boxes[rest, 3] - boxes[rest, 1] |
| ).clamp(min=0) |
| iou = inter / (current_area + rest_area - inter).clamp(min=1e-6) |
| matched = rest[iou > float(iou_threshold)] |
| groups.append(torch.cat((current.reshape(1), matched))) |
| order = rest[iou <= float(iou_threshold)] |
|
|
| return groups |
|
|
|
|
| def _point_sample(input: Any, points: Any, align_corners: bool = False, **kwargs: Any) -> Any: |
| import torch.nn.functional as F |
|
|
| add_dim = False |
| if points.dim() == 3: |
| add_dim = True |
| points = points.unsqueeze(2) |
|
|
| output = F.grid_sample(input, points.mul(2).sub(1), align_corners=align_corners, **kwargs) |
| if add_dim: |
| output = output.squeeze(3) |
| return output |
|
|
|
|
| def _rel_roi_point_to_rel_img_point(rois: Any, rel_roi_points: Any, img_shape: Any) -> Any: |
| x1, y1, x2, y2 = rois[:, 1], rois[:, 2], rois[:, 3], rois[:, 4] |
| roi_w = (x2 - x1).clamp(min=1) |
| roi_h = (y2 - y1).clamp(min=1) |
| img_h, img_w = img_shape[:2] |
| rel_img_points = rel_roi_points.clone() |
| rel_img_points[..., 0] = (x1[:, None] + rel_roi_points[..., 0] * roi_w[:, None]) / float(img_w) |
| rel_img_points[..., 1] = (y1[:, None] + rel_roi_points[..., 1] * roi_h[:, None]) / float(img_h) |
| return rel_img_points |
|
|
|
|
| def _sigmoid_focal_loss( |
| pred: Any, |
| target: Any, |
| gamma: float = 2.0, |
| alpha: float = 0.25, |
| weight: Any = None, |
| reduction: str = "mean", |
| ) -> Any: |
| import torch.nn.functional as F |
|
|
| pred_sigmoid = pred.sigmoid() |
| target = target.type_as(pred) |
| pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) |
| focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) |
| loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight |
| if weight is not None: |
| loss = loss * weight |
| if reduction == "sum": |
| return loss.sum() |
| if reduction == "mean": |
| return loss.mean() |
| return loss |
|
|
|
|
| def _ensure_module(module_name: str) -> types.ModuleType: |
| module = sys.modules.get(module_name) |
| if module is None: |
| module = types.ModuleType(module_name) |
| sys.modules[module_name] = module |
| return module |
|
|
|
|
| _GRADIO_CSS_PATCH_PATH: Path | None = None |
|
|
|
|
| def _normalize_gradio_css_paths(css_paths: Any) -> list[str]: |
| if css_paths is None: |
| return [] |
| if isinstance(css_paths, (str, Path)): |
| return [str(css_paths)] |
| return [str(path) for path in css_paths] |
|
|
|
|
| def _persist_egoforce_gradio_css(css: str) -> str: |
| global _GRADIO_CSS_PATCH_PATH |
|
|
| if _GRADIO_CSS_PATCH_PATH is None: |
| _GRADIO_CSS_PATCH_PATH = Path(tempfile.gettempdir()) / "egoforce-gradio-launch.css" |
| _GRADIO_CSS_PATCH_PATH.write_text(css, encoding="utf-8") |
| return str(_GRADIO_CSS_PATCH_PATH) |
|
|
|
|
| def _patch_gradio_launch() -> None: |
| try: |
| import gradio as gr |
| except ImportError: |
| return |
|
|
| launch_method = getattr(gr.Blocks, "launch", None) |
| if launch_method is None or getattr(launch_method, "__egoforce_runtime_patch__", False): |
| return |
|
|
| @functools.wraps(launch_method) |
| def patched_launch(self: Any, *args: Any, **kwargs: Any) -> Any: |
| css = kwargs.get("css") |
| if isinstance(css, str) and css.strip() and (".egoforce-hero" in css or "#sample-video-carousel" in css): |
| css_path_entries = _normalize_gradio_css_paths(kwargs.get("css_paths")) |
| patched_css_path = _persist_egoforce_gradio_css(css) |
| if patched_css_path not in css_path_entries: |
| css_path_entries.append(patched_css_path) |
| kwargs["css_paths"] = css_path_entries |
| kwargs["css"] = None |
| return launch_method(self, *args, **kwargs) |
|
|
| setattr(patched_launch, "__egoforce_runtime_patch__", True) |
| gr.Blocks.launch = patched_launch |
|
|
|
|
| def apply_runtime_patches() -> None: |
| _patch_gradio_launch() |
|
|
| try: |
| mmcv = importlib.import_module("mmcv") |
| except ImportError: |
| mmcv = None |
|
|
| ops_module = sys.modules.get("mmcv.ops") |
| if ops_module is None: |
| ops_module = types.ModuleType("mmcv.ops") |
| sys.modules["mmcv.ops"] = ops_module |
| ops_module.__path__ = [] |
|
|
| nms_module = _ensure_module("mmcv.ops.nms") |
|
|
| roi_align_module = _ensure_module("mmcv.ops.roi_align") |
| deform_conv_module = _ensure_module("mmcv.ops.deform_conv") |
| modulated_deform_conv_module = _ensure_module("mmcv.ops.modulated_deform_conv") |
| carafe_module = _ensure_module("mmcv.ops.carafe") |
| merge_cells_module = _ensure_module("mmcv.ops.merge_cells") |
| multi_scale_deform_attn_module = _ensure_module("mmcv.ops.multi_scale_deform_attn") |
|
|
| deform_conv2d = _unsupported_function("deform_conv2d") |
| DeformConv2d = _unsupported_module_class("DeformConv2d") |
| ModulatedDeformConv2d = _unsupported_module_class("ModulatedDeformConv2d") |
| MaskedConv2d = _unsupported_module_class("MaskedConv2d") |
| CornerPool = _unsupported_module_class("CornerPool") |
| CARAFEPack = _unsupported_module_class("CARAFEPack") |
| GlobalPoolingCell = _unsupported_module_class("GlobalPoolingCell") |
| SumCell = _unsupported_module_class("SumCell") |
| ConcatCell = _unsupported_module_class("ConcatCell") |
| MultiScaleDeformableAttention = _multiscale_deformable_attention_class() |
|
|
| ops_module.nms = _nms |
| ops_module.batched_nms = _batched_nms |
| ops_module.nms_match = _nms_match |
| ops_module.point_sample = _point_sample |
| ops_module.rel_roi_point_to_rel_img_point = _rel_roi_point_to_rel_img_point |
| ops_module.sigmoid_focal_loss = _sigmoid_focal_loss |
| ops_module.bbox_overlaps = _bbox_overlaps |
| ops_module.roi_align = _torchvision_roi_align() |
| ops_module.RoIAlign = _torchvision_roi_align_module() |
| ops_module.RoIPool = _torchvision_roi_pool_module() |
| ops_module.deform_conv2d = deform_conv2d |
| ops_module.DeformConv2d = DeformConv2d |
| ops_module.ModulatedDeformConv2d = ModulatedDeformConv2d |
| ops_module.MaskedConv2d = MaskedConv2d |
| ops_module.CornerPool = CornerPool |
| ops_module.CARAFEPack = CARAFEPack |
| ops_module.GlobalPoolingCell = GlobalPoolingCell |
| ops_module.SumCell = SumCell |
| ops_module.ConcatCell = ConcatCell |
| ops_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention |
| nms_module.nms = _nms |
| nms_module.batched_nms = _batched_nms |
| roi_align_module.roi_align = ops_module.roi_align |
| roi_align_module.RoIAlign = ops_module.RoIAlign |
| deform_conv_module.deform_conv2d = deform_conv2d |
| deform_conv_module.DeformConv2d = DeformConv2d |
| modulated_deform_conv_module.ModulatedDeformConv2d = ModulatedDeformConv2d |
| carafe_module.CARAFEPack = CARAFEPack |
| merge_cells_module.GlobalPoolingCell = GlobalPoolingCell |
| merge_cells_module.SumCell = SumCell |
| merge_cells_module.ConcatCell = ConcatCell |
| multi_scale_deform_attn_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention |
|
|
| try: |
| transformer_module = importlib.import_module("mmcv.cnn.bricks.transformer") |
| if not hasattr(transformer_module, "MultiScaleDeformableAttention"): |
| transformer_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention |
| except ImportError: |
| pass |
|
|
| if mmcv is not None: |
| mmcv.ops = ops_module |
|
|