File size: 25,938 Bytes
9dad400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
"""Inference utilities for MolmoAct2"""

from dataclasses import dataclass
from typing import Any, Iterable, Optional, Sequence, Tuple

import torch
from torch.nn import functional as F
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig


@dataclass
class _ActionFlowInputs:
    trajectory: torch.Tensor
    context: Any
    modulations: Sequence[Any]
    action_dim_is_pad: Optional[torch.Tensor]


@dataclass
class _ActionFlowCudaGraph:
    key: Tuple[Any, ...]
    graph: torch.cuda.CUDAGraph
    static_inputs: _ActionFlowInputs
    output: torch.Tensor


@dataclass
class _DepthDecodeCudaGraphLayerStage:
    residual: torch.Tensor
    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor


@dataclass
class _DepthDecodeCudaGraphPostStage:
    graph: torch.cuda.CUDAGraph
    attn_context: torch.Tensor


@dataclass
class _DepthDecodeCudaGraph:
    cache_key: Tuple[Any, ...]
    pre_graph: torch.cuda.CUDAGraph
    token_ids: torch.Tensor
    cos: torch.Tensor
    sin: torch.Tensor
    positions: torch.Tensor
    stages: Sequence[_DepthDecodeCudaGraphLayerStage]
    post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
    output: torch.Tensor


@dataclass
class _DepthDecodeCudaGraphSpec:
    eligible: bool
    cache_key_prefix: Tuple[Any, ...]
    num_hidden_layers: int
    head_dim: int
    num_attention_heads: int


def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
    if past_key_values is None:
        return 0
    seq_len = past_key_values.get_seq_length()
    if torch.is_tensor(seq_len):
        return int(seq_len.item())
    return int(seq_len)


def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
    if past_key_values is None:
        return -1
    max_len = past_key_values.get_max_cache_shape()
    if torch.is_tensor(max_len):
        return int(max_len.item())
    return int(max_len)


def _iter_cache_key_values(
    past_key_values: Cache,
) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
    layers = getattr(past_key_values, "layers", None)
    if layers is not None:
        for layer in layers:
            yield getattr(layer, "keys", None), getattr(layer, "values", None)
        return
    for layer in past_key_values:
        yield layer[0], layer[1]


class _DepthDecodeStaticLayerCache:
    is_compileable = False
    is_sliding = False

    def __init__(self, max_cache_len: int) -> None:
        self.max_cache_len = int(max_cache_len)
        self.cumulative_length = 0
        self.keys: Optional[torch.Tensor] = None
        self.values: Optional[torch.Tensor] = None

    def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
        bsz, n_heads = key_states.shape[:2]
        self.keys = torch.empty(
            (bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
            dtype=key_states.dtype,
            device=key_states.device,
        )
        self.values = torch.empty(
            (bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
            dtype=value_states.dtype,
            device=value_states.device,
        )

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        *args,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.keys is None:
            self._allocate(key_states, value_states)
        start = self.cumulative_length
        end = start + key_states.shape[-2]
        if end > self.max_cache_len:
            raise RuntimeError(
                f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
            )
        self.keys[:, :, start:end, :].copy_(key_states)
        self.values[:, :, start:end, :].copy_(value_states)
        self.cumulative_length = end
        return self.keys[:, :, :end, :], self.values[:, :, :end, :]

    def get_seq_length(self) -> int:
        return self.cumulative_length

    def get_max_cache_shape(self) -> int:
        return -1

    def reset(self) -> None:
        self.cumulative_length = 0


class _DepthDecodeStaticCache(Cache):
    def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
        text_config = config.get_text_config(decoder=True)
        super().__init__(
            layers=[
                _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
                for _ in range(text_config.num_hidden_layers)
            ]
        )

    def get_seq_length(self, layer_idx: int = 0) -> int:
        return self.layers[layer_idx].get_seq_length()

    def get_max_cache_shape(self, layer_idx: int = 0) -> int:
        return self.layers[layer_idx].get_max_cache_shape()

    def reset(self) -> None:
        for layer in self.layers:
            layer.reset()


class ActionCudaGraphManager:
    def __init__(self, model: Any) -> None:
        self.model = model
        self.enabled = True
        self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None

    def set_enabled(self, enabled: bool) -> None:
        self.enabled = bool(enabled)

    def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
        action_model = self.model
        if not self.enabled:
            return False
        if action_model.training or action_model._require_action_expert().training:
            return False
        if inputs.trajectory.device.type != "cuda":
            return False

        def all_on_cuda():
            yield inputs.trajectory
            for k, v in inputs.context.kv_contexts:
                yield k
                yield v
            for t in (
                inputs.context.cross_mask,
                inputs.context.self_mask,
                inputs.context.valid_action,
                inputs.action_dim_is_pad,
            ):
                if t is not None:
                    yield t
            if inputs.context.rope_cache is not None:
                yield from inputs.context.rope_cache
            for step in inputs.modulations:
                yield step.conditioning
                for block_modulation in step.block_modulations:
                    yield from block_modulation
                yield from step.final_modulation

        return all(t.device.type == "cuda" for t in all_on_cuda())

    def run_action_flow(
        self,
        inputs: _ActionFlowInputs,
        steps: int,
        run_loop,
    ) -> torch.Tensor:
        key = _cuda_graph_key(inputs, steps)
        cache = self.action_flow_graph
        if cache is None or cache.key != key:
            static_inputs = _clone_static_inputs(inputs)
            graph, output = _capture_cuda_graph(
                lambda: run_loop(static_inputs, steps),
                inputs.trajectory.device,
                after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
            )
            cache = _ActionFlowCudaGraph(
                key=key,
                graph=graph,
                static_inputs=static_inputs,
                output=output,
            )
            self.action_flow_graph = cache
        else:
            _copy_inputs_(cache.static_inputs, inputs)

        cache.graph.replay()
        return cache.output.clone()


class DepthDecodeCudaGraphManager:
    def __init__(self, model: Any) -> None:
        self.model = model
        self.backbone = model.model
        self.enabled = True
        self.graph: Optional[_DepthDecodeCudaGraph] = None
        self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None

    def set_enabled(self, enabled: bool) -> None:
        self.enabled = bool(enabled)

    def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
        return _DepthDecodeStaticCache(
            config=self.model.config.text_config,
            max_cache_len=max_cache_len,
        )

    def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
        static = self.graph_spec
        if static is None:
            cfg = self.backbone.transformer.config
            rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
            static = _DepthDecodeCudaGraphSpec(
                eligible=(
                    not cfg.norm_after
                    and cfg.rope_scaling_layers is None
                    and getattr(rotary_emb, "rope_type", None) == "default"
                    and cfg._attn_implementation == "sdpa"
                ),
                cache_key_prefix=(
                    cfg.hidden_size,
                    cfg.num_attention_heads,
                    cfg.num_key_value_heads,
                    cfg.head_dim,
                    cfg.num_hidden_layers,
                    cfg.use_qk_norm,
                    cfg.qk_norm_type,
                    cfg._attn_implementation,
                ),
                num_hidden_layers=cfg.num_hidden_layers,
                head_dim=cfg.head_dim,
                num_attention_heads=cfg.num_attention_heads,
            )
            self.graph_spec = static
        return static

    def can_use(
        self,
        next_input_ids: torch.Tensor,
        *,
        past_key_values: Cache,
        attention_bias: torch.Tensor,
    ) -> bool:
        if (
            not self.enabled
            or self.model.training
            or self.backbone.transformer.training
        ):
            return False
        if next_input_ids.device.type != "cuda":
            return False
        if (
            next_input_ids.ndim != 2
            or next_input_ids.shape[0] != 1
            or next_input_ids.shape[1] != 1
        ):
            return False
        if not isinstance(past_key_values, _DepthDecodeStaticCache):
            return False
        if (
            not torch.is_tensor(attention_bias)
            or attention_bias.device != next_input_ids.device
        ):
            return False
        return self._depth_decode_spec().eligible

    def _depth_decode_key(
        self,
        next_input_ids: torch.Tensor,
        attention_bias: torch.Tensor,
    ) -> Tuple[Any, ...]:
        device = next_input_ids.device
        return (
            self._depth_decode_spec().cache_key_prefix,
            device.type,
            device.index,
            self.model.lm_head.weight.dtype,
            attention_bias.shape[-1],
        )

    def _select_depth_decode_rope(
        self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
    ) -> None:
        emb = self.backbone.transformer.rotary_emb
        cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
        sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])

    def _depth_decode_pre_layer(
        self,
        layer_idx: int,
        hidden_states: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        block = self.backbone.transformer.blocks[layer_idx]
        attention = block.self_attn
        residual = hidden_states
        hidden_states = block.attn_norm(hidden_states)

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attention.head_dim)
        qkv = attention.att_proj(hidden_states)
        query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
        value_states = value_states.view(hidden_shape)

        apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
        norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"

        if apply_qk_norm and not norm_after_view:
            query_states = attention.q_norm(query_states)
            key_states = attention.k_norm(key_states)

        query_states = query_states.view(hidden_shape)
        key_states = key_states.view(hidden_shape)

        if norm_after_view:
            query_states = attention.q_norm(query_states)
            key_states = attention.k_norm(key_states)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
        query_states, key_states = _apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )
        return residual, query_states, key_states, value_states

    def _depth_decode_pre0(
        self,
        token_ids: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        inputs_embeds = self.model._embed_base_tokens(token_ids)
        return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)

    def _depth_decode_post_layer(
        self,
        layer_idx: int,
        residual: torch.Tensor,
        attn_context: torch.Tensor,
    ) -> torch.Tensor:
        block = self.backbone.transformer.blocks[layer_idx]
        attention = block.self_attn
        input_shape = residual.shape[:-1]
        attn_output = attn_context.reshape(*input_shape, -1).contiguous()
        attn_output = attention.attn_out(attn_output)
        hidden_states = residual + block.dropout(attn_output)

        residual = hidden_states
        hidden_states = block.ff_norm(hidden_states)
        hidden_states = block.mlp(hidden_states)
        hidden_states = residual + block.dropout(hidden_states)
        return hidden_states

    def _depth_decode_post_and_pre_next(
        self,
        layer_idx: int,
        residual: torch.Tensor,
        attn_context: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
        return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)

    def _depth_decode_last_post(
        self,
        layer_idx: int,
        residual: torch.Tensor,
        attn_context: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
        return self.backbone.transformer.ln_f(hidden_states)

    def _build_depth_decode_graph(
        self,
        next_input_ids: torch.Tensor,
        *,
        past_length: int,
        attention_bias: torch.Tensor,
    ) -> _DepthDecodeCudaGraph:
        text_config = self.backbone.transformer.config
        device = next_input_ids.device
        dtype = self.model.lm_head.weight.dtype
        static = self._depth_decode_spec()
        num_layers = static.num_hidden_layers
        head_dim = static.head_dim
        max_cache_len = int(attention_bias.shape[-1])
        max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
        self.backbone.transformer.prepare_rope_cache(
            device=device, max_seq_len=max_rope_len
        )

        token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
        cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
        sin = torch.empty_like(cos)
        positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
        context_shape = (1, 1, static.num_attention_heads, head_dim)

        token_ids.copy_(next_input_ids)
        self._select_depth_decode_rope(cos, sin, past_length=past_length)

        pre_graph, pre_output = _capture_cuda_graph(
            lambda: self._depth_decode_pre0(token_ids, cos, sin),
            device,
        )
        stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
        post_graphs = []
        for layer_idx in range(num_layers - 1):
            stage = stages[-1]
            attn_context = torch.empty(context_shape, device=device, dtype=dtype)
            graph, output = _capture_cuda_graph(
                lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
                    self._depth_decode_post_and_pre_next(
                        layer_idx,
                        stage.residual,
                        attn_context,
                        cos,
                        sin,
                    )
                ),
                device,
            )
            post_graphs.append(
                _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
            )
            stages.append(_DepthDecodeCudaGraphLayerStage(*output))

        last_stage = stages[-1]
        last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
        last_graph, last_output = _capture_cuda_graph(
            lambda: self._depth_decode_last_post(
                num_layers - 1,
                last_stage.residual,
                last_attn_context,
            ),
            device,
        )
        post_graphs.append(
            _DepthDecodeCudaGraphPostStage(
                graph=last_graph, attn_context=last_attn_context
            )
        )
        return _DepthDecodeCudaGraph(
            cache_key=self._depth_decode_key(next_input_ids, attention_bias),
            pre_graph=pre_graph,
            token_ids=token_ids,
            cos=cos,
            sin=sin,
            positions=positions,
            stages=tuple(stages),
            post_graphs=tuple(post_graphs),
            output=last_output,
        )

    def _get_depth_decode_graph(
        self,
        next_input_ids: torch.Tensor,
        *,
        past_length: int,
        attention_bias: torch.Tensor,
    ) -> _DepthDecodeCudaGraph:
        key = self._depth_decode_key(next_input_ids, attention_bias)
        decode_graph = self.graph
        if decode_graph is None or decode_graph.cache_key != key:
            decode_graph = self._build_depth_decode_graph(
                next_input_ids,
                past_length=past_length,
                attention_bias=attention_bias,
            )
            self.graph = decode_graph
        else:
            decode_graph.token_ids.copy_(next_input_ids)
            self._select_depth_decode_rope(
                decode_graph.cos, decode_graph.sin, past_length=past_length
            )
        return decode_graph

    def _run_depth_decode_attention_core(
        self,
        layer_idx: int,
        stage: _DepthDecodeCudaGraphLayerStage,
        *,
        past_key_values: Cache,
        attention_bias: torch.Tensor,
        cache_position: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> torch.Tensor:
        attention = self.backbone.transformer.blocks[layer_idx].self_attn
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_values.update(
            stage.key,
            stage.value,
            layer_idx,
            cache_kwargs,
        )
        key_states = _repeat_kv(key_states, attention.num_key_value_groups)
        value_states = _repeat_kv(value_states, attention.num_key_value_groups)
        attn_output = F.scaled_dot_product_attention(
            stage.query,
            key_states,
            value_states,
            attn_mask=attention_bias,
            dropout_p=0.0,
            is_causal=False,
        )
        return attn_output.transpose(1, 2)

    def run(
        self,
        next_input_ids: torch.Tensor,
        *,
        past_key_values: Cache,
        attention_bias: torch.Tensor,
        past_length: int,
    ) -> Tuple[torch.Tensor, Cache]:
        end = past_length + 1
        decode_graph = self._get_depth_decode_graph(
            next_input_ids,
            past_length=past_length,
            attention_bias=attention_bias,
        )
        cache_position = decode_graph.positions[past_length:end]
        attention_bias_q = attention_bias[:, :, past_length:end, :end]

        decode_graph.pre_graph.replay()

        for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
            attn_context = self._run_depth_decode_attention_core(
                layer_idx,
                decode_graph.stages[layer_idx],
                past_key_values=past_key_values,
                attention_bias=attention_bias_q,
                cache_position=cache_position,
                cos=decode_graph.cos,
                sin=decode_graph.sin,
            )
            post_graph.attn_context.copy_(attn_context)
            post_graph.graph.replay()

        return decode_graph.output, past_key_values


def _cuda_graph_tensor_signature(
    tensor: Optional[torch.Tensor],
) -> Optional[Tuple[Any, ...]]:
    if tensor is None:
        return None
    return (
        tuple(tensor.shape),
        tuple(tensor.stride()),
        str(tensor.dtype),
        str(tensor.device),
    )


def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
    sig = _cuda_graph_tensor_signature
    return (
        tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
        sig(context.cross_mask),
        sig(context.self_mask),
        sig(context.valid_action),
        None
        if context.rope_cache is None
        else tuple(sig(t) for t in context.rope_cache),
    )


def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
    sig = _cuda_graph_tensor_signature
    return tuple(
        (
            sig(step.conditioning),
            tuple(
                tuple(sig(t) for t in block_modulation)
                for block_modulation in step.block_modulations
            ),
            tuple(sig(t) for t in step.final_modulation),
        )
        for step in modulations
    )


def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
    sig = _cuda_graph_tensor_signature
    return (
        sig(inputs.trajectory),
        _cuda_graph_context_signature(inputs.context),
        _cuda_graph_modulation_signature(inputs.modulations),
        sig(inputs.action_dim_is_pad),
        int(steps),
    )


def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    if tensor is None:
        return None
    static = torch.empty_strided(
        tuple(tensor.shape),
        tuple(tensor.stride()),
        device=tensor.device,
        dtype=tensor.dtype,
    )
    static.copy_(tensor)
    return static


def _clone_static_context(context: Any) -> Any:
    rope_cache = None
    if context.rope_cache is not None:
        rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
    return context.__class__(
        kv_contexts=tuple(
            (_clone_static_tensor(k), _clone_static_tensor(v))
            for k, v in context.kv_contexts
        ),
        cross_mask=_clone_static_tensor(context.cross_mask),
        self_mask=_clone_static_tensor(context.self_mask),
        valid_action=_clone_static_tensor(context.valid_action),
        rope_cache=rope_cache,
    )


def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
    return tuple(
        step.__class__(
            conditioning=_clone_static_tensor(step.conditioning),
            block_modulations=tuple(
                tuple(_clone_static_tensor(t) for t in block_modulation)
                for block_modulation in step.block_modulations
            ),
            final_modulation=tuple(
                _clone_static_tensor(t) for t in step.final_modulation
            ),
        )
        for step in modulations
    )


def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
    return _ActionFlowInputs(
        trajectory=_clone_static_tensor(inputs.trajectory),
        context=_clone_static_context(inputs.context),
        modulations=_clone_static_modulations(inputs.modulations),
        action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
    )


def _copy_context_(dst: Any, src: Any) -> None:
    for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
        dst_k.copy_(src_k)
        dst_v.copy_(src_v)
    if src.cross_mask is not None:
        dst.cross_mask.copy_(src.cross_mask)
    if src.self_mask is not None:
        dst.self_mask.copy_(src.self_mask)
    if src.valid_action is not None:
        dst.valid_action.copy_(src.valid_action)
    if src.rope_cache is not None:
        for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
            dst_tensor.copy_(src_tensor)


def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
    dst.trajectory.copy_(src.trajectory)
    _copy_context_(dst.context, src.context)
    if src.action_dim_is_pad is not None:
        dst.action_dim_is_pad.copy_(src.action_dim_is_pad)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def _apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    unsqueeze_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (_rotate_half(q) * sin)
    k_embed = (k * cos) + (_rotate_half(k) * sin)
    return q_embed, k_embed


def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def _capture_cuda_graph(
    fn,
    device: torch.device,
    *,
    after_warmup=None,
) -> Tuple[torch.cuda.CUDAGraph, Any]:
    warmup_stream = torch.cuda.Stream(device=device)
    warmup_stream.wait_stream(torch.cuda.current_stream(device))
    with torch.cuda.stream(warmup_stream):
        fn()
    torch.cuda.current_stream(device).wait_stream(warmup_stream)
    if after_warmup is not None:
        after_warmup()

    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        output = fn()
    return graph, output