File size: 21,001 Bytes
4700ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
"""
AggregatorBase - Base class for all Aggregator implementations.

Provides shared functionality:
- Patch embedding (DINOv2)
- Special tokens (camera, register, scale)
- Block building
- Common forward pass structure

Subclasses implement mode-specific attention logic.
"""

import logging
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List

from lingbot_map.layers import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2

logger = logging.getLogger(__name__)

_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]


def slice_expand_and_flatten(token, B, S, first_num_frame=1):
    """
    Helper function to slice, expand and flatten tokens.

    Args:
        token: Token tensor [1, 2, N, C] where first index is for first frames
        B: Batch size
        S: Sequence length
        first_num_frame: Number of frames to use first token for

    Returns:
        Flattened tokens [B*S, N, C]
    """
    # token shape: [1, 2, N, C]
    # Expand to [B, S, N, C]
    if first_num_frame > 1:
        # Use first token for first first_num_frame frames, second for rest
        token_first = token[:, :1].expand(B, first_num_frame, -1, -1)  # [B, first_num_frame, N, C]
        token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1)  # [B, S-first_num_frame, N, C]
        token_expanded = torch.cat([token_first, token_rest], dim=1)  # [B, S, N, C]
    else:
        # Use first token for first frame, second for rest
        token_first = token[:, :1].expand(B, 1, -1, -1)  # [B, 1, N, C]
        token_rest = token[:, 1:].expand(B, S - 1, -1, -1)  # [B, S-1, N, C]
        token_expanded = torch.cat([token_first, token_rest], dim=1)  # [B, S, N, C]

    # Flatten to [B*S, N, C]
    return token_expanded.reshape(B * S, -1, token.shape[-1])


class AggregatorBase(nn.Module, ABC):
    """
    Base class for all Aggregator implementations.

    Handles shared components:
    - Patch embedding (DINOv2 or conv)
    - Special tokens (camera, register, optionally scale)
    - Block creation (frame + global)
    - RoPE (2D rotary position embeddings)
    - Common forward pass scaffolding

    Subclasses must implement:
    - _process_global_attention(): Mode-specific cross-frame attention logic
    """

    def __init__(
        self,
        # Architecture parameters
        img_size=518,
        patch_size=14,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4.0,
        num_register_tokens=4,
        # Block configuration
        block_fn=Block,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        qk_norm=True,
        init_values=0.01,
        # Patch embedding
        patch_embed="dinov2_vitl14_reg",
        pretrained_path=None,
        # Attention pattern
        aa_order=["frame", "global"],
        aa_block_size=1,
        # RoPE
        rope_freq=100,
        disable_global_rope=False,
        # Gradient checkpointing
        use_reentrant: bool = False,
        use_gradient_checkpoint: bool = True,
    ):
        super().__init__()

        # Store configuration
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.num_register_tokens = num_register_tokens
        self.aa_order = aa_order
        self.aa_block_size = aa_block_size
        self.disable_global_rope = disable_global_rope
        self.use_reentrant = use_reentrant
        self.use_gradient_checkpoint = use_gradient_checkpoint
        self.pretrained_path = pretrained_path
        self.enable_ulysses_cp = False  # CP disabled

        print("pretrained_path:", self.pretrained_path)

        # Validate depth
        if self.depth % self.aa_block_size != 0:
            raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
        self.aa_block_num = self.depth // self.aa_block_size

        # Build patch embedding
        self._build_patch_embed(
            patch_embed=patch_embed,
            img_size=img_size,
            patch_size=patch_size,
            num_register_tokens=num_register_tokens,
            embed_dim=embed_dim,
            pretrained_path=pretrained_path
        )

        # Initialize RoPE
        self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
        self.position_getter = PositionGetter() if self.rope is not None else None

        # Build blocks (frame + global)
        self._build_blocks(
            block_fn=block_fn,
            depth=depth,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            ffn_bias=ffn_bias,
            init_values=init_values,
            qk_norm=qk_norm,
        )

        # Setup special tokens (camera, register, optionally scale)
        self._setup_special_tokens()

        # Register normalization constants
        for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
            self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)

        # Initialize from DINO checkpoint if available
        if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
            self._init_blocks_from_dino(self._dino_checkpoint)
            del self._dino_checkpoint  # Free memory

    def _build_patch_embed(
        self,
        patch_embed: str,
        img_size: int,
        patch_size: int,
        num_register_tokens: int,
        embed_dim: int,
        pretrained_path: str,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        block_chunks=0,
        init_values=1.0,
    ):
        """
        Build patch embedding layer.

        Supports:
        - "conv": Simple convolutional patch embedding
        - "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
        """
        if "conv" in patch_embed:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=3,
                embed_dim=embed_dim
            )
            self._dino_checkpoint = None

        else:
            vit_models = {
                "dinov2_vitl14_reg": vit_large,
                "dinov2_vitb14_reg": vit_base,
                "dinov2_vits14_reg": vit_small,
                "dinov2_vitg2_reg": vit_giant2,
            }

            if patch_embed not in vit_models:
                raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")

            self.patch_embed = vit_models[patch_embed](
                img_size=img_size,
                patch_size=patch_size,
                num_register_tokens=num_register_tokens,
                interpolate_antialias=interpolate_antialias,
                interpolate_offset=interpolate_offset,
                block_chunks=block_chunks,
                init_values=init_values,
            )

            # Load pretrained weights
            try:
                ckpt = torch.load(pretrained_path)
                del ckpt['pos_embed']
                logger.info("Loading pretrained weights for DINOv2")
                missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
                logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")

                # Store checkpoint for block initialization
                self._dino_checkpoint = ckpt
            except Exception as e:
                logger.warning(f"Failed to load pretrained weights: {e}")
                self._dino_checkpoint = None

            # Disable gradients for mask token
            if hasattr(self.patch_embed, "mask_token"):
                self.patch_embed.mask_token.requires_grad_(False)

    @abstractmethod
    def _build_blocks(
        self,
        block_fn,
        depth: int,
        embed_dim: int,
        num_heads: int,
        mlp_ratio: float,
        qkv_bias: bool,
        proj_bias: bool,
        ffn_bias: bool,
        init_values: float,
        qk_norm: bool,
    ):
        """
        Build frame_blocks and global_blocks.

        Subclasses implement mode-specific block creation.

        Must create:
        - self.frame_blocks: nn.ModuleList of frame attention blocks
        - self.global_blocks: nn.ModuleList of global attention blocks
        """
        pass

    @abstractmethod
    def _setup_special_tokens(self):
        """
        Setup camera token, register tokens, and optionally scale token.

        Subclasses implement mode-specific token initialization.

        Must create:
        - self.camera_token
        - self.register_token
        - self.scale_token (optional, for causal mode)
        - self.patch_start_idx
        - self.num_special_tokens
        """
        pass

    def _init_blocks_from_dino(self, dino_ckpt: dict):
        """
        Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.

        Args:
            dino_ckpt: Checkpoint dictionary from DINOv2 model
        """
        logger.info("Initializing blocks from DINOv2 pretrained weights")

        # Extract block keys
        dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
        if not dino_block_keys:
            logger.warning("No 'blocks' found in DINO checkpoint")
            return

        # Get block indices
        block_indices = set()
        for key in dino_block_keys:
            parts = key.split('.')
            if len(parts) > 1 and parts[1].isdigit():
                block_indices.add(int(parts[1]))

        num_dino_blocks = len(block_indices)
        print(f"Found {num_dino_blocks} blocks in DINO checkpoint")

        # Initialize frame_blocks
        for i, block in enumerate(self.frame_blocks):
            dino_block_idx = i % num_dino_blocks
            block_state_dict = {}
            prefix = f'blocks.{dino_block_idx}.'
            for key, value in dino_ckpt.items():
                if key.startswith(prefix):
                    new_key = key[len(prefix):]
                    block_state_dict[new_key] = value

            if block_state_dict:
                missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
                if i == 0:  # Only log for first block to avoid spam
                    print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")

        # Initialize global_blocks
        for i, block in enumerate(self.global_blocks):
            dino_block_idx = i % num_dino_blocks
            block_state_dict = {}
            prefix = f'blocks.{dino_block_idx}.'
            for key, value in dino_ckpt.items():
                if key.startswith(prefix):
                    new_key = key[len(prefix):]
                    block_state_dict[new_key] = value

            if block_state_dict:
                missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
                if i == 0:  # Only log for first block to avoid spam
                    print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")

        logger.info("Successfully initialized blocks from DINOv2 weights")

    def _embed_images(
        self,
        images: torch.Tensor,
        num_frame_for_scale: Optional[int] = None,
    ) -> Tuple[torch.Tensor, int, int, int, int, int]:
        """
        Embed images and prepare for attention processing.

        Handles:
        - Image normalization
        - Patch embedding
        - Special token concatenation
        - Position embedding

        Args:
            images: Input images [B, S, 3, H, W] in range [0, 1]
            num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)

        Returns:
            (tokens, B, S, S, P, C):
                tokens: Embedded tokens [B*S, P, C]
                B: Batch size
                S: Sequence length
                S: Same as above (no CP slicing)
                P: Number of tokens per frame (patches + special tokens)
                C: Embedding dimension
        """
        B, S, C_in, H, W = images.shape

        if C_in != 3:
            raise ValueError(f"Expected 3 input channels, got {C_in}")

        # Normalize images
        images = (images - self._resnet_mean) / self._resnet_std

        # No CP slicing: S_local == S_global
        S_local = S
        S_global = S

        # Reshape for patch embedding [B*S, C, H, W]
        images = images.view(B * S, C_in, H, W)

        # Patch embedding
        patch_tokens = self.patch_embed(images)
        if isinstance(patch_tokens, dict):
            patch_tokens = patch_tokens["x_norm_patchtokens"]

        _, P_patch, C = patch_tokens.shape

        # Prepare special tokens
        special_tokens = self._prepare_special_tokens(
            B, S_local, S_global, C,
            num_frame_for_scale=num_frame_for_scale
        )

        # Concatenate special tokens + patch tokens
        tokens = torch.cat([special_tokens, patch_tokens], dim=1)

        _, P, C = tokens.shape

        return tokens, B, S_local, S_global, P, C

    @abstractmethod
    def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
        """
        Prepare special tokens (camera, register, optionally scale).

        Subclasses implement mode-specific token preparation.

        Args:
            B: Batch size
            S_local: Local sequence length
            S_global: Global sequence length
            C: Embedding dimension
            **kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)

        Returns:
            Special tokens [B*S, N_special, C]
        """
        pass

    def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
        """
        Get 2D position embeddings for RoPE.

        Args:
            B: Batch size
            S: Sequence length
            H: Image height
            W: Image width
            device: Device to create positions on

        Returns:
            Position tensor [B*S, P, 2] or None if rope is disabled
        """
        if self.rope is None:
            return None

        # Get patch positions
        pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)

        # Add offset for patch tokens (skip special tokens at pos=0)
        if self.patch_start_idx > 0:
            pos = pos + 1
            pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
            pos = torch.cat([pos_special, pos], dim=1)

        return pos

    def _process_frame_attention(
        self,
        tokens: torch.Tensor,
        B: int,
        S: int,
        P: int,
        C: int,
        frame_idx: int,
        pos: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
        """
        Process frame attention blocks.

        Frame attention operates independently per frame (no cross-frame communication).
        Tokens stay in shape [B*S, P, C].

        Args:
            tokens: Input tokens [B*S, P, C]
            B: Batch size
            S: Sequence length
            P: Tokens per frame
            C: Embedding dimension
            frame_idx: Current frame block index
            pos: Position embeddings [B*S, P, 2]

        Returns:
            (tokens, frame_idx, intermediates):
                tokens: Output tokens [B*S, P, C]
                frame_idx: Updated frame block index
                intermediates: List of intermediate outputs [B, S, P, C]
        """
        # Ensure correct shape
        if tokens.shape != (B * S, P, C):
            tokens = tokens.view(B * S, P, C)

        if pos is not None and pos.shape != (B * S, P, 2):
            pos = pos.view(B * S, P, 2)

        intermediates = []

        # Process blocks
        for i in range(self.aa_block_size):
            if self.training and self.use_gradient_checkpoint:
                from torch.utils.checkpoint import checkpoint
                tokens = checkpoint(
                    self.frame_blocks[frame_idx],
                    tokens,
                    pos,
                    False,  # enable_ulysses_cp (always False)
                    use_reentrant=self.use_reentrant
                )
            else:
                tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)

            frame_idx += 1
            intermediates.append(tokens.view(B, S, P, C))

        return tokens, frame_idx, intermediates

    @abstractmethod
    def _process_global_attention(
        self,
        tokens: torch.Tensor,
        B: int,
        S_local: int,
        S_global: int,
        P: int,
        C: int,
        global_idx: int,
        pos: Optional[torch.Tensor] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
        """
        Process global (cross-frame) attention blocks.

        Subclasses implement mode-specific attention logic.

        Args:
            tokens: Input tokens
            B: Batch size
            S_local: Local sequence length
            S_global: Global sequence length
            P: Tokens per frame
            C: Embedding dimension
            global_idx: Current global block index
            pos: Position embeddings
            **kwargs: Mode-specific parameters

        Returns:
            (tokens, global_idx, intermediates):
                tokens: Output tokens
                global_idx: Updated global block index
                intermediates: List of intermediate outputs
        """
        pass

    def forward(
        self,
        images: torch.Tensor,
        selected_idx: Optional[List[int]] = None,
        # Mode-specific parameters
        num_frame_for_scale: Optional[int] = None,
        sliding_window_size: Optional[int] = None,
        num_frame_per_block: int = 1,
    ) -> Tuple[List[torch.Tensor], int]:
        """
        Forward pass.

        Args:
            images: Input images [B, S, 3, H, W] in range [0, 1]
            selected_idx: Which block indices to output (None = all)
            num_frame_for_scale: Number of frames for scale estimation (causal mode)
            sliding_window_size: Sliding window size in blocks (causal mode)
            num_frame_per_block: Number of frames per processing block (causal mode)

        Returns:
            (output_list, patch_start_idx):
                output_list: List of block outputs [B, S, P, 2C]
                patch_start_idx: Index where patch tokens start
        """
        B, S_input, _, H, W = images.shape

        # Embed images
        tokens, B, S_local, S_global, P, C = self._embed_images(
            images,
            num_frame_for_scale=num_frame_for_scale,
        )

        # Get position embeddings
        pos_local = self._get_positions(B, S_local, H, W, device=images.device)
        pos_global = self._get_positions(B, S_global, H, W, device=images.device)

        # Alternating attention
        frame_idx = 0
        global_idx = 0
        output_list = []

        for block_group_idx in range(self.aa_block_num):
            for attn_type in self.aa_order:
                if attn_type == "frame":
                    tokens, frame_idx, frame_intermediates = self._process_frame_attention(
                        tokens, B, S_local, P, C, frame_idx, pos=pos_local
                    )
                elif attn_type == "global":
                    tokens, global_idx, global_intermediates = self._process_global_attention(
                        tokens, B, S_local, S_global, P, C, global_idx,
                        pos=pos_global,
                        num_frame_for_scale=num_frame_for_scale,
                        sliding_window_size=sliding_window_size,
                        num_frame_per_block=num_frame_per_block,
                        image_height=H,
                        image_width=W,
                    )
                else:
                    raise ValueError(f"Unknown attention type: {attn_type}")

            # Collect outputs
            if selected_idx is None or block_group_idx in selected_idx:
                for i in range(len(frame_intermediates)):
                    # Concatenate frame and global intermediates [B, S, P, 2C]
                    concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
                    output_list.append(concat_inter)

        return output_list, self.patch_start_idx