EgoForce / egoforce_runtime_patches.py
Christen Millerdurai
bug fix
f0a5ba2
from __future__ import annotations
import functools
import importlib
import sys
import tempfile
import types
from pathlib import Path
from typing import Any
def _torchvision_nms():
from torchvision.ops import nms
return nms
def _torchvision_roi_align():
from torchvision.ops import roi_align
return roi_align
def _torchvision_roi_align_module():
from torchvision.ops import RoIAlign
return RoIAlign
def _torchvision_roi_pool_module():
from torchvision.ops import RoIPool
return RoIPool
def _multiscale_deformable_attention_class():
import torch.nn as nn
class MultiScaleDeformableAttention(nn.Module):
"""Import-only fallback for mmdet registries when running with mmcv-lite."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def init_weights(self) -> None:
return None
def forward(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"MultiScaleDeformableAttention requires full mmcv with compiled ops. "
"The EgoForce demo uses RTMDet and should not execute this layer."
)
return MultiScaleDeformableAttention
def _unsupported_module_class(name: str):
import torch.nn as nn
class UnsupportedMMCVOp(nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def init_weights(self) -> None:
return None
def forward(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.")
UnsupportedMMCVOp.__name__ = name
return UnsupportedMMCVOp
def _unsupported_function(name: str):
def unsupported(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(f"{name} requires full mmcv with compiled ops and is not used by the EgoForce RTMDet demo.")
unsupported.__name__ = name
return unsupported
def _bbox_overlaps(
bboxes1: Any,
bboxes2: Any,
mode: str = "iou",
aligned: bool = False,
offset: int = 0,
eps: float = 1e-6,
) -> Any:
import torch
if bboxes1.numel() == 0 or bboxes2.numel() == 0:
if aligned:
return bboxes1.new_zeros((bboxes1.shape[0],))
return bboxes1.new_zeros((bboxes1.shape[0], bboxes2.shape[0]))
if aligned:
lt = torch.maximum(bboxes1[:, :2], bboxes2[:, :2])
rb = torch.minimum(bboxes1[:, 2:], bboxes2[:, 2:])
wh = (rb - lt + offset).clamp(min=0)
overlap = wh[:, 0] * wh[:, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset)
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset)
else:
lt = torch.maximum(bboxes1[:, None, :2], bboxes2[None, :, :2])
rb = torch.minimum(bboxes1[:, None, 2:], bboxes2[None, :, 2:])
wh = (rb - lt + offset).clamp(min=0)
overlap = wh[..., 0] * wh[..., 1]
area1 = ((bboxes1[:, 2] - bboxes1[:, 0] + offset) * (bboxes1[:, 3] - bboxes1[:, 1] + offset))[:, None]
area2 = ((bboxes2[:, 2] - bboxes2[:, 0] + offset) * (bboxes2[:, 3] - bboxes2[:, 1] + offset))[None, :]
if mode == "iof":
union = area1
elif mode == "giou":
union = area1 + area2 - overlap
if aligned:
enclosed_lt = torch.minimum(bboxes1[:, :2], bboxes2[:, :2])
enclosed_rb = torch.maximum(bboxes1[:, 2:], bboxes2[:, 2:])
else:
enclosed_lt = torch.minimum(bboxes1[:, None, :2], bboxes2[None, :, :2])
enclosed_rb = torch.maximum(bboxes1[:, None, 2:], bboxes2[None, :, 2:])
enclosed_wh = (enclosed_rb - enclosed_lt + offset).clamp(min=0)
enclosed_area = enclosed_wh[..., 0] * enclosed_wh[..., 1]
iou = overlap / union.clamp(min=eps)
return iou - (enclosed_area - union) / enclosed_area.clamp(min=eps)
else:
union = area1 + area2 - overlap
return overlap / union.clamp(min=eps)
def _nms(
boxes: Any,
scores: Any,
iou_threshold: float,
offset: int = 0,
score_threshold: float = 0,
max_num: int = -1,
) -> tuple[Any, Any]:
import torch
if boxes.numel() == 0 or scores.numel() == 0:
keep = torch.empty((0,), dtype=torch.long, device=scores.device)
dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1)
return dets, keep
if score_threshold > 0:
valid = scores > score_threshold
original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1)
filtered_boxes = boxes[valid]
filtered_scores = scores[valid]
else:
original_indices = torch.arange(scores.numel(), device=scores.device)
filtered_boxes = boxes
filtered_scores = scores
keep_local = _torchvision_nms()(filtered_boxes, filtered_scores, float(iou_threshold))
if max_num > 0:
keep_local = keep_local[:max_num]
keep = original_indices[keep_local]
dets = torch.cat((filtered_boxes[keep_local], filtered_scores[keep_local, None]), dim=1)
return dets, keep
def _batched_nms(
boxes: Any,
scores: Any,
idxs: Any,
nms_cfg: dict[str, Any] | None,
class_agnostic: bool = False,
) -> tuple[Any, Any]:
import torch
if boxes.numel() == 0 or scores.numel() == 0:
keep = torch.empty((0,), dtype=torch.long, device=scores.device)
dets = torch.cat((boxes.reshape(0, boxes.shape[-1]), scores.reshape(0, 1)), dim=1)
return dets, keep
if nms_cfg is None:
order = scores.argsort(descending=True)
return torch.cat((boxes[order], scores[order, None]), dim=1), order
nms_cfg = dict(nms_cfg)
iou_threshold = nms_cfg.pop("iou_threshold", nms_cfg.pop("iou_thr", 0.5))
score_threshold = nms_cfg.pop("score_threshold", 0)
max_num = nms_cfg.pop("max_num", -1)
if class_agnostic:
boxes_for_nms = boxes
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + boxes.new_tensor(1))
boxes_for_nms = boxes + offsets[:, None]
if score_threshold > 0:
valid = scores > score_threshold
original_indices = torch.nonzero(valid, as_tuple=False).squeeze(1)
boxes_for_nms = boxes_for_nms[valid]
scores_for_nms = scores[valid]
else:
original_indices = torch.arange(scores.numel(), device=scores.device)
scores_for_nms = scores
keep_local = _torchvision_nms()(boxes_for_nms, scores_for_nms, float(iou_threshold))
if max_num > 0:
keep_local = keep_local[:max_num]
keep = original_indices[keep_local]
dets = torch.cat((boxes[keep], scores[keep, None]), dim=1)
return dets, keep
def _nms_match(dets: Any, iou_threshold: float) -> list[Any]:
"""Pure PyTorch fallback for mmcv.ops.nms_match import paths."""
import torch
if dets.numel() == 0:
return []
boxes = dets[:, :4]
scores = dets[:, 4]
order = scores.argsort(descending=True)
groups = []
while order.numel() > 0:
current = order[0]
if order.numel() == 1:
groups.append(current.reshape(1))
break
rest = order[1:]
lt = torch.maximum(boxes[current, :2], boxes[rest, :2])
rb = torch.minimum(boxes[current, 2:], boxes[rest, 2:])
wh = (rb - lt).clamp(min=0)
inter = wh[:, 0] * wh[:, 1]
current_area = (boxes[current, 2] - boxes[current, 0]).clamp(min=0) * (
boxes[current, 3] - boxes[current, 1]
).clamp(min=0)
rest_area = (boxes[rest, 2] - boxes[rest, 0]).clamp(min=0) * (
boxes[rest, 3] - boxes[rest, 1]
).clamp(min=0)
iou = inter / (current_area + rest_area - inter).clamp(min=1e-6)
matched = rest[iou > float(iou_threshold)]
groups.append(torch.cat((current.reshape(1), matched)))
order = rest[iou <= float(iou_threshold)]
return groups
def _point_sample(input: Any, points: Any, align_corners: bool = False, **kwargs: Any) -> Any:
import torch.nn.functional as F
add_dim = False
if points.dim() == 3:
add_dim = True
points = points.unsqueeze(2)
output = F.grid_sample(input, points.mul(2).sub(1), align_corners=align_corners, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def _rel_roi_point_to_rel_img_point(rois: Any, rel_roi_points: Any, img_shape: Any) -> Any:
x1, y1, x2, y2 = rois[:, 1], rois[:, 2], rois[:, 3], rois[:, 4]
roi_w = (x2 - x1).clamp(min=1)
roi_h = (y2 - y1).clamp(min=1)
img_h, img_w = img_shape[:2]
rel_img_points = rel_roi_points.clone()
rel_img_points[..., 0] = (x1[:, None] + rel_roi_points[..., 0] * roi_w[:, None]) / float(img_w)
rel_img_points[..., 1] = (y1[:, None] + rel_roi_points[..., 1] * roi_h[:, None]) / float(img_h)
return rel_img_points
def _sigmoid_focal_loss(
pred: Any,
target: Any,
gamma: float = 2.0,
alpha: float = 0.25,
weight: Any = None,
reduction: str = "mean",
) -> Any:
import torch.nn.functional as F
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight
if weight is not None:
loss = loss * weight
if reduction == "sum":
return loss.sum()
if reduction == "mean":
return loss.mean()
return loss
def _ensure_module(module_name: str) -> types.ModuleType:
module = sys.modules.get(module_name)
if module is None:
module = types.ModuleType(module_name)
sys.modules[module_name] = module
return module
_GRADIO_CSS_PATCH_PATH: Path | None = None
def _normalize_gradio_css_paths(css_paths: Any) -> list[str]:
if css_paths is None:
return []
if isinstance(css_paths, (str, Path)):
return [str(css_paths)]
return [str(path) for path in css_paths]
def _persist_egoforce_gradio_css(css: str) -> str:
global _GRADIO_CSS_PATCH_PATH
if _GRADIO_CSS_PATCH_PATH is None:
_GRADIO_CSS_PATCH_PATH = Path(tempfile.gettempdir()) / "egoforce-gradio-launch.css"
_GRADIO_CSS_PATCH_PATH.write_text(css, encoding="utf-8")
return str(_GRADIO_CSS_PATCH_PATH)
def _patch_gradio_launch() -> None:
try:
import gradio as gr
except ImportError:
return
launch_method = getattr(gr.Blocks, "launch", None)
if launch_method is None or getattr(launch_method, "__egoforce_runtime_patch__", False):
return
@functools.wraps(launch_method)
def patched_launch(self: Any, *args: Any, **kwargs: Any) -> Any:
css = kwargs.get("css")
if isinstance(css, str) and css.strip() and (".egoforce-hero" in css or "#sample-video-carousel" in css):
css_path_entries = _normalize_gradio_css_paths(kwargs.get("css_paths"))
patched_css_path = _persist_egoforce_gradio_css(css)
if patched_css_path not in css_path_entries:
css_path_entries.append(patched_css_path)
kwargs["css_paths"] = css_path_entries
kwargs["css"] = None
return launch_method(self, *args, **kwargs)
setattr(patched_launch, "__egoforce_runtime_patch__", True)
gr.Blocks.launch = patched_launch
def apply_runtime_patches() -> None:
_patch_gradio_launch()
try:
mmcv = importlib.import_module("mmcv")
except ImportError:
mmcv = None
ops_module = sys.modules.get("mmcv.ops")
if ops_module is None:
ops_module = types.ModuleType("mmcv.ops")
sys.modules["mmcv.ops"] = ops_module
ops_module.__path__ = []
nms_module = _ensure_module("mmcv.ops.nms")
roi_align_module = _ensure_module("mmcv.ops.roi_align")
deform_conv_module = _ensure_module("mmcv.ops.deform_conv")
modulated_deform_conv_module = _ensure_module("mmcv.ops.modulated_deform_conv")
carafe_module = _ensure_module("mmcv.ops.carafe")
merge_cells_module = _ensure_module("mmcv.ops.merge_cells")
multi_scale_deform_attn_module = _ensure_module("mmcv.ops.multi_scale_deform_attn")
deform_conv2d = _unsupported_function("deform_conv2d")
DeformConv2d = _unsupported_module_class("DeformConv2d")
ModulatedDeformConv2d = _unsupported_module_class("ModulatedDeformConv2d")
MaskedConv2d = _unsupported_module_class("MaskedConv2d")
CornerPool = _unsupported_module_class("CornerPool")
CARAFEPack = _unsupported_module_class("CARAFEPack")
GlobalPoolingCell = _unsupported_module_class("GlobalPoolingCell")
SumCell = _unsupported_module_class("SumCell")
ConcatCell = _unsupported_module_class("ConcatCell")
MultiScaleDeformableAttention = _multiscale_deformable_attention_class()
ops_module.nms = _nms
ops_module.batched_nms = _batched_nms
ops_module.nms_match = _nms_match
ops_module.point_sample = _point_sample
ops_module.rel_roi_point_to_rel_img_point = _rel_roi_point_to_rel_img_point
ops_module.sigmoid_focal_loss = _sigmoid_focal_loss
ops_module.bbox_overlaps = _bbox_overlaps
ops_module.roi_align = _torchvision_roi_align()
ops_module.RoIAlign = _torchvision_roi_align_module()
ops_module.RoIPool = _torchvision_roi_pool_module()
ops_module.deform_conv2d = deform_conv2d
ops_module.DeformConv2d = DeformConv2d
ops_module.ModulatedDeformConv2d = ModulatedDeformConv2d
ops_module.MaskedConv2d = MaskedConv2d
ops_module.CornerPool = CornerPool
ops_module.CARAFEPack = CARAFEPack
ops_module.GlobalPoolingCell = GlobalPoolingCell
ops_module.SumCell = SumCell
ops_module.ConcatCell = ConcatCell
ops_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention
nms_module.nms = _nms
nms_module.batched_nms = _batched_nms
roi_align_module.roi_align = ops_module.roi_align
roi_align_module.RoIAlign = ops_module.RoIAlign
deform_conv_module.deform_conv2d = deform_conv2d
deform_conv_module.DeformConv2d = DeformConv2d
modulated_deform_conv_module.ModulatedDeformConv2d = ModulatedDeformConv2d
carafe_module.CARAFEPack = CARAFEPack
merge_cells_module.GlobalPoolingCell = GlobalPoolingCell
merge_cells_module.SumCell = SumCell
merge_cells_module.ConcatCell = ConcatCell
multi_scale_deform_attn_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention
try:
transformer_module = importlib.import_module("mmcv.cnn.bricks.transformer")
if not hasattr(transformer_module, "MultiScaleDeformableAttention"):
transformer_module.MultiScaleDeformableAttention = MultiScaleDeformableAttention
except ImportError:
pass
if mmcv is not None:
mmcv.ops = ops_module