Spaces:
Paused
Paused
| """ | |
| GCTBase - Base class for GCT model implementations. | |
| Provides shared functionality: | |
| - Prediction heads (camera, depth, point) | |
| - Forward pass structure | |
| - Model hub mixin (PyTorchModelHubMixin) | |
| """ | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from abc import ABC, abstractmethod | |
| from typing import Optional, Dict, Any, List, Union | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from lingbot_map.heads.dpt_head import DPTHead | |
| from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri | |
| from lingbot_map.utils.geometry import closed_form_inverse_se3 | |
| logger = logging.getLogger(__name__) | |
| class GCTBase(nn.Module, PyTorchModelHubMixin, ABC): | |
| """ | |
| Base class for GCT model implementations. | |
| Handles shared components: | |
| - Prediction heads (camera, depth, point) | |
| - Forward pass structure | |
| - Input normalization | |
| Subclasses must implement: | |
| - _build_aggregator(): Create mode-specific aggregator | |
| - _build_camera_head(): Create mode-specific camera head | |
| """ | |
| def __init__( | |
| self, | |
| # Architecture parameters | |
| img_size: int = 518, | |
| patch_size: int = 14, | |
| embed_dim: int = 1024, | |
| patch_embed: str = 'dinov2_vitl14_reg', | |
| disable_global_rope: bool = False, | |
| # Head configuration | |
| enable_camera: bool = True, | |
| enable_point: bool = True, | |
| enable_local_point: bool = False, | |
| enable_depth: bool = True, | |
| enable_track: bool = False, | |
| # Camera head sliding window | |
| enable_camera_sliding_window: bool = False, | |
| # 3D RoPE | |
| enable_3d_rope: bool = False, | |
| # Context Parallelism (kept for checkpoint compatibility but not used) | |
| enable_ulysses_cp: bool = False, | |
| # Normalization | |
| enable_normalize: bool = False, | |
| # Prediction normalization | |
| pred_normalization: bool = False, | |
| pred_normalization_detach_scale: bool = False, | |
| # Gradient checkpointing | |
| use_gradient_checkpoint: bool = True, | |
| ): | |
| super().__init__() | |
| # Store configuration | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.patch_embed = patch_embed | |
| self.disable_global_rope = disable_global_rope | |
| self.enable_ulysses_cp = False # CP disabled in standalone package | |
| self.enable_normalize = enable_normalize | |
| self.pred_normalization = pred_normalization | |
| self.pred_normalization_detach_scale = pred_normalization_detach_scale | |
| self.use_gradient_checkpoint = use_gradient_checkpoint | |
| # Head flags | |
| self.enable_camera = enable_camera | |
| self.enable_point = enable_point | |
| self.enable_local_point = enable_local_point | |
| self.enable_depth = enable_depth | |
| self.enable_track = enable_track | |
| self.enable_camera_sliding_window = enable_camera_sliding_window | |
| self.enable_3d_rope = enable_3d_rope | |
| # Build aggregator (subclass-specific) | |
| self.aggregator = self._build_aggregator() | |
| # Build prediction heads (subclass-specific) | |
| self.camera_head = self._build_camera_head() if enable_camera else None | |
| self.point_head = self._build_point_head() if enable_point else None | |
| self.local_point_head = self._build_local_point_head() if enable_local_point else None | |
| self.depth_head = self._build_depth_head() if enable_depth else None | |
| def _build_aggregator(self) -> nn.Module: | |
| pass | |
| def _build_camera_head(self) -> nn.Module: | |
| pass | |
| def _build_depth_head(self) -> nn.Module: | |
| return DPTHead( | |
| dim_in=2 * self.embed_dim, | |
| patch_size=self.patch_size, | |
| output_dim=2, | |
| activation="exp", | |
| conf_activation="expp1" | |
| ) | |
| def _build_point_head(self) -> nn.Module: | |
| return DPTHead( | |
| dim_in=2 * self.embed_dim, | |
| patch_size=self.patch_size, | |
| output_dim=4, | |
| activation="inv_log", | |
| conf_activation="expp1" | |
| ) | |
| def _build_local_point_head(self) -> nn.Module: | |
| return DPTHead( | |
| dim_in=2 * self.embed_dim, | |
| patch_size=self.patch_size, | |
| output_dim=4, | |
| activation="inv_log", | |
| conf_activation="expp1" | |
| ) | |
| def _normalize_input(self, images: torch.Tensor, query_points=None): | |
| if len(images.shape) == 4: | |
| images = images.unsqueeze(0) | |
| if query_points is not None and len(query_points.shape) == 2: | |
| query_points = query_points.unsqueeze(0) | |
| return images, query_points | |
| def _aggregate_features( | |
| self, | |
| images: torch.Tensor, | |
| num_frame_for_scale: Optional[int] = None, | |
| sliding_window_size: Optional[int] = None, | |
| num_frame_per_block: int = 1, | |
| view_graphs: Optional[torch.Tensor] = None, | |
| causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None, | |
| ordered_video: Optional[torch.Tensor] = None, | |
| is_cp_sliced: bool = False, | |
| ) -> tuple: | |
| pass | |
| def _predict_camera( | |
| self, | |
| aggregated_tokens_list: list, | |
| mask: Optional[torch.Tensor] = None, | |
| causal_inference: bool = False, | |
| num_frame_for_scale: Optional[int] = None, | |
| sliding_window_size: Optional[int] = None, | |
| num_frame_per_block: int = 1, | |
| gather_outputs: bool = True, | |
| ) -> Dict[str, torch.Tensor]: | |
| if self.camera_head is None: | |
| return {} | |
| aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] | |
| camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1 | |
| with torch.amp.autocast('cuda', enabled=False): | |
| pose_enc_list = self.camera_head( | |
| aggregated_tokens_list_fp32, | |
| mask=mask, | |
| causal_inference=causal_inference, | |
| num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1, | |
| sliding_window_size=camera_sliding_window, | |
| num_frame_per_block=num_frame_per_block, | |
| ) | |
| return { | |
| "pose_enc": pose_enc_list[-1], | |
| "pose_enc_list": pose_enc_list, | |
| } | |
| def _predict_depth( | |
| self, | |
| aggregated_tokens_list: list, | |
| images: torch.Tensor, | |
| patch_start_idx: int, | |
| gather_outputs: bool = True, | |
| ) -> Dict[str, torch.Tensor]: | |
| if self.depth_head is None: | |
| return {} | |
| aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] | |
| images_fp32 = images.float() | |
| with torch.amp.autocast('cuda', enabled=False): | |
| depth, depth_conf = self.depth_head( | |
| aggregated_tokens_list_fp32, | |
| images=images_fp32, | |
| patch_start_idx=patch_start_idx | |
| ) | |
| return {"depth": depth, "depth_conf": depth_conf} | |
| def _predict_points( | |
| self, | |
| aggregated_tokens_list: list, | |
| images: torch.Tensor, | |
| patch_start_idx: int, | |
| gather_outputs: bool = True, | |
| ) -> Dict[str, torch.Tensor]: | |
| if self.point_head is None: | |
| return {} | |
| aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] | |
| images_fp32 = images.float() | |
| with torch.amp.autocast('cuda', enabled=False): | |
| pts3d, pts3d_conf = self.point_head( | |
| aggregated_tokens_list_fp32, | |
| images=images_fp32, | |
| patch_start_idx=patch_start_idx | |
| ) | |
| return {"world_points": pts3d, "world_points_conf": pts3d_conf} | |
| def _predict_local_points( | |
| self, | |
| aggregated_tokens_list: list, | |
| images: torch.Tensor, | |
| patch_start_idx: int, | |
| gather_outputs: bool = True, | |
| ) -> Dict[str, torch.Tensor]: | |
| if self.local_point_head is None: | |
| return {} | |
| aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] | |
| images_fp32 = images.float() | |
| with torch.amp.autocast('cuda', enabled=False): | |
| pts3d, pts3d_conf = self.local_point_head( | |
| aggregated_tokens_list_fp32, | |
| images=images_fp32, | |
| patch_start_idx=patch_start_idx | |
| ) | |
| return {"cam_points": pts3d, "cam_points_conf": pts3d_conf} | |
| def _unproject_depth_to_world( | |
| self, | |
| depth: torch.Tensor, | |
| pose_enc: torch.Tensor, | |
| ) -> torch.Tensor: | |
| B, S, H, W, _ = depth.shape | |
| device = depth.device | |
| dtype = depth.dtype | |
| image_size_hw = (H, W) | |
| extrinsics, intrinsics = pose_encoding_to_extri_intri( | |
| pose_enc, image_size_hw=image_size_hw, build_intrinsics=True | |
| ) | |
| extrinsics_flat = extrinsics.view(B * S, 3, 4) | |
| extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype) | |
| extrinsics_4x4[:, :3, :] = extrinsics_flat | |
| extrinsics_4x4[:, 3, 3] = 1.0 | |
| c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4) | |
| y_grid, x_grid = torch.meshgrid( | |
| torch.arange(H, device=device, dtype=dtype), | |
| torch.arange(W, device=device, dtype=dtype), | |
| indexing='ij' | |
| ) | |
| pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1) | |
| intrinsics_inv = torch.inverse(intrinsics) | |
| camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords) | |
| camera_points = camera_coords * depth | |
| ones = torch.ones_like(camera_points[..., :1]) | |
| camera_points_h = torch.cat([camera_points, ones], dim=-1) | |
| world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h) | |
| return world_points_h[..., :3] | |
| def forward( | |
| self, | |
| images: torch.Tensor, | |
| query_points: Optional[torch.Tensor] = None, | |
| num_frame_for_scale: Optional[int] = None, | |
| sliding_window_size: Optional[int] = None, | |
| num_frame_per_block: int = 1, | |
| mask: Optional[torch.Tensor] = None, | |
| causal_inference: bool = False, | |
| ordered_video: Optional[torch.Tensor] = None, | |
| gather_outputs: bool = True, | |
| point_masks: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass of the GCT model. | |
| Args: | |
| images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1] | |
| query_points: Optional query points [N, 2] or [B, N, 2] | |
| Returns: | |
| Dictionary containing predictions: | |
| - pose_enc: Camera pose encoding [B, S, 9] | |
| - depth: Depth maps [B, S, H, W, 1] | |
| - depth_conf: Depth confidence [B, S, H, W] | |
| - world_points: 3D world coordinates [B, S, H, W, 3] | |
| - world_points_conf: Point confidence [B, S, H, W] | |
| """ | |
| images, query_points = self._normalize_input(images, query_points) | |
| aggregated_tokens_list, patch_start_idx = self._aggregate_features( | |
| images, | |
| num_frame_for_scale=num_frame_for_scale, | |
| sliding_window_size=sliding_window_size, | |
| num_frame_per_block=num_frame_per_block, | |
| ) | |
| predictions = {} | |
| predictions.update(self._predict_camera( | |
| aggregated_tokens_list, | |
| mask=ordered_video, | |
| causal_inference=causal_inference, | |
| num_frame_for_scale=num_frame_for_scale, | |
| sliding_window_size=sliding_window_size, | |
| num_frame_per_block=num_frame_per_block, | |
| gather_outputs=gather_outputs, | |
| )) | |
| predictions.update(self._predict_depth( | |
| aggregated_tokens_list, images, patch_start_idx, | |
| gather_outputs=gather_outputs, | |
| )) | |
| predictions.update(self._predict_points( | |
| aggregated_tokens_list, images, patch_start_idx, | |
| gather_outputs=gather_outputs, | |
| )) | |
| predictions.update(self._predict_local_points( | |
| aggregated_tokens_list, images, patch_start_idx, | |
| gather_outputs=gather_outputs, | |
| )) | |
| if not self.training: | |
| predictions["images"] = images | |
| return predictions | |