File size: 30,458 Bytes
4700ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
"""

FlashInfer KV Cache Manager — Two-Stream Paged Design.



Two logical streams sharing one physical page pool per layer:



  Patch stream (recyclable):

    - page_size = patches_per_frame  (256 for 224×224; 972 for 504×378)

    - Exactly 1 patch page per frame

    - Scale frames  → scale_patch_pages  (never evicted, maxlen=scale_frames)

    - Recent frames → live_window_patch_pages (evicted when > sliding_window)



  Special stream (append-only, never recycled):

    - num_special_tokens (6) special tokens per frame

    - Packed continuously: one special page holds floor(page_size/6) frames

      e.g. page_size=256 → 42 frames per special page, 4 slots wasted

    - Specials written for EVERY frame (including scale + window), not just evicted ones.



Physical layout per block:

    kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]

      Pages 0 .. max_patch_pages-1        : patch page pool (recyclable)

      Pages max_patch_pages .. max_pages-1: special page pool (append-only)

      dim 1: 0=K  1=V



Attention computation:

    visible = scale_patch_pages + live_window_patch_pages + all_special_pages

    Special pages placed LAST → paged_kv_last_page_len naturally describes

    the partial special-tail without a custom mask.



    plan() is called ONCE per frame step (when block_idx == 0).

    run() is called per layer, reusing the same plan.  All layers at the

    same frame step have identical page structures (same page IDs in same

    positions), so reusing the plan across layers is correct.



Public API is drop-in compatible with the previous FlashInferKVCacheManager:

    append_frame(block_idx, k, v)

    evict_frames(block_idx, scale_frames, sliding_window, ...)

    compute_attention(block_idx, q) -> out

    reset()

"""

import collections
import math
from typing import List

import torch
from torch import Tensor

try:
    import flashinfer
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False


class FlashInferKVCacheManager:
    """

    Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).



    Args:

        num_blocks:          Number of Transformer blocks (one cache per block).

        max_num_frames:      Maximum frames held in the KV window at once

                             (scale_frames + sliding_window + headroom).

        tokens_per_frame:    Total tokens per frame = patches + specials (e.g. 262).

        num_heads:           Number of KV heads (= QO heads; MHA assumed).

        head_dim:            Head dimension (64 for ViT-L).

        dtype:               Storage dtype (bfloat16 / float16).

        device:              CUDA device.

        num_special_tokens:  Special tokens per frame: camera + register×N + scale (6).

        scale_frames:        Number of always-resident scale frames (8).

        sliding_window:      Sliding window size (64).

        max_total_frames:    Upper bound on total frames ever processed; used to

                             pre-allocate the special page pool (default 2048).

    """

    def __init__(

        self,

        num_blocks: int,

        max_num_frames: int,

        tokens_per_frame: int,

        num_heads: int,

        head_dim: int,

        dtype: torch.dtype,

        device: torch.device,

        num_special_tokens: int = 6,

        scale_frames: int = 8,

        sliding_window: int = 64,

        max_total_frames: int = 2048,

        force_fp32: bool = False,

        fa3: bool = False,

    ):
        if not FLASHINFER_AVAILABLE:
            raise RuntimeError("FlashInfer is not available. Please install flashinfer.")

        self.num_blocks = num_blocks
        self.num_special_tokens = num_special_tokens         # 6
        self.patches_per_frame = tokens_per_frame - num_special_tokens  # 256 / 999 / ...
        # Use exact page_size = patches_per_frame to eliminate zero-padded slots.
        # FA2 (backend="fa2") supports non-power-of-2 page sizes.
        # FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
        p = self.patches_per_frame
        if fa3:
            # Round up to next power-of-2 for FA3 SM90 kernel requirement.
            # e.g. 999 → 1024 (25 zero-padded slots per patch page)
            self.page_size = 1 << (p - 1).bit_length()
        else:
            self.page_size = p  # exact: no zero padding in patch pages
        self.scale_frames = scale_frames                     # 8
        self.sliding_window = sliding_window                 # 64
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.tokens_per_frame = tokens_per_frame

        assert self.patches_per_frame > 0, (
            f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
        )
        assert self.page_size > 0

        # force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
        # instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
        # in fp32 for accuracy comparison.  Storage dtype is also kept as fp32 in this mode.
        self.force_fp32 = force_fp32
        if force_fp32:
            self.dtype = torch.float32
        else:
            if dtype == torch.float32:
                dtype = torch.bfloat16
            self.dtype = dtype
        self.device = device

        # ── Page pool sizing ─────────────────────────────────────────────────
        # Patch: scale + window + 16 headroom  (pages recycled → fixed count)
        max_patch_pages = scale_frames + sliding_window + 16   # e.g. 88
        # Special: enough for max_total_frames × 6 tokens, plus 16 headroom
        max_special_pages = (
            math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
        )
        self.max_patch_pages = max_patch_pages
        self.max_num_pages = max_patch_pages + max_special_pages

        # ── Physical paged KV caches ─────────────────────────────────────────
        # Shape per block: [max_num_pages, 2, page_size, H, D]   (NHD, K=dim0, V=dim1)
        self.kv_caches: List[Tensor] = [
            torch.zeros(
                self.max_num_pages, 2, self.page_size, num_heads, head_dim,
                dtype=dtype, device=device,
            )
            for _ in range(num_blocks)
        ]

        # ── Per-block state ──────────────────────────────────────────────────
        # Patch pages (IDs 0 .. max_patch_pages-1)
        self.scale_patch_pages: List[collections.deque] = [
            collections.deque() for _ in range(num_blocks)
        ]
        self.live_window_patch_pages: List[collections.deque] = [
            collections.deque() for _ in range(num_blocks)
        ]
        self.free_patch_pages: List[List[int]] = [
            list(range(max_patch_pages)) for _ in range(num_blocks)
        ]

        # Special pages (IDs max_patch_pages .. max_num_pages-1)
        self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
        self.free_special_pages: List[List[int]] = [
            list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
        ]
        self.special_token_count: List[int] = [0] * num_blocks

        # Frame counter per block (determines scale vs window routing)
        self.frame_count: List[int] = [0] * num_blocks

        # Deferred eviction support for flow-based keyframe selection.
        # When True, evict_frames() becomes a no-op; caller must later call
        # execute_deferred_eviction() or rollback_last_frame().
        self._defer_eviction: bool = False

        # ── FlashInfer wrapper ───────────────────────────────────────────────
        # plan() is called once per frame step (block_idx == 0).
        # run() is called per layer, reusing the same aux structures.
        # backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
        # FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
        # FlashInfer 0.2.5 at 518×378 resolution.
        _fi_backend = "fa3" if fa3 else "fa2"
        self.workspace_buffer = torch.zeros(
            128 * 1024 * 1024, dtype=torch.uint8, device=device
        )
        self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
            self.workspace_buffer,
            kv_layout="NHD",
            backend=_fi_backend,
        )

        # plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
        self._qo_indptr = torch.tensor(
            [0, tokens_per_frame], dtype=torch.int32, device=device
        )

    # =========================================================================
    # Public API  (drop-in compatible with previous FlashInferKVCacheManager)
    # =========================================================================

    def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
        """

        Append one frame's K/V tensors to the two-stream cache.



        Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]

        i.e. specials come first (matching stream.py's patch_start_idx convention).



        Args:

            block_idx: Block/layer index (0 … num_blocks-1).

            k: [tokens_per_frame, H, D]  NHD layout.

            v: [tokens_per_frame, H, D]  NHD layout.

        """
        n = self.num_special_tokens  # 6
        sp_k    = k[:n].to(self.dtype)      # [6,   H, D]
        patch_k = k[n:].to(self.dtype)     # [256, H, D]
        sp_v    = v[:n].to(self.dtype)
        patch_v = v[n:].to(self.dtype)

        assert patch_k.shape[0] == self.patches_per_frame, (
            f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
            f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
        )

        self._write_patch_page(block_idx, patch_k, patch_v)
        self._write_special_tokens(block_idx, sp_k, sp_v)
        self.frame_count[block_idx] += 1

    def evict_frames(

        self,

        block_idx: int,

        scale_frames: int,

        sliding_window: int,

        cross_frame_special: bool = True,

        include_scale_frames: bool = True,

        camera_only: bool = False,

        num_register_tokens: int = 4,

    ) -> None:
        """

        Evict old window patch pages (recycle to free list).



        Special pages are NEVER evicted.

        Scale pages are NEVER evicted.

        Only live_window_patch_pages beyond `sliding_window` are recycled.



        When ``_defer_eviction`` is True, this method is a no-op.  The caller

        is expected to later call ``execute_deferred_eviction()`` (keep frame)

        or ``rollback_last_frame()`` (discard frame).

        """
        if self._defer_eviction:
            return
        while len(self.live_window_patch_pages[block_idx]) > sliding_window:
            old_page = self.live_window_patch_pages[block_idx].popleft()
            self.free_patch_pages[block_idx].append(old_page)

    def execute_deferred_eviction(

        self,

        block_idx: int,

        scale_frames: int,

        sliding_window: int,

        **kwargs,

    ) -> None:
        """Run the eviction that was skipped while ``_defer_eviction`` was True."""
        while len(self.live_window_patch_pages[block_idx]) > sliding_window:
            old_page = self.live_window_patch_pages[block_idx].popleft()
            self.free_patch_pages[block_idx].append(old_page)

    def rollback_last_frame(self, block_idx: int) -> None:
        """Undo the most recent ``append_frame()`` for *block_idx*.



        This reverses all three sub-operations of ``append_frame``:

        patch page allocation, special-token write, and frame_count increment.

        It must be called **before** any eviction for that frame (i.e. while

        ``_defer_eviction`` is True or before ``evict_frames`` is called).

        """
        assert self.frame_count[block_idx] > 0, (
            f"block {block_idx}: cannot rollback, frame_count is 0"
        )

        # 1) Undo patch page ── pop from whichever deque it was routed to.
        if self.frame_count[block_idx] > self.scale_frames:
            page_id = self.live_window_patch_pages[block_idx].pop()
        else:
            page_id = self.scale_patch_pages[block_idx].pop()
        self.free_patch_pages[block_idx].append(page_id)

        # 2) Undo special tokens
        n = self.num_special_tokens
        new_count = self.special_token_count[block_idx] - n
        assert new_count >= 0, (
            f"block {block_idx}: special_token_count underflow "
            f"({self.special_token_count[block_idx]} - {n})"
        )
        new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0
        while len(self.all_special_pages[block_idx]) > new_num_pages:
            freed = self.all_special_pages[block_idx].pop()
            self.free_special_pages[block_idx].append(freed)
        self.special_token_count[block_idx] = new_count

        # 3) Decrement frame count
        self.frame_count[block_idx] -= 1

    def _gather_kv(self, block_idx: int):
        """

        Gather all visible K and V tokens from the paged cache into dense tensors.



        Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only

        supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.



        Returns:

            k_flat: [kv_len, H, D]  — all visible K tokens concatenated

            v_flat: [kv_len, H, D]  — all visible V tokens concatenated

        """
        visible  = self.build_visible_page_table(block_idx)
        last_len = self.compute_last_page_len(block_idx)
        P = self.page_size

        parts_k, parts_v = [], []
        for i, pid in enumerate(visible):
            n = last_len if (i == len(visible) - 1) else P
            parts_k.append(self.kv_caches[block_idx][pid, 0, :n])  # [n, H, D]
            parts_v.append(self.kv_caches[block_idx][pid, 1, :n])

        k_flat = torch.cat(parts_k, dim=0)  # [kv_len, H, D]
        v_flat = torch.cat(parts_v, dim=0)
        return k_flat, v_flat

    def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
        """

        Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.



        When self.force_fp32 is True, gathers all visible K/V into dense tensors

        and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.

        This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.



        plan() is called once per frame step (when block_idx == 0).

        All layers at the same step share the same visible page structure,

        so the plan is reused by calling run() with each layer's kv_cache.



        Args:

            block_idx: Block/layer index.

            q: [q_len, H, D]  NHD layout (q_len = tokens_per_frame = 262).



        Returns:

            out: [q_len, H, D]

        """
        if self.frame_count[block_idx] == 0:
            # No KV present yet (should not occur in normal usage after append_frame)
            return torch.zeros_like(q)

        if self.force_fp32:
            # ── fp32 gather+SDPA path ─────────────────────────────────────────
            # Gather visible K/V from paged cache and run SDPA in fp32.
            # This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
            # q_len, H, D → 1, H, q_len, D  (SDPA expects BHsD layout)
            import torch.nn.functional as F_nn
            k_flat, v_flat = self._gather_kv(block_idx)
            q_b = q.float().permute(1, 0, 2).unsqueeze(0)      # [1, H, q_len, D]
            k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
            v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
            out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
            return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]

        if block_idx == 0:
            # ── Plan once per frame step ──────────────────────────────────────
            # Build visible page table from block 0's state.
            # All blocks have identical page structures, so this plan is valid
            # for all subsequent run() calls (block_idx = 1, 2, ...).
            visible  = self.build_visible_page_table(0)
            last_len = self.compute_last_page_len(0)

            assert visible, "visible page table is empty after append_frame"
            assert 1 <= last_len <= self.page_size, (
                f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
            )

            paged_kv_indices       = torch.tensor(visible, dtype=torch.int32, device=self.device)
            paged_kv_indptr        = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
            paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)

            self.prefill_wrapper.plan(
                self._qo_indptr,
                paged_kv_indptr,
                paged_kv_indices,
                paged_kv_last_page_len,
                num_qo_heads      = self.num_heads,
                num_kv_heads      = self.num_heads,
                head_dim_qk       = self.head_dim,
                page_size         = self.page_size,
                causal            = False,          # custom page ordering; no causal mask
                pos_encoding_mode = "NONE",         # RoPE applied externally before append
                q_data_type       = self.dtype,
            )

        # ── Run attention for this layer ──────────────────────────────────────
        # Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
        return self.prefill_wrapper.run(
            q              = q.to(self.dtype).contiguous(),
            paged_kv_cache = self.kv_caches[block_idx],
        )  # → [q_len, H, D]

    def reset(self) -> None:
        """Reset all per-block state for a new sequence."""
        for i in range(self.num_blocks):
            self.scale_patch_pages[i].clear()
            self.live_window_patch_pages[i].clear()
            self.all_special_pages[i].clear()
            self.free_patch_pages[i]   = list(range(self.max_patch_pages))
            self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
            self.special_token_count[i] = 0
            self.frame_count[i] = 0

    # =========================================================================
    # Helper methods
    # =========================================================================

    def build_visible_page_table(self, block_idx: int) -> List[int]:
        """

        Return page IDs in strict order: scale → window → special.



        Placing special pages last means only the final page may be partially

        full, so paged_kv_last_page_len = compute_last_page_len() is sufficient

        without a custom attention mask.

        """
        return (
            list(self.scale_patch_pages[block_idx])       +
            list(self.live_window_patch_pages[block_idx]) +
            list(self.all_special_pages[block_idx])
        )

    def compute_last_page_len(self, block_idx: int) -> int:
        """

        Valid token count in the last page of the visible sequence.



        - No special pages      → last page is a patch page.

                                  Returns patches_per_frame (real tokens written),

                                  which may be < page_size when page_size was rounded

                                  up to a power of 2.

        - Special tail partial  → special_token_count % page_size.

        - Special tail exactly full → page_size.

        """
        if not self.all_special_pages[block_idx]:
            # Last page is a patch page.  We wrote patches_per_frame tokens (0..P-1);
            # positions P..page_size-1 are zero padding.  Tell FlashInfer the true
            # valid count so it doesn't read beyond the real tokens.
            return self.patches_per_frame

        tail = self.special_token_count[block_idx] % self.page_size
        return self.page_size if tail == 0 else tail

    # ── Internal write helpers ────────────────────────────────────────────────

    def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
        """

        Allocate one free patch page and write patches_per_frame patch tokens.



        Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids

        the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.

        kv_caches layout: [max_num_pages, 2, page_size, H, D]  (NHD, K=0, V=1).

        patch_k/v fill exactly one full page (patches_per_frame == page_size).



        Routes to scale_patch_pages if still filling scale quota,

        otherwise to live_window_patch_pages.



        Returns:

            page_id: Physical page index used.

        """
        assert self.free_patch_pages[block_idx], (
            f"block {block_idx}: patch page pool exhausted — "
            f"scale={len(self.scale_patch_pages[block_idx])}, "
            f"window={len(self.live_window_patch_pages[block_idx])}, "
            f"free={len(self.free_patch_pages[block_idx])}"
        )

        page_id = self.free_patch_pages[block_idx].pop()

        # Direct slice write: positions 0..patches_per_frame-1.
        # When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
        # this is equivalent to a full-page write.  When page_size > patches_per_frame
        # (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
        # positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
        P = self.patches_per_frame
        self.kv_caches[block_idx][page_id, 0, :P] = patch_k  # K
        self.kv_caches[block_idx][page_id, 1, :P] = patch_v  # V

        if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
            self.scale_patch_pages[block_idx].append(page_id)
        else:
            self.live_window_patch_pages[block_idx].append(page_id)

        return page_id

    def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
        """

        Append num_special_tokens (6) special tokens to the special stream.



        Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,

        tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch

        overhead of flashinfer.page.append_paged_kv_cache.



        Handles page-boundary crossing: if 6 tokens straddle two pages, performs

        two slice writes (rare — page_size=256 >> 6).

        """
        remaining = self.num_special_tokens   # 6
        written   = 0

        while remaining > 0:
            tail_offset = self.special_token_count[block_idx] % self.page_size

            if tail_offset == 0:
                # Current tail page is full (or no page exists) — allocate a new one
                assert self.free_special_pages[block_idx], (
                    f"block {block_idx}: special page pool exhausted at "
                    f"special_token_count={self.special_token_count[block_idx]}. "
                    f"Increase max_total_frames."
                )
                new_page = self.free_special_pages[block_idx].pop()
                self.all_special_pages[block_idx].append(new_page)

            tail_page = self.all_special_pages[block_idx][-1]
            space     = self.page_size - tail_offset   # free slots in tail page
            write_n   = min(remaining, space)

            # Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
            # shape: [page_size, H, D];  slice [tail_offset:tail_offset+write_n, :, :]
            end = tail_offset + write_n
            self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
            self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]

            self.special_token_count[block_idx] += write_n
            written   += write_n
            remaining -= write_n

    # ── Legacy property (used by stream.py) ──────────────────────────────────

    @property
    def num_frames(self) -> int:
        """Number of frames appended to block 0 (representative)."""
        return self.frame_count[0] if self.frame_count else 0


# =============================================================================
# Sanity check
# =============================================================================

def _sanity_check():
    """

    Minimal smoke test.

    Run with:  python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"

    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        print("[sanity_check] CUDA not available — skipping.")
        return

    tokens_per_frame  = 262   # 256 patch + 6 special (224×224)
    num_special       = 6
    patches_per_frame = tokens_per_frame - num_special  # 256
    page_size         = patches_per_frame               # 256

    mgr = FlashInferKVCacheManager(
        num_blocks         = 2,
        max_num_frames     = 88,
        tokens_per_frame   = tokens_per_frame,
        num_heads          = 16,
        head_dim           = 64,
        dtype              = torch.bfloat16,
        device             = device,
        num_special_tokens = num_special,
        scale_frames       = 8,
        sliding_window     = 64,
        max_total_frames   = 200,
    )

    def make_kv():
        k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
        v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
        return k, v

    def make_q():
        return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)

    for block in range(2):
        for t in range(100):
            k, v = make_kv()
            mgr.append_frame(block, k, v)
            mgr.evict_frames(block, scale_frames=8, sliding_window=64)

        # ── Page count checks ───────────────────────────────────────────────
        n_scale  = len(mgr.scale_patch_pages[block])
        n_window = len(mgr.live_window_patch_pages[block])
        n_spec   = len(mgr.all_special_pages[block])
        sp_count = mgr.special_token_count[block]

        assert n_scale  == 8,  f"block {block}: scale pages = {n_scale},  expected 8"
        assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
        # 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
        expected_spec_pages = math.ceil(100 * num_special / page_size)
        assert n_spec == expected_spec_pages, (
            f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
        )
        assert sp_count == 100 * num_special, (
            f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
        )

        # ── last_page_len ────────────────────────────────────────────────────
        last_len = mgr.compute_last_page_len(block)
        tail = sp_count % page_size
        expected_len = page_size if tail == 0 else tail
        assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"

        # ── visible page table order ─────────────────────────────────────────
        visible = mgr.build_visible_page_table(block)
        assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
        for pid in visible[:n_scale + n_window]:
            assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
        for pid in visible[n_scale + n_window:]:
            assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"

        # ── forward pass: plan() once for block 0, run() for both blocks ─────
        if block == 1:
            # Simulate the actual calling pattern: plan on block 0, run on both
            q0 = make_q()
            out0 = mgr.compute_attention(0, q0)   # triggers plan()
            q1 = make_q()
            out1 = mgr.compute_attention(1, q1)   # reuses plan, different kv_cache
            assert out0.shape == (tokens_per_frame, 16, 64)
            assert out1.shape == (tokens_per_frame, 16, 64)

        print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
              f"special_pages={n_spec}, special_tokens={sp_count}, "
              f"last_page_len={last_len}")

    mgr.reset()
    assert mgr.frame_count[0] == 0
    print("\n[sanity_check] All assertions passed.")


if __name__ == "__main__":
    _sanity_check()