| from flax import config |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as cp |
| from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor |
| from typing import Dict, List, Tuple, Optional, Any, Union |
| import numpy as np |
| import os |
| import cv2 |
| from collections import defaultdict |
| import builtins |
| import sys |
| from laser.models import llava_clip_model_v3 |
| sys.modules["llava_clip_model_v3"] = llava_clip_model_v3 |
| from safetensors.torch import load_file |
|
|
| import inspect |
| from transformers.models.clip import modeling_clip |
| import transformers |
| from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
| from .vine_config import VineConfig |
| from laser.models.model_utils import ( |
| extract_single_object, |
| extract_object_subject, |
| crop_image_contain_bboxes, |
| segment_list |
| ) |
| from .flattening import ( |
| extract_valid_object_pairs, |
| flatten_segments_for_batch, |
| ) |
|
|
| from .vis_utils import save_mask_one_image |
|
|
| class VineModel(PreTrainedModel): |
| """ |
| VINE (Video Understanding with Natural Language) Model |
| |
| This model processes videos along with categorical, unary, and binary keywords |
| to return probability distributions over those keywords for detected objects |
| and their relationships in the video. |
| """ |
| |
| config_class = VineConfig |
| |
| def __init__(self, config: VineConfig): |
| super().__init__(config) |
| |
| self.config = config |
| self.visualize = getattr(config, "visualize", False) |
| self.visualization_dir = getattr(config, "visualization_dir", None) |
| self.debug_visualizations = getattr(config, "debug_visualizations", False) |
| self._device = getattr(config, "_device") |
| |
|
|
| |
| |
| self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name) |
| if self.clip_tokenizer.pad_token is None: |
| self.clip_tokenizer.pad_token = ( |
| self.clip_tokenizer.unk_token |
| if self.clip_tokenizer.unk_token |
| else self.clip_tokenizer.eos_token |
| ) |
| self.clip_processor = AutoProcessor.from_pretrained(config.model_name) |
| self.clip_cate_model = AutoModel.from_pretrained(config.model_name) |
| self.clip_unary_model = AutoModel.from_pretrained(config.model_name) |
| self.clip_binary_model = AutoModel.from_pretrained(config.model_name) |
| |
| |
| |
| if config.use_hf_repo: |
| self._load_huggingface_vine_weights(config.model_repo, config.model_file) |
| else: |
| self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename) |
| |
| |
| self.to(self._device) |
| |
| def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None): |
| """ |
| Load pretrained VINE weights from HuggingFace Hub. |
| """ |
| try: |
| print(f"Loading VINE weights from HuggingFace repo: {model_repo}") |
| repo_path = snapshot_download(model_repo, revision=model_file or "main") |
| weights = load_file(os.path.join(repo_path, "model.safetensors")) |
| self.load_state_dict(weights, strict=False) |
| print("✓ Successfully loaded VINE weights from HuggingFace Hub") |
| return True |
| except Exception as e: |
| print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}") |
| print("Using base CLIP models instead") |
| return False |
|
|
| def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0): |
| """ |
| Load pretrained VINE weights from a saved .pt file or ensemble format. |
| """ |
| |
|
|
| |
| |
| full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir |
|
|
| if full_path.endswith(".pkl"): |
| print(f"Loading VINE weights from: {full_path}") |
| loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False) |
| |
| print(f"Loaded state type: {type(loaded_vine_model)}") |
| if not isinstance(loaded_vine_model, dict): |
| if hasattr(loaded_vine_model, 'clip_cate_model'): |
| self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict()) |
| if hasattr(loaded_vine_model, 'clip_unary_model'): |
| self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict()) |
| if hasattr(loaded_vine_model, 'clip_binary_model'): |
| self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict()) |
| return True |
| |
| elif full_path.endswith(".pt") or full_path.endswith(".pth"): |
| state = torch.load(full_path, map_location=self._device, weights_only=True) |
| print(f"Loaded state type: {type(state)}") |
| self.load_state_dict(state) |
| return True |
|
|
| |
| if os.path.isdir(full_path): |
| model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')] |
| if model_files: |
| model_file = os.path.join(full_path, model_files[0]) |
| print(f"Loading VINE weights from: {model_file}") |
| pretrained_model = torch.load(model_file, map_location="cpu") |
|
|
| |
| |
| if hasattr(pretrained_model, 'clip_cate_model'): |
| self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict()) |
| if hasattr(pretrained_model, 'clip_unary_model'): |
| self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict()) |
| if hasattr(pretrained_model, 'clip_binary_model'): |
| self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict()) |
| print("✓ Loaded all sub-model weights from ensemble format") |
| return True |
| else: |
| print(f"No model file found for epoch {epoch} in {full_path}") |
| return False |
|
|
| print("Unsupported format for pretrained_vine_path") |
| return False |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @classmethod |
| def from_pretrained_vine( |
| cls, |
| model_path: str, |
| config: Optional[VineConfig] = None, |
| epoch: int = 0, |
| **kwargs |
| ): |
| """ |
| Create VineModel from pretrained VINE weights. |
| |
| Args: |
| model_path: Path to pretrained VINE model |
| config: Optional config, will create default if None |
| epoch: Epoch number to load |
| **kwargs: Additional arguments |
| |
| Returns: |
| VineModel instance with loaded weights |
| """ |
| |
| if config is None: |
| |
| |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): |
| config = VineConfig(use_hf_repo=True, model_repo=model_path) |
| else: |
| |
| if os.path.isdir(model_path): |
| config = VineConfig(use_hf_repo=False, local_dir=model_path) |
| else: |
| config = VineConfig( |
| use_hf_repo=False, |
| local_dir=os.path.dirname(model_path) or None, |
| local_filename=os.path.basename(model_path) or None, |
| ) |
| else: |
| |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): |
| config.use_hf_repo = True |
| config.model_repo = model_path |
| config.model_file = None |
| config.local_dir = None |
| config.local_filename = None |
| else: |
| config.use_hf_repo = False |
| if os.path.isdir(model_path): |
| config.local_dir = model_path |
| config.local_filename = None |
| else: |
| config.local_dir = os.path.dirname(model_path) or None |
| config.local_filename = os.path.basename(model_path) or None |
| |
| |
| model = cls(config, **kwargs) |
| |
| return model |
| |
| def _text_features_checkpoint(self, model, tokens): |
| """Extract text features with gradient checkpointing.""" |
| token_keys = list(tokens.keys()) |
|
|
| def get_text_features_wrapped(*inputs): |
| kwargs = {key: value for key, value in zip(token_keys, inputs)} |
| return model.get_text_features(**kwargs) |
|
|
| token_values = [tokens[key] for key in token_keys] |
| return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False) |
| |
| def _image_features_checkpoint(self, model, images): |
| """Extract image features with gradient checkpointing.""" |
| return cp.checkpoint(model.get_image_features, images, use_reentrant=False) |
| |
| def clip_sim(self, model, nl_feat, img_feat): |
| img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True) |
| nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True) |
| logits = torch.matmul(img_feat, nl_feat.T) |
| if hasattr(model, "logit_scale"): |
| logits = logits * model.logit_scale.exp() |
| return logits |
| |
| def forward( |
| self, |
| video_frames: torch.Tensor, |
| masks: Dict[int, Dict[int, torch.Tensor]], |
| bboxes: Dict[int, Dict[int, List]], |
| categorical_keywords: List[str], |
| unary_keywords: Optional[List[str]] = None, |
| binary_keywords: Optional[List[str]] = None, |
| object_pairs: Optional[List[Tuple[int, int]]] = None, |
| return_flattened_segments: Optional[bool] = None, |
| return_valid_pairs: Optional[bool] = None, |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, |
| debug_visualizations: Optional[bool] = None, |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Forward pass of the VINE model. |
| |
| Args: |
| video_frames: Tensor of shape (num_frames, height, width, 3) |
| masks: Dict mapping frame_id -> object_id -> mask tensor |
| bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] |
| categorical_keywords: List of category names to classify objects |
| unary_keywords: Optional list of unary predicates (actions on single objects) |
| binary_keywords: Optional list of binary predicates (relations between objects) |
| object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification |
| |
| Returns: |
| Dict containing probability distributions for categorical, unary, and binary predictions |
| """ |
| if unary_keywords is None: |
| unary_keywords = [] |
| if binary_keywords is None: |
| binary_keywords = [] |
| if object_pairs is None: |
| object_pairs = [] |
| if return_flattened_segments is None: |
| return_flattened_segments = self.config.return_flattened_segments |
| if return_valid_pairs is None: |
| return_valid_pairs = self.config.return_valid_pairs |
| if interested_object_pairs is None or len(interested_object_pairs) == 0: |
| interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or [] |
| if debug_visualizations is None: |
| debug_visualizations = self.debug_visualizations |
| |
| |
| dummy_str = "" |
| |
| |
| if len(categorical_keywords) == 0: |
| categorical_keywords = [dummy_str] |
| if len(unary_keywords) == 0: |
| unary_keywords = [dummy_str] |
| if len(binary_keywords) == 0: |
| binary_keywords = [dummy_str] |
| |
| |
| categorical_features = self._extract_text_features( |
| self.clip_cate_model, categorical_keywords |
| ) |
| unary_features = self._extract_text_features( |
| self.clip_unary_model, unary_keywords |
| ) |
| binary_features = self._extract_text_features( |
| self.clip_binary_model, binary_keywords |
| ) |
| |
| |
| categorical_probs = {} |
| unary_probs = {} |
| binary_probs = {} |
| |
| |
| for frame_id, frame_masks in masks.items(): |
| if frame_id >= len(video_frames): |
| continue |
| |
| frame = self._frame_to_numpy(video_frames[frame_id]) |
| frame_bboxes = bboxes.get(frame_id, {}) |
| |
| |
| for obj_id, mask in frame_masks.items(): |
| if obj_id not in frame_bboxes: |
| continue |
| |
| bbox = frame_bboxes[obj_id] |
| |
| |
| mask_np = self._mask_to_numpy(mask) |
| |
| obj_image = extract_single_object( |
| frame, mask_np, alpha=self.config.alpha |
| ) |
| |
| |
| obj_features = self._extract_image_features( |
| self.clip_cate_model, obj_image |
| ) |
| |
| |
| cat_similarities = self.clip_sim( |
| self.clip_cate_model, categorical_features, obj_features |
| ) |
| cat_probs = F.softmax(cat_similarities, dim=-1) |
| |
| |
| for i, keyword in enumerate(categorical_keywords): |
| if keyword != dummy_str: |
| categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item() |
| |
| |
| if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str: |
| unary_similarities = self.clip_sim( |
| self.clip_unary_model, unary_features, obj_features |
| ) |
| unary_probs_tensor = F.softmax(unary_similarities, dim=-1) |
| |
| for i, keyword in enumerate(unary_keywords): |
| if keyword != dummy_str: |
| unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item() |
| |
| |
| if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0: |
| for obj1_id, obj2_id in object_pairs: |
| for frame_id, frame_masks in masks.items(): |
| if frame_id >= len(video_frames): |
| continue |
| if (obj1_id in frame_masks and obj2_id in frame_masks and |
| obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})): |
| |
| frame = self._frame_to_numpy(video_frames[frame_id]) |
| mask1 = frame_masks[obj1_id] |
| mask2 = frame_masks[obj2_id] |
| |
| mask1_np = self._mask_to_numpy(mask1) |
| mask2_np = self._mask_to_numpy(mask2) |
| |
| |
| pair_image = extract_object_subject( |
| frame, mask1_np[..., None], mask2_np[..., None], |
| alpha=self.config.alpha, |
| white_alpha=self.config.white_alpha |
| ) |
| |
| |
| bbox1 = bboxes[frame_id][obj1_id] |
| bbox2 = bboxes[frame_id][obj2_id] |
| |
| |
| if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \ |
| bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]: |
| continue |
| |
| cropped_image = crop_image_contain_bboxes( |
| pair_image, [bbox1, bbox2], f"frame_{frame_id}" |
| ) |
| |
| |
| pair_features = self._extract_image_features( |
| self.clip_binary_model, cropped_image |
| ) |
| |
| |
| binary_similarities = self.clip_sim( |
| self.clip_binary_model, binary_features, pair_features |
| ) |
| binary_probs_tensor = F.softmax(binary_similarities, dim=-1) |
| |
| for i, keyword in enumerate(binary_keywords): |
| if keyword != dummy_str: |
| binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item() |
| |
| |
| dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords)) |
| |
| result: Dict[str, Any] = { |
| "categorical_probs": {0: categorical_probs}, |
| "unary_probs": {0: unary_probs}, |
| "binary_probs": [binary_probs], |
| "dummy_prob": dummy_prob |
| } |
|
|
| if return_flattened_segments or return_valid_pairs: |
| flattened = flatten_segments_for_batch( |
| video_id=0, |
| segments=masks, |
| bbox_min_dim=self.config.bbox_min_dim, |
| ) |
| if return_flattened_segments: |
| result["flattened_segments"] = flattened |
| if return_valid_pairs: |
| interested_pairs = interested_object_pairs if interested_object_pairs else None |
| result["valid_pairs"] = extract_valid_object_pairs( |
| flattened["object_ids"], |
| interested_pairs, |
| ) |
| if interested_pairs is None: |
| |
| result["valid_pairs_metadata"] = {"pair_source": "all_pairs"} |
| else: |
| result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs} |
| |
| return result |
| |
| def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: |
| """Convert a frame tensor/array to a contiguous numpy array.""" |
| if torch.is_tensor(frame): |
| frame_np = frame.detach().cpu().numpy() |
| else: |
| frame_np = np.asarray(frame) |
| return np.ascontiguousarray(frame_np) |
|
|
| def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray: |
| """Convert a mask tensor/array to a 2D boolean numpy array.""" |
| if torch.is_tensor(mask): |
| mask_np = mask.detach().cpu().numpy() |
| else: |
| mask_np = np.asarray(mask) |
|
|
| if mask_np.ndim == 3: |
| if mask_np.shape[0] == 1: |
| mask_np = mask_np.squeeze(0) |
| elif mask_np.shape[2] == 1: |
| mask_np = mask_np.squeeze(2) |
|
|
| if mask_np.ndim != 2: |
| raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}") |
|
|
| return mask_np.astype(bool, copy=False) |
|
|
| def _extract_text_features(self, model, keywords): |
| """Extract text features for given keywords.""" |
| tokens = self.clip_tokenizer( |
| keywords, |
| return_tensors="pt", |
| max_length=75, |
| truncation=True, |
| padding='max_length' |
| ).to(self._device) |
| |
| return self._text_features_checkpoint(model, tokens) |
| |
| def _extract_image_features(self, model, image): |
| """Extract image features for given image.""" |
| |
| if isinstance(image, np.ndarray): |
| if image.dtype != np.uint8: |
| image = image.astype(np.uint8) |
| |
| if len(image.shape) == 3 and image.shape[2] == 3: |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| inputs = self.clip_processor( |
| images=image, |
| return_tensors="pt" |
| ).to(self._device) |
| |
| return self._image_features_checkpoint(model, inputs['pixel_values']) |
| |
| def predict( |
| self, |
| video_frames: torch.Tensor, |
| masks: Dict[int, Dict[int, torch.Tensor]], |
| bboxes: Dict[int, Dict[int, List]], |
| categorical_keywords: List[str], |
| unary_keywords: Optional[List[str]] = None, |
| binary_keywords: Optional[List[str]] = None, |
| object_pairs: Optional[List[Tuple[int, int]]] = None, |
| return_top_k: int = 3, |
| return_flattened_segments: Optional[bool] = None, |
| return_valid_pairs: Optional[bool] = None, |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, |
| debug_visualizations: Optional[bool] = None, |
| ) -> Dict[str, Any]: |
| """ |
| High-level prediction method that returns formatted results. |
| |
| Args: |
| video_frames: Tensor of shape (num_frames, height, width, 3) |
| masks: Dict mapping frame_id -> object_id -> mask tensor |
| bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] |
| categorical_keywords: List of category names |
| unary_keywords: Optional list of unary predicates |
| binary_keywords: Optional list of binary predicates |
| object_pairs: Optional list of object pairs for binary relations |
| return_top_k: Number of top predictions to return |
| return_flattened_segments: Whether to include flattened mask/bbox tensors |
| return_valid_pairs: Whether to compute valid object pairs per frame |
| interested_object_pairs: Optional subset of object pairs to track |
| |
| Returns: |
| Formatted prediction results |
| """ |
| |
| with torch.no_grad(): |
| outputs = self.forward( |
| video_frames=video_frames, |
| masks=masks, |
| bboxes=bboxes, |
| categorical_keywords=categorical_keywords, |
| unary_keywords=unary_keywords, |
| binary_keywords=binary_keywords, |
| object_pairs=object_pairs, |
| return_flattened_segments=return_flattened_segments, |
| return_valid_pairs=return_valid_pairs, |
| interested_object_pairs=interested_object_pairs, |
| debug_visualizations=debug_visualizations, |
| ) |
| |
| |
| formatted_categorical = {} |
| for (obj_id, category), prob in outputs["categorical_probs"][0].items(): |
| if obj_id not in formatted_categorical: |
| formatted_categorical[obj_id] = [] |
| formatted_categorical[obj_id].append((prob, category)) |
| |
| |
| for obj_id in formatted_categorical: |
| formatted_categorical[obj_id] = sorted( |
| formatted_categorical[obj_id], reverse=True |
| )[:return_top_k] |
| |
| |
| formatted_unary = {} |
| for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items(): |
| key = (frame_id, obj_id) |
| if key not in formatted_unary: |
| formatted_unary[key] = [] |
| formatted_unary[key].append((prob, predicate)) |
| |
| |
| for key in formatted_unary: |
| formatted_unary[key] = sorted( |
| formatted_unary[key], reverse=True |
| )[:return_top_k] |
| |
| |
| formatted_binary = {} |
| if len(outputs["binary_probs"]) > 0: |
| for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items(): |
| key = (frame_id, obj_pair) |
| if key not in formatted_binary: |
| formatted_binary[key] = [] |
| formatted_binary[key].append((prob, predicate)) |
| |
| |
| for key in formatted_binary: |
| formatted_binary[key] = sorted( |
| formatted_binary[key], reverse=True |
| )[:return_top_k] |
| |
| result: Dict[str, Any] = { |
| "categorical_predictions": formatted_categorical, |
| "unary_predictions": formatted_unary, |
| "binary_predictions": formatted_binary, |
| "confidence_scores": { |
| "categorical": max([max([p for p, _ in preds], default=0.0) |
| for preds in formatted_categorical.values()], default=0.0), |
| "unary": max([max([p for p, _ in preds], default=0.0) |
| for preds in formatted_unary.values()], default=0.0), |
| "binary": max([max([p for p, _ in preds], default=0.0) |
| for preds in formatted_binary.values()], default=0.0) |
| } |
| } |
|
|
| if "flattened_segments" in outputs: |
| result["flattened_segments"] = outputs["flattened_segments"] |
| if "valid_pairs" in outputs: |
| result["valid_pairs"] = outputs["valid_pairs"] |
| if "valid_pairs_metadata" in outputs: |
| result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"] |
|
|
| return result |
|
|