| import os |
| import cv2 |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch |
| import random |
| import math |
| from matplotlib.patches import Rectangle |
| import itertools |
| from typing import Any, Dict, List, Tuple, Optional, Union |
|
|
| from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def clean_label(label): |
| """Replace underscores and slashes with spaces for uniformity.""" |
| return label.replace("_", " ").replace("/", " ") |
|
|
| |
| def format_cate_preds(cate_preds): |
| |
| obj_pred_dict = {} |
| for (oid, label), prob in cate_preds.items(): |
| |
| clean_pred = clean_label(label) |
| if oid not in obj_pred_dict: |
| obj_pred_dict[oid] = [] |
| obj_pred_dict[oid].append((clean_pred, prob)) |
| for oid in obj_pred_dict: |
| obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True) |
| return obj_pred_dict |
|
|
| def format_binary_cate_preds(binary_preds): |
| frame_binary_preds = [] |
| for key, score in binary_preds.items(): |
| |
| try: |
| f_id, (subj, obj), pred_rel = key |
| frame_binary_preds.append((f_id, subj, obj, pred_rel, score)) |
| except Exception as e: |
| print("Skipping key with unexpected format:", key) |
| continue |
| frame_binary_preds.sort(key=lambda x: x[3], reverse=True) |
| return frame_binary_preds |
|
|
| _FONT = cv2.FONT_HERSHEY_SIMPLEX |
|
|
|
|
| def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]: |
| if mask is None: |
| return None |
| if isinstance(mask, torch.Tensor): |
| mask_np = mask.detach().cpu().numpy() |
| else: |
| mask_np = np.asarray(mask) |
| if mask_np.ndim == 0: |
| return None |
| if mask_np.ndim == 3: |
| mask_np = np.squeeze(mask_np) |
| if mask_np.ndim != 2: |
| return None |
| if mask_np.dtype == bool: |
| return mask_np |
| return mask_np > 0 |
|
|
|
|
| def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]: |
| if bbox is None: |
| return None |
| if isinstance(bbox, (list, tuple)) and len(bbox) >= 4: |
| x1, y1, x2, y2 = [float(b) for b in bbox[:4]] |
| elif isinstance(bbox, np.ndarray) and bbox.size >= 4: |
| x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]] |
| else: |
| return None |
| x1 = int(np.clip(round(x1), 0, width - 1)) |
| y1 = int(np.clip(round(y1), 0, height - 1)) |
| x2 = int(np.clip(round(x2), 0, width - 1)) |
| y2 = int(np.clip(round(y2), 0, height - 1)) |
| if x2 <= x1 or y2 <= y1: |
| return None |
| return (x1, y1, x2, y2) |
|
|
|
|
| def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]: |
| color = get_color(obj_id) |
| rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]] |
| return (rgb[2], rgb[1], rgb[0]) |
|
|
|
|
| def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]: |
| return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color) |
|
|
|
|
| def _draw_label_block( |
| image: np.ndarray, |
| lines: List[str], |
| anchor: Tuple[int, int], |
| color: Tuple[int, int, int], |
| font_scale: float = 0.5, |
| thickness: int = 1, |
| direction: str = "up", |
| ) -> None: |
| if not lines: |
| return |
| img_h, img_w = image.shape[:2] |
| x, y = anchor |
| x = int(np.clip(x, 0, img_w - 1)) |
| y_cursor = int(np.clip(y, 0, img_h - 1)) |
| bg_color = _background_color(color) |
|
|
| if direction == "down": |
| for text in lines: |
| text = str(text) |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
| left_x = x |
| right_x = min(left_x + tw + 8, img_w - 1) |
| top_y = int(np.clip(y_cursor + 6, 0, img_h - 1)) |
| bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) |
| if bottom_y <= top_y: |
| break |
| cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) |
| text_x = left_x + 4 |
| text_y = min(bottom_y - baseline - 2, img_h - 1) |
| cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
| y_cursor = bottom_y |
| else: |
| for text in lines: |
| text = str(text) |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
| top_y = max(y_cursor - th - baseline - 6, 0) |
| left_x = x |
| right_x = min(left_x + tw + 8, img_w - 1) |
| bottom_y = min(top_y + th + baseline + 6, img_h - 1) |
| cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) |
| text_x = left_x + 4 |
| text_y = min(bottom_y - baseline - 2, img_h - 1) |
| cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
| y_cursor = top_y |
|
|
|
|
| def _draw_centered_label( |
| image: np.ndarray, |
| text: str, |
| center: Tuple[int, int], |
| color: Tuple[int, int, int], |
| font_scale: float = 0.5, |
| thickness: int = 1, |
| ) -> None: |
| text = str(text) |
| img_h, img_w = image.shape[:2] |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
| cx = int(np.clip(center[0], 0, img_w - 1)) |
| cy = int(np.clip(center[1], 0, img_h - 1)) |
| left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1)) |
| top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1)) |
| right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1)) |
| bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) |
| cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1) |
| text_x = left_x + 4 |
| text_y = min(bottom_y - baseline - 2, img_h - 1) |
| cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
|
|
|
|
| def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]: |
| if isinstance(store, dict): |
| frame_entry = store.get(frame_idx, {}) |
| elif isinstance(store, list) and 0 <= frame_idx < len(store): |
| frame_entry = store[frame_idx] |
| else: |
| frame_entry = {} |
| if isinstance(frame_entry, dict): |
| return frame_entry |
| if isinstance(frame_entry, list): |
| return {i: value for i, value in enumerate(frame_entry)} |
| return {} |
|
|
|
|
| def _label_anchor_and_direction( |
| bbox: Tuple[int, int, int, int], |
| position: str, |
| ) -> Tuple[Tuple[int, int], str]: |
| x1, y1, x2, y2 = bbox |
| if position == "bottom": |
| return (x1, y2), "down" |
| return (x1, y1), "up" |
|
|
|
|
| def _draw_bbox_with_label( |
| image: np.ndarray, |
| bbox: Tuple[int, int, int, int], |
| obj_id: int, |
| title: Optional[str] = None, |
| sub_lines: Optional[List[str]] = None, |
| label_position: str = "top", |
| ) -> None: |
| color = _object_color_bgr(obj_id) |
| cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) |
| head = title if title else f"#{obj_id}" |
| if not head.startswith("#"): |
| head = f"#{obj_id} {head}" |
| lines = [head] |
| if sub_lines: |
| lines.extend(sub_lines) |
| anchor, direction = _label_anchor_and_direction(bbox, label_position) |
| _draw_label_block(image, lines, anchor, color, direction=direction) |
|
|
|
|
| def render_sam_frames( |
| frames: Union[np.ndarray, List[np.ndarray]], |
| sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None], |
| dino_labels: Optional[Dict[int, str]] = None, |
| ) -> List[np.ndarray]: |
| results: List[np.ndarray] = [] |
| frames_iterable = frames if isinstance(frames, list) else list(frames) |
| dino_labels = dino_labels or {} |
|
|
| for frame_idx, frame in enumerate(frames_iterable): |
| if frame is None: |
| continue |
| frame_rgb = np.asarray(frame) |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
| overlay = frame_bgr.astype(np.float32) |
| masks_for_frame = _extract_frame_entities(sam_masks, frame_idx) |
|
|
| for obj_id, mask in masks_for_frame.items(): |
| mask_np = _to_numpy_mask(mask) |
| if mask_np is None or not np.any(mask_np): |
| continue |
| color = _object_color_bgr(obj_id) |
| alpha = 0.45 |
| overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32) |
|
|
| annotated = np.clip(overlay, 0, 255).astype(np.uint8) |
| frame_h, frame_w = annotated.shape[:2] |
|
|
| for obj_id, mask in masks_for_frame.items(): |
| mask_np = _to_numpy_mask(mask) |
| if mask_np is None or not np.any(mask_np): |
| continue |
| bbox = mask_to_bbox(mask_np) |
| bbox = _sanitize_bbox(bbox, frame_w, frame_h) |
| if not bbox: |
| continue |
| label = dino_labels.get(obj_id) |
| title = f"{label}" if label else None |
| _draw_bbox_with_label(annotated, bbox, obj_id, title=title) |
|
|
| results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) |
|
|
| return results |
|
|
|
|
| def render_dino_frames( |
| frames: Union[np.ndarray, List[np.ndarray]], |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
| dino_labels: Optional[Dict[int, str]] = None, |
| ) -> List[np.ndarray]: |
| results: List[np.ndarray] = [] |
| frames_iterable = frames if isinstance(frames, list) else list(frames) |
| dino_labels = dino_labels or {} |
|
|
| for frame_idx, frame in enumerate(frames_iterable): |
| if frame is None: |
| continue |
| frame_rgb = np.asarray(frame) |
| annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
| frame_h, frame_w = annotated.shape[:2] |
| frame_bboxes = _extract_frame_entities(bboxes, frame_idx) |
|
|
| for obj_id, bbox_values in frame_bboxes.items(): |
| bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) |
| if not bbox: |
| continue |
| label = dino_labels.get(obj_id) |
| title = f"{label}" if label else None |
| _draw_bbox_with_label(annotated, bbox, obj_id, title=title) |
|
|
| results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) |
|
|
| return results |
|
|
|
|
| def render_vine_frame_sets( |
| frames: Union[np.ndarray, List[np.ndarray]], |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
| cat_label_lookup: Dict[int, Tuple[str, float]], |
| unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], |
| binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], |
| masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None, |
| ) -> Dict[str, List[np.ndarray]]: |
| frame_groups: Dict[str, List[np.ndarray]] = { |
| "object": [], |
| "unary": [], |
| "binary": [], |
| "all": [], |
| } |
| frames_iterable = frames if isinstance(frames, list) else list(frames) |
|
|
| for frame_idx, frame in enumerate(frames_iterable): |
| if frame is None: |
| continue |
| frame_rgb = np.asarray(frame) |
| base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
| frame_h, frame_w = base_bgr.shape[:2] |
| frame_bboxes = _extract_frame_entities(bboxes, frame_idx) |
| frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {} |
|
|
| objects_bgr = base_bgr.copy() |
| unary_bgr = base_bgr.copy() |
| binary_bgr = base_bgr.copy() |
| all_bgr = base_bgr.copy() |
|
|
| bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {} |
| unary_lines_lookup: Dict[int, List[str]] = {} |
| titles_lookup: Dict[int, Optional[str]] = {} |
|
|
| for obj_id, bbox_values in frame_bboxes.items(): |
| bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) |
| if not bbox: |
| continue |
| bbox_lookup[obj_id] = bbox |
| cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None)) |
| title_parts = [] |
| if cat_label: |
| if cat_prob is not None: |
| title_parts.append(f"{cat_label} {cat_prob:.2f}") |
| else: |
| title_parts.append(cat_label) |
| titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None |
| unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, []) |
| unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds] |
| unary_lines_lookup[obj_id] = unary_lines |
|
|
| for obj_id, bbox in bbox_lookup.items(): |
| unary_lines = unary_lines_lookup.get(obj_id, []) |
| if not unary_lines: |
| continue |
| mask_raw = frame_masks.get(obj_id) |
| mask_np = _to_numpy_mask(mask_raw) |
| if mask_np is None or not np.any(mask_np): |
| continue |
| color = np.array(_object_color_bgr(obj_id), dtype=np.float32) |
| alpha = 0.45 |
| for target in (unary_bgr, all_bgr): |
| target_vals = target[mask_np].astype(np.float32) |
| blended = (1.0 - alpha) * target_vals + alpha * color |
| target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8) |
|
|
| for obj_id, bbox in bbox_lookup.items(): |
| title = titles_lookup.get(obj_id) |
| unary_lines = unary_lines_lookup.get(obj_id, []) |
| _draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top") |
| _draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top") |
| if unary_lines: |
| anchor, direction = _label_anchor_and_direction(bbox, "bottom") |
| _draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction) |
| _draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top") |
| _draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top") |
| if unary_lines: |
| anchor, direction = _label_anchor_and_direction(bbox, "bottom") |
| _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction) |
|
|
| for obj_pair, relation_preds in binary_lookup.get(frame_idx, []): |
| if len(obj_pair) != 2 or not relation_preds: |
| continue |
| subj_id, obj_id = obj_pair |
| subj_bbox = bbox_lookup.get(subj_id) |
| obj_bbox = bbox_lookup.get(obj_id) |
| if not subj_bbox or not obj_bbox: |
| continue |
| start, end = relation_line(subj_bbox, obj_bbox) |
| color = tuple(int(c) for c in np.clip( |
| (np.array(_object_color_bgr(subj_id), dtype=np.float32) + |
| np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0, |
| 0, 255 |
| )) |
| prob, relation = relation_preds[0] |
| label_text = f"{relation} {prob:.2f}" |
| mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2)) |
| cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA) |
| cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA) |
| _draw_centered_label(binary_bgr, label_text, mid_point, color) |
| _draw_centered_label(all_bgr, label_text, mid_point, color) |
|
|
| frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB)) |
| frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB)) |
| frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB)) |
| frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| return frame_groups |
|
|
|
|
| def render_vine_frames( |
| frames: Union[np.ndarray, List[np.ndarray]], |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
| cat_label_lookup: Dict[int, Tuple[str, float]], |
| unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], |
| binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], |
| masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None, |
| ) -> List[np.ndarray]: |
| return render_vine_frame_sets( |
| frames, |
| bboxes, |
| cat_label_lookup, |
| unary_lookup, |
| binary_lookup, |
| masks, |
| ).get("all", []) |
| |
| def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object): |
| all_colors = [] |
| all_texts = [] |
| for (obj_id, bbox, gt_label) in gt_labels: |
| preds = obj_pred_dict.get(obj_id, []) |
| if len(preds) == 0: |
| top1 = "N/A" |
| box_color = (0, 0, 255) |
| else: |
| top1, prob1 = preds[0] |
| topk_labels = [p[0] for p in preds[:topk_object]] |
| |
| if top1.lower() == gt_label.lower(): |
| box_color = (0, 255, 0) |
| elif gt_label.lower() in [p.lower() for p in topk_labels]: |
| box_color = (0, 165, 255) |
| else: |
| box_color = (0, 0, 255) |
| |
| label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}" |
| all_colors.append(box_color) |
| all_texts.append(label_text) |
| return all_colors, all_texts |
|
|
| def plot_unary(frame_img, gt_labels, all_colors, all_texts): |
| |
| for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts): |
| x1, y1, x2, y2 = map(int, bbox) |
| cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2) |
| (tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
| cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1) |
| cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, |
| 0.5, (0, 0, 0), 1, cv2.LINE_AA) |
| |
| return frame_img |
|
|
| def get_white_pane(pane_height, |
| pane_width=600, |
| header_height = 50, |
| header_font = cv2.FONT_HERSHEY_SIMPLEX, |
| header_font_scale = 0.7, |
| header_thickness = 2, |
| header_color = (0, 0, 0)): |
| |
| white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8) |
| |
| |
| left_width = int(pane_width * 0.6) |
| right_width = pane_width - left_width |
| left_pane = white_pane[:, :left_width, :].copy() |
| right_pane = white_pane[:, left_width:, :].copy() |
| |
| cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30), |
| header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA) |
| cv2.putText(right_pane, "Ground Truth", (10, header_height - 30), |
| header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA) |
| |
| return white_pane |
|
|
| |
| def plot_binary_sg(frame_img, |
| white_pane, |
| bin_preds, |
| gt_relations, |
| topk_binary, |
| header_height=50, |
| indicator_size=20, |
| pane_width=600): |
| |
| line_height = 30 |
| x_text = 10 |
| y_text_left = header_height + 10 |
| y_text_right = header_height + 10 |
| |
| |
| left_width = int(pane_width * 0.6) |
| right_width = pane_width - left_width |
| left_pane = white_pane[:, :left_width, :].copy() |
| right_pane = white_pane[:, left_width:, :].copy() |
| |
| for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]: |
| correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1]) |
| for gt in gt_relations) |
| indicator_color = (0, 255, 0) if correct else (0, 0, 255) |
| cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5), |
| (x_text + indicator_size, y_text_left + 5), indicator_color, -1) |
| text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}" |
| cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA) |
| y_text_left += line_height |
| |
| |
| for gt in gt_relations: |
| if len(gt) != 3: |
| continue |
| text = f"{gt[0]} - {gt[2]} - {gt[1]}" |
| cv2.putText(right_pane, text, (x_text, y_text_right + 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA) |
| y_text_right += line_height |
| |
| |
| combined_pane = np.hstack((left_pane, right_pane)) |
| combined_image = np.hstack((frame_img, combined_pane)) |
| return combined_image |
|
|
| def visualized_frame(frame_img, |
| bboxes, |
| object_ids, |
| gt_labels, |
| cate_preds, |
| binary_preds, |
| gt_relations, |
| topk_object, |
| topk_binary, |
| phase="unary"): |
| |
| """Return the combined annotated frame for frame index i as an image (in BGR).""" |
| |
|
|
| |
| if phase == "unary": |
| objs = [] |
| for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels): |
| gt_label = clean_label(gt_label) |
| objs.append((obj_id, bbox, gt_label)) |
| |
| formatted_cate_preds = format_cate_preds(cate_preds) |
| all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object) |
| updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts) |
| return updated_frame_img |
| |
| else: |
| |
| formatted_binary_preds = format_binary_cate_preds(binary_preds) |
| |
| |
| |
| gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations] |
| |
| pane_width = 600 |
| pane_height = frame_img.shape[0] |
| |
| |
| header_height = 50 |
| white_pane = get_white_pane(pane_height, pane_width, header_height=header_height) |
| |
| combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary) |
| |
| return combined_image |
|
|
| def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False): |
| |
| mask = np.array(mask) |
| |
| if mask.ndim == 3: |
| |
| if mask.shape[0] == 1: |
| mask = mask.squeeze(0) |
| |
| elif mask.shape[2] == 1: |
| mask = mask.squeeze(2) |
| |
| assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}" |
|
|
| if random_color: |
| color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) |
| else: |
| cmap = plt.get_cmap("gist_rainbow") |
| cmap_idx = 0 if obj_id is None else obj_id |
| color = list(cmap((cmap_idx * 47) % 256)) |
| color[3] = 0.5 |
| color = np.array(color) |
| |
| |
| mask_expanded = mask[..., None] |
| mask_image = mask_expanded * color.reshape(1, 1, -1) |
|
|
| |
| if not det_class is None: |
| |
| y_indices, x_indices = np.where(mask > 0) |
| if y_indices.size > 0 and x_indices.size > 0: |
| x_min, x_max = x_indices.min(), x_indices.max() |
| y_min, y_max = y_indices.min(), y_indices.max() |
| rect = Rectangle( |
| (x_min, y_min), |
| x_max - x_min, |
| y_max - y_min, |
| linewidth=1.5, |
| edgecolor=color[:3], |
| facecolor="none", |
| alpha=color[3] |
| ) |
| ax.add_patch(rect) |
| ax.text( |
| x_min, |
| y_min - 5, |
| f"{det_class}", |
| color="white", |
| fontsize=6, |
| backgroundcolor=np.array(color), |
| alpha=1 |
| ) |
| ax.imshow(mask_image) |
|
|
| def save_mask_one_image(frame_image, masks, save_path): |
| """Render masks on top of a frame and store the visualization on disk.""" |
| fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
| frame_np = ( |
| frame_image.detach().cpu().numpy() |
| if torch.is_tensor(frame_image) |
| else np.asarray(frame_image) |
| ) |
| frame_np = np.ascontiguousarray(frame_np) |
|
|
| if isinstance(masks, dict): |
| mask_iter = masks.items() |
| else: |
| mask_iter = enumerate(masks) |
|
|
| prepared_masks = { |
| obj_id: ( |
| mask.detach().cpu().numpy() |
| if torch.is_tensor(mask) |
| else np.asarray(mask) |
| ) |
| for obj_id, mask in mask_iter |
| } |
|
|
| ax.imshow(frame_np) |
| ax.axis("off") |
|
|
| for obj_id, mask_np in prepared_masks.items(): |
| show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False) |
|
|
| fig.savefig(save_path, bbox_inches="tight", pad_inches=0) |
| plt.close(fig) |
| return save_path |
| |
| def get_video_masks_visualization(video_tensor, |
| video_masks, |
| video_id, |
| video_save_base_dir, |
| oid_class_pred=None, |
| sample_rate = 1): |
| |
| video_save_dir = os.path.join(video_save_base_dir, video_id) |
| if not os.path.exists(video_save_dir): |
| os.makedirs(video_save_dir, exist_ok=True) |
| |
| for frame_id, image in enumerate(video_tensor): |
| if frame_id not in video_masks: |
| print("No mask for Frame", frame_id) |
| continue |
| |
| masks = video_masks[frame_id] |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
| get_mask_one_image(image, masks, oid_class_pred) |
|
|
| def get_mask_one_image(frame_image, masks, oid_class_pred=None): |
| |
| fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
| |
| ax.imshow(frame_image) |
| ax.axis('off') |
|
|
| if type(masks) == list: |
| masks = {i: m for i, m in enumerate(masks)} |
| |
| |
| for obj_id, mask in masks.items(): |
| det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None |
| show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False) |
|
|
| |
| return fig, ax |
|
|
| def save_video(frames, output_filename, output_fps): |
| |
| |
| num_frames = len(frames) |
| frame_h, frame_w = frames.shape[:2] |
|
|
| |
| fourcc = cv2.VideoWriter_fourcc(*'avc1') |
| out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h)) |
|
|
| print(f"Processing {num_frames} frames...") |
| for i in range(num_frames): |
| vis_frame = get_visualized_frame(i) |
| out.write(vis_frame) |
| if i % 10 == 0: |
| print(f"Processed frame {i+1}/{num_frames}") |
|
|
| out.release() |
| print(f"Video saved as {output_filename}") |
| |
|
|
| def list_depth(lst): |
| """Calculates the depth of a nested list.""" |
| if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)): |
| return 0 |
| elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0): |
| return 1 |
| else: |
| return 1 + max(list_depth(item) for item in lst) |
| |
| def normalize_prompt(points, labels): |
| if list_depth(points) == 3: |
| points = torch.stack([p.unsqueeze(0) for p in points]) |
| labels = torch.stack([l.unsqueeze(0) for l in labels]) |
| return points, labels |
|
|
|
|
| def show_box(box, ax, object_id): |
| if len(box) == 0: |
| return |
| |
| cmap = plt.get_cmap("gist_rainbow") |
| cmap_idx = 0 if object_id is None else object_id |
| color = list(cmap((cmap_idx * 47) % 256)) |
| |
| x0, y0 = box[0], box[1] |
| w, h = box[2] - box[0], box[3] - box[1] |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2)) |
| |
| def show_points(coords, labels, ax, object_id=None, marker_size=375): |
| if len(labels) == 0: |
| return |
| |
| pos_points = coords[labels==1] |
| neg_points = coords[labels==0] |
| |
| cmap = plt.get_cmap("gist_rainbow") |
| cmap_idx = 0 if object_id is None else object_id |
| color = list(cmap((cmap_idx * 47) % 256)) |
| |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25) |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25) |
| |
| def save_prompts_one_image(frame_image, boxes, points, labels, save_path): |
| |
| fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
| |
| ax.imshow(frame_image) |
| ax.axis('off') |
|
|
| points, labels = normalize_prompt(points, labels) |
| if type(boxes) == torch.Tensor: |
| for object_id, box in enumerate(boxes): |
| |
| if not box is None: |
| show_box(box.cpu(), ax, object_id=object_id) |
| elif type(boxes) == dict: |
| for object_id, box in boxes.items(): |
| |
| if not box is None: |
| show_box(box.cpu(), ax, object_id=object_id) |
| elif type(boxes) == list and len(boxes) == 0: |
| pass |
| else: |
| raise Exception() |
| |
| for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)): |
| if not len(point_ls) == 0: |
| show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id) |
| |
| |
| plt.savefig(save_path) |
| plt.close() |
| |
| def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir): |
| video_save_dir = os.path.join(video_save_base_dir, video_id) |
| if not os.path.exists(video_save_dir): |
| os.makedirs(video_save_dir, exist_ok=True) |
| |
| for frame_id, image in enumerate(video_tensor): |
| boxes, points, labels = [], [], [] |
| |
| if frame_id in video_boxes: |
| boxes = video_boxes[frame_id] |
| |
| if frame_id in video_points: |
| points = video_points[frame_id] |
| if frame_id in video_labels: |
| labels = video_labels[frame_id] |
| |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
| save_prompts_one_image(image, boxes, points, labels, save_path) |
| |
|
|
| def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1): |
| video_save_dir = os.path.join(video_save_base_dir, video_id) |
| if not os.path.exists(video_save_dir): |
| os.makedirs(video_save_dir, exist_ok=True) |
| |
| for frame_id, image in enumerate(video_tensor): |
| if random.random() > sample_rate: |
| continue |
| if frame_id not in video_masks: |
| print("No mask for Frame", frame_id) |
| continue |
| masks = video_masks[frame_id] |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
| save_mask_one_image(image, masks, save_path) |
| |
|
|
|
|
| def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5): |
| cmap = plt.get_cmap(cmap_name) |
| cmap_idx = 0 if obj_id is None else obj_id |
| color = list(cmap((cmap_idx * 47) % 256)) |
| color[3] = 0.5 |
| color = np.array(color) |
| return color |
| |
| |
| def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]: |
| return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0) |
|
|
|
|
| def relation_line( |
| bbox1: Tuple[int, int, int, int], |
| bbox2: Tuple[int, int, int, int], |
| ) -> Tuple[Tuple[int, int], Tuple[int, int]]: |
| """ |
| Returns integer pixel centers suitable for drawing a relation line. For |
| coincident boxes, nudges the target center to ensure the segment has span. |
| """ |
| center1 = _bbox_center(bbox1) |
| center2 = _bbox_center(bbox2) |
| if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3): |
| offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05) |
| center2 = (center2[0] + offset, center2[1]) |
| start = (int(round(center1[0])), int(round(center1[1]))) |
| end = (int(round(center2[0])), int(round(center2[1]))) |
| if start == end: |
| end = (end[0] + 1, end[1]) |
| return start, end |
|
|
| def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None): |
| |
| fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
| |
| ax.imshow(frame_image) |
| ax.axis('off') |
| |
| all_objs_to_show = set() |
| all_lines_to_show = [] |
| |
| |
| for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items(): |
| all_objs_to_show.add(from_obj_id) |
| all_objs_to_show.add(to_obj_id) |
| |
| from_mask = masks[from_obj_id] |
| bbox1 = mask_to_bbox(from_mask) |
| to_mask = masks[to_obj_id] |
| bbox2 = mask_to_bbox(to_mask) |
| |
| c1, c2 = shortest_line_between_bboxes(bbox1, bbox2) |
| |
| line_color = get_color(from_obj_id) |
| face_color = get_color(to_obj_id) |
| line = c1, c2, face_color, line_color, rel_text |
| all_lines_to_show.append(line) |
| |
| masks_to_show = {} |
| for oid in all_objs_to_show: |
| masks_to_show[oid] = masks[oid] |
| |
| |
| for obj_id, mask in masks_to_show.items(): |
| show_mask(mask, ax, obj_id=obj_id, random_color=False) |
|
|
| for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show: |
| |
| plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3) |
| mid_pt_x = (from_pt_x + to_pt_x) / 2 |
| mid_pt_y = (from_pt_y + to_pt_y) / 2 |
| ax.text( |
| mid_pt_x - 5, |
| mid_pt_y, |
| rel_text, |
| color="white", |
| fontsize=6, |
| backgroundcolor=np.array(line_color), |
| bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'), |
| alpha=1 |
| ) |
| |
| |
| return fig, ax |
|
|