ASethi04 commited on
Commit
1472aed
·
verified ·
1 Parent(s): 298a67b

Initial upload of VINE model architecture and weights - model and weights

Browse files
Files changed (5) hide show
  1. config.json +6 -2
  2. flattening.py +124 -0
  3. model.safetensors +3 -0
  4. vine_model.py +702 -0
  5. vis_utils.py +941 -0
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_attn_implementation_autoset": true,
3
  "_device": "cuda",
4
  "alpha": 0.5,
 
 
 
5
  "auto_map": {
6
- "AutoConfig": "vine_config.VineConfig"
 
7
  },
8
  "bbox_min_dim": 5,
9
  "box_threshold": 0.35,
@@ -26,6 +29,7 @@
26
  "target_fps": 1,
27
  "text_threshold": 0.25,
28
  "topk_cate": 3,
 
29
  "transformers_version": "4.46.2",
30
  "use_hf_repo": true,
31
  "visualization_dir": null,
 
1
  {
 
2
  "_device": "cuda",
3
  "alpha": 0.5,
4
+ "architectures": [
5
+ "VineModel"
6
+ ],
7
  "auto_map": {
8
+ "AutoConfig": "vine_config.VineConfig",
9
+ "AutoModel": "vine_model.VineModel"
10
  },
11
  "bbox_min_dim": 5,
12
  "box_threshold": 0.35,
 
29
  "target_fps": 1,
30
  "text_threshold": 0.25,
31
  "topk_cate": 3,
32
+ "torch_dtype": "float32",
33
  "transformers_version": "4.46.2",
34
  "use_hf_repo": true,
35
  "visualization_dir": null,
flattening.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ MaskType = Union[np.ndarray, torch.Tensor]
11
+
12
+
13
+ def _to_numpy_mask(mask: MaskType) -> np.ndarray:
14
+ """
15
+ Convert assorted mask formats to a 2D numpy boolean array.
16
+ """
17
+ if isinstance(mask, torch.Tensor):
18
+ mask_np = mask.detach().cpu().numpy()
19
+ else:
20
+ mask_np = np.asarray(mask)
21
+
22
+ # Remove singleton dimensions at the front/back
23
+ while mask_np.ndim > 2 and mask_np.shape[0] == 1:
24
+ mask_np = np.squeeze(mask_np, axis=0)
25
+ if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
26
+ mask_np = np.squeeze(mask_np, axis=-1)
27
+
28
+ if mask_np.ndim != 2:
29
+ raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
30
+
31
+ return mask_np.astype(bool)
32
+
33
+
34
+ def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
35
+ """
36
+ Compute a bounding box for a 2D boolean mask.
37
+ """
38
+ if not mask.any():
39
+ return None
40
+ rows, cols = np.nonzero(mask)
41
+ y_min, y_max = rows.min(), rows.max()
42
+ x_min, x_max = cols.min(), cols.max()
43
+ return x_min, y_min, x_max, y_max
44
+
45
+
46
+ def flatten_segments_for_batch(
47
+ video_id: int,
48
+ segments: Dict[int, Dict[int, MaskType]],
49
+ bbox_min_dim: int = 5,
50
+ ) -> Dict[str, List]:
51
+ """
52
+ Flatten nested segmentation data into batched lists suitable for predicate
53
+ models or downstream visualizations. Mirrors the notebook helper but is
54
+ robust to differing mask dtypes/shapes.
55
+ """
56
+ batched_object_ids: List[Tuple[int, int, int]] = []
57
+ batched_masks: List[np.ndarray] = []
58
+ batched_bboxes: List[Tuple[int, int, int, int]] = []
59
+ frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
60
+
61
+ for frame_id, frame_objects in segments.items():
62
+ valid_objects: List[int] = []
63
+ for object_id, raw_mask in frame_objects.items():
64
+ mask = _to_numpy_mask(raw_mask)
65
+ bbox = _mask_to_bbox(mask)
66
+ if bbox is None:
67
+ continue
68
+
69
+ x_min, y_min, x_max, y_max = bbox
70
+ if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
71
+ continue
72
+
73
+ valid_objects.append(object_id)
74
+ batched_object_ids.append((video_id, frame_id, object_id))
75
+ batched_masks.append(mask)
76
+ batched_bboxes.append(bbox)
77
+
78
+ for i in valid_objects:
79
+ for j in valid_objects:
80
+ if i == j:
81
+ continue
82
+ frame_pairs.append((video_id, frame_id, (i, j)))
83
+
84
+ return {
85
+ "object_ids": batched_object_ids,
86
+ "masks": batched_masks,
87
+ "bboxes": batched_bboxes,
88
+ "pairs": frame_pairs,
89
+ }
90
+
91
+
92
+ def extract_valid_object_pairs(
93
+ batched_object_ids: Sequence[Tuple[int, int, int]],
94
+ interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
95
+ ) -> List[Tuple[int, int, Tuple[int, int]]]:
96
+ """
97
+ Filter object pairs per frame. If `interested_object_pairs` is provided, only
98
+ emit those combinations when both objects are present; otherwise emit all
99
+ permutations (i, j) with i != j for each frame.
100
+ """
101
+ frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
102
+ for vid, fid, oid in batched_object_ids:
103
+ frame_to_objects[(vid, fid)].add(oid)
104
+
105
+ interested = (
106
+ list(interested_object_pairs)
107
+ if interested_object_pairs is not None
108
+ else None
109
+ )
110
+
111
+ valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
112
+ for (vid, fid), object_ids in frame_to_objects.items():
113
+ if interested:
114
+ for src, dst in interested:
115
+ if src in object_ids and dst in object_ids:
116
+ valid_pairs.append((vid, fid, (src, dst)))
117
+ else:
118
+ for src in object_ids:
119
+ for dst in object_ids:
120
+ if src == dst:
121
+ continue
122
+ valid_pairs.append((vid, fid, (src, dst)))
123
+
124
+ return valid_pairs
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c91c273c5f61b7f17fc6cc265e14bb78ed134c71d7b54611208420fcbe4f81de
3
+ size 1815491340
vine_model.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import config
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as cp
6
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
7
+ from typing import Dict, List, Tuple, Optional, Any, Union
8
+ import numpy as np
9
+ import os
10
+ import cv2
11
+ from collections import defaultdict
12
+ import builtins
13
+ import sys
14
+ from laser.models import llava_clip_model_v3
15
+ sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
16
+ from safetensors.torch import load_file
17
+
18
+ import inspect
19
+ from transformers.models.clip import modeling_clip
20
+ import transformers
21
+ from huggingface_hub import snapshot_download
22
+
23
+
24
+
25
+
26
+ from .vine_config import VineConfig
27
+ from laser.models.model_utils import (
28
+ extract_single_object,
29
+ extract_object_subject,
30
+ crop_image_contain_bboxes,
31
+ segment_list
32
+ )
33
+ from .flattening import (
34
+ extract_valid_object_pairs,
35
+ flatten_segments_for_batch,
36
+ )
37
+
38
+ from .vis_utils import save_mask_one_image
39
+
40
+ class VineModel(PreTrainedModel):
41
+ """
42
+ VINE (Video Understanding with Natural Language) Model
43
+
44
+ This model processes videos along with categorical, unary, and binary keywords
45
+ to return probability distributions over those keywords for detected objects
46
+ and their relationships in the video.
47
+ """
48
+
49
+ config_class = VineConfig
50
+
51
+ def __init__(self, config: VineConfig):
52
+ super().__init__(config)
53
+
54
+ self.config = config
55
+ self.visualize = getattr(config, "visualize", False)
56
+ self.visualization_dir = getattr(config, "visualization_dir", None)
57
+ self.debug_visualizations = getattr(config, "debug_visualizations", False)
58
+ self._device = getattr(config, "_device")
59
+
60
+
61
+
62
+ # Initialize CLIP components
63
+ self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
64
+ if self.clip_tokenizer.pad_token is None:
65
+ self.clip_tokenizer.pad_token = (
66
+ self.clip_tokenizer.unk_token
67
+ if self.clip_tokenizer.unk_token
68
+ else self.clip_tokenizer.eos_token
69
+ )
70
+ self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
71
+ self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
72
+ self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
73
+ self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
74
+
75
+
76
+ # Then try to load pretrained VINE weights if specified
77
+ if config.use_hf_repo:
78
+ self._load_huggingface_vine_weights(config.model_repo, config.model_file)
79
+ else:
80
+ self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename)
81
+
82
+ # Move models to devicexwxw
83
+ self.to(self._device)
84
+
85
+ def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None):
86
+ """
87
+ Load pretrained VINE weights from HuggingFace Hub.
88
+ """
89
+ try:
90
+ print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
91
+ repo_path = snapshot_download(model_repo, revision=model_file or "main")
92
+ weights = load_file(os.path.join(repo_path, "model.safetensors"))
93
+ self.load_state_dict(weights, strict=False)
94
+ print("✓ Successfully loaded VINE weights from HuggingFace Hub")
95
+ return True
96
+ except Exception as e:
97
+ print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
98
+ print("Using base CLIP models instead")
99
+ return False
100
+
101
+ def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0):
102
+ """
103
+ Load pretrained VINE weights from a saved .pt file or ensemble format.
104
+ """
105
+ #try: # simple .pt or .pth checkpoint
106
+
107
+ # x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
108
+ # print(f"Loaded VINE checkpoint type: {type(x)}")
109
+ full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir
110
+
111
+ if full_path.endswith(".pkl"):
112
+ print(f"Loading VINE weights from: {full_path}")
113
+ loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False)
114
+
115
+ print(f"Loaded state type: {type(loaded_vine_model)}")
116
+ if not isinstance(loaded_vine_model, dict):
117
+ if hasattr(loaded_vine_model, 'clip_cate_model'):
118
+ self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict())
119
+ if hasattr(loaded_vine_model, 'clip_unary_model'):
120
+ self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict())
121
+ if hasattr(loaded_vine_model, 'clip_binary_model'):
122
+ self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
123
+ return True
124
+
125
+ elif full_path.endswith(".pt") or full_path.endswith(".pth"):
126
+ state = torch.load(full_path, map_location=self._device, weights_only=True)
127
+ print(f"Loaded state type: {type(state)}")
128
+ self.load_state_dict(state)
129
+ return True
130
+
131
+ # handle directory + epoch format
132
+ if os.path.isdir(full_path):
133
+ model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')]
134
+ if model_files:
135
+ model_file = os.path.join(full_path, model_files[0])
136
+ print(f"Loading VINE weights from: {model_file}")
137
+ pretrained_model = torch.load(model_file, map_location="cpu")
138
+
139
+ # Conversion from PredicateModel-like object to VineModel
140
+ # Only copy if attributes exist
141
+ if hasattr(pretrained_model, 'clip_cate_model'):
142
+ self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
143
+ if hasattr(pretrained_model, 'clip_unary_model'):
144
+ self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
145
+ if hasattr(pretrained_model, 'clip_binary_model'):
146
+ self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
147
+ print("✓ Loaded all sub-model weights from ensemble format")
148
+ return True
149
+ else:
150
+ print(f"No model file found for epoch {epoch} in {full_path}")
151
+ return False
152
+
153
+ print("Unsupported format for pretrained_vine_path")
154
+ return False
155
+
156
+ # except Exception as e:
157
+ # print(f"✗ Error loading VINE weights: {e}")
158
+ # print("Using base CLIP models instead")
159
+ # return False
160
+
161
+
162
+
163
+ # def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
164
+ # """
165
+ # Load pretrained VINE weights from local ensemble format.
166
+
167
+ # Args:
168
+ # pretrained_path: Path to the pretrained model directory or HF model name
169
+ # epoch: Epoch number to load (for ensemble format)
170
+ # """
171
+ # if pretrained_path == "video-fm/vine_v0":
172
+ # # Try to load from HuggingFace Hubtry:
173
+ # # ✅ TODO FIXED: Added support for loading .pt/.pth checkpoints with state dicts
174
+ # if pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
175
+ # print(f"Loading VINE weights from: {pretrained_path}")
176
+ # state = torch.load(pretrained_path, map_location="cpu")
177
+
178
+ # if "clip_cate_model" in state:
179
+ # self.clip_cate_model.load_state_dict(state["clip_cate_model"])
180
+ # print("✓ Loaded categorical model weights")
181
+ # if "clip_unary_model" in state:
182
+ # self.clip_unary_model.load_state_dict(state["clip_unary_model"])
183
+ # print("✓ Loaded unary model weights")
184
+ # if "clip_binary_model" in state:
185
+ # self.clip_binary_model.load_state_dict(state["clip_binary_model"])
186
+ # print("✓ Loaded binary model weights")
187
+
188
+ # if "clip_tokenizer" in state:
189
+ # self.clip_tokenizer = state["clip_tokenizer"]
190
+ # print("✓ Loaded tokenizer")
191
+ # if "clip_processor" in state:
192
+ # self.clip_processor = state["clip_processor"]
193
+ # print("✓ Loaded processor")
194
+
195
+ # print("✓ All VINE weights loaded successfully")
196
+ # return True
197
+
198
+ # # Load from local ensemble format
199
+ # try:
200
+ # if os.path.isdir(pretrained_path):
201
+ # # Directory format - look for ensemble file
202
+ # model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
203
+ # if model_files:
204
+ # model_file = os.path.join(pretrained_path, model_files[0])
205
+ # else:
206
+ # print(f"No model file found for epoch {epoch} in {pretrained_path}")
207
+ # return False
208
+ # else:
209
+ # # Direct file path
210
+ # model_file = pretrained_path
211
+
212
+ # print(f"Loading VINE weights from: {model_file}")
213
+
214
+ # # Load the ensemble model (PredicateModel instance)
215
+ # # TODO: conversion from PredicateModel to VineModel
216
+ # pretrained_model = torch.load(model_file, map_location='cpu', weights_only=False)
217
+
218
+ # # Transfer weights from the pretrained model to our HuggingFace models
219
+ # if hasattr(pretrained_model, 'clip_cate_model'):
220
+ # self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
221
+ # print("✓ Loaded categorical model weights")
222
+
223
+ # if hasattr(pretrained_model, 'clip_unary_model'):
224
+ # self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
225
+ # print("✓ Loaded unary model weights")
226
+
227
+ # if hasattr(pretrained_model, 'clip_binary_model'):
228
+ # self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
229
+ # print("✓ Loaded binary model weights")
230
+
231
+ # # Also transfer tokenizer and processor if available
232
+ # if hasattr(pretrained_model, 'clip_tokenizer'):
233
+ # self.clip_tokenizer = pretrained_model.clip_tokenizer
234
+ # print("✓ Loaded tokenizer")
235
+
236
+ # if hasattr(pretrained_model, 'clip_processor'):
237
+ # self.clip_processor = pretrained_model.clip_processor
238
+ # print("✓ Loaded processor")
239
+
240
+ # print("✓ Successfully loaded all VINE weights")
241
+ # return True
242
+
243
+ # except Exception as e:
244
+ # print(f"✗ Error loading VINE weights: {e}")
245
+ # print("Using base CLIP models instead")
246
+ # return False
247
+
248
+ @classmethod
249
+ def from_pretrained_vine(
250
+ cls,
251
+ model_path: str,
252
+ config: Optional[VineConfig] = None,
253
+ epoch: int = 0,
254
+ **kwargs
255
+ ):
256
+ """
257
+ Create VineModel from pretrained VINE weights.
258
+
259
+ Args:
260
+ model_path: Path to pretrained VINE model
261
+ config: Optional config, will create default if None
262
+ epoch: Epoch number to load
263
+ **kwargs: Additional arguments
264
+
265
+ Returns:
266
+ VineModel instance with loaded weights
267
+ """
268
+ # Normalize the incoming model_path into the new VineConfig fields.
269
+ if config is None:
270
+ # Heuristics: if path looks like a HF repo (contains a "/" and
271
+ # doesn't exist on disk) treat it as a repo. Otherwise treat as local.
272
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
273
+ config = VineConfig(use_hf_repo=True, model_repo=model_path)
274
+ else:
275
+ # Local path: could be a file or directory
276
+ if os.path.isdir(model_path):
277
+ config = VineConfig(use_hf_repo=False, local_dir=model_path)
278
+ else:
279
+ config = VineConfig(
280
+ use_hf_repo=False,
281
+ local_dir=os.path.dirname(model_path) or None,
282
+ local_filename=os.path.basename(model_path) or None,
283
+ )
284
+ else:
285
+ # Update provided config to reflect the requested pretrained path
286
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
287
+ config.use_hf_repo = True
288
+ config.model_repo = model_path
289
+ config.model_file = None
290
+ config.local_dir = None
291
+ config.local_filename = None
292
+ else:
293
+ config.use_hf_repo = False
294
+ if os.path.isdir(model_path):
295
+ config.local_dir = model_path
296
+ config.local_filename = None
297
+ else:
298
+ config.local_dir = os.path.dirname(model_path) or None
299
+ config.local_filename = os.path.basename(model_path) or None
300
+
301
+ # Create model instance (will automatically load weights)
302
+ model = cls(config, **kwargs)
303
+
304
+ return model
305
+
306
+ def _text_features_checkpoint(self, model, tokens):
307
+ """Extract text features with gradient checkpointing."""
308
+ token_keys = list(tokens.keys())
309
+
310
+ def get_text_features_wrapped(*inputs):
311
+ kwargs = {key: value for key, value in zip(token_keys, inputs)}
312
+ return model.get_text_features(**kwargs)
313
+
314
+ token_values = [tokens[key] for key in token_keys]
315
+ return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False)
316
+
317
+ def _image_features_checkpoint(self, model, images):
318
+ """Extract image features with gradient checkpointing."""
319
+ return cp.checkpoint(model.get_image_features, images, use_reentrant=False)
320
+
321
+ def clip_sim(self, model, nl_feat, img_feat):
322
+ img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
323
+ nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
324
+ logits = torch.matmul(img_feat, nl_feat.T)
325
+ if hasattr(model, "logit_scale"):
326
+ logits = logits * model.logit_scale.exp()
327
+ return logits
328
+
329
+ def forward(
330
+ self,
331
+ video_frames: torch.Tensor,
332
+ masks: Dict[int, Dict[int, torch.Tensor]],
333
+ bboxes: Dict[int, Dict[int, List]],
334
+ categorical_keywords: List[str],
335
+ unary_keywords: Optional[List[str]] = None,
336
+ binary_keywords: Optional[List[str]] = None,
337
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
338
+ return_flattened_segments: Optional[bool] = None,
339
+ return_valid_pairs: Optional[bool] = None,
340
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
341
+ debug_visualizations: Optional[bool] = None,
342
+ **kwargs
343
+ ) -> Dict[str, Any]:
344
+ """
345
+ Forward pass of the VINE model.
346
+
347
+ Args:
348
+ video_frames: Tensor of shape (num_frames, height, width, 3)
349
+ masks: Dict mapping frame_id -> object_id -> mask tensor
350
+ bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
351
+ categorical_keywords: List of category names to classify objects
352
+ unary_keywords: Optional list of unary predicates (actions on single objects)
353
+ binary_keywords: Optional list of binary predicates (relations between objects)
354
+ object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification
355
+
356
+ Returns:
357
+ Dict containing probability distributions for categorical, unary, and binary predictions
358
+ """
359
+ if unary_keywords is None:
360
+ unary_keywords = []
361
+ if binary_keywords is None:
362
+ binary_keywords = []
363
+ if object_pairs is None:
364
+ object_pairs = []
365
+ if return_flattened_segments is None:
366
+ return_flattened_segments = self.config.return_flattened_segments
367
+ if return_valid_pairs is None:
368
+ return_valid_pairs = self.config.return_valid_pairs
369
+ if interested_object_pairs is None or len(interested_object_pairs) == 0:
370
+ interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or []
371
+ if debug_visualizations is None:
372
+ debug_visualizations = self.debug_visualizations
373
+
374
+ # Prepare dummy strings for empty categories
375
+ dummy_str = ""
376
+
377
+ # Fill empty categories with dummy strings
378
+ if len(categorical_keywords) == 0:
379
+ categorical_keywords = [dummy_str]
380
+ if len(unary_keywords) == 0:
381
+ unary_keywords = [dummy_str]
382
+ if len(binary_keywords) == 0:
383
+ binary_keywords = [dummy_str]
384
+
385
+ # Extract text features for all keyword types
386
+ categorical_features = self._extract_text_features(
387
+ self.clip_cate_model, categorical_keywords
388
+ )
389
+ unary_features = self._extract_text_features(
390
+ self.clip_unary_model, unary_keywords
391
+ )
392
+ binary_features = self._extract_text_features(
393
+ self.clip_binary_model, binary_keywords
394
+ )
395
+
396
+ # Process video frames and extract object features
397
+ categorical_probs = {}
398
+ unary_probs = {}
399
+ binary_probs = {}
400
+
401
+ # Process each frame
402
+ for frame_id, frame_masks in masks.items():
403
+ if frame_id >= len(video_frames):
404
+ continue
405
+
406
+ frame = self._frame_to_numpy(video_frames[frame_id])
407
+ frame_bboxes = bboxes.get(frame_id, {})
408
+
409
+ # Extract object features for categorical classification
410
+ for obj_id, mask in frame_masks.items():
411
+ if obj_id not in frame_bboxes:
412
+ continue
413
+
414
+ bbox = frame_bboxes[obj_id]
415
+
416
+ # Extract single object image
417
+ mask_np = self._mask_to_numpy(mask)
418
+
419
+ obj_image = extract_single_object(
420
+ frame, mask_np, alpha=self.config.alpha
421
+ )
422
+
423
+ # Get image features
424
+ obj_features = self._extract_image_features(
425
+ self.clip_cate_model, obj_image
426
+ )
427
+
428
+ # Compute similarities for categorical classification
429
+ cat_similarities = self.clip_sim(
430
+ self.clip_cate_model, categorical_features, obj_features
431
+ )
432
+ cat_probs = F.softmax(cat_similarities, dim=-1)
433
+
434
+ # Store categorical predictions
435
+ for i, keyword in enumerate(categorical_keywords):
436
+ if keyword != dummy_str:
437
+ categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item()
438
+
439
+ # Compute unary predictions
440
+ if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str:
441
+ unary_similarities = self.clip_sim(
442
+ self.clip_unary_model, unary_features, obj_features
443
+ )
444
+ unary_probs_tensor = F.softmax(unary_similarities, dim=-1)
445
+
446
+ for i, keyword in enumerate(unary_keywords):
447
+ if keyword != dummy_str:
448
+ unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item()
449
+
450
+ # Process binary relationships
451
+ if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0:
452
+ for obj1_id, obj2_id in object_pairs:
453
+ for frame_id, frame_masks in masks.items():
454
+ if frame_id >= len(video_frames):
455
+ continue
456
+ if (obj1_id in frame_masks and obj2_id in frame_masks and
457
+ obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})):
458
+
459
+ frame = self._frame_to_numpy(video_frames[frame_id])
460
+ mask1 = frame_masks[obj1_id]
461
+ mask2 = frame_masks[obj2_id]
462
+
463
+ mask1_np = self._mask_to_numpy(mask1)
464
+ mask2_np = self._mask_to_numpy(mask2)
465
+
466
+ # Extract object pair image
467
+ pair_image = extract_object_subject(
468
+ frame, mask1_np[..., None], mask2_np[..., None],
469
+ alpha=self.config.alpha,
470
+ white_alpha=self.config.white_alpha
471
+ )
472
+
473
+ # Crop to contain both objects
474
+ bbox1 = bboxes[frame_id][obj1_id]
475
+ bbox2 = bboxes[frame_id][obj2_id]
476
+
477
+ # Bounding box overlap check
478
+ if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \
479
+ bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]:
480
+ continue
481
+
482
+ cropped_image = crop_image_contain_bboxes(
483
+ pair_image, [bbox1, bbox2], f"frame_{frame_id}"
484
+ )
485
+
486
+ # Get image features
487
+ pair_features = self._extract_image_features(
488
+ self.clip_binary_model, cropped_image
489
+ )
490
+
491
+ # Compute similarities for binary classification
492
+ binary_similarities = self.clip_sim(
493
+ self.clip_binary_model, binary_features, pair_features
494
+ )
495
+ binary_probs_tensor = F.softmax(binary_similarities, dim=-1)
496
+
497
+ for i, keyword in enumerate(binary_keywords):
498
+ if keyword != dummy_str:
499
+ binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item()
500
+
501
+ # Calculate dummy probability (for compatibility)
502
+ dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords))
503
+
504
+ result: Dict[str, Any] = {
505
+ "categorical_probs": {0: categorical_probs}, # Video ID 0
506
+ "unary_probs": {0: unary_probs},
507
+ "binary_probs": [binary_probs], # List format for compatibility
508
+ "dummy_prob": dummy_prob
509
+ }
510
+
511
+ if return_flattened_segments or return_valid_pairs:
512
+ flattened = flatten_segments_for_batch(
513
+ video_id=0,
514
+ segments=masks,
515
+ bbox_min_dim=self.config.bbox_min_dim,
516
+ )
517
+ if return_flattened_segments:
518
+ result["flattened_segments"] = flattened
519
+ if return_valid_pairs:
520
+ interested_pairs = interested_object_pairs if interested_object_pairs else None
521
+ result["valid_pairs"] = extract_valid_object_pairs(
522
+ flattened["object_ids"],
523
+ interested_pairs,
524
+ )
525
+ if interested_pairs is None:
526
+ # Provide all generated pairs for clarity when auto-generated.
527
+ result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
528
+ else:
529
+ result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs}
530
+
531
+ return result
532
+
533
+ def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
534
+ """Convert a frame tensor/array to a contiguous numpy array."""
535
+ if torch.is_tensor(frame):
536
+ frame_np = frame.detach().cpu().numpy()
537
+ else:
538
+ frame_np = np.asarray(frame)
539
+ return np.ascontiguousarray(frame_np)
540
+
541
+ def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
542
+ """Convert a mask tensor/array to a 2D boolean numpy array."""
543
+ if torch.is_tensor(mask):
544
+ mask_np = mask.detach().cpu().numpy()
545
+ else:
546
+ mask_np = np.asarray(mask)
547
+
548
+ if mask_np.ndim == 3:
549
+ if mask_np.shape[0] == 1:
550
+ mask_np = mask_np.squeeze(0)
551
+ elif mask_np.shape[2] == 1:
552
+ mask_np = mask_np.squeeze(2)
553
+
554
+ if mask_np.ndim != 2:
555
+ raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
556
+
557
+ return mask_np.astype(bool, copy=False)
558
+
559
+ def _extract_text_features(self, model, keywords):
560
+ """Extract text features for given keywords."""
561
+ tokens = self.clip_tokenizer(
562
+ keywords,
563
+ return_tensors="pt",
564
+ max_length=75,
565
+ truncation=True,
566
+ padding='max_length'
567
+ ).to(self._device)
568
+
569
+ return self._text_features_checkpoint(model, tokens)
570
+
571
+ def _extract_image_features(self, model, image):
572
+ """Extract image features for given image."""
573
+ # Ensure image is in correct format
574
+ if isinstance(image, np.ndarray):
575
+ if image.dtype != np.uint8:
576
+ image = image.astype(np.uint8)
577
+ # Convert BGR to RGB if needed
578
+ if len(image.shape) == 3 and image.shape[2] == 3:
579
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
580
+
581
+ # Process image with CLIP processor
582
+ inputs = self.clip_processor(
583
+ images=image,
584
+ return_tensors="pt"
585
+ ).to(self._device)
586
+
587
+ return self._image_features_checkpoint(model, inputs['pixel_values'])
588
+ #TODO: return masks and bboxes and their corresponding index
589
+ def predict(
590
+ self,
591
+ video_frames: torch.Tensor,
592
+ masks: Dict[int, Dict[int, torch.Tensor]],
593
+ bboxes: Dict[int, Dict[int, List]],
594
+ categorical_keywords: List[str],
595
+ unary_keywords: Optional[List[str]] = None,
596
+ binary_keywords: Optional[List[str]] = None,
597
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
598
+ return_top_k: int = 3,
599
+ return_flattened_segments: Optional[bool] = None,
600
+ return_valid_pairs: Optional[bool] = None,
601
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
602
+ debug_visualizations: Optional[bool] = None,
603
+ ) -> Dict[str, Any]:
604
+ """
605
+ High-level prediction method that returns formatted results.
606
+
607
+ Args:
608
+ video_frames: Tensor of shape (num_frames, height, width, 3)
609
+ masks: Dict mapping frame_id -> object_id -> mask tensor
610
+ bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
611
+ categorical_keywords: List of category names
612
+ unary_keywords: Optional list of unary predicates
613
+ binary_keywords: Optional list of binary predicates
614
+ object_pairs: Optional list of object pairs for binary relations
615
+ return_top_k: Number of top predictions to return
616
+ return_flattened_segments: Whether to include flattened mask/bbox tensors
617
+ return_valid_pairs: Whether to compute valid object pairs per frame
618
+ interested_object_pairs: Optional subset of object pairs to track
619
+
620
+ Returns:
621
+ Formatted prediction results
622
+ """
623
+
624
+ with torch.no_grad():
625
+ outputs = self.forward(
626
+ video_frames=video_frames,
627
+ masks=masks,
628
+ bboxes=bboxes,
629
+ categorical_keywords=categorical_keywords,
630
+ unary_keywords=unary_keywords,
631
+ binary_keywords=binary_keywords,
632
+ object_pairs=object_pairs,
633
+ return_flattened_segments=return_flattened_segments,
634
+ return_valid_pairs=return_valid_pairs,
635
+ interested_object_pairs=interested_object_pairs,
636
+ debug_visualizations=debug_visualizations,
637
+ )
638
+
639
+ # Format categorical results
640
+ formatted_categorical = {}
641
+ for (obj_id, category), prob in outputs["categorical_probs"][0].items():
642
+ if obj_id not in formatted_categorical:
643
+ formatted_categorical[obj_id] = []
644
+ formatted_categorical[obj_id].append((prob, category))
645
+
646
+ # Sort and take top-k for each object
647
+ for obj_id in formatted_categorical:
648
+ formatted_categorical[obj_id] = sorted(
649
+ formatted_categorical[obj_id], reverse=True
650
+ )[:return_top_k]
651
+
652
+ # Format unary results
653
+ formatted_unary = {}
654
+ for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
655
+ key = (frame_id, obj_id)
656
+ if key not in formatted_unary:
657
+ formatted_unary[key] = []
658
+ formatted_unary[key].append((prob, predicate))
659
+
660
+ # Sort and take top-k
661
+ for key in formatted_unary:
662
+ formatted_unary[key] = sorted(
663
+ formatted_unary[key], reverse=True
664
+ )[:return_top_k]
665
+
666
+ # Format binary results
667
+ formatted_binary = {}
668
+ if len(outputs["binary_probs"]) > 0:
669
+ for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
670
+ key = (frame_id, obj_pair)
671
+ if key not in formatted_binary:
672
+ formatted_binary[key] = []
673
+ formatted_binary[key].append((prob, predicate))
674
+
675
+ # Sort and take top-k
676
+ for key in formatted_binary:
677
+ formatted_binary[key] = sorted(
678
+ formatted_binary[key], reverse=True
679
+ )[:return_top_k]
680
+
681
+ result: Dict[str, Any] = {
682
+ "categorical_predictions": formatted_categorical,
683
+ "unary_predictions": formatted_unary,
684
+ "binary_predictions": formatted_binary,
685
+ "confidence_scores": {
686
+ "categorical": max([max([p for p, _ in preds], default=0.0)
687
+ for preds in formatted_categorical.values()], default=0.0),
688
+ "unary": max([max([p for p, _ in preds], default=0.0)
689
+ for preds in formatted_unary.values()], default=0.0),
690
+ "binary": max([max([p for p, _ in preds], default=0.0)
691
+ for preds in formatted_binary.values()], default=0.0)
692
+ }
693
+ }
694
+
695
+ if "flattened_segments" in outputs:
696
+ result["flattened_segments"] = outputs["flattened_segments"]
697
+ if "valid_pairs" in outputs:
698
+ result["valid_pairs"] = outputs["valid_pairs"]
699
+ if "valid_pairs_metadata" in outputs:
700
+ result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
701
+
702
+ return result
vis_utils.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import random
7
+ import math
8
+ from matplotlib.patches import Rectangle
9
+ import itertools
10
+ from typing import Any, Dict, List, Tuple, Optional, Union
11
+
12
+ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
13
+
14
+ ########################################################################################
15
+ ########## Visualization Library ########
16
+ ########################################################################################
17
+ # This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
18
+ #
19
+ # Conventions (RGB frames, pixel coords):
20
+ # - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
21
+ # - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
22
+ # - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
23
+ #
24
+ # Per-frame stores use one of:
25
+ # - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
26
+ # - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
27
+ #
28
+ # Renderer inputs/outputs:
29
+ # 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
30
+ # - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
31
+ # - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
32
+ #
33
+ # 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
34
+ # - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
35
+ #
36
+ # 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
37
+ # -> List[np.ndarray] (the "all" view)
38
+ # - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
39
+ # - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
40
+ # - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
41
+ # - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
42
+ #
43
+ # Ground-truth helpers used by plotting utilities:
44
+ # - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
45
+ #
46
+ # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
47
+ ########################################################################################
48
+
49
+ def clean_label(label):
50
+ """Replace underscores and slashes with spaces for uniformity."""
51
+ return label.replace("_", " ").replace("/", " ")
52
+
53
+ # Should be performed somewhere else I believe
54
+ def format_cate_preds(cate_preds):
55
+ # Group object predictions from the model output.
56
+ obj_pred_dict = {}
57
+ for (oid, label), prob in cate_preds.items():
58
+ # Clean the predicted label as well.
59
+ clean_pred = clean_label(label)
60
+ if oid not in obj_pred_dict:
61
+ obj_pred_dict[oid] = []
62
+ obj_pred_dict[oid].append((clean_pred, prob))
63
+ for oid in obj_pred_dict:
64
+ obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
65
+ return obj_pred_dict
66
+
67
+ def format_binary_cate_preds(binary_preds):
68
+ frame_binary_preds = []
69
+ for key, score in binary_preds.items():
70
+ # Expect key format: (frame_id, (subject, object), predicted_relation)
71
+ try:
72
+ f_id, (subj, obj), pred_rel = key
73
+ frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
74
+ except Exception as e:
75
+ print("Skipping key with unexpected format:", key)
76
+ continue
77
+ frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
78
+ return frame_binary_preds
79
+
80
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
81
+
82
+
83
+ def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
84
+ if mask is None:
85
+ return None
86
+ if isinstance(mask, torch.Tensor):
87
+ mask_np = mask.detach().cpu().numpy()
88
+ else:
89
+ mask_np = np.asarray(mask)
90
+ if mask_np.ndim == 0:
91
+ return None
92
+ if mask_np.ndim == 3:
93
+ mask_np = np.squeeze(mask_np)
94
+ if mask_np.ndim != 2:
95
+ return None
96
+ if mask_np.dtype == bool:
97
+ return mask_np
98
+ return mask_np > 0
99
+
100
+
101
+ def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
102
+ if bbox is None:
103
+ return None
104
+ if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
105
+ x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
106
+ elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
107
+ x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
108
+ else:
109
+ return None
110
+ x1 = int(np.clip(round(x1), 0, width - 1))
111
+ y1 = int(np.clip(round(y1), 0, height - 1))
112
+ x2 = int(np.clip(round(x2), 0, width - 1))
113
+ y2 = int(np.clip(round(y2), 0, height - 1))
114
+ if x2 <= x1 or y2 <= y1:
115
+ return None
116
+ return (x1, y1, x2, y2)
117
+
118
+
119
+ def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
120
+ color = get_color(obj_id)
121
+ rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
122
+ return (rgb[2], rgb[1], rgb[0])
123
+
124
+
125
+ def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
126
+ return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
127
+
128
+
129
+ def _draw_label_block(
130
+ image: np.ndarray,
131
+ lines: List[str],
132
+ anchor: Tuple[int, int],
133
+ color: Tuple[int, int, int],
134
+ font_scale: float = 0.5,
135
+ thickness: int = 1,
136
+ direction: str = "up",
137
+ ) -> None:
138
+ if not lines:
139
+ return
140
+ img_h, img_w = image.shape[:2]
141
+ x, y = anchor
142
+ x = int(np.clip(x, 0, img_w - 1))
143
+ y_cursor = int(np.clip(y, 0, img_h - 1))
144
+ bg_color = _background_color(color)
145
+
146
+ if direction == "down":
147
+ for text in lines:
148
+ text = str(text)
149
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
150
+ left_x = x
151
+ right_x = min(left_x + tw + 8, img_w - 1)
152
+ top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
153
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
154
+ if bottom_y <= top_y:
155
+ break
156
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
157
+ text_x = left_x + 4
158
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
159
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
160
+ y_cursor = bottom_y
161
+ else:
162
+ for text in lines:
163
+ text = str(text)
164
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
165
+ top_y = max(y_cursor - th - baseline - 6, 0)
166
+ left_x = x
167
+ right_x = min(left_x + tw + 8, img_w - 1)
168
+ bottom_y = min(top_y + th + baseline + 6, img_h - 1)
169
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
170
+ text_x = left_x + 4
171
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
172
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
173
+ y_cursor = top_y
174
+
175
+
176
+ def _draw_centered_label(
177
+ image: np.ndarray,
178
+ text: str,
179
+ center: Tuple[int, int],
180
+ color: Tuple[int, int, int],
181
+ font_scale: float = 0.5,
182
+ thickness: int = 1,
183
+ ) -> None:
184
+ text = str(text)
185
+ img_h, img_w = image.shape[:2]
186
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
187
+ cx = int(np.clip(center[0], 0, img_w - 1))
188
+ cy = int(np.clip(center[1], 0, img_h - 1))
189
+ left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
190
+ top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
191
+ right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
192
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
193
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
194
+ text_x = left_x + 4
195
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
196
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
197
+
198
+
199
+ def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
200
+ if isinstance(store, dict):
201
+ frame_entry = store.get(frame_idx, {})
202
+ elif isinstance(store, list) and 0 <= frame_idx < len(store):
203
+ frame_entry = store[frame_idx]
204
+ else:
205
+ frame_entry = {}
206
+ if isinstance(frame_entry, dict):
207
+ return frame_entry
208
+ if isinstance(frame_entry, list):
209
+ return {i: value for i, value in enumerate(frame_entry)}
210
+ return {}
211
+
212
+
213
+ def _label_anchor_and_direction(
214
+ bbox: Tuple[int, int, int, int],
215
+ position: str,
216
+ ) -> Tuple[Tuple[int, int], str]:
217
+ x1, y1, x2, y2 = bbox
218
+ if position == "bottom":
219
+ return (x1, y2), "down"
220
+ return (x1, y1), "up"
221
+
222
+
223
+ def _draw_bbox_with_label(
224
+ image: np.ndarray,
225
+ bbox: Tuple[int, int, int, int],
226
+ obj_id: int,
227
+ title: Optional[str] = None,
228
+ sub_lines: Optional[List[str]] = None,
229
+ label_position: str = "top",
230
+ ) -> None:
231
+ color = _object_color_bgr(obj_id)
232
+ cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
233
+ head = title if title else f"#{obj_id}"
234
+ if not head.startswith("#"):
235
+ head = f"#{obj_id} {head}"
236
+ lines = [head]
237
+ if sub_lines:
238
+ lines.extend(sub_lines)
239
+ anchor, direction = _label_anchor_and_direction(bbox, label_position)
240
+ _draw_label_block(image, lines, anchor, color, direction=direction)
241
+
242
+
243
+ def render_sam_frames(
244
+ frames: Union[np.ndarray, List[np.ndarray]],
245
+ sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
246
+ dino_labels: Optional[Dict[int, str]] = None,
247
+ ) -> List[np.ndarray]:
248
+ results: List[np.ndarray] = []
249
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
250
+ dino_labels = dino_labels or {}
251
+
252
+ for frame_idx, frame in enumerate(frames_iterable):
253
+ if frame is None:
254
+ continue
255
+ frame_rgb = np.asarray(frame)
256
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
257
+ overlay = frame_bgr.astype(np.float32)
258
+ masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
259
+
260
+ for obj_id, mask in masks_for_frame.items():
261
+ mask_np = _to_numpy_mask(mask)
262
+ if mask_np is None or not np.any(mask_np):
263
+ continue
264
+ color = _object_color_bgr(obj_id)
265
+ alpha = 0.45
266
+ overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
267
+
268
+ annotated = np.clip(overlay, 0, 255).astype(np.uint8)
269
+ frame_h, frame_w = annotated.shape[:2]
270
+
271
+ for obj_id, mask in masks_for_frame.items():
272
+ mask_np = _to_numpy_mask(mask)
273
+ if mask_np is None or not np.any(mask_np):
274
+ continue
275
+ bbox = mask_to_bbox(mask_np)
276
+ bbox = _sanitize_bbox(bbox, frame_w, frame_h)
277
+ if not bbox:
278
+ continue
279
+ label = dino_labels.get(obj_id)
280
+ title = f"{label}" if label else None
281
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
282
+
283
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
284
+
285
+ return results
286
+
287
+
288
+ def render_dino_frames(
289
+ frames: Union[np.ndarray, List[np.ndarray]],
290
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
291
+ dino_labels: Optional[Dict[int, str]] = None,
292
+ ) -> List[np.ndarray]:
293
+ results: List[np.ndarray] = []
294
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
295
+ dino_labels = dino_labels or {}
296
+
297
+ for frame_idx, frame in enumerate(frames_iterable):
298
+ if frame is None:
299
+ continue
300
+ frame_rgb = np.asarray(frame)
301
+ annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
302
+ frame_h, frame_w = annotated.shape[:2]
303
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
304
+
305
+ for obj_id, bbox_values in frame_bboxes.items():
306
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
307
+ if not bbox:
308
+ continue
309
+ label = dino_labels.get(obj_id)
310
+ title = f"{label}" if label else None
311
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
312
+
313
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
314
+
315
+ return results
316
+
317
+
318
+ def render_vine_frame_sets(
319
+ frames: Union[np.ndarray, List[np.ndarray]],
320
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
321
+ cat_label_lookup: Dict[int, Tuple[str, float]],
322
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
323
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
324
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
325
+ ) -> Dict[str, List[np.ndarray]]:
326
+ frame_groups: Dict[str, List[np.ndarray]] = {
327
+ "object": [],
328
+ "unary": [],
329
+ "binary": [],
330
+ "all": [],
331
+ }
332
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
333
+
334
+ for frame_idx, frame in enumerate(frames_iterable):
335
+ if frame is None:
336
+ continue
337
+ frame_rgb = np.asarray(frame)
338
+ base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
339
+ frame_h, frame_w = base_bgr.shape[:2]
340
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
341
+ frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
342
+
343
+ objects_bgr = base_bgr.copy()
344
+ unary_bgr = base_bgr.copy()
345
+ binary_bgr = base_bgr.copy()
346
+ all_bgr = base_bgr.copy()
347
+
348
+ bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
349
+ unary_lines_lookup: Dict[int, List[str]] = {}
350
+ titles_lookup: Dict[int, Optional[str]] = {}
351
+
352
+ for obj_id, bbox_values in frame_bboxes.items():
353
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
354
+ if not bbox:
355
+ continue
356
+ bbox_lookup[obj_id] = bbox
357
+ cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
358
+ title_parts = []
359
+ if cat_label:
360
+ if cat_prob is not None:
361
+ title_parts.append(f"{cat_label} {cat_prob:.2f}")
362
+ else:
363
+ title_parts.append(cat_label)
364
+ titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
365
+ unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
366
+ unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
367
+ unary_lines_lookup[obj_id] = unary_lines
368
+
369
+ for obj_id, bbox in bbox_lookup.items():
370
+ unary_lines = unary_lines_lookup.get(obj_id, [])
371
+ if not unary_lines:
372
+ continue
373
+ mask_raw = frame_masks.get(obj_id)
374
+ mask_np = _to_numpy_mask(mask_raw)
375
+ if mask_np is None or not np.any(mask_np):
376
+ continue
377
+ color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
378
+ alpha = 0.45
379
+ for target in (unary_bgr, all_bgr):
380
+ target_vals = target[mask_np].astype(np.float32)
381
+ blended = (1.0 - alpha) * target_vals + alpha * color
382
+ target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
383
+
384
+ for obj_id, bbox in bbox_lookup.items():
385
+ title = titles_lookup.get(obj_id)
386
+ unary_lines = unary_lines_lookup.get(obj_id, [])
387
+ _draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
388
+ _draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
389
+ if unary_lines:
390
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
391
+ _draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
392
+ _draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
393
+ _draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
394
+ if unary_lines:
395
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
396
+ _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
397
+
398
+ for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
399
+ if len(obj_pair) != 2 or not relation_preds:
400
+ continue
401
+ subj_id, obj_id = obj_pair
402
+ subj_bbox = bbox_lookup.get(subj_id)
403
+ obj_bbox = bbox_lookup.get(obj_id)
404
+ if not subj_bbox or not obj_bbox:
405
+ continue
406
+ start, end = relation_line(subj_bbox, obj_bbox)
407
+ color = tuple(int(c) for c in np.clip(
408
+ (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
409
+ np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
410
+ 0, 255
411
+ ))
412
+ prob, relation = relation_preds[0]
413
+ label_text = f"{relation} {prob:.2f}"
414
+ mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
415
+ cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
416
+ cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
417
+ _draw_centered_label(binary_bgr, label_text, mid_point, color)
418
+ _draw_centered_label(all_bgr, label_text, mid_point, color)
419
+
420
+ frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
421
+ frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
422
+ frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
423
+ frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
424
+
425
+ return frame_groups
426
+
427
+
428
+ def render_vine_frames(
429
+ frames: Union[np.ndarray, List[np.ndarray]],
430
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
431
+ cat_label_lookup: Dict[int, Tuple[str, float]],
432
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
433
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
434
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
435
+ ) -> List[np.ndarray]:
436
+ return render_vine_frame_sets(
437
+ frames,
438
+ bboxes,
439
+ cat_label_lookup,
440
+ unary_lookup,
441
+ binary_lookup,
442
+ masks,
443
+ ).get("all", [])
444
+
445
+ def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
446
+ all_colors = []
447
+ all_texts = []
448
+ for (obj_id, bbox, gt_label) in gt_labels:
449
+ preds = obj_pred_dict.get(obj_id, [])
450
+ if len(preds) == 0:
451
+ top1 = "N/A"
452
+ box_color = (0, 0, 255) # bright red if no prediction
453
+ else:
454
+ top1, prob1 = preds[0]
455
+ topk_labels = [p[0] for p in preds[:topk_object]]
456
+ # Compare cleaned labels.
457
+ if top1.lower() == gt_label.lower():
458
+ box_color = (0, 255, 0) # bright green for correct
459
+ elif gt_label.lower() in [p.lower() for p in topk_labels]:
460
+ box_color = (0, 165, 255) # bright orange for partial match
461
+ else:
462
+ box_color = (0, 0, 255) # bright red for incorrect
463
+
464
+ label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
465
+ all_colors.append(box_color)
466
+ all_texts.append(label_text)
467
+ return all_colors, all_texts
468
+
469
+ def plot_unary(frame_img, gt_labels, all_colors, all_texts):
470
+
471
+ for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
472
+ x1, y1, x2, y2 = map(int, bbox)
473
+ cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
474
+ (tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
475
+ cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
476
+ cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
477
+ 0.5, (0, 0, 0), 1, cv2.LINE_AA)
478
+
479
+ return frame_img
480
+
481
+ def get_white_pane(pane_height,
482
+ pane_width=600,
483
+ header_height = 50,
484
+ header_font = cv2.FONT_HERSHEY_SIMPLEX,
485
+ header_font_scale = 0.7,
486
+ header_thickness = 2,
487
+ header_color = (0, 0, 0)):
488
+ # Create an expanded white pane to display text info.
489
+ white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
490
+
491
+ # --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
492
+ left_width = int(pane_width * 0.6)
493
+ right_width = pane_width - left_width
494
+ left_pane = white_pane[:, :left_width, :].copy()
495
+ right_pane = white_pane[:, left_width:, :].copy()
496
+
497
+ cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
498
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
499
+ cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
500
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
501
+
502
+ return white_pane
503
+
504
+ # This is for ploting binary prediction results with frame-based scene graphs
505
+ def plot_binary_sg(frame_img,
506
+ white_pane,
507
+ bin_preds,
508
+ gt_relations,
509
+ topk_binary,
510
+ header_height=50,
511
+ indicator_size=20,
512
+ pane_width=600):
513
+ # Leave vertical space for the headers.
514
+ line_height = 30 # vertical spacing per line
515
+ x_text = 10 # left margin for text
516
+ y_text_left = header_height + 10 # starting y for left pane text
517
+ y_text_right = header_height + 10 # starting y for right pane text
518
+
519
+ # Left section: top-k binary predictions.
520
+ left_width = int(pane_width * 0.6)
521
+ right_width = pane_width - left_width
522
+ left_pane = white_pane[:, :left_width, :].copy()
523
+ right_pane = white_pane[:, left_width:, :].copy()
524
+
525
+ for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
526
+ correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
527
+ for gt in gt_relations)
528
+ indicator_color = (0, 255, 0) if correct else (0, 0, 255)
529
+ cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
530
+ (x_text + indicator_size, y_text_left + 5), indicator_color, -1)
531
+ text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
532
+ cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
533
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
534
+ y_text_left += line_height
535
+
536
+ # Right section: ground truth binary relations.
537
+ for gt in gt_relations:
538
+ if len(gt) != 3:
539
+ continue
540
+ text = f"{gt[0]} - {gt[2]} - {gt[1]}"
541
+ cv2.putText(right_pane, text, (x_text, y_text_right + 5),
542
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
543
+ y_text_right += line_height
544
+
545
+ # Combine the two text panes and then with the frame image.
546
+ combined_pane = np.hstack((left_pane, right_pane))
547
+ combined_image = np.hstack((frame_img, combined_pane))
548
+ return combined_image
549
+
550
+ def visualized_frame(frame_img,
551
+ bboxes,
552
+ object_ids,
553
+ gt_labels,
554
+ cate_preds,
555
+ binary_preds,
556
+ gt_relations,
557
+ topk_object,
558
+ topk_binary,
559
+ phase="unary"):
560
+
561
+ """Return the combined annotated frame for frame index i as an image (in BGR)."""
562
+ # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
563
+
564
+ # --- Process Object Predictions (for overlaying bboxes) ---
565
+ if phase == "unary":
566
+ objs = []
567
+ for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
568
+ gt_label = clean_label(gt_label)
569
+ objs.append((obj_id, bbox, gt_label))
570
+
571
+ formatted_cate_preds = format_cate_preds(cate_preds)
572
+ all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
573
+ updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
574
+ return updated_frame_img
575
+
576
+ else:
577
+ # --- Process Binary Predictions & Ground Truth for the Text Pane ---
578
+ formatted_binary_preds = format_binary_cate_preds(binary_preds)
579
+
580
+ # Ground truth binary relations for the frame.
581
+ # Clean ground truth relations.
582
+ gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
583
+
584
+ pane_width = 600 # increased pane width for more horizontal space
585
+ pane_height = frame_img.shape[0]
586
+
587
+ # --- Add header labels to each text pane with extra space ---
588
+ header_height = 50 # increased header space
589
+ white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
590
+
591
+ combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
592
+
593
+ return combined_image
594
+
595
+ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
596
+ # Ensure mask is a numpy array
597
+ mask = np.array(mask)
598
+ # Handle different mask shapes
599
+ if mask.ndim == 3:
600
+ # (1, H, W) -> (H, W)
601
+ if mask.shape[0] == 1:
602
+ mask = mask.squeeze(0)
603
+ # (H, W, 1) -> (H, W)
604
+ elif mask.shape[2] == 1:
605
+ mask = mask.squeeze(2)
606
+ # Now mask should be (H, W)
607
+ assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
608
+
609
+ if random_color:
610
+ color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
611
+ else:
612
+ cmap = plt.get_cmap("gist_rainbow")
613
+ cmap_idx = 0 if obj_id is None else obj_id
614
+ color = list(cmap((cmap_idx * 47) % 256))
615
+ color[3] = 0.5
616
+ color = np.array(color)
617
+
618
+ # Expand mask to (H, W, 1) for broadcasting
619
+ mask_expanded = mask[..., None]
620
+ mask_image = mask_expanded * color.reshape(1, 1, -1)
621
+
622
+ # draw a box around the mask with the det_class as the label
623
+ if not det_class is None:
624
+ # Find the bounding box coordinates
625
+ y_indices, x_indices = np.where(mask > 0)
626
+ if y_indices.size > 0 and x_indices.size > 0:
627
+ x_min, x_max = x_indices.min(), x_indices.max()
628
+ y_min, y_max = y_indices.min(), y_indices.max()
629
+ rect = Rectangle(
630
+ (x_min, y_min),
631
+ x_max - x_min,
632
+ y_max - y_min,
633
+ linewidth=1.5,
634
+ edgecolor=color[:3],
635
+ facecolor="none",
636
+ alpha=color[3]
637
+ )
638
+ ax.add_patch(rect)
639
+ ax.text(
640
+ x_min,
641
+ y_min - 5,
642
+ f"{det_class}",
643
+ color="white",
644
+ fontsize=6,
645
+ backgroundcolor=np.array(color),
646
+ alpha=1
647
+ )
648
+ ax.imshow(mask_image)
649
+
650
+ def save_mask_one_image(frame_image, masks, save_path):
651
+ """Render masks on top of a frame and store the visualization on disk."""
652
+ fig, ax = plt.subplots(1, figsize=(6, 6))
653
+
654
+ frame_np = (
655
+ frame_image.detach().cpu().numpy()
656
+ if torch.is_tensor(frame_image)
657
+ else np.asarray(frame_image)
658
+ )
659
+ frame_np = np.ascontiguousarray(frame_np)
660
+
661
+ if isinstance(masks, dict):
662
+ mask_iter = masks.items()
663
+ else:
664
+ mask_iter = enumerate(masks)
665
+
666
+ prepared_masks = {
667
+ obj_id: (
668
+ mask.detach().cpu().numpy()
669
+ if torch.is_tensor(mask)
670
+ else np.asarray(mask)
671
+ )
672
+ for obj_id, mask in mask_iter
673
+ }
674
+
675
+ ax.imshow(frame_np)
676
+ ax.axis("off")
677
+
678
+ for obj_id, mask_np in prepared_masks.items():
679
+ show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
680
+
681
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
682
+ plt.close(fig)
683
+ return save_path
684
+
685
+ def get_video_masks_visualization(video_tensor,
686
+ video_masks,
687
+ video_id,
688
+ video_save_base_dir,
689
+ oid_class_pred=None,
690
+ sample_rate = 1):
691
+
692
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
693
+ if not os.path.exists(video_save_dir):
694
+ os.makedirs(video_save_dir, exist_ok=True)
695
+
696
+ for frame_id, image in enumerate(video_tensor):
697
+ if frame_id not in video_masks:
698
+ print("No mask for Frame", frame_id)
699
+ continue
700
+
701
+ masks = video_masks[frame_id]
702
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
703
+ get_mask_one_image(image, masks, oid_class_pred)
704
+
705
+ def get_mask_one_image(frame_image, masks, oid_class_pred=None):
706
+ # Create a figure and axis
707
+ fig, ax = plt.subplots(1, figsize=(6, 6))
708
+
709
+ # Display the frame image
710
+ ax.imshow(frame_image)
711
+ ax.axis('off')
712
+
713
+ if type(masks) == list:
714
+ masks = {i: m for i, m in enumerate(masks)}
715
+
716
+ # Add the masks
717
+ for obj_id, mask in masks.items():
718
+ det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
719
+ show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
720
+
721
+ # Show the plot
722
+ return fig, ax
723
+
724
+ def save_video(frames, output_filename, output_fps):
725
+
726
+ # --- Create a video from all frames ---
727
+ num_frames = len(frames)
728
+ frame_h, frame_w = frames.shape[:2]
729
+
730
+ # Use a codec supported by VS Code (H.264 via 'avc1').
731
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
732
+ out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
733
+
734
+ print(f"Processing {num_frames} frames...")
735
+ for i in range(num_frames):
736
+ vis_frame = get_visualized_frame(i)
737
+ out.write(vis_frame)
738
+ if i % 10 == 0:
739
+ print(f"Processed frame {i+1}/{num_frames}")
740
+
741
+ out.release()
742
+ print(f"Video saved as {output_filename}")
743
+
744
+
745
+ def list_depth(lst):
746
+ """Calculates the depth of a nested list."""
747
+ if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
748
+ return 0
749
+ elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
750
+ return 1
751
+ else:
752
+ return 1 + max(list_depth(item) for item in lst)
753
+
754
+ def normalize_prompt(points, labels):
755
+ if list_depth(points) == 3:
756
+ points = torch.stack([p.unsqueeze(0) for p in points])
757
+ labels = torch.stack([l.unsqueeze(0) for l in labels])
758
+ return points, labels
759
+
760
+
761
+ def show_box(box, ax, object_id):
762
+ if len(box) == 0:
763
+ return
764
+
765
+ cmap = plt.get_cmap("gist_rainbow")
766
+ cmap_idx = 0 if object_id is None else object_id
767
+ color = list(cmap((cmap_idx * 47) % 256))
768
+
769
+ x0, y0 = box[0], box[1]
770
+ w, h = box[2] - box[0], box[3] - box[1]
771
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
772
+
773
+ def show_points(coords, labels, ax, object_id=None, marker_size=375):
774
+ if len(labels) == 0:
775
+ return
776
+
777
+ pos_points = coords[labels==1]
778
+ neg_points = coords[labels==0]
779
+
780
+ cmap = plt.get_cmap("gist_rainbow")
781
+ cmap_idx = 0 if object_id is None else object_id
782
+ color = list(cmap((cmap_idx * 47) % 256))
783
+
784
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
785
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
786
+
787
+ def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
788
+ # Create a figure and axis
789
+ fig, ax = plt.subplots(1, figsize=(6, 6))
790
+
791
+ # Display the frame image
792
+ ax.imshow(frame_image)
793
+ ax.axis('off')
794
+
795
+ points, labels = normalize_prompt(points, labels)
796
+ if type(boxes) == torch.Tensor:
797
+ for object_id, box in enumerate(boxes):
798
+ # Add the bounding boxes
799
+ if not box is None:
800
+ show_box(box.cpu(), ax, object_id=object_id)
801
+ elif type(boxes) == dict:
802
+ for object_id, box in boxes.items():
803
+ # Add the bounding boxes
804
+ if not box is None:
805
+ show_box(box.cpu(), ax, object_id=object_id)
806
+ elif type(boxes) == list and len(boxes) == 0:
807
+ pass
808
+ else:
809
+ raise Exception()
810
+
811
+ for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
812
+ if not len(point_ls) == 0:
813
+ show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
814
+
815
+ # Show the plot
816
+ plt.savefig(save_path)
817
+ plt.close()
818
+
819
+ def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
820
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
821
+ if not os.path.exists(video_save_dir):
822
+ os.makedirs(video_save_dir, exist_ok=True)
823
+
824
+ for frame_id, image in enumerate(video_tensor):
825
+ boxes, points, labels = [], [], []
826
+
827
+ if frame_id in video_boxes:
828
+ boxes = video_boxes[frame_id]
829
+
830
+ if frame_id in video_points:
831
+ points = video_points[frame_id]
832
+ if frame_id in video_labels:
833
+ labels = video_labels[frame_id]
834
+
835
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
836
+ save_prompts_one_image(image, boxes, points, labels, save_path)
837
+
838
+
839
+ def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
840
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
841
+ if not os.path.exists(video_save_dir):
842
+ os.makedirs(video_save_dir, exist_ok=True)
843
+
844
+ for frame_id, image in enumerate(video_tensor):
845
+ if random.random() > sample_rate:
846
+ continue
847
+ if frame_id not in video_masks:
848
+ print("No mask for Frame", frame_id)
849
+ continue
850
+ masks = video_masks[frame_id]
851
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
852
+ save_mask_one_image(image, masks, save_path)
853
+
854
+
855
+
856
+ def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
857
+ cmap = plt.get_cmap(cmap_name)
858
+ cmap_idx = 0 if obj_id is None else obj_id
859
+ color = list(cmap((cmap_idx * 47) % 256))
860
+ color[3] = 0.5
861
+ color = np.array(color)
862
+ return color
863
+
864
+
865
+ def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
866
+ return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
867
+
868
+
869
+ def relation_line(
870
+ bbox1: Tuple[int, int, int, int],
871
+ bbox2: Tuple[int, int, int, int],
872
+ ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
873
+ """
874
+ Returns integer pixel centers suitable for drawing a relation line. For
875
+ coincident boxes, nudges the target center to ensure the segment has span.
876
+ """
877
+ center1 = _bbox_center(bbox1)
878
+ center2 = _bbox_center(bbox2)
879
+ if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
880
+ offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
881
+ center2 = (center2[0] + offset, center2[1])
882
+ start = (int(round(center1[0])), int(round(center1[1])))
883
+ end = (int(round(center2[0])), int(round(center2[1])))
884
+ if start == end:
885
+ end = (end[0] + 1, end[1])
886
+ return start, end
887
+
888
+ def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
889
+ # Create a figure and axis
890
+ fig, ax = plt.subplots(1, figsize=(6, 6))
891
+
892
+ # Display the frame image
893
+ ax.imshow(frame_image)
894
+ ax.axis('off')
895
+
896
+ all_objs_to_show = set()
897
+ all_lines_to_show = []
898
+
899
+ # print(rel_pred_ls[0])
900
+ for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
901
+ all_objs_to_show.add(from_obj_id)
902
+ all_objs_to_show.add(to_obj_id)
903
+
904
+ from_mask = masks[from_obj_id]
905
+ bbox1 = mask_to_bbox(from_mask)
906
+ to_mask = masks[to_obj_id]
907
+ bbox2 = mask_to_bbox(to_mask)
908
+
909
+ c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
910
+
911
+ line_color = get_color(from_obj_id)
912
+ face_color = get_color(to_obj_id)
913
+ line = c1, c2, face_color, line_color, rel_text
914
+ all_lines_to_show.append(line)
915
+
916
+ masks_to_show = {}
917
+ for oid in all_objs_to_show:
918
+ masks_to_show[oid] = masks[oid]
919
+
920
+ # Add the masks
921
+ for obj_id, mask in masks_to_show.items():
922
+ show_mask(mask, ax, obj_id=obj_id, random_color=False)
923
+
924
+ for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
925
+
926
+ plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
927
+ mid_pt_x = (from_pt_x + to_pt_x) / 2
928
+ mid_pt_y = (from_pt_y + to_pt_y) / 2
929
+ ax.text(
930
+ mid_pt_x - 5,
931
+ mid_pt_y,
932
+ rel_text,
933
+ color="white",
934
+ fontsize=6,
935
+ backgroundcolor=np.array(line_color),
936
+ bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
937
+ alpha=1
938
+ )
939
+
940
+ # Show the plot
941
+ return fig, ax