Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| # References: | |
| # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py | |
| import logging | |
| import os | |
| import math | |
| import warnings | |
| import torch | |
| from torch import Tensor | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from lingbot_map.layers.rope import apply_rotary_emb | |
| from einops import rearrange | |
| # FlashInfer imports (optional - for paged attention) | |
| try: | |
| import flashinfer | |
| FLASHINFER_AVAILABLE = True | |
| except ImportError: | |
| FLASHINFER_AVAILABLE = False | |
| print("flashinfer not available") | |
| try: | |
| from torchtitan.distributed.sequence_parallel import ( | |
| gather_seq_scatter_heads, | |
| gather_heads_scatter_seq, | |
| pad_tensor, | |
| slice_input_tensor_scale_grad, | |
| gather_outputs, | |
| ) | |
| except ImportError: | |
| print("torchtitan not available for ulysses cp") | |
| def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int): | |
| """Gather sequence dimension and scatter head dimension for Q, K, V tensors.""" | |
| q = gather_seq_scatter_heads(q, seq_dim, head_dim) | |
| k = gather_seq_scatter_heads(k, seq_dim, head_dim) | |
| v = gather_seq_scatter_heads(v, seq_dim, head_dim) | |
| return q, k, v | |
| from typing_extensions import List | |
| from typing import Optional, Tuple | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| qk_norm: bool = False, | |
| fused_attn: bool = True, # use F.scaled_dot_product_attention or not | |
| rope=None, | |
| ) -> None: | |
| super().__init__() | |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.fused_attn = fused_attn | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim, bias=proj_bias) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.rope = rope | |
| def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor: | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if enable_ulysses_cp: | |
| q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) | |
| if self.rope is not None: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| if self.fused_attn: | |
| x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) | |
| else: | |
| q = q * self.scale | |
| attn = q @ k.transpose(-2, -1) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = attn @ v | |
| if enable_ulysses_cp: | |
| x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) | |
| x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class CausalAttention(nn.Module): | |
| """ | |
| Causal self-attention module with KV cache support for streaming inference. | |
| Used by CasualBlockCamera in camera_head.py. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| qk_norm: bool = False, | |
| fused_attn: bool = True, # use F.scaled_dot_product_attention or not | |
| rope=None, | |
| elementwise_attn_output_gate=False, | |
| # KV cache eviction parameters (matching build_attn_mask) | |
| kv_cache_sliding_window: int =64, | |
| kv_cache_scale_frames: int = 8, | |
| kv_cache_cross_frame_special: bool = True, | |
| kv_cache_include_scale_frames: bool = True, | |
| kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token) | |
| ) -> None: | |
| super().__init__() | |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.fused_attn = fused_attn | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim, bias=proj_bias) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.rope = rope | |
| self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None | |
| # Store KV cache eviction parameters | |
| self.kv_cache_sliding_window = kv_cache_sliding_window | |
| self.kv_cache_scale_frames = kv_cache_scale_frames | |
| self.kv_cache_cross_frame_special = kv_cache_cross_frame_special | |
| self.kv_cache_include_scale_frames = kv_cache_include_scale_frames | |
| self.kv_cache_camera_only = kv_cache_camera_only | |
| def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor: | |
| B, N, C = x.shape | |
| # Calculate special token indices | |
| camera_token_idx = 0 | |
| scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale | |
| # [3, B, num_heads, N, head_dim] | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| if self.gate_proj is not None: | |
| gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) | |
| if kv_cache is None: | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if enable_ulysses_cp: | |
| q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) | |
| N = q.shape[2] # Update N after gather | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif enable_3d_rope and pos is not None: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| with torch.no_grad(): | |
| block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]] | |
| if block_mask.dim() == 2: | |
| block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N] | |
| block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1]) | |
| video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N] | |
| video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1]) | |
| mask = block_mask | ~video_mask | |
| # Apply sliding window mask if sliding_window_size > 0 | |
| # sliding_window_size is in units of num_frame_per_block | |
| if sliding_window_size > 0 and frame_seqlen is not None: | |
| # Create sliding window mask: each frame can only attend to frames within the window | |
| num_frames = N // frame_seqlen | |
| sliding_mask = torch.zeros_like(mask, dtype=torch.bool) | |
| for i in range(num_frames): | |
| q_start = i * frame_seqlen | |
| q_end = (i + 1) * frame_seqlen | |
| # Calculate the window start: sliding_window_size is in units of num_frame_per_block | |
| # So the actual window size in frames is sliding_window_size * num_frame_per_block | |
| window_size_in_frames = sliding_window_size * num_frame_per_block | |
| window_start_frame = max(0, i - window_size_in_frames + 1) | |
| k_start = window_start_frame * frame_seqlen | |
| k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal) | |
| sliding_mask[:, :, q_start:q_end, k_start:k_end] = True | |
| # Combine with existing mask: both masks need to allow attention | |
| mask = mask & sliding_mask | |
| # If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames | |
| if num_frame_for_scale > 0: | |
| for i in range(num_frames): | |
| q_start = i * frame_seqlen | |
| q_end = (i + 1) * frame_seqlen | |
| # Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask) | |
| mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True | |
| ## global attention for the first num_frame_for_scale frames | |
| if num_frame_for_scale > 0: | |
| mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True | |
| if self.fused_attn: | |
| x = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| attn_mask=mask | |
| ) | |
| else: | |
| # Apply RoPE to current k before caching | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif enable_3d_rope and pos is not None: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| # Check if we should skip appending to cache (non-keyframe in keyframe mode) | |
| skip_append = kv_cache.get("_skip_append", False) | |
| k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| if not skip_append: | |
| # KEYFRAME: store in cache (original behavior) | |
| if kv_cache[f"k_{global_idx}"] is None: | |
| kv_cache[f"k_{global_idx}"] = k_reshaped | |
| kv_cache[f"v_{global_idx}"] = v_reshaped | |
| else: | |
| num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3] | |
| k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2) | |
| kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2) | |
| # Apply sliding window eviction BEFORE attention to match causal_3drope behavior | |
| # This ensures current frame only attends to frames within the sliding window | |
| self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx) | |
| # Retrieve full k, v from cache (already RoPE-applied, already evicted) | |
| k = kv_cache[f"k_{global_idx}"].clone() | |
| v = kv_cache[f"v_{global_idx}"].clone() | |
| else: | |
| # NON-KEYFRAME: attend to [cached + current] without storing in cache | |
| if kv_cache[f"k_{global_idx}"] is not None: | |
| k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2) | |
| v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2) | |
| else: | |
| k = k_reshaped | |
| v = v_reshaped | |
| a, b, c, d, e = k.shape | |
| k = k.reshape(a, b, c*d, e) | |
| v = v.reshape(a, b, c*d, e) | |
| # Prepend special tokens (camera + scale) from evicted frames if they exist | |
| if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None: | |
| special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D] | |
| special_v = kv_cache[f"v_{global_idx}_special"] | |
| sa, sb, sc, sd, se = special_k.shape | |
| special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D] | |
| special_v = special_v.reshape(sa, sb, sc * sd, se) | |
| # Prepend special tokens (older frames first) | |
| k = torch.cat([special_k, k], dim=2) | |
| v = torch.cat([special_v, v], dim=2) | |
| # Note: k from cache is already RoPE-applied, no need to apply again | |
| if self.fused_attn: | |
| # Use mask-based SDPA to ensure same kernel as batch mode | |
| # The causal constraint is enforced by KV cache contents, not by mask | |
| mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device) | |
| x = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| attn_mask=mask, | |
| ) | |
| if self.gate_proj is not None: | |
| x = x * torch.sigmoid(gate_score) | |
| if enable_ulysses_cp: | |
| x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) | |
| # Use actual dimensions from attention output, not original input C | |
| # x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim] | |
| x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx): | |
| """ | |
| Apply sliding window eviction to KV cache BEFORE attention. | |
| This ensures current frame only attends to frames within the sliding window, | |
| matching the behavior of causal_3drope's attention mask. | |
| """ | |
| sliding_window_frames = self.kv_cache_sliding_window | |
| scale_frames = self.kv_cache_scale_frames | |
| if kv_cache[f"k_{global_idx}"].shape[3] > 1: | |
| num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2] | |
| if num_cached_frames > sliding_window_frames + scale_frames: | |
| evict_start = scale_frames | |
| evict_end = num_cached_frames - sliding_window_frames | |
| if evict_end > evict_start: | |
| evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :] | |
| evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :] | |
| if self.kv_cache_cross_frame_special: | |
| if self.kv_cache_camera_only: | |
| # Only keep camera token | |
| new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() | |
| new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() | |
| else: | |
| # Keep ALL special tokens (camera + register + scale) to match attention_mask behavior | |
| # Special tokens are in range [camera_token_idx, scale_token_idx+1) | |
| new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() | |
| new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() | |
| if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None: | |
| kv_cache[f"k_{global_idx}_special"] = new_special_k | |
| kv_cache[f"v_{global_idx}_special"] = new_special_v | |
| else: | |
| kv_cache[f"k_{global_idx}_special"] = torch.cat( | |
| [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2) | |
| kv_cache[f"v_{global_idx}_special"] = torch.cat( | |
| [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2) | |
| if self.kv_cache_include_scale_frames: | |
| kv_cache[f"k_{global_idx}"] = torch.cat([ | |
| kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :], | |
| kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| ], dim=2) | |
| kv_cache[f"v_{global_idx}"] = torch.cat([ | |
| kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :], | |
| kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| ], dim=2) | |
| else: | |
| kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| class FlashInferAttention(Attention): | |
| """ | |
| FlashInfer variant of the GCT attention layer. | |
| Uses FlashInferKVCacheManager for paged KV cache storage and | |
| FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper). | |
| Supports the same optimized token layout and KV cache streaming inference. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| qk_norm: bool = False, | |
| fused_attn: bool = True, | |
| rope=None, | |
| # KV cache eviction parameters | |
| kv_cache_sliding_window: int = 64, | |
| kv_cache_scale_frames: int = 8, | |
| kv_cache_cross_frame_special: bool = True, | |
| kv_cache_include_scale_frames: bool = True, | |
| kv_cache_camera_only: bool = False, | |
| ) -> None: | |
| if not FLASHINFER_AVAILABLE: | |
| raise RuntimeError("FlashInfer is not available. Please install flashinfer.") | |
| super().__init__( | |
| dim=dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| fused_attn=fused_attn, | |
| rope=rope, | |
| ) | |
| # Store KV cache eviction parameters | |
| self.kv_cache_sliding_window = kv_cache_sliding_window | |
| self.kv_cache_scale_frames = kv_cache_scale_frames | |
| self.kv_cache_cross_frame_special = kv_cache_cross_frame_special | |
| self.kv_cache_include_scale_frames = kv_cache_include_scale_frames | |
| self.kv_cache_camera_only = kv_cache_camera_only | |
| def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple: | |
| """Fused pre-attention ops for single-frame streaming (Phase 2). | |
| Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to | |
| [tpf, H, D] format ready for append_frame + compute_attention. | |
| Extracted as a method so torch.compile can capture all pre-attn ops as one | |
| CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 -> | |
| squeeze/permute/contiguous x3). | |
| """ | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim] | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif self.rope is not None: # enable_3d_rope=True | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| # Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode) | |
| q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous() | |
| k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous() | |
| v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous() | |
| return q_nhd, k_nhd, v_nhd | |
| def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, | |
| num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False, | |
| # KV cache parameters (kv_cache is a FlashInferKVCacheManager or None) | |
| kv_cache=None, global_idx=0, num_frame_per_block=1, | |
| num_frame_for_scale=-1, num_register_tokens=4) -> Tensor: | |
| """ | |
| Forward pass with FlashInfer paged KV cache and attention. | |
| Args: | |
| x: Input tensor [B, N, C] | |
| kv_cache: FlashInferKVCacheManager instance or None (batch mode) | |
| global_idx: Block index for per-block cache access | |
| """ | |
| from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager | |
| B, N, C = x.shape | |
| # Detect if using optimized layout | |
| using_optimized_layout = (num_patches is not None and num_special is not None | |
| and num_frames is not None) | |
| # ========== Batch Mode (no KV cache manager) ========== | |
| if not isinstance(kv_cache, FlashInferKVCacheManager): | |
| # [3, B, num_heads, N, head_dim] | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim] | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if enable_ulysses_cp: | |
| if using_optimized_layout: | |
| boundary = num_frames * num_patches | |
| q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :] | |
| q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :] | |
| q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv( | |
| q_patch, k_patch, v_patch, seq_dim=2, head_dim=1 | |
| ) | |
| q_special, k_special, v_special = gather_seq_scatter_heads_qkv( | |
| q_special, k_special, v_special, seq_dim=2, head_dim=1 | |
| ) | |
| q = torch.cat([q_patch, q_special], dim=2) | |
| k = torch.cat([k_patch, k_special], dim=2) | |
| v = torch.cat([v_patch, v_special], dim=2) | |
| else: | |
| q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif self.rope is not None and enable_3d_rope: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| # Batch mode: use SDPA for numerical consistency with SDPA variant | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| ) | |
| if enable_ulysses_cp: | |
| if using_optimized_layout: | |
| seq_global = x.shape[2] | |
| seq_local = num_frames * (num_patches + num_special) | |
| cp_size = seq_global // seq_local | |
| boundary_global = num_frames * cp_size * num_patches | |
| x_patch = x[:, :, :boundary_global, :] | |
| x_special = x[:, :, boundary_global:, :] | |
| x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1) | |
| x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1) | |
| x = torch.cat([x_patch, x_special], dim=2) | |
| else: | |
| x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) | |
| x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) | |
| # ========== Streaming Mode (with FlashInferKVCacheManager) ========== | |
| else: | |
| manager = kv_cache # FlashInferKVCacheManager | |
| # Phase 1 (scale frames): num_frames > 1 — multi-frame batch | |
| # Phase 2 (streaming): num_frames == 1 — single frame | |
| is_multi_frame = (num_frames is not None and num_frames > 1) | |
| if is_multi_frame: | |
| # Phase 1: compute full self-attention via SDPA (all frames attend to each other), | |
| # then append each frame's K/V to the paged cache one at a time. | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| # Apply RoPE before caching (RoPE baked into K before append) | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif self.rope is not None and enable_3d_rope: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| ) | |
| x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) | |
| # Append each frame's K/V to the paged cache individually. | |
| tpf = manager.tokens_per_frame | |
| k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D] | |
| v_all = v.squeeze(0).permute(1, 0, 2) | |
| for f_idx in range(num_frames): | |
| s = f_idx * tpf | |
| manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous()) | |
| manager.evict_frames( | |
| block_idx=global_idx, | |
| scale_frames=self.kv_cache_scale_frames, | |
| sliding_window=self.kv_cache_sliding_window, | |
| cross_frame_special=self.kv_cache_cross_frame_special, | |
| include_scale_frames=self.kv_cache_include_scale_frames, | |
| camera_only=self.kv_cache_camera_only, | |
| num_register_tokens=num_register_tokens, | |
| ) | |
| else: | |
| # Phase 2: single-frame streaming via FlashInfer paged attention. | |
| q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope) | |
| # 1. Append to paged cache | |
| manager.append_frame(global_idx, k_nhd, v_nhd) | |
| # 2. Apply sliding window eviction | |
| manager.evict_frames( | |
| block_idx=global_idx, | |
| scale_frames=self.kv_cache_scale_frames, | |
| sliding_window=self.kv_cache_sliding_window, | |
| cross_frame_special=self.kv_cache_cross_frame_special, | |
| include_scale_frames=self.kv_cache_include_scale_frames, | |
| camera_only=self.kv_cache_camera_only, | |
| num_register_tokens=num_register_tokens, | |
| ) | |
| # 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper | |
| x = manager.compute_attention(global_idx, q_nhd) | |
| # Convert back: [tpf, H, D] -> [B, tpf, C]. | |
| x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class SDPAAttention(Attention): | |
| """ | |
| SDPA variant for streaming inference. | |
| Uses F.scaled_dot_product_attention with dict-based KV cache. | |
| No FlashInfer dependency required — works on any CUDA GPU. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| qk_norm: bool = False, | |
| fused_attn: bool = True, | |
| rope=None, | |
| kv_cache_sliding_window: int = 64, | |
| kv_cache_scale_frames: int = 8, | |
| kv_cache_cross_frame_special: bool = True, | |
| kv_cache_include_scale_frames: bool = True, | |
| kv_cache_camera_only: bool = False, | |
| ) -> None: | |
| super().__init__( | |
| dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, | |
| attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, | |
| qk_norm=qk_norm, fused_attn=fused_attn, rope=rope, | |
| ) | |
| self.kv_cache_sliding_window = kv_cache_sliding_window | |
| self.kv_cache_scale_frames = kv_cache_scale_frames | |
| self.kv_cache_cross_frame_special = kv_cache_cross_frame_special | |
| self.kv_cache_include_scale_frames = kv_cache_include_scale_frames | |
| self.kv_cache_camera_only = kv_cache_camera_only | |
| def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, | |
| num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False, | |
| kv_cache=None, global_idx=0, num_frame_per_block=1, | |
| num_frame_for_scale=-1, num_register_tokens=4) -> Tensor: | |
| B, N, C = x.shape | |
| using_optimized_layout = (num_patches is not None and num_special is not None | |
| and num_frames is not None) | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| # ========== Batch Mode (no KV cache) ========== | |
| if kv_cache is None: | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif self.rope is not None and enable_3d_rope: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| ) | |
| x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) | |
| # ========== Streaming Mode (with KV cache dict) ========== | |
| else: | |
| if self.rope is not None and not enable_3d_rope: | |
| q = self.rope(q, pos) | |
| k = self.rope(k, pos) | |
| elif self.rope is not None and enable_3d_rope: | |
| q = apply_rotary_emb(q, pos) | |
| k = apply_rotary_emb(k, pos) | |
| camera_token_idx = 0 | |
| scale_token_idx = camera_token_idx + num_register_tokens + 1 | |
| if kv_cache[f"k_{global_idx}"] is None: | |
| kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block, | |
| N // num_frame_per_block, self.head_dim) | |
| kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block, | |
| N // num_frame_per_block, self.head_dim) | |
| else: | |
| num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3] | |
| kv_cache[f"k_{global_idx}"] = torch.cat(( | |
| kv_cache[f"k_{global_idx}"], | |
| k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| ), dim=2) | |
| kv_cache[f"v_{global_idx}"] = torch.cat(( | |
| kv_cache[f"v_{global_idx}"], | |
| v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) | |
| ), dim=2) | |
| self._apply_kv_cache_eviction( | |
| kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens | |
| ) | |
| k_cached = kv_cache[f"k_{global_idx}"].clone() | |
| v_cached = kv_cache[f"v_{global_idx}"].clone() | |
| a, b, c, d, e = k_cached.shape | |
| k_full = k_cached.reshape(a, b, c * d, e) | |
| v_full = v_cached.reshape(a, b, c * d, e) | |
| if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None: | |
| special_k = kv_cache[f"k_{global_idx}_special"] | |
| special_v = kv_cache[f"v_{global_idx}_special"] | |
| sa, sb, sc, sd, se = special_k.shape | |
| k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2) | |
| v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2) | |
| q_seq_len = q.shape[2] | |
| x = F.scaled_dot_product_attention( | |
| q, k_full, v_full, | |
| dropout_p=self.attn_drop.p if self.training else 0.0, | |
| ) | |
| x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens): | |
| """Apply sliding window eviction to KV cache.""" | |
| sliding_window_frames = self.kv_cache_sliding_window | |
| scale_frames = self.kv_cache_scale_frames | |
| if kv_cache[f"k_{global_idx}"].shape[3] > 1: | |
| num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2] | |
| if num_cached_frames > sliding_window_frames + scale_frames: | |
| evict_start = scale_frames | |
| evict_end = num_cached_frames - sliding_window_frames | |
| if evict_end > evict_start: | |
| evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :] | |
| evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :] | |
| if self.kv_cache_cross_frame_special: | |
| if self.kv_cache_camera_only: | |
| new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() | |
| new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() | |
| else: | |
| new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() | |
| new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() | |
| if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None: | |
| kv_cache[f"k_{global_idx}_special"] = new_special_k | |
| kv_cache[f"v_{global_idx}_special"] = new_special_v | |
| else: | |
| kv_cache[f"k_{global_idx}_special"] = torch.cat( | |
| [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2) | |
| kv_cache[f"v_{global_idx}_special"] = torch.cat( | |
| [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2) | |
| if self.kv_cache_include_scale_frames: | |
| kv_cache[f"k_{global_idx}"] = torch.cat([ | |
| kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :], | |
| kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| ], dim=2) | |
| kv_cache[f"v_{global_idx}"] = torch.cat([ | |
| kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :], | |
| kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| ], dim=2) | |
| else: | |
| kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |
| kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] | |