Spaces:
Running
Running
| """ | |
| GCTStream - Streaming GCT with KV cache for online inference. | |
| Provides streaming inference functionality: | |
| - Temporal causal attention with KV cache | |
| - Sliding window support | |
| - Efficient frame-by-frame processing | |
| - 3D RoPE support for temporal consistency | |
| """ | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Dict, Any, List | |
| from tqdm.auto import tqdm | |
| from lingbot_map.heads.camera_head import CameraCausalHead | |
| from lingbot_map.models.gct_base import GCTBase | |
| from lingbot_map.aggregator.stream import AggregatorStream | |
| logger = logging.getLogger(__name__) | |
| class GCTStream(GCTBase): | |
| """ | |
| Streaming GCT model with KV cache for efficient online inference. | |
| Features: | |
| - AggregatorStream with KV cache support (FlashInfer backend) | |
| - CameraCausalHead for pose refinement | |
| - Sliding window attention for memory efficiency | |
| - Frame-by-frame streaming inference | |
| """ | |
| def __init__( | |
| self, | |
| # Architecture parameters | |
| img_size: int = 518, | |
| patch_size: int = 14, | |
| embed_dim: int = 1024, | |
| patch_embed: str = 'dinov2_vitl14_reg', | |
| pretrained_path: str = '', | |
| disable_global_rope: bool = False, | |
| # Head configuration | |
| enable_camera: bool = True, | |
| enable_point: bool = True, | |
| enable_local_point: bool = False, | |
| enable_depth: bool = True, | |
| enable_track: bool = False, | |
| # Normalization | |
| enable_normalize: bool = False, | |
| # Prediction normalization | |
| pred_normalization: bool = False, | |
| # Stream-specific parameters | |
| sliding_window_size: int = -1, | |
| num_frame_for_scale: int = 1, | |
| num_random_frames: int = 0, | |
| attend_to_special_tokens: bool = False, | |
| attend_to_scale_frames: bool = False, | |
| enable_stream_inference: bool = True, # Default to True for streaming | |
| enable_3d_rope: bool = False, | |
| max_frame_num: int = 1024, | |
| # Camera head 3D RoPE (separate from aggregator 3D RoPE) | |
| enable_camera_3d_rope: bool = False, | |
| camera_rope_theta: float = 10000.0, | |
| # Scale token configuration (kept for checkpoint compat, ignored) | |
| use_scale_token: bool = True, | |
| # KV cache parameters | |
| kv_cache_sliding_window: int = 64, | |
| kv_cache_scale_frames: int = 8, | |
| kv_cache_cross_frame_special: bool = True, | |
| kv_cache_include_scale_frames: bool = True, | |
| kv_cache_camera_only: bool = False, | |
| # Backend selection | |
| use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer | |
| # Gradient checkpointing | |
| use_gradient_checkpoint: bool = True, | |
| # Camera head iterative refinement (lower = faster inference; default 4) | |
| camera_num_iterations: int = 4, | |
| ): | |
| """ | |
| Initialize GCTStream. | |
| Args: | |
| img_size: Input image size | |
| patch_size: Patch size for embedding | |
| embed_dim: Embedding dimension | |
| patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.) | |
| pretrained_path: Path to pretrained DINOv2 weights | |
| disable_global_rope: Disable RoPE in global attention | |
| enable_camera/point/depth/track: Enable prediction heads | |
| enable_normalize: Enable normalization | |
| sliding_window_size: Sliding window size in blocks (-1 for full causal) | |
| num_frame_for_scale: Number of scale estimation frames | |
| num_random_frames: Number of random frames for long-range dependencies | |
| attend_to_special_tokens: Enable cross-frame special token attention | |
| attend_to_scale_frames: Whether to attend to scale frames | |
| enable_stream_inference: Enable streaming inference with KV cache | |
| enable_3d_rope: Enable 3D RoPE for temporal consistency | |
| max_frame_num: Maximum number of frames for 3D RoPE | |
| use_scale_token: Kept for checkpoint compatibility, ignored | |
| kv_cache_sliding_window: Sliding window size for KV cache eviction | |
| kv_cache_scale_frames: Number of scale frames to keep in KV cache | |
| kv_cache_cross_frame_special: Keep special tokens from evicted frames | |
| kv_cache_include_scale_frames: Include scale frames in KV cache | |
| kv_cache_camera_only: Only keep camera tokens from evicted frames | |
| """ | |
| # Store stream-specific parameters before calling super().__init__() | |
| self.pretrained_path = pretrained_path | |
| self.sliding_window_size = sliding_window_size | |
| self.num_frame_for_scale = num_frame_for_scale | |
| self.num_random_frames = num_random_frames | |
| self.attend_to_special_tokens = attend_to_special_tokens | |
| self.attend_to_scale_frames = attend_to_scale_frames | |
| self.enable_stream_inference = enable_stream_inference | |
| self.enable_3d_rope = enable_3d_rope | |
| self.max_frame_num = max_frame_num | |
| # Camera head 3D RoPE settings | |
| self.enable_camera_3d_rope = enable_camera_3d_rope | |
| self.camera_rope_theta = camera_rope_theta | |
| # KV cache parameters | |
| self.kv_cache_sliding_window = kv_cache_sliding_window | |
| self.kv_cache_scale_frames = kv_cache_scale_frames | |
| self.kv_cache_cross_frame_special = kv_cache_cross_frame_special | |
| self.kv_cache_include_scale_frames = kv_cache_include_scale_frames | |
| self.kv_cache_camera_only = kv_cache_camera_only | |
| self.use_sdpa = use_sdpa | |
| self.camera_num_iterations = camera_num_iterations | |
| # Call base class __init__ (will call _build_aggregator) | |
| super().__init__( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| embed_dim=embed_dim, | |
| patch_embed=patch_embed, | |
| disable_global_rope=disable_global_rope, | |
| enable_camera=enable_camera, | |
| enable_point=enable_point, | |
| enable_local_point=enable_local_point, | |
| enable_depth=enable_depth, | |
| enable_track=enable_track, | |
| enable_normalize=enable_normalize, | |
| pred_normalization=pred_normalization, | |
| enable_3d_rope=enable_3d_rope, | |
| use_gradient_checkpoint=use_gradient_checkpoint, | |
| ) | |
| def _build_aggregator(self) -> nn.Module: | |
| """ | |
| Build streaming aggregator with KV cache support (FlashInfer backend). | |
| Returns: | |
| AggregatorStream module | |
| """ | |
| return AggregatorStream( | |
| img_size=self.img_size, | |
| patch_size=self.patch_size, | |
| embed_dim=self.embed_dim, | |
| patch_embed=self.patch_embed, | |
| pretrained_path=self.pretrained_path, | |
| disable_global_rope=self.disable_global_rope, | |
| sliding_window_size=self.sliding_window_size, | |
| num_frame_for_scale=self.num_frame_for_scale, | |
| num_random_frames=self.num_random_frames, | |
| attend_to_special_tokens=self.attend_to_special_tokens, | |
| attend_to_scale_frames=self.attend_to_scale_frames, | |
| enable_stream_inference=self.enable_stream_inference, | |
| enable_3d_rope=self.enable_3d_rope, | |
| max_frame_num=self.max_frame_num, | |
| # Backend: FlashInfer (default) or SDPA (fallback) | |
| use_flashinfer=not self.use_sdpa, | |
| use_sdpa=self.use_sdpa, | |
| kv_cache_sliding_window=self.kv_cache_sliding_window, | |
| kv_cache_scale_frames=self.kv_cache_scale_frames, | |
| kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, | |
| kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, | |
| kv_cache_camera_only=self.kv_cache_camera_only, | |
| use_gradient_checkpoint=self.use_gradient_checkpoint, | |
| ) | |
| def _build_camera_head(self) -> nn.Module: | |
| """ | |
| Build causal camera head for streaming inference. | |
| Returns: | |
| CameraCausalHead module or None | |
| """ | |
| return CameraCausalHead( | |
| dim_in=2 * self.embed_dim, | |
| sliding_window_size=self.sliding_window_size, | |
| attend_to_scale_frames=self.attend_to_scale_frames, | |
| num_iterations=self.camera_num_iterations, | |
| # KV cache parameters | |
| kv_cache_sliding_window=self.kv_cache_sliding_window, | |
| kv_cache_scale_frames=self.kv_cache_scale_frames, | |
| kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, | |
| kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, | |
| kv_cache_camera_only=self.kv_cache_camera_only, | |
| # Camera head 3D RoPE parameters | |
| enable_3d_rope=self.enable_camera_3d_rope, | |
| max_frame_num=self.max_frame_num, | |
| rope_theta=self.camera_rope_theta, | |
| ) | |
| def _aggregate_features( | |
| self, | |
| images: torch.Tensor, | |
| num_frame_for_scale: Optional[int] = None, | |
| sliding_window_size: Optional[int] = None, | |
| num_frame_per_block: int = 1, | |
| **kwargs, | |
| ) -> tuple: | |
| """ | |
| Run aggregator to get multi-scale features. | |
| Args: | |
| images: Input images [B, S, 3, H, W] | |
| num_frame_for_scale: Number of frames for scale estimation | |
| sliding_window_size: Override sliding window size | |
| num_frame_per_block: Number of frames per block | |
| Returns: | |
| (aggregated_tokens_list, patch_start_idx) | |
| """ | |
| aggregated_tokens_list, patch_start_idx = self.aggregator( | |
| images, | |
| selected_idx=[4, 11, 17, 23], | |
| num_frame_for_scale=num_frame_for_scale, | |
| sliding_window_size=sliding_window_size, | |
| num_frame_per_block=num_frame_per_block, | |
| ) | |
| return aggregated_tokens_list, patch_start_idx | |
| def clean_kv_cache(self): | |
| """ | |
| Clean KV cache in aggregator. | |
| Call this method when starting a new video sequence to clear | |
| cached key-value pairs from previous sequences. | |
| """ | |
| if hasattr(self.aggregator, 'clean_kv_cache'): | |
| self.aggregator.clean_kv_cache() | |
| else: | |
| logger.warning("Aggregator does not support KV cache cleaning") | |
| if hasattr(self.camera_head, 'kv_cache'): | |
| self.camera_head.clean_kv_cache() | |
| else: | |
| logger.warning("Camera head does not support KV cache cleaning") | |
| def _set_skip_append(self, skip: bool): | |
| """Set _skip_append flag on all KV caches (aggregator + camera head). | |
| When skip=True, attention layers will attend to [cached_kv + current_kv] | |
| but will NOT store the current frame's KV in cache. This is used for | |
| non-keyframe processing in keyframe-based streaming inference. | |
| Args: | |
| skip: If True, subsequent forward passes will not append KV to cache. | |
| """ | |
| if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None: | |
| self.aggregator.kv_cache["_skip_append"] = skip | |
| if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: | |
| for cache_dict in self.camera_head.kv_cache: | |
| cache_dict["_skip_append"] = skip | |
| def get_kv_cache_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about current KV cache state. | |
| Returns: | |
| Dictionary with cache statistics: | |
| - num_cached_blocks: Number of blocks with cached KV | |
| - cache_memory_mb: Approximate memory usage in MB | |
| """ | |
| if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None: | |
| return {"num_cached_blocks": 0, "cache_memory_mb": 0.0} | |
| kv_cache = self.aggregator.kv_cache | |
| num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special')) | |
| # Estimate memory usage | |
| total_elements = 0 | |
| for _, v in kv_cache.items(): | |
| if v is not None and torch.is_tensor(v): | |
| total_elements += v.numel() | |
| # Assume bfloat16 (2 bytes per element) | |
| cache_memory_mb = (total_elements * 2) / (1024 * 1024) | |
| return { | |
| "num_cached_blocks": num_cached, | |
| "cache_memory_mb": round(cache_memory_mb, 2) | |
| } | |
| def inference_streaming( | |
| self, | |
| images: torch.Tensor, | |
| num_scale_frames: Optional[int] = None, | |
| keyframe_interval: int = 1, | |
| output_device: Optional[torch.device] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Streaming inference: process scale frames first, then frame-by-frame. | |
| This method enables efficient online inference by: | |
| 1. Processing initial scale frames together (bidirectional attention via scale token) | |
| 2. Processing remaining frames one-by-one with KV cache (causal streaming) | |
| Keyframe mode (keyframe_interval > 1): | |
| - Every keyframe_interval-th frame (after scale frames) is a keyframe | |
| - Keyframes: KV is stored in cache (normal behavior) | |
| - Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard) | |
| - All frames produce full predictions regardless of keyframe status | |
| - Reduces KV cache memory growth by ~1/keyframe_interval | |
| Args: | |
| images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1] | |
| num_scale_frames: Number of initial frames for scale estimation. | |
| If None, uses self.num_frame_for_scale. | |
| keyframe_interval: Every N-th frame (after scale frames) is a keyframe | |
| whose KV persists in cache. 1 = every frame is a | |
| keyframe (default, same as original behavior). | |
| output_device: Device to store output predictions on. If None, keeps on | |
| the same device as the model. Set to torch.device('cpu') | |
| to offload predictions per-frame and avoid GPU OOM on | |
| long sequences. | |
| Returns: | |
| Dictionary containing predictions for all frames: | |
| - pose_enc: [B, S, 9] | |
| - depth: [B, S, H, W, 1] | |
| - depth_conf: [B, S, H, W] | |
| - world_points: [B, S, H, W, 3] | |
| - world_points_conf: [B, S, H, W] | |
| """ | |
| # Normalize input shape | |
| if len(images.shape) == 4: | |
| images = images.unsqueeze(0) | |
| B, S, C, H, W = images.shape | |
| # Determine number of scale frames | |
| scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale | |
| scale_frames = min(scale_frames, S) # Cap to available frames | |
| # Helper to move tensor to output device | |
| def _to_out(t: torch.Tensor) -> torch.Tensor: | |
| if output_device is not None: | |
| return t.to(output_device) | |
| return t | |
| # Clean KV caches before starting new sequence | |
| self.clean_kv_cache() | |
| # Phase 1: Process scale frames together | |
| # These frames get bidirectional attention among themselves via scale token | |
| logger.info(f'Processing {scale_frames} scale frames...') | |
| scale_images = images[:, :scale_frames] | |
| scale_output = self.forward( | |
| scale_images, | |
| num_frame_for_scale=scale_frames, | |
| num_frame_per_block=scale_frames, # Process all scale frames as one block | |
| causal_inference=True, | |
| ) | |
| # Initialize output lists with scale frame predictions (offload if needed) | |
| all_pose_enc = [_to_out(scale_output["pose_enc"])] | |
| all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else [] | |
| all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else [] | |
| all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else [] | |
| all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else [] | |
| del scale_output | |
| # Phase 2: Process remaining frames one-by-one | |
| pbar = tqdm( | |
| range(scale_frames, S), | |
| desc='Streaming inference', | |
| initial=scale_frames, | |
| total=S, | |
| ) | |
| for i in pbar: | |
| frame_image = images[:, i:i+1] | |
| # Determine if this frame is a keyframe | |
| is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0) | |
| if not is_keyframe: | |
| self._set_skip_append(True) | |
| frame_output = self.forward( | |
| frame_image, | |
| num_frame_for_scale=scale_frames, # Keep same for scale token logic | |
| num_frame_per_block=1, # Single frame per block | |
| causal_inference=True, | |
| ) | |
| if not is_keyframe: | |
| self._set_skip_append(False) | |
| all_pose_enc.append(_to_out(frame_output["pose_enc"])) | |
| if "depth" in frame_output: | |
| all_depth.append(_to_out(frame_output["depth"])) | |
| if "depth_conf" in frame_output: | |
| all_depth_conf.append(_to_out(frame_output["depth_conf"])) | |
| if "world_points" in frame_output: | |
| all_world_points.append(_to_out(frame_output["world_points"])) | |
| if "world_points_conf" in frame_output: | |
| all_world_points_conf.append(_to_out(frame_output["world_points_conf"])) | |
| del frame_output | |
| # Free GPU memory before concatenation | |
| if output_device is not None: | |
| # Move images to output device, then free GPU copy | |
| images_out = _to_out(images) | |
| del images | |
| # Clean KV cache (no longer needed after inference) | |
| self.clean_kv_cache() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| else: | |
| images_out = images | |
| # Concatenate all predictions along sequence dimension | |
| predictions = { | |
| "pose_enc": torch.cat(all_pose_enc, dim=1), | |
| } | |
| del all_pose_enc | |
| if all_depth: | |
| predictions["depth"] = torch.cat(all_depth, dim=1) | |
| del all_depth | |
| if all_depth_conf: | |
| predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1) | |
| del all_depth_conf | |
| if all_world_points: | |
| predictions["world_points"] = torch.cat(all_world_points, dim=1) | |
| del all_world_points | |
| if all_world_points_conf: | |
| predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1) | |
| del all_world_points_conf | |
| # Store images for visualization | |
| predictions["images"] = images_out | |
| # Apply prediction normalization if enabled | |
| if self.pred_normalization: | |
| predictions = self._normalize_predictions(predictions) | |
| return predictions | |