Spaces:
Paused
Paused
| """ | |
| GCTStream - Streaming GCT with KV cache for online inference. | |
| Provides streaming inference functionality: | |
| - Temporal causal attention with KV cache | |
| - Sliding window support | |
| - Efficient frame-by-frame processing | |
| - 3D RoPE support for temporal consistency | |
| """ | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Dict, Any, List | |
| from tqdm.auto import tqdm | |
| from lingbot_map.utils.rotation import quat_to_mat, mat_to_quat | |
| from lingbot_map.heads.camera_head import CameraCausalHead | |
| from lingbot_map.models.gct_base import GCTBase | |
| from lingbot_map.aggregator.stream import AggregatorStream | |
| from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri | |
| from lingbot_map.utils.geometry import closed_form_inverse_se3 | |
| logger = logging.getLogger(__name__) | |
| def _compute_flow_magnitude( | |
| cur_pose_enc: torch.Tensor, | |
| kf_pose_enc: torch.Tensor, | |
| cur_depth: torch.Tensor, | |
| image_size_hw: tuple, | |
| stride: int = 8, | |
| ) -> float: | |
| """Compute mean optical flow magnitude induced by camera motion. | |
| Projects current frame pixels into the last keyframe camera using the | |
| current depth map and both frames' poses, then returns the average | |
| pixel displacement (L2 norm of flow) over valid pixels. | |
| Args: | |
| cur_pose_enc: Current frame pose encoding [B, 1, 9]. | |
| kf_pose_enc: Last keyframe pose encoding [B, 1, 9]. | |
| cur_depth: Current frame depth map [B, 1, H, W, 1]. | |
| image_size_hw: (H, W) of the depth map. | |
| stride: Subsampling stride for efficiency. | |
| Returns: | |
| Mean flow magnitude in pixels (scalar float). | |
| """ | |
| H, W = image_size_hw | |
| device = cur_pose_enc.device | |
| dtype = cur_depth.dtype | |
| cur_ext, cur_intr = pose_encoding_to_extri_intri( | |
| cur_pose_enc, image_size_hw=image_size_hw | |
| ) | |
| kf_ext, kf_intr = pose_encoding_to_extri_intri( | |
| kf_pose_enc, image_size_hw=image_size_hw | |
| ) | |
| B = cur_ext.shape[0] | |
| cur_ext = cur_ext[:, 0] | |
| cur_intr = cur_intr[:, 0] | |
| kf_ext = kf_ext[:, 0] | |
| kf_intr = kf_intr[:, 0] | |
| depth = cur_depth[:, 0, ::stride, ::stride, 0].to(dtype) | |
| Hs, Ws = depth.shape[1], depth.shape[2] | |
| v_coords = torch.arange(0, H, stride, device=device, dtype=dtype) | |
| u_coords = torch.arange(0, W, stride, device=device, dtype=dtype) | |
| v_grid, u_grid = torch.meshgrid(v_coords, u_coords, indexing='ij') | |
| ones = torch.ones_like(u_grid) | |
| pixel_coords = torch.stack([u_grid, v_grid, ones], dim=-1) | |
| intr_inv = torch.inverse(cur_intr) | |
| cam_coords = torch.einsum('bij,hwj->bhwi', intr_inv, pixel_coords) | |
| cam_pts = cam_coords * depth.unsqueeze(-1) | |
| c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype) | |
| c2w[:, :3, :] = cur_ext | |
| c2w[:, 3, 3] = 1.0 | |
| ones_hw = torch.ones(B, Hs, Ws, 1, device=device, dtype=dtype) | |
| cam_pts_h = torch.cat([cam_pts, ones_hw], dim=-1) | |
| world_pts = torch.einsum('bij,bhwj->bhwi', c2w, cam_pts_h)[..., :3] | |
| kf_c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype) | |
| kf_c2w[:, :3, :] = kf_ext | |
| kf_c2w[:, 3, 3] = 1.0 | |
| kf_w2c = closed_form_inverse_se3(kf_c2w) | |
| world_pts_h = torch.cat([world_pts, ones_hw], dim=-1) | |
| kf_cam_pts = torch.einsum('bij,bhwj->bhwi', kf_w2c, world_pts_h)[..., :3] | |
| z = kf_cam_pts[..., 2:3].clamp(min=1e-6) | |
| kf_cam_norm = kf_cam_pts / z | |
| kf_pixels = torch.einsum('bij,bhwj->bhwi', kf_intr, kf_cam_norm)[..., :2] | |
| orig_pixels = torch.stack([u_grid, v_grid], dim=-1).unsqueeze(0).expand(B, -1, -1, -1) | |
| flow = kf_pixels - orig_pixels | |
| valid = (depth > 1e-6) & (kf_cam_pts[..., 2] > 1e-6) | |
| flow_mag = flow.norm(dim=-1) | |
| valid_count = valid.float().sum() | |
| if valid_count < 1: | |
| return 0.0 | |
| mean_mag = (flow_mag * valid.float()).sum() / valid_count | |
| return mean_mag.item() | |
| class GCTStream(GCTBase): | |
| """ | |
| Streaming GCT model with KV cache for efficient online inference. | |
| Features: | |
| - AggregatorStream with KV cache support (FlashInfer backend) | |
| - CameraCausalHead for pose refinement | |
| - Sliding window attention for memory efficiency | |
| - Frame-by-frame streaming inference | |
| """ | |
| def __init__( | |
| self, | |
| # Architecture parameters | |
| img_size: int = 518, | |
| patch_size: int = 14, | |
| embed_dim: int = 1024, | |
| patch_embed: str = 'dinov2_vitl14_reg', | |
| pretrained_path: str = '', | |
| disable_global_rope: bool = False, | |
| # Head configuration | |
| enable_camera: bool = True, | |
| enable_point: bool = True, | |
| enable_local_point: bool = False, | |
| enable_depth: bool = True, | |
| enable_track: bool = False, | |
| # Normalization | |
| enable_normalize: bool = False, | |
| # Prediction normalization | |
| pred_normalization: bool = False, | |
| # Stream-specific parameters | |
| sliding_window_size: int = -1, | |
| num_frame_for_scale: int = 1, | |
| num_random_frames: int = 0, | |
| attend_to_special_tokens: bool = False, | |
| attend_to_scale_frames: bool = False, | |
| enable_stream_inference: bool = True, # Default to True for streaming | |
| enable_3d_rope: bool = False, | |
| max_frame_num: int = 1024, | |
| # Camera head 3D RoPE (separate from aggregator 3D RoPE) | |
| enable_camera_3d_rope: bool = False, | |
| camera_rope_theta: float = 10000.0, | |
| # Scale token configuration (kept for checkpoint compat, ignored) | |
| use_scale_token: bool = True, | |
| # KV cache parameters | |
| kv_cache_sliding_window: int = 64, | |
| kv_cache_scale_frames: int = 8, | |
| kv_cache_cross_frame_special: bool = True, | |
| kv_cache_include_scale_frames: bool = True, | |
| kv_cache_camera_only: bool = False, | |
| # Backend selection | |
| use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer | |
| # Gradient checkpointing | |
| use_gradient_checkpoint: bool = True, | |
| # Camera head iterative refinement (lower = faster inference; default 4) | |
| camera_num_iterations: int = 4, | |
| ): | |
| """ | |
| Initialize GCTStream. | |
| Args: | |
| img_size: Input image size | |
| patch_size: Patch size for embedding | |
| embed_dim: Embedding dimension | |
| patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.) | |
| pretrained_path: Path to pretrained DINOv2 weights | |
| disable_global_rope: Disable RoPE in global attention | |
| enable_camera/point/depth/track: Enable prediction heads | |
| enable_normalize: Enable normalization | |
| sliding_window_size: Sliding window size in blocks (-1 for full causal) | |
| num_frame_for_scale: Number of scale estimation frames | |
| num_random_frames: Number of random frames for long-range dependencies | |
| attend_to_special_tokens: Enable cross-frame special token attention | |
| attend_to_scale_frames: Whether to attend to scale frames | |
| enable_stream_inference: Enable streaming inference with KV cache | |
| enable_3d_rope: Enable 3D RoPE for temporal consistency | |
| max_frame_num: Maximum number of frames for 3D RoPE | |
| use_scale_token: Kept for checkpoint compatibility, ignored | |
| kv_cache_sliding_window: Sliding window size for KV cache eviction | |
| kv_cache_scale_frames: Number of scale frames to keep in KV cache | |
| kv_cache_cross_frame_special: Keep special tokens from evicted frames | |
| kv_cache_include_scale_frames: Include scale frames in KV cache | |
| kv_cache_camera_only: Only keep camera tokens from evicted frames | |
| """ | |
| # Store stream-specific parameters before calling super().__init__() | |
| self.pretrained_path = pretrained_path | |
| self.sliding_window_size = sliding_window_size | |
| self.num_frame_for_scale = num_frame_for_scale | |
| self.num_random_frames = num_random_frames | |
| self.attend_to_special_tokens = attend_to_special_tokens | |
| self.attend_to_scale_frames = attend_to_scale_frames | |
| self.enable_stream_inference = enable_stream_inference | |
| self.enable_3d_rope = enable_3d_rope | |
| self.max_frame_num = max_frame_num | |
| # Camera head 3D RoPE settings | |
| self.enable_camera_3d_rope = enable_camera_3d_rope | |
| self.camera_rope_theta = camera_rope_theta | |
| # KV cache parameters | |
| self.kv_cache_sliding_window = kv_cache_sliding_window | |
| self.kv_cache_scale_frames = kv_cache_scale_frames | |
| self.kv_cache_cross_frame_special = kv_cache_cross_frame_special | |
| self.kv_cache_include_scale_frames = kv_cache_include_scale_frames | |
| self.kv_cache_camera_only = kv_cache_camera_only | |
| self.use_sdpa = use_sdpa | |
| self.camera_num_iterations = camera_num_iterations | |
| # Call base class __init__ (will call _build_aggregator) | |
| super().__init__( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| embed_dim=embed_dim, | |
| patch_embed=patch_embed, | |
| disable_global_rope=disable_global_rope, | |
| enable_camera=enable_camera, | |
| enable_point=enable_point, | |
| enable_local_point=enable_local_point, | |
| enable_depth=enable_depth, | |
| enable_track=enable_track, | |
| enable_normalize=enable_normalize, | |
| pred_normalization=pred_normalization, | |
| enable_3d_rope=enable_3d_rope, | |
| use_gradient_checkpoint=use_gradient_checkpoint, | |
| ) | |
| def _build_aggregator(self) -> nn.Module: | |
| """ | |
| Build streaming aggregator with KV cache support (FlashInfer backend). | |
| Returns: | |
| AggregatorStream module | |
| """ | |
| return AggregatorStream( | |
| img_size=self.img_size, | |
| patch_size=self.patch_size, | |
| embed_dim=self.embed_dim, | |
| patch_embed=self.patch_embed, | |
| pretrained_path=self.pretrained_path, | |
| disable_global_rope=self.disable_global_rope, | |
| sliding_window_size=self.sliding_window_size, | |
| num_frame_for_scale=self.num_frame_for_scale, | |
| num_random_frames=self.num_random_frames, | |
| attend_to_special_tokens=self.attend_to_special_tokens, | |
| attend_to_scale_frames=self.attend_to_scale_frames, | |
| enable_stream_inference=self.enable_stream_inference, | |
| enable_3d_rope=self.enable_3d_rope, | |
| max_frame_num=self.max_frame_num, | |
| # Backend: FlashInfer (default) or SDPA (fallback) | |
| use_flashinfer=not self.use_sdpa, | |
| use_sdpa=self.use_sdpa, | |
| kv_cache_sliding_window=self.kv_cache_sliding_window, | |
| kv_cache_scale_frames=self.kv_cache_scale_frames, | |
| kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, | |
| kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, | |
| kv_cache_camera_only=self.kv_cache_camera_only, | |
| use_gradient_checkpoint=self.use_gradient_checkpoint, | |
| ) | |
| def _build_camera_head(self) -> nn.Module: | |
| """ | |
| Build causal camera head for streaming inference. | |
| Returns: | |
| CameraCausalHead module or None | |
| """ | |
| return CameraCausalHead( | |
| dim_in=2 * self.embed_dim, | |
| sliding_window_size=self.sliding_window_size, | |
| attend_to_scale_frames=self.attend_to_scale_frames, | |
| num_iterations=self.camera_num_iterations, | |
| # KV cache parameters | |
| kv_cache_sliding_window=self.kv_cache_sliding_window, | |
| kv_cache_scale_frames=self.kv_cache_scale_frames, | |
| kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, | |
| kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, | |
| kv_cache_camera_only=self.kv_cache_camera_only, | |
| # Camera head 3D RoPE parameters | |
| enable_3d_rope=self.enable_camera_3d_rope, | |
| max_frame_num=self.max_frame_num, | |
| rope_theta=self.camera_rope_theta, | |
| ) | |
| def _aggregate_features( | |
| self, | |
| images: torch.Tensor, | |
| num_frame_for_scale: Optional[int] = None, | |
| sliding_window_size: Optional[int] = None, | |
| num_frame_per_block: int = 1, | |
| **kwargs, | |
| ) -> tuple: | |
| """ | |
| Run aggregator to get multi-scale features. | |
| Args: | |
| images: Input images [B, S, 3, H, W] | |
| num_frame_for_scale: Number of frames for scale estimation | |
| sliding_window_size: Override sliding window size | |
| num_frame_per_block: Number of frames per block | |
| Returns: | |
| (aggregated_tokens_list, patch_start_idx) | |
| """ | |
| aggregated_tokens_list, patch_start_idx = self.aggregator( | |
| images, | |
| selected_idx=[4, 11, 17, 23], | |
| num_frame_for_scale=num_frame_for_scale, | |
| sliding_window_size=sliding_window_size, | |
| num_frame_per_block=num_frame_per_block, | |
| ) | |
| return aggregated_tokens_list, patch_start_idx | |
| def clean_kv_cache(self): | |
| """ | |
| Clean KV cache in aggregator. | |
| Call this method when starting a new video sequence to clear | |
| cached key-value pairs from previous sequences. | |
| """ | |
| if hasattr(self.aggregator, 'clean_kv_cache'): | |
| self.aggregator.clean_kv_cache() | |
| else: | |
| logger.warning("Aggregator does not support KV cache cleaning") | |
| if hasattr(self.camera_head, 'kv_cache'): | |
| self.camera_head.clean_kv_cache() | |
| else: | |
| logger.warning("Camera head does not support KV cache cleaning") | |
| def _set_skip_append(self, skip: bool): | |
| """Set _skip_append flag on all KV caches (aggregator + camera head). | |
| When skip=True, attention layers will attend to [cached_kv + current_kv] | |
| but will NOT store the current frame's KV in cache. This is used for | |
| non-keyframe processing in keyframe-based streaming inference. | |
| Args: | |
| skip: If True, subsequent forward passes will not append KV to cache. | |
| """ | |
| if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None: | |
| self.aggregator.kv_cache["_skip_append"] = skip | |
| # FlashInfer manager | |
| if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: | |
| self.aggregator.kv_cache_manager._skip_append = skip | |
| if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: | |
| for cache_dict in self.camera_head.kv_cache: | |
| cache_dict["_skip_append"] = skip | |
| # ── Flow-based keyframe helpers ──────────────────────────────────────── | |
| def _set_defer_eviction(self, defer: bool): | |
| """Set defer-eviction flag on FlashInfer manager and SDPA caches. | |
| While True, eviction is suppressed so that rollback can cleanly undo | |
| the most recent append without having to restore evicted frames. | |
| """ | |
| # FlashInfer manager | |
| if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: | |
| self.aggregator.kv_cache_manager._defer_eviction = defer | |
| # SDPA aggregator cache (dict) | |
| if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict): | |
| self.aggregator.kv_cache["_defer_eviction"] = defer | |
| # Camera head SDPA caches | |
| if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: | |
| for cache_dict in self.camera_head.kv_cache: | |
| cache_dict["_defer_eviction"] = defer | |
| def _rollback_last_frame(self): | |
| """Rollback the most recent frame from all caches. | |
| Undoes append_frame on FlashInfer manager (all blocks), trims the | |
| camera head SDPA cache, and decrements the aggregator frame counter. | |
| Must be called while eviction is still deferred. | |
| """ | |
| # FlashInfer manager — rollback each transformer block | |
| if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: | |
| mgr = self.aggregator.kv_cache_manager | |
| for block_idx in range(mgr.num_blocks): | |
| mgr.rollback_last_frame(block_idx) | |
| # SDPA aggregator cache — trim last frame along dim=2 | |
| if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict): | |
| kv = self.aggregator.kv_cache | |
| for key in list(kv.keys()): | |
| if key.startswith(("k_", "v_")) and kv[key] is not None and torch.is_tensor(kv[key]): | |
| if kv[key].dim() >= 3 and kv[key].shape[2] > 1: | |
| kv[key] = kv[key][:, :, :-1] | |
| elif kv[key].dim() >= 3: | |
| kv[key] = None | |
| # Camera head | |
| if self.camera_head is not None and hasattr(self.camera_head, 'rollback_last_frame'): | |
| self.camera_head.rollback_last_frame() | |
| # Aggregator frame counter (used for 3D RoPE temporal positions) | |
| self.aggregator.total_frames_processed -= 1 | |
| def _execute_deferred_eviction(self): | |
| """Execute the eviction that was deferred during the last forward pass.""" | |
| # FlashInfer manager | |
| if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: | |
| mgr = self.aggregator.kv_cache_manager | |
| for block_idx in range(mgr.num_blocks): | |
| mgr.execute_deferred_eviction( | |
| block_idx, | |
| scale_frames=self.kv_cache_scale_frames, | |
| sliding_window=self.kv_cache_sliding_window, | |
| ) | |
| def get_kv_cache_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about current KV cache state. | |
| Returns: | |
| Dictionary with cache statistics: | |
| - num_cached_blocks: Number of blocks with cached KV | |
| - cache_memory_mb: Approximate memory usage in MB | |
| """ | |
| if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None: | |
| return {"num_cached_blocks": 0, "cache_memory_mb": 0.0} | |
| kv_cache = self.aggregator.kv_cache | |
| num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special')) | |
| # Estimate memory usage | |
| total_elements = 0 | |
| for _, v in kv_cache.items(): | |
| if v is not None and torch.is_tensor(v): | |
| total_elements += v.numel() | |
| # Assume bfloat16 (2 bytes per element) | |
| cache_memory_mb = (total_elements * 2) / (1024 * 1024) | |
| return { | |
| "num_cached_blocks": num_cached, | |
| "cache_memory_mb": round(cache_memory_mb, 2) | |
| } | |
| def inference_streaming( | |
| self, | |
| images: torch.Tensor, | |
| num_scale_frames: Optional[int] = None, | |
| keyframe_interval: int = 1, | |
| output_device: Optional[torch.device] = None, | |
| flow_threshold: float = 0.0, | |
| max_non_keyframe_gap: int = 30, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Streaming inference: process scale frames first, then frame-by-frame. | |
| This method enables efficient online inference by: | |
| 1. Processing initial scale frames together (bidirectional attention via scale token) | |
| 2. Processing remaining frames one-by-one with KV cache (causal streaming) | |
| Keyframe mode (keyframe_interval > 1): | |
| - Every keyframe_interval-th frame (after scale frames) is a keyframe | |
| - Keyframes: KV is stored in cache (normal behavior) | |
| - Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard) | |
| - All frames produce full predictions regardless of keyframe status | |
| - Reduces KV cache memory growth by ~1/keyframe_interval | |
| Flow-based keyframe mode (flow_threshold > 0): | |
| - Takes precedence over keyframe_interval | |
| - Computes optical flow magnitude between current frame and last keyframe | |
| - Frame becomes keyframe if flow exceeds threshold or gap exceeds max_non_keyframe_gap | |
| - Uses defer-eviction + rollback for non-keyframes | |
| Args: | |
| images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1] | |
| num_scale_frames: Number of initial frames for scale estimation. | |
| If None, uses self.num_frame_for_scale. | |
| keyframe_interval: Every N-th frame (after scale frames) is a keyframe | |
| whose KV persists in cache. 1 = every frame is a | |
| keyframe (default, same as original behavior). | |
| output_device: Device to store output predictions on. If None, keeps on | |
| the same device as the model. Set to torch.device('cpu') | |
| to offload predictions per-frame and avoid GPU OOM on | |
| long sequences. | |
| flow_threshold: Mean flow magnitude threshold (pixels) for flow-based | |
| keyframe selection. >0 enables flow-based mode (takes precedence | |
| over keyframe_interval). | |
| max_non_keyframe_gap: Max consecutive non-keyframe frames before | |
| forcing a keyframe (flow mode only). | |
| Returns: | |
| Dictionary containing predictions for all frames: | |
| - pose_enc: [B, S, 9] | |
| - depth: [B, S, H, W, 1] | |
| - depth_conf: [B, S, H, W] | |
| - world_points: [B, S, H, W, 3] | |
| - world_points_conf: [B, S, H, W] | |
| """ | |
| # Normalize input shape | |
| if len(images.shape) == 4: | |
| images = images.unsqueeze(0) | |
| B, S, C, H, W = images.shape | |
| # Determine number of scale frames | |
| scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale | |
| scale_frames = min(scale_frames, S) # Cap to available frames | |
| # Helper to move tensor to output device | |
| def _to_out(t: torch.Tensor) -> torch.Tensor: | |
| if output_device is not None: | |
| return t.to(output_device) | |
| return t | |
| # Clean KV caches before starting new sequence | |
| self.clean_kv_cache() | |
| # Phase 1: Process scale frames together | |
| # These frames get bidirectional attention among themselves via scale token | |
| logger.info(f'Processing {scale_frames} scale frames...') | |
| scale_images = images[:, :scale_frames] | |
| scale_output = self.forward( | |
| scale_images, | |
| num_frame_for_scale=scale_frames, | |
| num_frame_per_block=scale_frames, # Process all scale frames as one block | |
| causal_inference=True, | |
| ) | |
| # Initialize output lists with scale frame predictions (offload if needed) | |
| all_pose_enc = [_to_out(scale_output["pose_enc"])] | |
| all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else [] | |
| all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else [] | |
| all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else [] | |
| all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else [] | |
| del scale_output | |
| # Phase 2: Process remaining frames one-by-one | |
| use_flow_keyframe = flow_threshold > 0.0 | |
| # Flow state: last keyframe = last scale frame | |
| if use_flow_keyframe: | |
| last_kf_pose_enc = all_pose_enc[0][:, -1:] # last scale frame | |
| last_kf_idx = scale_frames - 1 | |
| pbar = tqdm( | |
| range(scale_frames, S), | |
| desc='Streaming inference', | |
| initial=scale_frames, | |
| total=S, | |
| ) | |
| for i in pbar: | |
| frame_image = images[:, i:i+1] | |
| if use_flow_keyframe: | |
| # Flow-based: defer eviction, forward, then decide | |
| self._set_defer_eviction(True) | |
| frame_output = self.forward( | |
| frame_image, | |
| num_frame_for_scale=scale_frames, | |
| num_frame_per_block=1, | |
| causal_inference=True, | |
| ) | |
| self._set_defer_eviction(False) | |
| # Compute flow to decide keyframe | |
| cur_depth = frame_output.get("depth", None) | |
| if cur_depth is not None: | |
| H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3] | |
| flow_mag = _compute_flow_magnitude( | |
| frame_output["pose_enc"], last_kf_pose_enc, | |
| cur_depth, (H_pred, W_pred), | |
| ) | |
| else: | |
| flow_mag = flow_threshold + 1.0 | |
| frames_since_kf = i - last_kf_idx | |
| is_keyframe = ( | |
| (i == scale_frames) # first streaming frame | |
| or (flow_mag > flow_threshold) | |
| or (frames_since_kf >= max_non_keyframe_gap) | |
| ) | |
| if is_keyframe: | |
| self._execute_deferred_eviction() | |
| last_kf_pose_enc = frame_output["pose_enc"] | |
| last_kf_idx = i | |
| else: | |
| self._rollback_last_frame() | |
| else: | |
| # Fixed-interval keyframe mode | |
| is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0) | |
| if not is_keyframe: | |
| self._set_skip_append(True) | |
| frame_output = self.forward( | |
| frame_image, | |
| num_frame_for_scale=scale_frames, | |
| num_frame_per_block=1, | |
| causal_inference=True, | |
| ) | |
| if not is_keyframe: | |
| self._set_skip_append(False) | |
| all_pose_enc.append(_to_out(frame_output["pose_enc"])) | |
| if "depth" in frame_output: | |
| all_depth.append(_to_out(frame_output["depth"])) | |
| if "depth_conf" in frame_output: | |
| all_depth_conf.append(_to_out(frame_output["depth_conf"])) | |
| if "world_points" in frame_output: | |
| all_world_points.append(_to_out(frame_output["world_points"])) | |
| if "world_points_conf" in frame_output: | |
| all_world_points_conf.append(_to_out(frame_output["world_points_conf"])) | |
| del frame_output | |
| # Free GPU memory before concatenation | |
| if output_device is not None: | |
| # Move images to output device, then free GPU copy | |
| images_out = _to_out(images) | |
| del images | |
| # Clean KV cache (no longer needed after inference) | |
| self.clean_kv_cache() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| else: | |
| images_out = images | |
| # Concatenate all predictions along sequence dimension | |
| predictions = { | |
| "pose_enc": torch.cat(all_pose_enc, dim=1), | |
| } | |
| del all_pose_enc | |
| if all_depth: | |
| predictions["depth"] = torch.cat(all_depth, dim=1) | |
| del all_depth | |
| if all_depth_conf: | |
| predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1) | |
| del all_depth_conf | |
| if all_world_points: | |
| predictions["world_points"] = torch.cat(all_world_points, dim=1) | |
| del all_world_points | |
| if all_world_points_conf: | |
| predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1) | |
| del all_world_points_conf | |
| # Store images for visualization | |
| predictions["images"] = images_out | |
| # Apply prediction normalization if enabled | |
| if self.pred_normalization: | |
| predictions = self._normalize_predictions(predictions) | |
| return predictions | |
| # ══════════════════════════════════════════════════════════════════════ | |
| # Window stitching & cross-window alignment | |
| # ══════════════════════════════════════════════════════════════════════ | |
| _FRAME_AXIS_KEYS = frozenset({ | |
| "pose_enc", "depth", "depth_conf", | |
| "world_points", "world_points_conf", | |
| "frame_type", "is_keyframe", | |
| }) | |
| def _stitch_windows( | |
| self, | |
| windows: List[Dict], | |
| window_size: int, | |
| overlap: int, | |
| ) -> Dict: | |
| """Concatenate per-window predictions while de-duplicating overlaps. | |
| For each temporal key the method builds a slice table first — every | |
| window contributes ``[0, effective_end)`` frames where | |
| ``effective_end = total_frames - overlap`` for non-final windows. | |
| Non-temporal entries simply keep the latest available value. | |
| """ | |
| if len(windows) == 0: | |
| return {} | |
| if len(windows) == 1: | |
| return windows[0] | |
| n_win = len(windows) | |
| all_keys = list(windows[0].keys()) | |
| stitched: Dict = {} | |
| for key in all_keys: | |
| values = [w.get(key) for w in windows] | |
| if all(v is None for v in values): | |
| continue | |
| # Non-temporal entries: take latest | |
| if key not in self._FRAME_AXIS_KEYS: | |
| stitched[key] = next(v for v in reversed(values) if v is not None) | |
| continue | |
| # Build slice table: (start, end) for each window's contribution | |
| slices = [] | |
| for wi, tensor in enumerate(values): | |
| if tensor is None: | |
| slices.append(None) | |
| continue | |
| total = tensor.shape[1] | |
| is_last = (wi == n_win - 1) | |
| end = total if is_last else max(total - overlap, 0) | |
| slices.append((0, end) if end > 0 else None) | |
| parts = [ | |
| values[i][:, s:e] | |
| for i, s_e in enumerate(slices) | |
| if s_e is not None | |
| for s, e in [s_e] | |
| ] | |
| if parts: | |
| stitched[key] = torch.cat(parts, dim=1) | |
| else: | |
| fallback = next((v for v in reversed(values) if v is not None), None) | |
| if fallback is not None: | |
| stitched[key] = fallback | |
| return stitched | |
| def _depth_ratio_scale( | |
| anchor_depth: torch.Tensor, | |
| target_depth: torch.Tensor, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """Estimate per-batch scale as the median depth ratio anchor/target.""" | |
| a = anchor_depth.to(torch.float32).reshape(batch_size, -1) | |
| t = target_depth.to(torch.float32).reshape(batch_size, -1) | |
| ok = torch.isfinite(a) & torch.isfinite(t) & (t.abs() > torch.finfo(torch.float32).eps) | |
| scales = [] | |
| for b in range(batch_size): | |
| m = ok[b] | |
| if m.any(): | |
| scales.append((a[b, m] / t[b, m]).median()) | |
| else: | |
| scales.append(torch.tensor(1.0, device=device, dtype=torch.float32)) | |
| return torch.stack(scales).clamp(min=1e-3, max=1e3) | |
| def _pairwise_alignment( | |
| prev_pred: Dict, | |
| curr_pred: Dict, | |
| overlap: int, | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ): | |
| """Compute (scale, R, t) that maps *curr* into *prev*'s coordinate frame. | |
| Uses the first overlap frame of *curr* and the corresponding trailing | |
| frame of *prev* to establish the similarity transform. | |
| """ | |
| unit_s = torch.ones(batch_size, device=device, dtype=dtype) | |
| eye_R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1).clone() | |
| zero_t = torch.zeros(batch_size, 3, device=device, dtype=dtype) | |
| if overlap <= 0: | |
| return unit_s, eye_R, zero_t | |
| pe_prev = prev_pred.get("pose_enc") | |
| pe_curr = curr_pred.get("pose_enc") | |
| if pe_prev is None or pe_curr is None: | |
| return unit_s, eye_R, zero_t | |
| idx_a = max(pe_prev.shape[1] - overlap, 0) | |
| # Decompose C2W: center ([:3]) + quaternion ([3:7]) | |
| Ra = quat_to_mat(pe_prev[:, idx_a, 3:7]) # (B, 3, 3) | |
| ca = pe_prev[:, idx_a, :3] # (B, 3) | |
| Rb = quat_to_mat(pe_curr[:, 0, 3:7]) | |
| cb = pe_curr[:, 0, :3] | |
| R_ab = torch.bmm(Ra, Rb.transpose(1, 2)) # Ra = R_ab @ Rb | |
| # Scale from depth | |
| s_ab = unit_s.clone() | |
| da = prev_pred.get("depth") | |
| db = curr_pred.get("depth") | |
| if (da is not None and db is not None | |
| and da.shape[1] > idx_a and db.shape[1] > 0): | |
| s_ab = GCTStream._depth_ratio_scale( | |
| da[:, idx_a, ..., 0], db[:, 0, ..., 0], | |
| batch_size, device, | |
| ).to(dtype) | |
| # ca = s_ab * R_ab @ cb + t_ab => t_ab = ca - s_ab * R_ab @ cb | |
| t_ab = ca - s_ab.unsqueeze(-1) * torch.bmm(R_ab, cb.unsqueeze(-1)).squeeze(-1) | |
| return s_ab, R_ab.to(dtype), t_ab.to(dtype) | |
| def _warp_predictions( | |
| pred: Dict, | |
| R: torch.Tensor, | |
| t: torch.Tensor, | |
| s: torch.Tensor, | |
| batch_size: int, | |
| ) -> Dict: | |
| """Apply a similarity transform (s, R, t) to one window's predictions.""" | |
| warped: Dict = {} | |
| # Pose encoding: center + quaternion + intrinsics | |
| pe = pred.get("pose_enc") | |
| if pe is not None: | |
| nf = pe.shape[1] | |
| local_rot = quat_to_mat(pe[:, :, 3:7]) | |
| local_ctr = pe[:, :, :3] | |
| R_exp = R[:, None].expand(-1, nf, -1, -1) | |
| new_rot = torch.matmul(R_exp, local_rot) | |
| new_ctr = ( | |
| s.view(batch_size, 1, 1) * torch.matmul(R_exp, local_ctr.unsqueeze(-1)).squeeze(-1) | |
| + t.view(batch_size, 1, 3) | |
| ) | |
| out_pe = pe.clone() | |
| out_pe[:, :, :3] = new_ctr | |
| out_pe[:, :, 3:7] = mat_to_quat(new_rot) | |
| warped["pose_enc"] = out_pe | |
| else: | |
| warped["pose_enc"] = None | |
| # Depth: scale by s | |
| d = pred.get("depth") | |
| if d is not None: | |
| warped["depth"] = d * s.view(batch_size, 1, 1, 1, 1) | |
| else: | |
| warped["depth"] = None | |
| # World points: p_global = s * R @ p_local + t | |
| wp = pred.get("world_points") | |
| if wp is not None: | |
| b, nf, h, w, _ = wp.shape | |
| flat = wp.reshape(b, nf * h * w, 3) | |
| transformed = torch.bmm(flat, R.transpose(1, 2)) * s.view(b, 1, 1) | |
| transformed = transformed + t[:, None, :] | |
| warped["world_points"] = transformed.reshape(b, nf, h, w, 3) | |
| else: | |
| warped["world_points"] = None | |
| # Pass through all other keys untouched | |
| for k, v in pred.items(): | |
| if k not in warped: | |
| warped[k] = v | |
| return warped | |
| def _align_and_stitch_windows( | |
| self, | |
| windows: List[Dict], | |
| scale_mode: str = 'median', | |
| ) -> Dict: | |
| """Bring all windows into the first window's coordinate frame, then stitch. | |
| Iterates over consecutive window pairs, estimates the pairwise | |
| scaled alignment, warps each window, and finally concatenates | |
| via :meth:`_stitch_windows`. | |
| """ | |
| if len(windows) == 0: | |
| return {} | |
| if len(windows) == 1: | |
| out = windows[0].copy() | |
| out["alignment_mode"] = "scaled" | |
| return out | |
| # Discover batch / device / dtype from any available tensor | |
| ref = next( | |
| v | |
| for w in windows | |
| for k in ("pose_enc", "world_points", "depth") | |
| if (v := w.get(k)) is not None | |
| ) | |
| dev, dt, nb = ref.device, ref.dtype, ref.shape[0] | |
| overlap = getattr(self, "_last_overlap_size", 0) | |
| win_sz = getattr(self, "_last_window_size", -1) | |
| warped_windows: List[Dict] = [] | |
| per_window_scales: List[torch.Tensor] = [] | |
| per_window_transforms: List[torch.Tensor] = [] | |
| for idx, raw in enumerate(windows): | |
| if idx == 0: | |
| s_rel = torch.ones(nb, device=dev, dtype=dt) | |
| R_rel = torch.eye(3, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone() | |
| t_rel = torch.zeros(nb, 3, device=dev, dtype=dt) | |
| else: | |
| s_rel, R_rel, t_rel = self._pairwise_alignment( | |
| warped_windows[-1], raw, overlap, nb, dev, dt, | |
| ) | |
| per_window_scales.append(s_rel.clone()) | |
| T = torch.eye(4, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone() | |
| T[:, :3, :3] = R_rel | |
| T[:, :3, 3] = t_rel | |
| per_window_transforms.append(T) | |
| warped_windows.append( | |
| self._warp_predictions(raw, R_rel, t_rel, s_rel, nb) | |
| ) | |
| merged = self._stitch_windows(warped_windows, win_sz, overlap) | |
| # Attach alignment metadata | |
| if per_window_scales: | |
| merged["chunk_scales"] = torch.stack(per_window_scales, dim=1) | |
| if per_window_transforms: | |
| merged["chunk_transforms"] = torch.stack(per_window_transforms, dim=1) | |
| merged["alignment_mode"] = "scaled" | |
| return merged | |
| def inference_windowed( | |
| self, | |
| images: torch.Tensor, | |
| window_size: int = 16, | |
| overlap_size: Optional[int] = None, | |
| num_scale_frames: Optional[int] = None, | |
| scale_mode: str = 'median', | |
| output_device: Optional[torch.device] = None, | |
| keyframe_interval: int = 1, | |
| flow_threshold: float = 0.0, | |
| max_non_keyframe_gap: int = 30, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Windowed inference with keyframe detection and cross-window alignment. | |
| Each window is processed independently with a fresh KV cache. | |
| Overlap frames between windows are the next window's scale frames | |
| (bidirectional attention), ensuring the highest quality predictions | |
| at alignment boundaries. | |
| ``window_size`` counts **keyframes** (frames stored in KV cache), | |
| including scale frames. When ``keyframe_interval > 1``, each window | |
| covers more actual frames than ``window_size``: | |
| actual_frames = scale_frames + (window_size - scale_frames) * keyframe_interval | |
| Args: | |
| images: Input images [S, 3, H, W] or [B, S, 3, H, W] in [0, 1]. | |
| window_size: Number of **keyframes** per window (including scale | |
| frames). Directly controls KV cache memory. | |
| overlap_size: Number of overlapping frames between windows. | |
| Defaults to ``num_scale_frames`` (overlap = scale frames). | |
| num_scale_frames: Number of frames used as scale reference within | |
| each window. Defaults to ``self.num_frame_for_scale``. | |
| scale_mode: Scale estimation strategy for alignment. | |
| output_device: Device to store per-window outputs. | |
| keyframe_interval: Every N-th Phase 2 frame is a keyframe whose | |
| KV persists in cache. 1 = every frame (default). | |
| flow_threshold: Mean flow magnitude threshold (pixels) for | |
| flow-based keyframe selection. >0 enables flow-based mode | |
| (takes precedence over ``keyframe_interval``). | |
| max_non_keyframe_gap: Max consecutive non-keyframe frames before | |
| forcing a keyframe (flow mode only). | |
| Returns: | |
| Merged prediction dict with all frames. | |
| """ | |
| use_flow_keyframe = flow_threshold > 0.0 | |
| # Normalize input shape | |
| if len(images.shape) == 4: | |
| images = images.unsqueeze(0) | |
| B, S, C, H, W = images.shape | |
| ws = (num_scale_frames if num_scale_frames is not None | |
| else self.num_frame_for_scale) | |
| ws = min(ws, S) | |
| # overlap = scale_frames by default | |
| eff_overlap = min(overlap_size if overlap_size is not None else ws, | |
| S - 1) if S > 1 else 0 | |
| def _to_out(t: torch.Tensor) -> torch.Tensor: | |
| return t.to(output_device) if output_device is not None else t | |
| def _collect_frame(out, w_lists): | |
| w_lists['pose_enc'].append(_to_out(out["pose_enc"])) | |
| if "depth" in out: | |
| w_lists['depth'].append(_to_out(out["depth"])) | |
| if "depth_conf" in out: | |
| w_lists['depth_conf'].append(_to_out(out["depth_conf"])) | |
| if "world_points" in out: | |
| w_lists['world_points'].append(_to_out(out["world_points"])) | |
| if "world_points_conf" in out: | |
| w_lists['world_pts_conf'].append(_to_out(out["world_points_conf"])) | |
| def _make_window_pred(w_lists): | |
| pred: Dict = {"pose_enc": torch.cat(w_lists['pose_enc'], dim=1)} | |
| if w_lists['depth']: | |
| pred["depth"] = torch.cat(w_lists['depth'], dim=1) | |
| if w_lists['depth_conf']: | |
| pred["depth_conf"] = torch.cat(w_lists['depth_conf'], dim=1) | |
| if w_lists['world_points']: | |
| pred["world_points"] = torch.cat(w_lists['world_points'], dim=1) | |
| if w_lists['world_pts_conf']: | |
| pred["world_points_conf"] = torch.cat(w_lists['world_pts_conf'], dim=1) | |
| # Frame type: 0=scale, 1=keyframe, 2=non-keyframe | |
| ft = torch.tensor(w_lists['frame_type'], dtype=torch.uint8).unsqueeze(0) # [1, T] | |
| pred["frame_type"] = ft | |
| pred["is_keyframe"] = (ft != 2) # scale + keyframe = True | |
| return pred | |
| def _new_lists(): | |
| return { | |
| 'pose_enc': [], 'depth': [], 'depth_conf': [], | |
| 'world_points': [], 'world_pts_conf': [], | |
| 'frame_type': [], # list of ints: 0=scale, 1=keyframe, 2=non-keyframe | |
| } | |
| # ================================================================ | |
| # Flow-based mode: dynamic windows (can't precompute window list) | |
| # ================================================================ | |
| if use_flow_keyframe: | |
| all_window_predictions: List[Dict] = [] | |
| cursor = 0 | |
| window_idx = 0 | |
| pbar = tqdm(total=S, desc='Windowed inference (flow)', initial=0) | |
| while cursor < S: | |
| window_start = cursor | |
| window_scale = min(ws, S - cursor) | |
| # Fresh KV cache | |
| self.clean_kv_cache() | |
| # ---------- Phase 1: scale frames ---------- | |
| scale_images = images[:, cursor:cursor + window_scale] | |
| scale_out = self.forward( | |
| scale_images, | |
| num_frame_for_scale=window_scale, | |
| num_frame_per_block=window_scale, | |
| causal_inference=True, | |
| ) | |
| w_lists = _new_lists() | |
| _collect_frame(scale_out, w_lists) | |
| w_lists['frame_type'].extend([0] * window_scale) # scale frames | |
| # Flow state: last keyframe = last scale frame | |
| last_kf_pose_enc = scale_out["pose_enc"][:, -1:] | |
| last_kf_local_idx = window_scale - 1 | |
| del scale_out | |
| cursor += window_scale | |
| pbar.update(window_scale) | |
| # ---------- Phase 2: stream until enough keyframes ---------- | |
| target_kf = window_size - window_scale # keyframes to collect | |
| kf_count = 0 | |
| while cursor < S and kf_count < target_kf: | |
| frame_image = images[:, cursor:cursor + 1] | |
| self._set_defer_eviction(True) | |
| frame_out = self.forward( | |
| frame_image, | |
| num_frame_for_scale=window_scale, | |
| num_frame_per_block=1, | |
| causal_inference=True, | |
| ) | |
| self._set_defer_eviction(False) | |
| # Compute flow | |
| cur_depth = frame_out.get("depth", None) | |
| if cur_depth is not None: | |
| H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3] | |
| flow_mag = _compute_flow_magnitude( | |
| frame_out["pose_enc"], last_kf_pose_enc, | |
| cur_depth, (H_pred, W_pred), | |
| ) | |
| else: | |
| flow_mag = flow_threshold + 1.0 | |
| local_idx = window_scale + (cursor - window_start - window_scale) | |
| frames_since_kf = local_idx - last_kf_local_idx | |
| is_keyframe = ( | |
| (kf_count == 0) # first streaming frame | |
| or (flow_mag > flow_threshold) | |
| or (frames_since_kf >= max_non_keyframe_gap) | |
| ) | |
| if is_keyframe: | |
| self._execute_deferred_eviction() | |
| last_kf_pose_enc = frame_out["pose_enc"] | |
| last_kf_local_idx = local_idx | |
| kf_count += 1 | |
| w_lists['frame_type'].append(1) # keyframe | |
| else: | |
| self._rollback_last_frame() | |
| w_lists['frame_type'].append(2) # non-keyframe | |
| _collect_frame(frame_out, w_lists) | |
| del frame_out | |
| cursor += 1 | |
| pbar.update(1) | |
| all_window_predictions.append(_make_window_pred(w_lists)) | |
| window_idx += 1 | |
| # Next window starts overlap_size frames back (= scale frames) | |
| if cursor < S: | |
| cursor = max(cursor - eff_overlap, window_start + window_scale) | |
| pbar.close() | |
| # ================================================================ | |
| # Fixed-interval / default mode: precomputable windows | |
| # ================================================================ | |
| else: | |
| # Compute actual frames per window | |
| phase2_kf = max(window_size - ws, 0) | |
| kf_int = max(keyframe_interval, 1) | |
| phase2_frames = phase2_kf * kf_int | |
| actual_window_frames = ws + phase2_frames | |
| eff_window = min(actual_window_frames, S) | |
| step = max(eff_window - eff_overlap, 1) | |
| # Build window list | |
| if eff_window >= S: | |
| windows = [(0, S)] | |
| else: | |
| windows = [] | |
| for start_idx in range(0, S, step): | |
| end_idx = min(start_idx + eff_window, S) | |
| if end_idx - start_idx >= eff_overlap or end_idx == S: | |
| windows.append((start_idx, end_idx)) | |
| if end_idx == S: | |
| break | |
| all_window_predictions: List[Dict] = [] | |
| for start, end in tqdm(windows, desc='Windowed inference'): | |
| window_images = images[:, start:end] | |
| window_len = end - start | |
| # Fresh KV cache | |
| self.clean_kv_cache() | |
| window_scale = min(ws, window_len) | |
| # ---------- Phase 1: scale frames ---------- | |
| scale_out = self.forward( | |
| window_images[:, :window_scale], | |
| num_frame_for_scale=window_scale, | |
| num_frame_per_block=window_scale, | |
| causal_inference=True, | |
| ) | |
| w_lists = _new_lists() | |
| _collect_frame(scale_out, w_lists) | |
| w_lists['frame_type'].extend([0] * window_scale) # scale frames | |
| del scale_out | |
| # ---------- Phase 2: stream remaining frames ---------- | |
| for i in range(window_scale, window_len): | |
| is_keyframe = ( | |
| kf_int <= 1 | |
| or ((i - window_scale) % kf_int == 0) | |
| ) | |
| if not is_keyframe: | |
| self._set_skip_append(True) | |
| frame_out = self.forward( | |
| window_images[:, i:i + 1], | |
| num_frame_for_scale=window_scale, | |
| num_frame_per_block=1, | |
| causal_inference=True, | |
| ) | |
| if not is_keyframe: | |
| self._set_skip_append(False) | |
| _collect_frame(frame_out, w_lists) | |
| w_lists['frame_type'].append(1 if is_keyframe else 2) | |
| del frame_out | |
| all_window_predictions.append(_make_window_pred(w_lists)) | |
| # Store for merge helpers | |
| self._last_window_size = eff_overlap # not used directly, but kept for compat | |
| self._last_overlap_size = eff_overlap | |
| # Align and stitch windows | |
| predictions = self._align_and_stitch_windows( | |
| all_window_predictions, scale_mode=scale_mode | |
| ) | |
| predictions["images"] = _to_out(images) | |
| if self.pred_normalization: | |
| predictions = self._normalize_predictions(predictions) | |
| return predictions | |