| import torch |
| from transformers import PretrainedConfig |
| from typing import List, Optional, Dict, Any, Tuple |
|
|
|
|
| class VineConfig(PretrainedConfig): |
| """ |
| Configuration class for VINE (Video Understanding with Natural Language) model. |
| |
| VINE is a video understanding model that processes categorical (object class names), |
| unary keywords (actions on one object), and binary keywords (relations between two objects), |
| and returns probability distributions over all of them when passed a video. |
| |
| Args: |
| model_name (str): The CLIP model name to use as backbone. Default: "openai/clip-vit-large-patch14-336" |
| hidden_dim (int): Hidden dimension size. Default: 768 |
| num_top_pairs (int): Number of top object pairs to consider. Default: 10 |
| segmentation_method (str): Segmentation method to use ("sam2" or "grounding_dino_sam2"). Default: "grounding_dino_sam2" |
| box_threshold (float): Box threshold for Grounding DINO. Default: 0.35 |
| text_threshold (float): Text threshold for Grounding DINO. Default: 0.25 |
| target_fps (int): Target FPS for video processing. Default: 1 |
| alpha (float): Alpha value for object extraction. Default: 0.5 |
| white_alpha (float): White alpha value for background blending. Default: 0.8 |
| topk_cate (int): Top-k categories to return. Default: 3 |
| multi_class (bool): Whether to use multi-class classification. Default: False |
| output_logit (bool): Whether to output logits instead of probabilities. Default: False |
| max_video_length (int): Maximum number of frames to process. Default: 100 |
| bbox_min_dim (int): Minimum bounding box dimension. Default: 5 |
| visualize (bool): Whether to visualize results. Default: False |
| visualization_dir (str, optional): Directory to save visualizations. Default: None |
| debug_visualizations (bool): Whether to save debug visualizations. Default: False |
| return_flattened_segments (bool): Whether to return flattened segments. Default: False |
| return_valid_pairs (bool): Whether to return valid object pairs. Default: False |
| interested_object_pairs (List[Tuple[int, int]], optional): List of interested object pairs |
| """ |
| |
| model_type = "vine" |
| |
| def __init__( |
| self, |
| model_name: str = "openai/clip-vit-base-patch32", |
| hidden_dim = 768, |
| |
| use_hf_repo: bool = False, |
| model_repo: Optional[str] = None, |
| model_file: Optional[str] = None, |
| local_dir: Optional[str] = None, |
| local_filename: Optional[str] = None, |
| |
| num_top_pairs: int = 18, |
| segmentation_method: str = "grounding_dino_sam2", |
| box_threshold: float = 0.35, |
| text_threshold: float = 0.25, |
| target_fps: int = 1, |
| alpha: float = 0.5, |
| white_alpha: float = 0.8, |
| topk_cate: int = 3, |
| multi_class: bool = False, |
| output_logit: bool = False, |
| max_video_length: int = 100, |
| bbox_min_dim: int = 5, |
| visualize: bool = False, |
| visualization_dir: Optional[str] = None, |
| return_flattened_segments: bool = False, |
| return_valid_pairs: bool = False, |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, |
| debug_visualizations: bool = False, |
| device: Optional[str | int] = None, |
| **kwargs |
| ): |
| self.model_name = model_name |
| self.use_hf_repo = use_hf_repo |
| if use_hf_repo: |
| self.model_repo = model_repo |
| self.model_file = model_file |
| self.local_dir = None |
| self.local_filename = None |
| else: |
| self.model_repo = None |
| self.model_file = None |
| self.local_dir = local_dir |
| self.local_filename = local_filename |
| self.hidden_dim = hidden_dim |
| self.num_top_pairs = num_top_pairs |
| self.segmentation_method = segmentation_method |
| self.box_threshold = box_threshold |
| self.text_threshold = text_threshold |
| self.target_fps = target_fps |
| self.alpha = alpha |
| self.white_alpha = white_alpha |
| self.topk_cate = topk_cate |
| self.multi_class = multi_class |
| self.output_logit = output_logit |
| self.max_video_length = max_video_length |
| self.bbox_min_dim = bbox_min_dim |
| self.visualize = visualize |
| self.visualization_dir = visualization_dir |
| self.return_flattened_segments = return_flattened_segments |
| self.return_valid_pairs = return_valid_pairs |
| self.interested_object_pairs = interested_object_pairs or [] |
| self.debug_visualizations = debug_visualizations |
| if device is int: |
| self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" |
| else: |
| self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| super().__init__(**kwargs) |
|
|