| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| MaskType = Union[np.ndarray, torch.Tensor] |
|
|
|
|
| def _to_numpy_mask(mask: MaskType) -> np.ndarray: |
| """ |
| Convert assorted mask formats to a 2D numpy boolean array. |
| """ |
| if isinstance(mask, torch.Tensor): |
| mask_np = mask.detach().cpu().numpy() |
| else: |
| mask_np = np.asarray(mask) |
|
|
| |
| while mask_np.ndim > 2 and mask_np.shape[0] == 1: |
| mask_np = np.squeeze(mask_np, axis=0) |
| if mask_np.ndim > 2 and mask_np.shape[-1] == 1: |
| mask_np = np.squeeze(mask_np, axis=-1) |
|
|
| if mask_np.ndim != 2: |
| raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}") |
|
|
| return mask_np.astype(bool) |
|
|
|
|
| def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]: |
| """ |
| Compute a bounding box for a 2D boolean mask. |
| """ |
| if not mask.any(): |
| return None |
| rows, cols = np.nonzero(mask) |
| y_min, y_max = rows.min(), rows.max() |
| x_min, x_max = cols.min(), cols.max() |
| return x_min, y_min, x_max, y_max |
|
|
|
|
| def flatten_segments_for_batch( |
| video_id: int, |
| segments: Dict[int, Dict[int, MaskType]], |
| bbox_min_dim: int = 5, |
| ) -> Dict[str, List]: |
| """ |
| Flatten nested segmentation data into batched lists suitable for predicate |
| models or downstream visualizations. Mirrors the notebook helper but is |
| robust to differing mask dtypes/shapes. |
| """ |
| batched_object_ids: List[Tuple[int, int, int]] = [] |
| batched_masks: List[np.ndarray] = [] |
| batched_bboxes: List[Tuple[int, int, int, int]] = [] |
| frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] |
|
|
| for frame_id, frame_objects in segments.items(): |
| valid_objects: List[int] = [] |
| for object_id, raw_mask in frame_objects.items(): |
| mask = _to_numpy_mask(raw_mask) |
| bbox = _mask_to_bbox(mask) |
| if bbox is None: |
| continue |
|
|
| x_min, y_min, x_max, y_max = bbox |
| if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim: |
| continue |
|
|
| valid_objects.append(object_id) |
| batched_object_ids.append((video_id, frame_id, object_id)) |
| batched_masks.append(mask) |
| batched_bboxes.append(bbox) |
|
|
| for i in valid_objects: |
| for j in valid_objects: |
| if i == j: |
| continue |
| frame_pairs.append((video_id, frame_id, (i, j))) |
|
|
| return { |
| "object_ids": batched_object_ids, |
| "masks": batched_masks, |
| "bboxes": batched_bboxes, |
| "pairs": frame_pairs, |
| } |
|
|
|
|
| def extract_valid_object_pairs( |
| batched_object_ids: Sequence[Tuple[int, int, int]], |
| interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None, |
| ) -> List[Tuple[int, int, Tuple[int, int]]]: |
| """ |
| Filter object pairs per frame. If `interested_object_pairs` is provided, only |
| emit those combinations when both objects are present; otherwise emit all |
| permutations (i, j) with i != j for each frame. |
| """ |
| frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set) |
| for vid, fid, oid in batched_object_ids: |
| frame_to_objects[(vid, fid)].add(oid) |
|
|
| interested = ( |
| list(interested_object_pairs) |
| if interested_object_pairs is not None |
| else None |
| ) |
|
|
| valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] |
| for (vid, fid), object_ids in frame_to_objects.items(): |
| if interested: |
| for src, dst in interested: |
| if src in object_ids and dst in object_ids: |
| valid_pairs.append((vid, fid, (src, dst))) |
| else: |
| for src in object_ids: |
| for dst in object_ids: |
| if src == dst: |
| continue |
| valid_pairs.append((vid, fid, (src, dst))) |
|
|
| return valid_pairs |
|
|